In [None]:
import argparse
import os
import random
import warnings
import pandas as pd
import numpy as np
import time
import shutil
import datetime
from itertools import chain

import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torchvision
import torch.nn.functional as F

import pickle
from sklearn.cluster import KMeans

import matplotlib.pyplot as plt
import matplotlib.colors as colors

from utils.logging import *
from data.process import *

### Parameters

In [None]:
args = pd.Series({
    'checkpoint':'checkpoint/',
    'version': '1.0',
    'image_dir': 'Data/',
    'patch_label': 'Metadata/PatchLabels.csv',
    'predicting_var': 'response',
    'prediction': 'binary classification', # ['regression', 'binary classification', classification']
    'cohort': 'Cohort1',
    'magnification': '10X',
    'upsample': False,
    'train_val_split': 0.7,
    'base_epoch': 19,
    'normalize': True,
    'n_clusters': 4,
    'batch_size': 256,
    'workers': 4,
    'seed': 0,
    'gpu': 0,
})

### Code

In [None]:
def load_flat_embeddings(slides, embeddings_path, normalize=False):
    all_embeddings = []

    for slide in slides:
        slide_embeddings_paths = torch.load(os.path.join(embeddings_path, slide),
                                            map_location=torch.device('cuda'))
        slide_embeddings = slide_embeddings_paths['slide_embeddings']

        if normalize:
            slide_embeddings = F.normalize(slide_embeddings, p=2, dim=1)

        all_embeddings.append(slide_embeddings)
    flattened_embeddings = torch.cat(all_embeddings, dim=0)
    return flattened_embeddings

def slide_clusters_dict(cluster_labels, slides, patch_labels):
    # using reset index purely for chronological index in cluster labels
    subset_patch_labels = patch_labels[patch_labels.slide.isin(slides)].reset_index(drop=True)
    print(f'Have {len(cluster_labels)} cluster labels for {len(slides)} slides.')
    slide_clusters = {}
    for slide in slides:
        idx = subset_patch_labels[subset_patch_labels.slide == slide].index
        slide_clusters[slide] = cluster_labels[idx]
    return slide_clusters

def save_clusters(slide_clusters, save_folder, n_clusters, normalize=False):
    if not os.path.exists(save_folder):
        os.makedirs(save_folder)
    norm_str = ''
    if normalize:
        norm_str = 'normalized_'
    pickle.dump(slide_clusters, 
                open(os.path.join(save_folder, f'{n_clusters}_{norm_str}clusters.p'), 'wb'))
    print(f'Clusters saved to {save_folder}.')

def save_cluster_model(model, cluster_path, n_clusters, normalize=False):
    norm_str = ''
    if normalize:
        norm_str = 'normalized_'
    if not os.path.exists(os.path.join(cluster_path, 'models')):
        os.makedirs(os.path.join(cluster_path, 'models'))
    pickle.dump(model, open(os.path.join(cluster_path, 'models',
                                          f'{norm_str}{n_clusters}_clusters_KMeans_model.p'), 'wb'))

In [None]:
def main_worker(gpu, args):
    global best_acc1
    args.gpu = gpu

    if not torch.cuda.is_available():
        print('using CPU, this will be slow')
    elif args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))
        torch.cuda.set_device(args.gpu)
    
    # Load data
    patch_labels = pd.read_csv(args.patch_label, index_col=0)
    patch_labels = patch_labels[patch_labels.magnification == args.magnification]
    patch_labels = patch_labels.dropna(subset=[args.predicting_var])

    train_patch_labels, val_patch_labels, val_cases, _ = split_train_val(patch_labels, args.cohort, 
                                                                         args.train_val_split, args.seed, 
                                                                         args.prediction, args.predicting_var,
                                                                         args.upsample)

    train_slides = train_patch_labels.slide.unique()
    print(f'{len(train_slides)} training slides')
    val_slides = val_patch_labels.slide.unique()
    print(f'{len(val_slides)} validation slides')

    features_path = os.path.join(args.checkpoint, f'BaselineResNet{args.version}', 'Features', 
                                f'epoch_{args.base_epoch}')
    
    # Load train features and fit KMeans model
    train_features = load_flat_embeddings(train_slides, features_path, normalize=args.normalize)
    
    kmeans = KMeans(n_clusters=args.n_clusters, random_state=0).fit(train_features.cpu().detach().numpy())
    train_cluster_labels = kmeans.labels_
    del train_features
    slide_clusters = slide_clusters_dict(train_cluster_labels, train_slides, train_patch_labels)
    
    cluster_path = os.path.join(args.checkpoint, f'BaselineResNet{args.version}', 'Clusters', 
                                f'epoch_{args.base_epoch}')
    save_clusters(slide_clusters, os.path.join(cluster_path, 'Train'), args.n_clusters, args.normalize)
    print(f'Train clusters distribution: {np.unique(train_cluster_labels,return_counts=True)}')
    
    # Load validation features and apply KMeans model
    val_features = load_flat_embeddings(val_slides, features_path, normalize=args.normalize)

    val_cluster_labels = kmeans.predict(val_features.cpu().detach().numpy())
    del val_features
    val_slide_clusters = slide_clusters_dict(val_cluster_labels, val_slides, val_patch_labels)
    
    save_clusters(val_slide_clusters, os.path.join(cluster_path, 'Validation'), args.n_clusters, args.normalize)
    print(f'Val clusters distribution: {np.unique(val_cluster_labels,return_counts=True)}')
    
    save_cluster_model(kmeans, cluster_path, args.n_clusters, args.normalize)
    
    return slide_clusters, val_slide_clusters

In [None]:
def main():
    #args = parser.parse_args()

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    if args.gpu is not None:
        warnings.warn('You have chosen a specific GPU. This will completely '
                      'disable data parallelism.')

    return main_worker(args.gpu, args)

In [None]:
train_slide_clusters, val_slide_clusters = main()

## Plot cluster distribution

In [None]:
plt.title('Training set cluster labels')
plt.hist(list(chain.from_iterable(train_slide_clusters.values())))
plt.show()

In [None]:
plt.title('Validation set cluster labels')
plt.hist(list(chain.from_iterable(val_slide_clusters.values())))
plt.show()

## Explore cluster centres

In [None]:
def load_cluster_model(cluster_path, n_clusters, normalize=False):
    norm_str = ''
    if normalize:
        norm_str = 'normalized'
    return pickle.load(open(os.path.join(cluster_path, 'models',
                                          f'{norm_str}_{n_clusters}_clusters_KMeans_model.p'), 'rb'))

cluster_path = os.path.join(args.checkpoint, f'BaselineResNet{args.version}', 'Clusters', 
                                f'epoch_{args.base_epoch}')
cluster_model = load_cluster_model(cluster_path, n_clusters=args.n_clusters, normalize=args.normalize)

In [None]:
cluster_model.cluster_centers_