# Environment Setup

In [None]:
# General
import os
import time
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

# PyTorch
import torch
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
from torchvision import transforms
from torchinfo import summary
from torchmetrics import Accuracy, JaccardIndex

# CubiCasa
from floortrans.loaders import FloorplanSVG
from floortrans.loaders.augmentations import (RandomCropToSizeTorch,
                                              ResizePaddedTorch,
                                              Compose,
                                              DictToTensor,
                                              ColorJitterTorch,
                                              RandomRotations)

# Own modules
from models.deeplabv3plus import DeepLabV3Plus
from evaluation_metrics import Metrics, timer

# Release GPU memory
torch.cuda.empty_cache()
print("GPU memory has been released.")

# Check if GPU is available
if torch.cuda.is_available():
    device = torch.device("cuda")
    print(f"Using device: {torch.cuda.get_device_name(0)}")
else:
    device = torch.device("cpu")
    print("Using device: CPU")


print("Setup complete.")

# Data Preprocessing and Augmentations

In [2]:
IMAGE_SIZE = (256, 256)

aug = Compose([transforms.RandomChoice([RandomCropToSizeTorch(data_format='dict', size=IMAGE_SIZE),
                                            ResizePaddedTorch((0, 0), data_format='dict', size=IMAGE_SIZE)]),
                                        RandomRotations(format='cubi'),
                                        DictToTensor(),
                                        ColorJitterTorch()])

# Dataset

In [None]:
DATA_PATH = 'data/cubicasa5k/'
TRAIN_PATH = 'train.txt'
VAL_PATH = 'val.txt'
FORMAT = 'lmdb'


train_set = FloorplanSVG(DATA_PATH, TRAIN_PATH, format=FORMAT, augmentations=aug)

# Use this in the meantime to prevent kernel dying
# train_set = Subset(full_train_set, list(range(1000)))

val_set = FloorplanSVG(DATA_PATH, VAL_PATH, format=FORMAT, augmentations=DictToTensor())

print('Train set size:', len(train_set))
print('Validation set size:', len(val_set))

In [None]:
sample = train_set[np.random.randint(0, len(train_set))]
print('Image shape:', sample['image'].shape)
print('Label shape:', sample['label'].shape)

print('\nLabel shape (rooms): ', sample['label'][0].shape)
print('Label shape (icons): ', sample['label'][1].shape)

In [None]:
print('Image: ', sample['image'])

In [None]:
print('Label: ', sample['label'])

# Visualize Images and Labels

In [None]:
# Convert back to [0, 255] range
tensor_image = sample['image'] * 255.0

np_image = tensor_image.numpy().astype(np.uint8)  # Convert to unsigned 8-bit integer

# Transpose to [H, W, 3] from [3, H, W]
np_image = np.transpose(np_image, (1, 2, 0))

fig, axes = plt.subplots(1, 3, figsize=(15, 5))  # 1 row, 3 columns

# Input image
axes[0].imshow(np_image)
axes[0].axis('off')  # Remove the axes
axes[0].set_title('Input Image')

# Room segmentation map
axes[1].imshow(sample['label'][0])
axes[1].axis('off')  # Remove the axes
axes[1].set_title('Room Labels')

# Icon segmentation map
axes[2].imshow(sample['label'][1])
axes[2].axis('off')
axes[2].set_title('Icon Labels')

plt.tight_layout()
plt.show()


# DataLoader

In [None]:
NUM_WORKERS = 0
BATCH_SIZE = 16

train_loader = DataLoader(
    train_set, 
    batch_size=BATCH_SIZE, 
    num_workers=NUM_WORKERS, 
    shuffle=True, 
    pin_memory=True
)

val_loader = DataLoader(
    val_set, 
    batch_size=1, 
    num_workers=NUM_WORKERS, 
    pin_memory=True
)

print(f'Length of train dataloader: {len(train_loader)} batches of size {BATCH_SIZE}')
print(f'Length of val dataloader: {len(val_loader)} batches of size {BATCH_SIZE}')

batch_sample = next(iter(train_loader))
print('\nBatch image shape: ', batch_sample['image'].shape)
print('Batch label shape: ', batch_sample['label'].shape)

# Model Setup

For reference, here are the 23 classes:  

- **Rooms (12):** "Background", "Outdoor", "Wall", "Kitchen", "Living Room" ,"Bed Room", "Bath", "Entry", "Railing", "Storage", "Garage", "Undefined"  

- **Icons (11):** "No Icon", "Window", "Door", "Closet", "Electrical Applience" ,"Toilet", "Sink", "Sauna Bench", "Fire Place", "Bathtub", "Chimney"

In [None]:
model = DeepLabV3Plus(backbone='mobilenetv2', attention=False)
model.to(device)

summary(model, input_size=(BATCH_SIZE, 3, IMAGE_SIZE[0], IMAGE_SIZE[1]))

In [None]:
# Verifying output shapes
sample_input = torch.randn(16, 3, 256, 256)

room_output, icon_output = model(sample_input.to(device))

print("Room Output Shape: ", room_output.shape)  # Expected: [16, 12, 256, 256]
print("Icon Output Shape: ", icon_output.shape)  # Expected: [16, 11, 256, 256]

# Loss Function

In [10]:
# def multitask_loss(room_output, icon_output, room_labels, icon_labels, alpha=1.0, beta=1.0):
#     """
#     Compute the multitask loss for room and icon segmentation.
    
#     Args:
#         room_output: Model's room segmentation output.
#         icon_output: Model's icon segmentation output.
#         room_labels: Ground truth for room segmentation.
#         icon_labels: Ground truth for icon segmentation.
#         alpha: Weight for room segmentation loss.
#         beta: Weight for icon segmentation loss.
    
#     Returns:
#         total_loss: Combined loss for room and icon segmentation.
#     """

#     room_labels = room_labels.long()
#     icon_labels = icon_labels.long()

#     # Cross-Entropy Loss for room and icon segmentation
#     room_loss = F.cross_entropy(room_output, room_labels)
#     icon_loss = F.cross_entropy(icon_output, icon_labels)
    
#     # Combine losses using weights
#     total_loss = alpha * room_loss + beta * icon_loss
#     return total_loss

In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiTaskLossWrapper(nn.Module):
    def __init__(self, task_num):
        super(MultiTaskLossWrapper, self).__init__()
        # Learnable log of the task uncertainties (log_sigma for stability)
        self.log_sigma_room = nn.Parameter(torch.tensor(0.0))  # Room segmentation
        self.log_sigma_icon = nn.Parameter(torch.tensor(0.0))  # Icon segmentation

    def forward(self, room_output, icon_output, room_labels, icon_labels):
        """
        Compute the multi-task loss with uncertainty-driven weighting for room and icon segmentation.

        Args:
            room_output: Model's room segmentation output.
            icon_output: Model's icon segmentation output.
            room_labels: Ground truth for room segmentation.
            icon_labels: Ground truth for icon segmentation.

        Returns:
            total_loss: Combined loss for room and icon segmentation with uncertainty weighting.
        """
        room_labels = room_labels.long()
        icon_labels = icon_labels.long()

        # Cross-Entropy Loss for room and icon segmentation
        room_loss = F.cross_entropy(room_output, room_labels)
        icon_loss = F.cross_entropy(icon_output, icon_labels)

        # Uncertainty-weighted loss
        # Loss for each task is scaled by exp(-2 * log_sigma) and regularized by log_sigma
        loss_room = (1 / (2 * torch.exp(self.log_sigma_room))) * room_loss + self.log_sigma_room
        loss_icon = (1 / (2 * torch.exp(self.log_sigma_icon))) * icon_loss + self.log_sigma_icon

        # Total loss is the sum of the two uncertainty-weighted losses
        total_loss = loss_room + loss_icon
        return total_loss


# Training Hyperparameters

In [12]:
# EPOCHS = 100

# CRITERION = torch.nn.CrossEntropyLoss()

# initial_lr = 0.001
# OPTIMIZER = torch.optim.SGD(model.parameters(), lr=initial_lr, momentum=0.95, weight_decay=1e-4, nesterov=True)

# # Poly learning rate policy (used in DeepLabV3+ paper)
# class PolyLR(torch.optim.lr_scheduler._LRScheduler):
#     def __init__(self, optimizer, max_iters, power=0.9, last_epoch=-1):
#         self.max_iters = max_iters
#         self.power = power
#         super(PolyLR, self).__init__(optimizer, last_epoch)

#     def get_lr(self):
#         return [base_lr * (1 - self.last_epoch / self.max_iters) ** self.power for base_lr in self.base_lrs]

# max_iters = EPOCHS * len(train_loader)
# SCHEDULER = PolyLR(OPTIMIZER, max_iters)

# Training and Validation Loop

In [None]:
EPOCHS = 400
OPTIMIZER = torch.optim.Adam(model.parameters(), lr=0.0001)
CRITERION = MultiTaskLossWrapper(task_num=2)

def timer(start_time = None): 
    return time.time() if start_time == None else time.time() - start_time


def train_evaluate(model, 
                   train_loader, 
                   val_loader, 
                   device,
                   loss_fn,
                   optimizer, 
                   epochs,
                   early_stop_threshold=15,
                   save_prefix='deeplabv3plus',
                   save_path='saved_models'):
    
    # Store results, to be returned
    train_loss_list = []
    train_room_cpa_list = []
    train_room_mpa_list = []
    train_room_miou_list = []
    train_room_fwiou_list = []
    train_icon_cpa_list = []
    train_icon_mpa_list = []
    train_icon_miou_list = []
    train_icon_fwiou_list = []

    val_loss_list = []
    val_room_cpa_list = []
    val_room_mpa_list = []
    val_room_miou_list = []
    val_room_fwiou_list = []
    val_icon_cpa_list = []
    val_icon_mpa_list = []
    val_icon_miou_list = []
    val_icon_fwiou_list = []

    # Training metrics
    train_room_cpa = Accuracy(task='multiclass', num_classes=12, average=None).to(device)
    train_room_mpa = Accuracy(task='multiclass', num_classes=12, average='macro').to(device)
    train_room_miou = JaccardIndex(task='multiclass', num_classes=12, average='macro').to(device)
    train_room_fwiou = JaccardIndex(task='multiclass', num_classes=12, average='weighted').to(device)

    train_icon_cpa = Accuracy(task='multiclass', num_classes=11, average=None).to(device)
    train_icon_mpa = Accuracy(task='multiclass', num_classes=11, average='macro').to(device)
    train_icon_miou = JaccardIndex(task='multiclass', num_classes=11, average='macro').to(device)
    train_icon_fwiou = JaccardIndex(task='multiclass', num_classes=11, average='weighted').to(device)

    # Validation metrics
    val_room_cpa = Accuracy(task='multiclass', num_classes=12, average=None).to(device)
    val_room_mpa = Accuracy(task='multiclass', num_classes=12, average='macro').to(device)
    val_room_miou = JaccardIndex(task='multiclass', num_classes=12, average='macro').to(device)
    val_room_fwiou = JaccardIndex(task='multiclass', num_classes=12, average='weighted').to(device)

    val_icon_cpa = Accuracy(task='multiclass', num_classes=11, average=None).to(device)
    val_icon_mpa = Accuracy(task='multiclass', num_classes=11, average='macro').to(device)
    val_icon_miou = JaccardIndex(task='multiclass', num_classes=11, average='macro').to(device)
    val_icon_fwiou = JaccardIndex(task='multiclass', num_classes=11, average='weighted').to(device)
    
    best_loss = np.inf
    not_improving = 0
    
    # Save models in this directory
    os.makedirs(save_path, exist_ok=True)

    # Start timer
    train_start = timer()
    print('Start training process...')

    for epoch in range(1, epochs + 1):
        epoch_start = timer()
        
        # Training loop
        print(f'Epoch {epoch} train process started...')
        model.train()

        epoch_train_loss = 0.0

        for batch in tqdm(train_loader):
            images = batch['image'].to(device)
            room_labels = batch['label'][:, 0].to(device)
            icon_labels = batch['label'][:, 1].to(device)

            # Reset gradients since PyTorch accumulates previous gradients
            optimizer.zero_grad()

            # Forward pass
            room_output, icon_output = model(images)

            # Calculate loss
            loss = loss_fn(room_output, icon_output, room_labels, icon_labels)
            epoch_train_loss += loss.item()

            # Backward pass
            loss.backward()

            # Update weights
            optimizer.step()

            # Update metrics
            train_room_cpa.update(room_output, room_labels)
            train_room_mpa.update(room_output, room_labels)
            train_room_miou.update(room_output, room_labels)
            train_room_fwiou.update(room_output, room_labels)

            train_icon_cpa.update(icon_output, icon_labels)
            train_icon_mpa.update(icon_output, icon_labels)
            train_icon_miou.update(icon_output, icon_labels)
            train_icon_fwiou.update(icon_output, icon_labels)

        # Calculate training metrics
        train_room_cpa_value = train_room_cpa.compute()
        train_room_mpa_value = train_room_mpa.compute().item()
        train_room_miou_value = train_room_miou.compute().item()
        train_room_fwiou_value = train_room_fwiou.compute().item()

        train_icon_cpa_value = train_icon_cpa.compute()
        train_icon_mpa_value = train_icon_mpa.compute().item()
        train_icon_miou_value = train_icon_miou.compute().item()
        train_icon_fwiou_value = train_icon_fwiou.compute().item()

        # Reset metrics
        train_room_cpa.reset()
        train_room_mpa.reset()
        train_room_miou.reset()
        train_room_fwiou.reset()

        train_icon_cpa.reset()
        train_icon_mpa.reset()
        train_icon_miou.reset()
        train_icon_fwiou.reset()

        
        # Validation loop
        print(f'Epoch {epoch} validation process started...')
        model.eval()

        epoch_val_loss = 0.0

        with torch.no_grad():
            for batch in tqdm(val_loader):
                images = batch['image'].to(device)
                room_labels = batch['label'][:, 0].to(device)
                icon_labels = batch['label'][:, 1].to(device)

                # Get model predictions
                room_output, icon_output = model(images)

                # Calculate loss
                loss = loss_fn(room_output, icon_output, room_labels, icon_labels)
                epoch_val_loss += loss.item()

                # Update metrics
                val_room_cpa.update(room_output, room_labels)
                val_room_mpa.update(room_output, room_labels)
                val_room_miou.update(room_output, room_labels)
                val_room_fwiou.update(room_output, room_labels)

                val_icon_cpa.update(icon_output, icon_labels)
                val_icon_mpa.update(icon_output, icon_labels)
                val_icon_miou.update(icon_output, icon_labels)
                val_icon_fwiou.update(icon_output, icon_labels)

        # Calculate validation metrics
        val_room_cpa_value = val_room_cpa.compute()
        val_room_mpa_value = val_room_mpa.compute().item()
        val_room_miou_value = val_room_miou.compute().item()
        val_room_fwiou_value = val_room_fwiou.compute().item()

        val_icon_cpa_value = val_icon_cpa.compute()
        val_icon_mpa_value = val_icon_mpa.compute().item()
        val_icon_miou_value = val_icon_miou.compute().item()
        val_icon_fwiou_value = val_icon_fwiou.compute().item()

        # Reset metrics
        val_room_cpa.reset()
        val_room_mpa.reset()
        val_room_miou.reset()
        val_room_fwiou.reset()

        val_icon_cpa.reset()
        val_icon_mpa.reset()
        val_icon_miou.reset()
        val_icon_fwiou.reset()


        # Print epoch results
        print(f'Epoch {epoch} train process is completed.')
        print('\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~')
        print(f'\nEpoch {epoch} train process results:\n')

        print(f'Train Time: {timer(epoch_start):.3f} secs')

        # Combine room and icon metrics for training, including loss
        print(f'Train Room - Loss: {epoch_train_loss / len(train_loader):.3f}, MPA: {train_room_mpa_value:.3f}, mIOU: {train_room_miou_value:.3f}, fwIOU: {train_room_fwiou_value:.3f}')
        print(f'Train Icon - Loss: {epoch_train_loss / len(train_loader):.3f}, MPA: {train_icon_mpa_value:.3f}, mIOU: {train_icon_miou_value:.3f}, fwIOU: {train_icon_fwiou_value:.3f}')

        print(f'\nVal process results:')

        # Combine room and icon metrics for validation, including loss
        print(f'Val Room - Loss: {epoch_val_loss / len(val_loader):.3f}, MPA: {val_room_mpa_value:.3f}, mIOU: {val_room_miou_value:.3f}, fwIOU: {val_room_fwiou_value:.3f}')
        print(f'Val Icon - Loss: {epoch_val_loss / len(val_loader):.3f}, MPA: {val_icon_mpa_value:.3f}, mIOU: {val_icon_miou_value:.3f}, fwIOU: {val_icon_fwiou_value:.3f}')


        # Append results
        train_loss_list.append(epoch_train_loss / len(train_loader))
        train_room_cpa_list.append(train_room_cpa_value)
        train_room_mpa_list.append(train_room_mpa_value)
        train_room_miou_list.append(train_room_miou_value)
        train_room_fwiou_list.append(train_room_fwiou_value)
        train_icon_cpa_list.append(train_icon_cpa_value)
        train_icon_mpa_list.append(train_icon_mpa_value)
        train_icon_miou_list.append(train_icon_miou_value)
        train_icon_fwiou_list.append(train_icon_fwiou_value)

        val_loss_list.append(epoch_val_loss / len(val_loader))
        val_room_cpa_list.append(val_room_cpa_value)
        val_room_mpa_list.append(val_room_mpa_value)
        val_room_miou_list.append(val_room_miou_value)
        val_room_fwiou_list.append(val_room_fwiou_value)
        val_icon_cpa_list.append(val_icon_cpa_value)
        val_icon_mpa_list.append(val_icon_mpa_value)
        val_icon_miou_list.append(val_icon_miou_value)
        val_icon_fwiou_list.append(val_icon_fwiou_value)

        # Save model if validation loss is improved
        if (epoch_val_loss / len(val_loader)) < best_loss:
            print(f'\nLoss decreased from {best_loss:.3f} to {(epoch_val_loss / len(val_loader)):.3f}!')
            best_loss = (epoch_val_loss / len(val_loader))

            not_improving = 0 # Reset counter

            print('Saving the model with the best loss value...')
            torch.save(model.state_dict(), f'{save_path}/{save_prefix}.pt')
        
        else:
            not_improving += 1
            print(f'\nLoss did not decrease for {not_improving} epoch(s)!')

            if not_improving == early_stop_threshold:
                print(f'Stopping training process because loss did not decrease for {early_stop_threshold} epochs!')
                break
        
        print('\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~')

    print(f'Train process is completed in {(timer(train_start)) / 60:.3f} minutes.')


    return {
        'train_loss': train_loss_list,
        'train_room_cpa': train_room_cpa_list,
        'train_room_mpa': train_room_mpa_list,
        'train_room_miou': train_room_miou_list,
        'train_room_fwiou': train_room_fwiou_list,
        'train_icon_cpa': train_icon_cpa_list,
        'train_icon_mpa': train_icon_mpa_list,
        'train_icon_miou': train_icon_miou_list,
        'train_icon_fwiou': train_icon_fwiou_list,
        'val_loss': val_loss_list,
        'val_room_cpa': val_room_cpa_list,
        'val_room_mpa': val_room_mpa_list,
        'val_room_miou': val_room_miou_list,
        'val_room_fwiou': val_room_fwiou_list,
        'val_icon_cpa': val_icon_cpa_list,
        'val_icon_mpa': val_icon_mpa_list,
        'val_icon_miou': val_icon_miou_list,
        'val_icon_fwiou': val_icon_fwiou_list
    }


# Run training and validation process
history = train_evaluate(model,
                         train_loader,
                         val_loader,
                         device,
                         CRITERION,
                         OPTIMIZER,
                         EPOCHS,
                         save_prefix=f'deeplabv3plus_{model.backbone_name}')

# Visualizing Performance

In [16]:
# class Plot():
#     def __init__(self, results):
#         self.results = results

#         self.visualize(metric1="tr_iou", 
#                        metric2="val_iou", 
#                        label1="Train IoU",
#                        label2 ="Validation IoU", 
#                        title="Mean Intersection Over Union Learning Curve", 
#                        ylabel="mIoU Score")

#         self.visualize(metric1="tr_pa", 
#                        metric2="val_pa", 
#                        label1="Train PA",
#                        label2="Validation PA", 
#                        title="Pixel Accuracy Learning Curve", 
#                        ylabel="PA Score")

#         self.visualize(metric1="tr_loss", 
#                        metric2="val_loss", 
#                        label1="Train Loss",
#                        label2="Validation Loss", 
#                        title="Loss Learning Curve", 
#                        ylabel="Loss Value")

#     def plot(self, metric, label): 
#         plt.plot(self.results[metric], label=label)

#     def decorate(self, ylabel, title): 
#         plt.title(title)
#         plt.xlabel("Epochs")
#         plt.ylabel(ylabel)
#         plt.legend()
#         plt.show()

#     def visualize(self, metric1, metric2, label1, label2, title, ylabel):
#         plt.figure(figsize=(10, 5))
#         self.plot(metric1, label1)
#         self.plot(metric2, label2)
#         self.decorate(ylabel, title)


# Plot(history)

# Testing 

In [17]:
# class Test():
#     def __init__(self, model, test_loader, loss_fn, device):
#         self.model = model
#         self.test_loader = test_loader
#         self.loss_fn = loss_fn
#         self.device = device
    
#     def run(self):
#         self.model.eval()
#         test_loss = 0
#         test_iou = 0
#         test_pixel_acc = 0
#         test_len = len(self.test_loader)

#         imgs = []
#         gts = []
#         preds = []

#         with torch.no_grad():
#             for batch in tqdm(self.test_loader):
#                 imgs_batch = batch['image']
#                 gts_batch = batch['label']
#                 imgs_batch, gts_batch = imgs_batch.to(self.device), gts_batch.to(self.device)

#                 # Forward pass
#                 preds_batch = self.model(imgs_batch)
                
#                 # Calculate metrics
#                 metrics = Metrics(preds_batch, gts_batch, self.loss_fn)
#                 test_loss += metrics.loss().item()
#                 test_iou += metrics.mIOU()
#                 test_pixel_acc += metrics.PixelAcc()

#                 # Collect data for visualization
#                 preds_batch = torch.argmax(preds_batch, dim=1)
#                 imgs.extend(imgs_batch.cpu())
#                 gts.extend(gts_batch.cpu())
#                 preds.extend(preds_batch.cpu())

#         # Calculate average metrics
#         test_loss /= test_len
#         test_iou /= test_len
#         test_pixel_acc /= test_len

#         return imgs, gts, preds, test_loss, test_iou, test_pixel_acc


# test = Test(model, test_loader, CRITERION, device)
# imgs, gts, preds, test_loss, test_iou, test_pixel_acc = test.run()

# print(f"Test Loss: {test_loss:.4f}")
# print(f"Test mIoU: {test_iou:.4f}")
# print(f"Test PA: {test_pixel_acc:.4f}")