In [25]:
%run custom_datasets.ipynb

Imported classes.


In [26]:
import os.path as osp

import fcn
import torch.nn as nn

from livelossplot import PlotLosses

### Function that returns dataloaders and device

In [27]:
def get_dataloaders(projects, batch_size=24, dim="2D"): 
    
    if dim.lower() == "2d": 
        data = PrepareData2D(projects)
        
        train_ds = Dataset2D(data.train)
        train_dl = DataLoader(train_ds, batch_size=batch_size)

        val_ds = Dataset2D(data.val)
        val_dl = DataLoader(val_ds, batch_size=batch_size)

        test_ds = Dataset2D(data.test)
        test_dl = DataLoader(test_ds, batch_size=batch_size)
    
    elif dim.lower() == "3d": 
        data = PrepareData3D(projects)
        
        train_ds = Dataset3D(data.train)
        train_dl = DataLoader(train_ds, batch_size=batch_size)

        val_ds = Dataset3D(data.val)
        val_dl = DataLoader(val_ds, batch_size=batch_size)

        test_ds = Dataset3D(data.test)
        test_dl = DataLoader(test_ds, batch_size=batch_size)
    else:
        raise(Exception("Specify a valid dimension. Choose from: [\"2D\", \"3D\"]."))
    
    return train_dl, val_dl, test_dl

def set_device():
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print('----------------------------------')
    print('Using device for training:', DEVICE)
    print('----------------------------------')
    
    return DEVICE 

### Test models

In [28]:
class FCN2d(nn.Module):

    def __init__(self, n_class=2):
        super(FCN2d, self).__init__()
        
        self.model = nn.Sequential(
            nn.Conv2d(1, 16, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, stride=2, ceil_mode=True),
            nn.Conv2d(16, 64, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, stride=2, ceil_mode=True),
            nn.ConvTranspose2d(64, 8, 4, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(8, 1, 4, stride=2, padding=1),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        return self.model(x)
    
class FCN3d(nn.Module):

    def __init__(self, n_class=2):
        super(FCN3d, self).__init__()
        
        self.model = nn.Sequential(
            nn.Conv3d(1, 16, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool3d(2, stride=2, ceil_mode=True),
            
            nn.Conv3d(16, 64, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool3d(2, stride=2, ceil_mode=True),
            
            nn.ConvTranspose3d(64, 8, 4, stride=2, padding=1),
            nn.ReLU(inplace=True),
            
            nn.ConvTranspose3d(8, 1, 4, stride=2, padding=1),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        out = self.model(x)
        return out
    
    
def init_model(dim): 
    if dim.lower() == "2d": 
        model = FCN2d().to(DEVICE)

    elif dim.lower() == "3d": 
        model = FCN3d().to(DEVICE)

    else:
        raise(Exception("Specify a valid dimension. Choose from: [\"2D\", \"3D\"]."))
    
    return model

### Train function

In [29]:
from tqdm import tqdm

def train(model, epochs, train_dl, val_dl, optimizer, criterion, DEVICE): 
    liveloss = PlotLosses()
    
    for epoch in range(epochs):
        print(f"Starting epoch: {epoch}")
        
        t_loss, t_dice = 0, 0
        v_loss, v_dice = 0, 0
        
#         for i, batch in enumerate(tqdm(train_dl)):
        for i, batch in enumerate(train_dl):
            model.train()

            inputs, labels = batch[2:]
            inputs = inputs.float().to(DEVICE)
            labels = labels.float().to(DEVICE)
            
            out = model(inputs)    
            loss = criterion(out, labels)    

            t_loss += loss.item()
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            t_dice += dice_loss(out, labels).item()
    
        t_loss /= train_dl.__len__()
        t_dice /= i + 1
                                    
        mem("Finished training for epoch " + str(epoch))
        
#         for i, batch in enumerate(tqdm(val_dl)):            
        for i, batch in enumerate(val_dl):            
            model.eval()

            inputs, labels = batch[2:]
            inputs = inputs.float().to(DEVICE)
            labels = labels.float().to(DEVICE)

            out = model(inputs)  
            
            loss = criterion(out, labels)
            
            v_loss += loss.item()        
            
            v_dice += dice_loss(out, labels).item()
        
        mem("Finished validation for epoch " + str(epoch))
        
        v_loss /= val_dl.__len__()
        v_dice /= i + 1

        liveloss.update({"loss": t_loss, "val_loss": v_loss, "dice score": t_dice, "val_dice score": v_dice})
        liveloss.send()
        

def dice_loss(pred, target):
    
    smooth = 0.

    pred = torch.round(pred)

    pflat = pred.flatten()
    tflat = target.flatten()
    intersection = (pflat * tflat).sum()

    A_sum = torch.sum(pflat * pflat)
    B_sum = torch.sum(tflat * tflat)
    
    return 1 - ((2. * intersection + smooth) / (A_sum + B_sum + smooth) )

In [30]:
def mem(step):
    div = 1024*1024*1024
    
    t = torch.cuda.get_device_properties(0).total_memory
    c = torch.cuda.memory_cached(0)
    a = torch.cuda.memory_allocated(0)

    print(step)
    print("Memory cached:", round(c/div, 3))
    print("Memory allocated:", round(c/div, 3))
    print()

mem("test")

test
Memory cached: 1.551
Memory allocated: 1.551



## Here the training starts

In [31]:
DEVICE = set_device()

projects = ["Aorta Volunteers", "Aorta BaV", "Aorta Resvcue", "Aorta CoA"]
dim = "3d"
epochs = 1000

----------------------------------
Using device for training: cuda
----------------------------------


In [32]:
train_dl, val_dl, test_dl = get_dataloaders(projects, batch_size=12, dim=dim)

model = init_model(dim)
model.load_state_dict(torch.load("/scratch/ptenkaate/Models/3d_valdice_166.pth.tar"))

optimizer = torch.optim.Adam(model.parameters())
criterion = torch.nn.BCELoss()

In [33]:
batch = next(iter(train_dl))
print(batch[0])
# print(batch[0])
# print(batch[1].shape)
print(batch[1])
print(batch[2].shape)
print(batch[2])
# print(batch[3].shape)
# print(batch[3])

('RESV_109.npy', '16-09-28_126_done1.npy', 'RESV_017.npy', '16-05-20_217_done.npy', '16-05-11_208_done.npy', '16-09-16_243_done1.npy', '16-04-20_105_done.npy', '16-01-22_Jarik_kt-pca_done2.npy', '16-03-11_Mart_kt-pca_done.npy', '16-08-24_120_done1.npy', 'RESV_016.npy', '16-08-24_239_done1.npy')
('Aorta Resvcue', 'Aorta BaV', 'Aorta Resvcue', 'Aorta CoA', 'Aorta CoA', 'Aorta CoA', 'Aorta BaV', 'Aorta Volunteers', 'Aorta Volunteers', 'Aorta BaV', 'Aorta Resvcue', 'Aorta CoA')
torch.Size([12, 1, 24, 128, 128])
tensor([[[[[0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           ...,
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.]],

          [[0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           ...,
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,

In [37]:
mem("Memory before starting train function")

%matplotlib inline

train(model, epochs, train_dl, val_dl, optimizer, criterion, DEVICE)

KeyboardInterrupt: 

In [None]:
def tensor_to_array(tensor, dim):
    images = tensor.clone()
    
    if dim.lower() == "2d":
        images = images.detach().squeeze().permute(1,2,0).numpy()
        
    elif dim.lower() == "3d":
        lst = []
        images = images.detach().squeeze()
        
        for i in range(images.shape[0]):
            lst.append(images[i].permute(1,2,0).numpy())
        images = np.dstack(lst)

    return images


def show_model_output(batch, labels, dim):

    dice_losses = []
    
    for i in range(len(batch[0])):
        dice_losses.append(round(dice_loss(batch[3][i], labels[i].cpu()).item(), 2))
        
    
    if dim == "2d":
        headers = [f"{proj} : {subj}, Dice Loss {dl}" for proj, subj, dl in zip(batch[0], batch[1], dice_losses)]
    
    elif dim == "3d": 
        projs, subjs, dl_extended = [], [], []
        
        
        n_samples = batch[2].shape[0]
        n_slices = batch[2].shape[2]

        for sample_idx in range(n_samples): 
            for i in range(n_slices): 
                projs.append(batch[0][sample_idx])
                subjs.append(batch[1][sample_idx])
                dl_extended.append(dice_losses[sample_idx])
        
        headers = [f"{proj} : {subj}, Dice Loss {dl}" for proj, subj, dl in zip(projs, subjs, dl_extended)]
        
    pcmra = tensor_to_array(batch[2], dim)
    mask = tensor_to_array(batch[3], dim)
    prediction = tensor_to_array(labels.cpu().round(), dim)

    
    print(pcmra.shape)
    show = Show_images(headers, 
                       (pcmra, "pcmra"), 
                       (mask, "mask"), 
                       (pcmra + mask, "pcmra + mask"),
                       (prediction, "prediction"), 
                       (pcmra + prediction, "pcmra + prediction"))
    
    return show

In [None]:
%matplotlib qt

batch = next(iter(val_dl))
out = model(batch[2].float().to(DEVICE))
print(out.shape)

show = show_model_output(batch, out, dim)

In [None]:
mem("sadf")

In [None]:
torch.save(model.state_dict(), "/scratch/ptenkaate/Models/3d_valdice_166.pth.tar")

In [35]:
torch.cuda.empty_cache()