In [None]:
!pip install torchio

In [None]:
import torchio as tio

import os
import torch
torch.set_grad_enabled(False)
import pandas as pd
from tqdm.notebook import tqdm
from torch.utils.data import DataLoader

In [None]:
# set paths
data_dir   = '/kaggle/input/rsna-miccai-brain-tumor-radiogenomic-classification/'
train_dir  = data_dir+'train/'

# get labels and patients
labels_df = pd.read_csv(data_dir+'train_labels.csv', index_col=0)
patients = os.listdir(train_dir)

# Remove cases the competion host said to exclude 
# https://www.kaggle.com/c/rsna-miccai-brain-tumor-radiogenomic-classification/discussion/262046
patients.remove('00109')
patients.remove('00123')
patients.remove('00709')

patients = patients[:42]

print(f'Total patients: {len(patients)}\n\n')

In [None]:
# List of scan types... but just use T1wCE for now
scan_types = ['FLAIR', 'T1w', 'T1wCE', 'T2w']
#scan_types = ['T1wCE']

# decare variables
subjects_list = []
cases = {}
dataset_types ={}
dataset_cases ={}

# loop thru scan types
for scan_type in scan_types:
    # loop thru patients
    for patient in patients:
        # get label for patient 
        label = labels_df._get_value(int(patient), 'MGMT_value')
        
        # create subject object for each patient and save to cases dictionary
        cases[patient] = tio.Subject(
             image=tio.ScalarImage(f'{train_dir}{patient}/{scan_type}',),
             label=label,
         )
        # add subject object to subjects_list
        subjects_list.append(cases[patient])
    
    # normalize and resize scans
    transforms = [
        tio.ToCanonical(),
        tio.Resample(1),
        tio.ZNormalization(masking_method=tio.ZNormalization.mean),
        tio.CropOrPad((128,128,64)),
    ]
    transform = tio.Compose(transforms)
    
    # create subject datasets
    subjects_dataset = tio.SubjectsDataset(subjects_list, transform=transform)
    
    # add dataset and list to dictionary of types and cases
    dataset_types[scan_type] = subjects_dataset
    dataset_cases[scan_type] = subjects_list
            

In [None]:
# train a model
# https://torchio.readthedocs.io/data/patch_training.html

# loop thru scan types
for scan_type in scan_types:
    print(f'{scan_type}')
    
    # train a model
    patch_size = (128, 128, 64)
    queue_length = 300
    samples_per_volume = 2
    sampler = tio.data.UniformSampler(patch_size)
    patches_queue = tio.Queue(
        dataset_types[scan_type],
        queue_length,
        samples_per_volume,
        sampler,
        num_workers=4,
    )
    patches_loader = DataLoader(patches_queue, batch_size=16)    
    
    
    num_epochs = 2
    model = torch.nn.Identity()
    #model = torch.hub.load('fepegar/highresnet', 'highres3dnet', pretrained=True)
    for epoch_index in range(num_epochs):
        for patches_batch in patches_loader:
            inputs = patches_batch['image'][tio.DATA]  # key 'image' is in subject
            targets = patches_batch['label']#[tio.DATA]  # key 'label' is in subject
            logits = model(inputs)  # model being an instance of torch.nn.Module
    
    #save model: https://pytorch.org/tutorials/beginner/saving_loading_models.html
    print(model)
    torch.save(model.state_dict(), f'./{scan_type}_state_dict')
    torch.save(model, f'./{scan_type}_model')

In [None]:
# the goal here is to train an exisiting model with the new dataset... but I think this is just doing Inference
# https://torchio.readthedocs.io/data/patch_inference.html
'''
# pull a pre trained model
repo = 'fepegar/highresnet'
model_name = 'highres3dnet'
model = torch.hub.load(repo, model_name, pretrained=True)
device = torch.device('cuda') if torch.cuda.is_available() else 'cpu'
print('Device:', device)
model.to(device).eval();

# loop thru scan types
for scan_type in scan_types:
    print(f'{scan_type}')
    # to cut down on processing power we use a GridSampler and GridAggregator
    for preprocessed in dataset_types[scan_type]:
        patch_overlap = 4
        patch_size = (128, 128, 64)
        grid_sampler = tio.inference.GridSampler(
            preprocessed,
            patch_size,
            patch_overlap,
        )
        patch_loader = torch.utils.data.DataLoader(grid_sampler)
        aggregator = tio.inference.GridAggregator(grid_sampler)
        preprocessed.clear_history()

        for patches_batch in tqdm(patch_loader, unit='batch'):
            input_tensor = patches_batch['image'][tio.DATA].to(device)
            locations = patches_batch[tio.LOCATION]
            with torch.cuda.amp.autocast():
                logits = model(input_tensor)
            labels = logits.argmax(dim=tio.CHANNELS_DIMENSION, keepdim=True)
            aggregator.add_batch(labels, locations)
        output_tensor = aggregator.get_output_tensor()
'''