In [1]:
import torch
import torch.nn as nn
import torchvision
from torchvision.models import convnext

In [2]:
model = torchvision.models.convnext_tiny(weights=convnext.ConvNeXt_Tiny_Weights.DEFAULT)
model

ConvNeXt(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 96, kernel_size=(4, 4), stride=(4, 4))
      (1): LayerNorm2d((96,), eps=1e-06, elementwise_affine=True)
    )
    (1): Sequential(
      (0): CNBlock(
        (block): Sequential(
          (0): Conv2d(96, 96, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=96)
          (1): Permute()
          (2): LayerNorm((96,), eps=1e-06, elementwise_affine=True)
          (3): Linear(in_features=96, out_features=384, bias=True)
          (4): GELU(approximate='none')
          (5): Linear(in_features=384, out_features=96, bias=True)
          (6): Permute()
        )
        (stochastic_depth): StochasticDepth(p=0.0, mode=row)
      )
      (1): CNBlock(
        (block): Sequential(
          (0): Conv2d(96, 96, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=96)
          (1): Permute()
          (2): LayerNorm((96,), eps=1e-06, elementwise_affine=True)
          (3): Linear(in_features=

In [3]:
from torchvision import transforms
preprocess = transforms.Compose([
    transforms.ToTensor(),
])
from PIL import Image
input_image = Image.open("dog.jpg")
input_tensor = preprocess(input_image)
input_tensor

tensor([[[0.0510, 0.0667, 0.0588,  ..., 0.0745, 0.0784, 0.0706],
         [0.0627, 0.0706, 0.0588,  ..., 0.0745, 0.0902, 0.0824],
         [0.0667, 0.0706, 0.0667,  ..., 0.0667, 0.0902, 0.0863],
         ...,
         [0.4039, 0.2824, 0.2235,  ..., 0.2510, 0.3412, 0.4863],
         [0.4353, 0.3373, 0.2863,  ..., 0.2588, 0.2980, 0.3725],
         [0.3686, 0.2784, 0.2627,  ..., 0.2118, 0.2235, 0.2980]],

        [[0.1176, 0.1333, 0.1137,  ..., 0.1216, 0.1294, 0.1216],
         [0.1059, 0.1137, 0.1020,  ..., 0.1216, 0.1373, 0.1294],
         [0.0902, 0.0941, 0.0902,  ..., 0.1137, 0.1373, 0.1333],
         ...,
         [0.4863, 0.3843, 0.3451,  ..., 0.3686, 0.4196, 0.5294],
         [0.4824, 0.3922, 0.3647,  ..., 0.3765, 0.3569, 0.4000],
         [0.4157, 0.3333, 0.3412,  ..., 0.3294, 0.2824, 0.3255]],

        [[0.0471, 0.0706, 0.0549,  ..., 0.0667, 0.0902, 0.0824],
         [0.0745, 0.0902, 0.0784,  ..., 0.0431, 0.0588, 0.0510],
         [0.0745, 0.0784, 0.0745,  ..., 0.0353, 0.0588, 0.

In [4]:
input_batch = input_tensor.unsqueeze(0)

with torch.no_grad():
    output = model(input_batch)
print(output[0])

tensor([ 7.4499e-02,  2.6175e-01, -3.7133e-01, -4.0203e-01, -4.7379e-01,
         8.9643e-02,  1.7304e-01,  4.1468e-01,  5.9469e-01,  2.5760e-01,
         7.0624e-01,  3.3094e-01,  2.3178e-01,  8.5878e-01,  3.3874e-01,
         3.8268e-01,  1.8625e-01,  1.6838e-01,  1.1058e+00,  6.1064e-01,
         5.3451e-01,  2.2470e-01,  1.9841e-01,  9.0423e-01,  2.9987e-01,
        -3.0218e-01,  2.1656e-01,  1.3040e-02, -1.4743e-01,  2.3028e-01,
         3.6111e-02, -1.3268e-01,  3.4113e-01, -2.3189e-01, -1.9823e-01,
        -1.2049e-01, -3.2940e-02, -1.9624e-01,  7.2636e-02, -1.8214e-01,
         1.9801e-01, -1.8248e-01,  1.0644e-01, -2.4149e-01, -1.0319e-01,
         4.1211e-02,  4.3062e-01, -9.9799e-01, -2.0288e-01,  2.0561e-01,
         3.6348e-01, -8.0752e-01, -2.4530e-02,  5.2951e-02, -1.3128e-01,
        -3.8987e-01, -1.5092e-01, -1.4595e-01,  3.0709e-01, -1.3847e-01,
         2.1710e-01,  5.7917e-02, -2.8042e-01,  7.5014e-01, -1.3664e-01,
        -7.1815e-01, -2.8412e-01, -2.2380e-01, -4.9

In [6]:
# Change features
IN_CHANNELS = 32
NUM_CLASSES = 32
device = 'cuda'
model.features[0][0] = nn.Conv2d(IN_CHANNELS, 96, kernel_size=(4,4), stride=(4,4))
model.classifier[2] = nn.Linear(768, NUM_CLASSES)
model.to(device)
model

ConvNeXt(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(32, 96, kernel_size=(4, 4), stride=(4, 4))
      (1): LayerNorm2d((96,), eps=1e-06, elementwise_affine=True)
    )
    (1): Sequential(
      (0): CNBlock(
        (block): Sequential(
          (0): Conv2d(96, 96, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=96)
          (1): Permute()
          (2): LayerNorm((96,), eps=1e-06, elementwise_affine=True)
          (3): Linear(in_features=96, out_features=384, bias=True)
          (4): GELU(approximate='none')
          (5): Linear(in_features=384, out_features=96, bias=True)
          (6): Permute()
        )
        (stochastic_depth): StochasticDepth(p=0.0, mode=row)
      )
      (1): CNBlock(
        (block): Sequential(
          (0): Conv2d(96, 96, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=96)
          (1): Permute()
          (2): LayerNorm((96,), eps=1e-06, elementwise_affine=True)
          (3): Linear(in_features

In [7]:
# Create dataset
import pandas as pd
from pathlib import Path
NUM_LAYER_LLM = 32

class AttentionMapDataset(torch.utils.data.Dataset):
    def __init__(self, root, annotations_file, transform=None, target_transform=None):
        self.root = Path(root)
        self.files = pd.read_csv(annotations_file)
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.files) * NUM_LAYER_LLM

    def __getitem__(self, idx):
        f_idx = int(idx // NUM_LAYER_LLM)
        pt_path = self.root / self.files['filename'][f_idx]
        f = torch.load(pt_path)
        random_layer = idx % NUM_LAYER_LLM
        heads = f[random_layer][0]
        if self.transform:
            heads = self.transform(heads)
        bucket_label = int(random_layer // 4)
        return heads.to(torch.float32), bucket_label


In [8]:

class AttentionMapBatchDataset(torch.utils.data.Dataset):
    def __init__(self, root, annotations_file, transform=None, target_transform=None):
        self.root = Path(root)
        self.files = pd.read_csv(annotations_file)
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        pt_path = self.root / self.files['filename'][idx]
        f = torch.load(pt_path)
        heads = torch.stack(f).squeeze()
        if self.transform:
            heads = self.transform(heads)
        bucket_labels = torch.arange(NUM_LAYER_LLM)
        return heads.to(torch.float32), bucket_labels.to(torch.long)


In [9]:
from torchvision.transforms import v2 as transforms
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(0.5),
    transforms.RandomVerticalFlip(0.5),
    transforms.RandomChannelPermutation(),
    transforms.Normalize([0.5], [0.25]),
])

dataset = AttentionMapBatchDataset(Path.home() / "Downloads/mmlu_output/", Path.home() / "Downloads/mmlu_attention_files_list.txt", transform=transform)

# Randomly split into training and test set
indices = torch.randperm(len(dataset)).tolist()
split_idx = int(len(indices) * 0.8)
training_data = torch.utils.data.Subset(dataset, indices[:split_idx])
test_data = torch.utils.data.Subset(dataset, indices[split_idx:])

In [10]:
len(training_data), len(test_data)

(1068, 267)

In [11]:
train_dataloader = torch.utils.data.DataLoader(training_data, batch_size=1)
test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=1)

In [12]:
training_data.__getitem__(0)

  f = torch.load(pt_path)


(tensor([[[[-1.9688, -2.0000, -2.0000,  ..., -1.6094, -0.2891, -0.2578],
           [-1.9922, -2.0000, -2.0000,  ...,  1.7656, -1.8906, -2.0000],
           [-1.9609, -2.0000, -2.0000,  ...,  0.4531, -2.0000, -2.0000],
           ...,
           [-1.7266,  1.5938, -1.8672,  ..., -2.0000, -2.0000, -2.0000],
           [-0.5156,  0.5156, -2.0000,  ..., -2.0000, -2.0000, -2.0000],
           [ 2.0000, -2.0000, -2.0000,  ..., -2.0000, -2.0000, -2.0000]],
 
          [[ 1.3906, -2.0000, -2.0000,  ..., -1.9844, -2.0000, -1.9922],
           [ 1.1562, -1.9922, -2.0000,  ..., -1.9844, -1.9766, -2.0000],
           [ 1.4531, -2.0000, -2.0000,  ..., -1.9766, -2.0000, -2.0000],
           ...,
           [ 1.4219, -1.7188, -1.7031,  ..., -2.0000, -2.0000, -2.0000],
           [ 1.7656, -1.7734, -2.0000,  ..., -2.0000, -2.0000, -2.0000],
           [ 2.0000, -2.0000, -2.0000,  ..., -2.0000, -2.0000, -2.0000]],
 
          [[-1.9531, -2.0000, -2.0000,  ..., -1.8828,  0.9688, -1.6094],
           [-

In [13]:
model(training_data.__getitem__(0)[0])

  f = torch.load(pt_path)


tensor([[-0.0205,  0.1398,  0.1044,  ...,  0.2152, -0.2854, -0.0206],
        [-0.1604,  0.2341,  0.1607,  ...,  0.3198, -0.2941,  0.0870],
        [ 0.0227,  0.1755, -0.0862,  ...,  0.2575, -0.1685,  0.1081],
        ...,
        [-0.0512,  0.2394,  0.0767,  ...,  0.3011, -0.2062,  0.0551],
        [-0.0148,  0.1947,  0.0511,  ...,  0.2493, -0.2675, -0.0404],
        [-0.1721,  0.1384,  0.1960,  ...,  0.3207,  0.0209,  0.0763]],
       device='cuda:0', grad_fn=<AddmmBackward0>)

In [14]:
import torch.nn.functional as F
import torchmetrics
from tqdm import tqdm
NUM_CLASSES = 32
accuracy = torchmetrics.Accuracy(task='multiclass', num_classes=NUM_CLASSES).to(device)
recall = torchmetrics.Recall(task='multiclass', num_classes=NUM_CLASSES).to(device)
precision = torchmetrics.Precision(task='multiclass', num_classes=NUM_CLASSES).to(device)
auroc = torchmetrics.AUROC(task='multiclass', num_classes=NUM_CLASSES).to(device)

In [None]:
params = model.parameters()
optimizer = torch.optim.Adam(params, lr=5e-2)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.5)

num_epochs = 10
for epoch in range(num_epochs):
    tr_loss = 0
    nb_tr_steps = 0
    for x, y in tqdm(train_dataloader, desc='Training: '):
        x = x.squeeze().to(device)
        y = y.squeeze().to(device)
        
        optimizer.zero_grad()

        try:
            logits = model(x)
        except Exception as e:
            print(x.shape, y)
            raise e

        loss = F.cross_entropy(logits, y)

        loss.backward()
        
        optimizer.step()

        tr_loss += loss.item()
        nb_tr_steps += 1
    scheduler.step()

    print(f'Epoch {epoch}: Training Loss:', tr_loss / nb_tr_steps)

    model.eval()
    val_acc = []
    val_prec = []
    val_rec = []
    val_auroc = []
    for x, y in tqdm(test_dataloader, desc='Validation'):
        x = x.squeeze().to(device)
        y = y.squeeze().to(device)
        
        logits = model(x)
        prediction = torch.argmax(logits, dim=-1)

        val_acc.append(accuracy(prediction, y))
        val_prec.append(precision(prediction, y))
        val_rec.append(recall(prediction, y))
        # val_auroc.append(auroc(prediction, y))

    print(f'Epoch {epoch}: Validation Accuracy:', sum(val_acc) / len(val_acc))
    print(f'Epoch {epoch}: Validation Precision:', sum(val_prec) / len(val_prec))
    print(f'Epoch {epoch}: Validation Recall:', sum(val_rec) / len(val_rec))
    # print(f'Epoch {epoch}: Validation AUROC:', sum(val_auroc) / len(val_auroc))

  f = torch.load(pt_path)
Training:  21%|██████████████▊                                                      | 229/1068 [02:10<03:51,  3.62it/s]

In [79]:
model = torchvision.models.convnext_tiny(weights=convnext.ConvNeXt_Tiny_Weights.DEFAULT)
# Change features
IN_CHANNELS = 32*32
device = 'cuda'
model.features[0][0] = nn.Conv2d(IN_CHANNELS, 96, kernel_size=(4,4), stride=(4,4))
model.classifier[2] = nn.Linear(768, 2)
model.to(device)
model

ConvNeXt(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(1024, 96, kernel_size=(4, 4), stride=(4, 4))
      (1): LayerNorm2d((96,), eps=1e-06, elementwise_affine=True)
    )
    (1): Sequential(
      (0): CNBlock(
        (block): Sequential(
          (0): Conv2d(96, 96, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=96)
          (1): Permute()
          (2): LayerNorm((96,), eps=1e-06, elementwise_affine=True)
          (3): Linear(in_features=96, out_features=384, bias=True)
          (4): GELU(approximate='none')
          (5): Linear(in_features=384, out_features=96, bias=True)
          (6): Permute()
        )
        (stochastic_depth): StochasticDepth(p=0.0, mode=row)
      )
      (1): CNBlock(
        (block): Sequential(
          (0): Conv2d(96, 96, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=96)
          (1): Permute()
          (2): LayerNorm((96,), eps=1e-06, elementwise_affine=True)
          (3): Linear(in_featur

In [80]:
class AllAttentionMapDataset(torch.utils.data.Dataset):
    def __init__(self, root, annotations_file, transform=None, target_transform=None):
        self.root = Path(root)
        self.files = pd.read_csv(annotations_file)
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        item = self.files.iloc[idx]
        pt_path = self.root / item['filename']
        heads = torch.load(pt_path)
        heads = torch.stack(heads, dim=0).flatten(start_dim=0, end_dim=2)  # 2 is exclusive, only flattens dim0 and dim1
        if self.transform:
            heads = self.transform(heads)
        return heads.to(torch.float32), torch.tensor(item['prediction'] == item['correct']).to(torch.long)

In [81]:
from torchvision.transforms import v2 as transforms
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(0.5),
    transforms.RandomVerticalFlip(0.5),
    transforms.RandomChannelPermutation(),
    transforms.Normalize([0.5], [0.25]),
])

# dataset = AttentionMapDataset(Path.home() / "Downloads/mmlu_output/", Path.home() / "Downloads/mmlu_attention_files_list.txt", transform=transform)
dataset = AllAttentionMapDataset(Path.home() / "Downloads/mmlu_output/", Path.home() / "Downloads/mmlu_attention_files_list.txt", transform=transform)

# Randomly split into training and test set
indices = torch.randperm(len(dataset)).tolist()
split_idx = int(len(indices) * 0.8)
training_data = torch.utils.data.Subset(dataset, indices[:split_idx])
test_data = torch.utils.data.Subset(dataset, indices[split_idx:])

In [82]:
len(training_data), len(test_data)

(1068, 267)

In [83]:
train_dataloader = torch.utils.data.DataLoader(training_data, batch_size=1)
test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=1)

In [84]:
training_data.__getitem__(0)[0].shape

  heads = torch.load(pt_path)


torch.Size([1024, 55, 55])

In [85]:
model(training_data.__getitem__(0)[0].unsqueeze(0))

  heads = torch.load(pt_path)


tensor([[ 0.2379, -0.0539]], device='cuda:0', grad_fn=<AddmmBackward0>)

In [93]:
import torch.nn.functional as F
import torchmetrics
from tqdm import tqdm
accuracy = torchmetrics.Accuracy(task='binary').to(device)
recall = torchmetrics.Recall(task='binary').to(device)
precision = torchmetrics.Precision(task='binary').to(device)
auroc = torchmetrics.AUROC(task='binary').to(device)

In [95]:
params = model.parameters()
optimizer = torch.optim.Adam(params, lr=5e-2)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.1)

num_epochs = 10
for epoch in range(num_epochs):
    tr_loss = 0
    nb_tr_steps = 0
    for x, y in tqdm(train_dataloader, desc='Training: '):
        x = x.to(device)
        y = y.to(device)
        
        optimizer.zero_grad()

        try:
            logits = model(x)
        except Exception as e:
            print(x.shape, y, f)
            raise e

        loss = F.cross_entropy(logits, y)

        loss.backward()
        
        optimizer.step()

        tr_loss += loss.item()
        nb_tr_steps += 1
    scheduler.step()

    print(f'Epoch {epoch}: Training Loss:', tr_loss / nb_tr_steps)

    model.eval()
    val_acc = []
    val_prec = []
    val_rec = []
    val_auroc = []
    for x, y in tqdm(test_dataloader, desc='Validation'):
        x = x.to(device)
        y = y.to(device)
        
        logits = model(x)
        prediction = torch.argmax(logits, dim=-1)

        val_acc.append(accuracy(prediction, y))
        val_prec.append(precision(prediction, y))
        val_rec.append(recall(prediction, y))
        val_auroc.append(auroc(prediction, y))

    print(f'Epoch {epoch}: Validation Accuracy:', sum(val_acc) / len(val_acc))
    print(f'Epoch {epoch}: Validation Precision:', sum(val_prec) / len(val_prec))
    print(f'Epoch {epoch}: Validation Recall:', sum(val_rec) / len(val_rec))
    print(f'Epoch {epoch}: Validation AUROC:', sum(val_auroc) / len(val_auroc))

  heads = torch.load(pt_path)

raining: 100%|████████████████████████████████████████████████████████████████████| 1068/1068 [01:09<00:00, 15.26it/s]

Epoch 0: Training Loss: 1.415288276960271



alidation: 100%|████████████████████████████████████████████████████████████████████| 267/267 [00:10<00:00, 24.71it/s]

Epoch 0: Validation Accuracy: tensor(0.5768, device='cuda:0')
Epoch 0: Validation Precision: tensor(0.3071, device='cuda:0')
Epoch 0: Validation Recall: tensor(0.3071, device='cuda:0')
Epoch 0: Validation AUROC: tensor(0., device='cuda:0')



raining: 100%|████████████████████████████████████████████████████████████████████| 1068/1068 [01:09<00:00, 15.45it/s]

Epoch 1: Training Loss: 1.3076700279855575



alidation: 100%|████████████████████████████████████████████████████████████████████| 267/267 [00:12<00:00, 20.63it/s]

Epoch 1: Validation Accuracy: tensor(0.5768, device='cuda:0')
Epoch 1: Validation Precision: tensor(0.3071, device='cuda:0')
Epoch 1: Validation Recall: tensor(0.3071, device='cuda:0')
Epoch 1: Validation AUROC: tensor(0., device='cuda:0')



raining: 100%|████████████████████████████████████████████████████████████████████| 1068/1068 [01:08<00:00, 15.51it/s]

Epoch 2: Training Loss: 0.690339772302783



alidation: 100%|████████████████████████████████████████████████████████████████████| 267/267 [00:14<00:00, 17.81it/s]

Epoch 2: Validation Accuracy: tensor(0.5768, device='cuda:0')
Epoch 2: Validation Precision: tensor(0.3071, device='cuda:0')
Epoch 2: Validation Recall: tensor(0.3071, device='cuda:0')
Epoch 2: Validation AUROC: tensor(0., device='cuda:0')



raining: 100%|████████████████████████████████████████████████████████████████████| 1068/1068 [01:08<00:00, 15.53it/s]

Epoch 3: Training Loss: 0.6880884663889024



alidation: 100%|████████████████████████████████████████████████████████████████████| 267/267 [00:17<00:00, 15.65it/s]

Epoch 3: Validation Accuracy: tensor(0.5768, device='cuda:0')
Epoch 3: Validation Precision: tensor(0.3071, device='cuda:0')
Epoch 3: Validation Recall: tensor(0.3071, device='cuda:0')
Epoch 3: Validation AUROC: tensor(0., device='cuda:0')



raining: 100%|████████████████████████████████████████████████████████████████████| 1068/1068 [01:08<00:00, 15.55it/s]

Epoch 4: Training Loss: 0.6843108450540443



alidation: 100%|████████████████████████████████████████████████████████████████████| 267/267 [00:19<00:00, 13.87it/s]

Epoch 4: Validation Accuracy: tensor(0.5768, device='cuda:0')
Epoch 4: Validation Precision: tensor(0.3071, device='cuda:0')
Epoch 4: Validation Recall: tensor(0.3071, device='cuda:0')
Epoch 4: Validation AUROC: tensor(0., device='cuda:0')



raining: 100%|████████████████████████████████████████████████████████████████████| 1068/1068 [01:08<00:00, 15.59it/s]

Epoch 5: Training Loss: 0.6842318316189091



alidation: 100%|████████████████████████████████████████████████████████████████████| 267/267 [00:21<00:00, 12.24it/s]

Epoch 5: Validation Accuracy: tensor(0.5768, device='cuda:0')
Epoch 5: Validation Precision: tensor(0.3071, device='cuda:0')
Epoch 5: Validation Recall: tensor(0.3071, device='cuda:0')
Epoch 5: Validation AUROC: tensor(0., device='cuda:0')



raining: 100%|████████████████████████████████████████████████████████████████████| 1068/1068 [01:08<00:00, 15.66it/s]

Epoch 6: Training Loss: 0.6836288277241175



alidation: 100%|████████████████████████████████████████████████████████████████████| 267/267 [00:23<00:00, 11.28it/s]

Epoch 6: Validation Accuracy: tensor(0.5768, device='cuda:0')
Epoch 6: Validation Precision: tensor(0.3071, device='cuda:0')
Epoch 6: Validation Recall: tensor(0.3071, device='cuda:0')
Epoch 6: Validation AUROC: tensor(0., device='cuda:0')



raining:   6%|████                                                                  | 62/1068 [00:03<01:02, 16.17it/s]

KeyboardInterrupt: 