In [1]:
import argparse
import logging
import sys
from pathlib import Path
import numpy as np
import glob
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
import wandb
from torch import optim
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
from unet.evaluate import evaluate
from segmentation_experiments.data_loading import SegmentationDataSet
from segmentation_experiments import data_loading
from utils.dice_score import dice_loss
from unet import UNet
from unet import simpleUNet

In [2]:
from importlib import reload
reload(data_loading)
reload(simpleUNet)

<module 'unet.simpleUNet' from '/home/rahulv/codes/robustdg/unet/simpleUNet.py'>

In [4]:
train_set = data_loading.SegmentationDataSet('data/syntheticSegmentation/small_train_dom2.npz')
val_set = data_loading.SegmentationDataSet('data/syntheticSegmentation/test_dom1.npz')
batch_size = 32
dataloaders = {
    'train': DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0),
    'val': DataLoader(val_set, batch_size=batch_size, shuffle=True, num_workers=0)
}

In [5]:
#check outputs from dataloader

inputs, masks = next(iter(dataloaders['train']))
print(inputs.shape, masks.shape)

torch.Size([32, 1, 256, 256]) torch.Size([32, 2, 256, 256])


In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = simpleUNet.UNet(n_class=2)
model = model.to(device)

# check keras-like model summary using torchsummary
from torchsummary import summary
summary(model, input_size=(1, 256, 256))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 256, 256]             640
              ReLU-2         [-1, 64, 256, 256]               0
            Conv2d-3         [-1, 64, 256, 256]          36,928
              ReLU-4         [-1, 64, 256, 256]               0
         MaxPool2d-5         [-1, 64, 128, 128]               0
            Conv2d-6        [-1, 128, 128, 128]          73,856
              ReLU-7        [-1, 128, 128, 128]               0
            Conv2d-8        [-1, 128, 128, 128]         147,584
              ReLU-9        [-1, 128, 128, 128]               0
        MaxPool2d-10          [-1, 128, 64, 64]               0
           Conv2d-11          [-1, 256, 64, 64]         295,168
             ReLU-12          [-1, 256, 64, 64]               0
           Conv2d-13          [-1, 256, 64, 64]         590,080
             ReLU-14          [-1, 256,

In [None]:
import torch
import torch.optim as optim
from torch.optim import lr_scheduler
from unet import training_loop
reload(training_loop)

optimizer_ft = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4)

exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=30, gamma=0.1)

model, loss_values = training_loop.train_model(model, optimizer_ft, exp_lr_scheduler, dataloaders, num_epochs=60)

np.save('checkpoints/baseSegmentationModelMetrics.npz', loss_values)

Epoch 0/59
----------
LR 0.0001




train: bce: 0.559628, dice: 0.562207, loss: 0.560918
val: bce: 0.288309, dice: 0.500054, loss: 0.394181
saving best model
0m 44s
Epoch 1/59
----------
LR 0.0001
train: bce: 0.172381, dice: 0.445269, loss: 0.308825
val: bce: 0.074010, dice: 0.329069, loss: 0.201540
saving best model
0m 49s
Epoch 2/59
----------
LR 0.0001
train: bce: 0.052607, dice: 0.232489, loss: 0.142548
val: bce: 0.046320, dice: 0.095507, loss: 0.070914
saving best model
0m 50s
Epoch 3/59
----------
LR 0.0001
train: bce: 0.019582, dice: 0.063545, loss: 0.041564
val: bce: 0.004885, dice: 0.043716, loss: 0.024301
saving best model
0m 50s
Epoch 4/59
----------
LR 0.0001
train: bce: 0.008976, dice: 0.050182, loss: 0.029579
val: bce: 0.006121, dice: 0.044114, loss: 0.025118
0m 50s
Epoch 5/59
----------
LR 0.0001
train: bce: 0.004683, dice: 0.036537, loss: 0.020610
val: bce: 0.006377, dice: 0.037233, loss: 0.021805
saving best model
0m 50s
Epoch 6/59
----------
LR 0.0001
train: bce: 0.006636, dice: 0.037414, loss: 0.022025