In [2]:
import torch
from glob import glob
from ast import literal_eval
import os

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from typing import List
from IPython.core.display import display, HTML
from sklearn.manifold import TSNE
from tqdm.notebook import tqdm
from time import time

In [3]:
os.chdir('Continual/Shared/mammoth/')
os.getcwd()
base_path = '/nas/softechict-nas-2/efrascaroli/mammoth-data/'

In [4]:
from torchvision.datasets import CIFAR100
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from backbone.ResNet18 import resnet18
from utils.spectral_analysis import laplacian_analysis

In [5]:
device = 'cuda'
model = resnet18(100).to(device)
C100_TRANSFORM = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5071, 0.4867, 0.4408),
                          (0.2675, 0.2565, 0.2761))])

def c100_test(model: torch.nn.Module):
    ds = CIFAR100(base_path, transform=C100_TRANSFORM, train=False)
    dl = DataLoader(ds, 32)
    acc = 0
    model.eval()
    for x, y in dl:
        x, y = x.to(device), y.to(device)
        features = model.features(x)
        logits = model(x)
        acc += (logits.argmax(1) == y).sum().item()
    acc /= len(ds)
    model.train()
    return acc

In [6]:
def load_buffer(size=100):
    ds = CIFAR100(base_path, transform=C100_TRANSFORM)
    dl = DataLoader(ds, size, shuffle=True)
    x, y = next(iter(dl))
    return x.to(device), y.to(device)

In [7]:
print(f'Init acc: {c100_test(model)*100:.2f}')
saved_dict = torch.load('/nas/softechict-nas-2/efrascaroli/mammoth-data/checkpoints/rs18_cifar100_new.pth')
model.load_state_dict(saved_dict, strict=False)
model.to(device)
print(f'Pre trained acc: {c100_test(model)*100:.2f}')

Init acc: 1.09
Pre trained acc: 67.59


In [8]:
buffer_size = 1000
x_buffer, y_buffer = load_buffer(buffer_size)
knn = 10
n_evects = 20
latents = model.features(x_buffer)
t1 = time()
with torch.no_grad():
    energy, eigenvalues, eigenvectors, L, _ = laplacian_analysis(latents, norm_lap=True, knn=knn, n_pairs=n_evects)
t2 = time()
print(f'computing time: {t2-t1:.4f}')

computing time: 1.9502


In [9]:
len(eigenvectors)


1000