In [None]:
import os
import shutil
import sys
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import torchvision.transforms as transforms
import numpy as np
import random
from google.colab import drive

# --- 1. MOUNT DRIVE ---
drive.mount('/content/drive')

# --- 2. RESTORE MODEL CODE ---
if not os.path.exists('/content/models'):
    !git clone https://github.com/Gabrysse/MLDL2024_project1.git temp_repo
    shutil.copytree('temp_repo/models', '/content/models')
    shutil.rmtree('temp_repo')
    print("Models restored.")

# --- 3. RESTORE DATASET ---
# Check if dataset exists, if not, unzip it
if not os.path.exists('/content/dataset/project_data/gta5'):
    print("Dataset not found. checking for zip file...")

    # Check paths (semseg folder first, then root)
    zip_path_1 = '/content/drive/MyDrive/semseg/project_data.zip'
    zip_path_2 = '/content/drive/MyDrive/project_data.zip'

    if os.path.exists(zip_path_1):
        print(f"Unzipping from {zip_path_1}...")
        shutil.unpack_archive(zip_path_1, '/content/dataset')
        print("Dataset extracted!")
    elif os.path.exists(zip_path_2):
        print(f"Unzipping from {zip_path_2}...")
        shutil.unpack_archive(zip_path_2, '/content/dataset')
        print("Dataset extracted!")
    else:
        print("Error: 'project_data.zip' not found in Drive. Please check your path.")
else:
    print("Dataset is ready.")

# Add models to path
sys.path.append('/content/models')

Mounted at /content/drive
Cloning into 'temp_repo'...
remote: Enumerating objects: 34, done.[K
remote: Counting objects: 100% (21/21), done.[K
remote: Compressing objects: 100% (18/18), done.[K
Receiving objects: 100% (34/34), 11.29 KiB | 11.29 MiB/s, done.
Resolving deltas: 100% (9/9), done.
remote: Total 34 (delta 9), reused 3 (delta 3), pack-reused 13 (from 1)[K
Models restored.
Dataset not found. checking for zip file...
Unzipping from /content/drive/MyDrive/semseg/project_data.zip...
Dataset extracted!


In [None]:
import torch.fft

def extract_ampl_phase(fft_im):
    # fft_im: size should be bx3xhx2w
    fft_amp = fft_im[:,:,:,:,0]**2 + fft_im[:,:,:,:,1]**2
    fft_amp = torch.sqrt(fft_amp)
    fft_pha = torch.atan2(fft_im[:,:,:,:,1], fft_im[:,:,:,:,0])
    return fft_amp, fft_pha

def low_freq_mutate(amp_src, amp_trg, L=0.1):
    # Expects 4D input: [Batch, Channel, Height, Width]
    _, _, h, w = amp_src.size()
    b = (np.floor(np.amin((h,w))*L)).astype(int)     # get b
    amp_src[:,:,0:b,0:b]     = amp_trg[:,:,0:b,0:b]  # top left
    amp_src[:,:,0:b,w-b:w]   = amp_trg[:,:,0:b,w-b:w] # top right
    amp_src[:,:,h-b:h,0:b]   = amp_trg[:,:,h-b:h,0:b] # bottom left
    amp_src[:,:,h-b:h,w-b:w] = amp_trg[:,:,h-b:h,w-b:w] # bottom right
    return amp_src

def FDA_source_to_target(src_img, trg_img, L=0.1):
    # 1. Add a fake batch dimension to support 3D inputs [C,H,W] -> [1,C,H,W]
    src_img = src_img.unsqueeze(0)
    trg_img = trg_img.unsqueeze(0)

    # get fft of both source and target
    fft_src = torch.fft.rfft2(src_img.clone())
    fft_trg = torch.fft.rfft2(trg_img.clone())

    # extract amplitude and phase of both ffts
    amp_src, pha_src = torch.abs(fft_src), torch.angle(fft_src)
    amp_trg, pha_trg = torch.abs(fft_trg), torch.angle(fft_trg)

    # mutate the amplitude part of source with target
    amp_src_ = low_freq_mutate(amp_src.clone(), amp_trg.clone(), L=L)

    # mutated fft of source
    fft_src_ = torch.polar(amp_src_, pha_src)

    # get the mutated image
    src_in_trg = torch.fft.irfft2(fft_src_, s=src_img.shape[-2:])

    # 2. Remove the fake batch dimension to return [C,H,W]
    src_in_trg = src_in_trg.squeeze(0)

    return src_in_trg

In [None]:
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import torchvision.transforms as transforms
import numpy as np
import os
import random
import sys
from models.bisenet.build_bisenet import BiSeNet

# --- CONFIGURATION ---
CHECKPOINT_NAME = 'bisenet_fda_checkpoint.pth'
EPOCHS = 50
BATCH_SIZE = 8
L_BETA = 0.05  # Controls how much style is swapped (0.01 - 0.09 is standard)
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Paths
GTA_PATH = '/content/dataset/project_data/gta5'
CITYSCAPES_PATH = '/content/dataset/project_data/cityscapes'
SAVE_PATH = f'/content/drive/MyDrive/semseg/{CHECKPOINT_NAME}'

# --- DATASET WITH FDA ---
class GTA5_FDA_Dataset(Dataset):
    def __init__(self, gta_root, city_root):
        self.gta_images_dir = os.path.join(gta_root, 'images')
        self.gta_masks_dir = os.path.join(gta_root, 'labels')
        self.gta_images = sorted(os.listdir(self.gta_images_dir))

        # Load Cityscapes images list (to steal style from)
        self.city_images_dir = os.path.join(city_root, 'leftImg8bit', 'train')
        self.city_images = []
        if os.path.exists(self.city_images_dir):
            for city in os.listdir(self.city_images_dir):
                c_path = os.path.join(self.city_images_dir, city)
                if os.path.isdir(c_path):
                    for f in os.listdir(c_path):
                        if f.endswith('_leftImg8bit.png'):
                            self.city_images.append(os.path.join(c_path, f))

        self.normalize = transforms.Compose([
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

        self.id_mapping = {
            7: 0, 8: 1, 11: 2, 12: 3, 13: 4, 17: 5,
            19: 6, 20: 7, 21: 8, 22: 9, 23: 10, 24: 11, 25: 12,
            26: 13, 27: 14, 28: 15, 31: 16, 32: 17, 33: 18
        }

        self.to_tensor = transforms.ToTensor()

    def __len__(self): return len(self.gta_images)

    def __getitem__(self, idx):
        # 1. Load GTA Image & Mask
        img_path = os.path.join(self.gta_images_dir, self.gta_images[idx])
        mask_path = os.path.join(self.gta_masks_dir, self.gta_images[idx])

        gta_image = Image.open(img_path).convert('RGB').resize((1280, 720), Image.BILINEAR)
        gta_mask = Image.open(mask_path).resize((1280, 720), Image.NEAREST)

        # 2. Load Random Cityscapes Image (Target Style)
        rand_idx = random.randint(0, len(self.city_images) - 1)
        city_image = Image.open(self.city_images[rand_idx]).convert('RGB').resize((1280, 720), Image.BILINEAR)

        # 3. Convert to Tensor
        gta_t = self.to_tensor(gta_image)
        city_t = self.to_tensor(city_image)

        # 4. Apply FDA (Style Transfer)
        # We wrap in try/except just in case of odd dimension issues, but the fix should handle it
        try:
            gta_stylized = FDA_source_to_target(gta_t, city_t, L=L_BETA)
        except Exception as e:
            print(f"FDA Error: {e}")
            gta_stylized = gta_t # Fallback to original if FDA fails

        # 5. Normalize
        gta_stylized = self.normalize(gta_stylized)

        # 6. Process Mask
        mask_np = np.array(gta_mask)
        target_mask = np.full(mask_np.shape, 255, dtype=np.uint8)
        for k, v in self.id_mapping.items(): target_mask[mask_np == k] = v

        return gta_stylized, torch.from_numpy(target_mask).long()

# --- TRAINING LOOP ---
print(f"Starting FDA Training (Beta={L_BETA})...")

model = BiSeNet(num_classes=19, context_path='resnet18').to(DEVICE)
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
criterion = nn.CrossEntropyLoss(ignore_index=255)

dataset = GTA5_FDA_Dataset(GTA_PATH, CITYSCAPES_PATH)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)

if os.path.exists(SAVE_PATH):
    print("Resuming from checkpoint...")
    checkpoint = torch.load(SAVE_PATH, map_location=DEVICE)
    model.load_state_dict(checkpoint['model_state_dict'])
    start_epoch = checkpoint.get('epoch', 0) + 1
else:
    start_epoch = 0

for epoch in range(start_epoch, EPOCHS):
    model.train()
    for i, (img, lbl) in enumerate(loader):
        optimizer.zero_grad()
        out = model(img.to(DEVICE))
        loss = criterion(out[0], lbl.to(DEVICE)) + 0.1 * criterion(out[1], lbl.to(DEVICE)) + 0.1 * criterion(out[2], lbl.to(DEVICE))
        loss.backward()
        optimizer.step()

        if i % 50 == 0:
            print(f"Epoch [{epoch+1}/{EPOCHS}] Step [{i}/{len(loader)}] Loss: {loss.item():.4f}")

    # Save every epoch
    torch.save({'model_state_dict': model.state_dict(), 'epoch': epoch}, SAVE_PATH)
    print(f"Epoch {epoch+1} Saved.")

Starting FDA Training (Beta=0.05)...
Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


100%|██████████| 44.7M/44.7M [00:00<00:00, 122MB/s]


Downloading: "https://download.pytorch.org/models/resnet101-63fe2227.pth" to /root/.cache/torch/hub/checkpoints/resnet101-63fe2227.pth


100%|██████████| 171M/171M [00:01<00:00, 90.3MB/s]


Epoch [1/50] Step [0/313] Loss: 4.0714
Epoch [1/50] Step [50/313] Loss: 0.9045
Epoch [1/50] Step [100/313] Loss: 0.8016
Epoch [1/50] Step [150/313] Loss: 0.6505
Epoch [1/50] Step [200/313] Loss: 0.6071
Epoch [1/50] Step [250/313] Loss: 0.6166
Epoch [1/50] Step [300/313] Loss: 0.5493
Epoch 1 Saved.
Epoch [2/50] Step [0/313] Loss: 0.5026
Epoch [2/50] Step [50/313] Loss: 0.5384
Epoch [2/50] Step [100/313] Loss: 0.4582
Epoch [2/50] Step [150/313] Loss: 0.4515
Epoch [2/50] Step [200/313] Loss: 0.4586
Epoch [2/50] Step [250/313] Loss: 0.3605
Epoch [2/50] Step [300/313] Loss: 0.3965
Epoch 2 Saved.
Epoch [3/50] Step [0/313] Loss: 0.4182
Epoch [3/50] Step [50/313] Loss: 0.4937
Epoch [3/50] Step [100/313] Loss: 0.3657
Epoch [3/50] Step [150/313] Loss: 0.3702
Epoch [3/50] Step [200/313] Loss: 0.3456
Epoch [3/50] Step [250/313] Loss: 0.4183
Epoch [3/50] Step [300/313] Loss: 0.3498
Epoch 3 Saved.
Epoch [4/50] Step [0/313] Loss: 0.3558
Epoch [4/50] Step [50/313] Loss: 0.3469
Epoch [4/50] Step [100/3

In [None]:
from models.bisenet.build_bisenet import BiSeNet

# Configuration
CHECKPOINT_NAME = 'bisenet_fda_checkpoint.pth'
EPOCHS = 50
BATCH_SIZE = 8
L_BETA = 0.05
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
GTA_PATH = '/content/dataset/project_data/gta5'
CITYSCAPES_PATH = '/content/dataset/project_data/cityscapes'
SAVE_PATH = f'/content/drive/MyDrive/semseg/{CHECKPOINT_NAME}'

class GTA5_FDA_Dataset(Dataset):
    def __init__(self, gta_root, city_root):
        self.gta_images_dir = os.path.join(gta_root, 'images')
        self.gta_masks_dir = os.path.join(gta_root, 'labels')
        self.gta_images = sorted(os.listdir(self.gta_images_dir))

        self.city_images_dir = os.path.join(city_root, 'leftImg8bit', 'train')
        self.city_images = []
        if os.path.exists(self.city_images_dir):
            for city in os.listdir(self.city_images_dir):
                c_path = os.path.join(self.city_images_dir, city)
                if os.path.isdir(c_path):
                    for f in os.listdir(c_path):
                        if f.endswith('_leftImg8bit.png'):
                            self.city_images.append(os.path.join(c_path, f))

        self.normalize = transforms.Compose([
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

        self.id_mapping = {
            7: 0, 8: 1, 11: 2, 12: 3, 13: 4, 17: 5,
            19: 6, 20: 7, 21: 8, 22: 9, 23: 10, 24: 11, 25: 12,
            26: 13, 27: 14, 28: 15, 31: 16, 32: 17, 33: 18
        }
        self.to_tensor = transforms.ToTensor()

    def __len__(self): return len(self.gta_images)

    def __getitem__(self, idx):
        img_path = os.path.join(self.gta_images_dir, self.gta_images[idx])
        mask_path = os.path.join(self.gta_masks_dir, self.gta_images[idx])

        gta_image = Image.open(img_path).convert('RGB').resize((1280, 720), Image.BILINEAR)
        gta_mask = Image.open(mask_path).resize((1280, 720), Image.NEAREST)

        rand_idx = random.randint(0, len(self.city_images) - 1)
        city_image = Image.open(self.city_images[rand_idx]).convert('RGB').resize((1280, 720), Image.BILINEAR)

        gta_t = self.to_tensor(gta_image)
        city_t = self.to_tensor(city_image)

        try:
            gta_stylized = FDA_source_to_target(gta_t, city_t, L=L_BETA)
        except Exception as e:
            print(f"FDA Error: {e}")
            gta_stylized = gta_t

        gta_stylized = self.normalize(gta_stylized)
        mask_np = np.array(gta_mask)
        target_mask = np.full(mask_np.shape, 255, dtype=np.uint8)
        for k, v in self.id_mapping.items(): target_mask[mask_np == k] = v
        return gta_stylized, torch.from_numpy(target_mask).long()

print(f"Resuming FDA Training (Beta={L_BETA})...")

model = BiSeNet(num_classes=19, context_path='resnet18').to(DEVICE)
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
criterion = nn.CrossEntropyLoss(ignore_index=255)

if os.path.exists(SAVE_PATH):
    print("Loading Checkpoint...")
    checkpoint = torch.load(SAVE_PATH, map_location=DEVICE)
    model.load_state_dict(checkpoint['model_state_dict'])
    start_epoch = checkpoint.get('epoch', 0) + 1
    print(f"Starting from Epoch {start_epoch + 1}")
else:
    print("No checkpoint found! Starting from Epoch 1")
    start_epoch = 0

dataset = GTA5_FDA_Dataset(GTA_PATH, CITYSCAPES_PATH)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)

for epoch in range(start_epoch, EPOCHS):
    model.train()
    for i, (img, lbl) in enumerate(loader):
        optimizer.zero_grad()
        out = model(img.to(DEVICE))
        loss = criterion(out[0], lbl.to(DEVICE)) + 0.1 * criterion(out[1], lbl.to(DEVICE)) + 0.1 * criterion(out[2], lbl.to(DEVICE))
        loss.backward()
        optimizer.step()

        if i % 50 == 0:
            print(f"Epoch [{epoch+1}/{EPOCHS}] Step [{i}/{len(loader)}] Loss: {loss.item():.4f}")

    torch.save({'model_state_dict': model.state_dict(), 'epoch': epoch}, SAVE_PATH)
    print(f"Epoch {epoch+1} Saved.")

Resuming FDA Training (Beta=0.05)...
Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


100%|██████████| 44.7M/44.7M [00:00<00:00, 200MB/s]


Downloading: "https://download.pytorch.org/models/resnet101-63fe2227.pth" to /root/.cache/torch/hub/checkpoints/resnet101-63fe2227.pth


100%|██████████| 171M/171M [00:01<00:00, 126MB/s]


Loading Checkpoint...
Starting from Epoch 25
Epoch [25/50] Step [0/313] Loss: 0.1825
Epoch [25/50] Step [50/313] Loss: 0.2239
Epoch [25/50] Step [100/313] Loss: 0.2482
Epoch [25/50] Step [150/313] Loss: 0.1816
Epoch [25/50] Step [200/313] Loss: 0.2578
Epoch [25/50] Step [250/313] Loss: 0.2024
Epoch [25/50] Step [300/313] Loss: 0.1912
Epoch 25 Saved.
Epoch [26/50] Step [0/313] Loss: 0.1703
Epoch [26/50] Step [50/313] Loss: 0.1905
Epoch [26/50] Step [100/313] Loss: 0.1911
Epoch [26/50] Step [150/313] Loss: 0.1717
Epoch [26/50] Step [200/313] Loss: 0.1811
Epoch [26/50] Step [250/313] Loss: 0.2127
Epoch [26/50] Step [300/313] Loss: 0.1817
Epoch 26 Saved.
Epoch [27/50] Step [0/313] Loss: 0.1939
Epoch [27/50] Step [50/313] Loss: 0.1629
Epoch [27/50] Step [100/313] Loss: 0.1733
Epoch [27/50] Step [150/313] Loss: 0.1810
Epoch [27/50] Step [200/313] Loss: 0.1659
Epoch [27/50] Step [250/313] Loss: 0.1810
Epoch [27/50] Step [300/313] Loss: 0.1686
Epoch 27 Saved.
Epoch [28/50] Step [0/313] Loss: 0

In [None]:
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import torchvision.transforms as transforms
import os
import random
from models.bisenet.build_bisenet import BiSeNet

# Configuration
CHECKPOINT_NAME = 'bisenet_fda_checkpoint.pth'
EPOCHS = 50
BATCH_SIZE = 8
L_BETA = 0.05
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
GTA_PATH = '/content/dataset/project_data/gta5'
CITYSCAPES_PATH = '/content/dataset/project_data/cityscapes'
SAVE_PATH = f'/content/drive/MyDrive/semseg/{CHECKPOINT_NAME}'

class GTA5_FDA_Dataset(Dataset):
    def __init__(self, gta_root, city_root):
        self.gta_images_dir = os.path.join(gta_root, 'images')
        self.gta_masks_dir = os.path.join(gta_root, 'labels')
        self.gta_images = sorted(os.listdir(self.gta_images_dir))

        self.city_images_dir = os.path.join(city_root, 'leftImg8bit', 'train')
        self.city_images = []
        if os.path.exists(self.city_images_dir):
            for city in os.listdir(self.city_images_dir):
                c_path = os.path.join(self.city_images_dir, city)
                if os.path.isdir(c_path):
                    for f in os.listdir(c_path):
                        if f.endswith('_leftImg8bit.png'):
                            self.city_images.append(os.path.join(c_path, f))

        self.normalize = transforms.Compose([
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

        self.id_mapping = {
            7: 0, 8: 1, 11: 2, 12: 3, 13: 4, 17: 5,
            19: 6, 20: 7, 21: 8, 22: 9, 23: 10, 24: 11, 25: 12,
            26: 13, 27: 14, 28: 15, 31: 16, 32: 17, 33: 18
        }
        self.to_tensor = transforms.ToTensor()

    def __len__(self): return len(self.gta_images)

    def __getitem__(self, idx):
        img_path = os.path.join(self.gta_images_dir, self.gta_images[idx])
        mask_path = os.path.join(self.gta_masks_dir, self.gta_images[idx])

        gta_image = Image.open(img_path).convert('RGB').resize((1280, 720), Image.BILINEAR)
        gta_mask = Image.open(mask_path).resize((1280, 720), Image.NEAREST)

        rand_idx = random.randint(0, len(self.city_images) - 1)
        city_image = Image.open(self.city_images[rand_idx]).convert('RGB').resize((1280, 720), Image.BILINEAR)

        gta_t = self.to_tensor(gta_image)
        city_t = self.to_tensor(city_image)

        try:
            gta_stylized = FDA_source_to_target(gta_t, city_t, L=L_BETA)
        except Exception:
            gta_stylized = gta_t

        gta_stylized = self.normalize(gta_stylized)
        mask_np = np.array(gta_mask)
        target_mask = np.full(mask_np.shape, 255, dtype=np.uint8)
        for k, v in self.id_mapping.items(): target_mask[mask_np == k] = v
        return gta_stylized, torch.from_numpy(target_mask).long()

print(f"Resuming FDA Training...")

model = BiSeNet(num_classes=19, context_path='resnet18').to(DEVICE)
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
criterion = nn.CrossEntropyLoss(ignore_index=255)

if os.path.exists(SAVE_PATH):
    print("Found checkpoint. Loading...")
    checkpoint = torch.load(SAVE_PATH, map_location=DEVICE)
    model.load_state_dict(checkpoint['model_state_dict'])
    start_epoch = checkpoint.get('epoch', 0) + 1
    print(f"Starting from Epoch {start_epoch}")
else:
    print("No checkpoint found. Starting from Epoch 0")
    start_epoch = 0

dataset = GTA5_FDA_Dataset(GTA_PATH, CITYSCAPES_PATH)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)

for epoch in range(start_epoch, EPOCHS):
    model.train()
    for i, (img, lbl) in enumerate(loader):
        optimizer.zero_grad()
        out = model(img.to(DEVICE))
        loss = criterion(out[0], lbl.to(DEVICE)) + 0.1 * criterion(out[1], lbl.to(DEVICE)) + 0.1 * criterion(out[2], lbl.to(DEVICE))
        loss.backward()
        optimizer.step()

        if i % 50 == 0:
            print(f"Epoch [{epoch+1}/{EPOCHS}] Step [{i}/{len(loader)}] Loss: {loss.item():.4f}")

    torch.save({'model_state_dict': model.state_dict(), 'epoch': epoch}, SAVE_PATH)
    print(f"Epoch {epoch+1} Saved.")

Resuming FDA Training...
Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


100%|██████████| 44.7M/44.7M [00:00<00:00, 209MB/s]


Downloading: "https://download.pytorch.org/models/resnet101-63fe2227.pth" to /root/.cache/torch/hub/checkpoints/resnet101-63fe2227.pth


100%|██████████| 171M/171M [00:01<00:00, 169MB/s]


Found checkpoint. Loading...
Starting from Epoch 48
Epoch [49/50] Step [0/313] Loss: 0.1600
Epoch [49/50] Step [50/313] Loss: 0.1380
Epoch [49/50] Step [100/313] Loss: 0.1621
Epoch [49/50] Step [150/313] Loss: 0.1383
Epoch [49/50] Step [200/313] Loss: 0.1651
Epoch [49/50] Step [250/313] Loss: 0.1861
Epoch [49/50] Step [300/313] Loss: 0.1486
Epoch 49 Saved.
Epoch [50/50] Step [0/313] Loss: 0.1524
Epoch [50/50] Step [50/313] Loss: 0.1645
Epoch [50/50] Step [100/313] Loss: 0.1374
Epoch [50/50] Step [150/313] Loss: 0.1317
Epoch [50/50] Step [200/313] Loss: 0.2614
Epoch [50/50] Step [250/313] Loss: 0.1706
Epoch [50/50] Step [300/313] Loss: 0.1393
Epoch 50 Saved.


In [None]:
import torch
import os
import sys
import numpy as np
from torch.utils.data import DataLoader, Dataset
from models.bisenet.build_bisenet import BiSeNet
from PIL import Image
from tqdm import tqdm
import torchvision.transforms as transforms

# Configuration
CITYSCAPES_PATH = '/content/dataset/project_data/cityscapes'
CHECKPOINT_PATH = '/content/drive/MyDrive/semseg/bisenet_fda_checkpoint.pth'
NUM_CLASSES = 19
BATCH_SIZE = 4
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

CLASSES = [
    "Road", "Sidewalk", "Building", "Wall", "Fence", "Pole",
    "Traffic Light", "Traffic Sign", "Vegetation", "Terrain", "Sky",
    "Person", "Rider", "Car", "Truck", "Bus", "Train", "Motorcycle", "Bicycle"
]

if os.path.exists('/content/models'): sys.path.append('/content/models')

class CityscapesDataset(Dataset):
    def __init__(self, root, split='val', transform=None):
        self.root = root
        self.transform = transform
        self.images_dir = os.path.join(root, 'leftImg8bit', split)
        self.masks_dir = os.path.join(root, 'gtFine', split)
        self.images = []
        self.masks = []
        if os.path.exists(self.images_dir):
            for city in sorted(os.listdir(self.images_dir)):
                img_dir_path = os.path.join(self.images_dir, city)
                mask_dir_path = os.path.join(self.masks_dir, city)
                if not os.path.isdir(img_dir_path): continue
                for file_name in sorted(os.listdir(img_dir_path)):
                    if file_name.endswith('_leftImg8bit.png'):
                        self.images.append(os.path.join(img_dir_path, file_name))
                        mask_name = file_name.replace('_leftImg8bit.png', '_gtFine_labelTrainIds.png')
                        self.masks.append(os.path.join(mask_dir_path, mask_name))

    def __len__(self): return len(self.images)
    def __getitem__(self, idx):
        image = Image.open(self.images[idx]).convert('RGB').resize((1024, 512), Image.BILINEAR)
        mask = Image.open(self.masks[idx]).resize((1024, 512), Image.NEAREST)
        if self.transform: image = self.transform(image)
        return image, torch.from_numpy(np.array(mask)).long()

print(f"Evaluating: {CHECKPOINT_PATH}")
model = BiSeNet(num_classes=NUM_CLASSES, context_path='resnet18')
model.to(DEVICE)

if os.path.exists(CHECKPOINT_PATH):
    checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
    model.load_state_dict(checkpoint['model_state_dict'])
    print("Model loaded.")
else:
    print(f"Checkpoint not found at {CHECKPOINT_PATH}")
    sys.exit()

model.eval()
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
dataset = CityscapesDataset(CITYSCAPES_PATH, split='val', transform=transform)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

hist = np.zeros((NUM_CLASSES, NUM_CLASSES))
print("Processing validation images...")

with torch.no_grad():
    for images, labels in tqdm(dataloader):
        images = images.to(DEVICE)
        labels = labels.numpy()
        output = model(images)
        if isinstance(output, tuple): output = output[0]
        preds = torch.argmax(output, dim=1).cpu().numpy()
        mask = (labels >= 0) & (labels < NUM_CLASSES)
        hist += np.bincount(
            NUM_CLASSES * labels[mask].astype(int) + preds[mask],
            minlength=NUM_CLASSES ** 2
        ).reshape(NUM_CLASSES, NUM_CLASSES)

iou = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist))
miou = np.nanmean(iou)

print(f"\nFinal mIoU (FDA): {miou * 100:.2f}%")
print("-" * 30)
for i, class_name in enumerate(CLASSES):
    print(f"{class_name:15s}: {iou[i] * 100:.2f}%")
print("-" * 30)

Evaluating: /content/drive/MyDrive/semseg/bisenet_fda_checkpoint.pth
Model loaded.
Processing validation images...


100%|██████████| 125/125 [00:56<00:00,  2.23it/s]


Final mIoU (FDA): 26.98%
------------------------------
Road           : 71.53%
Sidewalk       : 25.06%
Building       : 67.71%
Wall           : 15.36%
Fence          : 12.14%
Pole           : 21.71%
Traffic Light  : 16.41%
Traffic Sign   : 12.93%
Vegetation     : 79.07%
Terrain        : 17.43%
Sky            : 70.83%
Person         : 36.45%
Rider          : 0.29%
Car            : 47.13%
Truck          : 8.06%
Bus            : 0.47%
Train          : 5.19%
Motorcycle     : 4.78%
Bicycle        : 0.02%
------------------------------



