In [1]:
import os
import argparse
from tqdm import tqdm
import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

from data_loader.cityscapes import CityscapesDataset
import utils.transforms
from model.deeplabv3plus import DeepLabv3Plus

from model.metrics import iou

In [2]:
writer = SummaryWriter()
checkpoint_file = os.path.join("./checkpoints", "best_model.pt")
img_root = "./data/leftImg8bit"
mask_root = "./data/gtFine"
transform = utils.transforms.Compose(
    [utils.transforms.Resize((512, 512)), utils.transforms.ToTensor()]
)

train_set = CityscapesDataset("train", img_root, mask_root, transform=transform)
val_set = CityscapesDataset("val", img_root, mask_root, transform=transform)

batch_size = 4
train_loader = DataLoader(dataset=train_set, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(dataset=val_set, batch_size=batch_size, shuffle=False)

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_classes = 20  # 0 -> 19
ignore_idx = train_set.ignoreId  # ignore 19
learning_rate = 0.001
model = DeepLabv3Plus(num_classes).to(device)
criterion = nn.CrossEntropyLoss(ignore_index=ignore_idx)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [4]:
def save_checkpoint(epoch, model, optimizer, path):
    torch.save(
        {
            "epoch": epoch,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
        },
        path,
    )


def train_one_epoch(epoch, model, criterion, optimizer, data_loader, device):
    model.train()

    total_loss = 0
    n_batches = len(data_loader)
    for i, sample in enumerate(tqdm(data_loader)):
        images = sample["image"].to(device)
        masks = sample["mask"].to(device)

        optimizer.zero_grad()
        pred = model(images)
        loss = criterion(pred, masks)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        writer.add_scalar("train_iter_loss", loss.item(), i + n_batches * epoch)

    avg_loss = total_loss / n_batches
    writer.add_scalar("train_epoch_avg_loss", avg_loss, epoch)
    return avg_loss


def evaluate(epoch, model, criterion, data_loader, num_classes, device):
    model.eval()

    total_loss = 0
    n_batches = len(data_loader)
    n_samples = len(data_loader.dataset)
    preds = torch.zeros((n_samples, 512, 512))
    gts = torch.zeros((n_samples, 512, 512))

    with torch.no_grad():
        for i, sample in enumerate(tqdm(data_loader)):
            images = sample["image"].to(device)
            masks = sample["mask"].to(device)

            pred = model(images)
            loss = criterion(pred, masks)
            total_loss += loss.item()

            b_size = len(masks)
            gts[i : i + b_size, :, :] = masks.cpu().detach()
            preds[i : i + b_size, :, :] = torch.argmax(pred, dim=1).cpu().detach()

    model.train()
    avg_loss = total_loss / n_batches
    ious, mIoU = iou(preds, gts)
    writer.add_scalar("val_epoch_avg_loss", avg_loss, epoch)
    writer.add_scalar("val_epoch_mIoU", mIoU, epoch)
    return avg_loss, ious, mIoU

In [5]:
# %load_ext tensorboard
# %tensorboard --logdir=runs --host=0.0.0.0 --port=6006

In [6]:
patience = 8
best_loss = np.inf
best_mIoU = 0
best_epoch = 0
for epoch in range(1000):
    train_loss = train_one_epoch(
        epoch, model, criterion, optimizer, train_loader, device
    )
    val_loss, ious, mIoU = evaluate(
        epoch, model, criterion, val_loader, num_classes, device
    )
    print(f"Epoch {epoch}: val_loss {val_loss:.4f} | mIoU {mIoU: .4f}")
    if mIoU < best_mIoU:
        best_mIoU = mIoU
        best_epoch = epoch
        save_checkpoint(epoch, model, optimizer, checkpoint_file)
    elif epoch - best_epoch >= patience:
        print(f"Early Stopping at epoch {epoch}")
        break

100%|██████████| 744/744 [08:10<00:00,  1.52it/s]
100%|██████████| 125/125 [00:47<00:00,  2.64it/s]
100%|██████████| 744/744 [08:08<00:00,  1.52it/s]
100%|██████████| 125/125 [00:47<00:00,  2.63it/s]
100%|██████████| 744/744 [08:08<00:00,  1.52it/s]
100%|██████████| 125/125 [00:46<00:00,  2.66it/s]
100%|██████████| 744/744 [08:08<00:00,  1.52it/s]
100%|██████████| 125/125 [00:47<00:00,  2.61it/s]
100%|██████████| 744/744 [08:11<00:00,  1.51it/s]
100%|██████████| 125/125 [00:47<00:00,  2.63it/s]
100%|██████████| 744/744 [08:08<00:00,  1.52it/s]
100%|██████████| 125/125 [00:47<00:00,  2.63it/s]
100%|██████████| 744/744 [08:08<00:00,  1.52it/s]
100%|██████████| 125/125 [00:47<00:00,  2.66it/s]
100%|██████████| 744/744 [08:07<00:00,  1.53it/s]
100%|██████████| 125/125 [00:47<00:00,  2.64it/s]
100%|██████████| 744/744 [08:07<00:00,  1.53it/s]
100%|██████████| 125/125 [00:47<00:00,  2.64it/s]
100%|██████████| 744/744 [08:08<00:00,  1.52it/s]
100%|██████████| 125/125 [00:46<00:00,  2.66it/s]


Early Stopping at epoch 32





In [7]:
writer.close()

In [8]:
# !kill tensorboard_pid