In [1]:
import os
import random
import numpy as np
from PIL import Image

import torch

from unet import UNet, convert_trainid_mask
from cityscapes import get_transforms

DEVICE = 'cuda'
torch.cuda.empty_cache()

In [2]:
model_weights_path = "./saved/unet_e_50.pth"
cfg = {
    'train_crop_size': [1024, 1024],
    'norm_mean': [0.0, 0.0, 0.0],
    'norm_std': [1.0, 1.0, 1.0],
    'num_classes': 20,
    'model_weights_path': model_weights_path,
}

In [3]:
transform_train, transform_val_test = get_transforms(cfg["train_crop_size"], cfg["norm_mean"], cfg["norm_std"])

model = UNet(num_classes=cfg['num_classes'])
model_state_dict = torch.load(cfg['model_weights_path'], map_location='cpu', weights_only=True)
model.load_state_dict(model_state_dict)
model.to(DEVICE)
model.eval()

UNet(
  (b0): ConvBlock(
    (conv): Conv2d(3, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
  )
  (b1): ConvBlock(
    (conv): Conv2d(48, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
  )
  (b2): ConvBlock(
    (conv): Conv2d(48, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
  )
  (b3): ConvBlock(
    (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
  )
  (b4): ConvBlock(
    (conv): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn):

In [4]:
def predict(image_path):
    image = Image.open(image_path).convert('RGB')
    img = transform_val_test(image).unsqueeze(0).to(DEVICE)

    with torch.no_grad():
        pred_logits = model(img)
        pred_mask = torch.argmax(pred_logits, dim=1).squeeze().cpu().numpy().astype(np.uint8)
        pred_mask_color = convert_trainid_mask(
            pred_mask,
            to="color",
            name_to_trainId_path='./cityscapes/name_to_trainId.json',
            name_to_color_path='./cityscapes/name_to_color.json',
            name_to_labelId_path='./cityscapes/name_to_labelId.json',
        ).astype(np.uint8)
    pred_mask_color = Image.fromarray(pred_mask_color).convert('RGB')
    return image, pred_mask_color

In [5]:
file_list = []
for city in os.listdir('./data/leftImg8bit/val/'):
    img_dir = os.path.join('./data/leftImg8bit/val/', city)
    for file_name in os.listdir(img_dir):
        if file_name.endswith('_leftImg8bit.png'):
            img_path = os.path.join(img_dir, file_name)
            file_list.append(img_path)

print(f"Found {len(file_list)} images.")
random.seed(42)
random.shuffle(file_list)

Found 500 images.


In [6]:
results = []
for i, file_path in enumerate(file_list[0:50]):
    print(f"Processing {i + 1}/{len(file_list[0:50])}", end='\r')
    image, pred_mask_color = predict(file_path)
    blended = Image.blend(image, pred_mask_color, 0.3)
    results.append(blended)

Processing 50/50

In [7]:
results[0].save(
    "./outputs/unet.gif",
    format="GIF",
    save_all=True,
    append_images=results[1:50],
    duration=5000,
    loop=0,
)