In [1]:
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchmetrics

from pathlib import Path
from torchvision.models import convnext
from torchvision.transforms import v2 as transforms
from tqdm import tqdm

In [2]:
model = torchvision.models.convnext_tiny(weights=convnext.ConvNeXt_Tiny_Weights.DEFAULT)
model

ConvNeXt(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 96, kernel_size=(4, 4), stride=(4, 4))
      (1): LayerNorm2d((96,), eps=1e-06, elementwise_affine=True)
    )
    (1): Sequential(
      (0): CNBlock(
        (block): Sequential(
          (0): Conv2d(96, 96, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=96)
          (1): Permute()
          (2): LayerNorm((96,), eps=1e-06, elementwise_affine=True)
          (3): Linear(in_features=96, out_features=384, bias=True)
          (4): GELU(approximate='none')
          (5): Linear(in_features=384, out_features=96, bias=True)
          (6): Permute()
        )
        (stochastic_depth): StochasticDepth(p=0.0, mode=row)
      )
      (1): CNBlock(
        (block): Sequential(
          (0): Conv2d(96, 96, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=96)
          (1): Permute()
          (2): LayerNorm((96,), eps=1e-06, elementwise_affine=True)
          (3): Linear(in_features=

## Experiment: Classify a single layer through all its 32 heads' attention maps.

In [3]:
# Create dataset
LAYERS_INCLUDED = [1, 11, 21, 31]
NUM_LAYER_LLM = len(LAYERS_INCLUDED)

class AttentionMapDataset(torch.utils.data.Dataset):
    def __init__(self, root, annotations_file, transform=None, target_transform=None):
        self.root = Path(root)
        self.files = pd.read_csv(annotations_file)
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.files) * NUM_LAYER_LLM

    def __getitem__(self, idx):
        f_idx = int(idx // NUM_LAYER_LLM)
        pt_path = self.root / self.files['filename'][f_idx]
        f = torch.load(pt_path)
        random_layer = idx % NUM_LAYER_LLM
        layer_idx = LAYERS_INCLUDED[random_layer]
        heads = f[layer_idx]
        if self.transform:
            heads = self.transform(heads)
        bucket_label = torch.tensor(random_layer)
        heads.unsqueeze(dim=0)
        bucket_label.unsqueeze(dim=0)
        return heads.to(torch.float32), bucket_label


In [4]:
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(0.5),
    transforms.RandomVerticalFlip(0.5),
    transforms.RandomChannelPermutation(),
    transforms.Normalize([0.5], [0.25]),
])

dataset = AttentionMapDataset(Path.home() / "Downloads/mmlu_output/", Path.home() / "Downloads/mmlu_attention_files_list.txt", transform=transform)

# Randomly split into training and test set
# indices = torch.randperm(len(dataset)).tolist()
indices = torch.arange(len(dataset)).tolist()
split_idx = int((len(indices) // NUM_LAYER_LLM) * 0.8) * NUM_LAYER_LLM
train_data = torch.utils.data.Subset(dataset, indices[:split_idx])
test_data = torch.utils.data.Subset(dataset, indices[split_idx:])
len(train_data), len(test_data)

(5552, 1388)

In [5]:
split_idx

5552

In [6]:
dataset[split_idx+3]

  f = torch.load(pt_path)


(tensor([[[ 2.0000, -2.0000, -2.0000,  ..., -2.0000, -2.0000, -2.0000],
          [-1.1953,  1.1875, -2.0000,  ..., -2.0000, -2.0000, -2.0000],
          [ 0.9375, -1.9062, -1.0234,  ..., -2.0000, -2.0000, -2.0000],
          ...,
          [ 0.3750, -1.9922, -2.0000,  ..., -0.8203, -2.0000, -2.0000],
          [ 1.0625, -2.0000, -2.0000,  ..., -1.9141, -1.1875, -2.0000],
          [-0.2578, -2.0000, -2.0000,  ..., -1.8125, -1.2734, -0.7734]],
 
         [[ 2.0000, -2.0000, -2.0000,  ..., -2.0000, -2.0000, -2.0000],
          [ 1.8906, -1.8906, -2.0000,  ..., -2.0000, -2.0000, -2.0000],
          [ 1.9688, -1.9844, -1.9844,  ..., -2.0000, -2.0000, -2.0000],
          ...,
          [ 0.9062, -1.9609, -1.9766,  ..., -1.8984, -2.0000, -2.0000],
          [ 0.4062, -1.9844, -1.9922,  ..., -1.9375, -1.9375, -2.0000],
          [ 0.2812, -1.9844, -1.9922,  ..., -1.9531, -1.9531, -1.8750]],
 
         [[ 2.0000, -2.0000, -2.0000,  ..., -2.0000, -2.0000, -2.0000],
          [-2.0000,  2.0000,

In [7]:
train_data[-1]

  f = torch.load(pt_path)


(tensor([[[-0.2500, -1.9922, -1.7188,  ..., -1.9688, -1.9062, -1.9375],
          [ 0.1719, -1.9922, -1.8125,  ..., -1.8438, -1.7109, -2.0000],
          [-0.0234, -1.9766, -1.5000,  ..., -1.9375, -2.0000, -2.0000],
          ...,
          [ 0.7656, -1.8984, -0.8594,  ..., -2.0000, -2.0000, -2.0000],
          [ 1.4531, -1.4531, -2.0000,  ..., -2.0000, -2.0000, -2.0000],
          [ 2.0000, -2.0000, -2.0000,  ..., -2.0000, -2.0000, -2.0000]],
 
         [[-0.4844, -2.0000, -2.0000,  ..., -1.8906, -1.7656,  0.0938],
          [-1.1250, -2.0000, -2.0000,  ..., -1.8906,  1.0000, -2.0000],
          [-1.2500, -2.0000, -2.0000,  ...,  1.1562, -2.0000, -2.0000],
          ...,
          [-0.4922, -1.7500,  0.2344,  ..., -2.0000, -2.0000, -2.0000],
          [ 0.7188, -0.7188, -2.0000,  ..., -2.0000, -2.0000, -2.0000],
          [ 2.0000, -2.0000, -2.0000,  ..., -2.0000, -2.0000, -2.0000]],
 
         [[-1.2734, -2.0000, -1.9922,  ..., -1.5859, -1.4062, -1.3281],
          [-0.7266, -2.0000,

In [8]:
test_data[-1]

  f = torch.load(pt_path)


(tensor([[[ 0.8906, -1.9844, -1.9766,  ..., -2.0000, -2.0000, -1.3047],
          [-2.0000, -1.3672, -1.9297,  ..., -1.9922, -1.9922,  0.4844],
          [-2.0000, -2.0000, -1.4219,  ..., -1.9922, -1.9688,  0.8750],
          ...,
          [-2.0000, -2.0000, -2.0000,  ..., -1.8984, -1.9062,  1.8125],
          [-2.0000, -2.0000, -2.0000,  ..., -2.0000,  0.8438, -0.8438],
          [-2.0000, -2.0000, -2.0000,  ..., -2.0000, -2.0000,  2.0000]],
 
         [[ 0.0156, -1.8984, -1.9766,  ..., -2.0000, -2.0000, -0.2422],
          [-2.0000, -0.0156, -1.8828,  ..., -2.0000, -2.0000, -0.2422],
          [-2.0000, -2.0000, -1.0312,  ..., -2.0000, -2.0000,  0.9062],
          ...,
          [-2.0000, -2.0000, -2.0000,  ..., -1.0781, -1.9062,  0.9844],
          [-2.0000, -2.0000, -2.0000,  ..., -2.0000, -0.4844,  0.4844],
          [-2.0000, -2.0000, -2.0000,  ..., -2.0000, -2.0000,  2.0000]],
 
         [[-1.5625, -1.7500, -1.8594,  ..., -2.0000, -2.0000,  0.9688],
          [-2.0000, -1.1094,

In [9]:
train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=1)
test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=1)

In [10]:
# Change features
IN_CHANNELS = 32
NUM_CLASSES = NUM_LAYER_LLM
device = 'cuda'

accuracy = torchmetrics.Accuracy(task='multiclass', num_classes=NUM_CLASSES).to(device)
recall = torchmetrics.Recall(task='multiclass', num_classes=NUM_CLASSES).to(device)
precision = torchmetrics.Precision(task='multiclass', num_classes=NUM_CLASSES).to(device)
auroc = torchmetrics.AUROC(task='multiclass', num_classes=NUM_CLASSES).to(device)

model.features[0][0] = nn.Conv2d(IN_CHANNELS, 96, kernel_size=(4,4), stride=(4,4))
model.classifier[2] = nn.Linear(768, NUM_CLASSES)
model.to(device)
model

ConvNeXt(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(32, 96, kernel_size=(4, 4), stride=(4, 4))
      (1): LayerNorm2d((96,), eps=1e-06, elementwise_affine=True)
    )
    (1): Sequential(
      (0): CNBlock(
        (block): Sequential(
          (0): Conv2d(96, 96, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=96)
          (1): Permute()
          (2): LayerNorm((96,), eps=1e-06, elementwise_affine=True)
          (3): Linear(in_features=96, out_features=384, bias=True)
          (4): GELU(approximate='none')
          (5): Linear(in_features=384, out_features=96, bias=True)
          (6): Permute()
        )
        (stochastic_depth): StochasticDepth(p=0.0, mode=row)
      )
      (1): CNBlock(
        (block): Sequential(
          (0): Conv2d(96, 96, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=96)
          (1): Permute()
          (2): LayerNorm((96,), eps=1e-06, elementwise_affine=True)
          (3): Linear(in_features

In [11]:
optimizer = torch.optim.Adam(model.parameters(), lr=5e-2)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.5)

num_epochs = 3
for epoch in range(num_epochs):
    tr_loss = 0
    nb_tr_steps = 0
    for x, y in tqdm(train_dataloader, desc='Training: '):
        x = x.to(device)
        y = y.to(device)
        
        optimizer.zero_grad()

        try:
            logits = model(x)
        except Exception as e:
            print(x.shape, y)
            raise e

        loss = F.cross_entropy(logits, y)

        loss.backward()
        
        optimizer.step()

        tr_loss += loss.item()
        nb_tr_steps += 1
    scheduler.step()

    print(f'Epoch {epoch}: Training Loss:', tr_loss / nb_tr_steps)

    model.eval()
    val_acc = []
    val_prec = []
    val_rec = []
    val_auroc = []
    for x, y in tqdm(test_dataloader, desc='Validation'):
        x = x.to(device)
        y = y.to(device)
        
        logits = model(x)
        prediction = torch.argmax(logits, dim=-1)

        val_acc.append(accuracy(prediction, y))
        val_prec.append(precision(prediction, y))
        val_rec.append(recall(prediction, y))
        # val_auroc.append(auroc(prediction, y))

    print(f'Epoch {epoch}: Validation Accuracy:', sum(val_acc) / len(val_acc))
    print(f'Epoch {epoch}: Validation Precision:', sum(val_prec) / len(val_prec))
    print(f'Epoch {epoch}: Validation Recall:', sum(val_rec) / len(val_rec))
    # print(f'Epoch {epoch}: Validation AUROC:', sum(val_auroc) / len(val_auroc))

  f = torch.load(pt_path)

raining: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5552/5552 [02:52<00:00, 32.13it/s]

Epoch 0: Training Loss: 1.639172484960985



alidation: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1388/1388 [00:20<00:00, 67.40it/s]

Epoch 0: Validation Accuracy: tensor(0.2500, device='cuda:0')
Epoch 0: Validation Precision: tensor(0.2500, device='cuda:0')
Epoch 0: Validation Recall: tensor(0.2500, device='cuda:0')



raining: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5552/5552 [02:41<00:00, 34.37it/s]

Epoch 1: Training Loss: 1.4343337866152912



alidation: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1388/1388 [00:20<00:00, 66.92it/s]

Epoch 1: Validation Accuracy: tensor(0.2500, device='cuda:0')
Epoch 1: Validation Precision: tensor(0.2500, device='cuda:0')
Epoch 1: Validation Recall: tensor(0.2500, device='cuda:0')



raining: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5552/5552 [02:46<00:00, 33.33it/s]

Epoch 2: Training Loss: 1.396711540131988


Validation: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1388/1388 [00:20<00:00, 66.14it/s]

Epoch 2: Validation Accuracy: tensor(0.2500, device='cuda:0')
Epoch 2: Validation Precision: tensor(0.2500, device='cuda:0')
Epoch 2: Validation Recall: tensor(0.2500, device='cuda:0')





## Experiment: Classify Correctness

In [66]:
# Create dataset
LAYERS_INCLUDED = [1, 11, 21, 31]
NUM_LAYER_LLM = len(LAYERS_INCLUDED)

class AttentionMapCorrectnessDataset(torch.utils.data.Dataset):
    def __init__(self, root, annotations_df, transform=None, target_transform=None):
        self.root = Path(root)
        self.files = annotations_df
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        f_idx = int(idx)
        pt_path = self.root / self.files['filename'][f_idx]
        f = torch.load(pt_path)
        random_layer = 3
        layer_idx = LAYERS_INCLUDED[random_layer]  ## We can select the specific layer on which we want to test correctness.
        heads = f[layer_idx]
        if self.transform:
            heads = self.transform(heads)
        bucket_label = torch.tensor(self.files['prediction'][f_idx] == self.files['correct'][f_idx], dtype=torch.long)
        heads.unsqueeze(dim=0)
        bucket_label.unsqueeze(dim=0)
        return heads.to(torch.float32), bucket_label


In [68]:
dataset_df = pd.read_csv(Path.home() / "Downloads/mmlu_attention_files_list.txt")
dataset_df = dataset_df.sample(frac=1).reset_index()  # shuffle the dataset
dataset_df

Unnamed: 0,index,filename,prediction,correct
0,1435,auxiliary_train\science_elementary\attentions\...,A,A
1,1518,auxiliary_train\science_elementary\attentions\...,C,C
2,1619,auxiliary_train\science_elementary\attentions\...,B,D
3,449,auxiliary_train\arc_hard\attentions\399_attent...,B,B
4,59,auxiliary_train\arc_hard\attentions\1053_atten...,C,A
...,...,...,...,...
1730,361,auxiliary_train\arc_hard\attentions\319_attent...,A,A
1731,455,auxiliary_train\arc_hard\attentions\403_attent...,C,D
1732,895,auxiliary_train\arc_hard\attentions\7_attentio...,D,D
1733,798,auxiliary_train\arc_hard\attentions\712_attent...,B,D


In [69]:
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(0.5),
    transforms.RandomVerticalFlip(0.5),
    transforms.RandomChannelPermutation(),
    transforms.Normalize([0.5], [0.25]),
])

dataset = AttentionMapCorrectnessDataset(Path.home() / "Downloads/mmlu_output/", dataset_df, transform=None)

# Randomly split into training and test set
# indices = torch.randperm(len(dataset)).tolist()
indices = torch.arange(len(dataset)).tolist()
split_idx = int((len(indices)) * 0.8)
train_data = torch.utils.data.Subset(dataset, indices[:split_idx])
test_data = torch.utils.data.Subset(dataset, indices[split_idx:])
len(train_data), len(test_data)

(1388, 347)

In [70]:
split_idx

1388

In [71]:
dataset[split_idx-1]

  f = torch.load(pt_path)


(tensor([[[1.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [6.5234e-01, 3.4570e-01, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [5.4688e-01, 5.8838e-02, 3.9258e-01,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          ...,
          [3.9453e-01, 1.5991e-02, 5.2795e-03,  ..., 3.4375e-01,
           0.0000e+00, 0.0000e+00],
          [2.3145e-01, 1.4267e-03, 3.1128e-03,  ..., 1.3477e-01,
           4.8828e-01, 0.0000e+00],
          [2.5586e-01, 8.9264e-04, 9.8419e-04,  ..., 4.1748e-02,
           7.5684e-02, 5.2734e-01]],
 
         [[1.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [7.1094e-01, 2.8906e-01, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [8.7109e-01, 4.2236e-02, 8.7402e-02,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          ...,
          [4.9805e-01, 4.3640e-03, 1.2894e-03,  ..., 2.832

In [72]:
train_data[1]

  f = torch.load(pt_path)


(tensor([[[1.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [6.5625e-01, 3.4180e-01, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [5.5859e-01, 5.2979e-02, 3.8867e-01,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          ...,
          [4.0234e-01, 1.1230e-02, 2.7161e-03,  ..., 3.7695e-01,
           0.0000e+00, 0.0000e+00],
          [2.4805e-01, 7.9727e-04, 5.0354e-03,  ..., 1.4062e-01,
           5.3125e-01, 0.0000e+00],
          [2.7734e-01, 8.7357e-04, 4.0283e-03,  ..., 5.5176e-02,
           6.9824e-02, 5.1953e-01]],
 
         [[1.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [7.1484e-01, 2.8516e-01, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [8.7891e-01, 3.8574e-02, 8.2520e-02,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          ...,
          [5.1562e-01, 6.8359e-03, 1.2054e-03,  ..., 2.890

In [73]:
test_data[1]

  f = torch.load(pt_path)


(tensor([[[1.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [6.5625e-01, 3.4180e-01, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [5.5469e-01, 5.4199e-02, 3.9062e-01,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          ...,
          [4.5312e-01, 4.7913e-03, 1.5564e-03,  ..., 3.5742e-01,
           0.0000e+00, 0.0000e+00],
          [2.4512e-01, 4.1008e-04, 1.8997e-03,  ..., 1.1572e-01,
           5.6641e-01, 0.0000e+00],
          [2.6953e-01, 5.1498e-04, 3.0060e-03,  ..., 5.4443e-02,
           8.0078e-02, 4.6680e-01]],
 
         [[1.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [7.1484e-01, 2.8516e-01, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [8.7500e-01, 3.9551e-02, 8.4961e-02,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          ...,
          [5.7031e-01, 1.7014e-03, 4.1771e-04,  ..., 2.402

In [74]:
train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=1)
test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=1)

In [75]:
# Change features
IN_CHANNELS = 32
NUM_CLASSES = 2
device = 'cuda'

accuracy = torchmetrics.Accuracy(task='multiclass', num_classes=NUM_CLASSES).to(device)
recall = torchmetrics.Recall(task='multiclass', num_classes=NUM_CLASSES).to(device)
precision = torchmetrics.Precision(task='multiclass', num_classes=NUM_CLASSES).to(device)
auroc = torchmetrics.AUROC(task='multiclass', num_classes=NUM_CLASSES).to(device)

model = torchvision.models.convnext_tiny(weights=convnext.ConvNeXt_Tiny_Weights.DEFAULT)
model.features[0][0] = nn.Conv2d(IN_CHANNELS, 96, kernel_size=(4,4), stride=(4,4))
model.classifier[2] = nn.Linear(768, NUM_CLASSES)
model.to(device)
model

ConvNeXt(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(32, 96, kernel_size=(4, 4), stride=(4, 4))
      (1): LayerNorm2d((96,), eps=1e-06, elementwise_affine=True)
    )
    (1): Sequential(
      (0): CNBlock(
        (block): Sequential(
          (0): Conv2d(96, 96, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=96)
          (1): Permute()
          (2): LayerNorm((96,), eps=1e-06, elementwise_affine=True)
          (3): Linear(in_features=96, out_features=384, bias=True)
          (4): GELU(approximate='none')
          (5): Linear(in_features=384, out_features=96, bias=True)
          (6): Permute()
        )
        (stochastic_depth): StochasticDepth(p=0.0, mode=row)
      )
      (1): CNBlock(
        (block): Sequential(
          (0): Conv2d(96, 96, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=96)
          (1): Permute()
          (2): LayerNorm((96,), eps=1e-06, elementwise_affine=True)
          (3): Linear(in_features

In [76]:
optimizer = torch.optim.Adam(model.parameters(), lr=5e-2)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.5)

num_epochs = 3
for epoch in range(num_epochs):
    tr_loss = 0
    nb_tr_steps = 0
    for x, y in tqdm(train_dataloader, desc='Training: '):
        x = x.to(device)
        y = y.to(device)
        
        optimizer.zero_grad()

        try:
            logits = model(x)
        except Exception as e:
            print(x.shape, y)
            raise e

        loss = F.cross_entropy(logits, y)

        loss.backward()
        
        optimizer.step()

        tr_loss += loss.item()
        nb_tr_steps += 1
    scheduler.step()

    print(f'Epoch {epoch}: Training Loss:', tr_loss / nb_tr_steps)

    model.eval()
    val_acc = []
    val_prec = []
    val_rec = []
    val_auroc = []
    for x, y in tqdm(test_dataloader, desc='Validation'):
        x = x.to(device)
        y = y.to(device)
        
        logits = model(x)
        prediction = torch.argmax(logits, dim=-1)

        val_acc.append(accuracy(prediction, y))
        val_prec.append(precision(prediction, y))
        val_rec.append(recall(prediction, y))
        # val_auroc.append(auroc(prediction, y))

    print(f'Epoch {epoch}: Validation Accuracy:', sum(val_acc) / len(val_acc))
    print(f'Epoch {epoch}: Validation Precision:', sum(val_prec) / len(val_prec))
    print(f'Epoch {epoch}: Validation Recall:', sum(val_rec) / len(val_rec))
    # print(f'Epoch {epoch}: Validation AUROC:', sum(val_auroc) / len(val_auroc))

  f = torch.load(pt_path)

raining: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1388/1388 [00:42<00:00, 32.84it/s]

Epoch 0: Training Loss: 1.244797036087731



alidation: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 347/347 [00:05<00:00, 65.03it/s]

Epoch 0: Validation Accuracy: tensor(0.7666, device='cuda:0')
Epoch 0: Validation Precision: tensor(0.7666, device='cuda:0')
Epoch 0: Validation Recall: tensor(0.7666, device='cuda:0')



raining: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1388/1388 [00:41<00:00, 33.47it/s]

Epoch 1: Training Loss: 0.6366801301050282



alidation: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 347/347 [00:05<00:00, 64.71it/s]

Epoch 1: Validation Accuracy: tensor(0.7666, device='cuda:0')
Epoch 1: Validation Precision: tensor(0.7666, device='cuda:0')
Epoch 1: Validation Recall: tensor(0.7666, device='cuda:0')



raining: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1388/1388 [00:42<00:00, 32.78it/s]

Epoch 2: Training Loss: 0.6105254676396977


Validation: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 347/347 [00:05<00:00, 65.83it/s]

Epoch 2: Validation Accuracy: tensor(0.7666, device='cuda:0')
Epoch 2: Validation Precision: tensor(0.7666, device='cuda:0')
Epoch 2: Validation Recall: tensor(0.7666, device='cuda:0')





## Experiment: Classify Dataset

In [128]:
# Create dataset
LAYERS_INCLUDED = [1, 11, 21, 31]
NUM_LAYER_LLM = len(LAYERS_INCLUDED)

class AttentionMapDFDataset(torch.utils.data.Dataset):
    def __init__(self, root, annotations_df, transform=None, target_transform=None):
        self.root = Path(root)
        self.files = annotations_df
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        f_idx = int(idx)
        pt_path = self.root / self.files['filename'][f_idx]
        f = torch.load(pt_path)
        random_layer = 1
        layer_idx = LAYERS_INCLUDED[random_layer]
        heads = f[layer_idx]
        if self.transform:
            heads = self.transform(heads)
        bucket_label = torch.tensor(self.files['dataset'][f_idx] == 'science_elementary', dtype=torch.long)
        heads.unsqueeze(dim=0)
        bucket_label.unsqueeze(dim=0)
        return heads.to(torch.float32), bucket_label


In [129]:
dataset_df = pd.read_csv(Path.home() / "Downloads/mmlu_attention_files_list.txt")
dataset_df['dataset'] = dataset_df['filename'].apply(lambda filename: str(Path(filename).parent.parent.name))
dataset_es = dataset_df[dataset_df['dataset'] == 'science_elementary'].reset_index()
dataset_df = dataset_df.sample(frac=1).reset_index()
dataset_df

Unnamed: 0,index,filename,prediction,correct,dataset
0,757,auxiliary_train\arc_hard\attentions\676_attent...,D,C,arc_hard
1,47,auxiliary_train\arc_hard\attentions\1042_atten...,B,B,arc_hard
2,1582,auxiliary_train\science_elementary\attentions\...,C,C,science_elementary
3,1484,auxiliary_train\science_elementary\attentions\...,B,B,science_elementary
4,1121,auxiliary_train\science_elementary\attentions\...,B,B,science_elementary
...,...,...,...,...,...
1730,897,auxiliary_train\arc_hard\attentions\801_attent...,B,B,arc_hard
1731,1246,auxiliary_train\science_elementary\attentions\...,A,A,science_elementary
1732,1101,auxiliary_train\arc_hard\attentions\986_attent...,C,C,arc_hard
1733,310,auxiliary_train\arc_hard\attentions\273_attent...,A,A,arc_hard


In [130]:
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(0.5),
    transforms.RandomVerticalFlip(0.5),
    transforms.RandomChannelPermutation(),
    transforms.Normalize([0.5], [0.25]),
])

dataset = AttentionMapDFDataset(Path.home() / "Downloads/mmlu_output/", dataset_df, transform=transform)

# Randomly split into training and test set
# indices = torch.randperm(len(dataset)).tolist()
indices = torch.arange(len(dataset)).tolist()
split_idx = int((len(indices)) * 0.8)
train_data = torch.utils.data.Subset(dataset, indices[:split_idx])
test_data = torch.utils.data.Subset(dataset, indices[split_idx:])
len(train_data), len(test_data)

(1388, 347)

In [131]:
split_idx

1388

In [132]:
dataset[split_idx]

  f = torch.load(pt_path)


(tensor([[[-1.7500, -1.3828, -1.3438,  ..., -1.9766, -1.9844,  0.2344],
          [-2.0000, -1.7500, -1.2500,  ..., -1.9531, -1.9531,  0.5469],
          [-2.0000, -2.0000, -1.4609,  ..., -1.6094, -1.8750, -0.7812],
          ...,
          [-2.0000, -2.0000, -2.0000,  ..., -1.6719, -1.2812,  0.9531],
          [-2.0000, -2.0000, -2.0000,  ..., -2.0000, -1.8047,  1.7969],
          [-2.0000, -2.0000, -2.0000,  ..., -2.0000, -2.0000,  2.0000]],
 
         [[-1.8906, -1.9297, -1.6562,  ..., -1.9766, -1.9609, -0.0078],
          [-2.0000, -1.8438, -1.6797,  ..., -1.9531, -1.9297,  0.1094],
          [-2.0000, -2.0000, -1.6797,  ..., -1.8828, -1.9062, -0.2812],
          ...,
          [-2.0000, -2.0000, -2.0000,  ..., -1.8750, -1.8828,  1.7656],
          [-2.0000, -2.0000, -2.0000,  ..., -2.0000, -1.8828,  1.8906],
          [-2.0000, -2.0000, -2.0000,  ..., -2.0000, -2.0000,  2.0000]],
 
         [[-1.8906, -1.7734, -1.5234,  ..., -1.8125, -1.8672, -0.7109],
          [-2.0000, -1.8750,

In [133]:
train_data[42]

  f = torch.load(pt_path)


(tensor([[[-0.7578, -1.9766, -1.9688,  ..., -1.6797, -1.9297, -1.8984],
          [-0.5156, -1.9375, -1.9531,  ..., -1.7500, -1.8984, -2.0000],
          [-0.2578, -1.8359, -1.8438,  ..., -1.8750, -2.0000, -2.0000],
          ...,
          [ 1.6875, -1.9062, -1.7812,  ..., -2.0000, -2.0000, -2.0000],
          [ 1.9219, -1.9141, -2.0000,  ..., -2.0000, -2.0000, -2.0000],
          [ 2.0000, -2.0000, -2.0000,  ..., -2.0000, -2.0000, -2.0000]],
 
         [[-0.0781, -1.9453, -1.9531,  ..., -1.8516, -1.9219, -1.7812],
          [ 0.6719, -1.7812, -1.8125,  ..., -1.9141, -1.9375, -2.0000],
          [ 0.7031, -1.6875, -1.8594,  ..., -1.7656, -2.0000, -2.0000],
          ...,
          [ 1.7656, -1.9609, -1.8125,  ..., -2.0000, -2.0000, -2.0000],
          [ 1.5156, -1.5156, -2.0000,  ..., -2.0000, -2.0000, -2.0000],
          [ 2.0000, -2.0000, -2.0000,  ..., -2.0000, -2.0000, -2.0000]],
 
         [[ 0.5938, -1.9844, -1.9922,  ..., -1.8750, -1.7266, -1.7734],
          [ 0.5156, -1.9844,

In [134]:
test_data[1]

  f = torch.load(pt_path)


(tensor([[[-2.0000, -2.0000, -2.0000,  ..., -2.0000, -2.0000,  2.0000],
          [-2.0000, -2.0000, -2.0000,  ..., -2.0000, -1.6562,  1.6562],
          [-2.0000, -2.0000, -2.0000,  ..., -1.6406, -1.8281,  1.4688],
          ...,
          [-2.0000, -2.0000, -1.4688,  ..., -1.9219, -1.8906, -0.0391],
          [-2.0000, -1.8203, -1.7969,  ..., -1.8906, -1.8984, -0.1953],
          [-1.8203, -1.9141, -1.8359,  ..., -1.9609, -1.9375, -0.1328]],
 
         [[-2.0000, -2.0000, -2.0000,  ..., -2.0000, -2.0000,  2.0000],
          [-2.0000, -2.0000, -2.0000,  ..., -2.0000, -1.8438,  1.8438],
          [-2.0000, -2.0000, -2.0000,  ..., -1.6562, -1.8359,  1.5000],
          ...,
          [-2.0000, -2.0000, -1.9062,  ..., -1.9844, -1.9844,  0.6094],
          [-2.0000, -1.7812, -1.7734,  ..., -2.0000, -1.9844,  1.3750],
          [-1.5469, -1.2109, -1.9219,  ..., -2.0000, -1.9922,  0.5938]],
 
         [[-2.0000, -2.0000, -2.0000,  ..., -2.0000, -2.0000,  2.0000],
          [-2.0000, -2.0000,

In [135]:
train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=1)
test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=1)

In [136]:
# Change features
IN_CHANNELS = 32
NUM_CLASSES = 2
device = 'cuda'

accuracy = torchmetrics.Accuracy(task='multiclass', num_classes=NUM_CLASSES).to(device)
recall = torchmetrics.Recall(task='multiclass', num_classes=NUM_CLASSES).to(device)
precision = torchmetrics.Precision(task='multiclass', num_classes=NUM_CLASSES).to(device)
auroc = torchmetrics.AUROC(task='multiclass', num_classes=NUM_CLASSES).to(device)

model = torchvision.models.convnext_tiny(weights=convnext.ConvNeXt_Tiny_Weights.DEFAULT)
model.features[0][0] = nn.Conv2d(IN_CHANNELS, 96, kernel_size=(4,4), stride=(4,4))
model.classifier[2] = nn.Linear(768, NUM_CLASSES)
model.to(device)
model

ConvNeXt(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(32, 96, kernel_size=(4, 4), stride=(4, 4))
      (1): LayerNorm2d((96,), eps=1e-06, elementwise_affine=True)
    )
    (1): Sequential(
      (0): CNBlock(
        (block): Sequential(
          (0): Conv2d(96, 96, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=96)
          (1): Permute()
          (2): LayerNorm((96,), eps=1e-06, elementwise_affine=True)
          (3): Linear(in_features=96, out_features=384, bias=True)
          (4): GELU(approximate='none')
          (5): Linear(in_features=384, out_features=96, bias=True)
          (6): Permute()
        )
        (stochastic_depth): StochasticDepth(p=0.0, mode=row)
      )
      (1): CNBlock(
        (block): Sequential(
          (0): Conv2d(96, 96, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=96)
          (1): Permute()
          (2): LayerNorm((96,), eps=1e-06, elementwise_affine=True)
          (3): Linear(in_features

In [137]:
optimizer = torch.optim.Adam(model.parameters(), lr=5e-2)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.1)

num_epochs = 3
for epoch in range(num_epochs):
    tr_loss = 0
    nb_tr_steps = 0
    for x, y in tqdm(train_dataloader, desc='Training: '):
        x = x.to(device)
        y = y.to(device)
        
        optimizer.zero_grad()

        try:
            logits = model(x)
        except Exception as e:
            print(x.shape, y)
            raise e

        loss = F.cross_entropy(logits, y)

        loss.backward()
        
        optimizer.step()

        tr_loss += loss.item()
        nb_tr_steps += 1
    scheduler.step()

    print(f'Epoch {epoch}: Training Loss:', tr_loss / nb_tr_steps)

    model.eval()
    val_acc = []
    val_prec = []
    val_rec = []
    val_auroc = []
    for x, y in tqdm(test_dataloader, desc='Validation'):
        x = x.to(device)
        y = y.to(device)
        
        logits = model(x)
        prediction = torch.argmax(logits, dim=-1)

        val_acc.append(accuracy(prediction, y))
        val_prec.append(precision(prediction, y))
        val_rec.append(recall(prediction, y))
        # val_auroc.append(auroc(prediction, y))

    print(f'Epoch {epoch}: Validation Accuracy:', sum(val_acc) / len(val_acc))
    print(f'Epoch {epoch}: Validation Precision:', sum(val_prec) / len(val_prec))
    print(f'Epoch {epoch}: Validation Recall:', sum(val_rec) / len(val_rec))
    # print(f'Epoch {epoch}: Validation AUROC:', sum(val_auroc) / len(val_auroc))

  f = torch.load(pt_path)

raining: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1388/1388 [00:47<00:00, 29.31it/s]

Epoch 0: Training Loss: 1.287624101500548



alidation: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 347/347 [00:05<00:00, 64.92it/s]

Epoch 0: Validation Accuracy: tensor(0.6340, device='cuda:0')
Epoch 0: Validation Precision: tensor(0.6340, device='cuda:0')
Epoch 0: Validation Recall: tensor(0.6340, device='cuda:0')



raining: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1388/1388 [00:41<00:00, 33.54it/s]

Epoch 1: Training Loss: 1.1428112855701176



alidation: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 347/347 [00:05<00:00, 63.42it/s]

Epoch 1: Validation Accuracy: tensor(0.6340, device='cuda:0')
Epoch 1: Validation Precision: tensor(0.6340, device='cuda:0')
Epoch 1: Validation Recall: tensor(0.6340, device='cuda:0')



raining: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1388/1388 [00:40<00:00, 34.00it/s]

Epoch 2: Training Loss: 0.655031361136725


Validation: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 347/347 [00:05<00:00, 64.97it/s]

Epoch 2: Validation Accuracy: tensor(0.6340, device='cuda:0')
Epoch 2: Validation Precision: tensor(0.6340, device='cuda:0')
Epoch 2: Validation Recall: tensor(0.6340, device='cuda:0')





## Experiment: Classify dataset using all layers

In [142]:
# Create dataset
LAYERS_INCLUDED = [1, 11, 21, 31]
NUM_LAYER_LLM = len(LAYERS_INCLUDED)

class AttentionMapBatchDataset(torch.utils.data.Dataset):
    def __init__(self, root, annotations_df, transform=None, target_transform=None):
        self.root = Path(root)
        self.files = annotations_df
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        f_idx = int(idx)
        pt_path = self.root / self.files['filename'][f_idx]
        f = torch.load(pt_path)
        heads = f.flatten(0, 1)
        if self.transform:
            heads = self.transform(heads)
        bucket_label = torch.tensor(self.files['dataset'][f_idx] == 'science_elementary', dtype=torch.long)
        # heads = heads.unsqueeze(dim=0)
        # bucket_labels = bucket_label.unsqueeze(dim=0)
        return heads.to(torch.float32), bucket_label


In [143]:
dataset_df = pd.read_csv(Path.home() / "Downloads/mmlu_attention_files_list.txt")
dataset_df['dataset'] = dataset_df['filename'].apply(lambda filename: str(Path(filename).parent.parent.name))
dataset_es = dataset_df[dataset_df['dataset'] == 'science_elementary'].reset_index()
dataset_df = dataset_df.sample(frac=1).reset_index()
dataset_df

Unnamed: 0,index,filename,prediction,correct,dataset
0,1052,auxiliary_train\arc_hard\attentions\941_attent...,D,C,arc_hard
1,467,auxiliary_train\arc_hard\attentions\414_attent...,C,C,arc_hard
2,1660,auxiliary_train\science_elementary\attentions\...,C,C,science_elementary
3,1174,auxiliary_train\science_elementary\attentions\...,A,A,science_elementary
4,1368,auxiliary_train\science_elementary\attentions\...,C,C,science_elementary
...,...,...,...,...,...
1730,1325,auxiliary_train\science_elementary\attentions\...,D,D,science_elementary
1731,1553,auxiliary_train\science_elementary\attentions\...,A,A,science_elementary
1732,818,auxiliary_train\arc_hard\attentions\730_attent...,A,A,arc_hard
1733,159,auxiliary_train\arc_hard\attentions\137_attent...,B,B,arc_hard


In [144]:
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(0.5),
    transforms.RandomVerticalFlip(0.5),
    transforms.RandomChannelPermutation(),
    transforms.Normalize([0.5], [0.25]),
])

dataset = AttentionMapBatchDataset(Path.home() / "Downloads/mmlu_output/", dataset_df, transform=None)

# Randomly split into training and test set
# indices = torch.randperm(len(dataset)).tolist()
indices = torch.arange(len(dataset)).tolist()
split_idx = int((len(indices)) * 0.8)
train_data = torch.utils.data.Subset(dataset, indices[:split_idx])
test_data = torch.utils.data.Subset(dataset, indices[split_idx:])
len(train_data), len(test_data)

(1384, 351)

In [145]:
split_idx

1384

In [146]:
dataset[split_idx-42]

  f = torch.load(pt_path)


(tensor([[[1.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [6.9531e-01, 3.0469e-01, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [5.3125e-01, 4.1992e-01, 5.0293e-02,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          ...,
          [6.7139e-03, 1.2398e-05, 7.8201e-05,  ..., 2.6562e-01,
           0.0000e+00, 0.0000e+00],
          [4.9805e-02, 6.2943e-04, 5.3787e-04,  ..., 4.8047e-01,
           8.0078e-02, 0.0000e+00],
          [2.0294e-03, 5.0068e-06, 8.0466e-06,  ..., 7.6953e-01,
           5.1025e-02, 1.7944e-02]],
 
         [[1.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [9.0625e-01, 9.5215e-02, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [8.0078e-01, 1.7871e-01, 2.2461e-02,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          ...,
          [4.8828e-02, 6.0425e-03, 4.6387e-03,  ..., 4.052

In [147]:
train_data[0]

  f = torch.load(pt_path)


(tensor([[[1.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [6.9531e-01, 3.0469e-01, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [5.3125e-01, 4.1992e-01, 5.0293e-02,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          ...,
          [2.5940e-03, 2.5511e-05, 3.8719e-04,  ..., 1.6895e-01,
           0.0000e+00, 0.0000e+00],
          [2.8687e-02, 3.1090e-04, 2.0294e-03,  ..., 3.9258e-01,
           8.6426e-02, 0.0000e+00],
          [7.2861e-04, 2.6822e-07, 2.6584e-05,  ..., 7.1484e-01,
           5.8105e-02, 2.0630e-02]],
 
         [[1.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [9.0625e-01, 9.5215e-02, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [8.0078e-01, 1.7871e-01, 2.2461e-02,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          ...,
          [4.7119e-02, 3.6926e-03, 4.4250e-03,  ..., 5.590

In [148]:
test_data[0][0]

  f = torch.load(pt_path)


tensor([[[1.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
          0.0000e+00, 0.0000e+00],
         [6.9531e-01, 3.0469e-01, 0.0000e+00,  ..., 0.0000e+00,
          0.0000e+00, 0.0000e+00],
         [5.3125e-01, 4.1992e-01, 5.0293e-02,  ..., 0.0000e+00,
          0.0000e+00, 0.0000e+00],
         ...,
         [1.5747e-02, 2.0409e-04, 4.9210e-04,  ..., 2.9102e-01,
          0.0000e+00, 0.0000e+00],
         [7.0312e-02, 8.0872e-04, 5.9128e-04,  ..., 4.9609e-01,
          8.2520e-02, 0.0000e+00],
         [5.7373e-03, 2.8968e-05, 4.7386e-06,  ..., 7.9688e-01,
          5.0537e-02, 1.8188e-02]],

        [[1.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
          0.0000e+00, 0.0000e+00],
         [9.0625e-01, 9.5215e-02, 0.0000e+00,  ..., 0.0000e+00,
          0.0000e+00, 0.0000e+00],
         [8.0078e-01, 1.7676e-01, 2.2461e-02,  ..., 0.0000e+00,
          0.0000e+00, 0.0000e+00],
         ...,
         [7.3242e-02, 1.7700e-02, 4.6082e-03,  ..., 3.5400e-02,
          0.000

In [149]:
train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=1)
test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=1)

In [150]:
# Change features
IN_CHANNELS = 1024
NUM_CLASSES = 2
device = 'cuda'

accuracy = torchmetrics.Accuracy(task='multiclass', num_classes=NUM_CLASSES).to(device)
recall = torchmetrics.Recall(task='multiclass', num_classes=NUM_CLASSES).to(device)
precision = torchmetrics.Precision(task='multiclass', num_classes=NUM_CLASSES).to(device)
auroc = torchmetrics.AUROC(task='multiclass', num_classes=NUM_CLASSES).to(device)

model = torchvision.models.convnext_tiny()
model.features[0][0] = nn.Conv2d(IN_CHANNELS, 96, kernel_size=(4,4), stride=(4,4))
model.classifier[2] = nn.Linear(768, NUM_CLASSES)
model.to(device)
model

ConvNeXt(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(1024, 96, kernel_size=(4, 4), stride=(4, 4))
      (1): LayerNorm2d((96,), eps=1e-06, elementwise_affine=True)
    )
    (1): Sequential(
      (0): CNBlock(
        (block): Sequential(
          (0): Conv2d(96, 96, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=96)
          (1): Permute()
          (2): LayerNorm((96,), eps=1e-06, elementwise_affine=True)
          (3): Linear(in_features=96, out_features=384, bias=True)
          (4): GELU(approximate='none')
          (5): Linear(in_features=384, out_features=96, bias=True)
          (6): Permute()
        )
        (stochastic_depth): StochasticDepth(p=0.0, mode=row)
      )
      (1): CNBlock(
        (block): Sequential(
          (0): Conv2d(96, 96, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=96)
          (1): Permute()
          (2): LayerNorm((96,), eps=1e-06, elementwise_affine=True)
          (3): Linear(in_featur

In [154]:
optimizer = torch.optim.Adam(model.parameters(), lr=5e-2)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.25)

num_epochs = 6
for epoch in range(num_epochs):
    tr_loss = 0
    nb_tr_steps = 0
    for x, y in tqdm(train_dataloader, desc='Training: '):
        x = x.to(device)
        y = y.to(device)
        
        optimizer.zero_grad()

        try:
            logits = model(x)
        except Exception as e:
            print(x.shape, y)
            raise e

        loss = F.cross_entropy(logits, y)

        loss.backward()
        
        optimizer.step()

        tr_loss += loss.item()
        nb_tr_steps += 1
    scheduler.step()

    print(f'Epoch {epoch}: Training Loss:', tr_loss / nb_tr_steps)

    model.eval()
    val_acc = []
    val_prec = []
    val_rec = []
    val_auroc = []
    for x, y in tqdm(test_dataloader, desc='Validation'):
        x = x.to(device)
        y = y.to(device)
        
        logits = model(x)
        prediction = torch.argmax(logits, dim=-1)

        val_acc.append(accuracy(prediction, y))
        val_prec.append(precision(prediction, y))
        val_rec.append(recall(prediction, y))
        # val_auroc.append(auroc(prediction, y))

    print(f'Epoch {epoch}: Validation Accuracy:', sum(val_acc) / len(val_acc))
    print(f'Epoch {epoch}: Validation Precision:', sum(val_prec) / len(val_prec))
    print(f'Epoch {epoch}: Validation Recall:', sum(val_rec) / len(val_rec))
    # print(f'Epoch {epoch}: Validation AUROC:', sum(val_auroc) / len(val_auroc))

  f = torch.load(pt_path)

raining: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1384/1384 [00:42<00:00, 32.88it/s]

Epoch 0: Training Loss: 1.1579424295978606



alidation: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 351/351 [00:05<00:00, 62.07it/s]

Epoch 0: Validation Accuracy: tensor(0.3276, device='cuda:0')
Epoch 0: Validation Precision: tensor(0.3276, device='cuda:0')
Epoch 0: Validation Recall: tensor(0.3276, device='cuda:0')



raining: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1384/1384 [00:43<00:00, 31.48it/s]

Epoch 1: Training Loss: 1.6788336689970602



alidation: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 351/351 [00:05<00:00, 64.11it/s]

Epoch 1: Validation Accuracy: tensor(0.3276, device='cuda:0')
Epoch 1: Validation Precision: tensor(0.3276, device='cuda:0')
Epoch 1: Validation Recall: tensor(0.3276, device='cuda:0')



raining: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1384/1384 [00:40<00:00, 34.36it/s]

Epoch 2: Training Loss: 0.811341903745736



alidation: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 351/351 [00:05<00:00, 64.64it/s]

Epoch 2: Validation Accuracy: tensor(0.5954, device='cuda:0')
Epoch 2: Validation Precision: tensor(0.5954, device='cuda:0')
Epoch 2: Validation Recall: tensor(0.5954, device='cuda:0')



raining: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1384/1384 [00:41<00:00, 33.26it/s]

Epoch 3: Training Loss: 0.6798322808584252



alidation: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 351/351 [00:05<00:00, 64.70it/s]

Epoch 3: Validation Accuracy: tensor(0.5954, device='cuda:0')
Epoch 3: Validation Precision: tensor(0.5954, device='cuda:0')
Epoch 3: Validation Recall: tensor(0.5954, device='cuda:0')



raining: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1384/1384 [00:43<00:00, 31.74it/s]

Epoch 4: Training Loss: 0.6448912463134768



alidation: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 351/351 [00:05<00:00, 62.56it/s]

Epoch 4: Validation Accuracy: tensor(0.5954, device='cuda:0')
Epoch 4: Validation Precision: tensor(0.5954, device='cuda:0')
Epoch 4: Validation Recall: tensor(0.5954, device='cuda:0')



raining: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1384/1384 [00:40<00:00, 34.18it/s]

Epoch 5: Training Loss: 0.6464135715619505


Validation: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 351/351 [00:05<00:00, 64.18it/s]

Epoch 5: Validation Accuracy: tensor(0.5954, device='cuda:0')
Epoch 5: Validation Precision: tensor(0.5954, device='cuda:0')
Epoch 5: Validation Recall: tensor(0.5954, device='cuda:0')





# END