In [2]:
import torch
from torch import nn
import torch.optim as optim
import torch.utils.data as data
import torchvision.transforms as transforms

import numpy as np
import pandas as pd
from pathlib import Path
from typing import List, Dict
from collections import defaultdict
from PIL import Image
from tqdm import tqdm
import wandb 
import random

import clip
from clip.simple_tokenizer import SimpleTokenizer

In [11]:
def extract_case_from_img(img_path: Path): 
    return "-".join(str(img_path.stem).split(".")[0].split("-")[:3])

In [24]:
# load CLIP / model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device, jit=False)

In [25]:
class PathologyPairDataset(data.Dataset): 
    def __init__(
        self, 
        img_dir: Path, 
        txt_dir: Path, 
        case_ids: List[int], 
        transform = None, 
    ): 
        """
        img_dir and txt_dir. 

        Use case_ids to perform train_test split at the start,
        can choose which cases are actually read in
        """
        self.img_dir = img_dir
        self.txt_dir = txt_dir
        self.case_ids = case_ids
        self.transform = transform

        # len determined by number of image paths
        self.img_paths = list(img_dir.rglob("*/*.png"))
        self.txt_paths = list(txt_dir.rglob("*.txt"))
        
        self.case2txt = defaultdict(str)
        for txt_path in self.txt_paths: 
            case_id = str(txt_path.stem).split(".")[0]
            if case_id in case_ids:
                self.case2txt[case_id] = txt_path
        
        # amortize case lookup
        self.img_txt_pairs = []
        for img_path in self.img_paths: 
            case_id = "-".join(str(img_path.stem).split(".")[0].split("-")[:3])
            if case_id in case_ids:
                self.img_txt_pairs.append((img_path, self.case2txt[case_id]))
            
    def __len__(self):
        return len(self.img_txt_pairs)

    def __getitem__(self, idx): 
        img_path, txt_path = self.img_txt_pairs[idx]
        
        img = Image.open(img_path)
        if self.transform is not None:
            img = self.transform(img)
        
        with open(txt_path, "r") as f: 
            txt = f.read()

        return img, txt


In [20]:
# get valid cases, match images with texts
img_dir = Path("../../../../data/data/patches/10.0_224")
txt_dir = Path("../../../../data/data/reports")

patch_folders = list(img_dir.glob("*"))
txt_paths = list(txt_dir.rglob("*.txt"))

# get unique case ids from img_paths
unique_case_ids = [extract_case_from_img(patch_folder) for patch_folder in patch_folders]
assert(len(np.unique(unique_case_ids)) == len(unique_case_ids))

In [27]:
# read in from csv, sample one from each class
labels_path = Path("../../../../data/data") / "labels_40.csv"
labels_df = pd.read_csv(labels_path)
all_cases = list(labels_df["Case ID"])

dataset = PathologyPairDataset(
    img_dir=img_dir, 
    txt_dir=txt_dir,
    case_ids=unique_case_ids,
    transform=preprocess, 
)
print("Len of dataset: ", len(dataset))

Len of dataset:  1999969


In [35]:
# create data loader
batch_size: int = 64
shuffle: int = True
num_workers: int = 4

loader_params = {
    'batch_size': batch_size, 
    'shuffle': shuffle, 
    'num_workers': num_workers, 
}
data_loader = data.DataLoader(dataset, **loader_params)

In [36]:
def preprocess_text(texts, model):
    _tokenizer = SimpleTokenizer()
    sot_token = _tokenizer.encoder["<|startoftext|>"]
    eot_token = _tokenizer.encoder["<|endoftext|>"]
    all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
    result = torch.zeros(len(all_tokens), model.context_length, dtype=torch.long)
    
    # will trim tokens to match context length
    for i, tokens in enumerate(all_tokens):
        if len(tokens) > model.context_length:
            tokens = tokens[:model.context_length]
            tokens[model.context_length - 1] = eot_token
        result[i, :len(tokens)] = torch.tensor(tokens)
    return result

In [42]:
lr: float = 1e-4
momentum: float = 0.9
epochs: int = 4
log_interval: int = 100
save_interval: int = 1000
save_dir: Path = Path("./checkpoints")
model_name: str = "sample_0"

criterion = nn.CrossEntropyLoss().cuda()
# optimizer = optim.AdamW(model.parameters(), lr=lr)
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)

# init wandb config
config = {
    "lr": lr,
    "momentum": momentum, 
    "epochs": epochs, 
    "log_interval": log_interval, 
    "save_interval": save_interval, 
    "save_dir": "save_dir", 
    "model_name": "sample_0", 
}

wandb.init(
    # set the wandb project where this run will be logged
    project="path-clip-v0",
    
    # track hyperparameters and run metadata
    config=config
)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mekintiu[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [43]:
def train(model, loader, device, criterion, optimizer, config): 
    model_save_dir = save_dir / model_name
    model_save_dir.mkdir(exist_ok=True, parents=True)
    
    # Run training
    total_batches = len(loader) * epochs
    example_ct = 0  # number of examples seen
    batch_ct = 0
    report_freq = log_interval
    highest_val_auc = 0 # save highest mean auc
    
    for epoch in range(epochs):
        running_loss = 0.0 # running loss over batch
        for data in tqdm(loader):
            # get the images
            images, texts = data
            texts = preprocess_text(texts, model) 
            
            # perform step for a single batch
            loss = train_batch(images, texts, model, device, criterion, optimizer)
            example_ct +=  len(images)
            batch_ct += 1
            running_loss += loss.item()

            # Report metrics every `report_freq` batch
            if (batch_ct % report_freq) == 0:
                train_log(running_loss / report_freq, example_ct, epoch)
                running_loss = 0.0
            
            if (batch_ct % save_interval) == 0: 
                model_path = model_save_dir / f"checkpoint_{str(batch_ct)}.pt"
                print("Saved checkpoint to: ", model_path)
                save(model, model_path)
                
def train_batch(images, texts, model, device, criterion, optimizer):
    images, texts = images.to(device), texts.to(device)
    
    # Forward pass ➡
    logits_per_image, logits_per_text = model(images, texts)
    
    # Create labels
    batch_size = images.shape[0]
    labels = torch.arange(batch_size).to(device)
    
    # Compute loss
    loss_img = criterion(logits_per_image, labels)
    loss_txt = criterion(logits_per_text, labels)
    loss = (loss_img + loss_txt)/2 # avg. img and txt loss

    # Backward pass ⬅
    optimizer.zero_grad()
    loss.backward()
    
    # Step with optimizer
    optimizer.step()
        
    return loss

def train_log(loss, example_ct, epoch):
    loss = float(loss)
    # save to wandb 
    wandb.log({"epoch": epoch, "loss": loss})
    # print to log
    print(f"Loss after " + str(example_ct).zfill(5) + f" examples: {loss:.3f}")
    
def save(model, path): 
    torch.save(model.state_dict(), path)

In [45]:
train(model, data_loader, device, criterion, optimizer, config)
wandb.finish()

  0%|          | 100/31250 [01:03<5:29:16,  1.58it/s]

Loss after 06400 examples: 3.527


  1%|          | 200/31250 [02:04<5:20:54,  1.61it/s]

Loss after 12800 examples: 2.362


  1%|          | 300/31250 [03:05<5:18:08,  1.62it/s]

Loss after 19200 examples: 1.627


  1%|▏         | 400/31250 [04:07<5:44:51,  1.49it/s]

Loss after 25600 examples: 1.189


  2%|▏         | 500/31250 [05:09<5:07:01,  1.67it/s]

Loss after 32000 examples: 0.912


  2%|▏         | 600/31250 [06:12<5:03:42,  1.68it/s]

Loss after 38400 examples: 0.810


  2%|▏         | 700/31250 [07:12<5:09:27,  1.65it/s]

Loss after 44800 examples: 0.665


  3%|▎         | 800/31250 [08:13<5:09:09,  1.64it/s]

Loss after 51200 examples: 0.602


  3%|▎         | 831/31250 [08:32<5:15:06,  1.61it/s]