### Setup Environment

In [1]:
!source ../venv/bin/activate

In [2]:
from data.abo_spin_dataset import AboSpinsDataset
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import pandas as pd
import numpy as np
from sys import path
from PIL import Image
from tqdm import tqdm
from sklearn.metrics import classification_report

In [3]:
from models.self_attention_pool import SelfAttentionalPooler
from models.seq_attn import SeqAttention

In [4]:
from transformers import AutoProcessor, AutoModel

In [5]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

### Fetch and Load Data

In [6]:
abo_metadata = "../../spirit/spins/metadata/spins.csv.gz"
abo_images = "../../spirit/spins/original/"

In [7]:
df = pd.read_csv(abo_metadata)
df.head()

Unnamed: 0,spin_id,azimuth,image_id,height,width,path
0,61c91265,0,41wqHws7a6L,248,1075,61/61c91265/61c91265_00.jpg
1,61c91265,1,41++eZZHP9L,248,1075,61/61c91265/61c91265_01.jpg
2,61c91265,2,41YF86LhGDL,248,1075,61/61c91265/61c91265_02.jpg
3,61c91265,3,41I5Zz-kbAL,248,1075,61/61c91265/61c91265_03.jpg
4,61c91265,4,41lAQM2Ys5L,248,1075,61/61c91265/61c91265_04.jpg


In [8]:
unique_spins = list(df['spin_id'].unique())
unique_spins = np.random.permutation(unique_spins)

In [9]:
n_spins = len(unique_spins)
seq_length = 10

abo_train = AboSpinsDataset(df[df['spin_id'].isin(unique_spins[:int(n_spins*0.6)])], image_dir=abo_images, seq_len=seq_length)
abo_val = AboSpinsDataset(df[df['spin_id'].isin(unique_spins[int(n_spins*0.6):int(n_spins*0.8)])], mode="val", image_dir=abo_images, seq_len=seq_length)
abo_test = AboSpinsDataset(df[df['spin_id'].isin(unique_spins[int(n_spins*0.8):])], mode="test", image_dir=abo_images, seq_len=seq_length)

print(len(abo_train), len(abo_val), len(abo_test))

4925 1642 1642


In [10]:
# ***Run to check negative sampling ratio***
# print("Negative Sample Ratio in Train: %f" % (np.mean([1 if x[1] == False else 0 for x in abo_train])))
# print("Negative Sample Ratio in Val: %f" % (np.mean([1 if x[1] == False else 0 for x in abo_val])))
# print("Negative Sample Ratio in Test: %f" % (np.mean([1 if x[1] == False else 0 for x in abo_test])))

In [11]:
processor = AutoProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
clip_model = AutoModel.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
clip_model.to(device)

def collate_fn(batch):
    features = []
    labels = []
    for x, label in batch:
        imgs = [Image.open(path) for path in x]
        inputs = processor(images=imgs, return_tensors='pt')
        inputs.to(device)
        
        features.append(clip_model.get_image_features(**inputs).unsqueeze(0).detach())
        labels += [1] if label else [0]

    return torch.cat(features), labels

In [12]:
train_loader = DataLoader(abo_train, batch_size=16, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(abo_val, batch_size=16, shuffle=True, collate_fn=collate_fn)
test_loader = DataLoader(abo_test, batch_size=16, shuffle=True, collate_fn=collate_fn)

### Trainer

In [13]:
def train(model, dataloader, criterion, optimizer):
    model.train()
    
    predictions = []
    targets = []
    
    epoch_loss = 0
    batch_loss = 0
    for i, batch in enumerate(tqdm(dataloader)):
        out = model(batch[0].to(device))
        out = torch.squeeze(torch.sigmoid(out).cpu(), 1)

        loss = criterion(out, torch.tensor(batch[1], dtype=torch.float))

        preds = torch.where(out >= 0.5, 1, 0)
        predictions += preds.detach().tolist()
        targets += batch[1]

        epoch_loss += loss.item()
        loss.backward()

        optimizer.step()
        optimizer.zero_grad()
        
        batch_loss += loss.item()
        if (i+1) % 100 == 0:
            print(f"Loss at step {i+1}/{len(dataloader)}: {batch_loss/100}")
            batch_loss = 0

    print(f"Loss: {epoch_loss/len(dataloader)}")
    
    print(classification_report(targets, predictions))
    return

### Evaluation

In [14]:
def validate(model, dataloader):
    model.eval()
    
    predictions = []
    targets = []
    
    total_loss = 0
    for batch in tqdm(dataloader):
        with torch.no_grad():
            out = model(batch[0].to(device))
            out = torch.squeeze(torch.sigmoid(out).cpu(), 1)

            loss = criterion(out, torch.tensor(batch[1], dtype=torch.float))

            preds = torch.where(out >= 0.5, 1, 0)
            predictions += preds.detach().tolist()
            targets += batch[1]

            total_loss += loss.item()

    print(f"Loss: {total_loss/len(dataloader)}")
    
    print(classification_report(targets, predictions))
    return

### Instantiate Model

In [15]:
# model = SelfAttentionalPooler(dim=1024, seq_len=seq_length, heads=8, dim_head=64, depth=1, mlp_dim=256, proj_dim=1, dropout=0.4)
model = SeqAttention(dim=1024, seq_len=seq_length, heads=8, dim_head=64, depth=1, lstm_dim=256, lstm_layers=2, bidirectional=True, mlp_dim=512, proj_dim=1, dropout = 0.4)

### Train Model

In [16]:
model.to(device)

n_epochs = 5
criterion = nn.BCELoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

In [17]:
for epoch in range(n_epochs):
    print("-"*10, f"TRAIN EPOCH: {epoch}", "-"*10)
    train(model, train_loader, criterion, optimizer)
    print("-"*10, f"VALID EPOCH: {epoch}", "-"*10)
    validate(model, val_loader)
    print()

---------- TRAIN EPOCH: 0 ----------


 32%|███▏      | 100/308 [07:27<15:11,  4.38s/it]

Loss at step 100/308: 0.7478242415189743


 65%|██████▍   | 200/308 [14:58<07:56,  4.41s/it]

Loss at step 200/308: 0.7107023245096207


 97%|█████████▋| 300/308 [22:25<00:35,  4.45s/it]

Loss at step 300/308: 0.710358037352562


100%|██████████| 308/308 [23:00<00:00,  4.48s/it]


Loss: 0.7221521713903972
              precision    recall  f1-score   support

           0       0.50      0.47      0.48      2436
           1       0.51      0.53      0.52      2489

    accuracy                           0.50      4925
   macro avg       0.50      0.50      0.50      4925
weighted avg       0.50      0.50      0.50      4925

---------- VALID EPOCH: 0 ----------


100%|██████████| 103/103 [06:52<00:00,  4.01s/it]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Loss: 0.7027467934830675
              precision    recall  f1-score   support

           0       0.51      1.00      0.68       845
           1       0.00      0.00      0.00       797

    accuracy                           0.51      1642
   macro avg       0.26      0.50      0.34      1642
weighted avg       0.26      0.51      0.35      1642


---------- TRAIN EPOCH: 1 ----------


 32%|███▏      | 100/308 [07:27<15:24,  4.44s/it]

Loss at step 100/308: 0.6995111566781997


 65%|██████▍   | 200/308 [14:58<07:50,  4.36s/it]

Loss at step 200/308: 0.6982843059301377


 97%|█████████▋| 300/308 [22:28<00:36,  4.56s/it]

Loss at step 300/308: 0.6973169517517089


100%|██████████| 308/308 [23:02<00:00,  4.49s/it]


Loss: 0.6982060214141746
              precision    recall  f1-score   support

           0       0.51      0.57      0.54      2495
           1       0.50      0.44      0.47      2430

    accuracy                           0.51      4925
   macro avg       0.50      0.50      0.50      4925
weighted avg       0.50      0.51      0.50      4925

---------- VALID EPOCH: 1 ----------


100%|██████████| 103/103 [06:52<00:00,  4.01s/it]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Loss: 0.6974951474411973
              precision    recall  f1-score   support

           0       0.00      0.00      0.00       823
           1       0.50      1.00      0.67       819

    accuracy                           0.50      1642
   macro avg       0.25      0.50      0.33      1642
weighted avg       0.25      0.50      0.33      1642


---------- TRAIN EPOCH: 2 ----------


 32%|███▏      | 100/308 [07:26<15:17,  4.41s/it]

Loss at step 100/308: 0.6980994862318038


 65%|██████▍   | 200/308 [14:56<07:54,  4.39s/it]

Loss at step 200/308: 0.6976942718029022


 97%|█████████▋| 300/308 [22:24<00:35,  4.45s/it]

Loss at step 300/308: 0.6952277302742005


100%|██████████| 308/308 [22:59<00:00,  4.48s/it]


Loss: 0.696885261055711
              precision    recall  f1-score   support

           0       0.48      0.36      0.41      2393
           1       0.51      0.63      0.56      2532

    accuracy                           0.50      4925
   macro avg       0.50      0.50      0.49      4925
weighted avg       0.50      0.50      0.49      4925

---------- VALID EPOCH: 2 ----------


100%|██████████| 103/103 [06:51<00:00,  4.00s/it]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Loss: 0.6933911269150891
              precision    recall  f1-score   support

           0       0.51      1.00      0.68       837
           1       0.00      0.00      0.00       805

    accuracy                           0.51      1642
   macro avg       0.25      0.50      0.34      1642
weighted avg       0.26      0.51      0.34      1642


---------- TRAIN EPOCH: 3 ----------


 32%|███▏      | 100/308 [07:27<15:13,  4.39s/it]

Loss at step 100/308: 0.6971805536746979


 65%|██████▍   | 200/308 [15:00<08:39,  4.81s/it]

Loss at step 200/308: 0.6989988607168197


 97%|█████████▋| 300/308 [22:27<00:37,  4.74s/it]

Loss at step 300/308: 0.6988480448722839


100%|██████████| 308/308 [23:01<00:00,  4.48s/it]


Loss: 0.6982746133943657
              precision    recall  f1-score   support

           0       0.48      0.47      0.48      2443
           1       0.49      0.51      0.50      2482

    accuracy                           0.49      4925
   macro avg       0.49      0.49      0.49      4925
weighted avg       0.49      0.49      0.49      4925

---------- VALID EPOCH: 3 ----------


100%|██████████| 103/103 [06:52<00:00,  4.00s/it]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Loss: 0.6939754214101624
              precision    recall  f1-score   support

           0       0.00      0.00      0.00       836
           1       0.49      1.00      0.66       806

    accuracy                           0.49      1642
   macro avg       0.25      0.50      0.33      1642
weighted avg       0.24      0.49      0.32      1642


---------- TRAIN EPOCH: 4 ----------


 32%|███▏      | 100/308 [07:30<15:35,  4.50s/it]

Loss at step 100/308: 0.6988257193565368


 65%|██████▍   | 200/308 [14:59<08:03,  4.48s/it]

Loss at step 200/308: 0.6991364145278931


 97%|█████████▋| 300/308 [22:30<00:35,  4.50s/it]

Loss at step 300/308: 0.6978613805770874


100%|██████████| 308/308 [23:06<00:00,  4.50s/it]


Loss: 0.69844455726735
              precision    recall  f1-score   support

           0       0.49      0.41      0.45      2442
           1       0.50      0.57      0.53      2483

    accuracy                           0.49      4925
   macro avg       0.49      0.49      0.49      4925
weighted avg       0.49      0.49      0.49      4925

---------- VALID EPOCH: 4 ----------


100%|██████████| 103/103 [06:52<00:00,  4.00s/it]

Loss: 0.6987388463853632
              precision    recall  f1-score   support

           0       0.49      1.00      0.66       807
           1       0.00      0.00      0.00       835

    accuracy                           0.49      1642
   macro avg       0.25      0.50      0.33      1642
weighted avg       0.24      0.49      0.32      1642





  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [18]:
validate(model, test_loader)

100%|██████████| 103/103 [06:52<00:00,  4.01s/it]

Loss: 0.6949231422063217
              precision    recall  f1-score   support

           0       0.51      1.00      0.68       845
           1       0.00      0.00      0.00       797

    accuracy                           0.51      1642
   macro avg       0.26      0.50      0.34      1642
weighted avg       0.26      0.51      0.35      1642




  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [32]:
for i, batch in enumerate(tqdm(train_loader)):
    out = model(batch[0].to(device))
    print(out)
    out = torch.squeeze(torch.sigmoid(out).cpu(), 1)
    print(out)
    loss = criterion(out, torch.tensor(batch[1], dtype=torch.float))
    print(loss)
    preds = torch.where(out >= 0.5, 1, 0)
    print(preds)
    print(batch[1])

    break

  0%|          | 0/1232 [00:01<?, ?it/s]

tensor([[-0.0314],
        [-0.0078],
        [-0.0403],
        [ 0.0338]], device='cuda:0', grad_fn=<AddmmBackward0>)
tensor([0.4922, 0.4981, 0.4899, 0.5085], grad_fn=<SqueezeBackward1>)
tensor(0.6791, grad_fn=<BinaryCrossEntropyBackward0>)
tensor([0, 0, 0, 1])
[0, 0, 0, 1]



