In [1]:
import os
import numpy as np
import torch
import torchvision
import PIL
import importlib
import matplotlib.pyplot as plt
%matplotlib inline
import copy
import time
#
import config.config_flags as Config
import data_load.data_provider as dp
import runner as runner
import utils.task_helper as th
import utils.helper as helper
import datasetconf as DC
import TaskClass as TaskClass
import Task as Task
import TestNets as TestNets
import maml as MAML
import tasml as TASML
import testing_routines as TESTING_ROUTINES
#
_ = importlib.reload(Config)
_ = importlib.reload(th)
_ = importlib.reload(dp)
_ = importlib.reload(runner)
_ = importlib.reload(helper)
_ = importlib.reload(DC)
_ = importlib.reload(TaskClass)
_ = importlib.reload(Task)
_ = importlib.reload(TestNets)
_ = importlib.reload(MAML)
_ = importlib.reload(TASML)
_ = importlib.reload(TESTING_ROUTINES)

## Create Tak Objects

In [2]:
#Set tensor device
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

#Generate Task DB or load from filesystem
#print("Generating tasks and compute their alpha weights")
#runner.populate_db()

#Perform top-m filtering
print("Filtering top-m alpha weights and training tasks")
alpha_weights, train_db = runner.top_m_filtering()

#Get target task from filesystem
test_path = helper.get_task_dataset_path("test")
test_db = runner.unpickle(test_path)

#Get train data embeddings
print("Fetching train embeddings")
train_provider = dp.DataProvider("train", debug=False, verbose=False)
train_tr_size = Config.TRAINING_NUM_OF_EXAMPLES_PER_CLASS
train_val_size = Config.VALIDATION_NUM_OF_EXAMPLES_PER_CLASS
print("Generating training tasks")
num_test_tasks = alpha_weights.shape[1]
train_tasks = []
for n in range(num_test_tasks): 
    print("Generating top m training tasks for test task " + str(n))
    train_tasks.append(th.generate_tasks(train_db[n], train_provider, train_tr_size, train_val_size, device))
del train_db, train_provider #Free up space

#Get train data embeddings
print("Fetching test embeddings")
test_provider = dp.DataProvider("test", debug=False, verbose=False)
test_tr_size = Config.TRAINING_NUM_OF_EXAMPLES_PER_CLASS
test_val_size = Config.TEST_VALIDATION_NUM_OF_EXAMPLES_PER_CLASS
print("Generating test tasks")
test_tasks = th.generate_tasks(test_db, test_provider, test_tr_size, test_val_size, device) # Target tasks with only tests populated
del test_db, test_provider #Free up space


Filtering top-m alpha weights and training tasks
Fetching train embeddings
Path fetched: ../embeddings/tieredImageNet/center/train_embeddings.pkl
 23%|██▎       | 23/100 [00:00<00:00, 197.80it/s]Generating training tasks
Generating top m training tasks for test task 0
100%|██████████| 100/100 [00:00<00:00, 424.44it/s]
 75%|███████▌  | 75/100 [00:00<00:00, 749.29it/s]Generating top m training tasks for test task 1
100%|██████████| 100/100 [00:00<00:00, 355.87it/s]
100%|██████████| 100/100 [00:00<00:00, 724.82it/s]
  0%|          | 0/100 [00:00<?, ?it/s]Generating top m training tasks for test task 2
Generating top m training tasks for test task 3
100%|██████████| 100/100 [00:00<00:00, 758.24it/s]
100%|██████████| 100/100 [00:00<00:00, 730.63it/s]
  0%|          | 0/100 [00:00<?, ?it/s]Generating top m training tasks for test task 4
Generating top m training tasks for test task 5
100%|██████████| 100/100 [00:00<00:00, 750.58it/s]
100%|██████████| 100/100 [00:00<00:00, 777.46it/s]
  0%|  

In [3]:
# Create network test instance.
def get_test_net():
    return TestNets.MAMLModule1(input_len=640, n_classes=Config.NUM_OF_CLASSES)

In [4]:
#Iterate each test task
#Fetch the training tasks and weights for the test task
for test_task_num, target_task in enumerate(test_tasks):
    alpha_weights_for_target = alpha_weights[:,test_task_num]
    training_tasks_for_target = train_tasks[test_task_num] #returns list
    
    # Remaps to utilisation of only training images for support and query.
    training_target_task = Task.Task(task_friendly_name=target_task.task_friendly_name, batch_size=target_task.batch_size) 
    training_target_task.supp_train = target_task.supp_train
    training_target_task.supp_targets = target_task.supp_targets
    training_target_task.query_train = target_task.supp_train
    training_target_task.query_targets = target_task.supp_targets
    #
    test_target_task = target_task

    for training_task_num, training_task in enumerate(training_tasks_for_target):
        alpha_weight = alpha_weights_for_target[training_task_num]
        support = training_task.task_classes[0].support_imgs # Get support set of class 0
        query = training_task.task_classes[0].query_imgs
        
        # if training_task_num < 5 and test_task_num==0: #Print weight and shape to get a feel of the data structures
        #     print("alpha weight: ", alpha_weight, "Support length: ", len(support), "Query length: ", len(query))

    # Test nn modules (equal).
    test_net_base = get_test_net()
    #
    test_net_maml = get_test_net()
    test_net_maml.load_state_dict(copy.deepcopy(test_net_base.state_dict()))
    #
    test_net_maml_ft = get_test_net()
    test_net_maml_ft.load_state_dict(copy.deepcopy(test_net_base.state_dict()))
    #
    test_net_tasml = get_test_net()
    test_net_tasml.load_state_dict(copy.deepcopy(test_net_maml.state_dict()))
    #
    test_net_tasml_ft = get_test_net()
    test_net_tasml_ft.load_state_dict(copy.deepcopy(test_net_base.state_dict()))

    TESTING_ROUTINES.run_baselearner(test_net_base, training_tasks_for_target, training_target_task, test_target_task)
    TESTING_ROUTINES.run_maml(test_net_maml, training_tasks_for_target, training_target_task, test_target_task, isMetaFinetuned=True)
    TESTING_ROUTINES.run_maml(test_net_maml_ft, training_tasks_for_target, training_target_task, test_target_task, isMetaFinetuned=False)
    TESTING_ROUTINES.run_tasml(test_net_tasml, training_tasks_for_target, training_target_task, alpha_weights_for_target, test_target_task, isMetaFinetuned=True)
    TESTING_ROUTINES.run_tasml(test_net_tasml_ft, training_tasks_for_target, training_target_task, alpha_weights_for_target, test_target_task, isMetaFinetuned=False)

BASE, 0, 20.0, -, 24.77, 47.13, 10.51775312423706, 0.05827784538269043, -, 10.595863103866577
MAML, 0, 20.0, -, 20.2, 60.67, 28.422596216201782, 0.28466105461120605, -, 28.72720217704773
MAML, 0, 20.0, -, 20.2, 43.53, 28.391676902770996, 0.05005002021789551, -, 28.46204400062561
TASML, 0, 20.0, 20.2, 28.13, 54.93, 28.051140785217285, 150.63891410827637, 0.2902259826660156, 179.0208237171173
TASML, 0, 20.0, 20.2, 28.13, 50.93, 29.27805733680725, 157.00564813613892, 0.06998610496520996, 186.3978853225708


KeyboardInterrupt: 

## Loading tasks from local img folders.
```python
task1 = Task.create_task_given(
    task_friendly_name='Task1',
    dataset_name='boat',
    class_names=['Gondola', 'Motopontonerettangolare'], #n-way
    len_support_dataset=3, #k-shot
    len_query_dataset=2,   #k-shot
    let_test_dataset=5, # Number of test cases per class
    transformer=torchvision.transforms.Compose([
        torchvision.transforms.Resize(224),
        torchvision.transforms.CenterCrop(224),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225])]),
    img_size=224,
    start_class_id=0)
task1.reset_train_session()
#
task2 = Task.create_task_given(
    task_friendly_name='Task2',
    dataset_name='boat',
    class_names=['Raccoltarifiuti', 'Water'],
    len_support_dataset=3,
    len_query_dataset=2,
    let_test_dataset=5, # Number of test cases per class
    transformer=torchvision.transforms.Compose([
        torchvision.transforms.Resize(224),
        torchvision.transforms.CenterCrop(224),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225])]),
    img_size=224,
    start_class_id=2)
task2.reset_train_session()
#
TASKS = [task1, task2]
LEN_CLASSES = sum(len(task.task_classes) for task in TASKS)
CLASSES_NAMES = [taskclass.class_friendly_name for task in TASKS for taskclass in task.task_classes]
NUM_IN_CHANNELS = 3 # RGB
```

### MAML Runner:
```python
MAML.maml_nn_classifier_learn(
    test_net: torch.nn.Module,
    tasks: list[Task.Task],
    convergence_diff: float = 0.0001,
    max_meta_epochs = 10,
    inner_epochs: int = 1,
    inner_lr: float = 0.001,
    outer_lr: float = 0.001,
    loss_function = torch.nn.CrossEntropyLoss()):
```

### TASML Runner:
```python
TASML.tasml_nn_classifier_learn(
    test_net: torch.nn.Module,
    tasks: list[Task.Task],
    target_task: Task.Task,
    alpha_weights: torch.Tensor,
    convergence_diff: float = 0.0001,
    max_meta_epochs = 10,
    inner_epochs: int = 1,
    inner_lr: float = 0.001,
    outer_lr: float = 0.001,
    loss_function = torch.nn.CrossEntropyLoss()):
```