In [123]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [124]:
import os

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from meta_neural_network_architectures import VGGReLUNormNetwork, ResNet12
from inner_loop_optimizers import LSLRGradientDescentLearningRule

In [125]:
#!pip install easydict
import easydict

In [126]:
args = easydict.EasyDict(
{
  "batch_size":2,
  "image_height":84,
  "image_width":84,
  "image_channels":3,
  "gpu_to_use":0,
  "num_dataprovider_workers":4,
  "max_models_to_save":5,
  "dataset_name":"mini_imagenet_full_size",
  "dataset_path":"mini_imagenet_full_size",
  "reset_stored_paths":False,
  "experiment_name":"alfa+maml",
  "train_seed": 0, "val_seed": 0,
  "indexes_of_folders_indicating_class": [-3, -2],
  "sets_are_pre_split": True,
  "train_val_test_split": [0.64, 0.16, 0.20],
  "evaluate_on_test_set_only": False,

  "total_epochs": 100,
  "total_iter_per_epoch":500, "continue_from_epoch": -2,
  "num_evaluation_tasks":600,
  "multi_step_loss_num_epochs": 15,
  "minimum_per_task_contribution": 0.01,
  "learnable_per_layer_per_step_inner_loop_learning_rate": False,
  "enable_inner_loop_optimizable_bn_params": False,
  "evalute_on_test_set_only": False,

  "max_pooling": True,
  "per_step_bn_statistics": False,
  "learnable_batch_norm_momentum": False,
  "load_into_memory": False,
  "init_inner_loop_learning_rate": 0.01,
  "init_inner_loop_weight_decay": 0.0005,
  "learnable_bn_gamma": True,
  "learnable_bn_beta": True,

  "dropout_rate_value":0.0,
  "min_learning_rate":0.001,
  "meta_learning_rate":0.001,   "total_epochs_before_pause": 100,
  "first_order_to_second_order_epoch":-1,
  "weight_decay": 0.0,

  "norm_layer":"batch_norm",
  "cnn_num_filters":48,
  "num_stages":4,
  "conv_padding": True,
  "number_of_training_steps_per_iter":5,
  "number_of_evaluation_steps_per_iter":5,
  "cnn_blocks_per_stage":1,
  "num_classes_per_set":5,
  "num_samples_per_class":5,
  "num_target_samples": 15,

  "second_order": True,
  "use_multi_step_loss_optimization":False,
  "attenuate": False,
  "alfa": True,
  "random_init": False,
  "backbone": "4-CONV"
}
)

device = torch.cuda.current_device()
im_shape = (2, 3, args.image_height, args.image_width)

args.use_cuda = torch.cuda.is_available()
args.seed = 104

In [127]:
def set_torch_seed(seed):
    """
    Sets the pytorch seeds for current experiment run
    :param seed: The seed (int)
    :return: A random number generator to use
    """
    rng = np.random.RandomState(seed=seed)
    torch_seed = rng.randint(0, 999999)
    torch.manual_seed(seed=torch_seed)

    return rng

class MAMLFewShotClassifier_Test(nn.Module):
    def __init__(self, im_shape, device, args):
        super(MAMLFewShotClassifier_Test, self).__init__()
        self.args = args
        self.device = device
        self.batch_size = args.batch_size
        self.use_cuda = args.use_cuda
        self.im_shape = im_shape
        self.current_epoch = 0

        self.rng = set_torch_seed(seed=args.seed)

        if self.args.backbone == 'ResNet12':
            self.classifier = ResNet12(im_shape=self.im_shape, num_output_classes=self.args.
                                                 num_classes_per_set,
                                                 args=args, device=device, meta_classifier=True).to(device=self.device)
        else:
            self.classifier = VGGReLUNormNetwork(im_shape=self.im_shape, num_output_classes=self.args.
                                                 num_classes_per_set,
                                                 args=args, device=device, meta_classifier=True).to(device=self.device)
            
        self.task_learning_rate = args.init_inner_loop_learning_rate
        
        # Inner loop의 모든 것이 이루어지는 공간이구나
        self.inner_loop_optimizer = LSLRGradientDescentLearningRule(device=device,
                                                                    init_learning_rate=self.task_learning_rate,
                                                                    init_weight_decay=args.init_inner_loop_weight_decay,
                                                                    total_num_inner_loop_steps=self.args.number_of_training_steps_per_iter,
                                                                    use_learnable_weight_decay=self.args.alfa,
                                                                    use_learnable_learning_rates=self.args.alfa,
                                                                    alfa=self.args.alfa, random_init=self.args.random_init)
        
        # requires_grad가 true인 parameter만 복제해 놓는다
        names_weights_copy = self.get_inner_loop_parameter_dict(self.classifier.named_parameters())
        ## 왜 해놓는 것일까? -> L2F 논문
        ### 각 task를 수행하는 동안, task-conditioned network 를 통해 각 layer에 대 attenuation parameter()를 생성하여 Conflict를 상쇄

        if self.args.attenuate:

            num_layers = len(names_weights_copy)
            # 각 layer에 대 attenuation parameter()를 생성하기 위해서 input의 차원을 layer의 수만큼 둔다
            self.attenuator = nn.Sequential(
                nn.Linear(num_layers, num_layers),
                nn.ReLU(inplace=True),
                nn.Linear(num_layers, num_layers),
                nn.Sigmoid() # 근데 attenuator의 output이 Sigmoid가 맞아?
            ).to(device=self.device)

        # 각 paramter에 해당하는 alpha, beta 초기값을 세팅해 loop만큼 놓는다
        self.inner_loop_optimizer.initialise(names_weights_dict=names_weights_copy)
        
        # Inner loop를 확인한다
        print("Inner Loop parameters")
        for key, value in self.inner_loop_optimizer.named_parameters():
            print(key, value.shape)
        
        
        self.use_cuda = args.use_cuda
        self.device = device
        self.args = args
        self.to(device)

        # outer loop를 확인한다
        print("Outer Loop parameters")
        for name, param in self.named_parameters():
            if param.requires_grad:
                print(name, param.shape, param.device, param.requires_grad)
        
    def get_inner_loop_parameter_dict(self, params):
        """
        Returns a dictionary with the parameters to use for inner loop updates.
        :param params: A dictionary of the network's parameters.
        :return: A dictionary of the parameters to use for the inner loop optimization process.
        """
        param_dict = dict()
        for name, param in params:
            if param.requires_grad:
                if self.args.enable_inner_loop_optimizable_bn_params:
                    param_dict[name] = param.to(device=self.device)
                else:
                    if "norm_layer" not in name:
                        param_dict[name] = param.to(device=self.device)

        return param_dict
        
    def forward(self, data_batch, epoch, use_second_order, use_multi_step_loss_optimization, num_steps, training_phase):
        print("MAMLFewShotClassifier_Test forward")

In [128]:
model = MAMLFewShotClassifier_Test(args=args, device=device,im_shape=(2, 3, args.image_height, args.image_width))

Using max pooling
(MetaConvNormLayerReLU build_block) out.shape==  torch.Size([2, 48, 84, 84])
(MetaConvNormLayerReLU build_block) out.shape==  torch.Size([2, 48, 42, 42])
(MetaConvNormLayerReLU build_block) out.shape==  torch.Size([2, 48, 21, 21])
(MetaConvNormLayerReLU build_block) out.shape==  torch.Size([2, 48, 10, 10])
MetaLinearLayer forward
VGGNetwork build out.shape ===  torch.Size([2, 5])
VGGNetwork build out ===  tensor([[0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.]], grad_fn=<AddmmBackward0>)
meta network params
layer_dict.conv0.conv.weight torch.Size([48, 3, 3, 3])
layer_dict.conv0.conv.bias torch.Size([48])
layer_dict.conv0.norm_layer.running_mean torch.Size([48])
layer_dict.conv0.norm_layer.running_var torch.Size([48])
layer_dict.conv0.norm_layer.bias torch.Size([48])
layer_dict.conv0.norm_layer.weight torch.Size([48])
layer_dict.conv1.conv.weight torch.Size([48, 48, 3, 3])
layer_dict.conv1.conv.bias torch.Size([48])
layer_dict.conv1.norm_layer.running_mean torch.Siz

In [129]:
# torch.ones(self.total_num_inner_loop_steps + 1) * self.init_learning_rate 확인
torch.ones(5 + 1) * 0.01

tensor([0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100])