# Classification

In [None]:
from __future__ import print_function 
from __future__ import division
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pandas as pd
import torchvision
from torchvision import datasets, models, transforms, test
from torchvision.models import resnet

import matplotlib.pyplot as plt
import time
import os
import copy
import glob
from pathlib import Path
import shutil
from tqdm import tqdm
from functools import partial
from torch.utils.data import DataLoader, Dataset
import seaborn as sns
from PIL import Image

from mocotools.mocoutil import ModelBase, ModelMoCo, test, ImageFolderWithPaths

print("PyTorch Version: ",torch.__version__)
print("Torchvision Version: ",torchvision.__version__)

Set the number of classes (findings) for each magnification here

In [None]:
batch_size = 256
img_size = 224

num_classes = {'mgn2x': 4, 'mgn5x': 8, 'mgn20x': 8, 'mgn20x_4': 8}

In [None]:
model = ModelMoCo(
        dim=128,
        K=4096,
        m=0.99,
        T=0.1,
        arch='resnet18').encoder_q


def get_device(use_gpu):
    if use_gpu and torch.cuda.is_available():
        torch.backends.cudnn.deterministic = True
        return torch.device("cuda")
    else:
        return torch.device("cpu")
    
device = get_device(use_gpu=True)

normalize = transforms.Normalize(mean=[0.85, 0.7, 0.78], std=[0.15, 0.24, 0.2])
transform = transforms.Compose([
    transforms.CenterCrop(img_size),
    transforms.ToTensor(),
    normalize])

Function to classify the findings for each tile, and output as csv file.

Model for each magnification have to be identified.

In [None]:
data_out_dir = Path('/path/to/my/output/') # location of csv export

def featextract(case_dir, model):
    mgn_dir = {
        'mgn2x' :  '2.5',
        'mgn5x' :  '5.0',
        'mgn20x_4': '20.0',
    }
    
    
    # skip if the tiling is not completed
    if False:
        completefile = Path(case_dir).name.replace('_files', '.dzi')
        if not os.path.exists(Path(case_dir).parent.joinpath(completefile)):
            return None
        
    print(f'------------ {Path(case_dir).name}  ------------')
    
    for mgn in ['mgn2x', 'mgn5x', 'mgn20x_4']:
        
        if mgn == 'mgn2x':
            cp = '/path/to/mgn2x-cnn-model.pth'
        elif mgn == 'mgn5x':
            cp = '/path/to/mgn5x-cnn-model.pth'
        elif mgn == 'mgn20x_4':
            cp = '/path/to/mgn20x-cnn-model.pth'
        
        cp = Path(cp)
        epoch = cp.stem

        case_name_tmp = f'{Path(case_dir).name}_{mgn}_{epoch}.csv'
        if os.path.exists(Path(data_out_dir).joinpath(case_name_tmp)):
            print(f'The result files already exist: {case_name_tmp}')
            continue        

        print(f'{mgn=}, {epoch=}')
        #print('moving files')

        img_files = str(Path(case_dir).joinpath(mgn_dir[mgn]))

        os.makedirs(f'{img_files}/squarefiles', exist_ok=True)
        for img in glob.glob(f'{img_files}/*.jpeg'):
            im = Image.open(img)
            if im.height == im.width:
                shutil.move(img, Path(img).parent.joinpath('squarefiles').joinpath(Path(img).name))

        # Build a network
        snet = []
        for name, module in model.net.named_children():
            snet.append(module)
            if isinstance(module, nn.AdaptiveAvgPool2d):
                snet.append(nn.Flatten(1))
                snet.append(nn.Linear(512, num_classes[mgn]))
                break
        model.net = nn.Sequential(*snet)
        model = model.cuda()

        # Load checkpoint
        cp_loaded = torch.load(cp)

        # match checkpoints to the models
        model.load_state_dict(cp_loaded['state_dict'])

        # forwards
        try:
            dataset = ImageFolderWithPaths(img_files, transform = transform)
        except RuntimeError:
            print('Dataset was not enough to extract features')
            continue

        data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=16, pin_memory=True)
        feat, path = test(model, data_loader)
        cls_predict = torch.max(feat,1).indices.to('cpu').numpy()

        slide = [Path(p).parent.parent.parent.name for p in path]
        case = [s.split('_')[0] + '_' + s.split('_')[1] for s in slide]
        x = [Path(p).stem.split('_')[-2] for p in path]
        y = [Path(p).stem.split('_')[-1] for p in path]

        df = pd.DataFrame({
            'path': path,
            'predict': cls_predict,
            'case': case,
            'slide': slide,
            'x': x,
            'y': y
        })
        df.to_csv(Path(data_out_dir).joinpath(f'{slide[0]}_{mgn}_{epoch.zfill(3)}.csv'))

In [None]:
# run
tile_dir = Path('/path/to/my/tiles/')
cases1 = list(tile_dir.glob('*_files'))

for c in cases1:
    featextract(c, model)

# Mapping the findings

Note that the filename of each tiles are: `<x coord>_<y coord>.jpeg`.


In [None]:
# define helper function here --- 

def get_img_shape(pos_cls):
    pos = list(pos_cls.keys())
    x_max = max([x[0] for x in pos])
    y_max = max([y[1] for y in pos])
    return x_max, y_max, 3

def resize_px(img, px=20):
    return img.resize((img.width*px, img.height*px), resample=Image.BOX)

def transpose3d(img_array):
    shape2 = img_array.shape
    img_out2 = np.zeros((shape2[1], shape2[0], 3))
    img_out2[:,:,0], img_out2[:,:,1], img_out2[:,:,2] = img_array[:,:,0].T, img_array[:,:,1].T, img_array[:,:,2].T
    return img_out2
    
# pos_cls: dictionary of {'position (tuple x_y)': cluster (e.g. 3)}
def get_img(pos_cls, size, palette):    
    img_out = np.zeros(get_img_shape(pos_cls)) # initialize
    
    for x in range(img_out.shape[0]):
        for y in range(img_out.shape[1]):
            cluster = pos_cls.get((x,y))
            if cluster is None:
                img_out[x,y,0], img_out[x,y,1], img_out[x,y,2] = (.99,.99,.99)
            else:
                img_out[x,y,0], img_out[x,y,1], img_out[x,y,2] = palette[cluster]
    
    img_out = transpose3d(img_out)
    img_out = Image.fromarray(np.uint8(img_out*255))
    img_out =  resize_px(img_out, size)
    
    return img_out

# Define the color palette here
pal = {
    'mgn2x': [
        (sns.color_palette('autumn'))[4],  # Acellular fibrosis: Orange
        (sns.color_palette('RdPu', 7))[6],  # Cellular fibrosis: Blue  
        (.72, .72, .72),  # Near Normal: Brown
        (.9, .9, .9)  ,                    # Other: Grey
    ],
    'mgn5x':  [
        sns.color_palette('OrRd')[4], # Acellular fibrosis: Orange
        sns.color_palette('YlGn',24)[12], # Cellular fibrotic IP: green
        sns.color_palette('Blues',24)[12], # Cellular IP, NSIP: Light blue
        (.65, .65, .65), # Complete Normal: Dark Gray
        sns.color_palette('Blues')[5], # Lymphoid follicle: Dark Blue
        (.9, .9, .9),                  # other: light Grey
        sns.color_palette('RdPu',24)[8], # Edge: Pink
        sns.color_palette('YlOrBr',24)[3], # Pale tissue: Yelllow
    ],
    'mgn20x_4':[
        sns.husl_palette(24)[1], # Dense fibrosis: Orange
        sns.husl_palette(24)[-3], # Immature fibroblasts: Pink
        sns.color_palette('summer')[3], # Elastosis: Light green
        sns.color_palette('YlOrBr')[0], # Fat: light Yellow
        sns.color_palette('RdBu_r')[0], # Lymphocytes: Dark Blue
        sns.color_palette('pink')[3], # Mucin: light Brown
        (.9, .9, .9),                    # other: Grey
        sns.color_palette('gist_heat')[1]  , # Respiratory epithelium: Dark Brown

    ]
}

In [None]:
# output file
img_preview_dir = Path('/path/to/my/dir/maps/')

# CSV file as input
all_results = Path('/path/to/my/output/').glob('*.csv')

for slide in tqdm(all_results):
    
    c_name = Path(slide).stem
    epoch = c_name.split('_')[-1]
    mgn = c_name.split('_')[-2]
    if mgn == '4':
        mgn = "mgn20x_4"
    
    slide_name = c_name.split('_files_')[0]
    
    df = pd.read_csv(slide)
    pos_cls = {(x, y): p for x, y, p in zip(df['x'], df['y'], df['predict'])}
    img = get_img(pos_cls, 15, pal[mgn])
    
    dir_to_save = img_preview_dir.joinpath(slide_name)
    os.makedirs(dir_to_save, exist_ok=True)
    img.save(dir_to_save.joinpath(f'{slide_name}_{mgn}.png'))

# Calcurate the frequency of each findings for each case (used for subsequent analysis)

In [None]:
def combine_case(df_all, mgn):
    dataflames = []


    for csvfile in allcsvfiles:
        df = pd.read_csv(csvfile)
        dataflames.append(df)
    
    data = pd.concat(dataflames)
    data['predict'] = data['predict'].map(lambda x: f'{mgn}_' + str(x))


    data = pd.crosstab(data['case'], data['predict'])
    
    if mgn == 'mgn5x':
        data = data.drop('mgn5x_3', axis=1)
    
    data = data.apply(lambda x: x/sum(x), axis=1)
    
    return data

allcsvfiles = data_out_dir.glob('*mgn2x.csv')
data_2x = combine_case(allcsvfiles, 'mgn2x')

allcsvfiles = data_out_dir.glob('*mgn5x.csv')
data_5x = combine_case(allcsvfiles, 'mgn5x')

allcsvfiles = data_out_dir.glob('*mgn20x_4.csv')
data_20x = combine_case(allcsvfiles, 'mgn20x')

In [None]:
data_20x

In [None]:
df = pd.concat([data_2x, data_5x, data_20x], axis=1)

# drop 'other'
df = df.drop(['mgn2x_3', 'mgn5x_5', 'mgn20x_6'], axis=1)

col_rename = {
    'mgn2x_0': 'Acellular_fibrosis_2x',
    'mgn2x_1': 'Cellular_fibrosis_2x',
    'mgn2x_2': 'Near_Normal_2x',
    'mgn2x_3': 'other_2x',
    
    'mgn5x_0': 'Accelular_fibrosis_5x',
    'mgn5x_1': 'Cellular_fibrotic_IP_5x',
    'mgn5x_2': 'CellularIP_NSIP_5x',
    'mgn5x_3': 'Complete_Normal_5x',
    'mgn5x_4': 'Lymphoid_follicle_5x',
    'mgn5x_5': 'other_5x',
    'mgn5x_6': 'Edge_5x',
    'mgn5x_7': 'Pale_5x',
    
    'mgn20x_0': 'Dense_fibrosis_20x',
    'mgn20x_1': 'Immature_fibrosis_20x',
    'mgn20x_2': 'Elastosis_20x',
    'mgn20x_3': 'Fat_20x',
    'mgn20x_4': 'Lymphocytes_20x',
    'mgn20x_5': 'Mucous_20x',
    'mgn20x_6': 'other_20x',
    'mgn20x_7': 'Resp_epithelium_20x',
}

df = df.rename(col_rename, axis=1)

df.to_csv('/path/to/my/project/features_cases.csv')