In [1]:
from rich.logging import RichHandler
import logging
from lightning_lite.utilities.seed import seed_everything
from shell_data.dataset.dataset import get_train_val_test_subsets
import torch
import os
from shell_data.utils.config import (
    ShELLDataSharingConfig,
    DatasetConfig,
    TaskModelConfig,
    TrainingConfig,
)
from shell_data.shell_agent.shell_agent_classification import ShELLClassificationAgent

import matplotlib.pyplot as plt
import numpy as np
from sklearn.decomposition import PCA
import trimap
import umap
from sklearn.manifold import TSNE
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
torch.use_deterministic_algorithms(True)
seed_everything(0)

2023-01-17 19:08:39.900592: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-01-17 19:08:40.498664: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2023-01-17 19:08:40.498714: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
Global seed set to 0


0

In [None]:
def to_features(X):
    return X.view(X.size(0), -1)

def dist(X, X2, p=2):
    return torch.cdist(to_features(X), to_features(X2), p=p)

def get_xy(dataset):
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=len(dataset))
    return next(iter(dataloader))

def cifar10_to_backbone_embedding(model, X):
    # with torch.no_grad():
    #     return model(X)
    batch_size = 128
    dataloader = torch.utils.data.DataLoader(X, batch_size=batch_size)
    embeddings = []
    for batch in dataloader:
        batch = batch.to("cuda")
        with torch.no_grad():
            embeddings.append(model(batch))
    return torch.cat(embeddings)

In [None]:
dataset_name = "cifar10"
train_subsets, val_subsets, test_subsets = get_train_val_test_subsets(
    dataset_name)
size = 512
num_cls_per_task = 2

cfg = ShELLDataSharingConfig(
    n_agents=1,
    dataset=DatasetConfig(
        name=dataset_name,
        train_size=size,
        val_size=min(size, min([len(d) for d in val_subsets])),
        num_task_per_life=1,
        num_cls_per_task=num_cls_per_task,
    ),
    task_model=TaskModelConfig(
        name=dataset_name,
    ),
    training=TrainingConfig(
        n_epochs=100,
    )
)

In [None]:
model_name = f"{dataset_name}_128_2.pt"
buffer_name = f"{dataset_name}_buffer"

# model_name = f"{dataset_name}_128_10.pt"
# buffer_name = f"{dataset_name}_10_buffer"

receiver = ShELLClassificationAgent(
        train_subsets, val_subsets, test_subsets, cfg,
        enable_validate_config=False,)
receiver.load_model(model_name)
receiver.load_buffer(buffer_name)
print("buffer len:", [len(b) for b in receiver.buffer.buffers])
print("past tasks:", receiver.buffer.past_tasks)

In [None]:
sender = ShELLClassificationAgent(
        train_subsets, val_subsets, test_subsets, cfg,
        enable_validate_config=False,)
sender.ll_dataset.perm = torch.tensor([5, 7])  # should send 7

In [None]:
sender_data = sender.ll_dataset.get_train_dataset(0)
receiver_data = receiver.ll_dataset.get_train_dataset(0)

In [None]:
# plot some random images from sender_data
n_samples = 5
fig, ax = plt.subplots(1, n_samples, figsize=(10, 10))
for i in range(n_samples):
    random_idx = np.random.randint(0, len(sender_data))
    ax[i].imshow(sender_data[random_idx][0].permute(1, 2, 0))
    ax[i].set_title(f"Label: {sender_data[random_idx][1]}")
plt.show()

In [None]:
# plot some random images from receiver_data
n_samples = 10
fig, ax = plt.subplots(1, n_samples, figsize=(10, 10))
for i in range(n_samples):
    random_idx = np.random.randint(0, len(receiver_data))
    ax[i].imshow(receiver_data[random_idx][0].permute(1, 2, 0))
    ax[i].set_title(f"Label: {receiver_data[random_idx][1]}")
plt.show()

In [None]:
features = {"feats":[]}
def get_features(name):
    def hook(model, input, output):
        features[name].append(output.detach())
    return hook

In [None]:
backbone = receiver.model.net
backbone

In [None]:
backbone.fcs[0].register_forward_hook(get_features("feats"))

In [None]:
# backbone.conv_layers.register_forward_hook(get_features('feats'))

In [None]:
from umap.parametric_umap import ParametricUMAP
clustering = "param_umap"
# clustering = "umap"
# clustering = "tsne"
# clustering = "trimap"

if clustering == "trimap":
    reducer = trimap.TRIMAP()
elif clustering == "umap":
    reducer = umap.UMAP()
elif clustering == "param_umap":
    reducer = ParametricUMAP()
elif clustering == "pca":
    reducer = PCA(n_components=2)
elif clustering == "tsne":
    reducer = TSNE(n_components=2, init="pca", random_state=0)
reducer

In [None]:
# receiver_x, receiver_y = get_xy(receiver_data)
receiver_x, receiver_y = receiver.buffer.get_data(len(receiver.buffer))
sender_x, sender_y = get_xy(sender_data)

In [None]:
receiver_x.shape, sender_x.shape

In [None]:
if dataset_name == "cifar10":
    features = {"feats": []}
    cifar10_to_backbone_embedding(backbone, receiver_x)
    receiver_x = torch.cat(features['feats'])
    features = {"feats": []}
    cifar10_to_backbone_embedding(backbone, sender_x)
    sender_x = torch.cat(features['feats'])
    # print("shape:", receiver_x.shape)
    print("shape:", receiver_x.shape, sender_x.shape)

In [None]:
receiver_x = to_features(receiver_x)
# print("shape:", receiver_x.shape)
sender_x = to_features(sender_x)
print("shape:", receiver_x.shape, sender_x.shape)

In [None]:
joint_x = torch.cat([receiver_x, sender_x])
joint_y = torch.cat([receiver_y, sender_y])

# joint_x = torch.cat([receiver_x])
# joint_y = torch.cat([receiver_y])
joint_x.shape, joint_y.shape

In [None]:
# https://umap-learn.readthedocs.io/en/latest/supervised.html
reducer.fit_transform(joint_x.cpu().numpy(), y=joint_y.cpu().numpy())

In [None]:
# plot the embedding
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
for i in range(len(joint_y)):
    if joint_y[joint_y == i].shape[0] > 0:
        ax.scatter(reducer.embedding_[joint_y == i, 0], reducer.embedding_[joint_y == i, 1], label=i)
ax.legend();

# fig, ax = plt.subplots(figsize=(10, 10))
# red_colors = ["red", "darkred", "salmon", "chocolate"]
# blue_colors = ["blue", "skyblue", "navy"]
# receiver_embed = reducer.embedding_[:len(receiver_y)]
# sender_embed = reducer.embedding_[len(receiver_y):]

# for i in range(10):
#     # plot receiver with bluish color and sender with reddish color
#     if len(receiver_y[receiver_y == i]) > 0:
#         ax.scatter(receiver_embed[receiver_y == i, 0], receiver_embed[receiver_y == i, 1],
#                         label=f"receiver y={i}", color=blue_colors.pop())

#     if len(sender_y[sender_y == i]) > 0:
#         ax.scatter(sender_embed[sender_y == i, 0], sender_embed[sender_y == i, 1],
#                         label=f"sender y={i}", color=red_colors.pop())
# ax.legend();