In [None]:
import argparse
import os
import random
import warnings
import pandas as pd
import numpy as np
import time
import shutil
import datetime

import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.optim.lr_scheduler import StepLR
import torchvision
import torchvision.transforms as transforms

import torch.nn.functional as F
import pickle

from utils.logging import *
from data.process import *

### Define parameters

In [None]:
args = pd.Series({
    'checkpoint':'checkpoint/',
    'version': '1.0',
    'image_dir': 'Data/',
    'patch_label': 'Metadata/PatchLabels.csv',
    'predicting_var': 'response',
    'prediction': 'binary classification', # ['regression', 'binary classification', classification']
    'cohort': 'Cohort1',
    'magnification': '10X',
    'num_classes': 1,
    'upsample': False,
    'train_val_split': 0.7,
    'random_crop': False,
    'features': 2048,
    'base_epoch': 19,
    'batch_size': 256,
    'workers': 4,
    'seed': 0,
    'gpu': 0
})

### Generating and saving features functions

In [None]:
def process_slides(slides, patch_labels, data_transforms, model):
    for slide in slides:
        slide_patch_labels = patch_labels[patch_labels.slide==slide].reset_index(drop=True)
        dataset = PatchPathDataset(patch_labels=slide_patch_labels, image_folder=args.image_dir,
                                           predicting_var=args.predicting_var, transform=data_transforms)
        save_slide_features(slide, dataset, model)
        

def save_slide_features(slide, dataset, model):
    # slide should be slide name
    feature_path = os.path.join(args.checkpoint, f'BaselineResNet{args.version}', 'Features', 
                                f'epoch_{args.base_epoch}')
    if not os.path.exists(feature_path):
        os.makedirs(feature_path)
    slide_feature_path = os.path.join(feature_path, slide)
    if os.path.exists(slide_feature_path):
        print(f'Features already exist for slide {slide} so won\'t be generated again')
    else:
        print(f'Generating and saving features for slide {slide} at {slide_feature_path}')
        slide_embeddings, patch_paths = generate_features(slide, dataset, model)
        slide_embeddings_paths = {'slide_embeddings': slide_embeddings, 'patch_paths': patch_paths}
        torch.save(slide_embeddings_paths, slide_feature_path)

    
def generate_features(slide, dataset, model):
    model.eval()
    loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, 
                                         num_workers=args.workers, pin_memory=True, sampler=None)

    slide_embeddings = []
    patch_paths = []

    with torch.no_grad():
        for i, (images, _, paths) in enumerate(loader):
            if args.gpu is not None:
                images = images.cuda(args.gpu, non_blocking=True)
            
            outputs = model(images).squeeze()
            slide_embeddings += [outputs]
            patch_paths += paths
            del images, outputs, paths

    slide_embeddings = torch.vstack(slide_embeddings)
    return slide_embeddings, patch_paths

# Save feature embeddings

In [None]:
def main_worker(gpu, args):
    global best_acc1
    args.gpu = gpu
    
    if not torch.cuda.is_available():
        print('using CPU, this will be slow')
    elif args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))
        torch.cuda.set_device(args.gpu)

    # Load baseline ResNet
    print('Loading baseline ResNet')
    base_model = torchvision.models.resnet50(pretrained=True)
    base_model.fc = nn.Linear(in_features=args.features, out_features=args.num_classes, bias=True)
    basenet_path = os.path.join(args.checkpoint, f'BaselineResNet{args.version}', f'epoch_{args.base_epoch}', 
                                'checkpoint.pth.tar')
    print(f'Using baseline model: at {basenet_path}')
    assert os.path.isfile(basenet_path)
    resnet_state = torch.load(basenet_path, map_location=torch.device('cuda'))
    resnet_state_dict = resnet_state['state_dict']
    saved_val_cases = resnet_state['val_cases']
    base_model.load_state_dict(resnet_state_dict, strict=True)
    
    feature_model = torch.nn.Sequential(*(list(base_model.children())[:-1]))
    feature_model = feature_model.cuda(args.gpu)
    del base_model
    feature_model.eval()
    
    # Load data
    patch_labels = pd.read_csv(args.patch_label, index_col=0)
    patch_labels = patch_labels[patch_labels.magnification == args.magnification]
    patch_labels = patch_labels.dropna(subset=[args.predicting_var])
    patch_labels = select_cohort(patch_labels, args.cohort)

    train_patch_labels, val_patch_labels, val_cases, _ = split_train_val(patch_labels, args.cohort, 
                                                                         args.train_val_split, args.seed, 
                                                                         args.prediction, args.predicting_var,
                                                                         args.upsample)
    # check saved_val_cases from baseline model are same as val_cases for attention model
    assert (val_cases == saved_val_cases).all()
    
    train_slides = train_patch_labels.slide.unique()
    print(f'{len(train_slides)} training slides')
    val_slides = val_patch_labels.slide.unique()
    print(f'{len(val_slides)} validation slides')

    cudnn.benchmark = True
    
    full_transforms, lim_transforms = image_transforms(args.random_crop)
    
    process_slides(train_slides, train_patch_labels, full_transforms, feature_model)
    process_slides(val_slides, val_patch_labels, lim_transforms, feature_model)

    del feature_model
    print('Saved all features.')