In [1]:
from __future__ import print_function
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, utils
from torch.autograd import Variable

import os
import csv
import numpy as np
import pandas as pd

from matplotlib import pyplot as plt
%matplotlib inline
# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

%load_ext autoreload

%autoreload 2
plt.ion()   # interactive mode

In [2]:
# Prepare the Dataset
IMAGE_DIR = "Dataset/2d_images/"
MASK_DIR = "Dataset/2d_masks/"
with open('Dataset/Lung_CT_Dataset.csv', 'wb') as csv_file:
    writer = csv.writer(csv_file)
    writer.writerow(["filename", "mask"])
    for p in os.listdir(IMAGE_DIR):
        image_path = os.path.join(IMAGE_DIR, p)
        mask_path = os.path.join(MASK_DIR, p)
        writer.writerow([image_path, mask_path])


In [3]:
data = pd.read_csv("Dataset/Lung_CT_Dataset.csv")
data = data.iloc[np.random.permutation(len(data))]
p = int(len(data)*0.7)
train, validation = data[:p], data[p:]
train.to_csv("Dataset/Lung_CT_Train.csv", index=False)
validation.to_csv("Dataset/Lung_CT_Validation.csv", index=False)

In [4]:
from dataset import LungCTDataset
lung_ct_train_dataset = LungCTDataset(csv_file='Dataset/Lung_CT_Train.csv', root_dir='./')
lung_ct_val_dataset = LungCTDataset(csv_file='Dataset/Lung_CT_Validation.csv', root_dir='./')
train_dataloader = DataLoader(lung_ct_train_dataset, batch_size=100, shuffle=True, num_workers=4)
val_dataloader = DataLoader(lung_ct_val_dataset, batch_size=100, shuffle=True, num_workers=4)

In [5]:
from model import * 
model_instance = UNet(1, 1)

In [6]:
# model_instance.apply(weights_init)

In [7]:
optimizer = optim.Adam(model_instance.parameters(), lr=0.000001)
criterion = nn.BCELoss()

In [8]:
def train(model, epoch):
    model.train()
    correct = 0
    for batch_idx, data in enumerate(train_dataloader):
        data, target = Variable(data["image"]), Variable(data["mask"])
        optimizer.zero_grad()
        output = model.forward(data.float())
        loss = criterion(output.float(), target.float())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if batch_idx % 1 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_dataloader.dataset),
                100. * batch_idx / len(train_dataloader), loss.data[0]))
    

In [18]:
def test(model):
    model.eval()
    test_loss = 0
    for data in val_dataloader:
        data, target = Variable(data["image"], volatile=True), Variable(data["mask"])
        output = model(data.float())
        # print(output.data[0])
        test_loss += criterion(output.float(), target.float()).data[0] # sum up batch loss
    test_loss /= len(val_dataloader.dataset)
    print("Average Loss: ", test_loss)


In [20]:
# 99 % validation accuracy 
for epoch in range(1, 15):
    train(model_instance, epoch)
    test(model_instance)    


(0 ,.,.) = 
  0.6797  0.6462  0.8109  ...   0.7261  0.6412  0.7083
  0.5893  0.5546  0.6293  ...   0.5366  0.5382  0.5871
  0.5374  0.5041  0.6811  ...   0.5835  0.5755  0.5888
           ...             ⋱             ...          
  0.5542  0.5194  0.5386  ...   0.5647  0.5818  0.6232
  0.5729  0.5098  0.5380  ...   0.5746  0.5912  0.5867
  0.5910  0.5803  0.5620  ...   0.5699  0.5325  0.5868
[torch.FloatTensor of size 1x32x32]

Average Loss:  0.0113601431251

(0 ,.,.) = 
  0.6745  0.6244  0.7889  ...   0.7119  0.6290  0.7008
  0.5869  0.5493  0.5853  ...   0.5312  0.5339  0.5869
  0.5438  0.4936  0.6518  ...   0.5745  0.5750  0.5888
           ...             ⋱             ...          
  0.5524  0.5172  0.5431  ...   0.5672  0.5795  0.6070
  0.5686  0.5080  0.5499  ...   0.5878  0.5916  0.5869
  0.5891  0.5651  0.5580  ...   0.5717  0.5332  0.5869
[torch.FloatTensor of size 1x32x32]

Average Loss:  0.0113555587828

(0 ,.,.) = 
  0.6806  0.6526  0.8214  ...   0.7282  0.6423  0.7103


In [21]:
torch.save(model_instance, './saved_models/mini_unet.pth')