In [1]:
import os, sys
sys.path.append( '..' )

In [2]:
import argparse
import pickle as pkl
import pandas as pd
import numpy as np
from collections import Counter
#from wsi_data import subset_wsi_dataset

import torch
from torch import nn, optim
import torch.multiprocessing as mp
from torch.utils.data import Subset

import copy as cp
import yaml
from sklearn.model_selection import KFold, StratifiedKFold

import matplotlib.pyplot as plt

In [3]:
from model import HE2Omics, fit, predict
from experiment import Experiment
from copy import deepcopy

In [4]:
from wsi_data import patient_stratified_kfold
#from image_materials.model import ImageToSubtype

#from image_materials.experiment import ImageSetup
#from image_materials.train import *

In [5]:
rootdir = '/gstore/data/tmai/Yasin/TMONC/HE2OMICS/TRL/2024_1_8'
savedir = './FOLDS/'
logdir = './exp'

In [6]:
def subset_wsi_dataset(dataset, subset_ix):
    """ Subset WSI dataset class instance with indices from subset_ix
    """
    dataset_copy = deepcopy(dataset)
    dataset_copy.samples = [dataset.samples[ix] for ix in subset_ix]
    dataset_copy.patients = np.array([dataset.patients[ix] for ix in subset_ix])
    dataset_copy.trials = pd.Series([dataset.trials[ix] for ix in subset_ix])
    dataset_copy.subtypes = pd.Series([dataset.subtypes[ix] for ix in subset_ix])
    return(dataset_copy)  

In [7]:
def convert_to_subtype_prediction(dataset,class_names,scaleFeatures=True,n_tiles=False):
    """
    Change (image,expression) tuples to (image,subtype), and make AlltilesWSI dataset
    """
    # string names for subtypes
    subtype = [str(i) for i in list(group2_dataset.subtypes)]
    # convert to indices
    subtype = [class_names.index(name) for name in subtype]
    
    dataset_copy = deepcopy(dataset)
    new_samples = [(image,int(subtype[idx])) for idx,(image,_) in enumerate(dataset_copy.samples)]
    #dataset_copy.samples = new_samples
    subtype_dataset = AlltilesWSI(new_samples,scaleFeatures=scaleFeatures,n_tiles=n_tiles)
    return(subtype_dataset)

### Define config files, use 50 most discriminant genes in HE2Omics

In [8]:
confi = yaml.load(open(os.path.join(rootdir,"config_main_50genes.yaml"), "r"), Loader=yaml.FullLoader)
print('Using configuration defined in {}'.format(confi['config']))

Using configuration defined in /gstore/data/tmai/Yasin/TMONC/HE2OMICS/TRL/2024_1_8/configs/IMvigor_3trial_50genes.ini


### Build dataset and permute patients stratifying on trial and subtype

In [9]:
config = confi['config']
exp = Experiment(config)

# Read architecture and training parameters
model_params = exp._read_architecture()
training_params = exp._read_training_params()

# Build dataset
dataset = exp._build_dataset()
# UNIQUE PATIENT IDENTIFIERS
pat_ids_unique, uniq_indices = np.unique(dataset.patients,return_index=True)
TOT_pat = pat_ids_unique.shape[0]
print(f'Number of patients: {TOT_pat}')

# STRATIFY BY TRIAL AND SUBTYPE
class_list = []
for trial,subtype in zip(dataset.trials,dataset.subtypes):
    class_list.append(trial+'_NMF'+str(subtype))
class_list = np.array(class_list)
class_list_patients = class_list[uniq_indices]

# PARAMETERS
Nsplit = 10
kfold_group1 = 5

# PERMUTE PATIENTS
ind_testing = np.zeros((Nsplit, TOT_pat))
np.random.seed(42)
for i in range(Nsplit):
    ind_testing[i, :] = np.random.permutation(TOT_pat)
ind_testing = ind_testing.astype(int)

logTransform: False, scaleVariables: False, ensureNonnegativity: True
Number of variables in input: 51
Number of genes found in transcriptome data: 51
Filtered out 161 slides less than 1000 tiles
number of images with paired omic data: 2285
Number of patients: 1810


In [None]:
# SEPARATE GROUP1 AND GROUP2 PATIENTS
# Group 1 gets the patients for transcriptomic learning
# Group 2 is the patient set where the image-based model will be compared with transfer learning (the transcriptomic model from group 1)

start_split = 2
start_fold = 1

# loop from 0 to Nsplit
for rs in range(start_split,Nsplit):
    ind = ind_testing[rs, :]
    shuffled_patient_classes = class_list_patients[ind]
    shuffled_patients = pat_ids_unique[ind]

    skf = StratifiedKFold(n_splits=2, shuffle=False) # shuffling is done earlier with permutations

    # Transcriptomic learning gets n_splits-1 folds (index1), classifier training gets 1 fold (index 2)
    # index1, index2 = next(iter(skf.split(ind, shuffled_patient_classes)))
    group1_index, group2_index = next(iter(skf.split(ind, shuffled_patient_classes)))

    print(f'Number of group 1 patients: {len(group1_index)}, Number of group 2 patients: {len(group2_index)}')

    pat_G1 = shuffled_patients[group1_index]
    pat_G2 = shuffled_patients[group2_index]
    pat_G1_classes = shuffled_patient_classes[group1_index]
    pat_G2_classes = shuffled_patient_classes[group2_index]
    
    # Make group 1 and group 2 datasets
    slide_ix_group1 = [ix for ix, d in enumerate(dataset.patients) if d in pat_G1]
    slide_ix_group2 = [ix for ix, d in enumerate(dataset.patients) if d in pat_G2]
    group1_dataset = subset_wsi_dataset(dataset,slide_ix_group1)
    group2_dataset = subset_wsi_dataset(dataset,slide_ix_group2)
    
    print(f'Number of group 1 images: {len(group1_dataset)}, Number of group 2 images: {len(group2_dataset)}')
    
    # Stratified KFold to train the gene prediction model within group 1 slides
    # Cross-validation will be done over slides (not patients) because inference will be on group 2 slides
    skf = StratifiedKFold(n_splits=kfold_group1, shuffle=True, random_state=1)

    # Stratify by TRIAL and SUBTYPE
    g1_class_list = []
    for trial,subtype in zip(group1_dataset.trials,group1_dataset.subtypes):
        g1_class_list.append(trial+'_NMF'+str(subtype))
    g1_class_list = np.array(g1_class_list)
    
    # Creating a folder for each split
    model_savedir = os.path.join(exp.savedir,'split'+str(rs))
    os.makedirs(model_savedir, exist_ok=True)
        
    # Run cross-validation
    for fold, (train_index, valid_index) in enumerate(skf.split(np.zeros((len(group1_dataset), 1)), g1_class_list)):

        if rs==start_split and fold < start_fold:
            continue
        print(f'Running split {str(rs)}, fold {str(fold)}')

        # If pickle file for samples does not exist, define and save it
        os.makedirs(savedir, exist_ok=True)
        path_pickle_file = os.path.join(savedir,'train_and_valid_set_ids_fold'+ str(fold) + '.pkl')
        if not os.path.exists(path_pickle_file):
            train_he = []
            for idx in train_index:
                filename = os.path.basename(group1_dataset.samples[idx][0]).replace('.npy','')
                train_he.append(filename)

            valid_he = []
            for idx in valid_index:
                filename = os.path.basename(group1_dataset.samples[idx][0]).replace('.npy','')
                valid_he.append(filename)

            # save train and validation slide identifers
            he_tuple = (train_he,valid_he)
            with open(path_pickle_file, 'wb') as handle:
                pkl.dump(he_tuple, handle)

        # subset train and validation sets
        train_set = Subset(group1_dataset, train_index)
        valid_set = Subset(group1_dataset, valid_index)
        valid_ctype_integers = np.zeros((len(valid_index))) # all zeroes since all samples are from the same cancer type
        model_params['input_dim'] = group1_dataset.dim # 2048
        model_params['output_dim'] = len(dataset.variables)
        print('Features starting with: {}'.format(dataset.variables[0][:4]))

        # Initialize bias of the last layer with the average target value on the train set
        try:
            model_params['bias_init'] = torch.nn.Parameter(torch.Tensor(np.mean([sample[1] for sample in train_set], axis=0)).cuda())
        except ValueError:
            model_params['bias_init'] = torch.nn.Parameter(torch.Tensor(np.mean([sample[1].numpy() for sample in train_set], axis=0)).cuda())

        model = HE2Omics(**model_params)
        optimizer = exp._setup_optimization(model)
        scheduler = optim.lr_scheduler.StepLR(optimizer,step_size=5,gamma=0.9)
        
        print("about to train model")
        test_preds, test_labels, valid_preds, valid_labels = fit(model,
                            train_set,
                            valid_set,
                            valid_ctype_integers,
                            params=training_params,
                            optimizer=optimizer,
                            scheduler=scheduler,
                            logdir=logdir,
                            path=model_savedir,
                            fold=fold)

Number of group 1 patients: 905, Number of group 2 patients: 905
Number of group 1 images: 1159, Number of group 2 images: 1126
Running split 2, fold 1
Features starting with: GENE
about to train model
Mean spearman r: 0.026


  0%|          | 0/29 [00:00<?, ?it/s]

Num epochs since best: 1


100%|██████████| 29/29 [01:11<00:00,  2.46s/it]


Epoch 1/300 - 119.76s
loss: 0.9922, val loss: 1.0465
Mean spearman r: 0.238


  0%|          | 0/29 [00:00<?, ?it/s]

Num epochs since best: 1


100%|██████████| 29/29 [01:11<00:00,  2.46s/it]


Epoch 2/300 - 95.11s


  0%|          | 0/29 [00:00<?, ?it/s]

loss: 0.8949, val loss: 0.9067
Mean spearman r: 0.344
Num epochs since best: 1


100%|██████████| 29/29 [01:07<00:00,  2.34s/it]


Epoch 3/300 - 89.52s


  0%|          | 0/29 [00:00<?, ?it/s]

loss: 0.8421, val loss: 0.8383
Mean spearman r: 0.416
Num epochs since best: 1


100%|██████████| 29/29 [01:08<00:00,  2.35s/it]


Epoch 4/300 - 90.38s


  0%|          | 0/29 [00:00<?, ?it/s]

loss: 0.7899, val loss: 0.8047
Mean spearman r: 0.461
Num epochs since best: 1


100%|██████████| 29/29 [01:08<00:00,  2.36s/it]


Epoch 5/300 - 90.70s


  0%|          | 0/29 [00:00<?, ?it/s]

loss: 0.7573, val loss: 0.7698
Mean spearman r: 0.490
Num epochs since best: 1


100%|██████████| 29/29 [01:08<00:00,  2.35s/it]


Epoch 6/300 - 91.03s


  0%|          | 0/29 [00:00<?, ?it/s]

loss: 0.7352, val loss: 0.7674
Mean spearman r: 0.500
Num epochs since best: 1


100%|██████████| 29/29 [01:18<00:00,  2.70s/it]


Epoch 7/300 - 103.19s


  0%|          | 0/29 [00:00<?, ?it/s]

loss: 0.7027, val loss: 0.7357
Mean spearman r: 0.523
Num epochs since best: 1


100%|██████████| 29/29 [01:20<00:00,  2.79s/it]


Epoch 8/300 - 105.87s


  0%|          | 0/29 [00:00<?, ?it/s]

loss: 0.6905, val loss: 0.7199
Mean spearman r: 0.540
Num epochs since best: 1


100%|██████████| 29/29 [01:18<00:00,  2.71s/it]


Epoch 9/300 - 104.39s


  0%|          | 0/29 [00:00<?, ?it/s]

loss: 0.6844, val loss: 0.7567
Mean spearman r: 0.536
Num epochs since best: 2


100%|██████████| 29/29 [01:19<00:00,  2.74s/it]


Epoch 10/300 - 106.74s


  0%|          | 0/29 [00:00<?, ?it/s]

loss: 0.6870, val loss: 0.7542
Mean spearman r: 0.542
Num epochs since best: 1


100%|██████████| 29/29 [01:19<00:00,  2.74s/it]


Epoch 11/300 - 106.42s


  0%|          | 0/29 [00:00<?, ?it/s]

loss: 0.6615, val loss: 0.7095
Mean spearman r: 0.550
Num epochs since best: 1


100%|██████████| 29/29 [01:20<00:00,  2.78s/it]


Epoch 12/300 - 106.09s


  0%|          | 0/29 [00:00<?, ?it/s]

loss: 0.6512, val loss: 0.7051
Mean spearman r: 0.559
Num epochs since best: 1


100%|██████████| 29/29 [01:20<00:00,  2.78s/it]


Epoch 13/300 - 106.79s


  0%|          | 0/29 [00:00<?, ?it/s]

loss: 0.6455, val loss: 0.7097
Mean spearman r: 0.555
Num epochs since best: 2


100%|██████████| 29/29 [01:20<00:00,  2.76s/it]


Epoch 14/300 - 107.21s


  0%|          | 0/29 [00:00<?, ?it/s]

loss: 0.6370, val loss: 0.7036
Mean spearman r: 0.562
Num epochs since best: 1


100%|██████████| 29/29 [02:10<00:00,  4.48s/it]


Epoch 15/300 - 169.34s


  0%|          | 0/29 [00:00<?, ?it/s]

loss: 0.6250, val loss: 0.6923
Mean spearman r: 0.565
Num epochs since best: 1


100%|██████████| 29/29 [01:21<00:00,  2.80s/it]


Epoch 16/300 - 106.95s


  0%|          | 0/29 [00:00<?, ?it/s]

loss: 0.6175, val loss: 0.7084
Mean spearman r: 0.561
Num epochs since best: 2


100%|██████████| 29/29 [02:15<00:00,  4.67s/it]


Epoch 17/300 - 161.76s


  0%|          | 0/29 [00:00<?, ?it/s]

loss: 0.6091, val loss: 0.7076
Mean spearman r: 0.559
Num epochs since best: 3


100%|██████████| 29/29 [01:37<00:00,  3.35s/it]


Epoch 18/300 - 150.83s


  0%|          | 0/29 [00:00<?, ?it/s]

loss: 0.6048, val loss: 0.6961
Mean spearman r: 0.563
Num epochs since best: 4


100%|██████████| 29/29 [01:19<00:00,  2.75s/it]


Epoch 19/300 - 106.81s


  0%|          | 0/29 [00:00<?, ?it/s]

loss: 0.6015, val loss: 0.7071
Mean spearman r: 0.562
Num epochs since best: 5


100%|██████████| 29/29 [01:17<00:00,  2.67s/it]


Epoch 20/300 - 103.68s


  0%|          | 0/29 [00:00<?, ?it/s]

loss: 0.5972, val loss: 0.6908
Mean spearman r: 0.567
Num epochs since best: 1


100%|██████████| 29/29 [01:22<00:00,  2.85s/it]


Epoch 21/300 - 108.67s


  0%|          | 0/29 [00:00<?, ?it/s]

loss: 0.5851, val loss: 0.6823
Mean spearman r: 0.571
Num epochs since best: 1


100%|██████████| 29/29 [01:15<00:00,  2.61s/it]


Epoch 22/300 - 101.73s


  0%|          | 0/29 [00:00<?, ?it/s]

loss: 0.5759, val loss: 0.6804
Mean spearman r: 0.573
Num epochs since best: 1


100%|██████████| 29/29 [01:18<00:00,  2.72s/it]


Epoch 23/300 - 105.58s


  0%|          | 0/29 [00:00<?, ?it/s]

loss: 0.5663, val loss: 0.6882
Mean spearman r: 0.573
Num epochs since best: 1


100%|██████████| 29/29 [01:19<00:00,  2.75s/it]


Epoch 24/300 - 106.17s


  0%|          | 0/29 [00:00<?, ?it/s]

loss: 0.5767, val loss: 0.6847
Mean spearman r: 0.575
Num epochs since best: 1


100%|██████████| 29/29 [01:20<00:00,  2.76s/it]


Epoch 25/300 - 106.59s


  0%|          | 0/29 [00:00<?, ?it/s]

loss: 0.5582, val loss: 0.6750
Mean spearman r: 0.574
Num epochs since best: 2


100%|██████████| 29/29 [01:19<00:00,  2.73s/it]


Epoch 26/300 - 104.06s


  0%|          | 0/29 [00:00<?, ?it/s]

loss: 0.5602, val loss: 0.6813
Mean spearman r: 0.571
Num epochs since best: 3


100%|██████████| 29/29 [01:19<00:00,  2.75s/it]


Epoch 27/300 - 105.27s


  0%|          | 0/29 [00:00<?, ?it/s]

loss: 0.5517, val loss: 0.7070
Mean spearman r: 0.568
Num epochs since best: 4


100%|██████████| 29/29 [01:20<00:00,  2.78s/it]


Epoch 28/300 - 108.11s


  0%|          | 0/29 [00:00<?, ?it/s]

loss: 0.5443, val loss: 0.6813
Mean spearman r: 0.571
Num epochs since best: 5


100%|██████████| 29/29 [01:22<00:00,  2.83s/it]


Epoch 29/300 - 109.26s


  0%|          | 0/29 [00:00<?, ?it/s]

loss: 0.5408, val loss: 0.6802
Mean spearman r: 0.576
Num epochs since best: 1


100%|██████████| 29/29 [01:21<00:00,  2.80s/it]


Epoch 30/300 - 107.70s


  0%|          | 0/29 [00:00<?, ?it/s]

loss: 0.5359, val loss: 0.7030
Mean spearman r: 0.578
Num epochs since best: 1


100%|██████████| 29/29 [01:20<00:00,  2.77s/it]


Epoch 31/300 - 106.40s


  0%|          | 0/29 [00:00<?, ?it/s]

loss: 0.5379, val loss: 0.6720
Mean spearman r: 0.575
Num epochs since best: 2


100%|██████████| 29/29 [01:20<00:00,  2.77s/it]


Epoch 32/300 - 107.67s


  0%|          | 0/29 [00:00<?, ?it/s]

loss: 0.5321, val loss: 0.6882
Mean spearman r: 0.575
Num epochs since best: 3


100%|██████████| 29/29 [01:19<00:00,  2.75s/it]


Epoch 33/300 - 107.62s


  0%|          | 0/29 [00:00<?, ?it/s]

loss: 0.5318, val loss: 0.6892
Mean spearman r: 0.581
Num epochs since best: 1


100%|██████████| 29/29 [01:21<00:00,  2.82s/it]


Epoch 34/300 - 109.53s


  0%|          | 0/29 [00:00<?, ?it/s]

loss: 0.5192, val loss: 0.6783
Mean spearman r: 0.577
Num epochs since best: 2


100%|██████████| 29/29 [01:17<00:00,  2.68s/it]


Epoch 35/300 - 103.30s


  0%|          | 0/29 [00:00<?, ?it/s]

loss: 0.5134, val loss: 0.6700
Mean spearman r: 0.578
Num epochs since best: 3


100%|██████████| 29/29 [01:21<00:00,  2.82s/it]


Epoch 36/300 - 109.06s


  0%|          | 0/29 [00:00<?, ?it/s]

loss: 0.5083, val loss: 0.6791
Mean spearman r: 0.580
Num epochs since best: 4


100%|██████████| 29/29 [01:20<00:00,  2.79s/it]


Epoch 37/300 - 106.30s


  0%|          | 0/29 [00:00<?, ?it/s]

loss: 0.5055, val loss: 0.6765
Mean spearman r: 0.577
Num epochs since best: 5


100%|██████████| 29/29 [01:18<00:00,  2.71s/it]


Epoch 38/300 - 105.66s


  0%|          | 0/29 [00:00<?, ?it/s]

loss: 0.4983, val loss: 0.6798
Mean spearman r: 0.576
Num epochs since best: 6


100%|██████████| 29/29 [01:21<00:00,  2.80s/it]


Epoch 39/300 - 110.19s


  0%|          | 0/29 [00:00<?, ?it/s]

loss: 0.5036, val loss: 0.6713
Mean spearman r: 0.579
Num epochs since best: 7


100%|██████████| 29/29 [01:22<00:00,  2.84s/it]


Epoch 40/300 - 109.60s


  0%|          | 0/29 [00:00<?, ?it/s]

loss: 0.4998, val loss: 0.6691
Mean spearman r: 0.579
Num epochs since best: 8


100%|██████████| 29/29 [01:20<00:00,  2.79s/it]


Epoch 41/300 - 107.67s


  0%|          | 0/29 [00:00<?, ?it/s]

loss: 0.4865, val loss: 0.6820
Mean spearman r: 0.578
Num epochs since best: 9


100%|██████████| 29/29 [01:21<00:00,  2.83s/it]


Epoch 42/300 - 109.79s


  0%|          | 0/29 [00:00<?, ?it/s]

loss: 0.4886, val loss: 0.6835
Mean spearman r: 0.575
Num epochs since best: 10


100%|██████████| 29/29 [01:20<00:00,  2.79s/it]


Epoch 43/300 - 105.64s


  0%|          | 0/29 [00:00<?, ?it/s]

loss: 0.4865, val loss: 0.6804
Mean spearman r: 0.571
Num epochs since best: 11


100%|██████████| 29/29 [01:20<00:00,  2.78s/it]


Epoch 44/300 - 107.69s


  0%|          | 0/29 [00:00<?, ?it/s]

loss: 0.5022, val loss: 0.6655
Mean spearman r: 0.583
Num epochs since best: 1


100%|██████████| 29/29 [01:18<00:00,  2.69s/it]


Epoch 45/300 - 107.59s


  0%|          | 0/29 [00:00<?, ?it/s]

loss: 0.4822, val loss: 0.6767
Mean spearman r: 0.578
Num epochs since best: 2


100%|██████████| 29/29 [01:20<00:00,  2.76s/it]


Epoch 46/300 - 108.36s


  0%|          | 0/29 [00:00<?, ?it/s]

loss: 0.4816, val loss: 0.6769
Mean spearman r: 0.574
Num epochs since best: 3


100%|██████████| 29/29 [01:22<00:00,  2.84s/it]


Epoch 47/300 - 108.96s


  0%|          | 0/29 [00:00<?, ?it/s]

loss: 0.4743, val loss: 0.6664
Mean spearman r: 0.582
Num epochs since best: 4


100%|██████████| 29/29 [01:21<00:00,  2.82s/it]


Epoch 48/300 - 108.75s


  0%|          | 0/29 [00:00<?, ?it/s]

loss: 0.4710, val loss: 0.6648
Mean spearman r: 0.581
Num epochs since best: 5


100%|██████████| 29/29 [01:21<00:00,  2.80s/it]


Epoch 49/300 - 108.53s


  0%|          | 0/29 [00:00<?, ?it/s]

loss: 0.4679, val loss: 0.6714
Mean spearman r: 0.576
Num epochs since best: 6


100%|██████████| 29/29 [01:22<00:00,  2.85s/it]


Epoch 50/300 - 109.93s


  0%|          | 0/29 [00:00<?, ?it/s]

loss: 0.4707, val loss: 0.6697
Mean spearman r: 0.580
Num epochs since best: 7


100%|██████████| 29/29 [01:15<00:00,  2.60s/it]


Epoch 51/300 - 104.25s


  0%|          | 0/29 [00:00<?, ?it/s]

loss: 0.4630, val loss: 0.6800
Mean spearman r: 0.579
Num epochs since best: 8


100%|██████████| 29/29 [01:13<00:00,  2.53s/it]


Epoch 52/300 - 97.88s


  0%|          | 0/29 [00:00<?, ?it/s]

loss: 0.4619, val loss: 0.6791
Mean spearman r: 0.577
Num epochs since best: 9


100%|██████████| 29/29 [01:20<00:00,  2.76s/it]


Epoch 53/300 - 104.01s


  0%|          | 0/29 [00:00<?, ?it/s]

loss: 0.4586, val loss: 0.6680
Mean spearman r: 0.580
Num epochs since best: 10


100%|██████████| 29/29 [01:20<00:00,  2.77s/it]


Epoch 54/300 - 107.52s


  0%|          | 0/29 [00:00<?, ?it/s]

loss: 0.4554, val loss: 0.6664
Mean spearman r: 0.579
Num epochs since best: 11


 86%|████████▌ | 25/29 [01:15<00:11,  2.99s/it]