In [1]:
import os
os.chdir("..")

In [2]:
import time
import pandas as pd
from lightning_lite.utilities.seed import seed_everything
from shell_data.dataset.dataset import get_train_val_test_subsets
import torch
import os
from shell_data.utils.config import (
    ShELLDataSharingConfig,
    DatasetConfig,
    TaskModelConfig,
    TrainingConfig,
    ExperienceReplayConfig,
    DataValuationConfig,
    RouterConfig,
    BoltzmanExplorationConfig,
)
from shell_data.utils.record import Record, snapshot_perf, snapshot_conf_mat
import numpy as np
from shell_data.shell_agent.shell_agent_classification import ShELLClassificationAgent
from itertools import combinations
import umap
from copy import deepcopy
from functools import partial


from shell_data.utils.utils import train
import matplotlib.pyplot as plt
# import mplcyberpunk
# plt.style.use("cyberpunk")
# plt.style.use('bmh')
import seaborn as sns
plt.style.use("fivethirtyeight")
# plt.style.use("xkcd")
sns.set_style("whitegrid")
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from sklearn.manifold import TSNE
import random
# from IPython.core.interactiveshell import InteractiveShell
# InteractiveShell.ast_node_interactivity = "all"


os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
torch.use_deterministic_algorithms(True)

SEED = 69
seed_everything(SEED)

2023-02-16 16:47:35.060157: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-02-16 16:47:36.189535: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2023-02-16 16:47:36.189700: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
Global seed set to 69


69

In [3]:
from shell_data.task_model.task_model import TaskModel, SupervisedLearningTaskModel
import torch.nn as nn

## Setup

In [4]:
num_cls_per_task = 5
n_agents = 2
num_task_per_life = 2
buffer_integration_size = 50000  # sample all!
batch_size = 32
size = 64
routing_method = "random"

dataset_name = "mnist"

In [5]:
train_subsets, val_subsets, test_subsets = get_train_val_test_subsets(
        dataset_name)

In [6]:
cfg = ShELLDataSharingConfig(
        n_agents=n_agents,
        dataset=DatasetConfig(
            name=dataset_name,
            train_size=size,
            test_size=1.0,
            val_size=size//2,
            num_task_per_life=num_task_per_life,
            num_cls_per_task=num_cls_per_task,
        ),
        task_model=TaskModelConfig(
            name=dataset_name,
        ),
        training=TrainingConfig(
            n_epochs=50,
            batch_size=batch_size,
            patience=1000,
            val_every_n_epoch=1,
        ),
        experience_replay=ExperienceReplayConfig(
            buffer_size=buffer_integration_size,
        ),
          router=RouterConfig(
            strategy=routing_method,  # control how the sender decides which data point to send
            num_batches=1,
            estimator_task_model=TaskModelConfig(
                name=dataset_name,
            ),
            n_heads=n_agents,
          ),
    )

In [7]:
receiver = ShELLClassificationAgent(
        train_subsets, val_subsets, test_subsets, cfg)

train_size: 64, num_cls_per_task: 5


In [8]:
sender_cfg = deepcopy(cfg)
# sender_cfg.dataset.train_size = 1.0 # all of the data for testing purposes...
sender = ShELLClassificationAgent(
        train_subsets, val_subsets, test_subsets, sender_cfg)

train_size: 64, num_cls_per_task: 5


In [11]:
receiver.ll_dataset.perm = torch.tensor([0, 1, 3, 4, 9,     2, 5, 6, 7, 8])
sender.ll_dataset.perm = torch.tensor([0, 4, 9, 2, 5,       1, 3, 6, 7, 8])

# intersection be 0, 4, 9 (and 2 out of distribution!)

receiver.init()
sender.init()

AttributeError: 'ShELLClassificationAgent' object has no attribute 'init'

## AutoEncoder

In [12]:
class MNISTAutoEncoder(nn.Module):
    """
    Auto-encoder in order to do 
    (1) image similarity search by computing the distance 
    between the latent representations.
    (2) outlier detection by computing the reconstruction error.
    """
    def __init__(self, num_classes) -> None:
        # use architecture here 
        # https://medium.com/dataseries/convolutional-autoencoder-in-pytorch-on-mnist-dataset-d65145c132ac
        super().__init__()

        self.encoder = nn.Sequential(
            nn.Conv2d(1, 16, 3, stride=3, padding=1),  # b, 16, 10, 10
            nn.ReLU(True),
            nn.MaxPool2d(2, stride=2),  # b, 16, 5, 5
            nn.Conv2d(16, 8, 3, stride=2, padding=1),  # b, 8, 3, 3
            nn.ReLU(True),
            nn.MaxPool2d(2, stride=1),  # b, 8, 2, 2
        )
        # output size: 8 * 2 * 2 = 32
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(8, 16, 3, stride=2),  # b, 16, 5, 5
            nn.ReLU(True),
            nn.ConvTranspose2d(16, 8, 5, stride=3, padding=1),  # b, 8, 15, 15
            nn.ReLU(True),
            nn.ConvTranspose2d(8, 1, 2, stride=2, padding=1),  # b, 1, 28, 28
            nn.Sigmoid(),
        )

        self.linear = nn.Linear(32, num_classes)


    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [13]:
from typing import (
    Optional,
    Union,
    Tuple,
)

class ReconstructionTaskModel(TaskModel):
    def __init__(self):
        self.net = MNISTAutoEncoder()
        self.criterion = nn.MSELoss()
        self.optimizer = torch.optim.Adam(self.net.parameters(), lr=1e-3)
        self.device = "cuda"
        self.net.to(self.device)
    

    def train_step(self, batch: Tuple[torch.Tensor, torch.Tensor], head_id=None):
        x, _ = self.to_device(batch)
        self.net.train()
        self.optimizer.zero_grad()
        reconstructed = self.net(x)
        loss = self.criterion(reconstructed, x)
        loss.backward()
        self.optimizer.step()
        return loss.item()
    
    def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor]):
        x, _ = self.to_device(batch)
        self.net.eval()
        with torch.no_grad():
            reconstructed = self.net(x)
            loss = self.criterion(reconstructed, x)
        return loss.item()
    
    def val_step(self, batch) -> float:
        return self.test_step(batch)

In [14]:
autoencoder = ReconstructionTaskModel()

TypeError: __init__() missing 1 required positional argument: 'num_classes'

In [None]:
receiver_data = receiver.ll_dataset.get_train_dataset(0, kind="all")
len(receiver_data)

In [None]:
sender_data = sender.ll_dataset.get_train_dataset(0, kind="all")
len(sender_data)

In [None]:
train_dataloader = torch.utils.data.DataLoader(
    receiver_data,
    batch_size=16,
    shuffle=True,
    num_workers=0,
    pin_memory=True,
)

In [None]:
def val_func(early_stopping, global_step, epoch, train_loss, record):
    print(f"Epoch: {epoch}, global_step: {global_step}, train_loss: {train_loss}")
    # total_train_loss = train on the entire dataset

    total_dataloader = torch.utils.data.DataLoader(
        receiver_data, batch_size=len(receiver_data), shuffle=False, num_workers=0, pin_memory=True
    )
    total_batch = next(iter(total_dataloader))
    total_loss = autoencoder.test_step(total_batch)
    record.write({
        "epoch": epoch,
        "global_step": global_step,
        "train_loss": train_loss,
        "total_loss": total_loss
    })
    return early_stopping.step(total_loss)

In [None]:
record = Record("mnist_autoencoder.csv")

In [None]:
# """
# https://medium.com/dataseries/convolutional-autoencoder-in-pytorch-on-mnist-dataset-d65145c132ac
# Should train for about 30 epochs
# """
# if os.path.exists("mnist_autoencoder.pt"):
#     print("loading...")
#     autoencoder.net.load_state_dict(torch.load("mnist_autoencoder.pt")) 
# else:
#     print('training...')
#     train(autoencoder, train_dataloader, val_dataloader=None, n_epochs=500, val_every_n_epoch=1,
#         patience=20, delta=0.0,
#         val_func=partial(val_func, record=record), val_before=False);
#     torch.save(autoencoder.net.state_dict(), "mnist_autoencoder.pt")
#     record.save()

In [None]:
# df = pd.read_csv(f"mnist_autoencoder.csv")
# # shortened = df.iloc[15:]
# df.plot(x="epoch", y=["train_loss", "total_loss"])

In [None]:
# df["total_loss"].min()

In [None]:
# # pick some random images and see how the autoencoder reconstructs them
# import matplotlib.pyplot as plt
# import numpy as np

# num_images = 10

# rand_idx = np.random.randint(0, len(receiver_data), num_images)
# images = [receiver_data[i][0] for i in rand_idx]

# fig, axes = plt.subplots(nrows=2, ncols=num_images, figsize=(20, 4))
# for i, image in enumerate(images):
#     axes[0, i].imshow(image.squeeze(), cmap="gray")
#     axes[0, i].axis("off")
#     axes[1, i].imshow(autoencoder.net(image.unsqueeze(0).to(autoencoder.device)).squeeze().cpu().detach(), cmap="gray")
#     axes[1, i].axis("off")

# plt.show();

Reconstruction looks pretty good! Reconstruction error is about 0.02, which is what you'd expect for MNIST.

In [None]:
def to_features(X):
    return X.view(X.size(0), -1)

In [None]:

def clustering_reducer(reducer, X):
    return torch.tensor(reducer.transform(to_features(X).cpu().numpy()))

Maybe need other metric:
https://medium.com/analytics-vidhya/image-similarity-model-6b89a22e2f1a
(add tsne embedding on top of the encoded X) or use cosine similarity
instead of 2D distance!

In [None]:
def autoencoder_reducer(autoencoder, X):
    X = X.to(autoencoder.device)
    X = autoencoder.net.encoder(X)
    X = to_features(X)
    return X

In [None]:
import torch.nn.functional as F
def image_search(queries, database, reducer_callable, n_neighbors=10, p=2, metric="distance"):
    query_embed = reducer_callable(X=queries)
    database_embed = reducer_callable(X=database)
    if metric == "distance":
        dist = torch.cdist(query_embed, database_embed, p=p)
    elif metric == "cosine":
        dist = 1 - torch.stack([F.cosine_similarity(query_embed[i], database_embed) for i in range(len(query_embed))])
    else:
        raise ValueError(f"metric {metric} is not supported")
    closest_dist, closest_idx = torch.topk(dist, k=n_neighbors, dim=1, largest=False)
    return closest_dist, closest_idx

In [None]:
X_receiver = torch.stack([x for x, _ in receiver_data])
X_sender = torch.stack([x for x, _ in sender_data])

In [None]:
y_sender = torch.tensor([y for _, y in sender_data])
y_receiver = torch.tensor([y for _, y in receiver_data])

In [None]:
y_receiver

## Image search

In [None]:
"""
Pick some random images from receiver and plot the neighbors from sender
returned above
"""

def viz_image_search(queries, database, closest_idx, closest_dist):
    num_images = 10
    n_neighbors = 10
    x_idx = np.random.randint(0, len(queries), num_images)
    x = queries[x_idx]
    closest_idx = closest_idx[x_idx]

    fig, axes = plt.subplots(nrows=num_images, ncols=n_neighbors + 1, figsize=( 10, 12));
    for i, image in enumerate(x):
        axes[i, 0].imshow(image.cpu().squeeze(), cmap="gray");
        axes[i, 0].axis("off");
        for j in range(10):
            axes[i, j+1].imshow(database[closest_idx[i, j]].cpu().squeeze(), cmap="gray");
            axes[i, j+1].title.set_text(f"{closest_dist[i, j]:.2f}");
            axes[i, j+1].axis("off");
    
    plt.show();

In [None]:
def compute_img_search_quality(queries, database, query_y, database_y, closest_idx, n_neighbors=5):
    # for each query get the k nearest neighbors
    closest_idx = closest_idx.cpu()
    neighbor_idx = closest_idx[:, :n_neighbors]
    # compute the accuracy which is defined the fraction of database_y of neighbor_idx match with query_y
    neighbor_y = database_y[neighbor_idx]
    accuracy = (neighbor_y == query_y.unsqueeze(1)).sum(dim=1) / n_neighbors
    # accuracy per sample
    return accuracy

In [None]:
reducer = umap.UMAP(n_neighbors=10, min_dist=0.0, n_components=2, random_state=42)
reducer.fit(to_features(X_receiver).cpu())

In [None]:
# training
closest_dist, closest_idx = image_search(X_receiver, X_receiver, reducer_callable=partial(autoencoder_reducer, autoencoder=autoencoder),
metric="cosine")
viz_image_search(X_receiver, X_receiver, closest_idx, closest_dist)

In [None]:
ae_training_retrieval_acc = compute_img_search_quality(X_receiver, X_receiver, y_receiver, y_receiver, closest_idx)
print(ae_training_retrieval_acc.mean())

In [None]:
closest_dist, closest_idx = image_search(X_receiver, X_receiver, reducer_callable=partial(clustering_reducer, reducer=reducer))
viz_image_search(X_receiver, X_receiver, closest_idx, closest_dist)

In [None]:
clustering_training_retrieval_acc = compute_img_search_quality(X_receiver, X_receiver, y_receiver, y_receiver, closest_idx)
print(clustering_training_retrieval_acc.mean())

In [None]:
closest_dist, closest_idx = image_search(X_sender, X_receiver,  reducer_callable=partial(autoencoder_reducer, autoencoder=autoencoder), metric="cosine")
viz_image_search(X_sender, X_receiver, closest_idx, closest_dist)

In [None]:
ae_testing_retrieval_acc = compute_img_search_quality(X_sender, X_receiver, y_sender, y_receiver, closest_idx)
print(ae_testing_retrieval_acc.mean())
# filter out all rows that have 2 and 5 labels (OOD)
in_dist_testing_acc = ae_testing_retrieval_acc[(y_sender != 2) & (y_sender != 5)]
print(in_dist_testing_acc.mean())

In [None]:
closest_dist, closest_idx = image_search(X_sender, X_receiver,  reducer_callable=partial(clustering_reducer, reducer=reducer))
viz_image_search(X_sender, X_receiver, closest_idx, closest_dist)

In [None]:
clustering_testing_retrieval_acc = compute_img_search_quality(X_sender, X_receiver, y_sender, y_receiver, closest_idx)
print(clustering_testing_retrieval_acc.mean())
# filter out all rows that have 2 and 5 labels (OOD)
in_dist_testing_acc = clustering_testing_retrieval_acc[(y_sender != 2) & (y_sender != 5)]
print(in_dist_testing_acc.mean())

## Outliers

In [None]:
# get the loss distribution of the autoencoder on the training data
train_dataloader = torch.utils.data.DataLoader(
    receiver_data,
    batch_size=1,
    shuffle=True,
    num_workers=0,
    pin_memory=True,
)

losses = []
for batch in train_dataloader:
    losses.append(autoencoder.test_step(batch))
len(losses)

In [None]:
# identify the upper bound of the loss distribution as k-standard deviations away from the mean
k = 3
mean = np.mean(losses)
std = np.std(losses)
upper_bound = mean + k * std

In [None]:
# plot the loss distribution
plt.hist(losses, bins=100);
# draw a vertical line at the upper bound
plt.axvline(x=upper_bound, color="red");

In [None]:
def get_outliers(model, x, upper_bound):
    # x = (batch_size, 1, 28, 28)
    x = x.to(model.device)
    model.net.eval()
    criterion = nn.MSELoss(reduce=False)
    with torch.no_grad():
        reconstructed = model.net(x)
        loss = criterion(reconstructed, x).mean(dim=(1, 2, 3))
    return (loss > upper_bound).cpu(), loss.cpu()

In [None]:
def viz_outliers(outliers, losses, outlier_idx, X):
    n_samples = min(10, len(outliers))
    print(f"Found {len(outliers)} outliers")
    # plot the outliers
    fig, axes = plt.subplots(nrows=2, ncols=n_samples, figsize=(20, 2));
    for i in range(n_samples):
        if n_samples < len(outliers):
            random_idx = np.random.randint(0, len(outliers))
        else:
            random_idx = i
        axes[0, i].imshow(X[outlier_idx][random_idx].cpu().squeeze(), cmap="gray");
        axes[0, i].title.set_text(f"{losses[outlier_idx][random_idx]:.4f}");
        # plot the reconstructed image
        axes[1, i].imshow(autoencoder.net(X[outlier_idx][random_idx].unsqueeze(0).to(autoencoder.device)).squeeze().cpu().detach(), cmap="gray");
        axes[0, i].axis("off");
        axes[1, i].axis("off");

In [None]:
receiver_outliers_idx, losses = get_outliers(autoencoder, X_receiver, upper_bound)
receiver_outliers = X_receiver[receiver_outliers_idx]
viz_outliers(receiver_outliers, losses, receiver_outliers_idx, X_receiver)

In [None]:
torch.unique(y_receiver[receiver_outliers_idx], return_counts=True)

In [None]:
sender_outliers_idx, losses = get_outliers(autoencoder, X_sender, upper_bound)
sender_outliers = X_sender[sender_outliers_idx]
print(len(sender_outliers))
print(min(losses))

In [None]:
torch.unique(y_sender[sender_outliers_idx], return_counts=True)

In [None]:
viz_outliers(sender_outliers, losses, sender_outliers_idx, X_sender)

In [None]:
# try the ae image similarity again, this time with the outlier filter...
closest_dist, closest_idx = image_search(X_sender, X_receiver,  reducer_callable=partial(autoencoder_reducer, autoencoder=autoencoder))
viz_image_search(X_sender, X_receiver, closest_idx, closest_dist)

In [None]:
ae_testing_retrieval_acc = compute_img_search_quality(X_sender, X_receiver, y_sender, y_receiver, closest_idx)
print(ae_testing_retrieval_acc.mean())
# filter out all rows that have 2 and 5 labels (OOD)
out_dist_testing_acc = ae_testing_retrieval_acc[sender_outliers_idx]
in_dist_testing_acc = ae_testing_retrieval_acc[~sender_outliers_idx]
print(out_dist_testing_acc.mean())
print(in_dist_testing_acc.mean())

## Contrasive Loss

In [None]:
class SupConLoss(nn.Module):
    """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.
    It also supports the unsupervised contrastive loss in SimCLR"""
    def __init__(self, temperature=0.06, contrast_mode='one',
                 base_temperature=0.06):
        super(SupConLoss, self).__init__()
        self.temperature = temperature
        self.contrast_mode = contrast_mode
        self.base_temperature = base_temperature
    def forward(self, features, labels=None, mask=None):
        """Compute loss for model. If both `labels` and `mask` are None,
        it degenerates to SimCLR unsupervised loss:
        https://arxiv.org/pdf/2002.05709.pdf
        Args:
            features: hidden vector of shape [bsz, n_views, ...].
            labels: ground truth of shape [bsz].
            mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
                has the same class as sample i. Can be asymmetric.
        Returns:
            A loss scalar.
        """
        if len(features.shape) < 3:
            raise ValueError('`features` needs to be [bsz, n_views, ...],'
                             'at least 3 dimensions are required')
        if len(features.shape) > 3:
            features = features.view(features.shape[0], features.shape[1], -1)
        batch_size = features.shape[0]
        if labels is not None and mask is not None:
            raise ValueError('Cannot define both `labels` and `mask`')
        elif labels is None and mask is None:
            mask = torch.eye(batch_size, dtype=torch.float32)
        elif labels is not None:
            labels = labels.contiguous().view(-1, 1)
            if labels.shape[0] != batch_size:
                raise ValueError('Num of labels does not match num of features')
            mask = torch.eq(labels, labels.T).float()
        else:
            mask = mask.float()
        contrast_count = features.shape[1]
        contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
        if self.contrast_mode == 'one':
            anchor_feature = features[:, 0]
            anchor_count = 1
        elif self.contrast_mode == 'all':
            anchor_feature = contrast_feature
            anchor_count = contrast_count
        else:
            raise ValueError('Unknown mode: {}'.format(self.contrast_mode))
        # compute logits
        anchor_dot_contrast = torch.div(
            torch.matmul(anchor_feature, contrast_feature.T),
            self.temperature)
        # for numerical stability
        logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
        logits = anchor_dot_contrast - logits_max.detach()
        # tile mask
        mask = mask.repeat(anchor_count, contrast_count)
        # mask-out self-contrast cases
        logits_mask = torch.scatter(
            torch.ones_like(mask),
            1,
            torch.arange(batch_size * anchor_count).view(-1, 1),
            0
        )
        mask = mask * logits_mask
        # compute log_prob
        logits = torch.clamp(logits, min=-20)
        exp_logits = torch.exp(logits) * logits_mask
        exp_logits = torch.where(exp_logits > 1e-6, exp_logits, torch.tensor(0).float())
        log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
        # compute mean of log-likelihood over positive
        mean_log_prob_pos = (mask * log_prob).sum(1) / (mask.sum(1) + 1e-6)
        # loss
        loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
        loss = loss.view(anchor_count, batch_size).mean()
        return loss

In [None]:
train_dataloader = torch.utils.data.DataLoader(
    receiver_data,
    batch_size=16,
    shuffle=True,
    num_workers=0,
    pin_memory=True,
)

In [None]:
class MNISTContrastiveEncoder(nn.Module):
    """
    Auto-encoder in order to do 
    (1) image similarity search by computing the distance 
    between the latent representations.
    (2) outlier detection by computing the reconstruction error.
    """
    def __init__(self) -> None:
        # use architecture here 
        # https://medium.com/dataseries/convolutional-autoencoder-in-pytorch-on-mnist-dataset-d65145c132ac
        super().__init__()

        self.encoder = nn.Sequential(
            nn.Conv2d(1, 16, 3, stride=3, padding=1),  # b, 16, 10, 10
            nn.ReLU(True),
            nn.MaxPool2d(2, stride=2),  # b, 16, 5, 5
            nn.Conv2d(16, 8, 3, stride=2, padding=1),  # b, 8, 3, 3
            nn.ReLU(True),
            nn.MaxPool2d(2, stride=1),  # b, 8, 2, 2
        )
         # output size: 8 * 2 * 2 = 32
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(8, 16, 3, stride=2),  # b, 16, 5, 5
            nn.ReLU(True),
            nn.ConvTranspose2d(16, 8, 5, stride=3, padding=1),  # b, 8, 15, 15
            nn.ReLU(True),
            nn.ConvTranspose2d(8, 1, 2, stride=2, padding=1),  # b, 1, 28, 28
            nn.Sigmoid(),
        )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [None]:
import torchvision.transforms as transforms

network = MNISTContrastiveEncoder()
losses = []
cont_losses = []
rec_losses = []

n_epochs = 1

optimizer = torch.optim.Adam(network.parameters(), lr=1e-3)

scl = SupConLoss()
rl = nn.MSELoss()

for i in range(n_epochs):
    for batch in train_dataloader:
        x, y = batch
        train_transform = transforms.Compose([
                        # transforms.RandomResizedCrop(size=32, scale=(0.2, 1.)),
                        transforms.RandomResizedCrop(size=28, scale=(0.2, 1.)),
                        transforms.RandomHorizontalFlip(),
                        transforms.RandomApply([
                            transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
                        ], p=0.8),
                        transforms.RandomGrayscale(p=0.2),
                    ])
        print(x.shape, train_transform(x).shape)

        encoded_transformed_images = network.encoder(train_transform(x))
        encoded_images = network.encoder(x)
        print(encoded_transformed_images.shape, encoded_images.shape)

        encoded_transformed_images = encoded_transformed_images.view(
            encoded_transformed_images.shape[0], -1)
        encoded_images = encoded_images.view(
            encoded_images.shape[0], -1)

        features = torch.cat(
            [encoded_transformed_images.unsqueeze(1), 
                encoded_images.unsqueeze(1)], dim=1)

        cont_loss = scl(features, y)
        rec_loss = rl(x, network(x))
        loss = cont_loss + rec_loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
        cont_losses.append(cont_loss.item())
        rec_losses.append(rec_loss.item())

In [None]:
plt.plot(losses);
plt.plot(cont_losses, color='red');
plt.plot(rec_losses, color='green');

In [None]:
plt.plot(cont_losses, color='red');
print(min(cont_losses))

In [None]:
plt.plot(rec_losses, color='green');
print(min(rec_losses))

In [None]:
def contrastive_callable(network, X):
    """
    This function is called by the ContrastiveModel
    """
    encoded_images = network.encoder(X).view(X.shape[0], -1)
    return encoded_images

In [None]:
# training
closest_dist, closest_idx = image_search(X_receiver, X_receiver, reducer_callable=partial(contrastive_callable, network=network),
metric="cosine")
viz_image_search(X_receiver, X_receiver, closest_idx, closest_dist)

In [None]:
contrastive_training_retrieval_acc = compute_img_search_quality(X_receiver, X_receiver, y_receiver, y_receiver, closest_idx)
print(contrastive_training_retrieval_acc.mean())

In [None]:
closest_dist, closest_idx = image_search(X_sender, X_receiver,  reducer_callable=partial(contrastive_callable, network=network),
metric="cosine")
viz_image_search(X_sender, X_receiver, closest_idx, closest_dist)

In [None]:
contrastive_testing_retrieval_acc = compute_img_search_quality(X_sender, X_receiver, y_sender, y_receiver, closest_idx)
print(contrastive_testing_retrieval_acc.mean())
# filter out all rows that have 2 and 5 labels (OOD)
in_dist_testing_acc = contrastive_testing_retrieval_acc[(y_sender != 2) & (y_sender != 5)]
print(in_dist_testing_acc.mean())

In [None]:
# get the loss distribution of the autoencoder on the training data
train_dataloader = torch.utils.data.DataLoader(
    receiver_data,
    batch_size=1,
    shuffle=True,
    num_workers=0,
    pin_memory=True,
)

losses = []
for batch in train_dataloader:
    x, y = batch
    train_transform = transforms.Compose([
                    transforms.RandomResizedCrop(size=32, scale=(0.2, 1.)),
                    transforms.RandomHorizontalFlip(),
                    transforms.RandomApply([
                        transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
                    ], p=0.8),
                    transforms.RandomGrayscale(p=0.2),
                ])

    encoded_transformed_images = network.encoder(train_transform(x))
    encoded_images = network.encoder(x)

    encoded_transformed_images = encoded_transformed_images.view(
        encoded_transformed_images.shape[0], -1)
    encoded_images = encoded_images.view(
        encoded_images.shape[0], -1)

    features = torch.cat(
        [encoded_transformed_images.unsqueeze(1), 
            encoded_images.unsqueeze(1)], dim=1)

    cont_loss = scl(features, y)
    rec_loss = rl(x, network(x))
    loss = cont_loss + rec_loss
    losses.append(loss.item())
len(losses)

In [None]:
# identify the upper bound of the loss distribution as k-standard deviations away from the mean
k = 2
mean = np.mean(losses)
std = np.std(losses)
upper_bound = mean + k * std

In [None]:
# plot the loss distribution
plt.hist(losses, bins=100);
# draw a vertical line at the upper bound
plt.axvline(x=upper_bound, color="red");

In [None]:
def get_outliers(model, x, upper_bound):
    # x = (batch_size, 1, 28, 28)
    # x = x.to(model.device)
    criterion = nn.MSELoss()
    model.eval()
    criterion = nn.MSELoss(reduce=False)
    with torch.no_grad():
        reconstructed = model(x)
        loss = criterion(reconstructed, x).mean(dim=(1, 2, 3))
    return (loss > upper_bound).cpu(), loss.cpu()

In [None]:
def viz_outliers(outliers, losses, outlier_idx, X):
    n_samples = min(10, len(outliers))
    print(f"Found {len(outliers)} outliers")
    # plot the outliers
    fig, axes = plt.subplots(nrows=2, ncols=n_samples, figsize=(20, 2));
    for i in range(n_samples):
        if n_samples < len(outliers):
            random_idx = np.random.randint(0, len(outliers))
        else:
            random_idx = i
        axes[0, i].imshow(X[outlier_idx][random_idx].cpu().squeeze(), cmap="gray");
        axes[0, i].title.set_text(f"{losses[outlier_idx][random_idx]:.4f}");
        # plot the reconstructed image
        axes[1, i].imshow(autoencoder.net(X[outlier_idx][random_idx].unsqueeze(0).to(autoencoder.device)).squeeze().cpu().detach(), cmap="gray");
        axes[0, i].axis("off");
        axes[1, i].axis("off");

In [None]:
receiver_outliers_idx, losses = get_outliers(network, X_receiver, upper_bound)
receiver_outliers = X_receiver[receiver_outliers_idx]
viz_outliers(receiver_outliers, losses, receiver_outliers_idx, X_receiver)

In [None]:
from pytorch_ood.detector import (
     ODIN,
     EnergyBased,
     KLMatching,
     Mahalanobis,
     MaxLogit,
     MaxSoftmax,
     ViM,
     MCD,
 )

In [None]:
receiver.load_model("./results/ood.pt")

In [None]:
# remap receiver_data so that y classes are 0, 1, 2, 3, 4
receiver.model.net.to("cpu")
y_receiver_remap = torch.zeros_like(y_receiver)
for i, y in enumerate(torch.unique(y_receiver)):
    y_receiver_remap[y_receiver == y] = i

receiver_data_remapped = torch.utils.data.TensorDataset(X_receiver, y_receiver_remap)

detector = Mahalanobis(receiver.model.net.features)
receiver_dataloader = torch.utils.data.DataLoader(
    receiver_data_remapped,
    batch_size=32,
    shuffle=False,
    num_workers=0,
    pin_memory=True,
)
detector.fit(receiver_dataloader, device="cpu")

In [None]:
# training
with torch.no_grad():
    train_scores = detector(X_receiver)

In [None]:
# plot the scores
# higher values indicate more likely to be OOD
m = torch.mean(train_scores.cpu())
print("mean", m)
plt.hist(train_scores.cpu().numpy(), bins=100);

k = 0.5
upper_bound = m + k * torch.std(train_scores.cpu())
print("upper bound", upper_bound)
plt.axvline(m, color="green", label="mean")
plt.axvline(upper_bound, color="red", label="upper bound");

In [None]:
# testing
with torch.no_grad():
    test_scores = detector(X_sender)

In [None]:
sender_outliers_idx = (test_scores > upper_bound).cpu()
print("Number of outliers:", sender_outliers_idx.sum())
y_outlier = y_sender[sender_outliers_idx]
sender_outliers = X_sender[sender_outliers_idx]
print("Outlier labels:", torch.unique(y_outlier, return_counts=True))

In [None]:
outlier_classes = [2, 5]
num_total_outliers = len(outlier_classes) * 64
num_outliers_caught = 0
for clz in outlier_classes:
    num_outliers_caught += (y_outlier == clz).sum().item() 

acc = num_outliers_caught / len(y_outlier)
false_positive_rate = 1 - acc
detection_rate = num_outliers_caught / num_total_outliers
print("acc:", acc)
print("detecion rate:", detection_rate)

In [None]:
# sender_outliers = X_sender[sender_outliers_idx]
# print(len(sender_outliers))
# print(min(losses))

In [None]:
torch.unique(y_sender[sender_outliers_idx], return_counts=True)

In [None]:
viz_outliers(sender_outliers, losses, sender_outliers_idx, X_sender)

In [None]:
contrastive_testing_retrieval_acc = compute_img_search_quality(X_sender, X_receiver, y_sender, y_receiver, closest_idx)
print(contrastive_testing_retrieval_acc.mean())
# filter out all OOD rows (according to the estimated OOD detector)
in_dist_testing_acc = contrastive_testing_retrieval_acc[~sender_outliers_idx]
print(in_dist_testing_acc.shape, in_dist_testing_acc.mean())
out_dist_testing_acc = contrastive_testing_retrieval_acc[sender_outliers_idx]
print(out_dist_testing_acc.shape, out_dist_testing_acc.mean())
in_dist_testing_acc = contrastive_testing_retrieval_acc[(y_sender != 2) & (y_sender != 5)]
print("oracle", in_dist_testing_acc.shape, in_dist_testing_acc.mean())
out_dist_testing_acc = contrastive_testing_retrieval_acc[(y_sender == 2) | (y_sender == 5)]
print("oracle", out_dist_testing_acc.shape, out_dist_testing_acc.mean())