## Bagging wth resnet 50, varying number of samples

In [1]:
import os
import sys
import glob
import random
import pickle
import numpy as np
import pandas as pd
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.utils.data
import torchvision.models as models
from torchvision import datasets, models, transforms
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

from fastai.imports import *
from fastai.transforms import *
from fastai.conv_learner import *
from fastai.model import *
from fastai.dataset import *
from fastai.sgdr import *
from fastai.plots import *

SEED = 101
np.random.seed(SEED)
from torchvision import datasets, models, transforms

%reload_ext autoreload
%autoreload 2
%matplotlib inline

# Add the src directory for functions
src_dir = os.path.join(os.path.dirname(os.path.dirname(os.getcwd())), 'src')
print(src_dir)
sys.path.append(src_dir)

# import my functions:
from WSI_pytorch_utils import*

fast_ai_dir = '/media/rene/Data/fastai/'
sys.path.append(fast_ai_dir)

torch.cuda.set_device(1)
print(torch.cuda.is_available())
print(torch.cuda.current_device())

/media/rene/Data/camelyon/src
True
1


In [2]:
# get data
data_loc = '/media/rene/Data/camelyon_out/tiles_224_100t_all'
all_imgs = glob.glob(data_loc+'/*')
img_names = [loc.rsplit('/', 1)[-1] for loc in all_imgs]

# load train valid split
ttv_split = np.load('/media/rene/Data/camelyon/other/ttv_split.p')
normal_valid = ttv_split['normal_vaild_idx']
tumor_valid = ttv_split['tumor_vaild_idx']
normal_train = list(range(1, 161))
normal_train = [num for num in normal_train if num not in normal_valid]
tumor_train = list(range(1, 111))
tumor_train = [num for num in tumor_train if num not in tumor_valid]

In [3]:
# get classes corresponding to each file
def get_label(name):
    if name.rsplit('_', 1)[-2][-5:]=='tumor':
        label = 'tumor'
    else: 
        label = 'normal'
    return label

classes = [get_label(name) for name in img_names]

# make csv
labels_df = pd.DataFrame(
    {'file_name': img_names,
     'label': classes
    })

# labels_df.to_csv('/media/rene/Data/camelyon_out/tiles_224_100t_all_other/tiles_224_100t_all_labels.csv')

In [4]:
# get the validation indices from the big list
def is_validation(name, normal_valid, tumor_valid):
    normal_valid = [int(x) for x in normal_valid]
    tumor_valid = [int(x) for x in tumor_valid]
    num = int(name.split('_', 1)[1].split('_', 1)[0])
    
    if classes[idx] == 'normal':
        return num in normal_valid
    
    elif classes[idx] == 'tumor':
        return num in tumor_valid
    else:
        raise ValueError("tile isn't tumor or non tumor")

valid_idxs = []
for idx, name in enumerate(img_names):
    if is_validation(name, normal_valid, tumor_valid):
        valid_idxs.append(idx)
        
# with open('/media/rene/Data/camelyon_out/tiles_224_100t_all_other/tiles_224_100t_all_val_idxs.p', 'wb') as fp:
#     pickle.dump(valid_idxs, fp)

def make_validation_mask(data_df, normal_valid, tumor_valid):
    img_names = data_df['file_name'].tolist()
    valid_idxs = [is_validation(name, normal_valid, tumor_valid) for name in img_names]
    return valid_idxs

In [5]:
# get the indices of slides to use
def make_sample(labels_df, normal_train, tumor_train, downsample=2):
    normal_train_subset = random.sample(normal_train, int(len(normal_train)/int(downsample)))
    normal_train_subset = [int(x) for x in normal_train_subset]
    
    tumor_train_subset = random.sample(tumor_train, int(len(tumor_train)/int(downsample)))
    tumor_train_subset = [int(x) for x in tumor_train_subset]
    
    drop_rows = []
    for index, row in labels_df.iterrows():
        label = row['label']
        num = int(row['file_name'].split('_', 1)[1].split('_', 1)[0])
        
        if label == 'normal' and num not in normal_train_subset:
            drop_rows.append(index)
        elif label == 'tumor' and num not in tumor_train_subset:
            drop_rows.append(index)
    
    subset_labels_df = labels_df.drop(labels_df.index[drop_rows])
    return subset_labels_df

In [None]:
models_arch = resnet50
models_name = 'resnet50'

lr=np.array([1e-3,5e-3,1e-2])

sz = 224
PATH = '/media/rene/Data/camelyon_out/tiles_224_100t_all_other'
train_folder = '/media/rene/Data/camelyon_out/tiles_224_100t_all'
# val_idxs = pickle.load( open( "/media/rene/Data/camelyon_out/tiles_224_100t_all_other/tiles_224_100t_all_val_idxs.p", "rb" ) )

for i in tqdm(range(10)):
    # create the new csv and save it
    subset_labels_df = make_sample(labels_df, normal_train, tumor_train, downsample=2)
    subset_labels_df.to_csv('/media/rene/Data/camelyon_out/tiles_224_100t_all_other/tiles_224_100t_all_subset_tmp2.csv',  
                            index = False)
    csv_fname = '/media/rene/Data/camelyon_out/tiles_224_100t_all_other/tiles_224_100t_all_subset_tmp2.csv'
    
    # get the validation indices based on this
    val_idxs = make_validation_mask(subset_labels_df, normal_valid, tumor_valid)

    tfms = tfms_from_model(models_arch, sz, aug_tfms=transforms_top_down, max_zoom=1)
    data = ImageClassifierData.from_csv(PATH, train_folder, csv_fname, tfms=tfms, 
                                        val_idxs=val_idxs, bs=64)
    learn = ConvLearner.pretrained(models_arch, data, precompute=False)
    
    lr =.001
    learn.fit(lr, 1, cycle_len=1, cycle_mult=1) # train last few layers
    lrs = np.array([lr/4,lr/2,lr])
    learn.unfreeze()
    learn.fit(lrs, 3, cycle_len=1, cycle_mult=2, best_save_name=models_name+'_half_wsi_'+str(i)) # train whole model

  0%|          | 0/10 [00:00<?, ?it/s]

HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))

  0%|          | 0/587 [00:00<?, ?it/s]




epoch      trn_loss   val_loss   accuracy                    
    0      0.262329   0.223346   0.911861  



HBox(children=(IntProgress(value=0, description='Epoch', max=7), HTML(value='')))

epoch      trn_loss   val_loss   accuracy                    
    0      0.208303   0.185835   0.926514  
    1      0.191932   0.160223   0.939209                    
    2      0.163955   0.153961   0.940674                    
    3      0.168058   0.14088    0.945801                    
    4      0.146847   0.137606   0.947998                    
    5      0.141507   0.129301   0.950684                    
    6      0.136285   0.129445   0.951904                    


 10%|█         | 1/10 [33:09<4:58:28, 1989.82s/it]




HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))

epoch      trn_loss   val_loss   accuracy                    
    0      0.267798   0.235747   0.901984  



HBox(children=(IntProgress(value=0, description='Epoch', max=7), HTML(value='')))

epoch      trn_loss   val_loss   accuracy                    
    0      0.227203   0.188417   0.920722  
    1      0.182155   0.161996   0.93497                     
    2      0.184459   0.158538   0.93714                     
    3      0.158938   0.144103   0.944159                    
    4      0.145542   0.142144   0.946763                    
    5      0.132528   0.139232   0.946255                    
    6      0.14332    0.138325   0.947557                    


 20%|██        | 2/10 [1:05:43<4:22:55, 1971.91s/it]




HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))

epoch      trn_loss   val_loss   accuracy                    
    0      0.266833   0.222873   0.90734   



HBox(children=(IntProgress(value=0, description='Epoch', max=7), HTML(value='')))

epoch      trn_loss   val_loss   accuracy                    
    0      0.221553   0.176262   0.928589  
    1      0.193404   0.145565   0.941307                    
    2      0.16546    0.14371    0.94276                     
    3      0.168801   0.133416   0.947847                    
    4      0.148051   0.12305    0.950391                    
    5      0.138405   0.1232     0.954388                    
                                                             

 30%|███       | 3/10 [1:37:22<3:47:13, 1947.66s/it]

    6      0.140337   0.127227   0.951118  



HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))

epoch      trn_loss   val_loss   accuracy                    
    0      0.262889   0.218582   0.911615  



HBox(children=(IntProgress(value=0, description='Epoch', max=7), HTML(value='')))

epoch      trn_loss   val_loss   accuracy                    
    0      0.2083     0.185892   0.923951  
    1      0.19502    0.15946    0.936012                    
 61%|██████    | 359/589 [02:25<01:33,  2.47it/s, loss=0.176]

In [None]:
models_arch = resnet50
models_name = 'resnet50'

lr=np.array([1e-3,5e-3,1e-2])

sz = 224
PATH = '/media/rene/Data/camelyon_out/tiles_224_100t_all_other'
train_folder = '/media/rene/Data/camelyon_out/tiles_224_100t_all'
# val_idxs = pickle.load( open( "/media/rene/Data/camelyon_out/tiles_224_100t_all_other/tiles_224_100t_all_val_idxs.p", "rb" ) )

for i in tqdm(range(10)):
    # create the new csv and save it
    subset_labels_df = make_sample(labels_df, normal_train, tumor_train, downsample=4)
    subset_labels_df.to_csv('/media/rene/Data/camelyon_out/tiles_224_100t_all_other/tiles_224_100t_all_subset_tmp2.csv',  
                            index = False)
    csv_fname = '/media/rene/Data/camelyon_out/tiles_224_100t_all_other/tiles_224_100t_all_subset_tmp2.csv'
    
    # get the validation indices based on this
    val_idxs = make_validation_mask(subset_labels_df, normal_valid, tumor_valid)

    tfms = tfms_from_model(models_arch, sz, aug_tfms=transforms_top_down, max_zoom=1)
    data = ImageClassifierData.from_csv(PATH, train_folder, csv_fname, tfms=tfms, 
                                        val_idxs=val_idxs, bs=64)
    learn = ConvLearner.pretrained(models_arch, data, precompute=False)
    
    lr =.001
    learn.fit(lr, 1, cycle_len=1, cycle_mult=1) # train last few layers
    lrs = np.array([lr/4,lr/2,lr])
    learn.unfreeze()
    learn.fit(lrs, 3, cycle_len=1, cycle_mult=2, best_save_name=models_name+'_quarter_wsi_'+str(i)) # train whole model

In [6]:
models_arch = resnet50
models_name = 'resnet50'

lr=np.array([1e-3,5e-3,1e-2])

sz = 224
PATH = '/media/rene/Data/camelyon_out/tiles_224_100t_all_other'
train_folder = '/media/rene/Data/camelyon_out/tiles_224_100t_all'
# val_idxs = pickle.load( open( "/media/rene/Data/camelyon_out/tiles_224_100t_all_other/tiles_224_100t_all_val_idxs.p", "rb" ) )

for i in tqdm(range(3, 10)):
    # create the new csv and save it
    subset_labels_df = make_sample(labels_df, normal_train, tumor_train, downsample=8)
    subset_labels_df.to_csv('/media/rene/Data/camelyon_out/tiles_224_100t_all_other/tiles_224_100t_all_subset_tmp2.csv',  
                            index = False)
    csv_fname = '/media/rene/Data/camelyon_out/tiles_224_100t_all_other/tiles_224_100t_all_subset_tmp2.csv'
    
    # get the validation indices based on this
    val_idxs = make_validation_mask(subset_labels_df, normal_valid, tumor_valid)

    tfms = tfms_from_model(models_arch, sz, aug_tfms=transforms_top_down, max_zoom=1)
    data = ImageClassifierData.from_csv(PATH, train_folder, csv_fname, tfms=tfms, 
                                        val_idxs=val_idxs, bs=64)
    learn = ConvLearner.pretrained(models_arch, data, precompute=False)
    
    lr =.001
    learn.fit(lr, 1, cycle_len=1, cycle_mult=1) # train last few layers
    lrs = np.array([lr/4,lr/2,lr])
    learn.unfreeze()
    learn.fit(lrs, 3, cycle_len=1, cycle_mult=2, best_save_name=models_name+'_eighth_wsi_'+str(i)) # train whole model

  0%|          | 0/7 [00:00<?, ?it/s]

HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))

  0%|          | 0/147 [00:00<?, ?it/s]




epoch      trn_loss   val_loss   accuracy                    
    0      0.325046   0.206593   0.927734  



HBox(children=(IntProgress(value=0, description='Epoch', max=7), HTML(value='')))

epoch      trn_loss   val_loss   accuracy                    
    0      0.253826   0.149591   0.943359  
    1      0.19455    0.114476   0.949219                    
    2      0.166934   0.106272   0.962891                    
    3      0.137974   0.10423    0.962891                    
    4      0.131908   0.091828   0.972656                    
    5      0.126399   0.09561    0.974609                    
                                                             

 14%|█▍        | 1/7 [07:34<45:27, 454.57s/it]

    6      0.106828   0.091631   0.956752  



HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))

epoch      trn_loss   val_loss   accuracy                    
    0      0.290788   0.218866   0.919     



HBox(children=(IntProgress(value=0, description='Epoch', max=7), HTML(value='')))

epoch      trn_loss   val_loss   accuracy                    
    0      0.251772   0.170955   0.93679   
    1      0.195534   0.140784   0.943182                    
    2      0.181554   0.137885   0.946023                    
    3      0.162888   0.137963   0.949574                    
    4      0.139038   0.124318   0.953125                    
    5      0.127484   0.117779   0.953835                    
    6      0.130003   0.118255   0.955256                    


 29%|██▊       | 2/7 [14:44<36:50, 442.09s/it]




HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))

epoch      trn_loss   val_loss   accuracy                    
    0      0.276907   0.253477   0.901339  



HBox(children=(IntProgress(value=0, description='Epoch', max=7), HTML(value='')))

epoch      trn_loss   val_loss   accuracy                    
    0      0.233152   0.202621   0.92753   
    1      0.18295    0.162353   0.944792                    
    2      0.17623    0.165807   0.942708                    
    3      0.163096   0.148665   0.947917                    
    4      0.141063   0.13776    0.951042                    
    5      0.131605   0.134847   0.950446                    
                                                             

 43%|████▎     | 3/7 [21:34<28:45, 431.39s/it]

    6      0.12496    0.145981   0.94628   



HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))

epoch      trn_loss   val_loss   accuracy                    
    0      0.283951   0.198252   0.919677  



HBox(children=(IntProgress(value=0, description='Epoch', max=7), HTML(value='')))

epoch      trn_loss   val_loss   accuracy                    
    0      0.220947   0.161576   0.941761  
    1      0.181819   0.140162   0.947443                    
    2      0.185009   0.131746   0.945279                    
    3      0.167852   0.121377   0.955222                    
    4      0.152452   0.118568   0.953835                    
    5      0.139008   0.112765   0.957386                    
                                                             

 57%|█████▋    | 4/7 [27:55<20:56, 418.95s/it]

    6      0.132593   0.114059   0.953125  



HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))

epoch      trn_loss   val_loss   accuracy                    
    0      0.30774    0.183016   0.925781  



HBox(children=(IntProgress(value=0, description='Epoch', max=7), HTML(value='')))

epoch      trn_loss   val_loss   accuracy                    
    0      0.259914   0.158487   0.935547  
    1      0.206153   0.120345   0.951172                    
    2      0.178511   0.118698   0.955078                    
    3      0.152007   0.107302   0.957031                    
    4      0.139541   0.096573   0.957031                    
    5      0.126523   0.088085   0.964844                    
                                                             

 71%|███████▏  | 5/7 [35:14<14:05, 422.98s/it]

    6      0.124722   0.089871   0.964844  



HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))

epoch      trn_loss   val_loss   accuracy                    
    0      0.284164   0.199176   0.91769   



HBox(children=(IntProgress(value=0, description='Epoch', max=7), HTML(value='')))

epoch      trn_loss   val_loss   accuracy                    
    0      0.199776   0.115555   0.970703  
    1      0.167617   0.095438   0.970703                    
    2      0.134185   0.08613    0.970703                    
    3      0.118087   0.072722   0.976562                    
    4      0.113105   0.072709   0.980469                    
    5      0.099509   0.066261   0.978516                     
    6      0.093223   0.064853   0.982422                     


 86%|████████▌ | 6/7 [42:25<07:04, 424.27s/it]




HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))

epoch      trn_loss   val_loss   accuracy                    
    0      0.35417    0.256219   0.891933  



HBox(children=(IntProgress(value=0, description='Epoch', max=7), HTML(value='')))

epoch      trn_loss   val_loss   accuracy                    
    0      0.287616   0.185909   0.928417  
    1      0.210554   0.138298   0.952278                    
    2      0.17637    0.130712   0.954433                    
    3      0.162616   0.110209   0.963054                    
    4      0.135962   0.100832   0.966287                    
    5      0.122037   0.094353   0.964671                    
                                                             

100%|██████████| 7/7 [52:36<00:00, 450.90s/it]

    6      0.124115   0.095085   0.964671  



In [None]:
##