In [2]:
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchmetrics

from pathlib import Path
from torchvision.models import convnext
from torchvision.transforms import v2 as transforms
from tqdm import tqdm

In [3]:
model = torchvision.models.convnext_tiny(weights=convnext.ConvNeXt_Tiny_Weights.DEFAULT)
model

ConvNeXt(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 96, kernel_size=(4, 4), stride=(4, 4))
      (1): LayerNorm2d((96,), eps=1e-06, elementwise_affine=True)
    )
    (1): Sequential(
      (0): CNBlock(
        (block): Sequential(
          (0): Conv2d(96, 96, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=96)
          (1): Permute()
          (2): LayerNorm((96,), eps=1e-06, elementwise_affine=True)
          (3): Linear(in_features=96, out_features=384, bias=True)
          (4): GELU(approximate='none')
          (5): Linear(in_features=384, out_features=96, bias=True)
          (6): Permute()
        )
        (stochastic_depth): StochasticDepth(p=0.0, mode=row)
      )
      (1): CNBlock(
        (block): Sequential(
          (0): Conv2d(96, 96, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=96)
          (1): Permute()
          (2): LayerNorm((96,), eps=1e-06, elementwise_affine=True)
          (3): Linear(in_features=

## Experiment: Classify a single layer through all its 32 heads' attention maps.

In [4]:
# Create dataset
LAYERS_INCLUDED = [1, 11, 21, 31]
NUM_LAYER_LLM = len(LAYERS_INCLUDED)

class AttentionMapDataset(torch.utils.data.Dataset):
    def __init__(self, root, annotations_file, transform=None, target_transform=None):
        self.root = Path(root)
        self.files = pd.read_csv(annotations_file)
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.files) * NUM_LAYER_LLM

    def __getitem__(self, idx):
        f_idx = int(idx // NUM_LAYER_LLM)
        pt_path = self.root / self.files['filename'][f_idx]
        f = torch.load(pt_path)
        random_layer = idx % NUM_LAYER_LLM
        layer_idx = LAYERS_INCLUDED[random_layer]
        heads = f[layer_idx]
        if self.transform:
            heads = self.transform(heads)
        bucket_label = torch.tensor(random_layer)
        heads.unsqueeze(dim=0)
        bucket_label.unsqueeze(dim=0)
        return heads.to(torch.float32), bucket_label


In [13]:
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(0.5),
    transforms.RandomVerticalFlip(0.5),
    transforms.RandomChannelPermutation(),
    transforms.Normalize([0.5], [0.25]),
])

dataset = AttentionMapDataset(Path.home() / "Downloads/mmlu_output/", Path.home() / "Downloads/mmlu_attention_files_list.txt", transform=transform)

# Randomly split into training and test set
# indices = torch.randperm(len(dataset)).tolist()
indices = torch.arange(len(dataset)).tolist()
split_idx = int((len(indices) // NUM_LAYER_LLM) * 0.8) * NUM_LAYER_LLM
train_data = torch.utils.data.Subset(dataset, indices[:split_idx])
test_data = torch.utils.data.Subset(dataset, indices[split_idx:])
len(train_data), len(test_data)

(5552, 1388)

In [14]:
split_idx

5552

In [15]:
dataset[split_idx+3]

  f = torch.load(pt_path)


(tensor([[[-1.5781, -1.5156, -1.8047,  ..., -2.0000, -2.0000,  0.7344],
          [-2.0000, -1.0312, -1.6172,  ..., -2.0000, -2.0000,  0.4688],
          [-2.0000, -2.0000, -1.6484,  ..., -2.0000, -2.0000,  1.3281],
          ...,
          [-2.0000, -2.0000, -2.0000,  ..., -1.6250, -1.8750,  1.5000],
          [-2.0000, -2.0000, -2.0000,  ..., -2.0000, -1.6875,  1.6875],
          [-2.0000, -2.0000, -2.0000,  ..., -2.0000, -2.0000,  2.0000]],
 
         [[-0.1953, -1.6719, -1.7812,  ..., -1.9922, -2.0000, -0.8906],
          [-2.0000,  0.1875, -1.5859,  ..., -1.9922, -2.0000, -0.9922],
          [-2.0000, -2.0000, -0.5781,  ..., -1.9844, -1.9844, -0.3672],
          ...,
          [-2.0000, -2.0000, -2.0000,  ..., -0.4531, -1.7891,  0.2344],
          [-2.0000, -2.0000, -2.0000,  ..., -2.0000, -0.6172,  0.6250],
          [-2.0000, -2.0000, -2.0000,  ..., -2.0000, -2.0000,  2.0000]],
 
         [[-1.8203, -1.7188, -1.8594,  ..., -1.9844, -1.9688,  0.8594],
          [-2.0000, -1.7969,

In [16]:
train_data[-1]

  f = torch.load(pt_path)


(tensor([[[-0.4219, -1.6875, -1.6875,  ..., -2.0000, -2.0000, -0.4141],
          [-2.0000,  0.2344, -1.7344,  ..., -2.0000, -2.0000, -0.6328],
          [-2.0000, -2.0000,  0.4531,  ..., -2.0000, -2.0000, -0.7266],
          ...,
          [-2.0000, -2.0000, -2.0000,  ..., -0.0156, -1.9062, -0.0781],
          [-2.0000, -2.0000, -2.0000,  ..., -2.0000, -0.2422,  0.2500],
          [-2.0000, -2.0000, -2.0000,  ..., -2.0000, -2.0000,  2.0000]],
 
         [[-1.1406, -1.9453, -1.9688,  ..., -2.0000, -2.0000,  0.9531],
          [-2.0000, -0.8750, -1.8594,  ..., -2.0000, -2.0000,  0.5938],
          [-2.0000, -2.0000, -1.0312,  ..., -2.0000, -2.0000,  0.9062],
          ...,
          [-2.0000, -2.0000, -2.0000,  ..., -1.0625, -1.9062,  0.9688],
          [-2.0000, -2.0000, -2.0000,  ..., -2.0000, -0.4766,  0.4688],
          [-2.0000, -2.0000, -2.0000,  ..., -2.0000, -2.0000,  2.0000]],
 
         [[ 0.3750, -1.9844, -1.9922,  ..., -2.0000, -2.0000, -0.8750],
          [-2.0000, -1.3594,

In [17]:
test_data[-1]

  f = torch.load(pt_path)


(tensor([[[-2.0000, -2.0000, -2.0000,  ..., -2.0000, -2.0000,  2.0000],
          [-2.0000, -2.0000, -2.0000,  ..., -2.0000,  2.0000, -2.0000],
          [-2.0000, -2.0000, -2.0000,  ...,  2.0000, -2.0000, -2.0000],
          ...,
          [-2.0000, -2.0000,  1.9531,  ..., -2.0000, -2.0000, -2.0000],
          [-2.0000,  1.8906, -1.9531,  ..., -2.0000, -2.0000, -2.0000],
          [ 1.9844, -1.9844, -2.0000,  ..., -2.0000, -2.0000, -2.0000]],
 
         [[-2.0000, -2.0000, -2.0000,  ..., -2.0000, -2.0000,  2.0000],
          [-2.0000, -2.0000, -2.0000,  ..., -2.0000, -0.1562,  0.1562],
          [-2.0000, -2.0000, -2.0000,  ..., -1.2656, -1.6953,  0.9688],
          ...,
          [-2.0000, -2.0000, -1.7188,  ..., -1.9922, -2.0000,  1.5625],
          [-2.0000, -1.1875, -1.8672,  ..., -1.9922, -1.9922,  0.9062],
          [-0.9844, -1.9141, -1.9922,  ..., -2.0000, -2.0000,  0.8438]],
 
         [[-2.0000, -2.0000, -2.0000,  ..., -2.0000, -2.0000,  2.0000],
          [-2.0000, -2.0000,

In [18]:
train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=1)
test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=1)

In [19]:
# Change features
IN_CHANNELS = 32
NUM_CLASSES = NUM_LAYER_LLM
device = 'cuda'

accuracy = torchmetrics.Accuracy(task='multiclass', num_classes=NUM_CLASSES).to(device)
recall = torchmetrics.Recall(task='multiclass', num_classes=NUM_CLASSES).to(device)
precision = torchmetrics.Precision(task='multiclass', num_classes=NUM_CLASSES).to(device)
auroc = torchmetrics.AUROC(task='multiclass', num_classes=NUM_CLASSES).to(device)

model.features[0][0] = nn.Conv2d(IN_CHANNELS, 96, kernel_size=(4,4), stride=(4,4))
model.classifier[2] = nn.Linear(768, NUM_CLASSES)
model.to(device)
model

ConvNeXt(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(32, 96, kernel_size=(4, 4), stride=(4, 4))
      (1): LayerNorm2d((96,), eps=1e-06, elementwise_affine=True)
    )
    (1): Sequential(
      (0): CNBlock(
        (block): Sequential(
          (0): Conv2d(96, 96, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=96)
          (1): Permute()
          (2): LayerNorm((96,), eps=1e-06, elementwise_affine=True)
          (3): Linear(in_features=96, out_features=384, bias=True)
          (4): GELU(approximate='none')
          (5): Linear(in_features=384, out_features=96, bias=True)
          (6): Permute()
        )
        (stochastic_depth): StochasticDepth(p=0.0, mode=row)
      )
      (1): CNBlock(
        (block): Sequential(
          (0): Conv2d(96, 96, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=96)
          (1): Permute()
          (2): LayerNorm((96,), eps=1e-06, elementwise_affine=True)
          (3): Linear(in_features

In [20]:
optimizer = torch.optim.Adam(model.parameters(), lr=5e-2)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.5)

num_epochs = 10
for epoch in range(num_epochs):
    tr_loss = 0
    nb_tr_steps = 0
    for x, y in tqdm(train_dataloader, desc='Training: '):
        x = x.to(device)
        y = y.to(device)
        
        optimizer.zero_grad()

        try:
            logits = model(x)
        except Exception as e:
            print(x.shape, y)
            raise e

        loss = F.cross_entropy(logits, y)

        loss.backward()
        
        optimizer.step()

        tr_loss += loss.item()
        nb_tr_steps += 1
    scheduler.step()

    print(f'Epoch {epoch}: Training Loss:', tr_loss / nb_tr_steps)

    model.eval()
    val_acc = []
    val_prec = []
    val_rec = []
    val_auroc = []
    for x, y in tqdm(test_dataloader, desc='Validation'):
        x = x.to(device)
        y = y.to(device)
        
        logits = model(x)
        prediction = torch.argmax(logits, dim=-1)

        val_acc.append(accuracy(prediction, y))
        val_prec.append(precision(prediction, y))
        val_rec.append(recall(prediction, y))
        # val_auroc.append(auroc(prediction, y))

    print(f'Epoch {epoch}: Validation Accuracy:', sum(val_acc) / len(val_acc))
    print(f'Epoch {epoch}: Validation Precision:', sum(val_prec) / len(val_prec))
    print(f'Epoch {epoch}: Validation Recall:', sum(val_rec) / len(val_rec))
    # print(f'Epoch {epoch}: Validation AUROC:', sum(val_auroc) / len(val_auroc))

  f = torch.load(pt_path)

raining: 100%|████████████████████████████████████████████████████████████████████| 5552/5552 [02:55<00:00, 31.63it/s]

Epoch 0: Training Loss: 1.5283158838820057



alidation: 100%|██████████████████████████████████████████████████████████████████| 1388/1388 [00:24<00:00, 57.35it/s]

Epoch 0: Validation Accuracy: tensor(0.2500, device='cuda:0')
Epoch 0: Validation Precision: tensor(0.2500, device='cuda:0')
Epoch 0: Validation Recall: tensor(0.2500, device='cuda:0')



raining: 100%|████████████████████████████████████████████████████████████████████| 5552/5552 [02:57<00:00, 31.33it/s]

Epoch 1: Training Loss: 1.5376778817298635



alidation: 100%|██████████████████████████████████████████████████████████████████| 1388/1388 [00:23<00:00, 59.80it/s]

Epoch 1: Validation Accuracy: tensor(0.2500, device='cuda:0')
Epoch 1: Validation Precision: tensor(0.2500, device='cuda:0')
Epoch 1: Validation Recall: tensor(0.2500, device='cuda:0')



raining:   1%|▊                                                                     | 67/5552 [00:02<02:48, 32.51it/s]

KeyboardInterrupt: 

## Experiment: Classify correctness by dataset

In [None]:
# Create dataset
LAYERS_INCLUDED = [1, 11, 21, 31]
NUM_LAYER_LLM = len(LAYERS_INCLUDED)

class AttentionMapDFDataset(torch.utils.data.Dataset):
    def __init__(self, root, annotations_file, transform=None, target_transform=None):
        self.root = Path(root)
        self.files = df
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.files) * NUM_LAYER_LLM

    def __getitem__(self, idx):
        f_idx = int(idx // NUM_LAYER_LLM)
        pt_path = self.root / self.files['filename'][f_idx]
        f = torch.load(pt_path)
        random_layer = idx % NUM_LAYER_LLM
        layer_idx = LAYERS_INCLUDED[random_layer]
        heads = f[layer_idx]
        if self.transform:
            heads = self.transform(heads)
        bucket_label = torch.tensor(random_layer)
        heads.unsqueeze(dim=0)
        bucket_label.unsqueeze(dim=0)
        return heads.to(torch.float32), bucket_label


In [None]:
dataset_df = pd.read_csv(Path.home() / "Downloads/mmlu_attention_files_list.txt")
dataset_df['dataset'] = dataset_df['filename'].apply(lambda filename: str(Path(filename).parent.parent.name))
dataset_es = dataset_df[dataset_df['dataset'] == 'elementary_science']

dataset_es = AttentionMapDFDataset(Path.home / "Downloads/mmlu_output/", dataset_es, transform=None)

## Experiment: Classify correctness by layer

In [7]:

class AttentionMapBatchDataset(torch.utils.data.Dataset):
    def __init__(self, root, annotations_file, transform=None, target_transform=None):
        self.root = Path(root)
        self.files = pd.read_csv(annotations_file)
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        pt_path = self.root / self.files['filename'][idx]
        f = torch.load(pt_path)
        heads = torch.stack(f).squeeze()
        if self.transform:
            heads = self.transform(heads)
        bucket_labels = torch.arange(NUM_LAYER_LLM)
        return heads.to(torch.float32), bucket_labels.to(torch.long)


In [8]:

transform = transforms.Compose([
    transforms.RandomHorizontalFlip(0.5),
    transforms.RandomVerticalFlip(0.5),
    transforms.RandomChannelPermutation(),
    transforms.Normalize([0.5], [0.25]),
])

dataset = AttentionMapBatchDataset(Path.home() / "Downloads/mmlu_output/", Path.home() / "Downloads/mmlu_attention_files_list.txt", transform=transform)

# Randomly split into training and test set
indices = torch.randperm(len(dataset)).tolist()
split_idx = int(len(indices) * 0.8)
training_data = torch.utils.data.Subset(dataset, indices[:split_idx])
test_data = torch.utils.data.Subset(dataset, indices[split_idx:])

In [9]:
len(training_data), len(test_data)

(1068, 267)

In [10]:
train_dataloader = torch.utils.data.DataLoader(training_data, batch_size=1)
test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=1)

In [11]:
training_data.__getitem__(0)

  f = torch.load(pt_path)


(tensor([[[[ 2.0000, -2.0000, -2.0000,  ..., -2.0000, -2.0000, -2.0000],
           [ 1.6250, -1.6250, -2.0000,  ..., -2.0000, -2.0000, -2.0000],
           [ 1.8750, -1.9062, -1.9688,  ..., -2.0000, -2.0000, -2.0000],
           ...,
           [-1.7500, -1.8594, -1.7578,  ..., -1.7266, -2.0000, -2.0000],
           [ 0.1562, -1.9766, -1.9922,  ..., -1.8984, -1.9609, -2.0000],
           [-1.6250, -1.9297, -1.9375,  ..., -1.1016, -1.8281, -1.8906]],
 
          [[ 2.0000, -2.0000, -2.0000,  ..., -2.0000, -2.0000, -2.0000],
           [ 1.5000, -1.5000, -2.0000,  ..., -2.0000, -2.0000, -2.0000],
           [ 1.5781, -1.8438, -1.7266,  ..., -2.0000, -2.0000, -2.0000],
           ...,
           [-1.5703, -2.0000, -1.9688,  ..., -1.8516, -2.0000, -2.0000],
           [-1.3438, -2.0000, -1.9219,  ..., -1.9219, -1.3672, -2.0000],
           [-1.5312, -2.0000, -1.9609,  ..., -1.7500, -1.6953, -1.8906]],
 
          [[ 2.0000, -2.0000, -2.0000,  ..., -2.0000, -2.0000, -2.0000],
           [ 

In [12]:
model(training_data.__getitem__(0)[0])

  f = torch.load(pt_path)


tensor([[-0.0271, -0.0674, -0.1068,  ..., -0.0028, -0.1077, -0.1471],
        [-0.0835, -0.0467, -0.1045,  ...,  0.0459, -0.1579, -0.1153],
        [-0.0839, -0.0469, -0.1067,  ...,  0.0483, -0.1555, -0.1108],
        ...,
        [-0.0828, -0.0502, -0.1122,  ...,  0.0479, -0.1581, -0.1150],
        [-0.0315, -0.0713, -0.1114,  ..., -0.0012, -0.1055, -0.1488],
        [-0.0809, -0.0500, -0.0982,  ...,  0.0465, -0.1564, -0.1143]],
       device='cuda:0', grad_fn=<AddmmBackward0>)

In [13]:
import torch.nn.functional as F
import torchmetrics
from tqdm import tqdm
NUM_CLASSES = 32
accuracy = torchmetrics.Accuracy(task='multiclass', num_classes=NUM_CLASSES).to(device)
recall = torchmetrics.Recall(task='multiclass', num_classes=NUM_CLASSES).to(device)
precision = torchmetrics.Precision(task='multiclass', num_classes=NUM_CLASSES).to(device)
auroc = torchmetrics.AUROC(task='multiclass', num_classes=NUM_CLASSES).to(device)

In [None]:
params = model.parameters()
optimizer = torch.optim.Adam(params, lr=5e-2)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.5)

num_epochs = 10
for epoch in range(num_epochs):
    tr_loss = 0
    nb_tr_steps = 0
    for x, y in tqdm(train_dataloader, desc='Training: '):
        x = x.squeeze().to(device)
        y = y.squeeze().to(device)
        
        optimizer.zero_grad()

        try:
            logits = model(x)
        except Exception as e:
            print(x.shape, y)
            raise e

        loss = F.cross_entropy(logits, y)

        loss.backward()
        
        optimizer.step()

        tr_loss += loss.item()
        nb_tr_steps += 1
    scheduler.step()

    print(f'Epoch {epoch}: Training Loss:', tr_loss / nb_tr_steps)

    model.eval()
    val_acc = []
    val_prec = []
    val_rec = []
    val_auroc = []
    for x, y in tqdm(test_dataloader, desc='Validation'):
        x = x.squeeze().to(device)
        y = y.squeeze().to(device)
        
        logits = model(x)
        prediction = torch.argmax(logits, dim=-1)

        val_acc.append(accuracy(prediction, y))
        val_prec.append(precision(prediction, y))
        val_rec.append(recall(prediction, y))
        # val_auroc.append(auroc(prediction, y))

    print(f'Epoch {epoch}: Validation Accuracy:', sum(val_acc) / len(val_acc))
    print(f'Epoch {epoch}: Validation Precision:', sum(val_prec) / len(val_prec))
    print(f'Epoch {epoch}: Validation Recall:', sum(val_rec) / len(val_rec))
    # print(f'Epoch {epoch}: Validation AUROC:', sum(val_auroc) / len(val_auroc))

  f = torch.load(pt_path)

raining: 100%|████████████████████████████████████████████████████████████████████| 1068/1068 [18:00<00:00,  1.01s/it]

Epoch 0: Training Loss: 3.5466350113854426



alidation: 100%|████████████████████████████████████████████████████████████████████| 267/267 [00:34<00:00,  7.83it/s]

Epoch 0: Validation Accuracy: tensor(0.0312, device='cuda:0')
Epoch 0: Validation Precision: tensor(0.0312, device='cuda:0')
Epoch 0: Validation Recall: tensor(0.0312, device='cuda:0')



raining: 100%|████████████████████████████████████████████████████████████████████| 1068/1068 [19:48<00:00,  1.11s/it]

Epoch 1: Training Loss: 3.4660963371451876



alidation: 100%|████████████████████████████████████████████████████████████████████| 267/267 [00:34<00:00,  7.80it/s]

Epoch 1: Validation Accuracy: tensor(0.0312, device='cuda:0')
Epoch 1: Validation Precision: tensor(0.0312, device='cuda:0')
Epoch 1: Validation Recall: tensor(0.0312, device='cuda:0')



raining: 100%|████████████████████████████████████████████████████████████████████| 1068/1068 [19:49<00:00,  1.11s/it]

Epoch 2: Training Loss: 3.465740100721295



alidation: 100%|████████████████████████████████████████████████████████████████████| 267/267 [00:34<00:00,  7.77it/s]

Epoch 2: Validation Accuracy: tensor(0.0312, device='cuda:0')
Epoch 2: Validation Precision: tensor(0.0312, device='cuda:0')
Epoch 2: Validation Recall: tensor(0.0312, device='cuda:0')



raining: 100%|████████████████████████████████████████████████████████████████████| 1068/1068 [19:46<00:00,  1.11s/it]

Epoch 3: Training Loss: 3.4659348172641424



alidation: 100%|████████████████████████████████████████████████████████████████████| 267/267 [00:33<00:00,  7.89it/s]

Epoch 3: Validation Accuracy: tensor(0.0312, device='cuda:0')
Epoch 3: Validation Precision: tensor(0.0312, device='cuda:0')
Epoch 3: Validation Recall: tensor(0.0312, device='cuda:0')



raining: 100%|████████████████████████████████████████████████████████████████████| 1068/1068 [19:50<00:00,  1.11s/it]

Epoch 4: Training Loss: 3.4657431532827654



alidation: 100%|████████████████████████████████████████████████████████████████████| 267/267 [00:34<00:00,  7.84it/s]

Epoch 4: Validation Accuracy: tensor(0.0312, device='cuda:0')
Epoch 4: Validation Precision: tensor(0.0312, device='cuda:0')
Epoch 4: Validation Recall: tensor(0.0312, device='cuda:0')



raining: 100%|████████████████████████████████████████████████████████████████████| 1068/1068 [19:47<00:00,  1.11s/it]

Epoch 5: Training Loss: 3.4658019937826006



alidation: 100%|████████████████████████████████████████████████████████████████████| 267/267 [00:34<00:00,  7.84it/s]

Epoch 5: Validation Accuracy: tensor(0.0312, device='cuda:0')
Epoch 5: Validation Precision: tensor(0.0312, device='cuda:0')
Epoch 5: Validation Recall: tensor(0.0312, device='cuda:0')



raining: 100%|████████████████████████████████████████████████████████████████████| 1068/1068 [19:46<00:00,  1.11s/it]

Epoch 6: Training Loss: 3.465737959865327



alidation: 100%|████████████████████████████████████████████████████████████████████| 267/267 [00:34<00:00,  7.81it/s]

Epoch 6: Validation Accuracy: tensor(0.0312, device='cuda:0')
Epoch 6: Validation Precision: tensor(0.0312, device='cuda:0')
Epoch 6: Validation Recall: tensor(0.0312, device='cuda:0')



raining: 100%|████████████████████████████████████████████████████████████████████| 1068/1068 [19:47<00:00,  1.11s/it]

Epoch 7: Training Loss: 3.4657572952102633



alidation: 100%|████████████████████████████████████████████████████████████████████| 267/267 [00:34<00:00,  7.81it/s]

Epoch 7: Validation Accuracy: tensor(0.0312, device='cuda:0')
Epoch 7: Validation Precision: tensor(0.0312, device='cuda:0')
Epoch 7: Validation Recall: tensor(0.0312, device='cuda:0')



raining: 100%|████████████████████████████████████████████████████████████████████| 1068/1068 [19:46<00:00,  1.11s/it]

Epoch 8: Training Loss: 3.4657369651151506



alidation: 100%|████████████████████████████████████████████████████████████████████| 267/267 [00:33<00:00,  7.86it/s]

Epoch 8: Validation Accuracy: tensor(0.0312, device='cuda:0')
Epoch 8: Validation Precision: tensor(0.0312, device='cuda:0')
Epoch 8: Validation Recall: tensor(0.0312, device='cuda:0')


Training:   9%|██████                                                                | 93/1068 [01:45<16:20,  1.01s/it]

In [25]:
# model = torchvision.models.convnext_tiny(weights=convnext.ConvNeXt_Tiny_Weights.DEFAULT)
model = torchvision.models.convnext_tiny()
# Change features
IN_CHANNELS = 32*32
device = 'cuda'
model.features[0][0] = nn.Conv2d(IN_CHANNELS, 96, kernel_size=(4,4), stride=(4,4))
model.classifier[2] = nn.Linear(768, 2)
model.to(device)
model

ConvNeXt(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(1024, 96, kernel_size=(4, 4), stride=(4, 4))
      (1): LayerNorm2d((96,), eps=1e-06, elementwise_affine=True)
    )
    (1): Sequential(
      (0): CNBlock(
        (block): Sequential(
          (0): Conv2d(96, 96, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=96)
          (1): Permute()
          (2): LayerNorm((96,), eps=1e-06, elementwise_affine=True)
          (3): Linear(in_features=96, out_features=384, bias=True)
          (4): GELU(approximate='none')
          (5): Linear(in_features=384, out_features=96, bias=True)
          (6): Permute()
        )
        (stochastic_depth): StochasticDepth(p=0.0, mode=row)
      )
      (1): CNBlock(
        (block): Sequential(
          (0): Conv2d(96, 96, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=96)
          (1): Permute()
          (2): LayerNorm((96,), eps=1e-06, elementwise_affine=True)
          (3): Linear(in_featur

In [36]:
class AllAttentionMapDataset(torch.utils.data.Dataset):
    def __init__(self, root, annotations_file, transform=None, target_transform=None):
        self.root = Path(root)
        self.files = pd.read_csv(annotations_file)
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        item = self.files.iloc[idx]
        pt_path = self.root / item['filename']
        heads = torch.load(pt_path)
        heads = torch.stack(heads, dim=0).flatten(start_dim=0, end_dim=2)  # 2 is exclusive, only flattens dim0 and dim1
        if self.transform:
            heads = self.transform(heads)
        return heads.to(torch.float32), torch.tensor(item['prediction'] == item['correct']).to(torch.long)

In [37]:
from torchvision.transforms import v2 as transforms
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(0.5),
    transforms.RandomVerticalFlip(0.5),
    transforms.RandomChannelPermutation(),
    transforms.Normalize([0.5], [0.25]),
])

# dataset = AttentionMapDataset(Path.home() / "Downloads/mmlu_output/", Path.home() / "Downloads/mmlu_attention_files_list.txt", transform=transform)
dataset = AllAttentionMapDataset(Path.home() / "Downloads/mmlu_output/", Path.home() / "Downloads/mmlu_attention_files_list.txt", transform=transform)

# Randomly split into training and test set
indices = torch.randperm(len(dataset)).tolist()
split_idx = int(len(indices) * 0.8)
training_data = torch.utils.data.Subset(dataset, indices[:split_idx])
validation_data = torch.utils.data.Subset(dataset, indices[split_idx:])

In [38]:
len(training_data), len(test_data)

(1068, 267)

In [39]:
train_dataloader = torch.utils.data.DataLoader(training_data, batch_size=1)
validation_dataloader = torch.utils.data.DataLoader(validation_data, batch_size=1)

In [40]:
training_data.__getitem__(0)[0].shape

  heads = torch.load(pt_path)


torch.Size([1024, 74, 74])

In [41]:
model(training_data.__getitem__(0)[0].unsqueeze(0))

  heads = torch.load(pt_path)


tensor([[-0.3492,  0.3039]], device='cuda:0', grad_fn=<AddmmBackward0>)

In [42]:
import torch.nn.functional as F
import torchmetrics
from tqdm import tqdm
accuracy = torchmetrics.Accuracy(task='binary').to(device)
recall = torchmetrics.Recall(task='binary').to(device)
precision = torchmetrics.Precision(task='binary').to(device)
auroc = torchmetrics.AUROC(task='binary').to(device)

In [43]:
params = model.parameters()
optimizer = torch.optim.Adam(params, lr=5e-2)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.5)

num_epochs = 10
for epoch in range(num_epochs):
    tr_loss = 0
    nb_tr_steps = 0
    for x, y in tqdm(train_dataloader, desc='Training: '):
        x = x.to(device)
        y = y.to(device)
        
        optimizer.zero_grad()

        try:
            logits = model(x)
        except Exception as e:
            print(x.shape, y, f)
            raise e

        loss = F.cross_entropy(logits, y)

        loss.backward()
        
        optimizer.step()

        tr_loss += loss.item()
        nb_tr_steps += 1
    scheduler.step()

    print(f'Epoch {epoch}: Training Loss:', tr_loss / nb_tr_steps)

    model.eval()
    val_acc = []
    val_prec = []
    val_rec = []
    val_auroc = []
    for x, y in tqdm(validation_dataloader, desc='Validation'):
        x = x.to(device)
        y = y.to(device)
        
        logits = model(x)
        prediction = torch.argmax(logits, dim=-1)

        val_acc.append(accuracy(prediction, y))
        val_prec.append(precision(prediction, y))
        val_rec.append(recall(prediction, y))
        val_auroc.append(auroc(prediction, y))

    print(f'Epoch {epoch}: Validation Accuracy:', sum(val_acc) / len(val_acc))
    print(f'Epoch {epoch}: Validation Precision:', sum(val_prec) / len(val_prec))
    print(f'Epoch {epoch}: Validation Recall:', sum(val_rec) / len(val_rec))
    print(f'Epoch {epoch}: Validation AUROC:', sum(val_auroc) / len(val_auroc))

  heads = torch.load(pt_path)

raining: 100%|████████████████████████████████████████████████████████████████████| 1068/1068 [00:59<00:00, 17.89it/s]

Epoch 0: Training Loss: 1.2979160130226597



alidation: 100%|████████████████████████████████████████████████████████████████████| 267/267 [00:12<00:00, 20.76it/s]

Epoch 0: Validation Accuracy: tensor(0.4944, device='cuda:0')
Epoch 0: Validation Precision: tensor(0.4944, device='cuda:0')
Epoch 0: Validation Recall: tensor(0.4944, device='cuda:0')
Epoch 0: Validation AUROC: tensor(0., device='cuda:0')



raining: 100%|████████████████████████████████████████████████████████████████████| 1068/1068 [00:55<00:00, 19.36it/s]

Epoch 1: Training Loss: 0.7849293898534699



alidation: 100%|████████████████████████████████████████████████████████████████████| 267/267 [00:13<00:00, 20.32it/s]

Epoch 1: Validation Accuracy: tensor(0.4944, device='cuda:0')
Epoch 1: Validation Precision: tensor(0.4944, device='cuda:0')
Epoch 1: Validation Recall: tensor(0.4944, device='cuda:0')
Epoch 1: Validation AUROC: tensor(0., device='cuda:0')



raining: 100%|████████████████████████████████████████████████████████████████████| 1068/1068 [00:52<00:00, 20.16it/s]

Epoch 2: Training Loss: 0.7033707734741522



alidation: 100%|████████████████████████████████████████████████████████████████████| 267/267 [00:15<00:00, 17.22it/s]

Epoch 2: Validation Accuracy: tensor(0.4944, device='cuda:0')
Epoch 2: Validation Precision: tensor(0.4944, device='cuda:0')
Epoch 2: Validation Recall: tensor(0.4944, device='cuda:0')
Epoch 2: Validation AUROC: tensor(0., device='cuda:0')



raining: 100%|████████████████████████████████████████████████████████████████████| 1068/1068 [00:55<00:00, 19.30it/s]

Epoch 3: Training Loss: 0.6957453842020214



alidation: 100%|████████████████████████████████████████████████████████████████████| 267/267 [00:17<00:00, 15.35it/s]

Epoch 3: Validation Accuracy: tensor(0.4944, device='cuda:0')
Epoch 3: Validation Precision: tensor(0.4944, device='cuda:0')
Epoch 3: Validation Recall: tensor(0.4944, device='cuda:0')
Epoch 3: Validation AUROC: tensor(0., device='cuda:0')



raining: 100%|████████████████████████████████████████████████████████████████████| 1068/1068 [00:52<00:00, 20.33it/s]

Epoch 4: Training Loss: 0.6949262813794033



alidation: 100%|████████████████████████████████████████████████████████████████████| 267/267 [00:19<00:00, 13.56it/s]

Epoch 4: Validation Accuracy: tensor(0.4944, device='cuda:0')
Epoch 4: Validation Precision: tensor(0.4944, device='cuda:0')
Epoch 4: Validation Recall: tensor(0.4944, device='cuda:0')
Epoch 4: Validation AUROC: tensor(0., device='cuda:0')



raining: 100%|████████████████████████████████████████████████████████████████████| 1068/1068 [00:55<00:00, 19.36it/s]

Epoch 5: Training Loss: 0.6937967861971158



alidation: 100%|████████████████████████████████████████████████████████████████████| 267/267 [00:21<00:00, 12.18it/s]

Epoch 5: Validation Accuracy: tensor(0.4944, device='cuda:0')
Epoch 5: Validation Precision: tensor(0.4944, device='cuda:0')
Epoch 5: Validation Recall: tensor(0.4944, device='cuda:0')
Epoch 5: Validation AUROC: tensor(0., device='cuda:0')



raining: 100%|████████████████████████████████████████████████████████████████████| 1068/1068 [00:53<00:00, 20.05it/s]

Epoch 6: Training Loss: 0.6934781480259663



alidation: 100%|████████████████████████████████████████████████████████████████████| 267/267 [00:24<00:00, 11.09it/s]

Epoch 6: Validation Accuracy: tensor(0.5056, device='cuda:0')
Epoch 6: Validation Precision: tensor(0., device='cuda:0')
Epoch 6: Validation Recall: tensor(0., device='cuda:0')
Epoch 6: Validation AUROC: tensor(0., device='cuda:0')



raining: 100%|████████████████████████████████████████████████████████████████████| 1068/1068 [00:55<00:00, 19.37it/s]

Epoch 7: Training Loss: 0.6932987929060218



alidation: 100%|████████████████████████████████████████████████████████████████████| 267/267 [00:26<00:00, 10.03it/s]

Epoch 7: Validation Accuracy: tensor(0.5056, device='cuda:0')
Epoch 7: Validation Precision: tensor(0., device='cuda:0')
Epoch 7: Validation Recall: tensor(0., device='cuda:0')
Epoch 7: Validation AUROC: tensor(0., device='cuda:0')



raining: 100%|████████████████████████████████████████████████████████████████████| 1068/1068 [00:53<00:00, 20.07it/s]

Epoch 8: Training Loss: 0.6931857895315363



alidation: 100%|████████████████████████████████████████████████████████████████████| 267/267 [00:28<00:00,  9.45it/s]

Epoch 8: Validation Accuracy: tensor(0.5056, device='cuda:0')
Epoch 8: Validation Precision: tensor(0., device='cuda:0')
Epoch 8: Validation Recall: tensor(0., device='cuda:0')
Epoch 8: Validation AUROC: tensor(0., device='cuda:0')



raining: 100%|████████████████████████████████████████████████████████████████████| 1068/1068 [00:55<00:00, 19.33it/s]

Epoch 9: Training Loss: 0.6931075874562567


Validation: 100%|████████████████████████████████████████████████████████████████████| 267/267 [00:30<00:00,  8.68it/s]

Epoch 9: Validation Accuracy: tensor(0.5618, device='cuda:0')
Epoch 9: Validation Precision: tensor(0.1236, device='cuda:0')
Epoch 9: Validation Recall: tensor(0.1236, device='cuda:0')
Epoch 9: Validation AUROC: tensor(0., device='cuda:0')





In [None]:
# Test set
test_dataset = AllAttentionMapDataset(Path.home() / "Downloads/mmlu_output/", Path.home() / "Downloads/mmlu_attention_files_testset.txt", transform=transform)

# Randomly split into training and test set
indices = torch.randperm(len(dataset)).tolist()
split_idx = int(len(indices) * 0.8)
training_data = torch.utils.data.Subset(dataset, indices[:split_idx])
validation_data = torch.utils.data.Subset(dataset, indices[split_idx:])

In [None]:
model.eval()
val_acc = []
val_prec = []
val_rec = []
val_auroc = []
for x, y in tqdm(test_dataloader, desc='Validation'):
    x = x.to(device)
    y = y.to(device)
    
    logits = model(x)
    prediction = torch.argmax(logits, dim=-1)

    val_acc.append(accuracy(prediction, y))
    val_prec.append(precision(prediction, y))
    val_rec.append(recall(prediction, y))
    val_auroc.append(auroc(prediction, y))

print(f'Epoch {epoch}: Validation Accuracy:', sum(val_acc) / len(val_acc))
print(f'Epoch {epoch}: Validation Precision:', sum(val_prec) / len(val_prec))
print(f'Epoch {epoch}: Validation Recall:', sum(val_rec) / len(val_rec))
print(f'Epoch {epoch}: Validation AUROC:', sum(val_auroc) / len(val_auroc))