In [1]:
import os
os.chdir("../..")

In [2]:
'''
File: /grad_sanity_check2.py
Project: experiments
Created Date: Thursday September 7th 2023
Author: Long Le (vlongle@seas.upenn.edu)

Copyright (c) 2023 Long Le
'''
'''
File: /grad_sanity_check.py
Project: experiments
Created Date: Tuesday September 5th 2023
Author: Long Le (vlongle@seas.upenn.edu)

Copyright (c) 2023 Long Le
'''

"""
1. Does the normal averaging destroy the structure of learned weights?
2. Under the toy domain, the fisher fixes that?

Settings: 2 agents with 2 tasks. We'd like to basically distill into one model that solves both tasks.

Checklist
[] Copying works.
[] See how long takes to bounce back
[] fischer info

"""


from shell.fleet.grad.fisher_monograd import ModelFisherSyncAgent
from copy import deepcopy
from shell.datasets.datasets import get_dataset
import numpy as np
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
from shell.utils.utils import seed_everything, viz_embedding
import torch
import subprocess
import torch.nn as nn
import torch
from omegaconf import DictConfig
from shell.datasets.datasets import get_dataset
from shell.utils.utils import seed_everything
from pprint import pprint
from shell.fleet.network import TopologyGenerator
from shell.models.cnn_soft_lifelong_dynamic import CNNSoftLLDynamic
from shell.models.cnn import CNN
from shell.models.mlp import MLP
from shell.models.mlp_soft_lifelong_dynamic import MLPSoftLLDynamic
from shell.learners.er_dynamic import CompositionalDynamicER
from shell.learners.er_nocomponents import NoComponentsER
from shell.utils.experiment_utils import eval_net
from shell.utils.experiment_utils import setup_experiment
from sklearn.manifold import TSNE
import logging
import seaborn as sns
import pickle
logging.basicConfig(level=logging.INFO)

"""
Two agents trained separately
"""

num_tasks = 10

net_cfg = {
    'depth': 2,
    'layer_size': 64,
    'num_init_tasks': 2,
    'i_size': 28,
    'num_classes': 2,
    'num_tasks': num_tasks,
    'dropout': 0.0,
}

agent_cfg = {
    'memory_size': 64,
    'use_contrastive': False,
    'save_dir': '',
}


data_cfg = {
    "dataset_name": "mnist",
    "num_tasks": num_tasks,
    "num_train_per_task": 128,
    "num_val_per_task": 102,
    'remap_labels': True,
    'use_contrastive': False,
    "with_replacement": True,
}
print(num_tasks)
bob_dataset = get_dataset(**data_cfg)
print(len(bob_dataset.trainset))

# # alice dataset is the same as bob's but tasks 0,1 are switched with 2,3
alice_dataset = deepcopy(bob_dataset)
alice_dataset.trainset[0], alice_dataset.trainset[2] = alice_dataset.trainset[2], alice_dataset.trainset[0]
alice_dataset.trainset[1], alice_dataset.trainset[3] = alice_dataset.trainset[3], alice_dataset.trainset[1]
alice_dataset.testset[0], alice_dataset.testset[2] = alice_dataset.testset[2], alice_dataset.testset[0]
alice_dataset.testset[1], alice_dataset.testset[3] = alice_dataset.testset[3], alice_dataset.testset[1]
alice_dataset.valset[0], alice_dataset.valset[2] = alice_dataset.valset[2], alice_dataset.valset[0]
alice_dataset.valset[1], alice_dataset.valset[3] = alice_dataset.valset[3], alice_dataset.valset[1]
alice_dataset.class_sequence[0], alice_dataset.class_sequence[
    2] = alice_dataset.class_sequence[2], alice_dataset.class_sequence[0]
alice_dataset.class_sequence[1], alice_dataset.class_sequence[
    3] = alice_dataset.class_sequence[3], alice_dataset.class_sequence[1]

# alice = CompositionalDynamicER(MLPSoftLLDynamic(**net_cfg), **agent_cfg)
# bob = CompositionalDynamicER(MLPSoftLLDynamic(**net_cfg), **agent_cfg)

train_cfg = {'num_epochs': 10}
sharing_cfg = {}
alice = ModelFisherSyncAgent(0, 0, alice_dataset, MLPSoftLLDynamic, CompositionalDynamicER, net_cfg, agent_cfg,
                             train_cfg, sharing_cfg)

bob = ModelFisherSyncAgent(1, 0, alice_dataset, MLPSoftLLDynamic, CompositionalDynamicER, net_cfg, agent_cfg,
                           train_cfg, sharing_cfg)

# synchronizing the random projection to make sure they're the same.
alice.net.random_linear_projection.weight = bob.net.random_linear_projection.weight
alice.net.random_linear_projection.bias = bob.net.random_linear_projection.bias


alice.train(0)
alice.train(1)

# alice.model = alice.prepare_model()
# fisher = alice.prepare_fisher_diag()

# print(fisher)

# # training ...
# for task_id in range(2):
#     testloaders = {task: torch.utils.data.DataLoader(testset,
#                                                      batch_size=128,
#                                                      shuffle=False,
#                                                      num_workers=0,
#                                                      pin_memory=True,
#                                                      ) for task, testset in enumerate(bob_dataset.testset[:(task_id+1)])}

#     trainloader = torch.utils.data.DataLoader(bob_dataset.trainset[task_id],
#                                               batch_size=64,
#                                               shuffle=True,
#                                               num_workers=0,
#                                               pin_memory=True,
#                                               )
#     valloader = torch.utils.data.DataLoader(bob_dataset.valset[task_id],
#                                             batch_size=256,
#                                             shuffle=False,
#                                             num_workers=4,
#                                             pin_memory=True,
#                                             )

#     bob.train(trainloader, task_id=task_id, num_epochs=10, testloaders=testloaders,
#               valloader=valloader, save_freq=1)


# print('\n\n\n')
# for task_id in range(2):
#     testloaders = {task: torch.utils.data.DataLoader(testset,
#                                                      batch_size=128,
#                                                      shuffle=False,
#                                                      num_workers=0,
#                                                      pin_memory=True,
#                                                      ) for task, testset in enumerate(alice_dataset.testset[:(task_id+1+2)])}

#     trainloader = torch.utils.data.DataLoader(alice_dataset.trainset[task_id],
#                                               batch_size=64,
#                                               shuffle=True,
#                                               num_workers=0,
#                                               pin_memory=True,
#                                               )
#     valloader = torch.utils.data.DataLoader(alice_dataset.valset[task_id],
#                                             batch_size=256,
#                                             shuffle=False,
#                                             num_workers=4,
#                                             pin_memory=True,
#                                             )

#     alice.train(trainloader, task_id=task_id, num_epochs=10, testloaders=testloaders,
#                 valloader=valloader, save_freq=1)

# # print(alice.net)


# # manually plug bob weights into alice...
# alice.net.add_tmp_module(task_id=2)
# alice.net.add_tmp_module(task_id=2)

# # turn the dict testloaders into one combined_testloader
# # combined_testloader = torch.utils.data.DataLoader(
# #     torch.utils.data.ConcatDataset(
# #         [testset.dataset for testset in testloaders.values()]),
# #     batch_size=128,
# #     shuffle=False,
# #     num_workers=0,
# #     pin_memory=True,
# # )

# alice_testloaders = {task: torch.utils.data.DataLoader(testset,
#                                                        batch_size=128,
#                                                        shuffle=False,
#                                                        num_workers=0,
#                                                        pin_memory=True,
#                                                        ) for task, testset in enumerate(alice_dataset.testset[:(task_id+1)])}
# print(alice_testloaders)
# ewc = EWC(alice.net, alice_testloaders)
# print("precision matrices")
# print(ewc._precision_matrices)

# """
# TODO: re-examine EWC.
# """

# # # print(alice.net)
# # alice.net.components[2].weight = bob.net.components[0].weight
# # alice.net.components[3].weight = bob.net.components[1].weight

# # alice.net.components[2].bias = bob.net.components[0].bias
# # alice.net.components[3].bias = bob.net.components[1].bias

# # # bob has 2 modules and alice has 4 modules. Black out modules 1 and 2
# # # of alice, and copy over the structure of bob's modules 1 and 2 into alice's modules 3 and 4
# # # 1. set all structure to -inf
# # alice.net.structure[2].data = -np.inf * torch.ones(4, 2)
# # alice.net.structure[3].data = -np.inf * torch.ones(4, 2)

# # # 2. copy over the structure of bob's modules 1 and 2 into alice's modules 3 and 4
# # alice.net.structure[2].data[2:, :] = bob.net.structure[0].data
# # alice.net.structure[3].data[2:, :] = bob.net.structure[1].data

# # # need to copy over the decoder as well
# # alice.net.decoder[2].weight = bob.net.decoder[0].weight
# # alice.net.decoder[3].weight = bob.net.decoder[1].weight
# # alice.net.decoder[2].bias = bob.net.decoder[0].bias
# # alice.net.decoder[3].bias = bob.net.decoder[1].bias

# # # print(alice.net.structure[2])
# # print(eval_net(alice.net, testloaders))


INFO:root:Class sequence: [1 2 4 6 4 8 8 7 8 2 7 3 3 0 4 5 1 0 4 2]
INFO:root:task 0 :(128, 1, 28, 28)
INFO:root:task 1 :(128, 1, 28, 28)
INFO:root:task 2 :(128, 1, 28, 28)
INFO:root:task 3 :(128, 1, 28, 28)
INFO:root:task 4 :(128, 1, 28, 28)


10


INFO:root:task 5 :(128, 1, 28, 28)
INFO:root:task 6 :(128, 1, 28, 28)
INFO:root:task 7 :(128, 1, 28, 28)
INFO:root:task 8 :(128, 1, 28, 28)
INFO:root:task 9 :(128, 1, 28, 28)
INFO:root:Agent: node_id: 0, seed: 0


10


INFO:root:Agent: node_id: 1, seed: 1000
INFO:root:epochs: 0, training task: 0
INFO:root:	task: 0	loss: 0.695	acc: 0.502
INFO:root:	task: avg	loss: 0.695	acc: 0.502
INFO:root:epochs: 0, training task: 0
INFO:root:	task: 0	loss: 0.695	acc: 0.502
INFO:root:	task: avg	loss: 0.695	acc: 0.502
INFO:root:final components: 2
INFO:root:epochs: 0, training task: 1
INFO:root:	task: 0	loss: 0.695	acc: 0.502
INFO:root:	task: 1	loss: 0.694	acc: 0.492
INFO:root:	task: avg	loss: 0.695	acc: 0.497
INFO:root:epochs: 1, training task: 1
INFO:root:	task: 0	loss: 0.690	acc: 0.510
INFO:root:	task: 1	loss: 0.689	acc: 0.649
INFO:root:	task: avg	loss: 0.689	acc: 0.579
INFO:root:epochs: 2, training task: 1
INFO:root:	task: 0	loss: 0.686	acc: 0.518
INFO:root:	task: 1	loss: 0.682	acc: 0.534
INFO:root:	task: avg	loss: 0.684	acc: 0.526
INFO:root:epochs: 3, training task: 1
INFO:root:	task: 0	loss: 0.681	acc: 0.530
INFO:root:	task: 1	loss: 0.676	acc: 0.515
INFO:root:	task: avg	loss: 0.679	acc: 0.523
INFO:root:epochs: 

AttributeError: 'collections.OrderedDict' object has no attribute 'named_parameters'

In [3]:
alice.model = alice.net

In [4]:
fisher = alice.prepare_fisher_diag()

In [5]:
fisher

{'components.0.weight': tensor([[2.1598e-09, 5.1071e-07, 2.2493e-07,  ..., 5.7379e-08, 1.5239e-08,
          1.1850e-08],
         [1.3021e-11, 6.1282e-10, 2.6502e-10,  ..., 3.6133e-09, 3.9926e-09,
          5.6248e-10],
         [9.1271e-12, 1.8683e-09, 1.3990e-10,  ..., 6.1934e-12, 1.0104e-09,
          1.2679e-10],
         ...,
         [1.8940e-09, 3.2291e-07, 1.1065e-07,  ..., 1.5047e-09, 6.3026e-08,
          2.4780e-08],
         [2.6458e-09, 2.6954e-08, 1.2392e-09,  ..., 1.0581e-09, 1.0557e-08,
          6.8356e-09],
         [4.4564e-10, 3.8562e-08, 7.2080e-11,  ..., 1.1137e-08, 1.1507e-08,
          9.7732e-09]], device='cuda:0'),
 'components.0.bias': tensor([2.7003e-06, 4.8618e-08, 7.3080e-09, 5.4496e-08, 2.8545e-08, 6.5768e-07,
         2.9294e-07, 6.7019e-07, 1.2863e-06, 1.9150e-06, 6.8796e-08, 1.1867e-06,
         0.0000e+00, 8.0132e-08, 0.0000e+00, 1.3524e-07, 2.6895e-07, 7.2265e-08,
         8.0536e-08, 7.0292e-07, 3.2829e-08, 1.7377e-06, 9.1122e-07, 1.6109e-06,
     

In [None]:
x, y, z = alice.agent.memory_loaders[0].dataset[0]

In [None]:
x.shape, y.shape, z.shape

In [None]:
y

In [None]:
alice.agent.replay_buffers[1][0]