In [None]:

import torch
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import ImageNet
from torchvision.models import ResNet18_Weights, resnet18

from nsa import estimators, evaluators, utils


  from tqdm.autonotebook import tqdm


# Data Preparation

In [2]:
DEFAULT_TRANSFORMATION = ResNet18_Weights.IMAGENET1K_V1.transforms()

device = utils.get_device()
device

'cuda'

In [3]:

trng = torch.Generator()
trng.manual_seed(42)


ds_train = ImageNet(
    root="/datasets/imagenet",
    split="train",
    transform=DEFAULT_TRANSFORMATION,
)

ds_train, _ = random_split(ds_train, [10000, len(ds_train) - 10000], generator=trng)

ds_val = ImageNet(
    root="/datasets/imagenet",
    split="val",
    transform=DEFAULT_TRANSFORMATION,
)

# Create a DataLoader
dl_train = DataLoader(ds_train, shuffle=False, num_workers=12, batch_size=512) # You can adjust the batch size
dl_val = DataLoader(ds_val, shuffle=False, num_workers=12, batch_size=512) # You can adjust the batch size

# Model Loading

In [4]:
model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
model.eval()
model.to(device)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

# Estimate Covariance

In [5]:
cov = estimators.estimate_cov_mat_at_layer(
    model=model,
    layer="layer1",
    dataloader=dl_train,
    device=device
)

[layer=layer1] estimating covariance matrix:   0%|          | 0/20 [00:00<?, ?it/s]

In [6]:
eigvals, eigvecs = utils.eigh(cov)

# Evaluation: Reconstruction

In [7]:

eval_recon = evaluators.ReconstructionErrorWithLowRankProjectionEvaluator()

df1 = eval_recon.evaluate(
    model=model,
    layer="layer1",
    dataloader=dl_val,
    U=eigvecs,
    arr_ks=[1, 5, 32, 64],
    device=device
)

[layer=layer1] evaluating reconstruction error:   0%|          | 0/98 [00:00<?, ?it/s]

In [8]:
df1

Unnamed: 0,k,norm,recon_err,cossim,d
0,1,87.876297,92.015701,0.097613,64
1,5,87.876297,85.461563,0.362044,64
2,32,87.876297,18.675001,0.976147,64
3,64,87.876297,0.021808,1.0,64


# Evaluation: Accuracy

In [9]:

eval_acc = evaluators.AccuracyWithLowRankProjectionEvaluator(num_classes=1000)
df2 = eval_acc.evaluate(
    model=model,
    layer="layer1",
    dataloader=dl_val,
    U=eigvecs,
    arr_ks=[1, 5, 32, 64],
    device=device
)
df2

[layer=layer1] evaluating accuracy:   0%|          | 0/98 [00:00<?, ?it/s]

Unnamed: 0,k,acc,xent,d
0,1,0.00086,7.666257,64
1,5,0.03558,6.746778,64
2,32,0.6799,1.320652,64
3,64,0.69758,1.246911,64


In [10]:
print("all passed!")

all passed!
