In [1]:
import os
import argparse
from tqdm import tqdm
import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

from data_loader.cityscapes import CityscapesDataset
import utils.transforms
from model.deeplabv3plus import DeepLabv3Plus

In [2]:
writer = SummaryWriter()
checkpoint_file = os.path.join('./checkpoints', 'best_model.pt')
img_root = './data/leftImg8bit'
mask_root = './data/gtFine'
transform = utils.transforms.Compose(
    [utils.transforms.Resize((224, 224)), utils.transforms.ToTensor()]
)

train_set = CityscapesDataset("train", img_root, mask_root, transform=transform)
val_set = CityscapesDataset("val", img_root, mask_root, transform=transform)

batch_size = 8
train_loader = DataLoader(dataset=train_set, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(dataset=val_set, batch_size=batch_size, shuffle=False)

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_classes = 19
ignore_idx = train_set.ignoreId
learning_rate = 0.001
model = DeepLabv3Plus(num_classes).to(device)
criterion = nn.CrossEntropyLoss(ignore_index=ignore_idx)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [4]:
def save_checkpoint(epoch, model, optimizer, path):
    torch.save(
        {
            "epoch": epoch,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict()
        },
        path
    )


def train_one_epoch(epoch, model, criterion, optimizer, data_loader, device):
    model.train()

    total_loss = 0
    n_batches = len(data_loader)
    for i, sample in enumerate(tqdm(data_loader)):
        images = sample["image"].to(device)
        masks = sample["mask"].to(device)

        optimizer.zero_grad()
        pred = model(images)
        loss = criterion(pred, masks)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        writer.add_scalar('train_iter_loss', loss.item(), i + n_batches * epoch)

    writer.add_scalar('train_epoch_loss', total_loss, epoch)
    return total_loss


def evaluate(epoch, model, criterion, data_loader, num_classes, device):
    model.eval()

    total_loss = 0
    with torch.no_grad():
        for i, sample in enumerate(tqdm(data_loader)):
            images = sample["image"].to(device)
            masks = sample["mask"].to(device)

            pred = model(images)
            pred = torch.argmax(pred, dim=1)
            loss = criterion(pred, masks)

            total_loss += loss.item()

    model.train()
    writer.add_scalar('val_epoch_loss', total_loss, epoch)
    return total_loss

In [5]:
%load_ext tensorboard
%tensorboard --logdir=runs

Reusing TensorBoard on port 6006 (pid 27147), started 0:01:31 ago. (Use '!kill 27147' to kill it.)

In [None]:
best_loss = np.inf
best_epoch = 0
for epoch in range(100):
    train_loss = train_one_epoch(
        epoch, model, criterion, optimizer, train_loader, device
    )
    val_loss = evaluate(
        epoch, model, criterion, val_loader, num_classes, device
    )
    if val_loss < best_loss:
        best_loss = val_loss
        best_epoch = epoch
        save_checkpoint(epoch, model, optimizer, checkpoint_file)
    elif epoch - best_epoch >= 5:
        print(f"Early Stopping at epoch {epoch}")
        break

 44%|████▍     | 164/372 [01:52<02:23,  1.44it/s]

In [None]:
writer.close()