In [1]:
import os
os.chdir("../../..")

In [2]:
import omegaconf
from shell.utils.experiment_utils import *
from shell.utils.metric import *
import matplotlib.pyplot as plt
from shell.fleet.network import TopologyGenerator
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import numpy as np
import pandas as pd
from shell.fleet.fleet import Agent
from shell.fleet.data.data_utilize import *
from shell.fleet.data.recv import *

from sklearn.manifold import TSNE
from torchvision.utils import make_grid
import logging
logging.basicConfig(level=logging.INFO)

In [3]:
seed_everything(0)

In [4]:
use_contrastive = True
# use_contrastive = False
num_tasks = 5
num_epochs = 20
flip_probability = 0.1
use_flip = True
jerk = True

In [5]:

data_cfg = {
    "dataset_name": "mnist",
    "num_tasks": num_tasks,
    "num_train_per_task": 128,
    "num_val_per_task": 102,
    'remap_labels': True,
    'use_contrastive': use_contrastive,
    # 'with_replacement': True,
}
dataset = get_dataset(**data_cfg)

INFO:root:Class sequence: [5 0 3 4 2 6 7 1 8 9]
INFO:root:task 0 :(128, 1, 28, 28)
INFO:root:task 1 :(128, 1, 28, 28)
INFO:root:task 2 :(128, 1, 28, 28)
INFO:root:task 3 :(128, 1, 28, 28)
INFO:root:task 4 :(128, 1, 28, 28)


In [6]:
seed_everything(7)
sender_dataset = get_dataset(**data_cfg)

INFO:root:Class sequence: [9 5 7 8 4 3 6 1 0 2]
INFO:root:task 0 :(128, 1, 28, 28)
INFO:root:task 1 :(128, 1, 28, 28)
INFO:root:task 2 :(128, 1, 28, 28)
INFO:root:task 3 :(128, 1, 28, 28)
INFO:root:task 4 :(128, 1, 28, 28)


In [7]:
net_cfg = {
    'depth': 4,
    'layer_size': 64,
    'num_init_tasks': num_tasks,
    'i_size': 28,
    'num_classes': 2,
    'num_tasks': num_tasks,
    'dropout': 0.0,
}

agent_cfg = {
    'memory_size': 64,
    'use_contrastive': use_contrastive,
    'save_dir': 'test',
}

In [8]:
## MODULAR
NetCls = MLPSoftLLDynamic
LearnerCls = CompositionalDynamicER

## MONOLITHIC
NetCls = MLP
LearnerCls = NoComponentsER

sharing_cfg = DictConfig({
    "scorer": "cross_entropy",
    "num_queries": 4,
    "query_score_threshold": 0.0,
})
train_cfg = {
    # "num_epochs": 40,
    "num_epochs": num_epochs,
}

agent = RecvDataAgent(0, 0, dataset,
                NetCls, LearnerCls, net_cfg, agent_cfg, train_cfg, 
                sharing_cfg)

INFO:root:Agent: node_id: 0, seed: 0


In [9]:
class ConcatenatedDataset(torch.utils.data.Dataset):
    def __init__(self, datasets):
        self.datasets = datasets

    def __len__(self):
        return sum(len(dataset) for dataset in self.datasets)
    
    def __getitem__(self, index):
        # Find the dataset which contains the data point with the given index
        for task_id, dataset in enumerate(self.datasets):
            if index < len(dataset):
                break
            index -= len(dataset)
            
        # Get the data from the dataset
        x, y = dataset[index]
        return x, y, task_id


def concatenate_datasets(sender_trainset):
    """
    Convert and concatenate sender_dataset.trainset into one big dataset.
    
    :param sender_trainset: List of datasets split by tasks.
    :return: Dataset with items (X, y_source, task_source_id).
    """
    
    return ConcatenatedDataset(sender_trainset)

In [10]:
combined_data = concatenate_datasets(sender_dataset.trainset)

transformed_data = utilize_global_labels(combined_data, source_class_sequence=sender_dataset.class_sequence, target_class_sequence=
                                            dataset.class_sequence, num_classes_per_task=dataset.num_classes_per_task)
if use_flip:
    flipped_data = RandomFlippedDataset(transformed_data, flip_probability=flip_probability, num_classes=
                                        dataset.num_classes_per_task,
                                        jerk=jerk)

In [11]:
# before additional data
for t in range(num_tasks):
    agent.train(t)

INFO:root:epochs: 0, training task: 0
INFO:root:	task: 0	loss: 0.69148300	acc: 0.524
INFO:root:	task: avg	loss: 0.69148300	acc: 0.524
INFO:root:epochs: 0, training task: 0
INFO:root:	task: 0	loss: 0.69148300	acc: 0.524
INFO:root:	task: avg	loss: 0.69148300	acc: 0.524
INFO:root:epochs: 0, training task: 1
INFO:root:	task: 0	loss: 0.69148300	acc: 0.524
INFO:root:	task: 1	loss: 0.69385989	acc: 0.493
INFO:root:	task: avg	loss: 0.69267144	acc: 0.508
INFO:root:epochs: 0, training task: 1
INFO:root:	task: 0	loss: 0.69148300	acc: 0.524
INFO:root:	task: 1	loss: 0.69385989	acc: 0.493
INFO:root:	task: avg	loss: 0.69267144	acc: 0.508
INFO:root:epochs: 0, training task: 2
INFO:root:	task: 0	loss: 0.69148300	acc: 0.524
INFO:root:	task: 1	loss: 0.69385989	acc: 0.493
INFO:root:	task: 2	loss: 0.69981068	acc: 0.481
INFO:root:	task: avg	loss: 0.69505119	acc: 0.499
INFO:root:epochs: 0, training task: 2
INFO:root:	task: 0	loss: 0.69148300	acc: 0.524
INFO:root:	task: 1	loss: 0.69385989	acc: 0.493
INFO:root:

In [12]:
if use_flip:
    mega_loader = torch.utils.data.DataLoader(flipped_data, batch_size=64, shuffle=True)
else:
    mega_loader = torch.utils.data.DataLoader(transformed_data, batch_size=64, shuffle=True)


task_id = num_tasks-1
testloaders = {task: torch.utils.data.DataLoader(testset,
                                                         batch_size=256,
                                                         shuffle=False,
                                                         num_workers=4,
                                                         pin_memory=True,
                                                         ) for task, testset in enumerate(agent.dataset.testset[:(task_id+1)])}


agent.agent._train(mega_loader, start_epoch=0, num_epochs=num_epochs,
                   task_id=task_id, testloaders=testloaders)

INFO:root:epochs: 1, training task: 4
INFO:root:	task: 0	loss: 0.65558522	acc: 0.524
INFO:root:	task: 1	loss: 0.66470089	acc: 0.793
INFO:root:	task: 2	loss: 0.64564788	acc: 0.670
INFO:root:	task: 3	loss: 0.66581983	acc: 0.732
INFO:root:	task: 4	loss: 0.67722365	acc: 0.703
INFO:root:	task: avg	loss: 0.66179549	acc: 0.684
INFO:root:epochs: 2, training task: 4
INFO:root:	task: 0	loss: 0.65172012	acc: 0.524
INFO:root:	task: 1	loss: 0.66742114	acc: 0.894
INFO:root:	task: 2	loss: 0.65261554	acc: 0.503
INFO:root:	task: 3	loss: 0.66239737	acc: 0.950
INFO:root:	task: 4	loss: 0.67684971	acc: 0.866
INFO:root:	task: avg	loss: 0.66220078	acc: 0.747
INFO:root:epochs: 3, training task: 4
INFO:root:	task: 0	loss: 0.64396297	acc: 0.525
INFO:root:	task: 1	loss: 0.65502059	acc: 0.949
INFO:root:	task: 2	loss: 0.64749241	acc: 0.526
INFO:root:	task: 3	loss: 0.65279588	acc: 0.971
INFO:root:	task: 4	loss: 0.67276792	acc: 0.916
INFO:root:	task: avg	loss: 0.65440795	acc: 0.778
INFO:root:epochs: 4, training task