# Feature extraction and Clustering

In [None]:
import json
from pathlib import Path

import torch
import torchvision
from tqdm import tqdm
from PIL import Image
import torch.nn as nn
from torch.nn import functional as F
from functools import partial
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision import datasets
from torchvision.datasets.utils import download_url
from torchvision.models import resnet

import numpy as np
import pathlib
import pandas as pd
import os
import shutil
import joblib
import argparse

from sklearn.cluster import KMeans
from sklearn import preprocessing

from scipy.spatial import distance

from multiprocessing import Pool, cpu_count
import matplotlib.pyplot as plt

from mocotools import mocoutil 

gpu_info = !nvidia-smi -i 0
gpu_info = '\n'.join(gpu_info)
print(gpu_info)

In [None]:
parser = argparse.ArgumentParser(description='Feature extraction')

# General config
parser.add_argument('--batch-size', default=64, type=int, metavar='N', help='mini-batch size')
parser.add_argument('--pretrained_parameters', default= None, type=str)

# moco specific configs:
parser.add_argument('--symmetric', action='store_true', help='use a symmetric loss function that backprops to both crops')

args = parser.parse_args('')  # running in ipynb
print(args)

In [None]:
normalize = transforms.Normalize(mean=[0.85, 0.7, 0.78], std=[0.15, 0.24, 0.2])

transform = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    normalize])

img_file = {
    '2x': '/path/to/my/tiles/training_patch280_2x_e70',
    '5x': '/path/to/my/tiles/training5x_all/',
    '20x': '/path/to/my/tiles/train20x_4000each/',
    }

checkpoint_file = {
    '2x': '/path/to/my/checkpoints/2x_epoch200.pth',
    '5x': '/path/to/my/checkpoints/5x_epoch160.pth',
    '20x': '/path/to/my/checkpoints/20x_epoch200.pth'
}

def get_device(use_gpu):
    if use_gpu and torch.cuda.is_available():
        torch.backends.cudnn.deterministic = True
        return torch.device("cuda")
    else:
        return torch.device("cpu")

# Feature extraction
Feature extraction from each magnification

In [None]:
feat_result_dir = '/path/to/features'

for mgn in ['2x', '5x', '20x']:
    model = mocoutil.ModelMoCo(dim=128,K=4096,m=0.99,T=0.1,arch='resnet18',
        bn_splits=args.bn_splits,
        symmetric=args.symmetric).cuda()
    
    dataset = mocoutil.ImageFolderWithPaths(img_file[mgn], transform = transform)
    data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False,
                             num_workers=16, pin_memory=True)
    
    model_all = torch.load(checkpoint_file[mgn])
    
    print(f'Checkpoint loaded: Magnification: {mgn}, epoch: {model_all["epoch"]}')
    
    model_state_dict = model_all['state_dict']
    print(model.load_state_dict(model_state_dict))
    
    device = get_device(use_gpu=True) 
    print(f'device: {device}')
    
    feat, path = mocoutil.test(model.encoder_q, data_loader)
          
    feat_np = feat.to('cpu').detach().numpy().copy()
          
    # Normalize the value
    std_scaler = preprocessing.StandardScaler()
    feat_np_std = std_scaler.fit_transform(feat_np)
          
    
    data_to_export = {'feat': feat_np_std, 'filename':path}
    
    target_dir = Path(f'{feat_result_dir}/{mgn}.pkl')
    
    joblib.dump(data_to_export, target_dir)

# K-means clustering

In [None]:
cls_result_csv_dir = '/path/to/cluster_results'

for mgn in tqdm(['2x', '5x', '20x']):
    cluster= {}
    feat_path = Path(f'{feat_result_dir}/{mgn}.pkl')
    feat = joblib.load(feat_path)
    feature = feat['feat']
    cluster['path'] = feat['filename']

    for k in tqdm([3, 30, 50, 80, 100, 120, 150, 200]):
        clustering = KMeans(n_clusters=k, random_state=300).fit(feature)
        cluster[f'k{k}'] = clustering.labels_

    pd.DataFrame(cluster).to_csv(f'{cls_result_csv_dir}/mgn{mgn}.csv')

# Preview Images

In [None]:
def safe_sample(df, n):
    if df.shape[0] < n:
        n_sample = df.shape[0]
    else:
        n_sample = n
    
    return df.sample(n_sample)

def save_images(path, title, filename, fontsize = 8):   
    row = 15
    col = 8
    fig = plt.figure(figsize=(15,30), dpi=100)
    num = 0
    
    if len(path) == 0:
        return None
    
    for i in range(1, row*col+1):
        ax = fig.add_subplot(row, col, i)
        try:
            #print(path[i-1])
            image = Image.open(path[i-1])
            ax.imshow(image)
            ax.set_title('', fontsize=fontsize)
            ax.axis('off')
        except IndexError:
            return None
        finally:
            fig.savefig(filename)
            plt.close()
            
def save_preview(cluster_list, title, path, cluster, n_cluster, target_dir):
    df = pd.DataFrame({'cluster': cluster_list, 'title': title, 'path': path})
    
    df = df[df['cluster'] == cluster]
    df = safe_sample(df, 120)
    
    os.makedirs(target_dir, exist_ok=True)
    save_images(list(df['path']), list(df['title']), Path(target_dir).joinpath(f'cls{str(cluster).zfill(2)}.jpg'))
            
def save_preview_wrapper(args):
    return save_preview(*args)

In [None]:
previw_img_dir = '/path/to/preview'

for mgn in tqdm(['2x', '5x', '20x']):
    df = pd.read_csv(f'{cls_result_csv_dir}/mgn{mgn}.csv')
    
    for n_cluster in tqdm([3, 30, 50, 80, 100]):
            
        path = list(df['path'])
        cluster = list(df[f'k{str(n_cluster)}'])
        case_name = [Path(p).parent.name for p in list(df['path'])]
        target = f'{previw_img_dir}/mgn{mgn}_k{str(n_cluster)}'
        values = [(cluster, case_name, path, c, n_cluster, target) for c in range(n_cluster)]
        
        p = Pool(processes=cpu_count()-1)
        p.map(save_preview_wrapper, values)
        p.close()
        p.join()