In [1]:

import torch
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import ImageNet
from torchvision.models.vision_transformer import vit_b_16
from torchvision.models import ViT_B_16_Weights

from nsa import estimators, evaluators, utils


  from tqdm.autonotebook import tqdm


# Data Preparation

In [2]:
DEFAULT_TRANSFORMATION = ViT_B_16_Weights.IMAGENET1K_V1.transforms()

device = utils.get_device()
device

'cuda'

In [3]:

trng = torch.Generator()
trng.manual_seed(42)


ds_train = ImageNet(
    root="/datasets/imagenet",
    split="train",
    transform=DEFAULT_TRANSFORMATION,
)

ds_train, _ = random_split(ds_train, [10000, len(ds_train) - 10000], generator=trng)

ds_val = ImageNet(
    root="/datasets/imagenet",
    split="val",
    transform=DEFAULT_TRANSFORMATION,
)

# Create a DataLoader
dl_train = DataLoader(ds_train, shuffle=False, num_workers=12, batch_size=128) # You can adjust the batch size
dl_val = DataLoader(ds_val, shuffle=False, num_workers=12, batch_size=128) # You can adjust the batch size

# Model Loading

In [4]:
model = vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1)
model.eval()
model.to(device)

VisionTransformer(
  (conv_proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
  (encoder): Encoder(
    (dropout): Dropout(p=0.0, inplace=False)
    (layers): Sequential(
      (encoder_layer_0): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
        )
        (dropout): Dropout(p=0.0, inplace=False)
        (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): MLPBlock(
          (0): Linear(in_features=768, out_features=3072, bias=True)
          (1): GELU(approximate='none')
          (2): Dropout(p=0.0, inplace=False)
          (3): Linear(in_features=3072, out_features=768, bias=True)
          (4): Dropout(p=0.0, inplace=False)
        )
      )
      (encoder_layer_1): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_a

# Estimate Covariance

In [5]:
layer = "encoder.layers.8"

In [6]:
cov = estimators.estimate_cov_mat_at_layer(
    model=model,
    layer=layer,
    dataloader=dl_train,
    device=device
)

[layer=encoder.layers.8] estimating covariance matrix:   0%|          | 0/79 [00:00<?, ?it/s]

In [7]:
eigvals, eigvecs = utils.eigh(cov)

# Evaluation: Reconstruction

In [8]:

eval_recon = evaluators.ReconstructionErrorWithLowRankProjectionEvaluator()

df1 = eval_recon.evaluate(
    model=model,
    layer=layer,
    dataloader=dl_val,
    U=eigvecs,
    arr_ks=[128, 256, 768],
    device=device
)

[layer=encoder.layers.8] evaluating reconstruction error:   0%|          | 0/391 [00:00<?, ?it/s]

In [9]:
df1

Unnamed: 0,k,norm,recon_err,cossim,d
0,128,19.223833,27.485861,0.391339,768
1,256,19.223833,16.84824,0.701184,768
2,768,19.223833,0.002044,1.0,768


# Evaluation: Accuracy

In [10]:

eval_acc = evaluators.AccuracyWithLowRankProjectionEvaluator(num_classes=1000)
df2 = eval_acc.evaluate(
    model=model,
    layer=layer,
    dataloader=dl_val,
    U=eigvecs,
    arr_ks=[128, 256, 768],
    device=device
)
df2

[layer=encoder.layers.8] evaluating accuracy:   0%|          | 0/391 [00:00<?, ?it/s]

Unnamed: 0,k,acc,xent,d
0,128,0.56048,2.639793,768
1,256,0.75188,1.262053,768
2,768,0.81066,0.838437,768


In [11]:
print("all passed!")

all passed!
