In [116]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from sklearn.metrics import roc_auc_score, precision_score, recall_score
import clip
from pathlib import Path

# load in clip
import os
import torch.utils.data as data
import torchvision.transforms as transforms
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader
from typing import List, Optional
from PIL import Image

from collections import defaultdict
from tqdm import tqdm

In [3]:
import cv2
import openslide

In [4]:
clip.__file__

'/opt/conda/envs/pytorch/lib/python3.9/site-packages/clip/__init__.py'

### Get Sample Data

In [7]:
data_dir = Path("./data")

In [5]:
# load clip vit models -- this should have pre-trained clip weights
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device, jit=False)

In [159]:
# read in from csv, sample one from each class
labels_path = data_dir / "labels_40.csv"
labels_df = pd.read_csv(labels_path)
# add column to labels 
labels_df["label"] = labels_df.apply(lambda x: x['Label'].split(" ")[0].lower(),axis=1)

sample_label_cases = list(labels_df.groupby("Label").first()['Case ID'])
labels_sample = labels_df[labels_df['Case ID'].isin(sample_label_cases)]

In [160]:
labels_sample

Unnamed: 0,Case ID,Label,label
0,TCGA-CJ-4642,Adenomas and Adenocarcinomas,adenomas
10,TCGA-BA-4077,Squamous Cell Neoplasms,squamous
20,TCGA-A8-A084,Ductal and Lobular Neoplasms,ductal
30,TCGA-06-0209,Gliomas,gliomas


In [161]:
# Get all labels 
all_cases = list(labels_df["Case ID"])

In [36]:
def get_svs_paths(all_svs: List[Path], cases: List[str]):
    # Returns single svs for each patient given a set of cases and a list of all svs
    svs_counts = defaultdict(int)
    svs_paths = []
    for svs_path in all_svs: 
        for sample_case in cases: 
            if sample_case in str(svs_path): 
                if svs_counts[sample_case] == 0: 
                    svs_paths.append(svs_path)
                svs_counts[sample_case] += 1
    return svs_paths

all_svs = list(data_dir.rglob("./*.svs"))
svs_paths = get_svs_paths(all_svs=all_svs, cases=sample_label_cases)
            

In [52]:
svs_paths

[PosixPath('data/TCGA-CJ-4642/TCGA-CJ-4642-01B-01-BS1.5a1225ca-0cb1-4be4-852d-701b1b1a4e67.svs'),
 PosixPath('data/TCGA-BA-4077/TCGA-BA-4077-01B-01-TS1.f5a59d10-d032-4d7a-9244-71c31954b122.svs'),
 PosixPath('data/TCGA-A8-A084/TCGA-A8-A084-01Z-00-DX1.2B52D1B8-5AD4-4BD6-ADF7-9D65B8EE2621.svs'),
 PosixPath('data/TCGA-06-0209/TCGA-06-0209-01Z-00-DX8.4a540299-b778-43e4-b80b-f70d5f222378.svs')]

### Extract Patches

In [53]:
# basic parse some 224x224 patches from some slides
# THIS WAS NOT USED
def extract_patches(paths: List[Path], output_dir: Path, resolution: Optional[int] = None):
    """
    Extract 224x224 patches from SVS files that are not masked out by an Otsu threshold.

    :param paths: list of paths to SVS files
    :param output_dir: output directory to save patches
    :param resolution: resolution level to extract patches from
    """
    # Create the output directory if it does not exist
    Path(output_dir).mkdir(parents=True, exist_ok=True)

    for path in tqdm(paths):
        # Load the SVS file with OpenSlide
        slide = openslide.open_slide(str(path))

        if resolution is None: 
            resolution = slide.level_count - 1

        # Compute the dimensions of the slide at the given resolution level
        w, h = slide.level_dimensions[resolution]

        # Compute the dimensions of a 224x224 patch at the given resolution level
        patch_size = 224
        pw = (w - patch_size) // patch_size + 1
        ph = (h - patch_size) // patch_size + 1

        # Iterate over the patches and extract those that are not masked out by an Otsu threshold
        for x in range(pw):
            for y in range(ph):
                # Compute the coordinates of the patch in the slide
                x0 = x * patch_size
                y0 = y * patch_size
                x1 = x0 + patch_size
                y1 = y0 + patch_size

                # Read the region from the slide
                region = slide.read_region((x0, y0), resolution, (patch_size, patch_size))

                # Convert the region to a NumPy array and extract the grayscale channel
                img = np.array(region.convert('L'))

                # Compute the Otsu threshold and binarize the image
                threshold, mask = cv2.threshold(img, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)

                # If the mask is all white, the patch is not masked out
                # if np.all(mask == 255):
                # Save the patch to the output directory with the desired file name
                patch_name = f"{path.stem}_{x}_{y}.png"
                patch_path = Path(output_dir) / patch_name
                cv2.imwrite(str(patch_path), img)


In [51]:
# output_dir = data_dir / "patches"
# extract_patches(paths=svs_paths, output_dir=output_dir)

100%|██████████| 4/4 [00:02<00:00,  1.40it/s]


## Inference

### Load Patches

In [61]:
patch_dir = data_dir / "patches"
patch_paths = list(patch_dir.rglob("./*.png"))

sample_image = preprocess(Image.open(patch_paths[0])).unsqueeze(0).to(device)

In [64]:
with torch.no_grad(): 
    image_features = model.encode_image(sample_image)

In [85]:
# run model on full test dataset
test_data_dir = patch_dir
test_dataset = datasets.ImageFolder(test_data_dir, transform=preprocess)
# Create a DataLoader to load the test dataset in batches
batch_size = 32
num_workers = 4
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

### Linear Evaluation

In [122]:
logits_list = []
labels_list = []
for batch_idx, (images, labels) in enumerate(test_loader):
    with torch.no_grad():
        images = images.to(device)
        features = model.encode_image(images)
    logits_list.append(features)
    labels_list.append(labels)
logits = torch.cat(logits_list, dim=0)
labels = torch.cat(labels_list, dim=0)

In [123]:
# Create a DataLoader to load the test dataset in batches
num_ftrs = 512
num_classes = 4
topk = 5

# Define the linear classifier and the loss function
linear_classifier = nn.Linear(num_ftrs, num_classes)
linear_classifier = linear_classifier.to(device)
criterion = nn.CrossEntropyLoss()

# Define the optimizer for the linear classifier
optimizer = optim.SGD(linear_classifier.parameters(), lr=0.01, momentum=0.9)

# Set the linear classifier to training mode
linear_classifier.train()

Linear(in_features=512, out_features=4, bias=True)

In [124]:
# Train the linear classifier on the pre-trained logits using the test dataset
num_epochs = 30
for epoch in range(num_epochs):
    running_loss = 0.0
    # train for each of the logits and labels, want to batch this if possible
    for i, (features, targets) in enumerate(zip(logits, labels)):
        optimizer.zero_grad()

        features = features.to(device)
        targets = targets.to(device)

        outputs = linear_classifier(features.float())
        loss = criterion(outputs.unsqueeze(0), targets.unsqueeze(0))
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print('Epoch [%d/%d], Loss: %.4f' % (epoch+1, num_epochs, running_loss))


Epoch [1/30], Loss: 641.9414
Epoch [2/30], Loss: 3942.0457
Epoch [3/30], Loss: 3229.3350
Epoch [4/30], Loss: 1966.2836
Epoch [5/30], Loss: 1020.1500
Epoch [6/30], Loss: 1009.6738
Epoch [7/30], Loss: 1297.4641
Epoch [8/30], Loss: 873.2217
Epoch [9/30], Loss: 503.1060
Epoch [10/30], Loss: 421.4751
Epoch [11/30], Loss: 339.7691
Epoch [12/30], Loss: 375.5216
Epoch [13/30], Loss: 301.1740
Epoch [14/30], Loss: 205.1505
Epoch [15/30], Loss: 241.1382
Epoch [16/30], Loss: 274.3821
Epoch [17/30], Loss: 231.8254
Epoch [18/30], Loss: 231.4314
Epoch [19/30], Loss: 160.9533
Epoch [20/30], Loss: 117.9594
Epoch [21/30], Loss: 95.8594
Epoch [22/30], Loss: 137.1429
Epoch [23/30], Loss: 98.7828
Epoch [24/30], Loss: 124.5679
Epoch [25/30], Loss: 114.5041
Epoch [26/30], Loss: 97.3875
Epoch [27/30], Loss: 81.7098
Epoch [28/30], Loss: 45.9641
Epoch [29/30], Loss: 61.1156
Epoch [30/30], Loss: 49.0063


In [145]:

# compute the model predictions for the test dataset
y_pred = []
y_true = []
model.eval()
with torch.no_grad():
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)

        output = model.encode_image(images) # clip 
        output = linear_classifier(output.float())
        
        y_pred.append(output)
        y_true.append(labels)

In [146]:
# assume y_pred is a tensor of shape (num_examples, num_classes) containing the outputs of the linear classifier
# and y_true is a tensor of shape (num_examples) containing the categorical labels
y_pred = torch.cat(y_pred, dim=0)
y_true = torch.cat(y_true, dim=0)

y_prob = torch.softmax(y_pred, dim=1)
y_true = y_true.detach().cpu().numpy()
y_prob = y_prob.detach().cpu().numpy()
# apply softmax to the outputs of the linear classifier to obtain the predicted probabilities

# compute the AUROC for each class using the predicted probabilities and the corresponding categorical labels
aurocs = []
for c in range(num_classes):
    y_true_c = (y_true == c)
    y_prob_c = y_prob[:, c]
    auroc_c = roc_auc_score(y_true_c, y_prob_c)
    aurocs.append(auroc_c)

# compute the mean AUROC over all classes
mean_auroc = sum(aurocs) / num_classes

In [149]:
print("multiclass AUROC: ", aurocs, mean_auroc)

multiclass AUROC:  [0.9942294662831576, 0.9977590617273394, 0.9804071770480807, 0.9525653519752824] 0.981240264258465


In [150]:
# acc@k, prec@k, rec@k
def top_k_acc_prec_rec(y_prob, y_true, k):
    # y_prob is a numpy array of shape (num_examples, num_classes) containing the predicted probabilities
    # y_true is a numpy array of shape (num_examples) containing the true labels
    # k is the number of classes to consider for top-k accuracy, precision, and recall

    # get the indices of the top k predicted probabilities for each example
    top_k_preds = np.argsort(y_prob, axis=1)[:, -k:]

    # compute accuracy@k, precision@k, and recall@k for each example
    acc_k = np.mean(np.any(top_k_preds == y_true[:, np.newaxis], axis=1))
    prec_k = np.mean([np.any(top_k_preds[i] == y_true[i]) for i in range(len(y_true))])
    rec_k = np.mean([np.any(y_true[i] == top_k_preds[i]) for i in range(len(y_true))])

    return acc_k, prec_k, rec_k

In [156]:
acc_k, prec_k, rec_k = top_k_acc_prec_rec(y_prob, y_true, k=1)

In [157]:
print("acc@k: ", acc_k)
print("prec@k: ", prec_k)
print("rec@k: ", rec_k)

acc@k:  0.7682047584715213
prec@k:  0.7682047584715213
rec@k:  0.7682047584715213


In [148]:
print(mean_auroc)

0.981240264258465


In [134]:
# convert the predictions to numpy arrays and compute the AUC
y_pred = y_pred.detach().cpu().numpy()
y_true = y_true.detach().cpu().numpy()
auc = roc_auc_score(y_true, y_pred, multi_class='ovr')

# compute the accuracy@k, precision@k, and recall@k for each example in the test dataset
num_examples = len(y_true)
acc_topk = np.zeros(num_examples)
prec_topk = np.zeros(num_examples)
rec_topk = np.zeros(num_examples)
for i in range(num_examples):
    y_true_i = y_true[i]
    y_pred_i = y_pred[i]
    topk_preds = np.argsort(y_pred_i)[-topk:]
    acc_topk[i] = y_true_i in topk_preds
    prec_topk[i] = precision_score(y_true_i, topk_preds, average='micro')
    rec_topk[i] = recall_score(y_true_i, topk_preds, average='micro')

# compute the average metrics over the entire test set
acc_topk_mean = np.mean(acc_topk)
prec_topk_mean = np.mean(prec_topk)
rec_topk_mean = np.mean(rec_topk)

ValueError: Target scores need to be probabilities for multiclass roc_auc, i.e. they should sum up to 1.0 over classes

In [111]:
# Evaluate the linear classifier on the test dataset
linear_classifier.eval()
correct = 0
total = 0

topk = 5

with torch.no_grad():
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        features = model.encode_image(images)
        outputs = linear_classifier(features.float())
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = 100

In [115]:
print('Accuracy: ', correct / total)
# TODO: test all the necessary evaluation metrics -- i.e. top k stuff

Accuracy:  0.589041095890411


In [None]:
class PatchLoader(data.Dataset):
    def __init__(self, svs_folder, label_df_path, otsu_threshold=0.5, patch_size=224):
        self.svs_folder = svs_folder
        self.label_df = pd.read_csv(label_df_path)
        self.otsu_threshold = otsu_threshold
        self.patch_size = patch_size

        # Generate list of valid patch coordinates and corresponding labels
        self.valid_coords, self.labels = self.generate_valid_coords()

    def __len__(self):
        return len(self.valid_coords)

    def __getitem__(self, idx):
        # Get patch coordinates
        svs_file, x, y = self.valid_coords[idx]

        # Read patch from slide
        slide = openslide.OpenSlide(os.path.join(self.svs_folder, svs_file))
        patch = slide.read_region((x, y), 0, (self.patch_size, self.patch_size)).convert("RGB")
        patch = np.asarray(patch)[:, :, :3]  # Remove alpha channel if present

        # Apply Otsu threshold
        gray_patch = cv2.cvtColor(patch, cv2.COLOR_RGB2GRAY)
        _, mask = cv2.threshold(gray_patch, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
        mask = mask / 255  # Normalize to [0, 1]
        if np.mean(mask) >= self.otsu_threshold:
            # If patch is too dark, generate a new patch
            return self.__getitem__(np.random.randint(len(self)))

        # Apply random augmentations if desired
        transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomVerticalFlip(),
            transforms.RandomRotation(90),
        ])
        patch = transform(patch)

        # Convert patch to tensor
        patch = transforms.ToTensor()(patch)

        # Get corresponding label
        label = self.labels[idx]

        return patch, label

    def generate_valid_coords(self):
        valid_coords = []
        labels = []
        for svs_file in os.listdir(self.svs_folder):
            if not svs_file.endswith('.svs'):
                continue
            label = self.label_df.loc[svs_file[:-4]]['label']
            slide = openslide.OpenSlide(os.path.join(self.svs_folder, svs_file))
            for y in range(self.patch_size // 2, slide.dimensions[1] - self.patch_size // 2, self.patch_size):
                for x in range(self.patch_size // 2, slide.dimensions[0] - self.patch_size // 2, self.patch_size):
                    # Read patch from slide
                    patch = slide.read_region((x, y), 0, (self.patch_size, self.patch_size)).convert("RGB")
                    patch = np.asarray(patch)[:, :, :3]  # Remove alpha channel if present

                    # Apply Otsu threshold
                    gray_patch = cv2.cvtColor(patch, cv2.COLOR_RGB2GRAY)
                    _, mask = cv2.threshold(gray_patch, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
                    mask = mask / 255  # Normalize to [0, 1]
                    if np.mean(mask) < self.otsu_threshold:
                        valid_coords.append((svs_file, x, y))
                        labels.append(label)
        return valid_coords, labels


In [None]:
# create data loader
svs_folder = "../data/sample/svs/"
label_df_path = "../data/sample/labels.csv"

patch_dataset = PatchLoader(
    svs_folder=svs_folder, 
    label_df_path=label_df_path,
)

# Define the batch size and other DataLoader options
batch_size = 32
num_workers = 4
shuffle = False

# Create the data loader
my_dataloader = DataLoader(
    dataset=patch_dataset,
    batch_size=batch_size,
    num_workers=num_workers,
    shuffle=shuffle,
    # Additional DataLoader options go here
)

In [None]:
import torch
import numpy as np
from sklearn.metrics import roc_auc_score

def evaluate(model, data_loader, k=5):
    # Set model to eval mode
    model.eval()

    # Initialize variables for metrics
    total_images = 0
    total_top_k_correct = 0
    total_precision = 0
    total_recall = 0
    all_labels = []
    all_preds = []

    with torch.no_grad():
        # Iterate over the data loader
        for images, labels in data_loader:
            # Send images and labels to device
            images = images.to(device)
            labels = labels.to(device)

            # Get predictions from model
            outputs = model(images)
            # TODO: go from model logits --> predictions
            
            preds = torch.topk(outputs, k=k, dim=1)[1]

            # Compute metrics
            total_images += len(labels)
            total_top_k_correct += (preds == labels.view(-1,1)).sum().item()

            precision, recall = top_k_precision_recall(preds, labels.view(-1,1), k=k)
            total_precision += precision
            total_recall += recall

            all_labels.append(labels.cpu().numpy())
            all_preds.append(outputs.cpu().numpy())

    # Compute multiclass AUROC
    all_labels = np.concatenate(all_labels)
    all_preds = np.concatenate(all_preds)
    auroc = roc_auc_score(y_true=all_labels, y_score=all_preds, multi_class='ovo')

    # Compute top-k accuracy, precision, and recall
    top_k_accuracy = total_top_k_correct / total_images
    top_k_precision = total_precision / len(data_loader)
    top_k_recall = total_recall / len(data_loader)

    return top_k_accuracy, auroc, top_k_precision, top_k_recall

def top_k_precision_recall(preds, labels, k=5):
    # Compute precision and recall for top-k predictions
    tp = 0
    fp = 0
    fn = 0
    for i in range(len(preds)):
        pred = preds[i][:k]
        label = labels[i]
        if label in pred:
            tp += 1
            fp += k - 1
            fn += 0
        else:
            tp += 0
            fp += k
            fn += 1
    precision = tp / (tp + fp) if tp + fp > 0 else 0
    recall = tp / (tp + fn) if tp + fn > 0 else 0
    return precision, recall
