In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from sklearn.metrics import roc_auc_score, precision_score, recall_score
from sklearn.model_selection import train_test_split
import clip
from pathlib import Path
import shutil

# load in clip
import os
import torch.utils.data as data
import torchvision.transforms as transforms
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader
from typing import List, Optional
from PIL import Image

from collections import defaultdict
from tqdm import tqdm

In [2]:
import cv2
import openslide

In [3]:
clip.__file__

'/opt/conda/envs/pytorch/lib/python3.9/site-packages/clip/__init__.py'

### Get Sample Data

In [4]:
data_dir = Path("./data")

In [5]:
# load clip vit models -- this should have pre-trained clip weights
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device, jit=False)

In [6]:
# read in from csv, sample one from each class
labels_path = data_dir / "labels_40.csv"
labels_df = pd.read_csv(labels_path)
# add column to labels 
labels_df["label"] = labels_df.apply(lambda x: x['Label'].split(" ")[0].lower(),axis=1)

sample_label_cases = list(labels_df.groupby("Label").first()['Case ID'])
labels_sample = labels_df[labels_df['Case ID'].isin(sample_label_cases)]

In [7]:
labels_sample

Unnamed: 0,Case ID,Label,label
0,TCGA-CJ-4642,Adenomas and Adenocarcinomas,adenomas
10,TCGA-BA-4077,Squamous Cell Neoplasms,squamous
20,TCGA-A8-A084,Ductal and Lobular Neoplasms,ductal
30,TCGA-06-0209,Gliomas,gliomas


In [8]:
# Get all labels 
all_cases = list(labels_df["Case ID"])

In [9]:
def get_svs_paths(all_svs: List[Path], cases: List[str]):
    # Returns single svs for each patient given a set of cases and a list of all svs
    svs_counts = defaultdict(int)
    svs_paths = []
    for svs_path in all_svs: # iterate through each svs
        for sample_case in cases: # for each case
            if sample_case in str(svs_path) and "DX" in str(svs_path): 
                if svs_counts[sample_case] == 0: 
                    svs_paths.append(svs_path)
                svs_counts[sample_case] += 1
    return svs_paths

all_svs = list(data_dir.rglob("./*.svs"))
svs_paths = get_svs_paths(all_svs=all_svs, cases=all_cases)
            

In [10]:
len(svs_paths)

36

In [11]:
# # make copies in separate folder
# slides_dir = data_dir / "slides"
# for svs in tqdm(svs_paths): 
#     shutil.copy(svs, slides_dir)

In [12]:
# move patches into correct folders
# do train test split

### Extract Patches

In [13]:
# basic parse some 224x224 patches from some slides
# THIS WAS NOT USED
def extract_patches(paths: List[Path], output_dir: Path, resolution: Optional[int] = None):
    """
    Extract 224x224 patches from SVS files that are not masked out by an Otsu threshold.

    :param paths: list of paths to SVS files
    :param output_dir: output directory to save patches
    :param resolution: resolution level to extract patches from
    """
    # Create the output directory if it does not exist
    Path(output_dir).mkdir(parents=True, exist_ok=True)

    for path in tqdm(paths):
        # Load the SVS file with OpenSlide
        slide = openslide.open_slide(str(path))

        if resolution is None: 
            resolution = slide.level_count - 1

        # Compute the dimensions of the slide at the given resolution level
        w, h = slide.level_dimensions[resolution]

        # Compute the dimensions of a 224x224 patch at the given resolution level
        patch_size = 224
        pw = (w - patch_size) // patch_size + 1
        ph = (h - patch_size) // patch_size + 1

        # Iterate over the patches and extract those that are not masked out by an Otsu threshold
        for x in range(pw):
            for y in range(ph):
                # Compute the coordinates of the patch in the slide
                x0 = x * patch_size
                y0 = y * patch_size
                x1 = x0 + patch_size
                y1 = y0 + patch_size

                # Read the region from the slide
                region = slide.read_region((x0, y0), resolution, (patch_size, patch_size))

                # Convert the region to a NumPy array and extract the grayscale channel
                img = np.array(region.convert('L'))

                # Compute the Otsu threshold and binarize the image
                threshold, mask = cv2.threshold(img, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)

                # If the mask is all white, the patch is not masked out
                # if np.all(mask == 255):
                # Save the patch to the output directory with the desired file name
                patch_name = f"{path.stem}_{x}_{y}.png"
                patch_path = Path(output_dir) / patch_name
                cv2.imwrite(str(patch_path), img)


In [14]:
# output_dir = data_dir / "patches"
# extract_patches(paths=svs_paths, output_dir=output_dir)

## Inference

### Data Loader

In [31]:
class PatchDataset(data.Dataset):
    def __init__(self, patch_paths: List[Path] , labels_df: pd.DataFrame, transform = None):
        """
        Notes:
            assumes `label_df` has columns `Case ID` and `label` mapping
            TCGA id to corresponding class. 
        """
        self.labels_df = labels_df
        self.patch_paths = patch_paths
        self.transform = transform
        
        self.last_idx = 0
        self.labels_to_idx = dict()

    def __len__(self):
        return len(self.patch_paths)

    def __getitem__(self, idx):
        # Get patch and preprocess
        patch_path = self.patch_paths[idx]
        patch_stem = patch_path.stem
        case_id = "-".join(str(patch_stem).split(".")[0].split("-")[:3])
        patch = Image.open(patch_path)
        if self.transform is not None:
            patch = self.transform(patch)
        
        # Get corresponding label
        label = self.labels_df[self.labels_df["Case ID"] == case_id]["label"]
        if len(label) == 0: 
            raise ValueError("Can't find corresponding label a given case.")
        label = label.values[0] # extract single label
        
#         print(label)
        if label not in self.labels_to_idx: 
            self.labels_to_idx[label] = self.last_idx
            self.last_idx += 1

        label_idx = self.labels_to_idx[label]
            
        return patch, label_idx

### Load Patches

In [16]:
patch_dir = data_dir / "patches" / "10.0_224"
patch_paths = list(patch_dir.rglob("./*.png"))
print("Number of paths: ", len(patch_paths))
# sample_image = preprocess(Image.open(patch_paths[0])).unsqueeze(0).to(device)

Number of paths:  76252


In [17]:
unq_ids = np.unique(["-".join(str(patch_path.stem).split(".")[0].split("-")[:3]) for patch_path in patch_paths])

In [18]:
train_ids, test_ids = train_test_split(unq_ids, test_size=0.2, random_state=42)

In [24]:
# get patches containing each id
def get_patches_from_ids(ids, patch_paths): 
    paths = []
    for id in ids: 
        for patch in patch_paths: 
            if id in str(patch): 
                paths.append(patch)
    return paths

train_paths = get_patches_from_ids(ids=train_ids, patch_paths=patch_paths)
test_paths = get_patches_from_ids(ids=test_ids, patch_paths=patch_paths)

In [28]:
print("Num train paths: ", len(train_paths))

Num train paths:  60164


In [29]:
print("Num test paths: ", len(test_paths))

Num test paths:  16088


In [104]:
# with torch.no_grad(): 
#     image_features = model.encode_image(sample_image)

In [None]:
# get patches containing these 

In [32]:
# run model on full test dataset
test_data_dir = patch_dir
# test_dataset = datasets.ImageFolder(test_data_dir, transform=preprocess)
train_dataset = PatchDataset(
    patch_paths=train_paths,
    labels_df=labels_df, 
    transform=preprocess,
)

test_dataset = PatchDataset(
    patch_paths=test_paths, 
    labels_df=labels_df, 
    transform=preprocess,
)

# Create a DataLoader to load the test dataset in batches
batch_size = 32
num_workers = 4
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

### Linear Evaluation
* First, train a linear model on a train dataset
* Second, evaluate linear model on a test dataset

In [33]:
logits_list = []
labels_list = []
for batch_idx, (images, labels) in tqdm(enumerate(train_loader), total=len(train_loader)):
    with torch.no_grad():
        images = images.to(device)
        features = model.encode_image(images)
        # move to gpu
        features.to(device)
        labels.to(device)
        
    logits_list.append(features)
    labels_list.append(labels)
logits = torch.cat(logits_list, dim=0)
labels = torch.cat(labels_list, dim=0)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1881/1881 [01:01<00:00, 30.76it/s]


In [34]:
# Create a DataLoader to load the test dataset in batches
num_ftrs = 512
num_classes = 4

# Define the linear classifier and the loss function
linear_classifier = nn.Linear(num_ftrs, num_classes)
linear_classifier = linear_classifier.to(device)
criterion = nn.CrossEntropyLoss()

# Define the optimizer for the linear classifier
optimizer = optim.SGD(linear_classifier.parameters(), lr=0.01, momentum=0.9)

# Set the linear classifier to training mode
linear_classifier.train()

Linear(in_features=512, out_features=4, bias=True)

In [35]:
# Train the linear classifier on the pre-trained logits using the test dataset
num_epochs = 30
for epoch in tqdm(range(num_epochs)):
    running_loss = 0.0
    # train for each of the logits and labels, want to batch this if possible
    for i, (features, targets) in enumerate(zip(logits, labels)):
        optimizer.zero_grad()

        features = features.to(device)
        targets = targets.to(device)

        outputs = linear_classifier(features.float())
        loss = criterion(outputs.unsqueeze(0), targets.unsqueeze(0))
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print('Epoch [%d/%d], Loss: %.4f' % (epoch+1, num_epochs, running_loss))


  3%|███▊                                                                                                              | 1/30 [00:27<13:22, 27.67s/it]

Epoch [1/30], Loss: 21104.3469


  7%|███████▌                                                                                                          | 2/30 [00:55<12:55, 27.69s/it]

Epoch [2/30], Loss: 16809.4639


 10%|███████████▍                                                                                                      | 3/30 [01:22<12:23, 27.55s/it]

Epoch [3/30], Loss: 13853.9231


 13%|███████████████▏                                                                                                  | 4/30 [01:50<11:57, 27.61s/it]

Epoch [4/30], Loss: 13155.4045


 17%|███████████████████                                                                                               | 5/30 [02:18<11:37, 27.92s/it]

Epoch [5/30], Loss: 13232.0555


 20%|██████████████████████▊                                                                                           | 6/30 [02:46<11:10, 27.96s/it]

Epoch [6/30], Loss: 11967.5044


 23%|██████████████████████████▌                                                                                       | 7/30 [03:14<10:40, 27.86s/it]

Epoch [7/30], Loss: 11474.8261


 27%|██████████████████████████████▍                                                                                   | 8/30 [03:42<10:10, 27.76s/it]

Epoch [8/30], Loss: 10956.5090


 30%|██████████████████████████████████▏                                                                               | 9/30 [04:09<09:41, 27.69s/it]

Epoch [9/30], Loss: 10453.4697


 33%|█████████████████████████████████████▋                                                                           | 10/30 [04:37<09:14, 27.73s/it]

Epoch [10/30], Loss: 10026.1160


 33%|█████████████████████████████████████▋                                                                           | 10/30 [04:48<09:37, 28.86s/it]


KeyboardInterrupt: 

In [50]:
# logits_list = []
# labels_list = []

# TODO: move into a function
num_epochs = 30
for epoch in range(num_epochs):
    running_loss = 0.0
    for batch_idx, (images, labels) in tqdm(enumerate(train_loader), total=len(train_loader)):
        with torch.no_grad():
            images = images.to(device)
            features = model.encode_image(images)
            features.to(device)
            labels_one_hot = torch.nn.functional.one_hot(labels, num_classes=4).float().to(device) # convert labels to one-hot encoding
            outputs = linear_classifier(features.float())
            loss = criterion(outputs, labels_one_hot)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
    
    print('Epoch [%d/%d], Loss: %.4f' % (epoch+1, num_epochs, running_loss))

# logits = torch.cat(logits_list, dim=0)
# labels = torch.cat(labels_list, dim=0)

  0%|                                                                                                                        | 0/1881 [00:00<?, ?it/s]


RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

In [39]:
# compute the model predictions for the test dataset
y_pred = []
y_true = []
model.eval()
with torch.no_grad():
    for images, labels in tqdm(test_loader):
        images = images.to(device)
        labels = labels.to(device)

        output = model.encode_image(images) # clip 
        output = linear_classifier(output.float())
        
        y_pred.append(output)
        y_true.append(labels)

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 503/503 [00:16<00:00, 31.19it/s]


In [40]:
# assume y_pred is a tensor of shape (num_examples, num_classes) containing the outputs of the linear classifier
# and y_true is a tensor of shape (num_examples) containing the categorical labels
y_pred = torch.cat(y_pred, dim=0)
y_true = torch.cat(y_true, dim=0)

y_prob = torch.softmax(y_pred, dim=1)
y_true = y_true.detach().cpu().numpy()
y_prob = y_prob.detach().cpu().numpy()
# apply softmax to the outputs of the linear classifier to obtain the predicted probabilities

# compute the AUROC for each class using the predicted probabilities and the corresponding categorical labels
aurocs = []
for c in range(num_classes):
    y_true_c = (y_true == c)
    y_prob_c = y_prob[:, c]
    auroc_c = roc_auc_score(y_true_c, y_prob_c)
    aurocs.append(auroc_c)

# compute the mean AUROC over all classes
mean_auroc = sum(aurocs) / num_classes

In [41]:
print("multiclass AUROC: ", aurocs, mean_auroc)

multiclass AUROC:  [0.9173501362357952, 0.6423820917906181, 0.85675162672, 0.35392481043158375] 0.6926021662944992


In [42]:
# acc@k, prec@k, rec@k
def top_k_acc_prec_rec(y_prob, y_true, k):
    # y_prob is a numpy array of shape (num_examples, num_classes) containing the predicted probabilities
    # y_true is a numpy array of shape (num_examples) containing the true labels
    # k is the number of classes to consider for top-k accuracy, precision, and recall

    # get the indices of the top k predicted probabilities for each example
    top_k_preds = np.argsort(y_prob, axis=1)[:, -k:]

    # compute accuracy@k, precision@k, and recall@k for each example
    acc_k = np.mean(np.any(top_k_preds == y_true[:, np.newaxis], axis=1))
    prec_k = np.mean([np.any(top_k_preds[i] == y_true[i]) for i in range(len(y_true))])
    rec_k = np.mean([np.any(y_true[i] == top_k_preds[i]) for i in range(len(y_true))])

    return acc_k, prec_k, rec_k

In [45]:
acc_k, prec_k, rec_k = top_k_acc_prec_rec(y_prob, y_true, k=2)

In [46]:
print("acc@k: ", acc_k)
print("prec@k: ", prec_k)
print("rec@k: ", rec_k)

acc@k:  0.4175783192441571
prec@k:  0.4175783192441571
rec@k:  0.4175783192441571


In [49]:
print(mean_auroc)

0.6926021662944992
