In [1]:
#imports random
import os
import glob
#from livelossplot import PlotLosses
import math
import shutil

import numpy as np
#from nilearn import plotting
import matplotlib.pyplot as plt
import nibabel as nib

#import torch
import torch
import torch.nn as nn
#from torchsummary import summary
import torch.optim as optim
from ipywidgets import widgets, interact

In [2]:
DEBUG = True

#device
device = torch.device("cuda:0")
#device = torch.device("cpu")

#Determinism
#np.random.seed(12345678)
set_determinism(seed=12345678)
RSEED = 12345678

#Paths
DATA_DIR = 'DS1_scans/*'

#Dataset and Dataloader
 
NUM_WORKERS= 0
BATCH_SIZE= 2
SHUFFLE= True

#training
#MAX_EPOCHS = 100
LEARNING_RATE = 1e-4 #0.01 
WEIGHT_DECAY= 1e-5
MOMENTUM= 0.95
VAL_INTERVAL = 1  # do validation for every epoch

In [3]:
def percorrePasta(pasta):
    files = []
    for f in glob.glob(pasta):
        if(os.path.isdir(f + "/")):
            f = f + "/*"
            files += percorrePasta(f)
        else:
            if(os.path.splitext(f)[1] == ".gz" ):
                files += [f]
    return files

def separaMasks(arr):
    masks = []
    scans = []
    for f in arr:
        if "generate" in f :
            if "mask" in f :
                masks += [f]
            else:
                scans += [f]
    return {"masks": masks, "scans": scans}

ficheirosTodos = percorrePasta('./DS1_scans/')
arr = separaMasks(ficheirosTodos)
#str1 = '\n'.join(arr)
#print (str1)
#print(arr["masks"])

In [4]:
def paths(arr):
    scans=[]
    masks=[]
    train_scans_files = arr["scans"]
    train_masks_files = arr["masks"]
    
    train_scans, val_scans, train_masks, val_masks = train_test_split(train_scans_files, train_masks_files, test_size = 0.2, random_state = 42)
    
    train_set = [
        {"scan": image_name, "mask": label_name}
        for image_name, label_name in zip(train_scans, train_masks)
    ]
    
    val_set = [
        {"scan": image_name, "mask": label_name}
        for image_name, label_name in zip(val_scans, val_masks)
    ]
    
    train_set, val_set = train_set[:6], val_set[-6:]
   
    print("nº total de imagens: ", len(train_scans_files))
    print("nº casos de treino: ", len(train_scans))
    print("nº casos de validação: ", len(val_scans))
   
    
    
    return train_scans_files, train_masks_files, train_scans, val_scans, train_masks, val_masks, train_set, val_set
train_scans_files, train_masks_files, train_scans, val_scans, train_masks, val_masks, train_set, val_set=paths(arr)

nº total de imagens:  852
nº casos de treino:  681
nº casos de validação:  171


In [5]:
def transforms():
   
    train_transform = Compose(
        [
            LoadImaged(keys=["scan", "mask"]),
            AddChanneld(keys=["scan", "mask"]),
            Rotate90d(keys=["scan", "mask"], k=1,  spatial_axes=(0, 1)),
            
            
            #ScaleIntensityd(keys=["scan"], minv=0.0, maxv=1.0, factor=None),
            #ScaleIntensityRanged(keys=["scan"], a_min=-20, a_max=120, b_min=0.0, b_max=1.0, clip=True),  
            #CenterSpatialCropd(keys=["scan", "mask"], roi_size=[160, 160, 160]),
            #CenterSpatialCropd(keys=["scan", "mask"], roi_size=[224, 224, 224]),
            ToTensord(keys=["scan", "mask"]),
        ]
    )
    proc_transform = Compose(
        [
            LoadImaged(keys=["scan", "mask"]),
            AddChanneld(keys=["scan", "mask"]),
            Rotate90d(keys=["scan", "mask"], k=1,  spatial_axes=(0, 1)),
            NormalizeIntensityd(keys="scan", nonzero=True, channel_wise=False),
            
            #ScaleIntensityd(keys=["scan"], minv=0.0, maxv=1.0, factor=None),
            #ScaleIntensityRanged(keys=["scan"], a_min=-20, a_max=120, b_min=0.0, b_max=1.0, clip=True),  
            #CenterSpatialCropd(keys=["scan", "mask"], roi_size=[160, 160, 160]),
            #CenterSpatialCropd(keys=["scan", "mask"], roi_size=[224, 224, 224]),
            ToTensord(keys=["scan", "mask"]),
        ]
    )
    val_transform = Compose(
        [
            LoadImaged(keys=["scan", "mask"]),
            AddChanneld(keys=["scan", "mask"]),
            
            #ScaleIntensityd(keys=["scan"], minv=0.0, maxv=1.0, factor=None),
            #ScaleIntensityRanged(keys=["scan"], a_min=-20, a_max=120, b_min=0.0, b_max=1.0, clip=True),  
            #CenterSpatialCropd(keys=["scan", "mask"], roi_size=[160, 160, 160]),
            #CenterSpatialCropd(keys=["scan", "mask"], roi_size=[224,224, 224]), 
            ToTensord(keys=["scan", "mask"]),
       
            
        ]
    )
    
    return train_transform, val_transform, proc_transform
    

train_transform, val_transform, proc_transform = transforms()

In [6]:
def dataloader(train_files, train_transform, val_transform, val_files):
    train_ds = Dataset(data=train_files, transform=train_transform)
    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=SHUFFLE, num_workers=NUM_WORKERS)
    val_ds = Dataset(data=val_files, transform=val_transform)
    val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)
    dataloaders = {"train": train_loader, "validation": val_loader}

    return train_ds, val_ds, train_loader, val_loader

train_ds, val_ds, train_loader, val_loader = dataloader(train_set, train_transform, val_transform, val_set)

In [9]:
@interact
def check_transforms_train(i=(0,166)):

    check_data = first(train_loader)
    print(check_data["scan"].shape)
    #check_data = first(check_flip)
    image, label = (check_data["scan"][0][0], check_data["mask"][0][0])
    print(f"image shape: {image.shape}, label shape: {label.shape}")
    #print ((label == 1).nonzero(as_tuple=True))
    #print ((label == 2).nonzero(as_tuple=True))
    #label1 = np.array(label)
    #for i in range(len(label)):
     #   if label1 == 1:
      #      print(i)
    # plot the slice [:, :, 80]
    plt.figure("check", (12, 6))
    plt.subplot(1, 2, 1)
    plt.title("image")
    plt.imshow(image[:, :, i], cmap="gray")
    #plt.imshow(image[:, :, 150], cmap="gray")
    plt.subplot(1, 2, 2)
    plt.title("label")
    #plt.imshow(label[:, :, i])
    plt.imshow(label[:, :, i])
    plt.show()
    
   # return check_ds, check_loader, check_data

#check_transforms_train()

interactive(children=(IntSlider(value=83, description='i', max=166), Output()), _dom_classes=('widget-interact…