In [None]:
import torch
import sys
import os
sys.path.append(os.path.abspath('..'))

In [None]:
# Logging
import logging
logger = logging.getLogger(__name__)
logging.basicConfig(filename='segnet.log', level=logging.INFO)

In [None]:
# Hyperparameters - according to paper
NUM_EPOCHS = 1000
BATCH_SIZE = 2
LEARNING_RATE = 0.001
IMG_SIZE = (512, 256)
NUM_WORKERS =  0#16
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(DEVICE)
logger.info(f'Device: {DEVICE}')

In [None]:
# Init: Datasets and Dataloader
import os
from SegCropNet.dataloader.data_loaders import TusimpleSet
from torch.utils.data import DataLoader

train_path = os.path.join(os.path.abspath(''), os.pardir, 'datasets', 'ma_dataset', 'combined', 'train')
train_dataset = TusimpleSet(train_path, img_size=IMG_SIZE, transform=True)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=True)

val_path = os.path.join(os.path.abspath(''), os.pardir, 'datasets', 'ma_dataset', 'combined', 'val')
val_dataset = TusimpleSet(val_path, img_size=IMG_SIZE, transform=False)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=True)

logger.info('created datasets and dataloader')

In [None]:
# Init: Model
import torch.optim as optim
from SegCropNet.model.SegCropNet.SegCropNet import SegCropNet
from SegCropNet.model.SegCropNet.average_meter import AverageMeter
model = SegCropNet(arch='UNet').to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=0.0005)
loss_type = 'CrossEntropyLoss'
best_loss = float('inf')
epoch_start = 0
scheduler = None

logger.info('created model and optimizer')

In [None]:
# Load Model if already existing
checkpoint_path = os.path.join(os.path.abspath(''), 'insta_checkpoint.pt.tar')
if os.path.isfile(checkpoint_path):
    logger.info('found existing model')
    checkpoint = torch.load(checkpoint_path, map_location=DEVICE)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch_start = checkpoint['epoch']
    best_loss = checkpoint['loss']
    logger.info('loaded existing model for continuation of training')

In [None]:
# Training Cycle
from tqdm import tqdm
import matplotlib.pyplot as plt
from SegCropNet.model.SegCropNet.loss import compute_loss

for epoch in range(epoch_start, NUM_EPOCHS, 1):
  model.train()
  epoch_loss = AverageMeter()
  epoch_loss_bin = AverageMeter()
  epoch_loss_inst = AverageMeter()
  train_iou = AverageMeter()
  for train_batch in tqdm(train_loader, desc=f'Epoch {epoch+1}/{NUM_EPOCHS}'):
    inputs = train_batch['input'].type(torch.FloatTensor).to(DEVICE)
    binaries = train_batch['binary'].type(torch.LongTensor).to(DEVICE)
    instances = train_batch['instance'].type(torch.FloatTensor).to(DEVICE)

    optimizer.zero_grad()

    train_preds = model(inputs)
    batch_loss = compute_loss(train_preds, binaries, instances, loss_type)

    batch_loss[0].backward()
    optimizer.step()
    if scheduler != None:
      scheduler.step()

    epoch_loss.update(batch_loss[0].item(), inputs.size(0))
    epoch_loss_bin.update(batch_loss[1].item(), inputs.size(0))
    epoch_loss_inst.update(batch_loss[2].item(), inputs.size(0))
    train_iou.update(batch_loss[4], inputs.size(0))

  model.eval()
  loss = 0.0
  val_loss = AverageMeter()
  val_loss_bin = AverageMeter()
  val_loss_inst = AverageMeter()
  val_iou = AverageMeter()
  with torch.no_grad():
    for val_batch in tqdm(val_loader, desc=f'Validation'):
      inputs = val_batch[0]['input'].type(torch.FloatTensor).to(DEVICE)
      binaries = val_batch[0]['binary'].type(torch.LongTensor).to(DEVICE)
      instances = val_batch[0]['instance'].type(torch.FloatTensor).to(DEVICE)

      val_preds = model(inputs)
      loss = compute_loss(val_preds, binaries, instances, loss_type)

      val_loss.update(loss[0].item(), inputs.size(0))
      val_loss_bin.update(loss[1].item(), inputs.size(0))
      val_loss_inst.update(loss[2].item(), inputs.size(0))
      val_iou.update(loss[4], inputs.size(0))
    print(f'train_loss: {epoch_loss.avg} | val_loss: {val_loss.avg}')

    if val_loss.avg < best_loss:
      best_loss = val_loss.avg
      torch.save(model.state_dict(), 'best_insta.pt')
      print('New Model saved')
      logger.info(f'updated best_model in epoch {epoch+1} with \n \
                  training_loss: total: {epoch_loss.avg} binary: {epoch_loss_bin.avg} instance: {epoch_loss_inst.avg} and \
                  validation_loss: total: {val_loss.avg} binary: {val_loss_bin.avg} instance: {val_loss_inst.avg}')

    if (epoch+1) % 10 == 0:
      torch.save({'epoch': epoch,
                  'model_state_dict': model.state_dict(),
                  'optimizer_state_dict': optimizer.state_dict(),
                  'loss': best_loss},
                  'insta_checkpoint.pt.tar')