In [1]:
import os
import argparse
from tqdm import tqdm
import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

from data_loader.cityscapes import CityscapesDataset
import utils.transforms
from model.deeplabv3plus import DeepLabv3Plus

from model.metric import SegmentationMetrics

In [2]:
run_dir = "run2"
checkpoint_file = os.path.join("./checkpoints", run_dir, "best_model.pt")
log_dir = os.path.join("./runs", run_dir)
writer = SummaryWriter(log_dir=log_dir)

img_root = "./data/leftImg8bit"
mask_root = "./data/gtFine"
train_transform = utils.transforms.Compose([
    # utils.transforms.Resize((1024, 512)),
    # utils.transforms.RandomCrop(512),
    utils.transforms.RandomScaleCrop(
        scale_min=0.75, scale_max=2.0, crop_size=512, inference_size=(1024, 512)
    ),
    utils.transforms.RandomHorizontalFlip(flip_prob=0.5),
    utils.transforms.ToTensor()
])
val_transform = utils.transforms.Compose([
    utils.transforms.Resize((1024, 512)),
    utils.transforms.CenterCrop((512, 512)),
    utils.transforms.ToTensor()
])
train_set = CityscapesDataset("train", img_root, mask_root, transform=train_transform)
val_set = CityscapesDataset("val", img_root, mask_root, transform=val_transform)

batch_size = 4
train_loader = DataLoader(dataset=train_set, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(dataset=val_set, batch_size=batch_size, shuffle=False)

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_classes = 20  # 0 -> 19
ignore_idx = train_set.ignoreId  # ignore 19
learning_rate = 0.001
momentum = 0.9
weight_decay = 0.0005
model = DeepLabv3Plus(num_classes).to(device)
criterion = nn.CrossEntropyLoss(ignore_index=ignore_idx)
optimizer = torch.optim.SGD(
    model.parameters(), lr=learning_rate, momentum=momentum, weight_decay=weight_decay
)

In [4]:
def save_checkpoint(epoch, model, optimizer, path):
    torch.save(
        {
            "epoch": epoch,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
        },
        path,
    )


def train_one_epoch(epoch, model, criterion, optimizer, data_loader, device):
    model.train()

    total_loss = 0
    n_batches = len(data_loader)
    for i, sample in enumerate(tqdm(data_loader)):
        images = sample["image"].to(device)
        masks = sample["mask"].to(device)

        optimizer.zero_grad()
        pred = model(images)
        loss = criterion(pred, masks)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        writer.add_scalar("train_iter_loss", loss.item(), i + n_batches * epoch)

    avg_loss = total_loss / n_batches
    writer.add_scalar("train_epoch_avg_loss", avg_loss, epoch)
    return avg_loss


def validate(epoch, model, criterion, data_loader, num_classes, device):
    model.eval()

    total_loss = 0
    n_batches = len(data_loader)
    metrics = SegmentationMetrics(num_classes=num_classes, ignore_idx=ignore_idx)
    with torch.no_grad():
        for i, sample in enumerate(tqdm(data_loader)):
            images = sample["image"].to(device)
            masks = sample["mask"].to(device)

            pred = model(images)
            loss = criterion(pred, masks)
            total_loss += loss.item()

            pred_cls = torch.argmax(pred, dim=1)
            metrics.update(pred_cls, masks)

    model.train()

    avg_loss = total_loss / n_batches
    ious, mIoU = metrics.iou()
    writer.add_scalar("val_epoch_avg_loss", avg_loss, epoch)
    writer.add_scalar("val_epoch_mIoU", mIoU, epoch)
    return avg_loss, ious, mIoU

In [5]:
# %load_ext tensorboard
# %tensorboard --logdir=runs --host=0.0.0.0 --port=6006

In [None]:
patience = 15
best_loss = np.inf
best_mIoU = 0
best_epoch = 0
for epoch in range(1000):
    train_loss = train_one_epoch(
        epoch, model, criterion, optimizer, train_loader, device
    )
    val_loss, ious, mIoU = validate(
        epoch, model, criterion, val_loader, num_classes, device
    )
    print(f"Epoch {epoch}: val_loss {val_loss:.4f} | mIoU {mIoU: .4f}")
    if mIoU < best_mIoU:
        best_mIoU = mIoU
        best_epoch = epoch
        save_checkpoint(epoch, model, optimizer, checkpoint_file)
    elif epoch - best_epoch >= patience:
        print(f"Early Stopping at epoch {epoch}")
        break

 45%|████▌     | 336/744 [03:51<04:39,  1.46it/s]

In [None]:
writer.close()

In [None]:
# !kill tensorboard_pid