In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import torch.nn.functional as F
from torchvision import transforms, utils
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from unet.unet_helpers import *
from unet.unet import UNet
from unet.utils import infere_model, visualize_sample

In [3]:
labels = [0,1,2]
label_thresh = [20,220]
mask_label_info = [labels, label_thresh]
split_samples = None
# Original images
transformations = transforms.Compose([Rescale((256,256)), ToTensor()])
etching_dataset = MicroscopeImageDataset(img_dir=['./images/etching/set2/bottom/','./images/etching/set2/top/'], 
                                         mask_dir='./images/etching/set2/masks/',
                                         mask_label_info=mask_label_info, read_top=True,
                                         transf=transformations, split_samples=split_samples)

# Contrast, brightness adjusted
transformations = transforms.Compose([Rescale((256,256)), 
                                      BrightnessContrastAdjustment((-0.3,1.3), 'brightness'), 
                                      ToTensor()])
etching_dataset_bc = MicroscopeImageDataset(img_dir=['./images/etching/set2/bottom/','./images/etching/set2/top/'], 
                                         mask_dir='./images/etching/set2/masks/',
                                         mask_label_info=mask_label_info, read_top=True,
                                         transf=transformations, split_samples=split_samples)

# Rotated by 90 degrees
transformations = transforms.Compose([Rescale((256,256)), 
                                      Rotate(90), 
                                      ToTensor()])
etching_dataset_rotate = MicroscopeImageDataset(img_dir=['./images/etching/set2/bottom/','./images/etching/set2/top/'], 
                                         mask_dir='./images/etching/set2/masks/',
                                         mask_label_info=mask_label_info, read_top=True,
                                         transf=transformations, split_samples=split_samples)

combined_dset = ConcatDatasets(etching_dataset, etching_dataset_bc, etching_dataset_rotate)

In [4]:
# Train parameter
epochs = 20
eta = 5e-3
lambda_l2 = 1e-3
gamma = 0.9
depth = 4
n_filters = 3

# U-Net definition
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = UNet(in_channels=2, n_classes=3, depth=depth, wf=n_filters, padding=True, up_mode='upconv', batch_norm=True).to(device)
optim = torch.optim.SGD(model.parameters(), lr=eta, momentum=gamma)
criterion = torch.nn.CrossEntropyLoss()
dataloader = DataLoader(combined_dset, batch_size=3,
                        shuffle=True, num_workers=4)

avg_epoch_loss, model = train_unet(model, device, optim, criterion, dataloader, 
                                   epochs=epochs, lambda_=1e-3, reg_type=None, 
                                   save=False)


Epoch 0
<class 'dict'>
torch.Size([3, 2, 256, 256])
<class 'dict'>
torch.Size([3, 2, 256, 256])
<class 'dict'>
torch.Size([3, 2, 256, 256])
<class 'dict'>
torch.Size([3, 2, 256, 256])
<class 'dict'>
torch.Size([3, 2, 256, 256])
<class 'dict'>
torch.Size([3, 2, 256, 256])


KeyboardInterrupt: 

In [None]:
# Visualize
import numpy as np
fig, ax = plt.subplots(figsize=(14,8))
ax.plot(np.arange(epochs), np.array(avg_epoch_loss), linewidth=2)
ax.set_xlabel("Epoch")
ax.set_ylabel("CrossEntropyLoss")

In [None]:
# Infere a sample
i = 4
sample = combined_dset[i]
image, mask = sample['image'], sample['mask']
infered_mask = infere_model(model, image)

fig = plt.figure(figsize=(15,10))

axes = visualize_sample(fig, sample, infered_mask)

In [None]:
transformations = transforms.Compose([Rescale((256,256)), ToTensor()])
test_dataset = MicroscopeImageDataset(img_dir=['./images/etching/test/bottom/','./images/etching/test/top/'], 
                                         mask_dir=None,
                                         mask_label_info=mask_label_info, read_top=True,
                                         transf=transformations, split_samples=split_samples)

In [None]:
sample = test_dataset[1]
image, mask = sample['image'], sample['mask']
infered_mask = infere_model(model, image)

fig = plt.figure(figsize=(15,10))

axes = visualize_sample(fig, sample, infered_mask)