In [None]:
import torch
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from unet.unet_helpers import *
from unet.unet import UNet
from unet.utils import visualize_sample

In [None]:
labels = [0,1,2]
label_thresh = [10,240]
mask_label_info = [labels, label_thresh]
rescale_to = (128,128)
split_samples = None
top = True

use_set = 'set23'
bottom = './images/etching/'+use_set+'/bottom/'
top = './images/etching/'+use_set+'/top/'
mask = './images/etching/'+use_set+'/masks/'
# Original images
transformations = transforms.Compose([Rescale(rescale_to), ToTensor()])
etching_dataset = MicroscopeImageDataset(img_dir=[bottom, top], 
                                         mask_dir=mask,
                                         mask_label_info=mask_label_info, read_top=top,
                                         transf=transformations, split_samples=split_samples)

# Contrast, brightness adjusted
transformations = transforms.Compose([Rescale(rescale_to), 
                                      BrightnessContrastAdjustment((-0.3,1.3), 'brightness'), 
                                      ToTensor()])
etching_dataset_bc = MicroscopeImageDataset(img_dir=[bottom, top], 
                                         mask_dir=mask,
                                         mask_label_info=mask_label_info, read_top=top,
                                         transf=transformations, split_samples=split_samples)

# Rotated by 90 degrees
transformations = transforms.Compose([Rescale(rescale_to), 
                                      Rotate(90), 
                                      ToTensor()])
etching_dataset_rotate = MicroscopeImageDataset(img_dir=[bottom, top], 
                                         mask_dir=mask,
                                         mask_label_info=mask_label_info, read_top=top,
                                         transf=transformations, split_samples=split_samples)

combined_dset = ConcatDatasets(etching_dataset, etching_dataset_bc, etching_dataset_rotate)
print("Total nbr of samples: {0}".format(len(combined_dset)))

In [None]:
# Enable if only one channel input
#combined_dset.top()

In [None]:
# U-Net definition

#for i,sample in enumerate(combined_dset):
 #   print("Sample {0}".format(i+1))
  #  print(list(sample['image'].shape))
input_shape = list(combined_dset[0]['image'].shape)
print(input_shape)
depth = 3
n_filters = 2

model = UNet(input_shape, n_classes=len(labels), depth=depth, wf=n_filters, padding=2, kernel_size=5, up_mode='upconv', batch_norm=True, name='test')
model.summary()

In [None]:
# Train parameter
batch_size = 23
epochs = 25
eta = 3e-2/batch_size
lambda_l2 = 1e-3
gamma = 0.7

# U-Net definition
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

optim = torch.optim.SGD(model.parameters(), lr=eta, momentum=gamma)
loss_weights = torch.Tensor([1,4,2])
criterion = torch.nn.CrossEntropyLoss()
dataloader = DataLoader(combined_dset, batch_size=batch_size,
                        shuffle=True, num_workers=4)

avg_epoch_loss, model = train_unet(model, device, optim, criterion, dataloader, 
                                   epochs=epochs, lambda_=1e-3, reg_type=None, 
                                   save=False)


In [None]:
# Visualize
import numpy as np
fig, ax = plt.subplots(figsize=(14,8))
ax.plot(np.arange(epochs), np.array(avg_epoch_loss), linewidth=2)
ax.set_xlabel("Epoch")
ax.set_ylabel("CrossEntropyLoss")

In [None]:
# Infere a sample
fig = plt.figure(figsize=(14,8))
i = 27
sample, infered_mask = combined_dset.infer(i, model)
axes = visualize_sample(fig, sample, infered_mask)


Infere on test dataset

In [None]:
transformations = transforms.Compose([Rescale((256,256)), ToTensor()])
test_dataset = MicroscopeImageDataset(img_dir=['./images/etching/test/bottom/','./images/etching/test/top/'], 
                                         mask_dir=None,
                                         mask_label_info=mask_label_info, read_top=True,
                                         transf=transformations, split_samples=split_samples)

In [None]:
test_dataset.bottom()

In [None]:
fig = plt.figure(figsize=(14,8))
i = 2
sample, infered_mask = test_dataset.infer(i, model)
axes = visualize_sample(fig, sample, infered_mask)
plt.show()