In [1]:
%load_ext autoreload
%autoreload 2

import torch
from torch import nn
import torch.optim as optim
import torch.utils.data as data
import torchvision.transforms as transforms

import numpy as np
import pandas as pd
from pathlib import Path
from typing import List, Dict
from collections import defaultdict
from PIL import Image
from tqdm import tqdm
import wandb 
import random

import clip
from clip.simple_tokenizer import SimpleTokenizer
from loader import PathologyPairDataset
# from builder import PathZero

In [2]:
def extract_case_from_img(img_path: Path): 
    return "-".join(str(img_path.stem).split(".")[0].split("-")[:3])

In [3]:
# get valid cases, match images with texts
img_dir = Path("../../../../data/data/patches/pt")
txt_dir = Path("../../../../data/data/reports")

patch_folders = list(img_dir.glob("*.pt"))
txt_paths = list(txt_dir.rglob("*.txt"))

# get unique case ids from img_paths
unique_case_ids = [extract_case_from_img(patch_folder) for patch_folder in patch_folders]
assert(len(np.unique(unique_case_ids)) == len(unique_case_ids))

In [4]:
# read in from csv, sample one from each class
labels_path = Path("../../../../data/data") / "labels_40.csv"
labels_df = pd.read_csv(labels_path)
all_cases = list(labels_df["Case ID"])

dataset = PathologyPairDataset(
    img_dir=img_dir, 
    txt_dir=txt_dir,
    case_ids=unique_case_ids,
    transform=None, 
)
print("Len of dataset: ", len(dataset))

Len of dataset:  1332


In [148]:
# create data loader
batch_size: int = 8
shuffle: int = False
num_workers: int = 0

def collate_fn(batch):
    # Get the maximum sequence length in the batch for the first output
    max_len = max([x[0].shape[0] for x in batch])

    # Pad all sequences in the batch to the same length for the first output
    padded_batch_0 = []
    for x in batch:
        padded_seq = torch.zeros((max_len, 2048), dtype=torch.float)
        padded_seq[:x[0].shape[0], :] = x[0]
        padded_batch_0.append(padded_seq)

    # Stack the padded sequences into a tensor for the first output
    padded_batch_0 = torch.stack(padded_batch_0)

    # Stack the second and third outputs into tensors
    padded_batch_1 = [x[1] for x in batch]
    padded_batch_2 = [x[2] for x in batch]

    return padded_batch_0, padded_batch_1, padded_batch_2

loader_params = {
    'batch_size': batch_size, 
    'shuffle': shuffle, 
    'num_workers': num_workers, 
    'collate_fn': collate_fn, 
}

data_loader = data.DataLoader(dataset, **loader_params)

In [6]:
# # load CLIP / model
device = "cuda" if torch.cuda.is_available() else "cpu"
clip_model, preprocess = clip.load("ViT-B/32", device=device, jit=False)

In [188]:
from collections import OrderedDict
from os.path import join
import pdb
from turtle import forward

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

from typing import Optional
import clip
from clip.simple_tokenizer import SimpleTokenizer

###########################
### PathZero Implementation ###
###########################
class PathZero(nn.Module): 
    def __init__(
        self, 
        image_input_dim: int = 2048, # d_i 
        embedding_dim: int = 256, # d_e
        logits_dim: int = 256, # d_l
        dropout: float = 0.25,
        clip_model: Optional[nn.Module] = None
    ): 
        super(PathZero, self).__init__()

        self.image_input_dim = image_input_dim 
        self.embedding_dim = embedding_dim # dim for coattn
        self.logits_dim = logits_dim # dim for contrastive
        self.dropout = dropout
        
        ## FC over WSI bag --> convert to embedding dim
        fc = [nn.Linear(image_input_dim, embedding_dim), nn.ReLU()]
        fc.append(nn.Dropout(dropout))
        self.wsi_net = nn.Sequential(*fc)

        ## Text encoder -- CLIP
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        if clip_model is None:
            self.clip_model, _ = clip.load("ViT-B/32", device=self.device, jit=False)
        else: 
            self.clip_model = clip_model
        self.clip_model.train()

        # Text Linear Layer, map to embedding dim NOTE: 512 = clip.text_encoder output dim
        txt_fc = [nn.Linear(512, embedding_dim), nn.ReLU()]
        txt_fc.append(nn.Dropout(dropout))
        self.text_net = nn.Sequential(*txt_fc) # for initial text features going into Co-Attn
        # text logits
        txt_logits_fc = [nn.Linear(512, logits_dim), nn.ReLU(), nn.Dropout(dropout)]
        self.text_logits_net = nn.Sequential(*txt_logits_fc) # for final text features before going into CLIP

        ## Co-Attention
        ### Multihead Attention
        self.coattn = nn.MultiheadAttention(embed_dim=embedding_dim, num_heads=1, batch_first=True)

        ### WSI Transformer + Attention Head
        path_encoder_layer = nn.TransformerEncoderLayer(d_model=embedding_dim, nhead=8, dim_feedforward=logits_dim, dropout=dropout, activation='relu')
        self.path_transformer = nn.TransformerEncoder(path_encoder_layer, num_layers=2)
        self.path_attention_head = Attn_Net_Gated(L=logits_dim, D=logits_dim, dropout=dropout, n_classes=1)
        self.path_rho = nn.Sequential(*[nn.Linear(logits_dim, embedding_dim), nn.ReLU(), nn.Dropout(dropout)]) # linear layer
        
    def encode_text(self, text: torch.Tensor): 
        """
        Modification of original CLIP encode_text that results in
        embeddings for each token instead of taking the features from the
        eot embedding. 

        See reference https://github.com/openai/CLIP/blob/main/clip/model.py

        """
        x = self.clip_model.token_embedding(text).type(self.clip_model.dtype)  # [batch_size, n_ctx, d_model]

        x = x + self.clip_model.positional_embedding.type(self.clip_model.dtype)
        x = x.permute(1, 0, 2)  # NLD -> LND
        x = self.clip_model.transformer(x)
        x = x.permute(1, 0, 2)  # LND -> NLD
        x = self.clip_model.ln_final(x).type(self.clip_model.dtype)
        return x

    def forward(self, img: torch.Tensor, txt: torch.Tensor, verbose: bool = False): 
        """
        Given image and text, returns both image and text logits. Also returns
        attention scores from co-attention, image, and text. 

        Args: 
            img: N x P x d_i x 1 x 1
            txt: N x T x 1
        Note: 
            batch_size is always equal to 1
        
        """

        ### Image linear layer
        # print("img shape: ", img.shape)
        # shape: (N x P x d_i)
        img_bag = self.wsi_net(img)  ### path embeddings are fed through a FC layer
        if len(img_bag.shape) == 2: 
            img_bag = img_bag.unsqueeze(1)
        # print("img_bag shape: ", img_bag.shape)
        # shape: (N x P x d_e)

        ### Text encoder
        # print("txt.shape: ", txt.shape)
        text_bag_init = self.encode_text(txt) # obtain clip text encoder embeddings
        text_bag_init_float = text_bag_init.float()
        # shape: (N x T x d_t)
        # print("text_bag.shape (before linear): ", text_bag_init_float.shape)
        
        # d_t --> d_e embedding dim to make same as img dim for coattn
        text_bag = self.text_net(text_bag_init_float) 
        # shape: (N x T x d_e)
        # print("text_bag.shape (after linear): ", text_bag.shape)
        
        ### Apply co-attn -- args: query, key, value
        img_coattn, A_coattn = self.coattn(text_bag, img_bag, img_bag)
        # print("img_coattn.shape: ", img_coattn.shape) # shape: (N x T x d_e)
        # print("A_coattn.shape: ", A_coattn.shape) # shape: (P x T)
        
        ### Apply WSI transformer --> changes dim from embed_dim (d_e) to logits_dim (d_l)
        img_trans = self.path_transformer(img_coattn) # shape: (N x T x d_l)
        # print("img_trans.shape: ", img_trans.shape)

        ## Apply Attention over the transformer features per token
        A_img, img_features = self.path_attention_head(img_trans.squeeze(1)) # attention mechanism
        # print("A_img.shape (before transpose): ", A_img.shape)
        # A_img = torch.squeeze(A_img, 2) # remove last dim of 1 before softmax and matrix multiply
        A_img = torch.transpose(A_img, 1, 2) # swap last two dims, 1 and token_length
        # print("A_img.shape (after transpose): ", A_img.shape)
        # print("img_features.shape: ", img_features.shape)
        img_features = torch.bmm(F.softmax(A_img, dim=1), img_features) # (N x 1 x d_l)
        img_features = torch.squeeze(img_features)
        # print("img_features.shape (after mm): ", img_features.shape) # shape: (N x d_l)
        img_features = self.path_rho(img_features) # shape: (N x d_l)

        ### Normalize features
        image_features = img_features / img_features.norm(dim=1, keepdim=True) # (N x d_l)
        # aggregate tokens --> single text representation (see CLIP https://github.com/openai/CLIP/blob/main/clip/model.py#L354)
        text_agg = text_bag_init[torch.arange(text_bag_init.shape[0]), txt.argmax(dim=-1)] @ self.clip_model.text_projection
        text_agg = text_agg.float() # shape: (N x d_e)
        # print("text_agg.shape: ", text_agg.shape, text_agg.dtype)
        # convert from d_e --> d_l for text agg embeddings
        text_agg = self.text_logits_net(text_agg) # shape: (N x d_l)
        # linear layer to get into logits shape
        text_features = text_agg / text_agg.norm(dim=1, keepdim=True) # shape: (N x d_l)
        # print("image_features.shape: ", image_features.shape)
        # print("text_features.shape: ", text_features.shape)

        # image_features = torch.squeeze(image_features)
        # text_features = torch.squeeze(text_features)
        # print("image_features.shape (after squeeze): ", image_features.shape)
        # print("text_features.shape (after squeeze): ", text_features.shape)
        
        ### Obtain dot product logits
        # cosine similarity as logits
        logit_scale = self.clip_model.logit_scale.exp()
        logits_per_image = logit_scale * image_features @ text_features.t() # shape: (N x N)
        logits_per_text = logits_per_image.t() # shape: (N x N)

        # print("logits_per_image.shape: ", logits_per_image.shape)
        # print("logits_per_text.shape: ", logits_per_text.shape)

        # shape = [global_batch_size, global_batch_size]
        attention_scores = {'coattn': A_coattn, 'img': A_img}
        return logits_per_image, logits_per_text, attention_scores

class Attn_Net_Gated(nn.Module):
    def __init__(self, L = 1024, D = 256, dropout = False, n_classes = 1):
        r"""
        See reference https://github.com/mahmoodlab/MCAT/blob/master/models/model_utils.py
        Attention Network with Sigmoid Gating (3 fc layers)
        args:
            L (int): input feature dimension
            D (int): hidden layer dimension
            dropout (bool): whether to apply dropout (p = 0.25)
            n_classes (int): number of classes
        """
        super(Attn_Net_Gated, self).__init__()
        self.attention_a = [
            nn.Linear(L, D),
            nn.Tanh()]
        
        self.attention_b = [nn.Linear(L, D), nn.Sigmoid()]
        if dropout:
            self.attention_a.append(nn.Dropout(0.25))
            self.attention_b.append(nn.Dropout(0.25))

        self.attention_a = nn.Sequential(*self.attention_a)
        self.attention_b = nn.Sequential(*self.attention_b)
        self.attention_c = nn.Linear(D, n_classes)

    def forward(self, x):
        a = self.attention_a(x)
        b = self.attention_b(x)
        A = a.mul(b)
        A = self.attention_c(A)  # N x n_classes
        return A, x

In [189]:
# Experiment with PathZero model
model = PathZero(clip_model=clip_model)
model = model.to(device)

In [190]:
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"The model has {num_params} trainable parameters.")

The model has 153316866 trainable parameters.


In [172]:
def preprocess_text(texts, model):
    _tokenizer = SimpleTokenizer()
    sot_token = _tokenizer.encoder["<|startoftext|>"]
    eot_token = _tokenizer.encoder["<|endoftext|>"]
    all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
    result = torch.zeros(len(all_tokens), model.context_length, dtype=torch.long)
    
    # will trim tokens to match context length
    for i, tokens in enumerate(all_tokens):
        if len(tokens) > model.context_length:
            tokens = tokens[:model.context_length]
            tokens[model.context_length - 1] = eot_token
        result[i, :len(tokens)] = torch.tensor(tokens)
    return result

In [173]:
for image, text, _ in data_loader:
    text = preprocess_text(text, model.clip_model)
    image = image.to(device)
    text = text.to(device)
    
    image_logits, text_logits, attention_scores = model(image, text)

    # training
    print(image_logits)
    print(text_logits)
    break

tensor([[30.1023, 21.9178, 22.7231, 27.2820, 22.4517, 24.2126, 30.1182, 24.7481],
        [23.9379, 25.9622, 24.1611, 21.9616, 18.0911, 22.6735, 25.9819, 24.1141],
        [29.6417, 31.0994, 26.9708, 27.7571, 20.5740, 27.0729, 31.9664, 25.8329],
        [25.3947, 19.7767, 16.5627, 22.5389, 18.4853, 20.5879, 25.6718, 19.0675],
        [21.9395, 27.3583, 21.4484, 22.0770, 17.7129, 18.1122, 21.6196, 18.3286],
        [30.0496, 24.8225, 24.5744, 28.0990, 19.1271, 25.2953, 31.1733, 22.8713],
        [23.7202, 23.9463, 19.4798, 22.0033, 14.2205, 16.1057, 25.7246, 18.9781],
        [23.9488, 23.6184, 17.0340, 20.5861, 19.3954, 16.2704, 27.3494, 22.9458]],
       device='cuda:0', grad_fn=<MmBackward0>)
tensor([[30.1023, 23.9379, 29.6417, 25.3947, 21.9395, 30.0496, 23.7202, 23.9488],
        [21.9178, 25.9622, 31.0994, 19.7767, 27.3583, 24.8225, 23.9463, 23.6184],
        [22.7231, 24.1611, 26.9708, 16.5627, 21.4484, 24.5744, 19.4798, 17.0340],
        [27.2820, 21.9616, 27.7571, 22.5389, 22.07

In [177]:
lr: float = 1e-4
momentum: float = 0.9
epochs: int = 4
log_interval: int = 10
save_interval: int = 150
save_dir: Path = Path("./checkpoints")
model_name: str = "path_zero_v0"

criterion = nn.CrossEntropyLoss().cuda()
# optimizer = optim.AdamW(model.parameters(), lr=lr)
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)

# init wandb config
config = {
    "lr": lr,
    "momentum": momentum, 
    "epochs": epochs, 
    "log_interval": log_interval, 
    "save_interval": save_interval, 
    "save_dir": "save_dir", 
    "model_name": "sample_0", 
}

# wandb.init(
#     # set the wandb project where this run will be logged
#     project="path-zero",
#     # track hyperparameters and run metadata
#     config=config
# )

In [178]:
def train(model, loader, device, criterion, optimizer, config): 
    model_save_dir = save_dir / model_name
    model_save_dir.mkdir(exist_ok=True, parents=True)
    
    # Run training
    total_batches = len(loader) * epochs
    example_ct = 0  # number of examples seen
    batch_ct = 0
    report_freq = log_interval
    highest_val_auc = 0 # save highest mean auc
    
    for epoch in range(epochs):
        running_loss = 0.0 # running loss over batch
        for data in tqdm(loader):
            # get the images
            images, texts, _ = data
            texts = preprocess_text(texts, model.clip_model) 
            
            # perform step for a single batch
            loss = train_batch(images, texts, model, device, criterion, optimizer)
            example_ct +=  len(images)
            batch_ct += 1
            running_loss += loss.item()

            # Report metrics every `report_freq` batch
            if (batch_ct % report_freq) == 0:
                train_log(running_loss / report_freq, example_ct, epoch)
                running_loss = 0.0
            
            if (batch_ct % save_interval) == 0: 
                model_path = model_save_dir / f"checkpoint_{str(batch_ct)}.pt"
                print("Saved checkpoint to: ", model_path)
                save(model, model_path)
                
def train_batch(images, texts, model, device, criterion, optimizer):
    images, texts = images.to(device), texts.to(device)
    
    # Forward pass ➡
    logits_per_image, logits_per_text, attention_scores = model(images, texts)

    # Create labels
    batch_size = images.shape[0]
    labels = torch.arange(batch_size).to(device)
    
    # Compute loss
    loss_img = criterion(logits_per_image, labels)
    loss_txt = criterion(logits_per_text, labels)
    loss = (loss_img + loss_txt)/2 # avg. img and txt loss

    # Backward pass ⬅
    optimizer.zero_grad()
    loss.backward()
    
    # Step with optimizer
    optimizer.step()
        
    return loss

def train_log(loss, example_ct, epoch):
    loss = float(loss)
    # save to wandb 
    wandb.log({"epoch": epoch, "loss": loss})
    # print to log
    print(f"Loss after " + str(example_ct).zfill(5) + f" examples: {loss:.3f}")
    
def save(model, path): 
    torch.save(model.state_dict(), path)

In [179]:
train(model, data_loader, device, criterion, optimizer, config)
wandb.finish()

  6%|▌         | 10/167 [00:06<01:44,  1.50it/s]

Loss after 00080 examples: 3.126


 12%|█▏        | 20/167 [00:17<01:35,  1.54it/s]

Loss after 00160 examples: 2.304


 18%|█▊        | 30/167 [00:30<03:33,  1.56s/it]

Loss after 00240 examples: 2.121


 24%|██▍       | 40/167 [00:43<04:01,  1.90s/it]

Loss after 00320 examples: 2.120


 30%|██▉       | 50/167 [00:58<02:09,  1.11s/it]

Loss after 00400 examples: 2.100


 36%|███▌      | 60/167 [01:09<02:17,  1.28s/it]

Loss after 00480 examples: 2.096


 42%|████▏     | 70/167 [01:26<02:16,  1.40s/it]

Loss after 00560 examples: 2.096


 48%|████▊     | 80/167 [01:36<01:25,  1.01it/s]

Loss after 00640 examples: 2.090


 54%|█████▍    | 90/167 [01:49<01:17,  1.00s/it]

Loss after 00720 examples: 2.096


 60%|█████▉    | 100/167 [02:00<01:06,  1.01it/s]

Loss after 00800 examples: 2.085


 66%|██████▌   | 110/167 [02:07<00:45,  1.25it/s]

Loss after 00880 examples: 2.077


 72%|███████▏  | 120/167 [02:20<01:10,  1.50s/it]

Loss after 00960 examples: 2.079


 78%|███████▊  | 130/167 [02:33<00:39,  1.05s/it]

Loss after 01040 examples: 2.079


 84%|████████▍ | 140/167 [02:43<00:24,  1.08it/s]

Loss after 01120 examples: 2.084


 89%|████████▉ | 149/167 [02:54<00:22,  1.24s/it]

Loss after 01200 examples: 2.084
Saved checkpoint to:  checkpoints/path_zero_v0/checkpoint_150.pt


 96%|█████████▌| 160/167 [03:07<00:06,  1.15it/s]

Loss after 01280 examples: 2.080


100%|██████████| 167/167 [03:17<00:00,  1.18s/it]
  2%|▏         | 3/167 [00:01<01:36,  1.70it/s]

Loss after 01356 examples: 0.624


  8%|▊         | 13/167 [00:17<06:01,  2.35s/it]

Loss after 01436 examples: 2.081


 14%|█▍        | 23/167 [00:25<02:29,  1.04s/it]

Loss after 01516 examples: 2.076


 20%|█▉        | 33/167 [00:39<03:13,  1.45s/it]

Loss after 01596 examples: 2.080


 26%|██▌       | 43/167 [00:55<04:17,  2.08s/it]

Loss after 01676 examples: 2.081


 32%|███▏      | 53/167 [01:05<02:10,  1.15s/it]

Loss after 01756 examples: 2.080


 38%|███▊      | 63/167 [01:17<02:25,  1.40s/it]

Loss after 01836 examples: 2.076


 44%|████▎     | 73/167 [01:32<02:01,  1.29s/it]

Loss after 01916 examples: 2.084


 50%|████▉     | 83/167 [01:41<01:14,  1.12it/s]

Loss after 01996 examples: 2.077


 56%|█████▌    | 93/167 [01:54<01:05,  1.13it/s]

Loss after 02076 examples: 2.080


 62%|██████▏   | 103/167 [02:03<00:37,  1.69it/s]

Loss after 02156 examples: 2.080


 68%|██████▊   | 113/167 [02:13<01:12,  1.34s/it]

Loss after 02236 examples: 2.077


 74%|███████▎  | 123/167 [02:27<01:17,  1.77s/it]

Loss after 02316 examples: 2.078


 79%|███████▉  | 132/167 [02:37<00:36,  1.04s/it]

Loss after 02396 examples: 2.083
Saved checkpoint to:  checkpoints/path_zero_v0/checkpoint_300.pt


 86%|████████▌ | 143/167 [02:47<00:20,  1.15it/s]

Loss after 02476 examples: 2.078


 92%|█████████▏| 153/167 [03:02<00:20,  1.49s/it]

Loss after 02556 examples: 2.081


 98%|█████████▊| 163/167 [03:13<00:07,  1.81s/it]

Loss after 02636 examples: 2.081


100%|██████████| 167/167 [03:17<00:00,  1.19s/it]
  4%|▎         | 6/167 [00:03<01:50,  1.45it/s]

Loss after 02712 examples: 1.246


 10%|▉         | 16/167 [00:19<03:19,  1.32s/it]

Loss after 02792 examples: 2.080


 16%|█▌        | 26/167 [00:28<02:35,  1.10s/it]

Loss after 02872 examples: 2.083


 22%|██▏       | 36/167 [00:39<01:41,  1.29it/s]

Loss after 02952 examples: 2.080


 28%|██▊       | 46/167 [00:57<02:58,  1.48s/it]

Loss after 03032 examples: 2.080


 34%|███▎      | 56/167 [01:08<02:07,  1.14s/it]

Loss after 03112 examples: 2.078


 40%|███▉      | 66/167 [01:23<03:15,  1.94s/it]

Loss after 03192 examples: 2.080


 46%|████▌     | 76/167 [01:35<01:44,  1.15s/it]

Loss after 03272 examples: 2.079


 51%|█████▏    | 86/167 [01:48<01:56,  1.44s/it]

Loss after 03352 examples: 2.077


 57%|█████▋    | 96/167 [01:57<01:10,  1.01it/s]

Loss after 03432 examples: 2.078


 63%|██████▎   | 106/167 [02:05<00:44,  1.38it/s]

Loss after 03512 examples: 2.081


 69%|██████▉   | 115/167 [02:13<00:51,  1.00it/s]

Loss after 03592 examples: 2.081
Saved checkpoint to:  checkpoints/path_zero_v0/checkpoint_450.pt


 75%|███████▌  | 126/167 [02:24<00:34,  1.19it/s]

Loss after 03672 examples: 2.078


 81%|████████▏ | 136/167 [02:33<00:28,  1.10it/s]

Loss after 03752 examples: 2.077


 87%|████████▋ | 146/167 [02:44<00:32,  1.55s/it]

Loss after 03832 examples: 2.087


 93%|█████████▎| 156/167 [02:56<00:10,  1.06it/s]

Loss after 03912 examples: 2.079


 99%|█████████▉| 166/167 [03:10<00:01,  1.49s/it]

Loss after 03992 examples: 2.077


100%|██████████| 167/167 [03:10<00:00,  1.14s/it]
  5%|▌         | 9/167 [00:07<02:32,  1.04it/s]

Loss after 04068 examples: 1.870


 11%|█▏        | 19/167 [00:20<02:05,  1.18it/s]

Loss after 04148 examples: 2.081


 17%|█▋        | 29/167 [00:31<03:12,  1.40s/it]

Loss after 04228 examples: 2.081


 23%|██▎       | 39/167 [00:43<02:48,  1.32s/it]

Loss after 04308 examples: 2.077


 29%|██▉       | 49/167 [01:00<02:23,  1.22s/it]

Loss after 04388 examples: 2.078


 35%|███▌      | 59/167 [01:09<01:26,  1.25it/s]

Loss after 04468 examples: 2.073


 41%|████▏     | 69/167 [01:27<02:30,  1.54s/it]

Loss after 04548 examples: 2.085


 47%|████▋     | 79/167 [01:38<01:51,  1.27s/it]

Loss after 04628 examples: 2.085


 53%|█████▎    | 89/167 [01:51<01:23,  1.07s/it]

Loss after 04708 examples: 2.082


 59%|█████▊    | 98/167 [02:00<01:27,  1.26s/it]

Loss after 04788 examples: 2.081
Saved checkpoint to:  checkpoints/path_zero_v0/checkpoint_600.pt


 65%|██████▌   | 109/167 [02:08<00:51,  1.12it/s]

Loss after 04868 examples: 2.080


 71%|███████▏  | 119/167 [02:20<01:04,  1.33s/it]

Loss after 04948 examples: 2.079


 75%|███████▌  | 126/167 [02:31<00:49,  1.21s/it]


KeyboardInterrupt: 

## Init Train Config

In [60]:
lr: float = 1e-4
momentum: float = 0.9
epochs: int = 4
log_interval: int = 100
save_interval: int = 1000
save_dir: Path = Path("./checkpoints")
model_name: str = "sample_0"

criterion = nn.CrossEntropyLoss().cuda()
# optimizer = optim.AdamW(model.parameters(), lr=lr)
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)

# init wandb config
config = {
    "lr": lr,
    "momentum": momentum, 
    "epochs": epochs, 
    "log_interval": log_interval, 
    "save_interval": save_interval, 
    "save_dir": "save_dir", 
    "model_name": "sample_0", 
}

wandb.init(
    # set the wandb project where this run will be logged
    project="path-clip-v0",
    
    # track hyperparameters and run metadata
    config=config
)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit:[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit:[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit:[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit:[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit:[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit:[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl

In [64]:
def train(model, loader, device, criterion, optimizer, config): 
    model_save_dir = save_dir / model_name
    model_save_dir.mkdir(exist_ok=True, parents=True)
    
    # Run training
    total_batches = len(loader) * epochs
    example_ct = 0  # number of examples seen
    batch_ct = 0
    report_freq = log_interval
    highest_val_auc = 0 # save highest mean auc
    
    for epoch in range(epochs):
        running_loss = 0.0 # running loss over batch
        for data in tqdm(loader):
            # get the images
            images, texts, _ = data
            texts = preprocess_text(texts, model) 
            
            # perform step for a single batch
            loss = train_batch(images, texts, model, device, criterion, optimizer)
            example_ct +=  len(images)
            batch_ct += 1
            running_loss += loss.item()

            # Report metrics every `report_freq` batch
            if (batch_ct % report_freq) == 0:
                train_log(running_loss / report_freq, example_ct, epoch)
                running_loss = 0.0
            
            if (batch_ct % save_interval) == 0: 
                model_path = model_save_dir / f"checkpoint_{str(batch_ct)}.pt"
                print("Saved checkpoint to: ", model_path)
                save(model, model_path)
                
def train_batch(images, texts, model, device, criterion, optimizer):
    images, texts = images.to(device), texts.to(device)
    
    # Forward pass ➡
    logits_per_image, logits_per_text = model(images, texts)
    print("logits_per_image.shape: ", logits_per_image.shape)
    print("logits_per_text.shape: ", logits_per_text.shape)
    
    # Create labels
    batch_size = images.shape[0]
    labels = torch.arange(batch_size).to(device)
    
    # Compute loss
    loss_img = criterion(logits_per_image, labels)
    loss_txt = criterion(logits_per_text, labels)
    loss = (loss_img + loss_txt)/2 # avg. img and txt loss

    # Backward pass ⬅
    optimizer.zero_grad()
    loss.backward()
    
    # Step with optimizer
    optimizer.step()
        
    return loss

def train_log(loss, example_ct, epoch):
    loss = float(loss)
    # save to wandb 
    wandb.log({"epoch": epoch, "loss": loss})
    # print to log
    print(f"Loss after " + str(example_ct).zfill(5) + f" examples: {loss:.3f}")
    
def save(model, path): 
    torch.save(model.state_dict(), path)

In [65]:
train(clip_model, data_loader, device, criterion, optimizer, config)
wandb.finish()

  0%|          | 0/1332 [00:00<?, ?it/s]


RuntimeError: Given groups=1, weight of size [768, 3, 32, 32], expected input[1, 1, 178, 2048] to have 3 channels, but got 1 channels instead