In [None]:
    %matplotlib inline
    import torch
    from torch.utils.data import DataLoader
    from torchvision import transforms
    from dataset import NucleiDataset
    from utils import RandomCrop, show_images, iterate #TODO evaluate
    from losses import focal_loss, dice_loss, get_iou
    from unet import UNet

In [None]:
TRAIN_DATA_PATH = '/g/kreshuk/zinchenk/courses/EMBL_BTM_2019/advanced_machine_learning/nuclei_train_data'
BATCH_SIZE = 1
NUM_LAYERS = 3
IN_FILTERS = 32
GAMMA = 2

In [None]:
train_data = NucleiDataset(TRAIN_DATA_PATH, RandomCrop(256))
train_dataloader = DataLoader(train_data, batch_size=BATCH_SIZE)

In [None]:
model=UNet(IN_FILTERS, NUM_LAYERS)
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

In [None]:
def train(model, optimizer, dataloader, gamma, num_epochs=10):
    dataset_size = len(dataloader)
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch+1, num_epochs))
        print('-' * 10)
        model.train()
        train_loss = 0.0
        train_accuracy = 0.0
        train_iou = 0.0
        count = 0
        for images, masks in iterate(dataloader):
            count += 1
            optimizer.zero_grad()
            outputs = model(images)
            predictions = (outputs > 0.5)
            if count % 10 == 0:
                show_images(images, masks, predictions)
            loss = focal_loss(outputs, masks, gamma)
            accuracy = torch.mean((predictions == masks.byte()).float())
            iou = get_iou(predictions, masks.type(torch.bool))
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            train_accuracy += accuracy.item()
            train_iou += iou.item()
        epoch_loss = train_loss / dataset_size
        epoch_accuracy = train_accuracy / dataset_size
        epoch_iou = train_iou / dataset_size
        print ('Training loss is {:.6f}, iou is {:.6f}, accuracy is {:.6f}'.format(epoch_loss, epoch_iou, epoch_accuracy))
    return model

In [None]:
model = train(model, optimizer, train_dataloader, gamma=GAMMA, num_epochs=10)