In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import torch.nn.functional as F
from torchvision import transforms, utils
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from unet.unet_helpers import MicroscopeImageDataset, ToTensor, Rescale, train_unet
from unet.unet import UNet

In [None]:
labels = [0,1,2]
label_thresh = [20,220]
mask_label_info = [labels, label_thresh]
transformations = transforms.Compose([Rescale((128,256)), ToTensor()])
etching_dataset = MicroscopeImageDataset(img_dir=['./images/etching/set2/bottom/','./images/etching/set2/top/'], 
                                         mask_dir='./images/etching/set2/masks/',
                                         mask_label_info=mask_label_info, read_top=True,
                                         transf=transformations, split_samples=(3,3))

In [None]:
# Show the samples
fig = plt.figure(figsize=(50,35))

for i in range(len(etching_dataset)):
    sample = etching_dataset[i]
    image, mask = sample['image'], sample['mask']
    print(i, image.shape, mask.shape)
for i in range(len(etching_dataset)):
    sample = etching_dataset[i]
    image, mask = sample['image'], sample['mask']

    ax = plt.subplot(10, 3, 3*i + 1)
    plt.imshow(image[0,:,:], cmap='gray', vmin=0, vmax=1)
    ax.set_title('Sample #{} trench bottom'.format(i))
    ax.axis('off')
    
    ax = plt.subplot(10, 3, 3*i + 2)
    plt.tight_layout()
    plt.imshow(image[1,:,:], cmap='gray', vmin=0, vmax=1)
    ax.set_title('Sample #{} top'.format(i))
    ax.axis('off')
    
    ax = plt.subplot(10, 3, 3*i + 3)
    plt.tight_layout()
    plt.imshow(mask, cmap='gray', vmin=0, vmax=2)
    ax.set_title('Sample #{} mask'.format(i))
    ax.axis('off')

    plt.tight_layout()
    if i == 9:
        plt.show()
        break

In [None]:
# Train parameter
epochs = 25
eta = 5e-3
lambda_l2 = 1e-3

# U-Net definition
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = UNet(in_channels=2, n_classes=3, depth=3, wf=3, padding=True, up_mode='upconv', batch_norm=True).to(device)
optim = torch.optim.Adam(model.parameters(), lr=eta)
criterion = torch.nn.CrossEntropyLoss()
dataloader = DataLoader(etching_dataset, batch_size=4,
                        shuffle=True, num_workers=4)

avg_epoch_loss, model = train_unet(model, device, optim, criterion, dataloader, 
                                   epochs=epochs, lambda_=1e-3, reg_type='l2', 
                                   save=False)


In [None]:
# Visualize
import numpy as np
fig, ax = plt.subplots(figsize=(14,8))
ax.plot(np.arange(epochs), np.array(avg_epoch_loss), linewidth=2)
ax.set_xlabel("Epoch")
ax.set_ylabel("CrossEntropyLoss")