In [None]:
import os
import random
import numpy as np
from PIL import Image

import torch

from deeplabv3 import DeepLabV3, convert_trainid_mask
from cityscapes import get_transforms

DEVICE = 'cuda'
torch.cuda.empty_cache()

In [None]:
model_weights_path = "./saved/dlv3_os_16_e_30.pth"
output_stride = 16
cfg = {
    'train_crop_size': [1024, 1024],
    'norm_mean': [0.485, 0.456, 0.406],
    'norm_std': [0.229, 0.224, 0.225],
    'backbone': "resnet50",
    'num_classes': 20,
    'output_stride': output_stride,
    'model_weights_path': model_weights_path,
}

In [None]:
transform_train, transform_val_test = get_transforms(cfg["train_crop_size"], cfg["norm_mean"], cfg["norm_std"])

model = DeepLabV3(
    backbone=cfg['backbone'],
    num_classes=cfg['num_classes'],
    output_stride=cfg['output_stride'],
)
model_state_dict = torch.load(cfg['model_weights_path'], map_location='cpu', weights_only=True)
model.load_state_dict(model_state_dict)
model.to(DEVICE)
model.eval()

In [None]:
def predict(image_path):
    image = Image.open(image_path).convert('RGB')
    img = transform_val_test(image).unsqueeze(0).to(DEVICE)

    with torch.no_grad():
        pred_logits = model(img)
        pred_mask = torch.argmax(pred_logits, dim=1).squeeze().cpu().numpy().astype(np.uint8)
        pred_mask_color = convert_trainid_mask(
            pred_mask,
            to="color",
            name_to_trainId_path='./cityscapes/name_to_trainId.json',
            name_to_color_path='./cityscapes/name_to_color.json',
            name_to_labelId_path='./cityscapes/name_to_labelId.json',
        ).astype(np.uint8)
    pred_mask_color = Image.fromarray(pred_mask_color).convert('RGB')
    return image, pred_mask_color

In [None]:
file_list = []
for city in os.listdir('./data/leftImg8bit/val/'):
    img_dir = os.path.join('./data/leftImg8bit/val/', city)
    for file_name in os.listdir(img_dir):
        if file_name.endswith('_leftImg8bit.png'):
            img_path = os.path.join(img_dir, file_name)
            file_list.append(img_path)

print(f"Found {len(file_list)} images.")
random.seed(42)
random.shuffle(file_list)

In [None]:
results = []
for i, file_path in enumerate(file_list[0:50]):
    print(f"Processing {i + 1}/{len(file_list[0:50])}", end='\r')
    image, pred_mask_color = predict(file_path)
    blended = Image.blend(image, pred_mask_color, 0.3)
    results.append(blended)

In [None]:
results[0].save(
    "./outputs/predictions_os_16.gif",
    format="GIF",
    save_all=True,
    append_images=results[1:50],
    duration=5000,
    loop=0,
)