In [None]:
import sys
sys.path.append("..")

In [None]:
import os

In [None]:
import torch
import torch.optim as optim

In [None]:
import train_semantic_segmentation as semseg
from train_semantic_segmentation import (
    call_many,
    save_model_on_better_miou,
    save_segmentations_for_image,
    save_interesting_images
)

In [None]:
from matplotlib import pyplot as plt

In [None]:
EXPERIMENT_NAME = 'experiments/overfit-no-chandrop-small'

In [None]:
DATA_PATHS = {
    'source': os.path.join('../../data/VOCdevkit/VOC2012/JPEGImages'),
    'segmentation': os.path.join('../../data/VOCdevkit/VOC2012/SegmentationClass'),
    'train': os.path.join('../../data/VOCdevkit/VOC2012/ImageSets/Segmentation/train.txt'),
    'val': os.path.join('../../data/VOCdevkit/VOC2012/ImageSets/Segmentation/val.txt')
}

In [None]:
train_loader, val_loader, test_loader = semseg.load_data(
    DATA_PATHS['source'],
    DATA_PATHS['segmentation'],
    DATA_PATHS['train'],
    DATA_PATHS['val'],
    DATA_PATHS['val']
)

In [None]:
val_loader_with_viewable_transforms = val_loader.dataset.with_viewable_transforms()

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
learning_rate = 0.007
epochs = 50

In [None]:
model = semseg.DeepLabModel(input_channels=3, num_classes=21).to(device)

In [None]:
criterion = semseg.segmentation_cross_entropy_loss(size_average=None,
                                                   ignore_index=255,
                                                   device=device)
optimizer = optim.SGD(semseg.differential_learning_rates(model, [
                          ((model.feature_detection_layers, ), 1),
                          ((model.spatial_pyramid_pooling, model.decoder), 10)
                      ], learning_rate),
                      momentum=0.9,
                      weight_decay=5e-4,
                      nesterov=False)
scheduler = semseg.PolynomialLearningRateScheduler(optimizer,
                                                   learning_rate,
                                                   epochs,
                                                   len(train_loader))

In [None]:
def log_statistics_to_notebook(stats):
    print(stats)

In [None]:
semseg.training_loop(model,
                     train_loader,
                     val_loader,
                     criterion,
                     optimizer,
                     scheduler,
                     device,
                     epochs=epochs,
                     statistics_callback=call_many(
                          semseg.log_statistics(os.path.join(EXPERIMENT_NAME, 'logs', 'statistics'), False),
                          # Take the first image from the first three batches
                          *[save_segmentations_for_image(model,
                                                         val_loader_with_viewable_transforms[i]["image"].to(device),
                                                         val_loader_with_viewable_transforms[i]["label"].to(device),
                                                         os.path.join(
                                                             EXPERIMENT_NAME,
                                                             "logs",
                                                             "segmentations",
                                                             "image_{}.png".format(i)
                                                         ))
                            for i in range(0, 3)]
                     ),
                     epoch_end_callback=call_many(
                          save_model_on_better_miou(os.path.join(EXPERIMENT_NAME, "saved/model.pt"),
                                                    0),
                          save_interesting_images(os.path.join(EXPERIMENT_NAME,
                                                               "logs",
                                                               "interesting",
                                                               "image.png"),
                                                  device)
                     ),
                     start_epoch=0)