In [1]:
import os
from glob import glob
import time

import numpy as np
import h5py
import cv2

import torch 
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset,Subset
from torchvision import transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
import torch.nn.functional as F

import matplotlib as mpl
import matplotlib.pyplot as plt

from torchvision.models import mobilenet_v3_small
from sklearn.metrics import jaccard_score
from torchvision.models.mobilenetv3 import MobileNet_V3_Small_Weights
from torchvision.models import vgg16, VGG16_Weights



import csv

from datetime import datetime

from PIL import Image
import torch.nn.functional as F


from labels import labels

%matplotlib inline

curr_dir=os.getcwd()
root= os.path.join(curr_dir,"cityscapes_dataset")
curr_dir,root

('/home/rmajumd/2024/ML_in_image_synthesis/Cityscapes/Cityscapes',
 '/home/rmajumd/2024/ML_in_image_synthesis/Cityscapes/Cityscapes/cityscapes_dataset')

In [2]:

import torch
from torchvision import transforms
import torchvision.transforms.functional as TF
# from albumentations.augmentations.transforms import RandomShadow

class Normalize(object):
    """ Normalizes RGB image to  0-mean 1-std_dev """ 
    def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], depth_norm=5, max_depth=250):
        self.mean = mean
        self.std = std
        self.depth_norm = depth_norm
        self.max_depth = max_depth

    def __call__(self, sample):
        left, mask, depth = sample['left'], sample['mask'], sample['depth']
            
        return {'left': TF.normalize(left, self.mean, self.std), 
                'mask': mask, 
                'depth' : torch.clip( # saftey clip :)
                            torch.log(torch.clip(depth, 0, self.max_depth))/self.depth_norm, 
                            0, 
                            self.max_depth)}


class AddColorJitter(object):
    """Convert a color image to grayscale and normalize the color range to [0,1].""" 
    def __init__(self, brightness, contrast, saturation, hue):
        ''' Applies brightness, constrast, saturation, and hue jitter to image ''' 
        self.color_jitter = transforms.ColorJitter(brightness, contrast, saturation, hue)

    def __call__(self, sample):
        left, mask, depth = sample['left'], sample['mask'], sample['depth']

        return {'left': self.color_jitter(left), 
                'mask': mask, 
                'depth' : depth}


class Rescale(object):
    """ Rescales images with bilinear interpolation and masks with nearest interpolation """

    def __init__(self, h, w):
        self.h, self.w = h, w

    def __call__(self, sample):
        left, mask, depth = sample['left'], sample['mask'], sample['depth']
# mask interpolation Nearest is import to have smoothness
        return {'left': TF.resize(left, (self.h, self.w)), 
                'mask': TF.resize(mask.unsqueeze(0), (self.h, self.w), transforms.InterpolationMode.NEAREST), 
                'depth' : TF.resize(depth.unsqueeze(0), (self.h, self.w))}


class RandomCrop(object):
    def __init__(self, h, w, scale=(0.08, 1.0), ratio=(3.0 / 4.0, 4.0 / 3.0)):
        self.h = h
        self.w = w
        self.scale = scale
        self.ratio = ratio

    def __call__(self, sample):
        left, mask, depth = sample['left'], sample['mask'], sample['depth']
        i, j, h, w = transforms.RandomResizedCrop.get_params(left, scale=self.scale, ratio=self.ratio)

        return {'left': TF.resized_crop(left, i, j, h, w, (self.h, self.w)), 
                'mask': TF.resized_crop(mask.unsqueeze(0), i, j, h, w, (self.h, self.w), interpolation=TF.InterpolationMode.NEAREST),
                'depth' : TF.resized_crop(depth.unsqueeze(0), i, j, h, w, (self.h, self.w))}


class ToTensor(object):
    """Convert ndarrays in sample to Tensors."""
    def __call__(self, sample):
         
        left, mask, depth = sample['left'], sample['mask'], sample['depth']

        return {'left': transforms.ToTensor()(left), 
                'mask': torch.as_tensor(mask, dtype=torch.int64),
                'depth' : transforms.ToTensor()(depth).type(torch.float32)}
    

class ElasticTransform(object):
    def __init__(self, alpha=25.0, sigma=5.0, prob=0.5):
        self.alpha = [1.0, alpha]
        self.sigma = [1, sigma]
        self.prob = prob

    def __call__(self, sample):
        
        if torch.rand(1) < self.prob:

            left, mask, depth = sample['left'], sample['mask'], sample['depth']
            _, H, W = mask.shape
            displacement = transforms.ElasticTransform.get_params(self.alpha, self.sigma, [H, W])

            # # TEMP
            # print(TF.elastic_transform(left, displacement).shape)
            # print(TF.elastic_transform(mask.unsqueeze(0), displacement, interpolation=TF.InterpolationMode.NEAREST).shape)
            # print(torch.clip(TF.elastic_transform(depth, displacement), 0, depth.max()).shape)

            return {'left': TF.elastic_transform(left, displacement), 
                    'mask': TF.elastic_transform(mask.unsqueeze(0), displacement, interpolation=TF.InterpolationMode.NEAREST), 
                    'depth' : torch.clip(TF.elastic_transform(depth, displacement), 0, depth.max())} 
        
        else:
            return sample

        
    

# new transform to rotate the images
class RandomRotate(object):
    def __init__(self, angle):
        if not isinstance(angle, (list, tuple)):
            self.angle = (-abs(angle), abs(angle))
        else:
            self.angle = angle

    def __call__(self, sample):
        left, mask, depth = sample['left'], sample['mask'], sample['depth']

        angle = transforms.RandomRotation.get_params(self.angle)

        return {'left': TF.rotate(left, angle), 
                'mask': TF.rotate(mask.unsqueeze(0), angle), 
                'depth' : TF.rotate(depth, angle)}
    
    
class RandomHorizontalFlip(object):
    def __init__(self, prob=0.5):
        self.prob = prob

    def __call__(self, sample):
        
        if torch.rand(1) < self.prob:
            left, mask, depth = sample['left'], sample['mask'], sample['depth']
            return {'left': TF.hflip(left), 
                    'mask': TF.hflip(mask), 
                    'depth' : TF.hflip(depth)}
        else:
            return sample
        

class RandomVerticalFlip(object):
    def __init__(self, prob=0.5):
        self.prob = prob

    def __call__(self, sample):
        if torch.rand(1) < self.prob:
            left, mask, depth = sample['left'], sample['mask'], sample['depth']
            return {'left': TF.vflip(left), 
                    'mask': TF.vflip(mask), 
                    'depth' : TF.vflip(depth)}
        else:
            return sample
        

In [3]:

def convert_to_numpy(image):
    if not isinstance(image, np.ndarray):
        if len(image.shape) == 2:
            image = image.detach().cpu().numpy()
        else:
            image = image.detach().cpu().numpy().transpose(1, 2, 0)

    return image

def get_color_mask(mask, labels, id_type='id'):
    try:
        h, w = mask.shape
    except ValueError:
        mask = mask.squeeze(-1)
        h, w = mask.shape

    color_mask = np.zeros((h, w, 3), dtype=np.uint8)

    if id_type == 'id':
        for lbl in labels:
            color_mask[mask == lbl.id] = lbl.color
    elif id_type == 'trainId':
        for lbl in labels:
            if (lbl.trainId != 255) | (lbl.trainId != -1):
                color_mask[mask == lbl.trainId] = lbl.color

    return color_mask


def plot_items(left, mask, depth, labels=None, num_seg_labels=34, id_type='id'):
    left = convert_to_numpy(left)
    mask = convert_to_numpy(mask)
    depth = convert_to_numpy(depth)

    # unnormalize left image
    left = (left*np.array([0.229, 0.224, 0.225])) + np.array([0.485, 0.456, 0.406])

    # cmaps: 'prism', 'terrain', 'turbo', 'gist_rainbow_r', 'nipy_spectral_r'
    
    
    _, ax = plt.subplots(1, 3, figsize=(15,10))
    ax[0].imshow(left)
    ax[0].set_title("Left Image")

    if labels:
        color_mask = get_color_mask(mask, labels, id_type)
        ax[1].imshow(color_mask)
    else:
        cmap = mpl.colormaps.get_cmap('nipy_spectral_r').resampled(num_seg_labels)
        ax[1].imshow(mask, cmap=cmap)

    ax[1].set_title("Seg Mask")
    ax[2].imshow(depth, cmap='plasma')
    ax[2].set_title("Depth")

In [4]:
def scale_invariant_depth_loss(pred, target, lambda_weight=0.1):
    if pred.shape != target.shape:
        pred = F.interpolate(pred, size=target.shape[1:], mode='bilinear', align_corners=False)
    
    diff = pred - target
    n = diff.numel()
    mse = torch.sum(diff**2) / n
    scale_invariant = mse - (lambda_weight / (n**2)) * (torch.sum(diff))**2
    return scale_invariant

def depth_smoothness_loss(pred, img, alpha=1.0):
    depth_grad_x = torch.abs(pred[:, :, :, :-1] - pred[:, :, :, 1:])
    depth_grad_y = torch.abs(pred[:, :, :-1, :] - pred[:, :, 1:, :])
    img_grad_x = torch.mean(torch.abs(img[:, :, :, :-1] - img[:, :, :, 1:]), dim=1, keepdim=True)
    img_grad_y = torch.mean(torch.abs(img[:, :, :-1, :] - img[:, :, 1:, :]), dim=1, keepdim=True)
    smoothness_x = depth_grad_x * torch.exp(-alpha * img_grad_x)
    smoothness_y = depth_grad_y * torch.exp(-alpha * img_grad_y)
    return smoothness_x.mean() + smoothness_y.mean()


def inv_huber_loss(pred, target, delta=0.1):
    """
    Inverse Huber loss for depth prediction.
    Args:
        pred (Tensor): Predicted depth map.
        target (Tensor): Ground truth depth map.
        delta (float): Threshold for switching between quadratic and linear terms.
    Returns:
        Tensor: Inverse Huber loss.
    """
    abs_diff = torch.abs(pred - target)
    delta_tensor = torch.tensor(delta, dtype=abs_diff.dtype, device=abs_diff.device)  # Convert delta to tensor
    quadratic = torch.minimum(abs_diff, delta_tensor)
    linear = abs_diff - quadratic
    return (0.5 * quadratic**2 + delta_tensor * linear).mean()


def mean_iou(pred, target, num_classes):
    pred = torch.argmax(pred, dim=1)
    intersection = torch.logical_and(pred == target, target != 255).float()  # Ignore class 255
    union = torch.logical_or(pred == target, target != 255).float()
    iou = torch.sum(intersection) / torch.sum(union)
    return iou



def contrastive_loss(pred, target, margin=1.0):
    """
    Contrastive loss to ensure the depth map predictions are closer to the target.
    """
    # Flatten the tensors for element-wise operations
    pred_flat = pred.view(pred.size(0), -1)  # Flatten except for the batch dimension
    target_flat = target.view(target.size(0), -1)  # Flatten except for the batch dimension

    # Compute the pairwise distances
    distances = torch.sqrt(torch.sum((pred_flat - target_flat) ** 2, dim=1))  # Batch-wise distances

    # Create labels for contrastive loss
    labels = (torch.abs(pred_flat - target_flat).mean(dim=1) < margin).float()  # Batch-wise labels

    # Calculate contrastive loss
    similar_loss = labels * distances**2
    dissimilar_loss = (1 - labels) * torch.clamp(margin - distances, min=0)**2
    loss = (similar_loss + dissimilar_loss).mean()

    return loss


def dice_loss(predictions, targets, smooth=1e-6):
    """
    Calculate Dice Loss for segmentation.
    Args:
        predictions (torch.Tensor): The predicted segmentation map (logits or probabilities).
                                    Shape: [batch_size, num_classes, height, width]
        targets (torch.Tensor): The ground truth segmentation map (one-hot encoded or integer labels).
                                Shape: [batch_size, height, width]
        smooth (float): Smoothing factor to avoid division by zero.
    Returns:
        torch.Tensor: Dice Loss (scalar).
    """
    # Convert integer labels to one-hot if needed
    if predictions.shape != targets.shape:
        targets = F.one_hot(targets, num_classes=predictions.shape[1]).permute(0, 3, 1, 2).float()
    
    # Apply softmax to predictions for multi-class segmentation
    predictions = torch.softmax(predictions, dim=1)
    
    # Flatten tensors to calculate intersection and union
    predictions_flat = predictions.view(predictions.shape[0], predictions.shape[1], -1)
    targets_flat = targets.view(targets.shape[0], targets.shape[1], -1)
    
    # Calculate intersection and union
    intersection = (predictions_flat * targets_flat).sum(dim=-1)
    union = predictions_flat.sum(dim=-1) + targets_flat.sum(dim=-1)
    
    # Calculate Dice Coefficient
    dice_coeff = (2 * intersection + smooth) / (union + smooth)
    
    # Dice Loss
    return 1 - dice_coeff.mean()

In [5]:
def plot_loss(train_losses, valid_losses, save_dir):
    epochs = range(1, len(train_losses["seg"]) + 1)

    # Plot Segmentation Loss
    plt.figure(figsize=(10, 6))
    plt.plot(epochs, train_losses["seg"], label="Train Seg Loss")
    plt.plot(epochs, valid_losses["seg"], label="Valid Seg Loss")
    plt.xlabel("Epochs")
    plt.ylabel("Segmentation Loss")
    plt.legend()
    plt.title("Segmentation Loss Over Epochs")
    plt.savefig(os.path.join(save_dir, "segmentation_loss.png"))
    plt.close()

    # Plot Depth Loss
    plt.figure(figsize=(10, 6))
    plt.plot(epochs, train_losses["depth"], label="Train Depth Loss")
    plt.plot(epochs, valid_losses["depth"], label="Valid Depth Loss")
    plt.xlabel("Epochs")
    plt.ylabel("Depth Loss")
    plt.legend()
    plt.title("Depth Loss Over Epochs")
    plt.savefig(os.path.join(save_dir, "depth_loss.png"))
    plt.close()

    # Plot Combined Loss
    plt.figure(figsize=(10, 6))
    plt.plot(epochs, train_losses["combined"], label="Train Combined Loss")
    plt.plot(epochs, valid_losses["combined"], label="Valid Combined Loss")
    plt.xlabel("Epochs")
    plt.ylabel("Combined Loss")
    plt.legend()
    plt.title("Combined Loss Over Epochs")
    plt.savefig(os.path.join(save_dir, "combined_loss.png"))
    plt.close()


In [6]:
def save_training_visualization_as_gif2(epoch, inputs, seg_output, depth_output, seg_labels, depth_labels):
    inputs = inputs.detach().cpu()
    seg_output = torch.argmax(seg_output, dim=1).detach().cpu()
    depth_output = depth_output.detach().cpu()
    seg_labels = seg_labels.detach().cpu()
    depth_labels = depth_labels.detach().cpu()
    
#     inputs_rgb = (inputs - inputs.min()) / (inputs.max() - inputs.min() + 1e-5)  # Normalize inputs to [0, 1]
    
#     # Normalize depth maps for visualization
#     depth_labels_vis = (depth_labels - depth_labels.min()) / (depth_labels.max() - depth_labels.min() + 1e-5)
#     depth_preds_vis = (depth_output - depth_output.min()) / (depth_output.max() - depth_output.min() + 1e-5)



    batch_size = min(4, inputs.size(0))  # Limit to 4 samples for visualization
    fig, axes = plt.subplots(batch_size, 5, figsize=(15, 4 * batch_size))

    for i in range(batch_size):
        
        inputs_temp = inputs[i]
        # print(f"inputs_temp: {inputs_temp.shape}")
        inputs_rgb = (inputs_temp - inputs_temp.min()) / (inputs_temp.max() - inputs_temp.min() + 1e-5)  # Normalize inputs to [0, 1]
        
        depth_labels_vis = (depth_labels[i] - depth_labels[i].min()) / (depth_labels[i].max() - depth_labels[i].min() + 1e-5)
        depth_preds = depth_output[i]
        depth_preds_vis = (depth_preds - depth_preds.min()) / (depth_preds.max() - depth_preds.min() + 1e-5)
        # print(f"depth_labels_vis: {depth_labels_vis.shape}")
        # print(f"depth_preds_vis: {depth_preds_vis.shape}")

    
        
        # Row 1: Ground truth
        axes[i, 0].imshow(inputs_rgb.permute(1, 2, 0))
        axes[i, 0].set_title("RGB Image")
        axes[i, 0].axis("off")

        axes[i, 1].imshow(seg_labels[i], cmap="tab20")
        axes[i, 1].set_title("GT Segmentation")
        axes[i, 1].axis("off")

        axes[i, 2].imshow(depth_labels_vis.squeeze(), cmap="inferno")
        axes[i, 2].set_title("GT Depth")
        axes[i, 2].axis("off")

        # Row 2: Predictions
        axes[i, 3].imshow(seg_output[i], cmap="tab20")
        axes[i, 3].set_title("Generated Segmentation")
        axes[i, 3].axis("off")

        axes[i, 4].imshow(depth_preds_vis.squeeze(), cmap="inferno")
        axes[i, 4].set_title("Generated Depth")
        axes[i, 4].axis("off")
        
    # Remove axes for cleaner visualization
    for ax in axes.flat:
        ax.axis("off")


    # plt.tight_layout()
    fig.tight_layout()
    fig.canvas.draw()
    
    # # Save current epoch as an image for GIF
    # epoch_img_path = os.path.join(gif_path, f"epoch_{epoch}.png")
    # os.makedirs(gif_path, exist_ok=True)
    # plt.savefig(epoch_img_path)
    # plt.close()
    
    
    # return epoch_img_path
    frame = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8)  # Updated to buffer_rgba
    frame = frame.reshape(fig.canvas.get_width_height()[::-1] + (4,))  # RGBA has 4 channels
    plt.close(fig)

    # Convert to PIL.Image for GIF
    frame_rgb = frame[:, :, :3] 

    # Return as PIL.Image for GIF creation
    # return Image.fromarray(frame)
    return Image.fromarray(frame_rgb)




In [7]:
import torchvision.models as models

# Define the ResBlock
class ResBlock(nn.Module):
    def __init__(self, channels):
        super(ResBlock, self).__init__()
        self.conv_block = nn.Sequential(
            nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False),
            nn.InstanceNorm2d(channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False),
            nn.InstanceNorm2d(channels),
        )

    def forward(self, x):
        return x + self.conv_block(x)

# Define the CRPBlock
class CRPBlock(nn.Module):
    def __init__(self, in_chans, out_chans, n_stages=4, groups=False):
        super(CRPBlock, self).__init__()
        self.n_stages = n_stages
        groups = in_chans if groups else 1
        self.mini_blocks = nn.ModuleList()
        for _ in range(n_stages):
            self.mini_blocks.append(nn.MaxPool2d(kernel_size=5, stride=1, padding=2))
            self.mini_blocks.append(nn.Conv2d(in_chans, out_chans, kernel_size=1, bias=False, groups=groups))
    
    def forward(self, x):
        out = x
        for block in self.mini_blocks:
            out = block(out)
            x = x + out
        return x

class ResNetBackbone(nn.Module):
    def __init__(self, pretrained=True, feature_dim=256):
        super(ResNetBackbone, self).__init__()
        base_model = models.resnet18(pretrained=pretrained)

        # Freeze pre-trained layers
        for param in base_model.parameters():
            param.requires_grad = False

        # Extract ResNet layers and modify strides/pooling to preserve spatial dimensions
        layers = list(base_model.children())[:-2]  # Remove FC and AvgPool layers
        for layer in layers:
            if isinstance(layer, nn.Conv2d):
                layer.stride = (1, 1)  # Set stride to 1
            elif isinstance(layer, nn.MaxPool2d) or isinstance(layer, nn.AvgPool2d):
                layer.stride = (1, 1)  # Avoid reducing dimensions with pooling layers

        self.features = nn.Sequential(*layers)

        # Adjust final feature dimension using a 1x1 convolution
        self.feature_dim = feature_dim
        self.adjust_channels = nn.Conv2d(base_model.fc.in_features, feature_dim, kernel_size=1, bias=False)

    def forward(self, x):
        x = self.features(x)  # Extract features without changing spatial dimensions
        x = self.adjust_channels(x)  # Adjust feature channels
        return x



In [8]:

class CityScapesDataset(Dataset):
    def __init__(self, root, transform=None, split='train', label_map='id', crop=True):
        """
        
        """
        self.root = root
        self.transform = transform
        self.label_map = label_map
        self.crop = crop

        self.left_paths = glob(os.path.join(root, 'leftImg8bit', split, '**/*.png'))
        self.mask_paths = glob(os.path.join(root, 'gtFine', split, '**/*gtFine_labelIds.png'))
        self.depth_paths = glob(os.path.join(root, 'crestereo_depth2', split, '**/*.npy'))

        sorted(self.left_paths)
        sorted(self.mask_paths)
        sorted(self.depth_paths)

        # get label mappings
        self.id_2_train = {}
        self.id_2_cat = {}
        self.train_2_id = {}
        self.id_2_name = {-1 : 'unlabeled'}
        self.trainid_2_name = {19 : 'unlabeled'} # {255 : 'unlabeled', -1 : 'unlabeled'}

        for lbl in labels:
            self.id_2_train.update({lbl.id : lbl.trainId})
            self.id_2_cat.update({lbl.id : lbl.categoryId})
            if lbl.trainId != 19: # (lbl.trainId > 0) and (lbl.trainId != 255):
                self.trainid_2_name.update({lbl.trainId : lbl.name})
                self.train_2_id.update({lbl.trainId : lbl.id})
            if lbl.id > 0:
                self.id_2_name.update({lbl.id : lbl.name})


    def __getitem__(self, idx):
        left = cv2.cvtColor(cv2.imread(self.left_paths[idx]), cv2.COLOR_BGR2RGB)
        mask = cv2.imread(self.mask_paths[idx], cv2.IMREAD_UNCHANGED).astype(np.uint8)
        depth = np.load(self.depth_paths[idx]) # data is type float16

        if self.crop:
            left = left[:800, :, :]
            mask = mask[:800, :]
            depth = depth[:800, :]

        # get label id
        if self.label_map == 'id':
            mask[mask==-1] == 0
        elif self.label_map == 'trainId':
            for _id, train_id in self.id_2_train.items():
                mask[mask==_id] = train_id
        elif self.label_map == 'categoryId':
            for _id, train_id in self.id_2_cat.items():
                mask[mask==_id] = train_id

        sample = {'left' : left, 'mask' : mask, 'depth' : depth}

        if self.transform:
            sample = self.transform(sample)

        # ensure that no depth values are less than 0
        depth[depth < 0] = 0

        return sample
    

    def __len__(self):
        print(f"Number of RGB images: {len(self.left_paths)}")
        print(f"Number of Mask images: {len(self.mask_paths)}")
        print(f"Number of depth images: {len(self.depth_paths)}")
        return len(self.left_paths)
    
    

In [9]:
OG_W, OG_H = 2048, 800 # OG width and height after crop
W, H = OG_W//4, OG_H//4 # resize w,h for training

transform = transforms.Compose([
    ToTensor(),
    RandomCrop(H, W),
    # ElasticTransform(alpha=100.0, sigma=25.0, prob=0.5),
    AddColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),
    RandomHorizontalFlip(0.5),
    RandomVerticalFlip(0.2),
    # RandomRotate((-30, 30)),
    Normalize()
])

valid_transform = transforms.Compose([
    ToTensor(),
    Rescale(H, W),
    Normalize()
])

test_transform = transforms.Compose([
    ToTensor(),
    Normalize()
])


BATCH_SIZE = 8
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'



train_dataset = CityScapesDataset(root, transform=transform, split='train', label_map='trainId') # 'trainId')
train_subset = Subset(train_dataset, indices=range(2968)) #2968
# train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, pin_memory=True, shuffle=True)
train_loader = DataLoader(train_subset, batch_size=BATCH_SIZE, pin_memory=True, shuffle=True)


valid_dataset = CityScapesDataset(root, transform=valid_transform, split='val', label_map='trainId')
val_subset = Subset(valid_dataset, indices=range(496)) #496
# valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, pin_memory=True, shuffle=False)
valid_loader = DataLoader(val_subset, batch_size=BATCH_SIZE, pin_memory=True, shuffle=False)


# shared Generator

In [10]:
# class SharedGenerator(nn.Module):
#     def __init__(self):
#         """
#         Shared Generator for both tasks.
#         Contains shared layers for skip connection processing and refinement.
#         """
#         super(SharedGenerator, self).__init__()
#         # Shared convolution layers to process each skip connection
#         self.shared_conv1 = nn.Conv2d(576, 256, kernel_size=1, bias=False)  # Process l11_out (1/32)
#         self.shared_conv2 = nn.Conv2d(576, 256, kernel_size=1, bias=False)  # Process l7_out (1/16)
#         self.shared_conv3 = nn.Conv2d(576, 256, kernel_size=1, bias=False)  # Process l3_out (1/8)
#         self.shared_conv4 = nn.Conv2d(576, 256, kernel_size=1, bias=False)  # Process l1_out (1/4)

#         # Shared CRP blocks for refinement
#         self.shared_crp1 = CRPBlock(256, 256, n_stages=4)  # CRP for 1/32
#         self.shared_crp2 = CRPBlock(256, 256, n_stages=4)  # CRP for 1/16
#         self.shared_crp3 = CRPBlock(256, 256, n_stages=4)  # CRP for 1/8
#         self.shared_crp4 = CRPBlock(256, 256, n_stages=4)  # CRP for 1/4

#     def forward(self, skips):
#         """
#         Process skips with shared layers for task-specific generation.
#         Args:
#             skips (dict): Skip connections from the encoder.
#         Returns:
#             dict: Processed skip connections.
#         """
#         x1 = self.shared_crp1(self.shared_conv1(skips["l11_out"]))
#         x2 = self.shared_crp2(self.shared_conv2(skips["l7_out"]))
#         x3 = self.shared_crp3(self.shared_conv3(skips["l3_out"]))
#         x4 = self.shared_crp4(self.shared_conv4(skips["l1_out"]))

#         return {"x1": x1, "x2": x2, "x3": x3, "x4": x4}


In [11]:
# class SharedPix2PixGenerator(nn.Module):
#     def __init__(self, seg_output_channels=20, depth_output_channels=1):
#         """
#         Shared Pix2Pix Generator for Segmentation and Depth tasks.
#         Args:
#             seg_output_channels (int): Number of output channels for segmentation.
#             depth_output_channels (int): Number of output channels for depth estimation.
#         """
#         super(SharedPix2PixGenerator, self).__init__()
#         self.shared_generator = SharedGenerator()
#         self.seg_output_layer = TaskOutputLayer(output_channels=seg_output_channels)
#         self.depth_output_layer = TaskOutputLayer(output_channels=depth_output_channels)

#     def forward(self, skips, input_size):
#         """
#         Forward pass for both tasks.
#         Args:
#             skips (dict): Skip connections from the encoder.
#             input_size (tuple): Original input size (H, W).
#         Returns:
#             dict: Outputs for segmentation and depth tasks.
#         """
#         shared_features = self.shared_generator(skips)

#         # Task-specific outputs
#         seg_output = self.seg_output_layer(shared_features["x4"], input_size)
#         depth_output = self.depth_output_layer(shared_features["x4"], input_size)

#         return {
#             "seg_output": seg_output,
#             "depth_output": depth_output
#         }


# Saving batch gif code And function to plot al losses


In [12]:
def save_training_visualization_as_gif2(epoch, inputs, seg_output, depth_output, seg_labels, depth_labels):
    inputs = inputs.detach().cpu()
    seg_output = torch.argmax(seg_output, dim=1).detach().cpu()
    depth_output = depth_output.detach().cpu()
    seg_labels = seg_labels.detach().cpu()
    depth_labels = depth_labels.detach().cpu()
    
#     inputs_rgb = (inputs - inputs.min()) / (inputs.max() - inputs.min() + 1e-5)  # Normalize inputs to [0, 1]
    
#     # Normalize depth maps for visualization
#     depth_labels_vis = (depth_labels - depth_labels.min()) / (depth_labels.max() - depth_labels.min() + 1e-5)
#     depth_preds_vis = (depth_output - depth_output.min()) / (depth_output.max() - depth_output.min() + 1e-5)



    batch_size = min(4, inputs.size(0))  # Limit to 4 samples for visualization
    fig, axes = plt.subplots(batch_size, 5, figsize=(15, 4 * batch_size))

    for i in range(batch_size):
        
        inputs_temp = inputs[i]
        # print(f"inputs_temp: {inputs_temp.shape}")
        inputs_rgb = (inputs_temp - inputs_temp.min()) / (inputs_temp.max() - inputs_temp.min() + 1e-5)  # Normalize inputs to [0, 1]
        
        depth_labels_vis = (depth_labels[i] - depth_labels[i].min()) / (depth_labels[i].max() - depth_labels[i].min() + 1e-5)
        depth_preds = depth_output[i]
        depth_preds_vis = (depth_preds - depth_preds.min()) / (depth_preds.max() - depth_preds.min() + 1e-5)
        # print(f"depth_labels_vis: {depth_labels_vis.shape}")
        # print(f"depth_preds_vis: {depth_preds_vis.shape}")

    
        
        # Row 1: Ground truth
        axes[i, 0].imshow(inputs_rgb.permute(1, 2, 0))
        axes[i, 0].set_title("RGB Image")
        axes[i, 0].axis("off")

        axes[i, 1].imshow(seg_labels[i], cmap="tab20")
        axes[i, 1].set_title("GT Segmentation")
        axes[i, 1].axis("off")

        axes[i, 2].imshow(depth_labels_vis.squeeze(), cmap="inferno")
        axes[i, 2].set_title("GT Depth")
        axes[i, 2].axis("off")

        # Row 2: Predictions
        axes[i, 3].imshow(seg_output[i], cmap="tab20")
        axes[i, 3].set_title("Generated Segmentation")
        axes[i, 3].axis("off")

        axes[i, 4].imshow(depth_preds_vis.squeeze(), cmap="inferno")
        axes[i, 4].set_title("Generated Depth")
        axes[i, 4].axis("off")
        
    # Remove axes for cleaner visualization
    for ax in axes.flat:
        ax.axis("off")


    # plt.tight_layout()
    fig.tight_layout()
    fig.canvas.draw()
    
    # # Save current epoch as an image for GIF
    # epoch_img_path = os.path.join(gif_path, f"epoch_{epoch}.png")
    # os.makedirs(gif_path, exist_ok=True)
    # plt.savefig(epoch_img_path)
    # plt.close()
    
    
    # return epoch_img_path
    frame = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8)  # Updated to buffer_rgba
    frame = frame.reshape(fig.canvas.get_width_height()[::-1] + (4,))  # RGBA has 4 channels
    plt.close(fig)

    # Convert to PIL.Image for GIF
    frame_rgb = frame[:, :, :3] 

    # Return as PIL.Image for GIF creation
    # return Image.fromarray(frame)
    return Image.fromarray(frame_rgb)

def plot_all_losses(train_losses,valid_losses,save_dir):
    # Plot training and validation losses
    for key in train_losses.keys():
        plt.figure()
        plt.plot(train_losses[key], label=f"Train {key}")
        plt.plot(valid_losses[key], label=f"Valid {key}")
        plt.xlabel("Epoch")
        plt.ylabel(key.replace("_", " ").title())
        plt.legend()
        plt.savefig(os.path.join(save_dir, f"{key}_loss.png"))
        plt.close()


# Loss Function

In [13]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models

class PerceptualLoss(nn.Module):
    def __init__(self, pretrained_model="vgg16", layers=["relu3_3"], device="cuda"):
        """
        Perceptual loss class.

        Args:
            pretrained_model (str): Pretrained model to use (e.g., "vgg16").
            layers (list of str): Layers to extract features from.
            device (str): Device to load the pretrained model on ("cuda" or "cpu").
        """
        super().__init__()

        # Load pretrained model
        if pretrained_model == "vgg16":
            vgg = models.vgg16(weights=VGG16_Weights.IMAGENET1K_V1).features.to(device).eval()
        else:
            raise ValueError(f"Unsupported pretrained model: {pretrained_model}")

        # Freeze the parameters
        for param in vgg.parameters():
            param.requires_grad = False

        # Select layers
        self.layers = layers
        self.feature_extractor = nn.ModuleDict({
            layer: vgg[:i] for i, layer in enumerate(vgg._modules.keys()) if layer in self.layers
        })

    def forward(self, generated, target):
        """
        Compute perceptual loss between generated and target images.

        Args:
            generated (torch.Tensor): Generated image batch.
            target (torch.Tensor): Target image batch.

        Returns:
            torch.Tensor: MSE loss between extracted features.
        """
        loss = 0.0
        for layer_name, extractor in self.feature_extractor.items():
            gen_features = extractor(generated)
            target_features = extractor(target)
            loss += F.mse_loss(gen_features, target_features)
        return loss

def scale_invariant_depth_loss(pred, target, lambda_weight=0.1):
    if pred.shape != target.shape:
        pred = F.interpolate(pred, size=target.shape[1:], mode='bilinear', align_corners=False)
    
    diff = pred - target
    n = diff.numel()
    mse = torch.sum(diff**2) / n
    scale_invariant = mse - (lambda_weight / (n**2)) * (torch.sum(diff))**2
    return scale_invariant

def depth_smoothness_loss(pred, img, alpha=1.0):
    depth_grad_x = torch.abs(pred[:, :, :, :-1] - pred[:, :, :, 1:])
    depth_grad_y = torch.abs(pred[:, :, :-1, :] - pred[:, :, 1:, :])
    img_grad_x = torch.mean(torch.abs(img[:, :, :, :-1] - img[:, :, :, 1:]), dim=1, keepdim=True)
    img_grad_y = torch.mean(torch.abs(img[:, :, :-1, :] - img[:, :, 1:, :]), dim=1, keepdim=True)
    smoothness_x = depth_grad_x * torch.exp(-alpha * img_grad_x)
    smoothness_y = depth_grad_y * torch.exp(-alpha * img_grad_y)
    return smoothness_x.mean() + smoothness_y.mean()


def inv_huber_loss(pred, target, delta=0.1):
    """
    Inverse Huber loss for depth prediction.
    Args:
        pred (Tensor): Predicted depth map.
        target (Tensor): Ground truth depth map.
        delta (float): Threshold for switching between quadratic and linear terms.
    Returns:
        Tensor: Inverse Huber loss.
    """
    abs_diff = torch.abs(pred - target)
    delta_tensor = torch.tensor(delta, dtype=abs_diff.dtype, device=abs_diff.device)  # Convert delta to tensor
    quadratic = torch.minimum(abs_diff, delta_tensor)
    linear = abs_diff - quadratic
    return (0.5 * quadratic**2 + delta_tensor * linear).mean()


def mean_iou(pred, target, num_classes):
    pred = torch.argmax(pred, dim=1)
    intersection = torch.logical_and(pred == target, target != 255).float()  # Ignore class 255
    union = torch.logical_or(pred == target, target != 255).float()
    iou = torch.sum(intersection) / torch.sum(union)
    return iou



def contrastive_loss(pred, target, margin=1.0):
    """
    Contrastive loss to ensure the depth map predictions are closer to the target.
    """
    # Flatten the tensors for element-wise operations
    pred_flat = pred.view(pred.size(0), -1)  # Flatten except for the batch dimension
    target_flat = target.view(target.size(0), -1)  # Flatten except for the batch dimension

    # Compute the pairwise distances
    distances = torch.sqrt(torch.sum((pred_flat - target_flat) ** 2, dim=1))  # Batch-wise distances

    # Create labels for contrastive loss
    labels = (torch.abs(pred_flat - target_flat).mean(dim=1) < margin).float()  # Batch-wise labels

    # Calculate contrastive loss
    similar_loss = labels * distances**2
    dissimilar_loss = (1 - labels) * torch.clamp(margin - distances, min=0)**2
    loss = (similar_loss + dissimilar_loss).mean()

    return loss


def dice_loss(predictions, targets, smooth=1e-6):
    """
    Calculate Dice Loss for segmentation.
    Args:
        predictions (torch.Tensor): The predicted segmentation map (logits or probabilities).
                                    Shape: [batch_size, num_classes, height, width]
        targets (torch.Tensor): The ground truth segmentation map (one-hot encoded or integer labels).
                                Shape: [batch_size, height, width]
        smooth (float): Smoothing factor to avoid division by zero.
    Returns:
        torch.Tensor: Dice Loss (scalar).
    """
    # Convert integer labels to one-hot if needed
    if predictions.shape != targets.shape:
        targets = F.one_hot(targets, num_classes=predictions.shape[1]).permute(0, 3, 1, 2).float()
    
    # Apply softmax to predictions for multi-class segmentation
    predictions = torch.softmax(predictions, dim=1)
    
    # Flatten tensors to calculate intersection and union
    predictions_flat = predictions.view(predictions.shape[0], predictions.shape[1], -1)
    targets_flat = targets.view(targets.shape[0], targets.shape[1], -1)
    
    # Calculate intersection and union
    intersection = (predictions_flat * targets_flat).sum(dim=-1)
    union = predictions_flat.sum(dim=-1) + targets_flat.sum(dim=-1)
    
    # Calculate Dice Coefficient
    dice_coeff = (2 * intersection + smooth) / (union + smooth)
    
    # Dice Loss
    return 1 - dice_coeff.mean()


def initialize_optimizers_and_schedulers(model, lr_gen=1e-4, lr_disc=1e-4, weight_decay=1e-4):
    """
    Initialize optimizers and schedulers for all generators and discriminators.
    
    Args:
        model (nn.Module): MultiTaskModel instance.
        lr_gen (float): Learning rate for generators.
        lr_disc (float): Learning rate for discriminators.
        weight_decay (float): Weight decay for optimizers.
    
    Returns:
        dict: Optimizers and schedulers for generators and discriminators.
    """
    # Optimizers for shared generator
    optimizer_shared_gen = torch.optim.AdamW(
        model.feature_generator.parameters(),
        lr=lr_gen,
        weight_decay=weight_decay
    )
    scheduler_shared_gen = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer_shared_gen, T_max=50, eta_min=1e-6
    )

    # Optimizer and scheduler for the shared generator's refinement layer
    optimizer_shared_refine = torch.optim.AdamW(
        model.shared_generator.parameters(),
        lr=lr_gen,
        weight_decay=weight_decay
    )
    scheduler_shared_refine = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer_shared_refine, T_max=50, eta_min=1e-6
    )

    # Optimizers and schedulers for task-specific generators
    optimizer_seg_gen = torch.optim.AdamW(
        model.seg_output_layer.parameters(),
        lr=lr_gen,
        weight_decay=weight_decay
    )
    scheduler_seg_gen = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer_seg_gen, mode='min', factor=0.5, patience=5
    )

    optimizer_depth_gen = torch.optim.AdamW(
        model.depth_output_layer.parameters(),
        lr=lr_gen,
        weight_decay=weight_decay
    )
    scheduler_depth_gen = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer_depth_gen, mode='min', factor=0.5, patience=5
    )

    # Optimizers and schedulers for task-specific discriminators
    optimizer_seg_disc = torch.optim.AdamW(
        model.seg_discriminator.parameters(),
        lr=lr_disc,
        weight_decay=weight_decay
    )
    scheduler_seg_disc = torch.optim.lr_scheduler.StepLR(
        optimizer_seg_disc, step_size=20, gamma=0.1
    )

    optimizer_depth_disc = torch.optim.AdamW(
        model.depth_discriminator.parameters(),
        lr=lr_disc,
        weight_decay=weight_decay
    )
    scheduler_depth_disc = torch.optim.lr_scheduler.StepLR(
        optimizer_depth_disc, step_size=20, gamma=0.1
    )

    # Optimizer and scheduler for the multi-task discriminator
    optimizer_multi_task_disc = torch.optim.AdamW(
        model.multi_task_discriminator.parameters(),
        lr=lr_disc,
        weight_decay=weight_decay
    )
    scheduler_multi_task_disc = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer_multi_task_disc, T_max=50, eta_min=1e-6
    )

    return {
        "optimizers": {
            "shared_gen": optimizer_shared_gen,
            "shared_refine": optimizer_shared_refine,
            "seg_gen": optimizer_seg_gen,
            "depth_gen": optimizer_depth_gen,
            "seg_disc": optimizer_seg_disc,
            "depth_disc": optimizer_depth_disc,
            "multi_task_disc": optimizer_multi_task_disc
        },
        "schedulers": {
            "shared_gen": scheduler_shared_gen,
            "shared_refine": scheduler_shared_refine,
            "seg_gen": scheduler_seg_gen,
            "depth_gen": scheduler_depth_gen,
            "seg_disc": scheduler_seg_disc,
            "depth_disc": scheduler_depth_disc,
            "multi_task_disc": scheduler_multi_task_disc
        }
    }


## MultiTaskModel

In [14]:
class MobileNetV3Backbone(nn.Module):
    def __init__(self, backbone):
        super().__init__()

        self.backbone = backbone
        self.proj_l1 = nn.Conv2d(16, 576, kernel_size=1, bias=False)   # For l1_out (1/4 resolution)
        self.proj_l3 = nn.Conv2d(24, 576, kernel_size=1, bias=False)  # For l3_out (1/8 resolution)
        self.proj_l7 = nn.Conv2d(48, 576, kernel_size=1, bias=False)  # For l7_out (1/16 resolution)
        self.proj_l11 = nn.Conv2d(96, 576, kernel_size=1, bias=False) # For l11_out (1/32 resolution)

    
    def forward(self, x):
        """ Passes input theough MobileNetV3 backbone feature extraction layers
            layers to add connections to (0 indexed)
                - 1:  1/4 res
                - 3:  1/8 res
                - 7, 8:  1/16 res
                - 10, 11: 1/32 res
           """
        # skips = nn.ParameterDict()
        # for i in range(len(self.backbone) - 1):
        #     x = self.backbone[i](x)
        #     # add skip connection outputs
        #     if i in [1, 3, 7, 11]:
        #         skips.update({f"l{i}_out" : x})

        # return skips
        skips = {}  # Dictionary to store skip connections

        for i, layer in enumerate(self.backbone):
            x = layer(x)
            # Add skip connections for specific layers
            if i == 1:
                skips["l1_out"] = self.proj_l1(x)  # Project l1_out
            elif i == 3:
                skips["l3_out"] = self.proj_l3(x)  # Project l3_out
            elif i == 7:
                skips["l7_out"] = self.proj_l7(x)  # Project l7_out
            elif i == 11:
                skips["l11_out"] = self.proj_l11(x)  # Project l11_out

        return skips

In [15]:
class EnhancedSharedGenerator(nn.Module):
    def __init__(self):
        """
        Enhanced Shared Generator for Segmentation and Depth tasks.
        Includes additional refinement layers for better generalization.
        """
        super(EnhancedSharedGenerator, self).__init__()
        # Shared convolution layers to process each skip connection
        self.shared_conv1 = nn.Conv2d(576, 256, kernel_size=1, bias=False)  # Process l11_out (1/32)
        self.shared_conv2 = nn.Conv2d(576, 256, kernel_size=1, bias=False)  # Process l7_out (1/16)
        self.shared_conv3 = nn.Conv2d(576, 256, kernel_size=1, bias=False)  # Process l3_out (1/8)
        self.shared_conv4 = nn.Conv2d(576, 256, kernel_size=1, bias=False)  # Process l1_out (1/4)

        # CRP blocks for refinement
        self.shared_crp1 = CRPBlock(256, 256, n_stages=4)
        self.shared_crp2 = CRPBlock(256, 256, n_stages=4)
        self.shared_crp3 = CRPBlock(256, 256, n_stages=4)
        self.shared_crp4 = CRPBlock(256, 256, n_stages=4)

        # Additional refinement layers
        self.refine = nn.Sequential(
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1)
        )

    def forward(self, skips):
        """
        Process skips with shared layers for task-specific generation.
        Args:
            skips (dict): Skip connections from the encoder.
        Returns:
            dict: Processed skip connections.
        """
        x1 = self.shared_crp1(self.shared_conv1(skips["l11_out"]))
        x1 = self.refine(x1)  # Extra refinement

        x2 = self.shared_crp2(self.shared_conv2(skips["l7_out"]))
        x2 = self.refine(x2)

        x3 = self.shared_crp3(self.shared_conv3(skips["l3_out"]))
        x3 = self.refine(x3)

        x4 = self.shared_crp4(self.shared_conv4(skips["l1_out"]))
        x4 = self.refine(x4)

        return {"x1": x1, "x2": x2, "x3": x3, "x4": x4}


In [16]:
class TaskOutputLayer(nn.Module):
    def __init__(self, output_channels):
        """
        Task-specific output layers for generating final predictions.
        Args:
            output_channels (int): Number of output channels (e.g., 20 for segmentation, 1 for depth).
        """
        super(TaskOutputLayer, self).__init__()
        self.final_conv = nn.Conv2d(256, output_channels, kernel_size=3, padding=1)

    def forward(self, x, input_size):
        """
        Generate task-specific output.
        Args:
            x (Tensor): Input feature map.
            input_size (tuple): Original input size (H, W).
        Returns:
            Tensor: Task-specific output.
        """
        x = self.final_conv(x)
        return nn.functional.interpolate(x, size=input_size, mode="bilinear", align_corners=False)



In [17]:
class TaskSpecificDiscriminator(nn.Module):
    def __init__(self, input_channels):
        super(TaskSpecificDiscriminator, self).__init__()
        self.adapt_conv = nn.Conv2d(input_channels+input_channels, input_channels, kernel_size=1, bias=False)
        self.model = nn.Sequential(
            nn.Conv2d(input_channels, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 1, kernel_size=4, stride=1, padding=1)
        )

    def forward(self, task_output, labels=None):
        """
        Forward pass through the discriminator.

        Args:
            task_output (Tensor): Output from the generator (e.g., seg_output or depth_output).
            labels (Tensor, optional): Ground truth labels. If provided, aligns channels with task_output.

        Returns:
            Tensor: Discriminator's prediction.
        """
        if labels is not None:
            # Ensure labels match the shape of task_output
            if labels.dim() < task_output.dim():
                labels = labels.unsqueeze(1)  # Add channel dimension if needed
            if labels.size(1) != task_output.size(1):
                labels = torch.nn.functional.one_hot(labels.squeeze(1), num_classes=task_output.size(1))
                labels = labels.permute(0, 3, 1, 2).float().to(task_output.device)
            combined = torch.cat([task_output, labels], dim=1)
            combined = self.adapt_conv(combined)
        else:
            combined = task_output

        return self.model(combined)


In [18]:

class MultiTaskDiscriminator(nn.Module):
    def __init__(self, input_channels):
        """
        Multi-Task Discriminator for evaluating all task-specific outputs.
        Args:
            input_channels (int): Number of input channels for concatenated features and outputs.
        """
        super(MultiTaskDiscriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(input_channels, 128, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1)
        )

    def forward(self, inputs):
        """
        Evaluate input image and task-specific outputs.
        Args:
            inputs (Tensor): Input image.
            outputs (list[Tensor]): List of task-specific outputs.
        Returns:
            Tensor: Discriminator output.
        """
        # combined = torch.cat([inputs] + outputs, dim=1)
        return self.model(inputs)


In [19]:
class MultiTaskModel(nn.Module):
    def __init__(self, backbone, num_seg_classes=20, depth_channels=1):
        """
        Multi-task model with shared Pix2Pix Generator, task-specific discriminators,
        and a multi-task discriminator.
        Args:
            backbone (nn.Module): Encoder backbone for feature extraction.
            num_seg_classes (int): Number of segmentation classes.
            depth_channels (int): Number of output channels for depth.
        """
        super(MultiTaskModel, self).__init__()
        self.feature_generator = MobileNetV3Backbone(backbone)
        self.shared_generator = EnhancedSharedGenerator()
        self.seg_output_layer = TaskOutputLayer(output_channels=num_seg_classes)
        self.depth_output_layer = TaskOutputLayer(output_channels=depth_channels)

        # Task-specific discriminators
        self.seg_discriminator = TaskSpecificDiscriminator(input_channels=num_seg_classes)
        self.depth_discriminator = TaskSpecificDiscriminator(input_channels=depth_channels)

        # Multi-task discriminator
        self.multi_task_discriminator = MultiTaskDiscriminator(input_channels=3 + num_seg_classes + depth_channels)
        
    def forward(self, inputs, input_size, seg_labels=None, depth_labels=None, return_discriminator_outputs=False):
        # Extract features from the encoder
        skips = self.feature_generator(inputs)
        shared_features = self.shared_generator(skips)

        # Task-specific outputs
        seg_output = self.seg_output_layer(shared_features["x4"], input_size)
        depth_output = self.depth_output_layer(shared_features["x4"], input_size)

        output_dict = {
            "seg_output": seg_output,
            "depth_output": depth_output,
        }

        if return_discriminator_outputs:
            
            # Detach outputs to prevent discriminator backward from interfering with the generator
            seg_output_detached = seg_output.detach()
            depth_output_detached = depth_output.detach()
            
            # Adversarial feedback from task-specific discriminators
            seg_real_disc = self.seg_discriminator(seg_output_detached, seg_labels) if seg_labels is not None else None
            seg_fake_disc = self.seg_discriminator(seg_output_detached, None)

            depth_real_disc = self.depth_discriminator(depth_output_detached, depth_labels) if depth_labels is not None else None
            depth_fake_disc = self.depth_discriminator(depth_output_detached, None)

            # Multi-task discriminator feedback
            combined_real_input = torch.cat([inputs, seg_labels, depth_labels], dim=1) if seg_labels is not None and depth_labels is not None else None
            combined_fake_input = torch.cat([inputs, seg_output, depth_output], dim=1)

            combined_real_disc = self.multi_task_discriminator(combined_real_input) if combined_real_input is not None else None
            combined_fake_disc = self.multi_task_discriminator(combined_fake_input.detach())

            output_dict.update({
                "seg_real_disc": seg_real_disc,
                "seg_fake_disc": seg_fake_disc,
                "depth_real_disc": depth_real_disc,
                "depth_fake_disc": depth_fake_disc,
                "combined_real_disc": combined_real_disc,
                "combined_fake_disc": combined_fake_disc,
            })

        return output_dict


#     def forward(self, inputs, input_size, seg_labels=None, depth_labels=None, return_discriminator_outputs=False):
#         """
#         Forward pass for multi-task model.
#         Args:
#             inputs (Tensor): Input images.
#             input_size (tuple): Original input size.
#             seg_labels (Tensor, optional): Ground truth segmentation labels. Required for discriminator feedback.
#             depth_labels (Tensor, optional): Ground truth depth labels. Required for discriminator feedback.
#             return_discriminator_outputs (bool): If True, returns discriminator outputs for adversarial loss.
#         Returns:
#             dict: Outputs for segmentation and depth tasks, and optionally discriminator outputs.
#         """
#         skips = self.feature_generator(inputs)
#         shared_features = self.shared_generator(skips)

#         # Task-specific outputs
#         seg_output = self.seg_output_layer(shared_features["x4"], input_size)
#         depth_output = self.depth_output_layer(shared_features["x4"], input_size)

#         if return_discriminator_outputs:
#             # Adversarial feedback from task-specific discriminators
#             seg_real_disc = self.seg_discriminator(seg_output, seg_labels) if seg_labels is not None else None
#             depth_real_disc = self.depth_discriminator(depth_output, depth_labels) if depth_labels is not None else None

#             # Multi-task discriminator feedback
#             combined_real_disc = self.multi_task_discriminator(inputs, [seg_output, depth_output])

#             return {
#                 "seg_output": seg_output,
#                 "depth_output": depth_output,
#                 "seg_real_disc": seg_real_disc,
#                 "depth_real_disc": depth_real_disc,
#                 "combined_real_disc": combined_real_disc
#             }

#         return {
#             "seg_output": seg_output,
#             "depth_output": depth_output
#         }


# Saving loss charts

In [20]:
def plot_all_losses(epoch, train_losses,valid_losses,save_dir):
    # Plot training and validation losses
    for key in train_losses.keys():
        plt.figure()
        plt.plot(train_losses[key], label=f"Train {key}")
        plt.plot(valid_losses[key], label=f"Valid {key}")
        plt.xlabel("Epoch")
        plt.ylabel(key.replace("_", " ").title())
        plt.legend()
        plt.savefig(os.path.join(save_dir, f"{key}_loss_after_epoch_{epoch}.png"))
        plt.close()

# training

In [21]:
def train_model_with_adversarial_loss_tracking(
    model, train_loader, valid_loader, num_epochs, device, opt_sched, save_dir="results"
):
    """
    Trains a multi-task model with adversarial feedback and tracks losses.
    
    Args:
        model: Multi-task model with integrated generators and discriminators.
        train_loader: DataLoader for training data.
        valid_loader: DataLoader for validation data.
        num_epochs: Number of epochs to train.
        device: Device for training ("cuda" or "cpu").
        opt_sched: Dictionary of optimizers and schedulers for generators and discriminators.
        save_dir: Directory to save results.
    
    Returns:
        train_losses, valid_losses: Lists of losses for training and validation.
    """
    # Create directories for saving results
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    save_dir = os.path.join(save_dir, timestamp)
    os.makedirs(save_dir, exist_ok=True)

    # Prepare CSV file for loss tracking
    csv_path = os.path.join(save_dir, f"loss_tracking_{timestamp}.csv")
    gif_path = os.path.join(save_dir, f"training_visualization_{timestamp}.gif")
    
    with open(csv_path, "w", newline="") as f:
        writer = csv.writer(f)
        writer.writerow([
            "epoch", "train_seg_loss", "train_depth_loss", "train_combined_loss",
            "train_adv_loss", 
            # "train_seg_iou",
            "valid_seg_loss", "valid_depth_loss", "valid_combined_loss",
            "valid_adv_loss", 
            # "valid_seg_iou"
        ])

    # Initialize tracking variables
    train_losses = {"seg": [], "depth": [], "combined": [], "adv": []}
    valid_losses = {"seg": [], "depth": [], "combined": [], "adv": []}
    best_combined_loss = float("inf")
    gif_frames =[]
    perceptual_loss_fn = PerceptualLoss(pretrained_model="vgg16").to(device)

    # Start training loop
    for epoch in range(num_epochs):
        model.train()
        epoch_train = {key: 0.0 for key in train_losses.keys()}
        num_batches = 0
        
        
        
        with tqdm(total=len(train_loader), desc=f"Epoch {epoch+1}/{num_epochs} - Training", unit="batch") as pbar:
            for batch in train_loader:
                inputs, seg_labels, depth_labels = (
                    batch["left"].to(device),
                    batch["mask"].to(device),
                    batch["depth"].to(device),
                )
                input_size = inputs.size()[-2:]

                # Preprocess seg_labels to one-hot encoding
                if seg_labels.size(1) == 1:  # If class indices are given
                    seg_labels = torch.nn.functional.one_hot(seg_labels.squeeze(1), num_classes=20)
                    seg_labels = seg_labels.permute(0, 3, 1, 2).float().to(device)  # Convert to [B, C, H, W]

                # Ensure depth_labels has correct dimensions
                if depth_labels.dim() == 5:  # If depth_labels has extra dimensions
                    depth_labels = depth_labels.squeeze(2)

                # Zero gradients
                for optimizer in opt_sched["optimizers"].values():
                    optimizer.zero_grad()

                # Forward pass with discriminator outputs
                outputs = model(
                    inputs,
                    input_size=input_size,
                    seg_labels=seg_labels,
                    depth_labels=depth_labels,
                    return_discriminator_outputs=True,
                )

                # Generator losses
                seg_loss = nn.CrossEntropyLoss()(outputs["seg_output"], seg_labels) + \
                           dice_loss(outputs["seg_output"], seg_labels)
                depth_loss = scale_invariant_depth_loss(outputs["depth_output"], depth_labels) + \
                             inv_huber_loss(outputs["depth_output"], depth_labels) + \
                             depth_smoothness_loss(outputs["depth_output"], inputs)
                
                seg_perceptual_loss = perceptual_loss_fn(outputs["seg_output"], seg_labels.unsqueeze(1))
                depth_perceptual_loss = perceptual_loss_fn(outputs["depth_output"], depth_labels)
                
                seg_loss = seg_loss + 0.1 * seg_perceptual_loss
                depth_loss = depth_loss + 0.1 * depth_perceptual_loss
                

                adv_loss = -(
                    torch.mean(outputs["seg_real_disc"]) +
                    torch.mean(outputs["depth_real_disc"]) +
                    torch.mean(outputs["combined_real_disc"])
                )

                combined_loss = seg_loss + depth_loss + 0.01 * adv_loss

                # Backpropagation for generators
                combined_loss.backward(retain_graph=True)
                # opt_sched["optimizers"]["generator"].step()
                opt_sched["optimizers"]["shared_gen"].step()
                opt_sched["optimizers"]["shared_refine"].step()
                opt_sched["optimizers"]["seg_gen"].step()
                opt_sched["optimizers"]["depth_gen"].step()


                # Update task-specific discriminators
                for task, disc_optimizer in [
                    ("seg", "seg_disc"),
                    ("depth", "depth_disc"),
                ]:
                    opt_sched["optimizers"][disc_optimizer].zero_grad()
                    real_disc_loss = torch.mean(
                        (outputs[f"{task}_real_disc"] - 1) ** 2
                    )
                    fake_disc_loss = torch.mean(
                        (outputs[f"{task}_fake_disc"].detach()) ** 2
                    )
                    disc_loss = (real_disc_loss + fake_disc_loss) / 2
                    disc_loss.backward()
                    opt_sched["optimizers"][disc_optimizer].step()

                # Update multi-task discriminator
                opt_sched["optimizers"]["multi_task_disc"].zero_grad()
                real_combined_loss = torch.mean(
                    (outputs["combined_real_disc"] - 1) ** 2
                )
                fake_combined_loss = torch.mean(
                    (outputs["combined_fake_disc"].detach()) ** 2
                )
                combined_disc_loss = (real_combined_loss + fake_combined_loss) / 2
                combined_disc_loss.backward()
                opt_sched["optimizers"]["multi_task_disc"].step()

                # Update training metrics
                epoch_train["seg"] += seg_loss.item()
                epoch_train["depth"] += depth_loss.item()
                epoch_train["combined"] += combined_loss.item()
                epoch_train["adv"] += adv_loss.item()
                # epoch_train["iou"] += mean_iou(outputs["seg_output"], seg_labels, num_classes=20).item()
                num_batches += 1

            # Average training metrics
            for key in epoch_train.keys():
                train_losses[key].append(epoch_train[key] / num_batches)

            # Validation loop
            model.eval()
            epoch_valid = {key: 0.0 for key in valid_losses.keys()}
            num_valid_batches = 0

            with torch.no_grad():
                for batch in valid_loader:
                    inputs, seg_labels, depth_labels = (
                        batch["left"].to(device),
                        batch["mask"].to(device),
                        batch["depth"].to(device),
                    )
                    input_size = inputs.size()[-2:]

                    # Preprocess seg_labels to one-hot encoding
                    if seg_labels.size(1) == 1:
                        seg_labels = torch.nn.functional.one_hot(seg_labels.squeeze(1), num_classes=20)
                        seg_labels = seg_labels.permute(0, 3, 1, 2).float().to(device)

                    # Ensure depth_labels has correct dimensions
                    if depth_labels.dim() == 5:
                        depth_labels = depth_labels.squeeze(2)


                    # Forward pass
                    outputs = model(
                        inputs,
                        input_size=input_size,
                        seg_labels=seg_labels,
                        depth_labels=depth_labels,
                        return_discriminator_outputs=True,
                    )

                    # Validation loss calculations
                    seg_loss = nn.CrossEntropyLoss()(outputs["seg_output"], seg_labels) + \
                               dice_loss(outputs["seg_output"], seg_labels)
                    depth_loss = scale_invariant_depth_loss(outputs["depth_output"], depth_labels) + \
                                 inv_huber_loss(outputs["depth_output"], depth_labels) + \
                                 depth_smoothness_loss(outputs["depth_output"], inputs)
                    
                    seg_perceptual_loss = perceptual_loss_fn(outputs["seg_output"], seg_labels.unsqueeze(1))
                    depth_perceptual_loss = perceptual_loss_fn(outputs["depth_output"], depth_labels)

                    seg_loss = seg_loss + 0.1 * seg_perceptual_loss
                    depth_loss = depth_loss + 0.1 * depth_perceptual_loss


                    adv_loss = -(
                        torch.mean(outputs["seg_real_disc"]) +
                        torch.mean(outputs["depth_real_disc"]) +
                        torch.mean(outputs["combined_real_disc"])
                    )

                    combined_loss = seg_loss + depth_loss + 0.01 * adv_loss

                    # Update validation metrics
                    epoch_valid["seg"] += seg_loss.item()
                    epoch_valid["depth"] += depth_loss.item()
                    epoch_valid["combined"] += combined_loss.item()
                    epoch_valid["adv"] += adv_loss.item()
                    # epoch_valid["iou"] += mean_iou(outputs["seg_output"], seg_labels, num_classes=20).item()
                    num_valid_batches += 1

        # Average validation metrics
        for key in epoch_valid.keys():
            valid_losses[key].append(epoch_valid[key] / num_valid_batches)

        # Save best model
        valid_combined_loss = epoch_valid["combined"] / num_valid_batches
        if valid_combined_loss < best_combined_loss:
            best_combined_loss = valid_combined_loss
            torch.save(model.state_dict(), os.path.join(save_dir, "best_model.pth"))
            print(f"Best model saved at epoch {epoch+1} with combined loss {best_combined_loss:.4f}")
            
        frame = save_training_visualization_as_gif2(epoch, inputs, outputs["seg_output"], outputs["depth_output"], torch.argmax(seg_labels, dim=1), depth_labels)
        gif_frames.append(frame)
        
        

        # Append metrics to CSV
        with open(csv_path, "a", newline="") as f:
            writer = csv.writer(f)
            writer.writerow([
                epoch + 1,
                epoch_train["seg"] / num_batches,
                epoch_train["depth"] / num_batches,
                epoch_train["combined"] / num_batches,
                epoch_train["adv"] / num_batches,
                # epoch_train["iou"] / num_batches,
                epoch_valid["seg"] / num_valid_batches,
                epoch_valid["depth"] / num_valid_batches,
                epoch_valid["combined"] / num_valid_batches,
                epoch_valid["adv"] / num_valid_batches,
                # epoch_valid["iou"] / num_valid_batches,
            ])
            
        # Print epoch results
        print(f"Epoch {epoch + 1}/{num_epochs} Results:")

        # Print training losses
        print(f"  Train Losses - Segmentation: {epoch_train['seg']/num_batches:.4f}, Depth: {epoch_train['depth']/num_batches:.4f}, "
              f"Combined: {epoch_train['combined']/num_batches:.4f}, Adversarial: {epoch_train['adv']/num_batches:.4f}")

        # Print validation losses
        print(f"  Valid Losses - Segmentation: {epoch_valid['seg']/ num_valid_batches:.4f}, Depth: {epoch_valid['depth']/ num_valid_batches:.4f}, "
              f"Combined: {epoch_valid['combined']/ num_valid_batches:.4f}, Adversarial: {epoch_valid['adv']/ num_valid_batches:.4f}")


        # Update schedulers
        for name, scheduler in opt_sched["schedulers"].items():
            if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
                # Pass the appropriate metric to ReduceLROnPlateau
                scheduler.step(valid_losses["combined"][-1])  # Use the most recent validation combined loss
            else:
                scheduler.step()
            
        if epoch %10 == 0:
            gif_path2 =os.path.join(save_dir,f"viz_epoch_{epoch}.gif")
            gif_frames[0].save(gif_path2, save_all=True, append_images=gif_frames[1:], duration=500, loop=0)
            plot_all_losses(epoch, train_losses,valid_losses,save_dir)
            
    
    gif_frames[0].save(gif_path, save_all=True, append_images=gif_frames[1:], duration=500, loop=0)
    print(f"Training visualization saved as GIF at {gif_path}")

    
    return train_losses, valid_losses


# Visualize using torchviz

In [24]:
# !pip install torchviz


In [25]:
from torchviz import make_dot

In [26]:
# Load the model
mobilenet_backbone = mobilenet_v3_small(weights="IMAGENET1K_V1").features
model = MultiTaskModel(backbone=mobilenet_backbone, num_seg_classes=20, depth_channels=1)
model.eval()


MultiTaskModel(
  (feature_generator): MobileNetV3Backbone(
    (backbone): Sequential(
      (0): Conv2dNormActivation(
        (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
        (2): Hardswish()
      )
      (1): InvertedResidual(
        (block): Sequential(
          (0): Conv2dNormActivation(
            (0): Conv2d(16, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=16, bias=False)
            (1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
          )
          (1): SqueezeExcitation(
            (avgpool): AdaptiveAvgPool2d(output_size=1)
            (fc1): Conv2d(16, 8, kernel_size=(1, 1), stride=(1, 1))
            (fc2): Conv2d(8, 16, kernel_size=(1, 1), stride=(1, 1))
            (activation): ReLU()
            (scale_activation): Hardsigmoid()
          

In [29]:
# Create a dummy input tensor (adjust shape based on your model input requirements)
dummy_input = torch.randn(1, 3, 224, 224)  # Example: batch_size=1, channels=3, height=224, width=224
input_size = dummy_input.size()[-2:]
input_size

torch.Size([224, 224])

In [30]:
# Forward pass
output = model(dummy_input,input_size=input_size)

In [32]:
# Visualize the model
dot = make_dot(output["seg_output"], params=dict(model.named_parameters()))  # Adjust output key if necessary
dot.format = "png"
dot.render("multitask_model_graph.png")  # This saves the graph as multitask_model_graph.png

print("Model visualization saved as multitask_model_graph.png")

Model visualization saved as multitask_model_graph.png


In [33]:
# Visualize the model
dot = make_dot(output["seg_output"], params=dict(model.named_parameters()), show_attrs=True, show_saved=True)  # Adjust output key if necessary
dot.format = "png"
dot.render("multitask_model_graph2.png")  # This saves the graph as multitask_model_graph.png

print("Model visualization saved as multitask_model_graph2.png")

Model visualization saved as multitask_model_graph2.png


In [38]:
# Visualize the model
dot = make_dot(output["seg_output"], params=dict(model.named_parameters()), show_attrs=True, show_saved=True)  # Adjust output key if necessary
dot.format = "png"
dot.render("model_with_mobilenetv3_as_block_seg")  # This saves the graph as multitask_model_graph.png

print("Model visualization saved as model_with_mobilenetv3_as_block_seg.png")

Model visualization saved as model_with_mobilenetv3_as_block_seg.png


In [39]:
# Visualize the model
dot = make_dot(output["depth_output"], params=dict(model.named_parameters()), show_attrs=True, show_saved=True)  # Adjust output key if necessary
dot.format = "png"
dot.render("model_with_mobilenetv3_as_block_depth.png")  # This saves the graph as multitask_model_graph.png

print("Model visualization saved as model_with_mobilenetv3_as_block_depth.png")

Model visualization saved as model_with_mobilenetv3_as_block_depth.png


In [40]:

# Instantiate the model
mobilenet_backbone = mobilenet_v3_small(weights="IMAGENET1K_V1").features
model = MultiTaskModel(backbone=mobilenet_backbone, num_seg_classes=20, depth_channels=1)
model.eval()

# Create a dummy input tensor
dummy_input = torch.randn(1, 3, 224, 224)

# Forward pass
outputs = model(dummy_input, input_size=(224, 224))

# Visualize the model
dot = make_dot(outputs["seg_output"], params=dict(model.named_parameters()))
dot.format = "png"
dot.render("multi_task_model")

'multi_task_model.png'

# running the model

In [59]:
# Initialize the model
# Instantiate Models

BATCH_SIZE = 8
EPOCHS = 30
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'


mobilenet_backbone = mobilenet_v3_small(weights=MobileNet_V3_Small_Weights.IMAGENET1K_V1)
# encoder = MobileNetV3Backbone(mobilenet_backbone.features)
model = MultiTaskModel(backbone=mobilenet_backbone.features, num_seg_classes=20, depth_channels=1)
model.to(DEVICE)

# Initialize optimizers and schedulers
opt_sched = initialize_optimizers_and_schedulers(model)

# Access optimizers
optimizers = opt_sched["optimizers"]
schedulers = opt_sched["schedulers"]

In [60]:
# Prepare Data Loaders (Ensure train_loader and valid_loader are ready)
train_losses, valid_losses = train_model_with_adversarial_loss_tracking(
    model=model,
    train_loader=train_loader,
    valid_loader=valid_loader,
    num_epochs=EPOCHS,
    device=DEVICE,
    opt_sched=opt_sched,
    save_dir="results_test8_final"
)

Epoch 1/30 - Training:   0%|          | 0/371 [08:56<?, ?batch/s]


Best model saved at epoch 1 with combined loss 2.4547
Epoch 1/30 Results:
  Train Losses - Segmentation: 2.4241, Depth: 0.0622, Combined: 2.4565, Adversarial: -2.9767
  Valid Losses - Segmentation: 2.3414, Depth: 0.1430, Combined: 2.4547, Adversarial: -2.9770


Epoch 2/30 - Training:   0%|          | 0/371 [08:07<?, ?batch/s]


Best model saved at epoch 2 with combined loss 2.2811
Epoch 2/30 Results:
  Train Losses - Segmentation: 2.2065, Depth: 0.0461, Combined: 2.2226, Adversarial: -2.9956
  Valid Losses - Segmentation: 2.2700, Depth: 0.0409, Combined: 2.2811, Adversarial: -2.9745


Epoch 3/30 - Training:   0%|          | 0/371 [08:07<?, ?batch/s]


Best model saved at epoch 3 with combined loss 2.2520
Epoch 3/30 Results:
  Train Losses - Segmentation: 2.1411, Depth: 0.0424, Combined: 2.1535, Adversarial: -2.9968
  Valid Losses - Segmentation: 2.2153, Depth: 0.0670, Combined: 2.2520, Adversarial: -3.0336


Epoch 4/30 - Training:   0%|          | 0/371 [08:07<?, ?batch/s]


Best model saved at epoch 4 with combined loss 2.0698
Epoch 4/30 Results:
  Train Losses - Segmentation: 2.0640, Depth: 0.0398, Combined: 2.0739, Adversarial: -2.9975
  Valid Losses - Segmentation: 2.0565, Depth: 0.0431, Combined: 2.0698, Adversarial: -2.9766


Epoch 5/30 - Training:   0%|          | 0/371 [08:06<?, ?batch/s]


Epoch 5/30 Results:
  Train Losses - Segmentation: 2.0353, Depth: 0.0392, Combined: 2.0445, Adversarial: -2.9978
  Valid Losses - Segmentation: 2.0531, Depth: 0.0663, Combined: 2.0896, Adversarial: -2.9907


Epoch 6/30 - Training:   0%|          | 0/371 [08:07<?, ?batch/s]


Epoch 6/30 Results:
  Train Losses - Segmentation: 2.0003, Depth: 0.0376, Combined: 2.0080, Adversarial: -2.9981
  Valid Losses - Segmentation: 2.1355, Depth: 0.0628, Combined: 2.1683, Adversarial: -2.9978


Epoch 7/30 - Training:   0%|          | 0/371 [08:07<?, ?batch/s]


Best model saved at epoch 7 with combined loss 2.0074
Epoch 7/30 Results:
  Train Losses - Segmentation: 1.9766, Depth: 0.0361, Combined: 1.9827, Adversarial: -2.9980
  Valid Losses - Segmentation: 1.9935, Depth: 0.0438, Combined: 2.0074, Adversarial: -2.9877


Epoch 8/30 - Training:   0%|          | 0/371 [08:08<?, ?batch/s]


Epoch 8/30 Results:
  Train Losses - Segmentation: 1.9690, Depth: 0.0366, Combined: 1.9757, Adversarial: -2.9984
  Valid Losses - Segmentation: 1.9993, Depth: 0.0763, Combined: 2.0457, Adversarial: -2.9892


Epoch 9/30 - Training:   0%|          | 0/371 [08:09<?, ?batch/s]


Best model saved at epoch 9 with combined loss 1.9606
Epoch 9/30 Results:
  Train Losses - Segmentation: 1.9558, Depth: 0.0352, Combined: 1.9610, Adversarial: -2.9984
  Valid Losses - Segmentation: 1.9211, Depth: 0.0696, Combined: 1.9606, Adversarial: -3.0070


Epoch 10/30 - Training:   0%|          | 0/371 [08:11<?, ?batch/s]


Epoch 10/30 Results:
  Train Losses - Segmentation: 1.9284, Depth: 0.0344, Combined: 1.9328, Adversarial: -2.9985
  Valid Losses - Segmentation: 2.0639, Depth: 0.0570, Combined: 2.0910, Adversarial: -2.9877


Epoch 11/30 - Training:   0%|          | 0/371 [08:08<?, ?batch/s]


Epoch 11/30 Results:
  Train Losses - Segmentation: 1.9304, Depth: 0.0344, Combined: 1.9348, Adversarial: -2.9987
  Valid Losses - Segmentation: 2.0771, Depth: 0.0677, Combined: 2.1145, Adversarial: -3.0219


Epoch 12/30 - Training:   0%|          | 0/371 [08:09<?, ?batch/s]


Epoch 12/30 Results:
  Train Losses - Segmentation: 1.9135, Depth: 0.0354, Combined: 1.9189, Adversarial: -2.9988
  Valid Losses - Segmentation: 1.9600, Depth: 0.0468, Combined: 1.9768, Adversarial: -3.0046


Epoch 13/30 - Training:   0%|          | 0/371 [08:08<?, ?batch/s]


Epoch 13/30 Results:
  Train Losses - Segmentation: 1.8651, Depth: 0.0335, Combined: 1.8686, Adversarial: -2.9989
  Valid Losses - Segmentation: 1.9520, Depth: 0.0414, Combined: 1.9635, Adversarial: -2.9953


Epoch 14/30 - Training:   0%|          | 0/371 [08:08<?, ?batch/s]


Epoch 14/30 Results:
  Train Losses - Segmentation: 1.8722, Depth: 0.0327, Combined: 1.8749, Adversarial: -2.9990
  Valid Losses - Segmentation: 1.9996, Depth: 0.0552, Combined: 2.0248, Adversarial: -2.9971


Epoch 15/30 - Training:   0%|          | 0/371 [08:09<?, ?batch/s]


Epoch 15/30 Results:
  Train Losses - Segmentation: 1.8755, Depth: 0.0331, Combined: 1.8787, Adversarial: -2.9991
  Valid Losses - Segmentation: 1.9893, Depth: 0.0478, Combined: 2.0071, Adversarial: -3.0014


Epoch 16/30 - Training:   0%|          | 0/371 [08:06<?, ?batch/s]


Epoch 16/30 Results:
  Train Losses - Segmentation: 1.8393, Depth: 0.0318, Combined: 1.8411, Adversarial: -2.9991
  Valid Losses - Segmentation: 2.0337, Depth: 0.0464, Combined: 2.0501, Adversarial: -3.0044


Epoch 17/30 - Training:   0%|          | 0/371 [08:09<?, ?batch/s]


Best model saved at epoch 17 with combined loss 1.9592
Epoch 17/30 Results:
  Train Losses - Segmentation: 1.8255, Depth: 0.0317, Combined: 1.8272, Adversarial: -2.9993
  Valid Losses - Segmentation: 1.9454, Depth: 0.0436, Combined: 1.9592, Adversarial: -2.9847


Epoch 18/30 - Training:   0%|          | 0/371 [08:07<?, ?batch/s]


Epoch 18/30 Results:
  Train Losses - Segmentation: 1.8452, Depth: 0.0313, Combined: 1.8464, Adversarial: -2.9992
  Valid Losses - Segmentation: 2.0587, Depth: 0.0494, Combined: 2.0780, Adversarial: -3.0008


Epoch 19/30 - Training:   0%|          | 0/371 [08:07<?, ?batch/s]


Best model saved at epoch 19 with combined loss 1.8837
Epoch 19/30 Results:
  Train Losses - Segmentation: 1.8174, Depth: 0.0309, Combined: 1.8183, Adversarial: -2.9993
  Valid Losses - Segmentation: 1.8717, Depth: 0.0418, Combined: 1.8837, Adversarial: -2.9841


Epoch 20/30 - Training:   0%|          | 0/371 [08:06<?, ?batch/s]


Epoch 20/30 Results:
  Train Losses - Segmentation: 1.8296, Depth: 0.0316, Combined: 1.8312, Adversarial: -2.9994
  Valid Losses - Segmentation: 2.0162, Depth: 0.0428, Combined: 2.0291, Adversarial: -2.9927


Epoch 21/30 - Training:   0%|          | 0/371 [08:07<?, ?batch/s]


Epoch 21/30 Results:
  Train Losses - Segmentation: 1.8039, Depth: 0.0301, Combined: 1.8040, Adversarial: -2.9994
  Valid Losses - Segmentation: 1.9842, Depth: 0.0625, Combined: 2.0167, Adversarial: -2.9956


Epoch 22/30 - Training:   0%|          | 0/371 [08:09<?, ?batch/s]


Best model saved at epoch 22 with combined loss 1.8619
Epoch 22/30 Results:
  Train Losses - Segmentation: 1.7949, Depth: 0.0302, Combined: 1.7950, Adversarial: -2.9995
  Valid Losses - Segmentation: 1.8484, Depth: 0.0435, Combined: 1.8619, Adversarial: -2.9881


Epoch 23/30 - Training:   0%|          | 0/371 [08:07<?, ?batch/s]


Epoch 23/30 Results:
  Train Losses - Segmentation: 1.7947, Depth: 0.0294, Combined: 1.7942, Adversarial: -2.9995
  Valid Losses - Segmentation: 1.9298, Depth: 0.0499, Combined: 1.9496, Adversarial: -3.0075


Epoch 24/30 - Training:   0%|          | 0/371 [08:09<?, ?batch/s]


Epoch 24/30 Results:
  Train Losses - Segmentation: 1.7751, Depth: 0.0295, Combined: 1.7746, Adversarial: -2.9996
  Valid Losses - Segmentation: 1.9406, Depth: 0.0457, Combined: 1.9563, Adversarial: -3.0065


Epoch 25/30 - Training:   0%|          | 0/371 [08:08<?, ?batch/s]


Epoch 25/30 Results:
  Train Losses - Segmentation: 1.7948, Depth: 0.0300, Combined: 1.7948, Adversarial: -2.9996
  Valid Losses - Segmentation: 1.9196, Depth: 0.0580, Combined: 1.9476, Adversarial: -2.9954


Epoch 26/30 - Training:   0%|          | 0/371 [08:08<?, ?batch/s]


Epoch 26/30 Results:
  Train Losses - Segmentation: 1.7758, Depth: 0.0295, Combined: 1.7752, Adversarial: -2.9996
  Valid Losses - Segmentation: 1.8867, Depth: 0.0347, Combined: 1.8913, Adversarial: -2.9978


Epoch 27/30 - Training:   0%|          | 0/371 [08:08<?, ?batch/s]


Epoch 27/30 Results:
  Train Losses - Segmentation: 1.7570, Depth: 0.0295, Combined: 1.7565, Adversarial: -2.9996
  Valid Losses - Segmentation: 1.8660, Depth: 0.0414, Combined: 1.8773, Adversarial: -3.0068


Epoch 28/30 - Training:   0%|          | 0/371 [08:09<?, ?batch/s]


Epoch 28/30 Results:
  Train Losses - Segmentation: 1.7395, Depth: 0.0284, Combined: 1.7379, Adversarial: -2.9997
  Valid Losses - Segmentation: 1.9218, Depth: 0.0453, Combined: 1.9371, Adversarial: -3.0008


Epoch 29/30 - Training:   0%|          | 0/371 [08:08<?, ?batch/s]


Epoch 29/30 Results:
  Train Losses - Segmentation: 1.7425, Depth: 0.0278, Combined: 1.7403, Adversarial: -2.9997
  Valid Losses - Segmentation: 1.8592, Depth: 0.0434, Combined: 1.8727, Adversarial: -2.9902


Epoch 30/30 - Training:   0%|          | 0/371 [08:06<?, ?batch/s]


Epoch 30/30 Results:
  Train Losses - Segmentation: 1.7296, Depth: 0.0284, Combined: 1.7280, Adversarial: -2.9997
  Valid Losses - Segmentation: 1.9274, Depth: 0.0409, Combined: 1.9383, Adversarial: -3.0075
Training visualization saved as GIF at results_test8_final/20241127_132230/training_visualization_20241127_132230.gif


In [None]:
# check with perpetual loss later

In [None]:
# import os
# import csv
# from datetime import datetime
# import matplotlib.pyplot as plt
# import torch
# from torch.utils.data import DataLoader
# from torchvision.utils import save_image

# def train_model_with_loss_tracking(
#     model, train_loader, valid_loader, num_epochs, device, opt_sched, save_dir="results"
# ):
#     """
#     Trains a multi-task model with Conditional GANs, structural consistency, and perceptual loss.

#     Args:
#         model: The multi-task model to train.
#         train_loader: DataLoader for training data.
#         valid_loader: DataLoader for validation data.
#         num_epochs: Number of epochs to train.
#         device: Device for training ("cuda" or "cpu").
#         opt_sched: Dictionary of optimizers and schedulers.
#         save_dir: Directory to save results.

#     Returns:
#         train_losses, valid_losses: Lists of losses for training and validation.
#     """
#     # Create directories for saving results
#     timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
#     save_dir = os.path.join(save_dir, timestamp)
#     os.makedirs(save_dir, exist_ok=True)

#     # Prepare CSV file
#     csv_path = os.path.join(save_dir, f"loss_tracking_{timestamp}.csv")
#     with open(csv_path, "w", newline="") as f:
#         writer = csv.writer(f)
#         writer.writerow([
#             "epoch", "train_seg_loss", "train_depth_loss", "train_combined_loss",
#             "train_depth_sidl", "train_depth_smooth", "train_seg_iou", "train_seg_perceptual_loss",
#             "train_depth_perceptual_loss", "valid_seg_loss", "valid_depth_loss", "valid_combined_loss",
#             "valid_depth_sidl", "valid_depth_smooth", "valid_seg_iou", "valid_seg_perceptual_loss",
#             "valid_depth_perceptual_loss"
#         ])

#     # Initialize tracking variables
#     train_losses = {
#         "seg": [], "depth": [], "combined": [], "iou": [], "depth_sidl": [], "depth_smooth": [],
#         "seg_perceptual": [], "depth_perceptual": []
#     }
#     valid_losses = {
#         "seg": [], "depth": [], "combined": [], "iou": [], "depth_sidl": [], "depth_smooth": [],
#         "seg_perceptual": [], "depth_perceptual": []
#     }
#     best_combined_loss = float("inf")
#     gif_frames = []

#     # Perceptual Loss (example using VGG features)
#     perceptual_loss_fn = PerceptualLoss(pretrained_model="vgg16").to(device)

#     for epoch in range(num_epochs):
#         model.train()
#         epoch_train = {key: 0.0 for key in train_losses.keys()}
#         num_batches = 0

#         for batch in train_loader:
#             inputs, seg_labels, depth_labels = (
#                 batch["left"].to(device),
#                 batch["mask"].to(device),
#                 batch["depth"].to(device)
#             )
#             latent_noise = torch.randn(inputs.size(0), 3).to(device)

#             # Zero gradients
#             for optimizer in opt_sched["optimizers"].values():
#                 optimizer.zero_grad()

#             # Forward pass
#             outputs = model(inputs, input_size=inputs.size()[-2:])

#             # Loss calculations
#             seg_loss_task = nn.CrossEntropyLoss()(outputs["seg_output"], seg_labels) + dice_loss(outputs["seg_output"], seg_labels)
#             seg_perceptual_loss = perceptual_loss_fn(outputs["seg_output"], seg_labels.unsqueeze(1))
#             seg_loss = seg_loss_task + 0.1 * seg_perceptual_loss

#             depth_loss_sidl = scale_invariant_depth_loss(outputs["depth_output"], depth_labels)
#             depth_loss_huber = inv_huber_loss(outputs["depth_output"], depth_labels)
#             depth_loss_smooth = depth_smoothness_loss(outputs["depth_output"], inputs)
#             depth_perceptual_loss = perceptual_loss_fn(outputs["depth_output"], depth_labels)
#             depth_loss = depth_loss_sidl + depth_loss_huber + depth_loss_smooth + 0.1 * depth_perceptual_loss

#             combined_loss = seg_loss + depth_loss

#             # Backpropagation
#             combined_loss.backward()
#             for optimizer in opt_sched["optimizers"].values():
#                 optimizer.step()

#             # Update training metrics
#             epoch_train["seg"] += seg_loss.item()
#             epoch_train["depth"] += depth_loss.item()
#             epoch_train["combined"] += combined_loss.item()
#             epoch_train["iou"] += mean_iou(outputs["seg_output"], seg_labels, num_classes=20).item()
#             epoch_train["depth_sidl"] += depth_loss_sidl.item()
#             epoch_train["depth_smooth"] += depth_loss_smooth.item()
#             epoch_train["seg_perceptual"] += seg_perceptual_loss.item()
#             epoch_train["depth_perceptual"] += depth_perceptual_loss.item()
#             num_batches += 1

#         # Average training metrics
#         for key in epoch_train.keys():
#             train_losses[key].append(epoch_train[key] / num_batches)

#         # Validation loop
#         model.eval()
#         epoch_valid = {key: 0.0 for key in valid_losses.keys()}
#         num_valid_batches = 0

#         with torch.no_grad():
#             for batch in valid_loader:
#                 inputs, seg_labels, depth_labels = (
#                     batch["left"].to(device),
#                     batch["mask"].to(device),
#                     batch["depth"].to(device)
#                 )
#                 latent_noise = torch.randn(inputs.size(0), 3).to(device)

#                 # Forward pass
#                 outputs = model(inputs, input_size=inputs.size()[-2:])

#                 # Validation loss calculations
#                 seg_loss_task = nn.CrossEntropyLoss()(outputs["seg_output"], seg_labels) + dice_loss(outputs["seg_output"], seg_labels)
#                 seg_perceptual_loss = perceptual_loss_fn(outputs["seg_output"], seg_labels.unsqueeze(1))
#                 seg_loss = seg_loss_task + 0.1 * seg_perceptual_loss

#                 depth_loss_sidl = scale_invariant_depth_loss(outputs["depth_output"], depth_labels)
#                 depth_loss_huber = inv_huber_loss(outputs["depth_output"], depth_labels)
#                 depth_loss_smooth = depth_smoothness_loss(outputs["depth_output"], inputs)
#                 depth_perceptual_loss = perceptual_loss_fn(outputs["depth_output"], depth_labels)
#                 depth_loss = depth_loss_sidl + depth_loss_huber + depth_loss_smooth + 0.1 * depth_perceptual_loss

#                 combined_loss = seg_loss + depth_loss

#                 # Update validation metrics
#                 epoch_valid["seg"] += seg_loss.item()
#                 epoch_valid["depth"] += depth_loss.item()
#                 epoch_valid["combined"] += combined_loss.item()
#                 epoch_valid["iou"] += mean_iou(outputs["seg_output"], seg_labels, num_classes=20).item()
#                 epoch_valid["depth_sidl"] += depth_loss_sidl.item()
#                 epoch_valid["depth_smooth"] += depth_loss_smooth.item()
#                 epoch_valid["seg_perceptual"] += seg_perceptual_loss.item()
#                 epoch_valid["depth_perceptual"] += depth_perceptual_loss.item()
#                 num_valid_batches += 1

#         # Average validation metrics
#         for key in epoch_valid.keys():
#             valid_losses[key].append(epoch_valid[key] / num_valid_batches)

#         # Save best model
#         valid_combined_loss = epoch_valid["combined"] / num_valid_batches
#         if valid_combined_loss < best_combined_loss:
#             best_combined_loss = valid_combined_loss
#             torch.save(model.state_dict(), os.path.join(save_dir, "best_model.pth"))

#         # Append metrics to CSV
#         with open(csv_path, "a", newline="") as f:
#             writer = csv.writer(f)
#             writer.writerow([
#                 epoch + 1,
#                 epoch_train["seg"] / num_batches,
#                 epoch_train["depth"] / num_batches,
#                 epoch_train["combined"] / num_batches,
#                 epoch_train["depth_sidl"] / num_batches,
#                 epoch_train["depth_smooth"] / num_batches,
#                 epoch_train["iou"] / num_batches,
#                 epoch_train["seg_perceptual"] / num_batches,
#                 epoch_train["depth_perceptual"] / num_batches,
#                 epoch_valid["seg"] / num_valid_batches,
#                 epoch_valid["depth"] / num_valid_batches,
#                 epoch_valid["combined"] / num_valid_batches,
#                 epoch_valid["depth_sidl"] / num_valid_batches,
#                 epoch_valid["depth_smooth"] / num_valid_batches,
#                 epoch_valid["iou"] / num_valid_batches,
#                 epoch_valid["seg_perceptual"] / num_valid_batches,
#                 epoch_valid["depth_perceptual"] / num_valid_batches,
#             ])

#         # Update schedulers
#         for scheduler in opt_sched["schedulers"].values():
#             scheduler.step()
            
#     plot_all_losses(train_losses,valid_losses)

    

#     return train_losses, valid_losses


In [None]:
# import os
# import csv
# import torch
# import torch.nn as nn
# from datetime import datetime
# from torchvision.utils import save_image
# from PIL import Image
# import matplotlib.pyplot as plt

# # Updated Train Function
# def train_model_with_loss_tracking_and_gif(
#     model, train_loader, valid_loader, num_epochs, device, save_dir="training_output_bicycle_and_pix2pix"
# ):
#     # Create directories for saving models and outputs
#     os.makedirs(save_dir, exist_ok=True)
#     timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
#     save_dir = os.path.join(save_dir, timestamp)
#     os.makedirs(save_dir, exist_ok=True)

#     csv_path = os.path.join(save_dir, f"loss_tracking_{timestamp}.csv")
#     gif_path = os.path.join(save_dir, f"training_visualization_{timestamp}.gif")

#     # Initialize CSV for saving loss data
#     with open(csv_path, "w", newline="") as f:
#         writer = csv.writer(f)
#         writer.writerow([
#             "epoch", "train_seg_loss", "train_depth_loss", "train_combined_loss",
#             "train_depth_sidl", "train_depth_smooth", "train_seg_iou",
#             "valid_seg_loss", "valid_depth_loss", "valid_combined_loss",
#             "valid_depth_sidl", "valid_depth_smooth", "valid_seg_iou"
#         ])

#     best_combined_loss = float("inf")  # Initialize best combined loss for saving the best model
#     train_losses = {"seg": [], "depth": [], "combined": [], "iou": [], "depth_sidl": [], "depth_smooth": []}
#     valid_losses = {"seg": [], "depth": [], "combined": [], "iou": [], "depth_sidl": [], "depth_smooth": []}
#     gif_frames = []

#     for epoch in range(num_epochs):
#         torch.cuda.empty_cache()
#         model.train()
#         epoch_train = {key: 0.0 for key in train_losses.keys()}
#         num_batches = 0

#         # Training Loop
#         for batch in train_loader:
#             inputs, seg_labels, depth_labels = batch["left"].to(device), batch["mask"].to(device), batch["depth"].to(device)
#             latent_noise = torch.randn(inputs.size(0), 3).to(device)

#             model.optimizer_stage1.zero_grad()
#             model.optimizer_stage2.zero_grad()

#             # Forward Pass
#             outputs = model(inputs, seg_labels, depth_labels, latent_noise)
#             seg_output = outputs["seg_output"]
#             depth_output = outputs["depth_output"]

#             # Loss Calculations
#             seg_loss = nn.CrossEntropyLoss()(seg_output, seg_labels.squeeze(1))
#             seg_dice = dice_loss(seg_output, seg_labels)
#             seg_iou = mean_iou(seg_output, seg_labels, num_classes=20)
#             seg_loss_total = 0.6 * seg_loss + 0.4 * seg_dice

#             depth_sidl = scale_invariant_depth_loss(depth_output, depth_labels)
#             depth_smooth = depth_smoothness_loss(depth_output, inputs)
#             depth_loss_total = depth_sidl + depth_smooth

#             # Combined Loss
#             total_loss = seg_loss_total + depth_loss_total
#             total_loss.backward()

#             # Optimizers Step
#             model.optimizer_stage1.step()
#             model.optimizer_stage2.step()

#             # Accumulate Training Metrics
#             epoch_train["seg"] += seg_loss.item()
#             epoch_train["depth"] += depth_loss_total.item()
#             epoch_train["combined"] += total_loss.item()
#             epoch_train["iou"] += seg_iou.item()
#             epoch_train["depth_sidl"] += depth_sidl.item()
#             epoch_train["depth_smooth"] += depth_smooth.item()
#             num_batches += 1

#             # Save training images for visualization
#             if num_batches % 10 == 0:
#                 img_grid = torch.cat([inputs[0], seg_output[0].argmax(0, keepdim=True), depth_output[0]], dim=2)
#                 save_image(img_grid, os.path.join(save_dir, f"train_{epoch}_{num_batches}.png"))
#                 gif_frames.append(Image.open(os.path.join(save_dir, f"train_{epoch}_{num_batches}.png")))

#         model.scheduler_stage1.step()
#         model.scheduler_stage2.step(epoch_train["combined"] / num_batches)

#         # Average Training Losses
#         for key in epoch_train.keys():
#             train_losses[key].append(epoch_train[key] / num_batches)

#         # Validation Loop
#         model.eval()
#         epoch_valid = {key: 0.0 for key in valid_losses.keys()}
#         num_valid_batches = 0
#         with torch.no_grad():
#             for batch in valid_loader:
#                 inputs, seg_labels, depth_labels = batch["left"].to(device), batch["mask"].to(device), batch["depth"].to(device)
#                 latent_noise = torch.randn(inputs.size(0), 3).to(device)

#                 outputs = model(inputs, seg_labels, depth_labels, latent_noise)
#                 seg_output = outputs["seg_output"]
#                 depth_output = outputs["depth_output"]

#                 # Validation Loss Calculations
#                 seg_loss = nn.CrossEntropyLoss()(seg_output, seg_labels.squeeze(1))
#                 seg_dice = dice_loss(seg_output, seg_labels)
#                 seg_iou = mean_iou(seg_output, seg_labels, num_classes=20)
#                 seg_loss_total = 0.6 * seg_loss + 0.4 * seg_dice

#                 depth_sidl = scale_invariant_depth_loss(depth_output, depth_labels)
#                 depth_smooth = depth_smoothness_loss(depth_output, inputs)
#                 depth_loss_total = depth_sidl + depth_smooth

#                 # Accumulate Validation Metrics
#                 epoch_valid["seg"] += seg_loss.item()
#                 epoch_valid["depth"] += depth_loss_total.item()
#                 epoch_valid["combined"] += (seg_loss_total + depth_loss_total).item()
#                 epoch_valid["iou"] += seg_iou.item()
#                 epoch_valid["depth_sidl"] += depth_sidl.item()
#                 epoch_valid["depth_smooth"] += depth_smooth.item()
#                 num_valid_batches += 1
            
#             frame = save_training_visualization_as_gif2(epoch, inputs, seg_output, depth_output, seg_labels, depth_labels)
#             gif_frames.append(frame)
              

#         for key in epoch_valid.keys():
#             valid_losses[key].append(epoch_valid[key] / num_valid_batches)

#         # Save Model if Validation Loss Improves
#         valid_combined_loss = epoch_valid["combined"] / num_valid_batches
#         if valid_combined_loss < best_combined_loss:
#             best_combined_loss = valid_combined_loss
#             torch.save(model.state_dict(), os.path.join(save_dir, "best_model.pth"))
#             print(f"Best model saved at epoch {epoch+1} with combined loss {best_combined_loss:.4f}")
            
#         if epoch%10==0:
#             gif_path2 =os.path.join(save_dir,f"viz_epoch_{epoch}.gif")
#             gif_frames[0].save(gif_path2, save_all=True, append_images=gif_frames[1:], duration=500, loop=0)

            

#         # Append Validation Metrics to CSV
#         with open(csv_path, "a", newline="") as f:
#             writer = csv.writer(f)
#             writer.writerow([
#                 epoch + 1,
#                 epoch_train["seg"] / num_batches,
#                 epoch_train["depth"] / num_batches,
#                 epoch_train["combined"] / num_batches,
#                 epoch_train["depth_sidl"] / num_batches,
#                 epoch_train["depth_smooth"] / num_batches,
#                 epoch_train["iou"] / num_batches,
#                 epoch_valid["seg"] / num_valid_batches,
#                 epoch_valid["depth"] / num_valid_batches,
#                 epoch_valid["combined"] / num_valid_batches,
#                 epoch_valid["depth_sidl"] / num_valid_batches,
#                 epoch_valid["depth_smooth"] / num_valid_batches,
#                 epoch_valid["iou"] / num_valid_batches,
#             ])

#     # Save GIF
#     gif_frames[0].save(
#         gif_path, save_all=True, append_images=gif_frames[1:], duration=200, loop=0
#     )

#     return train_losses, valid_losses


In [28]:


# def train_model_with_loss_tracking_and_gif(
#     model, train_loader, valid_loader, num_epochs, device, save_dir="training_output_bicycle_and_pix2pix"):
#     # Create directory for saving models and outputs
#     os.makedirs(save_dir, exist_ok=True)
#     timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
#     save_dir = os.path.join(save_dir, timestamp)
#     os.makedirs(save_dir, exist_ok=True)

#     csv_path = os.path.join(save_dir, f"loss_tracking_{timestamp}.csv")
#     gif_path = os.path.join(save_dir, f"training_visualization_{timestamp}.gif")

#     # Initialize CSV for saving loss data
#     with open(csv_path, "w", newline="") as f:
#         writer = csv.writer(f)
#         writer.writerow([
#             "epoch", "train_seg_loss", "train_depth_loss", "train_combined_loss",
#             "train_depth_sidl", "train_depth_inv_huber", "train_depth_contrastive", "train_depth_smooth",
#             "valid_seg_loss", "valid_depth_loss", "valid_combined_loss",
#             "valid_depth_sidl", "valid_depth_inv_huber", "valid_depth_contrastive", "valid_depth_smooth"
#         ])

#     best_combined_loss = float("inf")  # Initialize best combined loss for saving the best model

#     train_losses = {"seg": [], "depth": [], "combined": [], "iou": [], "depth_sidl": [], "depth_smooth": []}
#     valid_losses = {"seg": [], "depth": [], "combined": [], "iou": [], "depth_sidl": [], "depth_smooth": []}
#     # , "depth_inv_huber": [], "depth_contrastive": []

#     gif_frames = []
#     num_classes = 20
#     # Optimizer for latent noise
    

#     for epoch in range(num_epochs):
#         torch.cuda.empty_cache()
#         torch.autograd.set_detect_anomaly(True)

#         model.train()
#         # epoch_train_seg_loss = 0
#         # epoch_train_depth_loss = 0
#         # epoch_train_iou = 0
#         # epoch_train_combined_loss = 0
#         # epoch_train_depth_sidl = 0
#         # epoch_train_depth_inv_huber = 0
#         # epoch_train_depth_contrastive = 0
#         # epoch_train_depth_smooth = 0
#         epoch_train = {key: 0.0 for key in train_losses.keys()}
#         num_batches = 0

#         reconstruction_layer = nn.Conv2d(256, 3, kernel_size=1).to(device)
        
#         # scaler = torch.cuda.amp.GradScaler()

#         # Training Loop
#         for batch in train_loader:
#             inputs, seg_labels, depth_labels = batch["left"].to(device), batch["mask"].to(device), batch["depth"].to(device)
#             print(inputs.shape,seg_labels.shape,depth_labels.shape) # torch.Size([8, 3, 200, 512]) torch.Size([8, 1, 200, 512]) torch.Size([8, 1, 1, 200, 512])
#             return
        
                       


#             # Forward pass
#             seg_output, depth_output, backbone_features = model(...)
            


#             seg_loss = nn.CrossEntropyLoss()(seg_output, seg_labels)
#             seg_dice = dice_loss(seg_output, seg_labels)
#             seg_iou = mean_iou(seg_output, seg_labels, num_classes)
#             seg_loss_total = 0.6 * seg_loss  + 0.4 * seg_dice
            
#             depth_sidl = scale_invariant_depth_loss(depth_output, depth_labels)
#             depth_inv_huber = inv_huber_loss(depth_output, depth_labels)
#             depth_smooth = depth_smoothness_loss(depth_output, inputs)
#             depth_loss_total = depth_sidl + depth_inv_huber + depth_smooth
            
           
           
#             # Combined Loss
#             total_loss = bicycle_loss + pix2pix_total_loss

#             # Single backward pass
#             total_loss.backward()

#             # Update both optimizers
#             model.optimizer_stage1.step()
#             model.optimizer_stage2.step()
#             latent_optimizer.step()

#             # Accumulate Training Metrics
#             epoch_train["seg"] += seg_loss.item()
#             epoch_train["depth"] += (depth_sidl + depth_smooth).item()
#             epoch_train["combined"] += total_loss.item()
#             epoch_train["iou"] += seg_iou.item()
#             epoch_train["depth_sidl"] += depth_sidl.item()
#             epoch_train["depth_smooth"] += depth_smooth.item()
#             num_batches += 1
            
#         model.scheduler_stage1.step()
#         model.scheduler_stage2.step(epoch_train["combined"]/num_batches)


#         # Average Training Losses
#         for key in epoch_train.keys():
#             train_losses[key].append(epoch_train[key] / num_batches)

#         print(
#             f"Epoch {epoch+1}/{num_epochs} - Train Seg Loss: {epoch_train['seg']:.4f}, "
#             f"Train Depth Loss: {epoch_train['depth']:.4f}, Train Combined Loss: {epoch_train['combined']:.4f}, "
#             f"Train mIOU: {epoch_train['iou']:.4f}, Train sidl Loss: {epoch_train['depth_sidl']:.4f}, "
#             f"Train depth smooth: {epoch_train['depth_smooth']:.4f}"
#     )       

#         # Validation Loop
#         model.eval()
#         epoch_valid = {key: 0.0 for key in valid_losses.keys()}
#         num_valid_batches = 0

#         with torch.no_grad():
#             for batch in valid_loader:
#                 # print("inside valid")
#                 inputs, seg_labels, depth_labels = batch["left"].to(device), batch["mask"].to(device), batch["depth"].to(device)

#                 # Ensure depth_labels and segmentation labels have correct dimensions
                

               


#                 seg_output_old =seg_output
#                 # Resize seg_output to match the spatial dimensions of seg_labels
#                 seg_output_resized = F.interpolate(seg_output, size=seg_labels.shape[1:], mode='bilinear', align_corners=False)
#                 seg_output = seg_output_resized

#                 depth_output_old = depth_output
#                 depth_output_resized = F.interpolate(depth_output, size=depth_labels.shape[-2:], mode='bilinear', align_corners=False)
#                 depth_output =depth_output_resized


#                 # Segmentation Loss
#                 seg_loss = nn.CrossEntropyLoss()(seg_output, seg_labels)
#                 seg_dice = dice_loss(seg_output, seg_labels)
#                 seg_iou = mean_iou(seg_output, seg_labels, num_classes)
#                 seg_loss_total = 0.6 * seg_loss  + 0.4 * seg_dice
                
#                 depth_sidl = scale_invariant_depth_loss(depth_output, depth_labels)
#                 depth_inv_huber = inv_huber_loss(depth_output, depth_labels)
#                 depth_smooth = depth_smoothness_loss(depth_output, inputs)
#                 depth_loss_total = depth_sidl + depth_inv_huber + depth_smooth

#                 pix2pix_loss = seg_loss_total + depth_loss_total

#                 # Combined Validation Loss
#                 combined_loss = pix2pix_loss

#                 # Accumulate Validation Metrics
#                 epoch_valid["seg"] += seg_loss.item()
#                 epoch_valid["depth"] += (depth_sidl + depth_smooth).item()
#                 epoch_valid["combined"] += combined_loss.item()
#                 epoch_valid["iou"] += seg_iou.item()
#                 epoch_valid["depth_sidl"] += depth_sidl.item()
#                 epoch_valid["depth_smooth"] += depth_smooth.item()
                
#                 num_valid_batches += 1
                
#             frame = save_training_visualization_as_gif2(epoch, inputs, seg_output, depth_output, seg_labels, depth_labels)
#             gif_frames.append(frame)
                
                
#         # Calculate epoch averages
#         # Average Validation Losses
#         for key in epoch_valid.keys():
#             valid_losses[key].append(epoch_valid[key] / num_valid_batches)

#         print(
#             f"Epoch {epoch+1}/{num_epochs} - Valid Seg Loss: {epoch_valid['seg']:.4f}, "
#             f"Valid Depth Loss: {epoch_valid['depth']:.4f}, Valid Combined Loss: {epoch_valid['combined']:.4f}, "
#             f"Valid mIOU: {epoch_valid['iou']:.4f}, Valid sidl Loss: {epoch_valid['depth_sidl']:.4f}, "
#             f"Valid depth smooth: {epoch_valid['depth_smooth']:.4f}"
#         )

#         # Write the losses to CSV
#         with open(csv_path, "a", newline="") as f:
#             writer = csv.writer(f)
#             writer.writerow([
#                 epoch + 1,
#                 train_losses["seg"], train_losses["depth"], train_losses["combined"],
#                 train_losses["depth_sidl"], 0,0,
#                 # avg_train_depth_inv_huber, avg_train_depth_contrastive,
#                 train_losses["depth_smooth"],
#                 valid_losses["seg"], valid_losses["depth"], valid_losses['combined'],
#                 valid_losses["depth_sidl"],0,0,
#                 # avg_valid_depth_inv_huber, avg_valid_depth_contrastive, 
#                 valid_losses["depth_smooth"]
#             ])

       
#         # Save best model
#         if valid_losses["combined"][-1] < best_combined_loss:
#             best_combined_loss = valid_losses["combined"][-1]
#             torch.save(model, os.path.join(save_dir, "best_model_resnetBackbone.pth"))
#             print(f"Best model saved at epoch {epoch+1} with combined loss {best_combined_loss:.4f}")
            
#         if epoch%10==0:
#             gif_path2 =os.path.join(save_dir,f"viz_epoch_{epoch}.gif")
#             gif_frames[0].save(gif_path2, save_all=True, append_images=gif_frames[1:], duration=500, loop=0)

    
    
#     plot_loss(train_losses, valid_losses, save_dir)
#     gif_frames[0].save(gif_path, save_all=True, append_images=gif_frames[1:], duration=500, loop=0)

    
    
    
#     return train_losses,valid_losses


In [29]:

# # Create your model instance
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# # device = 'cpu'
# model = MultiTaskModel(num_seg_classes=20, feature_channels=256).to(device)

In [30]:
# # Set the number of epochs
# num_epochs = 10

# # Call the training function
# train_losses, valid_losses = train_model_with_loss_tracking_and_gif(
#     model=model,
#     train_loader=train_loader,
#     valid_loader=valid_loader,
#     num_epochs=num_epochs,
#     device=device,
#     save_dir="test7_res"
# )

In [31]:


# def train_model_with_loss_tracking_and_gif(
#     model, train_loader, valid_loader, num_epochs, device, save_dir="training_output_bicycle_and_pix2pix"):
#     # Create directory for saving models and outputs
#     os.makedirs(save_dir, exist_ok=True)
#     timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
#     save_dir = os.path.join(save_dir, timestamp)
#     os.makedirs(save_dir, exist_ok=True)

#     csv_path = os.path.join(save_dir, f"loss_tracking_{timestamp}.csv")
#     gif_path = os.path.join(save_dir, f"training_visualization_{timestamp}.gif")

#     # Initialize CSV for saving loss data
#     with open(csv_path, "w", newline="") as f:
#         writer = csv.writer(f)
#         writer.writerow([
#             "epoch", "train_seg_loss", "train_depth_loss", "train_combined_loss",
#             "train_depth_sidl", "train_depth_inv_huber", "train_depth_contrastive", "train_depth_smooth",
#             "valid_seg_loss", "valid_depth_loss", "valid_combined_loss",
#             "valid_depth_sidl", "valid_depth_inv_huber", "valid_depth_contrastive", "valid_depth_smooth"
#         ])

#     best_combined_loss = float("inf")  # Initialize best combined loss for saving the best model

#     train_losses = {"seg": [], "depth": [], "combined": [], "iou": [], "depth_sidl": [], "depth_smooth": []}
#     valid_losses = {"seg": [], "depth": [], "combined": [], "iou": [], "depth_sidl": [], "depth_smooth": []}
#     # , "depth_inv_huber": [], "depth_contrastive": []

#     gif_frames = []
#     num_classes = 20
#     # Optimizer for latent noise
    

#     for epoch in range(num_epochs):
#         torch.cuda.empty_cache()
#         torch.autograd.set_detect_anomaly(True)

#         model.train()
#         # epoch_train_seg_loss = 0
#         # epoch_train_depth_loss = 0
#         # epoch_train_iou = 0
#         # epoch_train_combined_loss = 0
#         # epoch_train_depth_sidl = 0
#         # epoch_train_depth_inv_huber = 0
#         # epoch_train_depth_contrastive = 0
#         # epoch_train_depth_smooth = 0
#         epoch_train = {key: 0.0 for key in train_losses.keys()}
#         num_batches = 0

#         reconstruction_layer = nn.Conv2d(256, 3, kernel_size=1).to(device)
        
#         # scaler = torch.cuda.amp.GradScaler()

#         # Training Loop
#         for batch in train_loader:
#             inputs, seg_labels, depth_labels = batch["left"].to(device), batch["mask"].to(device), batch["depth"].to(device)

#             # Ensure depth_labels and segmentation labels have correct dimensions
#             if depth_labels.dim() == 5:
#                 depth_labels = depth_labels.squeeze(2)
#             if seg_labels.dim() == 4 and seg_labels.shape[1] == 1:
#                 seg_labels = seg_labels.squeeze(1)

#             # Transform depth labels
#             # depth_labels = torch.log(depth_labels.flatten(start_dim=1)) / 5
#             # depth_labels = depth_labels.view_as(depth_labels)  # Restore shape
#             # depth_labels = torch.clamp(depth_labels, min=1e-5) 
#             # depth_labels = torch.log(depth_labels + 1e-5) / 5  # Avoid log(0)

#             # print(f'seg_labels shape : {seg_labels.shape}')
#             # print(f'depth_labels shape: {depth_labels.shape}')

#             # Start with random noise as latent condition
#             if epoch == 0:
#                 latent_noise = torch.randn_like(inputs).to(device)
#                 # print(f"latent_noise: {latent_noise.shape}")
#                 latent_noise.requires_grad = True  # Make it trainable
#                 latent_optimizer = torch.optim.Adam([latent_noise], lr=1e-3)
            
            


#             # Stage 1: Train BicycleGAN (Backbone Features)
            

#             # Reset gradients for both optimizers
#             model.optimizer_stage1.zero_grad()
#             model.optimizer_stage2.zero_grad()
#             latent_optimizer.zero_grad()

#             # Forward pass
#             seg_output, depth_output, backbone_features = model(inputs, latent_noise)
#             # print(f'seg_ouput shape : {seg_output.shape}')
#             # print(f'depth_output shape: {depth_output.shape}')
#             # print(backbone_features.shape)
#             # return


#             seg_output_old =seg_output
#             # Resize seg_output to match the spatial dimensions of seg_labels
#             seg_output_resized = F.interpolate(seg_output, size=seg_labels.shape[1:], mode='bilinear', align_corners=False)
#             seg_output = seg_output_resized

#             # print(f"depth_output shape before resize: {depth_output.shape}")
#             # print(f"depth_labels shape: {depth_labels.shape}")
#             # return

#             depth_output_old = depth_output
#             depth_output_resized = F.interpolate(depth_output, size=depth_labels.shape[-2:], mode='bilinear', align_corners=False)
#             depth_output = depth_output_resized


#             # Pix2Pix Losses
#             seg_loss = nn.CrossEntropyLoss()(seg_output, seg_labels)
#             seg_dice = dice_loss(seg_output, seg_labels)
#             seg_iou = mean_iou(seg_output, seg_labels, num_classes)
#             seg_loss_total = 0.6 * seg_loss  + 0.4 * seg_dice
            
#             depth_sidl = scale_invariant_depth_loss(depth_output, depth_labels)
#             depth_inv_huber = inv_huber_loss(depth_output, depth_labels)
#             depth_smooth = depth_smoothness_loss(depth_output, inputs)
#             depth_loss_total = depth_sidl + depth_inv_huber + depth_smooth
            
#             pix2pix_loss = seg_loss_total + depth_loss_total

#             # Reconstruction loss
#             # inputs_resized = F.interpolate(inputs, size=(backbone_features.size(2), backbone_features.size(3)))
#             # reconstructed_image = reconstruction_layer(backbone_features)
#             # recon_loss = nn.L1Loss()(reconstructed_image, inputs_resized)
#             # adaptive_weight = 1 / (1 + torch.exp(-recon_loss))
#             # adaptive_weight_value = adaptive_weight.item() 

#             # loss_stage1 = nn.MSELoss()(real_validity, torch.ones_like(real_validity).to(device)) + recon_loss
#             # loss_stage1.backward(retain_graph=True)
#             # model.optimizer_stage1.step()

#             # Pix2Pix Adversarial Losses
#             seg_validity = model.segmentation_discriminator(seg_output)
#             depth_validity = model.depth_discriminator(depth_output)
#             adv_seg_loss = nn.MSELoss()(seg_validity, torch.ones_like(seg_validity))
#             adv_depth_loss = nn.MSELoss()(depth_validity, torch.ones_like(depth_validity))
#             pix2pix_total_loss = pix2pix_loss + adv_seg_loss + adv_depth_loss


#             # BicycleGAN Loss with Pix2Pix Condition
#             # real_validity = model.bicycle_discriminator(backbone_features)
#             # recon_loss = nn.L1Loss()(backbone_features, inputs)
#             # bicycle_loss = nn.MSELoss()(real_validity, torch.ones_like(real_validity)) + recon_loss
#             # conditional_bicycle_loss = bicycle_loss + pix2pix_loss
#             # conditional_bicycle_loss.backward(retain_graph=True)
#             # model.optimizer_stage1.step()

#             # BicycleGAN Loss with Pix2Pix Condition
#             real_validity = model.bicycle_discriminator(backbone_features,latent_noise)

#             # Resize inputs to match backbone_features
#             inputs_resized = F.interpolate(inputs, size=backbone_features.shape[-2:], mode='bilinear', align_corners=False)

#             # print(f"backbone_features shape: {backbone_features.shape}, inputs shape: {inputs.shape}")
#             # print(f"inputs_resized shape: {inputs_resized.shape}")
#             # recon_loss = nn.L1Loss()(backbone_features, inputs_resized)
#             bicycle_loss = adv_seg_loss + adv_depth_loss
#             # + recon_loss

#             # Combined Loss
#             total_loss = bicycle_loss + pix2pix_total_loss

#             # Single backward pass
#             total_loss.backward()

#             # Update both optimizers
#             model.optimizer_stage1.step()
#             model.optimizer_stage2.step()
#             latent_optimizer.step()

#             # Accumulate Training Metrics
#             epoch_train["seg"] += seg_loss.item()
#             epoch_train["depth"] += (depth_sidl + depth_smooth).item()
#             epoch_train["combined"] += total_loss.item()
#             epoch_train["iou"] += seg_iou.item()
#             epoch_train["depth_sidl"] += depth_sidl.item()
#             epoch_train["depth_smooth"] += depth_smooth.item()
#             num_batches += 1
            
#         model.scheduler_stage1.step()
#         model.scheduler_stage2.step(epoch_train["combined"]/num_batches)


#         # Average Training Losses
#         for key in epoch_train.keys():
#             train_losses[key].append(epoch_train[key] / num_batches)

#         print(
#             f"Epoch {epoch+1}/{num_epochs} - Train Seg Loss: {epoch_train['seg']:.4f}, "
#             f"Train Depth Loss: {epoch_train['depth']:.4f}, Train Combined Loss: {epoch_train['combined']:.4f}, "
#             f"Train mIOU: {epoch_train['iou']:.4f}, Train sidl Loss: {epoch_train['depth_sidl']:.4f}, "
#             f"Train depth smooth: {epoch_train['depth_smooth']:.4f}"
#     )       

#         # Validation Loop
#         model.eval()
#         # epoch_valid_seg_loss = 0
#         # epoch_valid_depth_loss = 0
#         # epoch_valid_iou =0
#         # epoch_valid_combined_loss = 0
#         # epoch_valid_depth_sidl = 0
#         # epoch_valid_depth_inv_huber = 0
#         # epoch_valid_depth_contrastive = 0
#         # epoch_valid_depth_smooth = 0
#         epoch_valid = {key: 0.0 for key in valid_losses.keys()}
#         num_valid_batches = 0

#         with torch.no_grad():
#             for batch in valid_loader:
#                 # print("inside valid")
#                 inputs, seg_labels, depth_labels = batch["left"].to(device), batch["mask"].to(device), batch["depth"].to(device)

#                 # Ensure depth_labels and segmentation labels have correct dimensions
#                 if depth_labels.dim() == 5:
#                     depth_labels = depth_labels.squeeze(2)
#                 if seg_labels.dim() == 4 and seg_labels.shape[1] == 1:
#                     seg_labels = seg_labels.squeeze(1)

#                 # Transform depth labels
#                 # depth_labels = torch.log(depth_labels.flatten(start_dim=1)) / 5
#                 # depth_labels = depth_labels.view_as(depth_labels)  # Restore shape
#                 # depth_labels = torch.clamp(depth_labels, min=1e-5) 
#                 # depth_labels = torch.log(depth_labels + 1e-5) / 5  # Avoid log(0)

#                 # Latent noise for validation
#                 latent_noise = torch.randn_like(inputs).to(device)
#                 seg_output, depth_output, backbone_features = model(inputs, latent_noise)

                
                

#                 seg_output_old =seg_output
#                 # Resize seg_output to match the spatial dimensions of seg_labels
#                 seg_output_resized = F.interpolate(seg_output, size=seg_labels.shape[1:], mode='bilinear', align_corners=False)
#                 seg_output = seg_output_resized

#                 depth_output_old = depth_output
#                 depth_output_resized = F.interpolate(depth_output, size=depth_labels.shape[-2:], mode='bilinear', align_corners=False)
#                 depth_output =depth_output_resized


#                 # Segmentation Loss
#                 seg_loss = nn.CrossEntropyLoss()(seg_output, seg_labels)
#                 seg_dice = dice_loss(seg_output, seg_labels)
#                 seg_iou = mean_iou(seg_output, seg_labels, num_classes)
#                 seg_loss_total = 0.6 * seg_loss  + 0.4 * seg_dice
                
#                 depth_sidl = scale_invariant_depth_loss(depth_output, depth_labels)
#                 depth_inv_huber = inv_huber_loss(depth_output, depth_labels)
#                 depth_smooth = depth_smoothness_loss(depth_output, inputs)
#                 depth_loss_total = depth_sidl + depth_inv_huber + depth_smooth

#                 pix2pix_loss = seg_loss_total + depth_loss_total

#                 # Combined Validation Loss
#                 combined_loss = pix2pix_loss

#                 # Accumulate Validation Metrics
#                 epoch_valid["seg"] += seg_loss.item()
#                 epoch_valid["depth"] += (depth_sidl + depth_smooth).item()
#                 epoch_valid["combined"] += combined_loss.item()
#                 epoch_valid["iou"] += seg_iou.item()
#                 epoch_valid["depth_sidl"] += depth_sidl.item()
#                 epoch_valid["depth_smooth"] += depth_smooth.item()
                
#                 num_valid_batches += 1
                
#                 # epoch, inputs, seg_output, depth_output, seg_labels, depth_labels, gif_frames
#             frame = save_training_visualization_as_gif2(epoch, inputs, seg_output, depth_output, seg_labels, depth_labels)
#             gif_frames.append(frame)
                
                
#         # Calculate epoch averages
#         # Average Validation Losses
#         for key in epoch_valid.keys():
#             valid_losses[key].append(epoch_valid[key] / num_valid_batches)

        
        
# # train_losses = { "depth_sidl": [], "depth_inv_huber": [], "depth_contrastive": [], "depth_smooth": []}
#         print(
#             f"Epoch {epoch+1}/{num_epochs} - Valid Seg Loss: {epoch_valid['seg']:.4f}, "
#             f"Valid Depth Loss: {epoch_valid['depth']:.4f}, Valid Combined Loss: {epoch_valid['combined']:.4f}, "
#             f"Valid mIOU: {epoch_valid['iou']:.4f}, Valid sidl Loss: {epoch_valid['depth_sidl']:.4f}, "
#             f"Valid depth smooth: {epoch_valid['depth_smooth']:.4f}"
#         )

#         # Write the losses to CSV
#         with open(csv_path, "a", newline="") as f:
#             writer = csv.writer(f)
#             writer.writerow([
#                 epoch + 1,
#                 train_losses["seg"], train_losses["depth"], train_losses["combined"],
#                 train_losses["depth_sidl"], 0,0,
#                 # avg_train_depth_inv_huber, avg_train_depth_contrastive,
#                 train_losses["depth_smooth"],
#                 valid_losses["seg"], valid_losses["depth"], valid_losses['combined'],
#                 valid_losses["depth_sidl"],0,0,
#                 # avg_valid_depth_inv_huber, avg_valid_depth_contrastive, 
#                 valid_losses["depth_smooth"]
#             ])

#         # Save GIF visualization frames
#         # save_training_visualization_as_gif(epoch, inputs, seg_output, depth_output, seg_labels, depth_labels, gif_frames)

#         # Save best model
#         if valid_losses["combined"][-1] < best_combined_loss:
#             best_combined_loss = valid_losses["combined"][-1]
#             torch.save(model, os.path.join(save_dir, "best_model_resnetBackbone.pth"))
#             print(f"Best model saved at epoch {epoch+1} with combined loss {best_combined_loss:.4f}")
            
#         if epoch%10==0:
#             gif_path2 =os.path.join(save_dir,f"viz_epoch_{epoch}.gif")
#             gif_frames[0].save(gif_path2, save_all=True, append_images=gif_frames[1:], duration=500, loop=0)

#     # Save GIF
#     # gif_frames[0].save(gif_path, save_all=True, append_images=gif_frames[1:], duration=500, loop=0)
    
#     plot_loss(train_losses, valid_losses, save_dir)
#     gif_frames[0].save(gif_path, save_all=True, append_images=gif_frames[1:], duration=500, loop=0)

    
    
    
#     return train_losses,valid_losses
