In [1]:
import os
import cv2
import torch
import random
import logging
import argparse
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from torchvision import transforms
from arch.unet import UNet20, UNet256
from torch.utils.data import Dataset, DataLoader
%matplotlib inline

In [None]:


class ICVGIPDataset(Dataset):
    def __init__(
        self,
        image_dir="data/leftImg8bit/train",
        labels_dir="data/gtFine/train",
        print_dataset=False,
        input_img_size=(388, 388),
        output_img_size=(388, 388),
    ):
        X = []
        y = []
        for root, directories, files in os.walk(image_dir, topdown=False):
            for name in files:
                print(name)
                X.append(os.path.join(root, name))

        for root, directories, files in os.walk(labels_dir, topdown=False):
            for name in files:
                # if "_gtFine_labellevel3Ids.png" in name:
                if "inst_label.png" in name:
                    y.append(os.path.join(root, name))
        print(len(y))
        print(len(X))
        assert len(X) == len(y)
        X.sort()
        y.sort()
        self.samples = list(zip(X, y))
        del X, y
        if print_dataset:
            self.print_dataset()

        self.input_img_size = input_img_size
        self.output_img_size = output_img_size

    def __len__(self):
        return len(self.samples)

    def print_dataset(self):
        for X, y in self.samples:
            print(X, y)

    def __getitem__(self, index):
        image_path, label_path = self.samples[index]
        image = cv2.imread(image_path) / 255.0
        image = cv2.resize(image, self.input_img_size, interpolation=cv2.INTER_NEAREST).reshape(3, self.input_img_size[0], self.input_img_size[1])
        image = torch.Tensor(image)

        labels = cv2.imread(label_path, cv2.IMREAD_GRAYSCALE)
        labels = cv2.resize(labels, self.output_img_size, cv2.INTER_NEAREST)
        labels[np.where(labels > 26)] = 26

        labels = torch.Tensor(np.asarray(labels)).long()
        image = self.transform(image)
        return image, labels

    def transform(self, image):
        transform_ops = transforms.Compose(
            [
                # transforms.ToTensor(),
                transforms.Normalize(
                    mean=(0.485, 0.56, 0.406), std=(0.229, 0.224, 0.225)
                ),
            ]
        )
        return transform_ops(image)


def get_dataloader(
    image_dir="data/leftImg8bit/train",
    labels_dir="data/gtFine/train",
    print_dataset=False,
    batch_size=8,
    input_img_size=(388, 388),
    output_img_size=(388, 388),
):
    dataset = ICVGIPDataset(
        image_dir=image_dir,
        labels_dir=labels_dir,
        print_dataset=print_dataset,
        input_img_size=input_img_size,
        output_img_size=output_img_size,
    )
    dataloader = DataLoader(dataset, batch_size=batch_size)
    return dataloader



In [None]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
device = torch.device(device)

In [None]:
train_img_dir = "data/leftImg8bit/train"
train_label_dir = "data/gtFine/train"
num_classes = 27
train_batch_size = 8
epochs = 100
lr = 1e-3

In [None]:
train_dataloader = get_dataloader(
    image_dir=train_img_dir,
    labels_dir=train_label_dir,
    batch_size=train_batch_size,
    print_dataset=True,
    input_img_size=(256, 256),
    output_img_size=(256, 256),
)
model = UNet256(num_classes=num_classes).to(device)

In [None]:
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max = args.epochs, eta_min = 1e-6, last_epoch=-1, verbose=True)

In [None]:
step_losses = []
epoch_losses = []

for epoch in tqdm(range(epochs)):
    epoch_loss = 0
    for X, y in tqdm(train_dataloader):
        optimizer.zero_grad()
        output = model(X.to(device))
        loss = criterion(output, y.to(device))
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        step_losses.append(loss.item())
    scheduler.step()
    epoch_loss = epoch_loss / len(train_dataloader)
    logger.info("Average Loss: {}".format(epoch_loss))
    epoch_losses.append(epoch_loss)
    torch.save(model, "checkpoint_{}_{}_{}.pth".format(args.model, epoch, identifier))
fig, axes = plt.subplots(1, 2, figsize=(10, 5))
axes[0].plot(step_losses)
axes[1].plot(epoch_losses)
plt.savefig("{}_{}_train_analysis.png".format(args.model, identifier))