In [1]:
%load_ext autoreload
%autoreload 2

import sys, os

module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

from meta_neural_network_architectures import VGGReLUNormNetwork, ResNet12
from few_shot_learning_system import MAMLFewShotClassifier
from prompters import padding
from utils.parser_utils import get_args
from data import MetaLearningSystemDataLoader
from data import FewShotLearningDatasetParallel
from experiment_builder import ExperimentBuilder
import prompters

import easydict

import torch
import torch.nn as nn
import numpy as np

import torch.backends.cudnn as cudnn
import torchvision.models as models
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR100
import torch.nn.functional as F
import matplotlib.pyplot as plt
import torch.optim as optim

In [2]:
os.environ['DATASET_DIR'] = os.path.join(os.getcwd(), "datasets")

args = easydict.EasyDict(
{
  "batch_size":2,
  "image_height":84,
  "image_width":84,
  "image_channels":3,
  "gpu_to_use":0,
  "num_dataprovider_workers":4,
  "max_models_to_save":5,
  "dataset_name":"mini_imagenet_full_size",
  "dataset_path":"mini_imagenet_full_size",
  "reset_stored_paths":False,
  "experiment_name":"alfa+maml",
  "train_seed": 0, "val_seed": 0,
  "indexes_of_folders_indicating_class": [-3, -2],
  "sets_are_pre_split": True,
  "train_val_test_split": [0.64, 0.16, 0.20],
  "evaluate_on_test_set_only": False,

  "total_epochs": 100,
  "total_iter_per_epoch":500, "continue_from_epoch": -2,
  "num_evaluation_tasks":600,
  "multi_step_loss_num_epochs": 15,
  "minimum_per_task_contribution": 0.01,
  "learnable_per_layer_per_step_inner_loop_learning_rate": False,
  "enable_inner_loop_optimizable_bn_params": False,
  "evalute_on_test_set_only": False,

  "max_pooling": True,
  "per_step_bn_statistics": False,
  "learnable_batch_norm_momentum": False,
  "load_into_memory": False,
  "init_inner_loop_learning_rate": 0.01,
  "init_inner_loop_weight_decay": 0.0005,
  "learnable_bn_gamma": True,
  "learnable_bn_beta": True,

  "dropout_rate_value":0.0,
  "min_learning_rate":0.001,
  "meta_learning_rate":0.001,   "total_epochs_before_pause": 100,
  "first_order_to_second_order_epoch":-1,
  "weight_decay": 0.0,

  "norm_layer":"batch_norm",
  "cnn_num_filters":48,
  "num_stages":4,
  "conv_padding": True,
  "number_of_training_steps_per_iter":5,
  "number_of_evaluation_steps_per_iter":5,
  "cnn_blocks_per_stage":1,
  "num_classes_per_set":5,
  "num_samples_per_class":5,
  "num_target_samples": 15,
    "samples_per_iter" : 1,

  "second_order": True,
  "use_multi_step_loss_optimization":False,
  "attenuate": False,
  "alfa": True,
  "random_init": False,
  "backbone": "4-CONV"
}
)

device = torch.cuda.current_device()
args.im_shape = (2, 3, args.image_height, args.image_width)

args.use_cuda = torch.cuda.is_available()
args.seed = 104
args.reverse_channels=False
args.labels_as_int=False
args.reset_stored_filepaths=False
args.num_of_gpus=1

In [3]:
preprocess = transforms.Compose([
    transforms.Resize(84),
    transforms.ToTensor()
])

train_dataset = CIFAR100("./data", transform=preprocess,
                          download=True, train=True)

val_dataset = CIFAR100("./data", transform=preprocess,
                        download=True, train=False)

train_loader = DataLoader(train_dataset,
                          batch_size=25, pin_memory=True,
                          num_workers=16, shuffle=True)

class_names = train_dataset.classes

Files already downloaded and verified
Files already downloaded and verified


In [4]:
images, targets = next(iter(train_loader))
images = images.to(device)

targets = np.array([0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4])
targets = torch.Tensor(targets)
targets = targets.type(torch.LongTensor)
targets = targets.to(device)

In [5]:
print(images.shape)
print(targets)

torch.Size([25, 3, 84, 84])
tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 4, 4,
        4], device='cuda:0')


In [112]:
class MAMLFewShotClassifier(nn.Module):
    
    def __init__(self, im_shape, device, args):
        
        super(MAMLFewShotClassifier, self).__init__()
        self.args = args
        self.device = device
        self.batch_size = args.batch_size
        self.use_cuda = args.use_cuda
        self.im_shape = im_shape
        self.current_epoch = 0
        
        self.classifier = VGGReLUNormNetwork(im_shape=self.im_shape, num_output_classes=self.args.
                                                 num_classes_per_set,
                                                 args=args, device=device, meta_classifier=True).to(device=self.device)
        
        
        self.prompter = prompters.padding(args=args, prompt_size=10, image_size=self.im_shape)
        
        
        self.optimizer_all = optim.Adam(self.trainable_parameters(), lr=args.meta_learning_rate, amsgrad=False)
        self.scheduler_all = optim.lr_scheduler.CosineAnnealingLR(optimizer=self.optimizer_all, T_max=self.args.total_epochs,
                                                              eta_min=self.args.min_learning_rate)
        
    def trainable_parameters(self):
        for param in self.parameters():
            if param.requires_grad:
                yield param
                
                
    def get_inner_loop_parameter_dict(self, params):
        
        param_dict = dict()
        for name, param in params:
            if param.requires_grad:
                param_dict[name] = param.to(device=device)

        return param_dict
    
    def forward(self, x, y):
        
        loss, prompt_preds = self.net_forward(x,y)
        observer = self.meta_update(loss)
        return observer
    
    
    def forward_grad(self, x, y):
        
        loss, prompt_preds = self.net_forward(x,y)
        
        prompt_grads = self.meta_update_grad(loss)
        
        return prompt_grads
        
    def forward_prompt_backward(self, x, y):
        
        loss, prompt_preds = self.net_forward(x,y)
        
        prompt_grads = self.meta_prompt_update(loss)
        
        return prompt_grads
    
    def forward_classifier_backward(self, x, y):
        
        loss, prompt_preds = self.net_forward(x,y)
        
        prompt_grads = self.meta_classifier_update(loss)
        
        return prompt_grads
    
        
    def net_forward(self, x, y):
        
        names_prompt_weights_copy = self.get_inner_loop_parameter_dict(self.prompter.named_parameters())
        names_weights_copy = self.get_inner_loop_parameter_dict(self.classifier.named_parameters())

        names_prompt_weights_copy = {
                        name.replace('module.', ''): value.unsqueeze(0).repeat(
                            [1] + [1 for i in range(len(value.shape))]) for
                        name, value in names_prompt_weights_copy.items()}

        names_weights_copy = {
                        name.replace('module.', ''): value.unsqueeze(0).repeat(
                            [1] + [1 for i in range(len(value.shape))]) for
                        name, value in names_weights_copy.items()}
        
        prompted_images = self.prompter.forward(x=images, params=names_prompt_weights_copy)
        prompt_preds = self.classifier.forward(prompted_images, params=names_weights_copy,num_step=4)
        
        loss = F.cross_entropy(input=prompt_preds, target=y)
        
        return loss, prompt_preds
        
    
    def meta_update_grad(self, loss):
        
        prompt_grads = torch.autograd.grad(loss, self.prompter.parameters(), allow_unused=True)
        
        return prompt_grads
        
        
    def meta_prompt_update(self, loss):
        
        print("meta_prompt_update")
        
        for param in self.classifier.parameters():
            param.requires_grad = False
            
        for param in self.prompter.parameters():
            param.requires_grad = True
        
        self.optimizer_all.zero_grad()
        loss.backward()
        
        for name, param in self.prompter.named_parameters():
            print(param.grad)
        
        self.optimizer_all.step()
        
        
    def meta_classifier_update(self, loss):
        
        print("meta_classifier_update")
        
        for param in self.classifier.parameters():
            param.requires_grad = True
        
        for param in self.prompter.parameters():
            param.requires_grad = False
        
        self.optimizer_all.zero_grad()
        loss.backward()
        
        for name, param in self.classifier.named_parameters():
            print(param.grad)
        
        self.optimizer_all.step()
    
    
    def unify_test(self, images, targets):
        prompt_grads1 = model.forward_grad(images, targets)
        
        
        ## loss.bac
        model.forward_prompt_backward(images, targets)
        model.forward_classifier_backward(images, targets)
        
        return  prompt_grads1
    
    
    def meta_update(self, loss):
        
        print("loss == ", loss)
        
        #print(self.classifier.state_dict())
        print(self.prompter.state_dict())
        print("==============================")
        
        #### Prompter Update ####
        
        for param in self.classifier.parameters():
            param.requires_grad = False

        for param in self.prompter.parameters():
            param.requires_grad = True
        
        self.optimizer_all.zero_grad()
        
        loss.backward(retain_graph=True)
        self.optimizer_all.step()

        print(self.prompter.state_dict())
        print("==============================")
        
        
        #### Classifier Update ####        
                
        self.classifier.zero_grad()
        self.prompter.zero_grad()
        self.optimizer_all.zero_grad()
        
        for param in self.classifier.parameters():
            param.requires_grad = True

        for param in self.prompter.parameters():
            param.requires_grad = False
            
        loss.backward()
        self.optimizer_all.step()


        print(self.prompter.state_dict())
        print("==============================")
        
        
#########
model = MAMLFewShotClassifier(args=args, device=device, im_shape=(2, 3, args.image_height, args.image_width))

Using max pooling
No inner loop params
torch.Size([2, 48, 84, 84])
No inner loop params
No inner loop params
torch.Size([2, 48, 42, 42])
No inner loop params
No inner loop params
torch.Size([2, 48, 21, 21])
No inner loop params
No inner loop params
torch.Size([2, 48, 10, 10])
No inner loop params
(VGGReLUNormNetwork) meta network params
layer_dict.conv0.conv.weight torch.Size([48, 3, 3, 3])
layer_dict.conv0.conv.bias torch.Size([48])
layer_dict.conv0.norm_layer.running_mean torch.Size([48])
layer_dict.conv0.norm_layer.running_var torch.Size([48])
layer_dict.conv0.norm_layer.bias torch.Size([48])
layer_dict.conv0.norm_layer.weight torch.Size([48])
layer_dict.conv1.conv.weight torch.Size([48, 48, 3, 3])
layer_dict.conv1.conv.bias torch.Size([48])
layer_dict.conv1.norm_layer.running_mean torch.Size([48])
layer_dict.conv1.norm_layer.running_var torch.Size([48])
layer_dict.conv1.norm_layer.bias torch.Size([48])
layer_dict.conv1.norm_layer.weight torch.Size([48])
layer_dict.conv2.conv.weight

### 1) retain_graph= True 사용법 확인

In [113]:
observer = model.forward(images, targets)

loss ==  tensor(3.2313, device='cuda:0', grad_fn=<NllLossBackward0>)
OrderedDict([('pad_dict.pad_up', tensor([[[-0.4360,  0.2818, -0.8179,  ..., -1.3155,  2.1867, -0.2194],
         [ 0.4819,  0.0524, -0.2031,  ..., -1.4800, -0.1848, -0.0638],
         [-0.5514,  0.3375, -1.8276,  ..., -0.7464,  1.6744,  0.9878],
         ...,
         [ 0.5290, -0.1302,  0.2830,  ..., -0.2333,  0.7336,  0.3237],
         [-0.7211, -0.3225,  0.2932,  ..., -1.4422, -1.9310,  0.6447],
         [-0.4023,  0.4491, -0.8455,  ...,  0.9185, -0.4634,  0.2466]],

        [[-0.6048,  0.3230, -1.4740,  ..., -2.1402,  0.7514,  0.2727],
         [-1.4480, -0.1694,  0.2857,  ...,  0.9131,  0.5614,  0.8324],
         [ 0.8564,  0.3845,  1.2634,  ...,  0.6458,  0.3530, -1.0482],
         ...,
         [-0.0256,  0.1351, -0.3309,  ...,  1.2774, -1.0258, -0.0634],
         [ 1.6523, -0.1813, -0.0339,  ...,  0.1125,  1.5760, -1.4010],
         [-1.2190,  1.3813,  0.7423,  ..., -0.2176, -1.7850,  2.6484]],

        [[-0.7

In [80]:
before_update_classifier = observer['before_update_classifier']
before_update_prompter = observer['before_update_prompter']

one_update_classifier = observer['one_update_classifier']
one_update_prompter = observer['one_update_prompter']

two_update_classifier = observer['two_update_classifier']
two_update_prompter = observer['two_update_prompter']


for param in before_update_classifier.parameters():
    print(param)

Parameter containing:
tensor([[[[ 0.0075,  0.0732, -0.0628],
          [-0.0165, -0.0671, -0.0209],
          [ 0.0997, -0.1113, -0.1042]],

         [[ 0.0764,  0.0563,  0.0163],
          [ 0.0509,  0.0345, -0.0792],
          [-0.0655,  0.0283, -0.0187]],

         [[ 0.0037,  0.0686,  0.0595],
          [-0.0564, -0.0613,  0.1009],
          [-0.0316,  0.0289,  0.0627]]],


        [[[ 0.0891, -0.0179,  0.1122],
          [-0.0462, -0.1015,  0.0760],
          [ 0.0137, -0.0129,  0.0560]],

         [[-0.0984,  0.1083,  0.0061],
          [ 0.0952,  0.0551, -0.0016],
          [ 0.0244,  0.0427,  0.0426]],

         [[-0.0700,  0.0334,  0.0726],
          [-0.0606, -0.0467, -0.0784],
          [-0.0316,  0.0130, -0.0647]]],


        [[[ 0.0661, -0.0717, -0.1110],
          [ 0.1130,  0.0614, -0.0839],
          [-0.0838,  0.0761, -0.0360]],

         [[ 0.0804,  0.0934,  0.0663],
          [ 0.0562,  0.0989, -0.0300],
          [-0.0761,  0.0960, -0.0105]],

         [[-0.0455, -0

In [81]:
for param in one_update_classifier.parameters():
    print(param)

Parameter containing:
tensor([[[[ 0.0075,  0.0732, -0.0628],
          [-0.0165, -0.0671, -0.0209],
          [ 0.0997, -0.1113, -0.1042]],

         [[ 0.0764,  0.0563,  0.0163],
          [ 0.0509,  0.0345, -0.0792],
          [-0.0655,  0.0283, -0.0187]],

         [[ 0.0037,  0.0686,  0.0595],
          [-0.0564, -0.0613,  0.1009],
          [-0.0316,  0.0289,  0.0627]]],


        [[[ 0.0891, -0.0179,  0.1122],
          [-0.0462, -0.1015,  0.0760],
          [ 0.0137, -0.0129,  0.0560]],

         [[-0.0984,  0.1083,  0.0061],
          [ 0.0952,  0.0551, -0.0016],
          [ 0.0244,  0.0427,  0.0426]],

         [[-0.0700,  0.0334,  0.0726],
          [-0.0606, -0.0467, -0.0784],
          [-0.0316,  0.0130, -0.0647]]],


        [[[ 0.0661, -0.0717, -0.1110],
          [ 0.1130,  0.0614, -0.0839],
          [-0.0838,  0.0761, -0.0360]],

         [[ 0.0804,  0.0934,  0.0663],
          [ 0.0562,  0.0989, -0.0300],
          [-0.0761,  0.0960, -0.0105]],

         [[-0.0455, -0

In [82]:
for param in two_update_classifier.parameters():
    print(param)

Parameter containing:
tensor([[[[ 0.0075,  0.0732, -0.0628],
          [-0.0165, -0.0671, -0.0209],
          [ 0.0997, -0.1113, -0.1042]],

         [[ 0.0764,  0.0563,  0.0163],
          [ 0.0509,  0.0345, -0.0792],
          [-0.0655,  0.0283, -0.0187]],

         [[ 0.0037,  0.0686,  0.0595],
          [-0.0564, -0.0613,  0.1009],
          [-0.0316,  0.0289,  0.0627]]],


        [[[ 0.0891, -0.0179,  0.1122],
          [-0.0462, -0.1015,  0.0760],
          [ 0.0137, -0.0129,  0.0560]],

         [[-0.0984,  0.1083,  0.0061],
          [ 0.0952,  0.0551, -0.0016],
          [ 0.0244,  0.0427,  0.0426]],

         [[-0.0700,  0.0334,  0.0726],
          [-0.0606, -0.0467, -0.0784],
          [-0.0316,  0.0130, -0.0647]]],


        [[[ 0.0661, -0.0717, -0.1110],
          [ 0.1130,  0.0614, -0.0839],
          [-0.0838,  0.0761, -0.0360]],

         [[ 0.0804,  0.0934,  0.0663],
          [ 0.0562,  0.0989, -0.0300],
          [-0.0761,  0.0960, -0.0105]],

         [[-0.0455, -0

In [72]:
for param in before_update_prompter.parameters():
    print(param)

Parameter containing:
tensor([[[-2.0942, -1.1419,  0.9164,  ..., -0.5809,  0.0054,  1.0477],
         [ 1.5359,  1.6130, -1.3872,  ...,  0.5636, -0.6456,  1.3802],
         [ 0.2366,  1.2955, -0.4094,  ...,  1.3922, -0.4090,  0.7554],
         ...,
         [-0.0964, -0.2958, -0.4110,  ...,  0.2037, -1.2620, -0.7342],
         [ 1.2579,  1.1338,  0.1011,  ..., -0.9341, -0.2614,  0.8536],
         [ 0.8272,  0.1948,  1.2514,  ..., -0.4523,  0.3193,  1.6686]],

        [[ 1.3553, -0.0542, -1.7229,  ..., -0.9602, -1.2888,  0.4893],
         [-0.0821,  2.5575,  0.5522,  ..., -1.1204,  1.8127, -1.9828],
         [-0.4700, -1.2364,  2.6644,  ...,  0.7980, -0.2313,  0.0909],
         ...,
         [-0.6481, -1.3625, -0.3220,  ...,  1.5608, -0.7971,  0.1644],
         [ 0.5518,  0.1963,  0.6133,  ...,  1.6888, -0.6343, -0.8982],
         [ 2.4009,  0.8844,  0.2474,  ..., -1.1102, -0.0202, -0.7911]],

        [[ 0.1863,  0.1599,  0.7091,  ...,  1.1835, -1.1085,  1.8267],
         [-0.5668,  1.6

In [73]:
for param in one_update_prompter.parameters():
    print(param)

Parameter containing:
tensor([[[-2.0942, -1.1419,  0.9164,  ..., -0.5809,  0.0054,  1.0477],
         [ 1.5359,  1.6130, -1.3872,  ...,  0.5636, -0.6456,  1.3802],
         [ 0.2366,  1.2955, -0.4094,  ...,  1.3922, -0.4090,  0.7554],
         ...,
         [-0.0964, -0.2958, -0.4110,  ...,  0.2037, -1.2620, -0.7342],
         [ 1.2579,  1.1338,  0.1011,  ..., -0.9341, -0.2614,  0.8536],
         [ 0.8272,  0.1948,  1.2514,  ..., -0.4523,  0.3193,  1.6686]],

        [[ 1.3553, -0.0542, -1.7229,  ..., -0.9602, -1.2888,  0.4893],
         [-0.0821,  2.5575,  0.5522,  ..., -1.1204,  1.8127, -1.9828],
         [-0.4700, -1.2364,  2.6644,  ...,  0.7980, -0.2313,  0.0909],
         ...,
         [-0.6481, -1.3625, -0.3220,  ...,  1.5608, -0.7971,  0.1644],
         [ 0.5518,  0.1963,  0.6133,  ...,  1.6888, -0.6343, -0.8982],
         [ 2.4009,  0.8844,  0.2474,  ..., -1.1102, -0.0202, -0.7911]],

        [[ 0.1863,  0.1599,  0.7091,  ...,  1.1835, -1.1085,  1.8267],
         [-0.5668,  1.6

In [74]:
for param in two_update_prompter.parameters():
    print(param)

Parameter containing:
tensor([[[-2.0942, -1.1419,  0.9164,  ..., -0.5809,  0.0054,  1.0477],
         [ 1.5359,  1.6130, -1.3872,  ...,  0.5636, -0.6456,  1.3802],
         [ 0.2366,  1.2955, -0.4094,  ...,  1.3922, -0.4090,  0.7554],
         ...,
         [-0.0964, -0.2958, -0.4110,  ...,  0.2037, -1.2620, -0.7342],
         [ 1.2579,  1.1338,  0.1011,  ..., -0.9341, -0.2614,  0.8536],
         [ 0.8272,  0.1948,  1.2514,  ..., -0.4523,  0.3193,  1.6686]],

        [[ 1.3553, -0.0542, -1.7229,  ..., -0.9602, -1.2888,  0.4893],
         [-0.0821,  2.5575,  0.5522,  ..., -1.1204,  1.8127, -1.9828],
         [-0.4700, -1.2364,  2.6644,  ...,  0.7980, -0.2313,  0.0909],
         ...,
         [-0.6481, -1.3625, -0.3220,  ...,  1.5608, -0.7971,  0.1644],
         [ 0.5518,  0.1963,  0.6133,  ...,  1.6888, -0.6343, -0.8982],
         [ 2.4009,  0.8844,  0.2474,  ..., -1.1102, -0.0202, -0.7911]],

        [[ 0.1863,  0.1599,  0.7091,  ...,  1.1835, -1.1085,  1.8267],
         [-0.5668,  1.6

### 2) torch.autograd.grad와 loss.backward를 사용해서 gradient를 구한 것이 같은지 확인

In [42]:
prompt_grads1 = model.unify_test(images, targets)

meta_prompt_update
tensor([[[-2.0216e-03, -3.6958e-04, -1.5783e-03,  ...,  2.0131e-04,
          -2.1512e-03,  2.7885e-03],
         [ 9.7590e-03,  6.5974e-03, -2.7928e-05,  ..., -2.6555e-03,
          -8.1571e-03, -4.3589e-03],
         [ 2.8159e-03,  5.3855e-03, -4.5202e-03,  ..., -3.0622e-03,
           1.1536e-04,  5.4496e-03],
         ...,
         [-3.7085e-03, -1.0084e-02, -1.4233e-03,  ...,  5.3646e-03,
           2.3510e-03, -5.5829e-03],
         [-1.1532e-03,  9.8780e-04, -3.1116e-03,  ...,  3.5715e-03,
           3.8425e-03,  2.9738e-03],
         [-6.1892e-03, -5.0816e-03,  1.2909e-03,  ..., -2.0270e-03,
          -1.1112e-02,  7.5673e-03]],

        [[-3.1656e-03,  1.3641e-02,  1.5022e-03,  ...,  2.8501e-03,
           3.8098e-03, -1.5951e-03],
         [-4.4339e-03, -6.2720e-03, -6.3043e-04,  ..., -1.8463e-03,
           4.5801e-03,  2.2205e-03],
         [-2.8521e-03, -1.0274e-02,  7.1907e-04,  ...,  3.0277e-04,
          -1.1531e-03,  1.0226e-03],
         ...,
      