In [21]:
# Run inference only
import os
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sn
import torch
import torch.nn as nn
from tqdm import tqdm
import torch.optim as optim
import torch.utils.data as data_utils
from PIL import Image
from Reinhard import Normalizer
from data_augment import RandomAugment
from dataset_generic import CustomDataset, load_dataframe, get_indep_test_sets, MultipleInferenceDataset
from model_zoo import Encoder, get_ozan_ciga, freeze_encoder
from ssc_utils import SSCMaxPoolClassifier, train_model_w_ssc, SSCGatedAttentionClassifier, get_scores_ssc, \
    get_scores_ssc_multiple_inference, train_model_w_ssc_multiple_inference
from training import train_model, train_model_multiple_inference
from utils import get_scores, set_device_and_seed, load_model_weights, get_scores_multiple_inference, get_model, get_model_ssc


In [39]:
configuration = {
    # Training Parameters
    "num_epochs": 1,
    "batch_size": 6,
    "gpu_name": "cuda:0",
    "seed": 42,
    "multi_gpu": False, # if True, use all available JPUs with DataParallel

    # DataLoading parameters
    "tile_path":"/data/MSc_students_accounts/sneha/tiles_summary_5x_2.5mm_50%_adaptive_threshold.csv", # abs path to csv file with tiling info
    "data_dir": "/data/goldin/images_raw/nafld_liver_biopsies_sirius_red/", # abs path to loc of .ndpi files
    "multiple_inference": False, # test with multiple inference dataset
    "bag_size": 10, # tiles per bag


    # Preprocessing parameters
    "resize": True, # resize images before trainnig
    "img_size": 224, # resize to img_size x img_size patches
    "stain_normalization": # apply reinhard stain normalisation based on source image in source_dir
    {
        "apply_reinhard": True, 
        "source_dir":"/data/MSc_students_accounts/sneha/sneha/sirius_red-master/reinhard_source.jpg"
    },
    "transform_color_jitter": None, # hsv color jitter to apply
    "hsv": False, # input img color space
    "cmyk": False,# input img color space

    # Model Parameters
    "encoder": "se_resnet18", # se_resnet18, se_resnet34, resnet18, resnet34, simclr
    "image_net_pretrained": True, # load imgnet pretrained weights
    "load_path": False, # if not False, location of pretrained encoder weights to load
    "freeze": False, # True, False or "part" - encoder weights to freeze
    "dropout": False,
    "num_layers": None, # only for max pool models
    "aggregation": "gated_attention", # gated_attention or simple_attention or max_pool
    "best_model_weights_path": "/data/MSc_students_accounts/sneha/temp_imgs/best_model_fold_0_5vvb5ni5.h5",

    # SSC module parameters
    "use_ssc": False,
    "ssc_reconst_loss": "l2",
    "ssc_num_routings": 3, # R
    "ssc_lr": 0.1, # initial lr for ssc_module
    "apply_ssc_scheduler": (0.1,30), # Decay by factor of x0.1 every 30 epochs. False if not applying scheduler
    "ssc_num_stains": 2, # S
    "ssc_num_groups": 6, # M
    "ssc_group_width": 3, # N
    "ssc_use_od": True, # apply OD transformation in SSC capsule
    "ssc_in_channels": 3 # num channels in input image (3=RGB/HSV, 4 = CMYK)
}

In [40]:


def run_inference(configuration):
    print("Configuration for run:", configuration)
    # set random seed and device CPU / GPU
    device = set_device_and_seed(GPU=True, seed=configuration["seed"],
                                 gpu_name=configuration["gpu_name"])
    # info to load the data
    tiles_summary_data = configuration["tile_path"]
    
    data_dir = configuration["data_dir"]
    num_workers =16
    pin_memory=False
    resize=configuration["resize"]
    img_size = configuration["img_size"]

    # make dummy labels for dataloading
    test_df = pd.read_csv(os.path.abspath(tiles_summary_data)).dropna(subset=["ndpi_file", "mostly_tissue"])
    
    
    
    test_df["stage"] = test_df.index
#     test_df = load_dataframe(os.path.abspath(tiles_summary_data), label_map)
    print("TEST DF", test_df[["ndpi_file", "mostly_tissue","stage"]])
    # remove for experiment running
    
    # get normalizer
    if configuration["stain_normalization"]["apply_reinhard"]:
        source_dir = configuration["stain_normalization"]["source_dir"]
        normalizer = Normalizer(source_path = source_dir)
    else:
        normalizer = None


    # cv_results
    test_results = {"test_f1": [], "test_accuracy": [], "test_cm": []}

    # create datasets
    if configuration["multiple_inference"]:
        test_dataset = MultipleInferenceDataset(test_df, data_dir,  keep_top=configuration["bag_size"], verbose=False, hsv=configuration["hsv"], cmyk=configuration["cmyk"], transform=None, train=False,  resize=resize, img_size=img_size,normalizer=normalizer)
        print(f"Created MI datasets: Testing on {len(test_dataset)} WSIs")
    else:
        test_dataset = CustomDataset(test_df, data_dir,  keep_top=configuration["bag_size"], verbose=False, hsv=configuration["hsv"], cmyk=configuration["cmyk"], transform=None, train=False,  resize=resize, img_size=img_size,normalizer=normalizer)
        print(f"Created Normal datasets: Testing on {len(test_dataset)} WSIs")

    # create dataloaders
    test_loader = data_utils.DataLoader(test_dataset, batch_size = configuration["batch_size"], shuffle=True, num_workers=num_workers, pin_memory=pin_memory)

    print(f"Created Dataloader with num workers {num_workers} and pinned memory {pin_memory}")
    if configuration["use_ssc"]:
        print("Getting SSC model")
        model = get_model_ssc(configuration, device)
    else:
        print("Getting non-SSC model")
        model = get_model(configuration, device)
    print(model)

    # track model weights
    criterion = nn.BCELoss().to(device)
    print(f"Loading best model")
    # reload best model
    if configuration["multi_gpu"]:
        model.module.load_state_dict(torch.load(configuration["best_model_weights_path"]))
    else:
        model.load_state_dict(torch.load(configuration["best_model_weights_path"], map_location=configuration["gpu_name"]))

        
    predictions = []
    paths = []
    color = []
    y = []
    visualise=True
    model.eval()
    test_bar = tqdm(enumerate(test_loader))
    with torch.no_grad():
        for i, (data, label) in test_bar:
            test_bar.set_description('WSI: [{}/{}]'.format(i, len(test_loader)))
            # get predictions
            data = data.to(device)
            if configuration["multiple_inference"]:
                print(label, len(label))
                bag_label, path = label[0], label[1]
                label = bag_label.to(device)
                path = torch.tensor(list(path))
                path = path.to(device)
                paths.append(path)
            else:
                
                path =  label.to(device)
                # we have labelled paths as the dataframe index
                paths.append(path)
                label = label.to(device)

            if configuration["use_ssc"]:
                outputs, reconst, normed_input, A = model(data, return_attn=True)
            else:
                outputs, A = model(data, return_attn=True)

            pred = outputs > 0.5
            pred = pred.long()
            predictions.append(pred)
            reshaped_data = data.reshape(data.shape[0]*data.shape[1], data.shape[2], data.shape[3], data.shape[4])
            # add attention code here
            
            
    if configuration["multiple_inference"]:
        # get predictions per path - gather in dataframe and compute one-wins-all aggregation
        predictions = torch.cat(predictions, dim=0).cpu().numpy()
        paths = torch.cat(paths, dim=0).cpu().numpy()
        predictions_df = pd.DataFrame({"prediction": predictions, "wsi_path": paths}, columns=["prediction", "wsi_path"])
        wsi_preds = []
        wsi_paths = []
        for wsi in np.unique(predictions_df["wsi_path"].tolist()):
            wsi_df = predictions_df[predictions_df["wsi_path"]==wsi]
            wsi_prediction = wsi_df["prediction"].max()
            wsi_preds.append(wsi_prediction)
            wsi_paths.append(wsi)
            
        # convert integer wsi paths back to strings
        wsi_paths =  test_dataset.tiles_summary.iloc[paths]["ndpi_file"].to_list()
        print(wsi_paths)
        return list(zip(wsi_paths, wsi_preds))
    
    else:
        print(paths, predictions , sep="\n")
        predictions = torch.cat(predictions, dim=0).cpu().numpy()
        paths = torch.cat(paths, dim=0).cpu().numpy()
        
        # convert paths back to wsi_paths
        wsi_paths = test_dataset.tiles_summary.iloc[paths]["ndpi_file"].to_list()
        return list(zip(wsi_paths, predictions))
        
            
    
        
    

    
   
    
run_inference(configuration)


Configuration for run: {'num_epochs': 1, 'batch_size': 6, 'gpu_name': 'cuda:0', 'seed': 42, 'multi_gpu': False, 'tile_path': '/data/MSc_students_accounts/sneha/tiles_summary_5x_2.5mm_50%_adaptive_threshold.csv', 'data_dir': '/data/goldin/images_raw/nafld_liver_biopsies_sirius_red/', 'multiple_inference': False, 'bag_size': 10, 'resize': True, 'img_size': 224, 'stain_normalization': {'apply_reinhard': True, 'source_dir': '/data/MSc_students_accounts/sneha/sneha/sirius_red-master/reinhard_source.jpg'}, 'transform_color_jitter': None, 'hsv': False, 'cmyk': False, 'encoder': 'se_resnet18', 'image_net_pretrained': True, 'load_path': False, 'freeze': False, 'dropout': False, 'num_layers': None, 'aggregation': 'gated_attention', 'best_model_weights_path': '/data/MSc_students_accounts/sneha/temp_imgs/best_model_fold_0_5vvb5ni5.h5', 'use_ssc': False, 'ssc_reconst_loss': 'l2', 'ssc_num_routings': 3, 'ssc_lr': 0.1, 'apply_ssc_scheduler': (0.1, 30), 'ssc_num_stains': 2, 'ssc_num_groups': 6, 'ssc_g

WSI: [0/1]: : 1it [00:02,  2.94s/it]

[tensor([1, 0], device='cuda:0')]
[tensor([0, 0], device='cuda:0')]





[('/data/goldin/images_raw/nafld_liver_biopsies_sirius_red/MS13-17926.SR.ndpi',
  0),
 ('/data/goldin/images_raw/nafld_liver_biopsies_sirius_red/MS15-19086.SR.ndpi',
  0)]