In [3]:
import os
import numpy as np
import torch
import torchvision
import PIL
import importlib
import matplotlib.pyplot as plt
%matplotlib inline
import copy
#
import config.config_flags as Config
import data_load.data_provider as dp
import runner as runner
import utils.task_helper as th
import utils.helper as helper
import datasetconf as DC
import TaskClass as TaskClass
import Task as Task
import TestNets as TestNets
import maml as MAML
#
_ = importlib.reload(Config)
_ = importlib.reload(th)
_ = importlib.reload(dp)
_ = importlib.reload(runner)
_ = importlib.reload(helper)
_ = importlib.reload(DC)
_ = importlib.reload(TaskClass)
_ = importlib.reload(Task)
_ = importlib.reload(TestNets)
_ = importlib.reload(MAML)

## Loading tasks from local img folders.

In [2]:
task1 = Task.create_task_given(
    task_friendly_name='Task1',
    dataset_name='boat',
    class_names=['Gondola', 'Motopontonerettangolare'], #n-way
    len_support_dataset=3, #k-shot
    len_query_dataset=2,   #k-shot
    let_test_dataset=5, # Number of test cases per class
    transformer=torchvision.transforms.Compose([
        torchvision.transforms.Resize(224),
        torchvision.transforms.CenterCrop(224),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225])]),
    img_size=224,
    start_class_id=0)
task1.reset_train_session()
#
task2 = Task.create_task_given(
    task_friendly_name='Task2',
    dataset_name='boat',
    class_names=['Raccoltarifiuti', 'Water'],
    len_support_dataset=3,
    len_query_dataset=2,
    let_test_dataset=5, # Number of test cases per class
    transformer=torchvision.transforms.Compose([
        torchvision.transforms.Resize(224),
        torchvision.transforms.CenterCrop(224),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225])]),
    img_size=224,
    start_class_id=2)
task2.reset_train_session()
#
TASKS = [task1, task2]
LEN_CLASSES = sum(len(task.task_classes) for task in TASKS)
CLASSES_NAMES = [taskclass.class_friendly_name for task in TASKS for taskclass in task.task_classes]
NUM_IN_CHANNELS = 3 # RGB

## Create Tak Objects

In [None]:
#Set tensor device
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

#Generate Task DB or load from filesystem
print("Generating tasks and compute their alpha weights")
runner.populate_db()

#Perform top-m filtering
print("Filtering top-m alpha weights and training tasks")
alpha_weights, train_db = runner.top_m_filtering()

#Get target task from filesystem
test_path = helper.get_task_dataset_path("test")
test_db = runner.unpickle(test_path)

#Get train data embeddings
print("Fetching train embeddings")
train_provider = dp.DataProvider("train", debug=False, verbose=False)
train_tr_size = Config.TRAINING_NUM_OF_EXAMPLES_PER_CLASS
train_val_size = Config.VALIDATION_NUM_OF_EXAMPLES_PER_CLASS
print("Generating training tasks")
train_tasks = th.generate_tasks(train_db, train_provider, train_tr_size, train_val_size, device)

#Get train data embeddings
print("Fetching test embeddings")
test_provider = dp.DataProvider("test", debug=False, verbose=False)
test_tr_size = 0
test_val_size = Config.TEST_VALIDATION_NUM_OF_EXAMPLES_PER_CLASS
print("Generating test tasks")
test_tasks = th.generate_tasks(test_db, test_provider, test_tr_size, test_val_size, device)

In [9]:
print(alpha_weights.shape)
print(len(train_tasks))
print(alpha_weights[1])
print(alpha_weights[1].item())

(100, 1)
100
[0.23874423]
0.2387442287581507


In [5]:
test_net = TestNets.TestNet224x224(n_channels_in=NUM_IN_CHANNELS, n_classes=LEN_CLASSES)

In [None]:
MAML.maml_nn_classifier_learn(
    test_net=test_net,
    tasks=TASKS,
    # inner_epochs: int = 1, # Not implemented yet.
    meta_epochs=20,
    # inner_lr: float = 0.001,
    # outer_lr: float = 0.001,
    # loss_function = torch.nn.CrossEntropyLoss()
)

In [None]:
# Repeat until convergence:
TASML.tasml_epoch_nn_classifier_learn(
    test_net=test_net,
    tasks=train_tasks,
    alpha_weights=alpha_weights,
    # inner_epochs: int = 1, # Feature not implemented.
    # inner_lr: float = 0.001,
    # outer_lr: float = 0.001,
    # loss_function = torch.nn.CrossEntropyLoss()
):