In [59]:
import torch
from torch import nn
import torch.optim as optim
import torch.utils.data as data
import torchvision.transforms as transforms

import pandas as pd
from pathlib import Path
from typing import List, Dict
from collections import defaultdict
from PIL import Image
from tqdm import tqdm

import clip
from clip.simple_tokenizer import SimpleTokenizer

In [45]:
# load CLIP / model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device, jit=False)

In [50]:
class PathologyPairDataset(data.Dataset): 
    def __init__(
        self, 
        img_dir: Path, 
        txt_dir: Path, 
        case_ids: List[int], 
        transform = None, 
    ): 
        """
        img_dir and txt_dir. 

        Use case_ids to perform train_test split at the start,
        can choose which cases are actually read in
        """
        self.img_dir = img_dir
        self.txt_dir = txt_dir
        self.case_ids = case_ids
        self.transform = transform

        # len determined by number of image paths
        self.img_paths = list(img_dir.rglob("*/*.png") )
        self.txt_paths = list(txt_dir.rglob("*.txt"))
        
        self.case2txt = defaultdict(str)
        for txt_path in self.txt_paths: 
            case_id = str(txt_path.stem).split(".")[0]
            if case_id in case_ids:
                self.case2txt[case_id] = txt_path
        
        # amortize case lookup
        self.img_txt_pairs = []
        for img_path in self.img_paths: 
            case_id = "-".join(str(img_path.stem).split(".")[0].split("-")[:3])
            if case_id in case_ids:
                self.img_txt_pairs.append((img_path, self.case2txt[case_id]))
            
    def __len__(self):
        return len(self.img_txt_pairs)

    def __getitem__(self, idx): 
        img_path, txt_path = self.img_txt_pairs[idx]
        
        img = Image.open(img_path)
        if self.transform is not None:
            img = self.transform(img)
        
        with open(txt_path, "r") as f: 
            txt = f.read()

        return img, txt


In [51]:
img_dir = Path("./data/patches/10.0_224")
txt_dir = Path("./data/reports/")
# read in from csv, sample one from each class
labels_path = Path("./data") / "labels_40.csv"
labels_df = pd.read_csv(labels_path)
all_cases = list(labels_df["Case ID"])

dataset = PathologyPairDataset(
    img_dir=img_dir, 
    txt_dir=txt_dir,
    case_ids=all_cases,
    transform=preprocess, 
)
print("Len of dataset: ", len(dataset))

Len of dataset:  76252


In [52]:
# create data loader
batch_size: int = 64
shuffle: int = True
num_workers: int = 4

loader_params = {
    'batch_size': batch_size, 
    'shuffle': shuffle, 
    'num_workers': num_workers, 
}
data_loader = data.DataLoader(dataset, **loader_params)

In [56]:
def preprocess_text(texts, model):
    _tokenizer = SimpleTokenizer()
    sot_token = _tokenizer.encoder["<|startoftext|>"]
    eot_token = _tokenizer.encoder["<|endoftext|>"]
    all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
    result = torch.zeros(len(all_tokens), model.context_length, dtype=torch.long)
    
    # will trim tokens to match context length
    for i, tokens in enumerate(all_tokens):
        if len(tokens) > model.context_length:
            tokens = tokens[:model.context_length]
            tokens[model.context_length - 1] = eot_token
        result[i, :len(tokens)] = torch.tensor(tokens)
    return result

In [72]:
lr: float = 1e-4
momentum: float = 0.9
epochs: int = 4
log_interval: int = 10
save_interval: int = 100
save_dir: Path = Path("./checkpoints")
model_name: str = "sample_0"

criterion = nn.CrossEntropyLoss().cuda()
# optimizer = optim.AdamW(model.parameters(), lr=lr)
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)

In [75]:
def train(model, loader, device, criterion, optimizer, **kwargs): 
    model_save_dir = save_dir / model_name
    model_save_dir.mkdir(exist_ok=True, parents=True)
    
    # Run training
    total_batches = len(loader) * epochs
    example_ct = 0  # number of examples seen
    batch_ct = 0
    report_freq = log_interval
    highest_val_auc = 0 # save highest mean auc
    
    for epoch in range(epochs):
        running_loss = 0.0 # running loss over batch
        for data in tqdm(loader):
            # get the images
            images, texts = data
            texts = preprocess_text(texts, model) 
            
            # perform step for a single batch
            loss = train_batch(images, texts, model, device, criterion, optimizer)
            example_ct +=  len(images)
            batch_ct += 1
            running_loss += loss.item()

            # Report metrics every `report_freq` batch
            if (batch_ct % report_freq) == 0:
                train_log(running_loss / report_freq, example_ct, epoch)
                running_loss = 0.0
            
            if (batch_ct % save_interval) == 0: 
                model_path = model_save_dir / f"checkpoint_{str(batch_ct)}.pt"
                print("Saved checkpoint to: ", model_path)
                save(model, model_path)
                
def train_batch(images, texts, model, device, criterion, optimizer):
    images, texts = images.to(device), texts.to(device)
    
    # Forward pass ➡
    logits_per_image, logits_per_text = model(images, texts)
    
    # Create labels
    batch_size = images.shape[0]
    labels = torch.arange(batch_size).to(device)
    
    # Compute loss
    loss_img = criterion(logits_per_image, labels)
    loss_txt = criterion(logits_per_text, labels)
    loss = (loss_img + loss_txt)/2 # avg. img and txt loss

    # Backward pass ⬅
    optimizer.zero_grad()
    loss.backward()
    
    # Step with optimizer
    optimizer.step()
        
    return loss

def train_log(loss, example_ct, epoch):
    loss = float(loss)
    print(f"Loss after " + str(example_ct).zfill(5) + f" examples: {loss:.3f}")
    
def save(model, path): 
    torch.save(model.state_dict(), path)

In [77]:
train(model, data_loader, device, criterion, optimizer)

  0%|          | 0/1192 [00:00<?, ?it/s]

texts: tensor([[49406, 15715,   281,  ...,  1150, 12544, 49407],
        [49406, 15715,   281,  ...,     0,     0,     0],
        [49406, 15715, 20616,  ...,     0,     0,     0],
        ...,
        [49406, 49407,     0,  ...,     0,     0,     0],
        [49406, 15715,   281,  ...,  1150, 12544, 49407],
        [49406, 49407,     0,  ...,     0,     0,     0]])


  0%|          | 1/1192 [00:01<21:00,  1.06s/it]

texts: tensor([[49406, 49407,     0,  ...,     0,     0,     0],
        [49406, 15715,   281,  ...,  8437,  9313, 49407],
        [49406, 49407,     0,  ...,     0,     0,     0],
        ...,
        [49406, 15715,   281,  ...,  4178,   631, 49407],
        [49406, 15715,   267,  ...,  6262, 15715, 49407],
        [49406, 15715,    25,  ...,   593,  1653, 49407]])


  0%|          | 2/1192 [00:01<13:35,  1.46it/s]

texts: tensor([[49406, 49407,     0,  ...,     0,     0,     0],
        [49406, 49407,     0,  ...,     0,     0,     0],
        [49406, 15715,   281,  ..., 15368,   269, 49407],
        ...,
        [49406, 15715,   281,  ...,  4178,   631, 49407],
        [49406, 49407,     0,  ...,     0,     0,     0],
        [49406, 15715, 20616,  ...,     0,     0,     0]])


  0%|          | 3/1192 [00:01<11:21,  1.75it/s]

texts: tensor([[49406, 15715,    25,  ...,   593,  1653, 49407],
        [49406, 49407,     0,  ...,     0,     0,     0],
        [49406, 15715,    25,  ...,   593,  1653, 49407],
        ...,
        [49406, 15715,   281,  ...,  1150, 12544, 49407],
        [49406, 49407,     0,  ...,     0,     0,     0],
        [49406, 15715,  1474,  ...,  5306, 40534, 49407]])


  0%|          | 4/1192 [00:02<10:40,  1.85it/s]

texts: tensor([[49406, 15715,   281,  ...,  4178,   631, 49407],
        [49406, 49407,     0,  ...,     0,     0,     0],
        [49406, 15715,   281,  ...,  1150, 12544, 49407],
        ...,
        [49406, 15715,   281,  ..., 22616, 16439, 49407],
        [49406, 49407,     0,  ...,     0,     0,     0],
        [49406, 49407,     0,  ...,     0,     0,     0]])


  0%|          | 5/1192 [00:02<09:55,  1.99it/s]

texts: tensor([[49406, 49407,     0,  ...,     0,     0,     0],
        [49406, 15715,   281,  ...,  8437,  9313, 49407],
        [49406, 49407,     0,  ...,     0,     0,     0],
        ...,
        [49406, 49407,     0,  ...,     0,     0,     0],
        [49406, 49407,     0,  ...,     0,     0,     0],
        [49406, 15715, 20616,  ...,     0,     0,     0]])


  1%|          | 6/1192 [00:03<09:25,  2.10it/s]

texts: tensor([[49406, 15715,   261,  ..., 21010,   533, 49407],
        [49406, 15715,   261,  ..., 21010,   533, 49407],
        [49406, 49407,     0,  ...,     0,     0,     0],
        ...,
        [49406, 15715,  5896,  ...,     0,     0,     0],
        [49406, 15715,   261,  ..., 21010,   533, 49407],
        [49406, 49407,     0,  ...,     0,     0,     0]])


  1%|          | 6/1192 [00:03<12:18,  1.61it/s]


KeyboardInterrupt: 