In [5]:
import os
from glob import glob
import time

import numpy as np
import h5py
import cv2

import torch 
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset,Subset
from torchvision import transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
import torch.nn.functional as F

import matplotlib as mpl
import matplotlib.pyplot as plt

from torchvision.models import mobilenet_v3_small
from sklearn.metrics import jaccard_score
from torchvision.models.mobilenetv3 import MobileNet_V3_Small_Weights
from torchvision.models import vgg16, VGG16_Weights



import csv

from datetime import datetime

from PIL import Image
import torch.nn.functional as F


from labels import labels

%matplotlib inline

curr_dir=os.getcwd()
root= os.path.join(curr_dir,"cityscapes_dataset")
curr_dir,root

('/home/rmajumd/2024/ML_in_image_synthesis/Cityscapes/Cityscapes',
 '/home/rmajumd/2024/ML_in_image_synthesis/Cityscapes/Cityscapes/cityscapes_dataset')

In [6]:

import torch
from torchvision import transforms
import torchvision.transforms.functional as TF
# from albumentations.augmentations.transforms import RandomShadow

class Normalize(object):
    """ Normalizes RGB image to  0-mean 1-std_dev """ 
    def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], depth_norm=5, max_depth=250):
        self.mean = mean
        self.std = std
        self.depth_norm = depth_norm
        self.max_depth = max_depth

    def __call__(self, sample):
        left, mask, depth = sample['left'], sample['mask'], sample['depth']
            
        return {'left': TF.normalize(left, self.mean, self.std), 
                'mask': mask, 
                'depth' : torch.clip( # saftey clip :)
                            torch.log(torch.clip(depth, 0, self.max_depth))/self.depth_norm, 
                            0, 
                            self.max_depth)}


class AddColorJitter(object):
    """Convert a color image to grayscale and normalize the color range to [0,1].""" 
    def __init__(self, brightness, contrast, saturation, hue):
        ''' Applies brightness, constrast, saturation, and hue jitter to image ''' 
        self.color_jitter = transforms.ColorJitter(brightness, contrast, saturation, hue)

    def __call__(self, sample):
        left, mask, depth = sample['left'], sample['mask'], sample['depth']

        return {'left': self.color_jitter(left), 
                'mask': mask, 
                'depth' : depth}


class Rescale(object):
    """ Rescales images with bilinear interpolation and masks with nearest interpolation """

    def __init__(self, h, w):
        self.h, self.w = h, w

    def __call__(self, sample):
        left, mask, depth = sample['left'], sample['mask'], sample['depth']
# mask interpolation Nearest is import to have smoothness
        return {'left': TF.resize(left, (self.h, self.w)), 
                'mask': TF.resize(mask.unsqueeze(0), (self.h, self.w), transforms.InterpolationMode.NEAREST), 
                'depth' : TF.resize(depth.unsqueeze(0), (self.h, self.w))}


class RandomCrop(object):
    def __init__(self, h, w, scale=(0.08, 1.0), ratio=(3.0 / 4.0, 4.0 / 3.0)):
        self.h = h
        self.w = w
        self.scale = scale
        self.ratio = ratio

    def __call__(self, sample):
        left, mask, depth = sample['left'], sample['mask'], sample['depth']
        i, j, h, w = transforms.RandomResizedCrop.get_params(left, scale=self.scale, ratio=self.ratio)

        return {'left': TF.resized_crop(left, i, j, h, w, (self.h, self.w)), 
                'mask': TF.resized_crop(mask.unsqueeze(0), i, j, h, w, (self.h, self.w), interpolation=TF.InterpolationMode.NEAREST),
                'depth' : TF.resized_crop(depth.unsqueeze(0), i, j, h, w, (self.h, self.w))}


class ToTensor(object):
    """Convert ndarrays in sample to Tensors."""
    def __call__(self, sample):
         
        left, mask, depth = sample['left'], sample['mask'], sample['depth']

        return {'left': transforms.ToTensor()(left), 
                'mask': torch.as_tensor(mask, dtype=torch.int64),
                'depth' : transforms.ToTensor()(depth).type(torch.float32)}
    

class ElasticTransform(object):
    def __init__(self, alpha=25.0, sigma=5.0, prob=0.5):
        self.alpha = [1.0, alpha]
        self.sigma = [1, sigma]
        self.prob = prob

    def __call__(self, sample):
        
        if torch.rand(1) < self.prob:

            left, mask, depth = sample['left'], sample['mask'], sample['depth']
            _, H, W = mask.shape
            displacement = transforms.ElasticTransform.get_params(self.alpha, self.sigma, [H, W])

            # # TEMP
            # print(TF.elastic_transform(left, displacement).shape)
            # print(TF.elastic_transform(mask.unsqueeze(0), displacement, interpolation=TF.InterpolationMode.NEAREST).shape)
            # print(torch.clip(TF.elastic_transform(depth, displacement), 0, depth.max()).shape)

            return {'left': TF.elastic_transform(left, displacement), 
                    'mask': TF.elastic_transform(mask.unsqueeze(0), displacement, interpolation=TF.InterpolationMode.NEAREST), 
                    'depth' : torch.clip(TF.elastic_transform(depth, displacement), 0, depth.max())} 
        
        else:
            return sample

        
    

# new transform to rotate the images
class RandomRotate(object):
    def __init__(self, angle):
        if not isinstance(angle, (list, tuple)):
            self.angle = (-abs(angle), abs(angle))
        else:
            self.angle = angle

    def __call__(self, sample):
        left, mask, depth = sample['left'], sample['mask'], sample['depth']

        angle = transforms.RandomRotation.get_params(self.angle)

        return {'left': TF.rotate(left, angle), 
                'mask': TF.rotate(mask.unsqueeze(0), angle), 
                'depth' : TF.rotate(depth, angle)}
    
    
class RandomHorizontalFlip(object):
    def __init__(self, prob=0.5):
        self.prob = prob

    def __call__(self, sample):
        
        if torch.rand(1) < self.prob:
            left, mask, depth = sample['left'], sample['mask'], sample['depth']
            return {'left': TF.hflip(left), 
                    'mask': TF.hflip(mask), 
                    'depth' : TF.hflip(depth)}
        else:
            return sample
        

class RandomVerticalFlip(object):
    def __init__(self, prob=0.5):
        self.prob = prob

    def __call__(self, sample):
        if torch.rand(1) < self.prob:
            left, mask, depth = sample['left'], sample['mask'], sample['depth']
            return {'left': TF.vflip(left), 
                    'mask': TF.vflip(mask), 
                    'depth' : TF.vflip(depth)}
        else:
            return sample
        

In [7]:

def convert_to_numpy(image):
    if not isinstance(image, np.ndarray):
        if len(image.shape) == 2:
            image = image.detach().cpu().numpy()
        else:
            image = image.detach().cpu().numpy().transpose(1, 2, 0)

    return image

def get_color_mask(mask, labels, id_type='id'):
    try:
        h, w = mask.shape
    except ValueError:
        mask = mask.squeeze(-1)
        h, w = mask.shape

    color_mask = np.zeros((h, w, 3), dtype=np.uint8)

    if id_type == 'id':
        for lbl in labels:
            color_mask[mask == lbl.id] = lbl.color
    elif id_type == 'trainId':
        for lbl in labels:
            if (lbl.trainId != 255) | (lbl.trainId != -1):
                color_mask[mask == lbl.trainId] = lbl.color

    return color_mask


def plot_items(left, mask, depth, labels=None, num_seg_labels=34, id_type='id'):
    left = convert_to_numpy(left)
    mask = convert_to_numpy(mask)
    depth = convert_to_numpy(depth)

    # unnormalize left image
    left = (left*np.array([0.229, 0.224, 0.225])) + np.array([0.485, 0.456, 0.406])

    # cmaps: 'prism', 'terrain', 'turbo', 'gist_rainbow_r', 'nipy_spectral_r'
    
    
    _, ax = plt.subplots(1, 3, figsize=(15,10))
    ax[0].imshow(left)
    ax[0].set_title("Left Image")

    if labels:
        color_mask = get_color_mask(mask, labels, id_type)
        ax[1].imshow(color_mask)
    else:
        cmap = mpl.colormaps.get_cmap('nipy_spectral_r').resampled(num_seg_labels)
        ax[1].imshow(mask, cmap=cmap)

    ax[1].set_title("Seg Mask")
    ax[2].imshow(depth, cmap='plasma')
    ax[2].set_title("Depth")

In [8]:
def scale_invariant_depth_loss(pred, target, lambda_weight=0.1):
    if pred.shape != target.shape:
        pred = F.interpolate(pred, size=target.shape[1:], mode='bilinear', align_corners=False)
    
    diff = pred - target
    n = diff.numel()
    mse = torch.sum(diff**2) / n
    scale_invariant = mse - (lambda_weight / (n**2)) * (torch.sum(diff))**2
    return scale_invariant

def depth_smoothness_loss(pred, img, alpha=1.0):
    depth_grad_x = torch.abs(pred[:, :, :, :-1] - pred[:, :, :, 1:])
    depth_grad_y = torch.abs(pred[:, :, :-1, :] - pred[:, :, 1:, :])
    img_grad_x = torch.mean(torch.abs(img[:, :, :, :-1] - img[:, :, :, 1:]), dim=1, keepdim=True)
    img_grad_y = torch.mean(torch.abs(img[:, :, :-1, :] - img[:, :, 1:, :]), dim=1, keepdim=True)
    smoothness_x = depth_grad_x * torch.exp(-alpha * img_grad_x)
    smoothness_y = depth_grad_y * torch.exp(-alpha * img_grad_y)
    return smoothness_x.mean() + smoothness_y.mean()


def inv_huber_loss(pred, target, delta=0.1):
    """
    Inverse Huber loss for depth prediction.
    Args:
        pred (Tensor): Predicted depth map.
        target (Tensor): Ground truth depth map.
        delta (float): Threshold for switching between quadratic and linear terms.
    Returns:
        Tensor: Inverse Huber loss.
    """
    abs_diff = torch.abs(pred - target)
    delta_tensor = torch.tensor(delta, dtype=abs_diff.dtype, device=abs_diff.device)  # Convert delta to tensor
    quadratic = torch.minimum(abs_diff, delta_tensor)
    linear = abs_diff - quadratic
    return (0.5 * quadratic**2 + delta_tensor * linear).mean()


def mean_iou(pred, target, num_classes):
    pred = torch.argmax(pred, dim=1)
    intersection = torch.logical_and(pred == target, target != 255).float()  # Ignore class 255
    union = torch.logical_or(pred == target, target != 255).float()
    iou = torch.sum(intersection) / torch.sum(union)
    return iou



def contrastive_loss(pred, target, margin=1.0):
    """
    Contrastive loss to ensure the depth map predictions are closer to the target.
    """
    # Flatten the tensors for element-wise operations
    pred_flat = pred.view(pred.size(0), -1)  # Flatten except for the batch dimension
    target_flat = target.view(target.size(0), -1)  # Flatten except for the batch dimension

    # Compute the pairwise distances
    distances = torch.sqrt(torch.sum((pred_flat - target_flat) ** 2, dim=1))  # Batch-wise distances

    # Create labels for contrastive loss
    labels = (torch.abs(pred_flat - target_flat).mean(dim=1) < margin).float()  # Batch-wise labels

    # Calculate contrastive loss
    similar_loss = labels * distances**2
    dissimilar_loss = (1 - labels) * torch.clamp(margin - distances, min=0)**2
    loss = (similar_loss + dissimilar_loss).mean()

    return loss


def dice_loss(predictions, targets, smooth=1e-6):
    """
    Calculate Dice Loss for segmentation.
    Args:
        predictions (torch.Tensor): The predicted segmentation map (logits or probabilities).
                                    Shape: [batch_size, num_classes, height, width]
        targets (torch.Tensor): The ground truth segmentation map (one-hot encoded or integer labels).
                                Shape: [batch_size, height, width]
        smooth (float): Smoothing factor to avoid division by zero.
    Returns:
        torch.Tensor: Dice Loss (scalar).
    """
    # Convert integer labels to one-hot if needed
    if predictions.shape != targets.shape:
        targets = F.one_hot(targets, num_classes=predictions.shape[1]).permute(0, 3, 1, 2).float()
    
    # Apply softmax to predictions for multi-class segmentation
    predictions = torch.softmax(predictions, dim=1)
    
    # Flatten tensors to calculate intersection and union
    predictions_flat = predictions.view(predictions.shape[0], predictions.shape[1], -1)
    targets_flat = targets.view(targets.shape[0], targets.shape[1], -1)
    
    # Calculate intersection and union
    intersection = (predictions_flat * targets_flat).sum(dim=-1)
    union = predictions_flat.sum(dim=-1) + targets_flat.sum(dim=-1)
    
    # Calculate Dice Coefficient
    dice_coeff = (2 * intersection + smooth) / (union + smooth)
    
    # Dice Loss
    return 1 - dice_coeff.mean()

In [9]:
def plot_loss(train_losses, valid_losses, save_dir):
    epochs = range(1, len(train_losses["seg"]) + 1)

    # Plot Segmentation Loss
    plt.figure(figsize=(10, 6))
    plt.plot(epochs, train_losses["seg"], label="Train Seg Loss")
    plt.plot(epochs, valid_losses["seg"], label="Valid Seg Loss")
    plt.xlabel("Epochs")
    plt.ylabel("Segmentation Loss")
    plt.legend()
    plt.title("Segmentation Loss Over Epochs")
    plt.savefig(os.path.join(save_dir, "segmentation_loss.png"))
    plt.close()

    # Plot Depth Loss
    plt.figure(figsize=(10, 6))
    plt.plot(epochs, train_losses["depth"], label="Train Depth Loss")
    plt.plot(epochs, valid_losses["depth"], label="Valid Depth Loss")
    plt.xlabel("Epochs")
    plt.ylabel("Depth Loss")
    plt.legend()
    plt.title("Depth Loss Over Epochs")
    plt.savefig(os.path.join(save_dir, "depth_loss.png"))
    plt.close()

    # Plot Combined Loss
    plt.figure(figsize=(10, 6))
    plt.plot(epochs, train_losses["combined"], label="Train Combined Loss")
    plt.plot(epochs, valid_losses["combined"], label="Valid Combined Loss")
    plt.xlabel("Epochs")
    plt.ylabel("Combined Loss")
    plt.legend()
    plt.title("Combined Loss Over Epochs")
    plt.savefig(os.path.join(save_dir, "combined_loss.png"))
    plt.close()


In [10]:
def save_training_visualization_as_gif2(epoch, inputs, seg_output, depth_output, seg_labels, depth_labels):
    inputs = inputs.detach().cpu()
    seg_output = torch.argmax(seg_output, dim=1).detach().cpu()
    depth_output = depth_output.detach().cpu()
    seg_labels = seg_labels.detach().cpu()
    depth_labels = depth_labels.detach().cpu()
    
#     inputs_rgb = (inputs - inputs.min()) / (inputs.max() - inputs.min() + 1e-5)  # Normalize inputs to [0, 1]
    
#     # Normalize depth maps for visualization
#     depth_labels_vis = (depth_labels - depth_labels.min()) / (depth_labels.max() - depth_labels.min() + 1e-5)
#     depth_preds_vis = (depth_output - depth_output.min()) / (depth_output.max() - depth_output.min() + 1e-5)



    batch_size = min(4, inputs.size(0))  # Limit to 4 samples for visualization
    fig, axes = plt.subplots(batch_size, 5, figsize=(15, 4 * batch_size))

    for i in range(batch_size):
        
        inputs_temp = inputs[i]
        # print(f"inputs_temp: {inputs_temp.shape}")
        inputs_rgb = (inputs_temp - inputs_temp.min()) / (inputs_temp.max() - inputs_temp.min() + 1e-5)  # Normalize inputs to [0, 1]
        
        depth_labels_vis = (depth_labels[i] - depth_labels[i].min()) / (depth_labels[i].max() - depth_labels[i].min() + 1e-5)
        depth_preds = depth_output[i]
        depth_preds_vis = (depth_preds - depth_preds.min()) / (depth_preds.max() - depth_preds.min() + 1e-5)
        # print(f"depth_labels_vis: {depth_labels_vis.shape}")
        # print(f"depth_preds_vis: {depth_preds_vis.shape}")

    
        
        # Row 1: Ground truth
        axes[i, 0].imshow(inputs_rgb.permute(1, 2, 0))
        axes[i, 0].set_title("RGB Image")
        axes[i, 0].axis("off")

        axes[i, 1].imshow(seg_labels[i], cmap="tab20")
        axes[i, 1].set_title("GT Segmentation")
        axes[i, 1].axis("off")

        axes[i, 2].imshow(depth_labels_vis.squeeze(), cmap="inferno")
        axes[i, 2].set_title("GT Depth")
        axes[i, 2].axis("off")

        # Row 2: Predictions
        axes[i, 3].imshow(seg_output[i], cmap="tab20")
        axes[i, 3].set_title("Generated Segmentation")
        axes[i, 3].axis("off")

        axes[i, 4].imshow(depth_preds_vis.squeeze(), cmap="inferno")
        axes[i, 4].set_title("Generated Depth")
        axes[i, 4].axis("off")
        
    # Remove axes for cleaner visualization
    for ax in axes.flat:
        ax.axis("off")


    # plt.tight_layout()
    fig.tight_layout()
    fig.canvas.draw()
    
    # # Save current epoch as an image for GIF
    # epoch_img_path = os.path.join(gif_path, f"epoch_{epoch}.png")
    # os.makedirs(gif_path, exist_ok=True)
    # plt.savefig(epoch_img_path)
    # plt.close()
    
    
    # return epoch_img_path
    frame = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8)  # Updated to buffer_rgba
    frame = frame.reshape(fig.canvas.get_width_height()[::-1] + (4,))  # RGBA has 4 channels
    plt.close(fig)

    # Convert to PIL.Image for GIF
    frame_rgb = frame[:, :, :3] 

    # Return as PIL.Image for GIF creation
    # return Image.fromarray(frame)
    return Image.fromarray(frame_rgb)




In [11]:
import torchvision.models as models

# Define the ResBlock
class ResBlock(nn.Module):
    def __init__(self, channels):
        super(ResBlock, self).__init__()
        self.conv_block = nn.Sequential(
            nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False),
            nn.InstanceNorm2d(channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False),
            nn.InstanceNorm2d(channels),
        )

    def forward(self, x):
        return x + self.conv_block(x)

# Define the CRPBlock
class CRPBlock(nn.Module):
    def __init__(self, in_chans, out_chans, n_stages=4, groups=False):
        super(CRPBlock, self).__init__()
        self.n_stages = n_stages
        groups = in_chans if groups else 1
        self.mini_blocks = nn.ModuleList()
        for _ in range(n_stages):
            self.mini_blocks.append(nn.MaxPool2d(kernel_size=5, stride=1, padding=2))
            self.mini_blocks.append(nn.Conv2d(in_chans, out_chans, kernel_size=1, bias=False, groups=groups))
    
    def forward(self, x):
        out = x
        for block in self.mini_blocks:
            out = block(out)
            x = x + out
        return x

class ResNetBackbone(nn.Module):
    def __init__(self, pretrained=True, feature_dim=256):
        super(ResNetBackbone, self).__init__()
        base_model = models.resnet18(pretrained=pretrained)

        # Freeze pre-trained layers
        for param in base_model.parameters():
            param.requires_grad = False

        # Extract ResNet layers and modify strides/pooling to preserve spatial dimensions
        layers = list(base_model.children())[:-2]  # Remove FC and AvgPool layers
        for layer in layers:
            if isinstance(layer, nn.Conv2d):
                layer.stride = (1, 1)  # Set stride to 1
            elif isinstance(layer, nn.MaxPool2d) or isinstance(layer, nn.AvgPool2d):
                layer.stride = (1, 1)  # Avoid reducing dimensions with pooling layers

        self.features = nn.Sequential(*layers)

        # Adjust final feature dimension using a 1x1 convolution
        self.feature_dim = feature_dim
        self.adjust_channels = nn.Conv2d(base_model.fc.in_features, feature_dim, kernel_size=1, bias=False)

    def forward(self, x):
        x = self.features(x)  # Extract features without changing spatial dimensions
        x = self.adjust_channels(x)  # Adjust feature channels
        return x



In [12]:

class CityScapesDataset(Dataset):
    def __init__(self, root, transform=None, split='train', label_map='id', crop=True):
        """
        
        """
        self.root = root
        self.transform = transform
        self.label_map = label_map
        self.crop = crop

        self.left_paths = glob(os.path.join(root, 'leftImg8bit', split, '**/*.png'))
        self.mask_paths = glob(os.path.join(root, 'gtFine', split, '**/*gtFine_labelIds.png'))
        self.depth_paths = glob(os.path.join(root, 'crestereo_depth2', split, '**/*.npy'))

        sorted(self.left_paths)
        sorted(self.mask_paths)
        sorted(self.depth_paths)

        # get label mappings
        self.id_2_train = {}
        self.id_2_cat = {}
        self.train_2_id = {}
        self.id_2_name = {-1 : 'unlabeled'}
        self.trainid_2_name = {19 : 'unlabeled'} # {255 : 'unlabeled', -1 : 'unlabeled'}

        for lbl in labels:
            self.id_2_train.update({lbl.id : lbl.trainId})
            self.id_2_cat.update({lbl.id : lbl.categoryId})
            if lbl.trainId != 19: # (lbl.trainId > 0) and (lbl.trainId != 255):
                self.trainid_2_name.update({lbl.trainId : lbl.name})
                self.train_2_id.update({lbl.trainId : lbl.id})
            if lbl.id > 0:
                self.id_2_name.update({lbl.id : lbl.name})


    def __getitem__(self, idx):
        left = cv2.cvtColor(cv2.imread(self.left_paths[idx]), cv2.COLOR_BGR2RGB)
        mask = cv2.imread(self.mask_paths[idx], cv2.IMREAD_UNCHANGED).astype(np.uint8)
        depth = np.load(self.depth_paths[idx]) # data is type float16

        if self.crop:
            left = left[:800, :, :]
            mask = mask[:800, :]
            depth = depth[:800, :]

        # get label id
        if self.label_map == 'id':
            mask[mask==-1] == 0
        elif self.label_map == 'trainId':
            for _id, train_id in self.id_2_train.items():
                mask[mask==_id] = train_id
        elif self.label_map == 'categoryId':
            for _id, train_id in self.id_2_cat.items():
                mask[mask==_id] = train_id

        sample = {'left' : left, 'mask' : mask, 'depth' : depth}

        if self.transform:
            sample = self.transform(sample)

        # ensure that no depth values are less than 0
        depth[depth < 0] = 0

        return sample
    

    def __len__(self):
        print(f"Number of RGB images: {len(self.left_paths)}")
        print(f"Number of Mask images: {len(self.mask_paths)}")
        print(f"Number of depth images: {len(self.depth_paths)}")
        return len(self.left_paths)
    
    

In [13]:
OG_W, OG_H = 2048, 800 # OG width and height after crop
W, H = OG_W//4, OG_H//4 # resize w,h for training

transform = transforms.Compose([
    ToTensor(),
    RandomCrop(H, W),
    # ElasticTransform(alpha=100.0, sigma=25.0, prob=0.5),
    AddColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),
    RandomHorizontalFlip(0.5),
    RandomVerticalFlip(0.2),
    # RandomRotate((-30, 30)),
    Normalize()
])

valid_transform = transforms.Compose([
    ToTensor(),
    Rescale(H, W),
    Normalize()
])

test_transform = transforms.Compose([
    ToTensor(),
    Normalize()
])


BATCH_SIZE = 8
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'



train_dataset = CityScapesDataset(root, transform=transform, split='train', label_map='trainId') # 'trainId')
train_subset = Subset(train_dataset, indices=range(2968)) #2968
# train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, pin_memory=True, shuffle=True)
train_loader = DataLoader(train_subset, batch_size=BATCH_SIZE, pin_memory=True, shuffle=True)


valid_dataset = CityScapesDataset(root, transform=valid_transform, split='val', label_map='trainId')
val_subset = Subset(valid_dataset, indices=range(496)) #496 
# valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, pin_memory=True, shuffle=False)
valid_loader = DataLoader(val_subset, batch_size=BATCH_SIZE, pin_memory=True, shuffle=False)


# shared Generator

In [14]:
# class SharedGenerator(nn.Module):
#     def __init__(self):
#         """
#         Shared Generator for both tasks.
#         Contains shared layers for skip connection processing and refinement.
#         """
#         super(SharedGenerator, self).__init__()
#         # Shared convolution layers to process each skip connection
#         self.shared_conv1 = nn.Conv2d(576, 256, kernel_size=1, bias=False)  # Process l11_out (1/32)
#         self.shared_conv2 = nn.Conv2d(576, 256, kernel_size=1, bias=False)  # Process l7_out (1/16)
#         self.shared_conv3 = nn.Conv2d(576, 256, kernel_size=1, bias=False)  # Process l3_out (1/8)
#         self.shared_conv4 = nn.Conv2d(576, 256, kernel_size=1, bias=False)  # Process l1_out (1/4)

#         # Shared CRP blocks for refinement
#         self.shared_crp1 = CRPBlock(256, 256, n_stages=4)  # CRP for 1/32
#         self.shared_crp2 = CRPBlock(256, 256, n_stages=4)  # CRP for 1/16
#         self.shared_crp3 = CRPBlock(256, 256, n_stages=4)  # CRP for 1/8
#         self.shared_crp4 = CRPBlock(256, 256, n_stages=4)  # CRP for 1/4

#     def forward(self, skips):
#         """
#         Process skips with shared layers for task-specific generation.
#         Args:
#             skips (dict): Skip connections from the encoder.
#         Returns:
#             dict: Processed skip connections.
#         """
#         x1 = self.shared_crp1(self.shared_conv1(skips["l11_out"]))
#         x2 = self.shared_crp2(self.shared_conv2(skips["l7_out"]))
#         x3 = self.shared_crp3(self.shared_conv3(skips["l3_out"]))
#         x4 = self.shared_crp4(self.shared_conv4(skips["l1_out"]))

#         return {"x1": x1, "x2": x2, "x3": x3, "x4": x4}


In [15]:
# class SharedPix2PixGenerator(nn.Module):
#     def __init__(self, seg_output_channels=20, depth_output_channels=1):
#         """
#         Shared Pix2Pix Generator for Segmentation and Depth tasks.
#         Args:
#             seg_output_channels (int): Number of output channels for segmentation.
#             depth_output_channels (int): Number of output channels for depth estimation.
#         """
#         super(SharedPix2PixGenerator, self).__init__()
#         self.shared_generator = SharedGenerator()
#         self.seg_output_layer = TaskOutputLayer(output_channels=seg_output_channels)
#         self.depth_output_layer = TaskOutputLayer(output_channels=depth_output_channels)

#     def forward(self, skips, input_size):
#         """
#         Forward pass for both tasks.
#         Args:
#             skips (dict): Skip connections from the encoder.
#             input_size (tuple): Original input size (H, W).
#         Returns:
#             dict: Outputs for segmentation and depth tasks.
#         """
#         shared_features = self.shared_generator(skips)

#         # Task-specific outputs
#         seg_output = self.seg_output_layer(shared_features["x4"], input_size)
#         depth_output = self.depth_output_layer(shared_features["x4"], input_size)

#         return {
#             "seg_output": seg_output,
#             "depth_output": depth_output
#         }


# Saving batch gif code And function to plot al losses


In [16]:
def save_training_visualization_as_gif2(epoch, inputs, seg_output, depth_output, seg_labels, depth_labels):
    inputs = inputs.detach().cpu()
    seg_output = torch.argmax(seg_output, dim=1).detach().cpu()
    depth_output = depth_output.detach().cpu()
    seg_labels = seg_labels.detach().cpu()
    depth_labels = depth_labels.detach().cpu()
    
#     inputs_rgb = (inputs - inputs.min()) / (inputs.max() - inputs.min() + 1e-5)  # Normalize inputs to [0, 1]
    
#     # Normalize depth maps for visualization
#     depth_labels_vis = (depth_labels - depth_labels.min()) / (depth_labels.max() - depth_labels.min() + 1e-5)
#     depth_preds_vis = (depth_output - depth_output.min()) / (depth_output.max() - depth_output.min() + 1e-5)



    batch_size = min(4, inputs.size(0))  # Limit to 4 samples for visualization
    fig, axes = plt.subplots(batch_size, 5, figsize=(15, 4 * batch_size))

    for i in range(batch_size):
        
        inputs_temp = inputs[i]
        # print(f"inputs_temp: {inputs_temp.shape}")
        inputs_rgb = (inputs_temp - inputs_temp.min()) / (inputs_temp.max() - inputs_temp.min() + 1e-5)  # Normalize inputs to [0, 1]
        
        depth_labels_vis = (depth_labels[i] - depth_labels[i].min()) / (depth_labels[i].max() - depth_labels[i].min() + 1e-5)
        depth_preds = depth_output[i]
        depth_preds_vis = (depth_preds - depth_preds.min()) / (depth_preds.max() - depth_preds.min() + 1e-5)
        # print(f"depth_labels_vis: {depth_labels_vis.shape}")
        # print(f"depth_preds_vis: {depth_preds_vis.shape}")

    
        
        # Row 1: Ground truth
        axes[i, 0].imshow(inputs_rgb.permute(1, 2, 0))
        axes[i, 0].set_title("RGB Image")
        axes[i, 0].axis("off")

        axes[i, 1].imshow(seg_labels[i], cmap="tab20")
        axes[i, 1].set_title("GT Segmentation")
        axes[i, 1].axis("off")

        axes[i, 2].imshow(depth_labels_vis.squeeze(), cmap="inferno")
        axes[i, 2].set_title("GT Depth")
        axes[i, 2].axis("off")

        # Row 2: Predictions
        axes[i, 3].imshow(seg_output[i], cmap="tab20")
        axes[i, 3].set_title("Generated Segmentation")
        axes[i, 3].axis("off")

        axes[i, 4].imshow(depth_preds_vis.squeeze(), cmap="inferno")
        axes[i, 4].set_title("Generated Depth")
        axes[i, 4].axis("off")
        
    # Remove axes for cleaner visualization
    for ax in axes.flat:
        ax.axis("off")


    # plt.tight_layout()
    fig.tight_layout()
    fig.canvas.draw()
    
    # # Save current epoch as an image for GIF
    # epoch_img_path = os.path.join(gif_path, f"epoch_{epoch}.png")
    # os.makedirs(gif_path, exist_ok=True)
    # plt.savefig(epoch_img_path)
    # plt.close()
    
    
    # return epoch_img_path
    frame = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8)  # Updated to buffer_rgba
    frame = frame.reshape(fig.canvas.get_width_height()[::-1] + (4,))  # RGBA has 4 channels
    plt.close(fig)

    # Convert to PIL.Image for GIF
    frame_rgb = frame[:, :, :3] 

    # Return as PIL.Image for GIF creation
    # return Image.fromarray(frame)
    return Image.fromarray(frame_rgb)

def plot_all_losses(train_losses,valid_losses,save_dir):
    # Plot training and validation losses
    for key in train_losses.keys():
        plt.figure()
        plt.plot(train_losses[key], label=f"Train {key}")
        plt.plot(valid_losses[key], label=f"Valid {key}")
        plt.xlabel("Epoch")
        plt.ylabel(key.replace("_", " ").title())
        plt.legend()
        plt.savefig(os.path.join(save_dir, f"{key}_loss.png"))
        plt.close()


# Loss Function

In [17]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models

class PerceptualLoss(nn.Module):
    def __init__(self, pretrained_model="vgg16", layers=["relu3_3"], device="cuda"):
        """
        Perceptual loss class.

        Args:
            pretrained_model (str): Pretrained model to use (e.g., "vgg16").
            layers (list of str): Layers to extract features from.
            device (str): Device to load the pretrained model on ("cuda" or "cpu").
        """
        super().__init__()

        # Load pretrained model
        if pretrained_model == "vgg16":
            vgg = models.vgg16(weights=VGG16_Weights.IMAGENET1K_V1).features.to(device).eval()
        else:
            raise ValueError(f"Unsupported pretrained model: {pretrained_model}")

        # Freeze the parameters
        for param in vgg.parameters():
            param.requires_grad = False

        # Select layers
        self.layers = layers
        self.feature_extractor = nn.ModuleDict({
            layer: vgg[:i] for i, layer in enumerate(vgg._modules.keys()) if layer in self.layers
        })

    def forward(self, generated, target):
        """
        Compute perceptual loss between generated and target images.

        Args:
            generated (torch.Tensor): Generated image batch.
            target (torch.Tensor): Target image batch.

        Returns:
            torch.Tensor: MSE loss between extracted features.
        """
        loss = 0.0
        for layer_name, extractor in self.feature_extractor.items():
            gen_features = extractor(generated)
            target_features = extractor(target)
            loss += F.mse_loss(gen_features, target_features)
        return loss

def scale_invariant_depth_loss(pred, target, lambda_weight=0.1):
    if pred.shape != target.shape:
        pred = F.interpolate(pred, size=target.shape[1:], mode='bilinear', align_corners=False)
    
    diff = pred - target
    n = diff.numel()
    mse = torch.sum(diff**2) / n
    scale_invariant = mse - (lambda_weight / (n**2)) * (torch.sum(diff))**2
    return scale_invariant

def depth_smoothness_loss(pred, img, alpha=1.0):
    depth_grad_x = torch.abs(pred[:, :, :, :-1] - pred[:, :, :, 1:])
    depth_grad_y = torch.abs(pred[:, :, :-1, :] - pred[:, :, 1:, :])
    img_grad_x = torch.mean(torch.abs(img[:, :, :, :-1] - img[:, :, :, 1:]), dim=1, keepdim=True)
    img_grad_y = torch.mean(torch.abs(img[:, :, :-1, :] - img[:, :, 1:, :]), dim=1, keepdim=True)
    smoothness_x = depth_grad_x * torch.exp(-alpha * img_grad_x)
    smoothness_y = depth_grad_y * torch.exp(-alpha * img_grad_y)
    return smoothness_x.mean() + smoothness_y.mean()


def inv_huber_loss(pred, target, delta=0.1):
    """
    Inverse Huber loss for depth prediction.
    Args:
        pred (Tensor): Predicted depth map.
        target (Tensor): Ground truth depth map.
        delta (float): Threshold for switching between quadratic and linear terms.
    Returns:
        Tensor: Inverse Huber loss.
    """
    abs_diff = torch.abs(pred - target)
    delta_tensor = torch.tensor(delta, dtype=abs_diff.dtype, device=abs_diff.device)  # Convert delta to tensor
    quadratic = torch.minimum(abs_diff, delta_tensor)
    linear = abs_diff - quadratic
    return (0.5 * quadratic**2 + delta_tensor * linear).mean()


def mean_iou(pred, target, num_classes):
    pred = torch.argmax(pred, dim=1)
    intersection = torch.logical_and(pred == target, target != 255).float()  # Ignore class 255
    union = torch.logical_or(pred == target, target != 255).float()
    iou = torch.sum(intersection) / torch.sum(union)
    return iou



def contrastive_loss(pred, target, margin=1.0):
    """
    Contrastive loss to ensure the depth map predictions are closer to the target.
    """
    # Flatten the tensors for element-wise operations
    pred_flat = pred.view(pred.size(0), -1)  # Flatten except for the batch dimension
    target_flat = target.view(target.size(0), -1)  # Flatten except for the batch dimension

    # Compute the pairwise distances
    distances = torch.sqrt(torch.sum((pred_flat - target_flat) ** 2, dim=1))  # Batch-wise distances

    # Create labels for contrastive loss
    labels = (torch.abs(pred_flat - target_flat).mean(dim=1) < margin).float()  # Batch-wise labels

    # Calculate contrastive loss
    similar_loss = labels * distances**2
    dissimilar_loss = (1 - labels) * torch.clamp(margin - distances, min=0)**2
    loss = (similar_loss + dissimilar_loss).mean()

    return loss


def dice_loss(predictions, targets, smooth=1e-6):
    """
    Calculate Dice Loss for segmentation.
    Args:
        predictions (torch.Tensor): The predicted segmentation map (logits or probabilities).
                                    Shape: [batch_size, num_classes, height, width]
        targets (torch.Tensor): The ground truth segmentation map (one-hot encoded or integer labels).
                                Shape: [batch_size, height, width]
        smooth (float): Smoothing factor to avoid division by zero.
    Returns:
        torch.Tensor: Dice Loss (scalar).
    """
    # Convert integer labels to one-hot if needed
    if predictions.shape != targets.shape:
        targets = F.one_hot(targets, num_classes=predictions.shape[1]).permute(0, 3, 1, 2).float()
    
    # Apply softmax to predictions for multi-class segmentation
    predictions = torch.softmax(predictions, dim=1)
    
    # Flatten tensors to calculate intersection and union
    predictions_flat = predictions.view(predictions.shape[0], predictions.shape[1], -1)
    targets_flat = targets.view(targets.shape[0], targets.shape[1], -1)
    
    # Calculate intersection and union
    intersection = (predictions_flat * targets_flat).sum(dim=-1)
    union = predictions_flat.sum(dim=-1) + targets_flat.sum(dim=-1)
    
    # Calculate Dice Coefficient
    dice_coeff = (2 * intersection + smooth) / (union + smooth)
    
    # Dice Loss
    return 1 - dice_coeff.mean()


def initialize_optimizers_and_schedulers(model, lr_gen=1e-4, lr_disc=1e-4, weight_decay=1e-4):
    """
    Initialize optimizers and schedulers for all generators and discriminators.
    
    Args:
        model (nn.Module): MultiTaskModel instance.
        lr_gen (float): Learning rate for generators.
        lr_disc (float): Learning rate for discriminators.
        weight_decay (float): Weight decay for optimizers.
    
    Returns:
        dict: Optimizers and schedulers for generators and discriminators.
    """
    # Optimizers for shared generator
    optimizer_shared_gen = torch.optim.AdamW(
        model.feature_generator.parameters(),
        lr=lr_gen,
        weight_decay=weight_decay
    )
    scheduler_shared_gen = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer_shared_gen, T_max=50, eta_min=1e-6
    )

    # Optimizer and scheduler for the shared generator's refinement layer
    optimizer_shared_refine = torch.optim.AdamW(
        model.shared_generator.parameters(),
        lr=lr_gen,
        weight_decay=weight_decay
    )
    scheduler_shared_refine = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer_shared_refine, T_max=50, eta_min=1e-6
    )

    # Optimizers and schedulers for task-specific generators
    optimizer_seg_gen = torch.optim.AdamW(
        model.seg_output_layer.parameters(),
        lr=lr_gen,
        weight_decay=weight_decay
    )
    scheduler_seg_gen = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer_seg_gen, mode='min', factor=0.5, patience=5
    )

    optimizer_depth_gen = torch.optim.AdamW(
        model.depth_output_layer.parameters(),
        lr=lr_gen,
        weight_decay=weight_decay
    )
    scheduler_depth_gen = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer_depth_gen, mode='min', factor=0.5, patience=5
    )

    # Optimizers and schedulers for task-specific discriminators
    optimizer_seg_disc = torch.optim.AdamW(
        model.seg_discriminator.parameters(),
        lr=lr_disc,
        weight_decay=weight_decay
    )
    scheduler_seg_disc = torch.optim.lr_scheduler.StepLR(
        optimizer_seg_disc, step_size=20, gamma=0.1
    )

    optimizer_depth_disc = torch.optim.AdamW(
        model.depth_discriminator.parameters(),
        lr=lr_disc,
        weight_decay=weight_decay
    )
    scheduler_depth_disc = torch.optim.lr_scheduler.StepLR(
        optimizer_depth_disc, step_size=20, gamma=0.1
    )

    # Optimizer and scheduler for the multi-task discriminator
    optimizer_multi_task_disc = torch.optim.AdamW(
        model.multi_task_discriminator.parameters(),
        lr=lr_disc,
        weight_decay=weight_decay
    )
    scheduler_multi_task_disc = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer_multi_task_disc, T_max=50, eta_min=1e-6
    )

    return {
        "optimizers": {
            "shared_gen": optimizer_shared_gen,
            "shared_refine": optimizer_shared_refine,
            "seg_gen": optimizer_seg_gen,
            "depth_gen": optimizer_depth_gen,
            "seg_disc": optimizer_seg_disc,
            "depth_disc": optimizer_depth_disc,
            "multi_task_disc": optimizer_multi_task_disc
        },
        "schedulers": {
            "shared_gen": scheduler_shared_gen,
            "shared_refine": scheduler_shared_refine,
            "seg_gen": scheduler_seg_gen,
            "depth_gen": scheduler_depth_gen,
            "seg_disc": scheduler_seg_disc,
            "depth_disc": scheduler_depth_disc,
            "multi_task_disc": scheduler_multi_task_disc
        }
    }


## MultiTaskModel

In [18]:
class MobileNetV3Backbone(nn.Module):
    def __init__(self, backbone):
        super().__init__()

        self.backbone = backbone
        self.proj_l1 = nn.Conv2d(16, 576, kernel_size=1, bias=False)   # For l1_out (1/4 resolution)
        self.proj_l3 = nn.Conv2d(24, 576, kernel_size=1, bias=False)  # For l3_out (1/8 resolution)
        self.proj_l7 = nn.Conv2d(48, 576, kernel_size=1, bias=False)  # For l7_out (1/16 resolution)
        self.proj_l11 = nn.Conv2d(96, 576, kernel_size=1, bias=False) # For l11_out (1/32 resolution)

    
    def forward(self, x):
        """ Passes input theough MobileNetV3 backbone feature extraction layers
            layers to add connections to (0 indexed)
                - 1:  1/4 res
                - 3:  1/8 res
                - 7, 8:  1/16 res
                - 10, 11: 1/32 res
           """
        # skips = nn.ParameterDict()
        # for i in range(len(self.backbone) - 1):
        #     x = self.backbone[i](x)
        #     # add skip connection outputs
        #     if i in [1, 3, 7, 11]:
        #         skips.update({f"l{i}_out" : x})

        # return skips
        skips = {}  # Dictionary to store skip connections

        for i, layer in enumerate(self.backbone):
            x = layer(x)
            # Add skip connections for specific layers
            if i == 1:
                skips["l1_out"] = self.proj_l1(x)  # Project l1_out
            elif i == 3:
                skips["l3_out"] = self.proj_l3(x)  # Project l3_out
            elif i == 7:
                skips["l7_out"] = self.proj_l7(x)  # Project l7_out
            elif i == 11:
                skips["l11_out"] = self.proj_l11(x)  # Project l11_out

        return skips

In [19]:
class EnhancedSharedGenerator(nn.Module):
    def __init__(self):
        """
        Enhanced Shared Generator for Segmentation and Depth tasks.
        Includes additional refinement layers for better generalization.
        """
        super(EnhancedSharedGenerator, self).__init__()
        # Shared convolution layers to process each skip connection
        self.shared_conv1 = nn.Conv2d(576, 256, kernel_size=1, bias=False)  # Process l11_out (1/32)
        self.shared_conv2 = nn.Conv2d(576, 256, kernel_size=1, bias=False)  # Process l7_out (1/16)
        self.shared_conv3 = nn.Conv2d(576, 256, kernel_size=1, bias=False)  # Process l3_out (1/8)
        self.shared_conv4 = nn.Conv2d(576, 256, kernel_size=1, bias=False)  # Process l1_out (1/4)

        # CRP blocks for refinement
        self.shared_crp1 = CRPBlock(256, 256, n_stages=4)
        self.shared_crp2 = CRPBlock(256, 256, n_stages=4)
        self.shared_crp3 = CRPBlock(256, 256, n_stages=4)
        self.shared_crp4 = CRPBlock(256, 256, n_stages=4)

        # Additional refinement layers
        self.refine = nn.Sequential(
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1)
        )

    def forward(self, skips):
        """
        Process skips with shared layers for task-specific generation.
        Args:
            skips (dict): Skip connections from the encoder.
        Returns:
            dict: Processed skip connections.
        """
        x1 = self.shared_crp1(self.shared_conv1(skips["l11_out"]))
        x1 = self.refine(x1)  # Extra refinement

        x2 = self.shared_crp2(self.shared_conv2(skips["l7_out"]))
        x2 = self.refine(x2)

        x3 = self.shared_crp3(self.shared_conv3(skips["l3_out"]))
        x3 = self.refine(x3)

        x4 = self.shared_crp4(self.shared_conv4(skips["l1_out"]))
        x4 = self.refine(x4)

        return {"x1": x1, "x2": x2, "x3": x3, "x4": x4}


In [20]:
class TaskOutputLayer(nn.Module):
    def __init__(self, output_channels):
        """
        Task-specific output layers for generating final predictions.
        Args:
            output_channels (int): Number of output channels (e.g., 20 for segmentation, 1 for depth).
        """
        super(TaskOutputLayer, self).__init__()
        self.final_conv = nn.Conv2d(256, output_channels, kernel_size=3, padding=1)

    def forward(self, x, input_size):
        """
        Generate task-specific output.
        Args:
            x (Tensor): Input feature map.
            input_size (tuple): Original input size (H, W).
        Returns:
            Tensor: Task-specific output.
        """
        x = self.final_conv(x)
        return nn.functional.interpolate(x, size=input_size, mode="bilinear", align_corners=False)



In [21]:
class TaskSpecificDiscriminator(nn.Module):
    def __init__(self, input_channels):
        super(TaskSpecificDiscriminator, self).__init__()
        self.adapt_conv = nn.Conv2d(input_channels+input_channels, input_channels, kernel_size=1, bias=False)
        self.model = nn.Sequential(
            nn.Conv2d(input_channels, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 1, kernel_size=4, stride=1, padding=1)
        )

    def forward(self, task_output, labels=None):
        """
        Forward pass through the discriminator.

        Args:
            task_output (Tensor): Output from the generator (e.g., seg_output or depth_output).
            labels (Tensor, optional): Ground truth labels. If provided, aligns channels with task_output.

        Returns:
            Tensor: Discriminator's prediction.
        """
        if labels is not None:
            # Ensure labels match the shape of task_output
            if labels.dim() < task_output.dim():
                labels = labels.unsqueeze(1)  # Add channel dimension if needed
            if labels.size(1) != task_output.size(1):
                labels = torch.nn.functional.one_hot(labels.squeeze(1), num_classes=task_output.size(1))
                labels = labels.permute(0, 3, 1, 2).float().to(task_output.device)
            combined = torch.cat([task_output, labels], dim=1)
            combined = self.adapt_conv(combined)
        else:
            combined = task_output

        return self.model(combined)


In [22]:

class MultiTaskDiscriminator(nn.Module):
    def __init__(self, input_channels):
        """
        Multi-Task Discriminator for evaluating all task-specific outputs.
        Args:
            input_channels (int): Number of input channels for concatenated features and outputs.
        """
        super(MultiTaskDiscriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(input_channels, 128, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1)
        )

    def forward(self, inputs):
        """
        Evaluate input image and task-specific outputs.
        Args:
            inputs (Tensor): Input image.
            outputs (list[Tensor]): List of task-specific outputs.
        Returns:
            Tensor: Discriminator output.
        """
        # combined = torch.cat([inputs] + outputs, dim=1)
        return self.model(inputs)


In [23]:
class MultiTaskModel(nn.Module):
    def __init__(self, backbone, num_seg_classes=20, depth_channels=1):
        """
        Multi-task model with shared Pix2Pix Generator, task-specific discriminators,
        and a multi-task discriminator.
        Args:
            backbone (nn.Module): Encoder backbone for feature extraction.
            num_seg_classes (int): Number of segmentation classes.
            depth_channels (int): Number of output channels for depth.
        """
        super(MultiTaskModel, self).__init__()
        self.feature_generator = MobileNetV3Backbone(backbone)
        self.shared_generator = EnhancedSharedGenerator()
        self.seg_output_layer = TaskOutputLayer(output_channels=num_seg_classes)
        self.depth_output_layer = TaskOutputLayer(output_channels=depth_channels)

        # Task-specific discriminators
        self.seg_discriminator = TaskSpecificDiscriminator(input_channels=num_seg_classes)
        self.depth_discriminator = TaskSpecificDiscriminator(input_channels=depth_channels)

        # Multi-task discriminator
        self.multi_task_discriminator = MultiTaskDiscriminator(input_channels=3 + num_seg_classes + depth_channels)
        
    def forward(self, inputs, input_size, seg_labels=None, depth_labels=None, return_discriminator_outputs=False):
        # Extract features from the encoder
        skips = self.feature_generator(inputs)
        shared_features = self.shared_generator(skips)

        # Task-specific outputs
        seg_output = self.seg_output_layer(shared_features["x4"], input_size)
        depth_output = self.depth_output_layer(shared_features["x4"], input_size)

        output_dict = {
            "seg_output": seg_output,
            "depth_output": depth_output,
        }

        if return_discriminator_outputs:
            
            # Detach outputs to prevent discriminator backward from interfering with the generator
            seg_output_detached = seg_output.detach()
            depth_output_detached = depth_output.detach()
            
            # Adversarial feedback from task-specific discriminators
            seg_real_disc = self.seg_discriminator(seg_output_detached, seg_labels) if seg_labels is not None else None
            seg_fake_disc = self.seg_discriminator(seg_output_detached, None)

            depth_real_disc = self.depth_discriminator(depth_output_detached, depth_labels) if depth_labels is not None else None
            depth_fake_disc = self.depth_discriminator(depth_output_detached, None)

            # Multi-task discriminator feedback
            combined_real_input = torch.cat([inputs, seg_labels, depth_labels], dim=1) if seg_labels is not None and depth_labels is not None else None
            combined_fake_input = torch.cat([inputs, seg_output, depth_output], dim=1)

            combined_real_disc = self.multi_task_discriminator(combined_real_input) if combined_real_input is not None else None
            combined_fake_disc = self.multi_task_discriminator(combined_fake_input.detach())

            output_dict.update({
                "seg_real_disc": seg_real_disc,
                "seg_fake_disc": seg_fake_disc,
                "depth_real_disc": depth_real_disc,
                "depth_fake_disc": depth_fake_disc,
                "combined_real_disc": combined_real_disc,
                "combined_fake_disc": combined_fake_disc,
            })

        return output_dict


#     def forward(self, inputs, input_size, seg_labels=None, depth_labels=None, return_discriminator_outputs=False):
#         """
#         Forward pass for multi-task model.
#         Args:
#             inputs (Tensor): Input images.
#             input_size (tuple): Original input size.
#             seg_labels (Tensor, optional): Ground truth segmentation labels. Required for discriminator feedback.
#             depth_labels (Tensor, optional): Ground truth depth labels. Required for discriminator feedback.
#             return_discriminator_outputs (bool): If True, returns discriminator outputs for adversarial loss.
#         Returns:
#             dict: Outputs for segmentation and depth tasks, and optionally discriminator outputs.
#         """
#         skips = self.feature_generator(inputs)
#         shared_features = self.shared_generator(skips)

#         # Task-specific outputs
#         seg_output = self.seg_output_layer(shared_features["x4"], input_size)
#         depth_output = self.depth_output_layer(shared_features["x4"], input_size)

#         if return_discriminator_outputs:
#             # Adversarial feedback from task-specific discriminators
#             seg_real_disc = self.seg_discriminator(seg_output, seg_labels) if seg_labels is not None else None
#             depth_real_disc = self.depth_discriminator(depth_output, depth_labels) if depth_labels is not None else None

#             # Multi-task discriminator feedback
#             combined_real_disc = self.multi_task_discriminator(inputs, [seg_output, depth_output])

#             return {
#                 "seg_output": seg_output,
#                 "depth_output": depth_output,
#                 "seg_real_disc": seg_real_disc,
#                 "depth_real_disc": depth_real_disc,
#                 "combined_real_disc": combined_real_disc
#             }

#         return {
#             "seg_output": seg_output,
#             "depth_output": depth_output
#         }


In [24]:
from PIL import Image, ImageSequence
import os

def combine_training_gifs(model_dir, save_dir2, output_path):
    """
    Combine two training visualization GIFs into one.
    
    Args:
        model_dir: Directory containing the first training GIF.
        save_dir2: Directory containing the second training GIF.
        output_path: Path to save the combined GIF.
    """
    # Find the GIF files
    model_dir_gif = [file for file in os.listdir(model_dir) if file.startswith("training_visualization") and file.endswith(".gif")]
    save_dir2_gif = [file for file in os.listdir(save_dir2) if file.startswith("training_visualization") and file.endswith(".gif")]
    
    if not model_dir_gif or not save_dir2_gif:
        raise FileNotFoundError("Could not find training_visualization_*.gif in one of the directories.")
    
    model_dir_gif_path = os.path.join(model_dir, model_dir_gif[0])
    save_dir2_gif_path = os.path.join(save_dir2, save_dir2_gif[0])

    # Open the GIFs
    gif1 = Image.open(model_dir_gif_path)
    gif2 = Image.open(save_dir2_gif_path)

    # Collect all frames from both GIFs
    combined_frames = []
    for frame in ImageSequence.Iterator(gif1):
        combined_frames.append(frame.copy())
    for frame in ImageSequence.Iterator(gif2):
        combined_frames.append(frame.copy())

    # Save the combined GIF
    combined_frames[0].save(
        output_path,
        save_all=True,
        append_images=combined_frames[1:],
        duration=gif1.info.get("duration", 500),  # Use duration from the first GIF
        loop=0
    )

    print(f"Combined GIF saved to {output_path}")


# Saving loss charts

In [25]:
def plot_all_losses(epoch, train_losses,valid_losses,save_dir):
    # Plot training and validation losses
    for key in train_losses.keys():
        plt.figure()
        plt.plot(train_losses[key], label=f"Train {key}")
        plt.plot(valid_losses[key], label=f"Valid {key}")
        plt.xlabel("Epoch")
        plt.ylabel(key.replace("_", " ").title())
        plt.legend()
        plt.savefig(os.path.join(save_dir, f"{key}_loss_after_epoch_{epoch}.png"))
        plt.close()

# saving checkpoints

In [26]:
# Save checkpoint including model, optimizer, and scheduler states
def save_checkpoint(model, opt_sched, save_path, epoch, best_loss):
    checkpoint = {
        "model_state_dict": model.state_dict(),
        "optimizer_states": {name: opt.state_dict() for name, opt in opt_sched["optimizers"].items()},
        "scheduler_states": {name: sched.state_dict() for name, sched in opt_sched["schedulers"].items()},
        "epoch": epoch,
        "best_loss": best_loss
    }
    torch.save(checkpoint, save_path, _use_new_zipfile_serialization=True)
    
def load_checkpoint(model, opt_sched, checkpoint_path, device):
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint["model_state_dict"])
    
    for name, opt in opt_sched["optimizers"].items():
        if name in checkpoint["optimizer_states"]:
            opt.load_state_dict(checkpoint["optimizer_states"][name])
            
    for name, sched in opt_sched["schedulers"].items():
        if name in checkpoint["scheduler_states"]:
            sched.load_state_dict(checkpoint["scheduler_states"][name])
            
    # return checkpoint["epoch"], checkpoint["best_loss"]
    return checkpoint.get("epoch", 0), checkpoint.get("best_loss", float("inf"))



In [27]:
# import pandas as pd

def combine_and_plot_loss_data(model_dir, save_dir2, combined_save_dir="all_data_from_prev_curr_epoch"):
    """
    Combines loss data from previous and current training sessions and plots combined graphs.
    
    Args:
        model_dir: Path to the directory containing the previous loss-tracking CSV.
        save_dir2: Path to the directory containing the current loss-tracking CSV.
        combined_save_dir: Path to save the combined data and plots.

    Returns:
        combined_df: A pandas DataFrame containing the combined loss data.
    """
    # Ensure the save directory exists
    combined_save_dir = os.path.join(save_dir2, combined_save_dir)
    os.makedirs(combined_save_dir, exist_ok=True)

    # Locate CSV files
    previous_csv = os.path.join(model_dir, [file for file in os.listdir(model_dir) if file.endswith(".csv")][0])
    current_csv = os.path.join(save_dir2, [file for file in os.listdir(save_dir2) if file.endswith(".csv")][0])

    # Load data into pandas DataFrames
    previous_df = pd.read_csv(previous_csv)
    current_df = pd.read_csv(current_csv)

    # Update epoch numbers in the current DataFrame
    max_prev_epoch = previous_df["epoch"].max()
    current_df["epoch"] += max_prev_epoch

    # Combine the DataFrames
    combined_df = pd.concat([previous_df, current_df], ignore_index=True)

    # Save the combined DataFrame
    combined_csv_path = os.path.join(combined_save_dir, "combined_loss_tracking.csv")
    combined_df.to_csv(combined_csv_path, index=False)
    print(f"Combined loss data saved to {combined_csv_path}")

    # # Generate plots for each loss type
    # loss_columns = ["train_seg_loss", "train_depth_loss", "train_combined_loss", "train_adv_loss",
    #                 "valid_seg_loss", "valid_depth_loss", "valid_combined_loss", "valid_adv_loss"]
    # for col in loss_columns:
    #     plt.figure()
    #     plt.plot(combined_df["epoch"], combined_df[col], label=col)
    #     plt.xlabel("Epoch")
    #     plt.ylabel("Loss")
    #     plt.title(f"{col.replace('_', ' ').title()} Over Epochs")
    #     plt.legend()
    #     plot_path = os.path.join(combined_save_dir, f"{col}_plot.png")
    #     plt.savefig(plot_path)
    #     plt.close()
    #     print(f"Plot saved to {plot_path}")
    # Generate combined plots for train and valid losses
    
    loss_types = ["seg_loss", "depth_loss", "combined_loss", "adv_loss"]
    for loss_type in loss_types:
        train_loss_col = f"train_{loss_type}"
        valid_loss_col = f"valid_{loss_type}"

        plt.figure()
        plt.plot(combined_df["epoch"], combined_df[train_loss_col], label=f"Train {loss_type.capitalize()}")
        plt.plot(combined_df["epoch"], combined_df[valid_loss_col], label=f"Valid {loss_type.capitalize()}")
        plt.xlabel("Epoch")
        plt.ylabel("Loss")
        plt.title(f"{loss_type.replace('_', ' ').capitalize()} Over Epochs")
        plt.legend()
        plot_path = os.path.join(combined_save_dir, f"{loss_type}_train_valid_plot.png")
        plt.savefig(plot_path)
        plt.close()
        print(f"Combined plot saved to {plot_path}")


    return combined_df


In [28]:
# testing combine

In [29]:
os.getcwd(),'results_test8'

('/home/rmajumd/2024/ML_in_image_synthesis/Cityscapes/Cityscapes',
 'results_test8')

# training

In [30]:
def train_model_with_adversarial_loss_tracking(
    model, train_loader, valid_loader, num_epochs, device, opt_sched, save_dir="results"
):
    """
    Trains a multi-task model with adversarial feedback and tracks losses.
    
    Args:
        model: Multi-task model with integrated generators and discriminators.
        train_loader: DataLoader for training data.
        valid_loader: DataLoader for validation data.
        num_epochs: Number of epochs to train.
        device: Device for training ("cuda" or "cpu").
        opt_sched: Dictionary of optimizers and schedulers for generators and discriminators.
        save_dir: Directory to save results.
    
    Returns:
        train_losses, valid_losses: Lists of losses for training and validation.
    """
    # Create directories for saving results
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    save_dir = os.path.join(save_dir, timestamp)
    os.makedirs(save_dir, exist_ok=True)

    # Prepare CSV file for loss tracking
    csv_path = os.path.join(save_dir, f"loss_tracking_{timestamp}.csv")
    gif_path = os.path.join(save_dir, f"training_visualization_{timestamp}.gif")
    
    with open(csv_path, "w", newline="") as f:
        writer = csv.writer(f)
        writer.writerow([
            "epoch", "train_seg_loss", "train_depth_loss", "train_combined_loss",
            "train_adv_loss", 
            # "train_seg_iou",
            "valid_seg_loss", "valid_depth_loss", "valid_combined_loss",
            "valid_adv_loss", 
            # "valid_seg_iou"
        ])

    # Initialize tracking variables
    train_losses = {"seg": [], "depth": [], "combined": [], "adv": []}
    valid_losses = {"seg": [], "depth": [], "combined": [], "adv": []}
    best_combined_loss = float("inf")
    gif_frames =[]
    perceptual_loss_fn = PerceptualLoss(pretrained_model="vgg16").to(device)

    # Start training loop
    for epoch in range(num_epochs):
        model.train()
        epoch_train = {key: 0.0 for key in train_losses.keys()}
        num_batches = 0
        
        
        
        with tqdm(total=len(train_loader), desc=f"Epoch {epoch+1}/{num_epochs} - Training", unit="batch") as pbar:
            for batch in train_loader:
                inputs, seg_labels, depth_labels = (
                    batch["left"].to(device),
                    batch["mask"].to(device),
                    batch["depth"].to(device),
                )
                input_size = inputs.size()[-2:]

                # Preprocess seg_labels to one-hot encoding
                if seg_labels.size(1) == 1:  # If class indices are given
                    seg_labels = torch.nn.functional.one_hot(seg_labels.squeeze(1), num_classes=20)
                    seg_labels = seg_labels.permute(0, 3, 1, 2).float().to(device)  # Convert to [B, C, H, W]

                # Ensure depth_labels has correct dimensions
                if depth_labels.dim() == 5:  # If depth_labels has extra dimensions
                    depth_labels = depth_labels.squeeze(2)

                # Zero gradients
                for optimizer in opt_sched["optimizers"].values():
                    optimizer.zero_grad()

                # Forward pass with discriminator outputs
                outputs = model(
                    inputs,
                    input_size=input_size,
                    seg_labels=seg_labels,
                    depth_labels=depth_labels,
                    return_discriminator_outputs=True,
                )

                # Generator losses
                seg_loss = nn.CrossEntropyLoss()(outputs["seg_output"], seg_labels) + \
                           dice_loss(outputs["seg_output"], seg_labels)
                depth_loss = scale_invariant_depth_loss(outputs["depth_output"], depth_labels) + \
                             inv_huber_loss(outputs["depth_output"], depth_labels) + \
                             depth_smoothness_loss(outputs["depth_output"], inputs)
                
                seg_perceptual_loss = perceptual_loss_fn(outputs["seg_output"], seg_labels.unsqueeze(1))
                depth_perceptual_loss = perceptual_loss_fn(outputs["depth_output"], depth_labels)
                
                seg_loss = seg_loss + 0.1 * seg_perceptual_loss
                depth_loss = depth_loss + 0.1 * depth_perceptual_loss
                

                adv_loss = -(
                    torch.mean(outputs["seg_real_disc"]) +
                    torch.mean(outputs["depth_real_disc"]) +
                    torch.mean(outputs["combined_real_disc"])
                )

                combined_loss = seg_loss + depth_loss + 0.01 * adv_loss

                # Backpropagation for generators
                combined_loss.backward(retain_graph=True)
                # opt_sched["optimizers"]["generator"].step()
                opt_sched["optimizers"]["shared_gen"].step()
                opt_sched["optimizers"]["shared_refine"].step()
                opt_sched["optimizers"]["seg_gen"].step()
                opt_sched["optimizers"]["depth_gen"].step()


                # Update task-specific discriminators
                for task, disc_optimizer in [
                    ("seg", "seg_disc"),
                    ("depth", "depth_disc"),
                ]:
                    opt_sched["optimizers"][disc_optimizer].zero_grad()
                    real_disc_loss = torch.mean(
                        (outputs[f"{task}_real_disc"] - 1) ** 2
                    )
                    fake_disc_loss = torch.mean(
                        (outputs[f"{task}_fake_disc"].detach()) ** 2
                    )
                    disc_loss = (real_disc_loss + fake_disc_loss) / 2
                    disc_loss.backward()
                    opt_sched["optimizers"][disc_optimizer].step()

                # Update multi-task discriminator
                opt_sched["optimizers"]["multi_task_disc"].zero_grad()
                real_combined_loss = torch.mean(
                    (outputs["combined_real_disc"] - 1) ** 2
                )
                fake_combined_loss = torch.mean(
                    (outputs["combined_fake_disc"].detach()) ** 2
                )
                combined_disc_loss = (real_combined_loss + fake_combined_loss) / 2
                combined_disc_loss.backward()
                opt_sched["optimizers"]["multi_task_disc"].step()

                # Update training metrics
                epoch_train["seg"] += seg_loss.item()
                epoch_train["depth"] += depth_loss.item()
                epoch_train["combined"] += combined_loss.item()
                epoch_train["adv"] += adv_loss.item()
                # epoch_train["iou"] += mean_iou(outputs["seg_output"], seg_labels, num_classes=20).item()
                num_batches += 1

            # Average training metrics
            for key in epoch_train.keys():
                train_losses[key].append(epoch_train[key] / num_batches)

            # Validation loop
            model.eval()
            epoch_valid = {key: 0.0 for key in valid_losses.keys()}
            num_valid_batches = 0

            with torch.no_grad():
                for batch in valid_loader:
                    inputs, seg_labels, depth_labels = (
                        batch["left"].to(device),
                        batch["mask"].to(device),
                        batch["depth"].to(device),
                    )
                    input_size = inputs.size()[-2:]

                    # Preprocess seg_labels to one-hot encoding
                    if seg_labels.size(1) == 1:
                        seg_labels = torch.nn.functional.one_hot(seg_labels.squeeze(1), num_classes=20)
                        seg_labels = seg_labels.permute(0, 3, 1, 2).float().to(device)

                    # Ensure depth_labels has correct dimensions
                    if depth_labels.dim() == 5:
                        depth_labels = depth_labels.squeeze(2)


                    # Forward pass
                    outputs = model(
                        inputs,
                        input_size=input_size,
                        seg_labels=seg_labels,
                        depth_labels=depth_labels,
                        return_discriminator_outputs=True,
                    )

                    # Validation loss calculations
                    seg_loss = nn.CrossEntropyLoss()(outputs["seg_output"], seg_labels) + \
                               dice_loss(outputs["seg_output"], seg_labels)
                    depth_loss = scale_invariant_depth_loss(outputs["depth_output"], depth_labels) + \
                                 inv_huber_loss(outputs["depth_output"], depth_labels) + \
                                 depth_smoothness_loss(outputs["depth_output"], inputs)
                    
                    seg_perceptual_loss = perceptual_loss_fn(outputs["seg_output"], seg_labels.unsqueeze(1))
                    depth_perceptual_loss = perceptual_loss_fn(outputs["depth_output"], depth_labels)

                    seg_loss = seg_loss + 0.1 * seg_perceptual_loss
                    depth_loss = depth_loss + 0.1 * depth_perceptual_loss


                    adv_loss = -(
                        torch.mean(outputs["seg_real_disc"]) +
                        torch.mean(outputs["depth_real_disc"]) +
                        torch.mean(outputs["combined_real_disc"])
                    )

                    combined_loss = seg_loss + depth_loss + 0.01 * adv_loss

                    # Update validation metrics
                    epoch_valid["seg"] += seg_loss.item()
                    epoch_valid["depth"] += depth_loss.item()
                    epoch_valid["combined"] += combined_loss.item()
                    epoch_valid["adv"] += adv_loss.item()
                    # epoch_valid["iou"] += mean_iou(outputs["seg_output"], seg_labels, num_classes=20).item()
                    num_valid_batches += 1

        # Average validation metrics
        for key in epoch_valid.keys():
            valid_losses[key].append(epoch_valid[key] / num_valid_batches)

        # Save best model
        valid_combined_loss = epoch_valid["combined"] / num_valid_batches
        if valid_combined_loss < best_combined_loss:
            best_combined_loss = valid_combined_loss
            # torch.save(model.state_dict(), os.path.join(save_dir, "best_model.pth"))
            checkpoint_path = os.path.join(save_dir, "best_model_checkpoint.pth")
            save_checkpoint(model, opt_sched, checkpoint_path, epoch + 1, best_combined_loss)
            print(f"Best model saved at epoch {epoch+1} with combined loss {best_combined_loss:.4f}")
            
        frame = save_training_visualization_as_gif2(epoch, inputs, outputs["seg_output"], outputs["depth_output"], torch.argmax(seg_labels, dim=1), depth_labels)
        gif_frames.append(frame)
        
        

        # Append metrics to CSV
        with open(csv_path, "a", newline="") as f:
            writer = csv.writer(f)
            writer.writerow([
                epoch + 1,
                epoch_train["seg"] / num_batches,
                epoch_train["depth"] / num_batches,
                epoch_train["combined"] / num_batches,
                epoch_train["adv"] / num_batches,
                # epoch_train["iou"] / num_batches,
                epoch_valid["seg"] / num_valid_batches,
                epoch_valid["depth"] / num_valid_batches,
                epoch_valid["combined"] / num_valid_batches,
                epoch_valid["adv"] / num_valid_batches,
                # epoch_valid["iou"] / num_valid_batches,
            ])
            
        # Print epoch results
        print(f"Epoch {epoch + 1}/{num_epochs} Results:")

        # Print training losses
        print(f"  Train Losses - Segmentation: {epoch_train['seg']/num_batches:.4f}, Depth: {epoch_train['depth']/num_batches:.4f}, "
              f"Combined: {epoch_train['combined']/num_batches:.4f}, Adversarial: {epoch_train['adv']/num_batches:.4f}")

        # Print validation losses
        print(f"  Valid Losses - Segmentation: {epoch_valid['seg']/ num_valid_batches:.4f}, Depth: {epoch_valid['depth']/ num_valid_batches:.4f}, "
              f"Combined: {epoch_valid['combined']/ num_valid_batches:.4f}, Adversarial: {epoch_valid['adv']/ num_valid_batches:.4f}")


        # Update schedulers
        for name, scheduler in opt_sched["schedulers"].items():
            if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
                # Pass the appropriate metric to ReduceLROnPlateau
                scheduler.step(valid_losses["combined"][-1])  # Use the most recent validation combined loss
            else:
                scheduler.step()
            
        if epoch %10 == 0:
            gif_path2 =os.path.join(save_dir,f"viz_epoch_{epoch}.gif")
            gif_frames[0].save(gif_path2, save_all=True, append_images=gif_frames[1:], duration=500, loop=0)
            plot_all_losses(epoch, train_losses,valid_losses,save_dir)
            
    
    gif_frames[0].save(gif_path, save_all=True, append_images=gif_frames[1:], duration=500, loop=0)
    print(f"Training visualization saved as GIF at {gif_path}")
    plot_all_losses(epoch, train_losses,valid_losses,save_dir)

    
    return train_losses, valid_losses, save_dir


In [31]:

def resume_training_with_loss_tracking(
    model_class,
    model_dir,
    train_loader,
    valid_loader,
    num_additional_epochs,
    device,
    opt_sched,
    save_dir,
):
    """
    Resumes training a multi-task model, appends loss data to the existing CSV file,
    and generates graphs for the combined training history.

    Args:
        model_class: The model class to instantiate.
        model_dir: Path to the directory containing the saved model and loss CSV file.
        train_loader: DataLoader for training data.
        valid_loader: DataLoader for validation data.
        num_additional_epochs: Number of additional epochs to train.
        device: Device for training ("cuda" or "cpu").
        opt_sched: Dictionary of optimizers and schedulers.
        save_dir: Directory to save the updated results.

    Returns:
        Updated train and validation losses.
    """
    # # Load the best model
    # best_model_path = os.path.join(model_dir, "best_model.pth")
    # if not os.path.exists(best_model_path):
    #     raise FileNotFoundError(f"Best model not found at {best_model_path}")
        
    # Load the checkpoint
    checkpoint_path = os.path.join(model_dir, "best_model_checkpoint.pth")
    if not os.path.exists(checkpoint_path):
        raise FileNotFoundError(f"Checkpoint not found at {checkpoint_path}")
    
        
    model = model_class().to(device)
    
    start_epoch, best_loss = load_checkpoint(model, opt_sched, checkpoint_path, device)

    # model.load_state_dict(torch.load(best_model_path, map_location=device))

    # Locate the existing loss CSV
    csv_path = os.path.join(model_dir, [file for file in os.listdir(model_dir) if file.endswith(".csv")][0])

    # Parse existing CSV data
    existing_train_losses = {"seg": [], "depth": [], "combined": [], "adv": []}
    existing_valid_losses = {"seg": [], "depth": [], "combined": [], "adv": []}
    current_epoch = 0
    with open(csv_path, "r") as f:
        reader = csv.reader(f)
        next(reader)  # Skip header
        for row in reader:
            current_epoch = int(row[0])
            existing_train_losses["seg"].append(float(row[1]))
            existing_train_losses["depth"].append(float(row[2]))
            existing_train_losses["combined"].append(float(row[3]))
            existing_train_losses["adv"].append(float(row[4]))
            existing_valid_losses["seg"].append(float(row[5]))
            existing_valid_losses["depth"].append(float(row[6]))
            existing_valid_losses["combined"].append(float(row[7]))
            existing_valid_losses["adv"].append(float(row[8]))

    # Train for additional epochs
    
    train_losses, valid_losses,save_dir2 = train_model_with_adversarial_loss_tracking(
        model=model,
        train_loader=train_loader,
        valid_loader=valid_loader,
        num_epochs=num_additional_epochs,
        device=device,
        opt_sched=opt_sched,
        save_dir=save_dir,
    )

    # Combine the new losses with the existing ones
    for key in existing_train_losses.keys():
        existing_train_losses[key].extend(train_losses[key])
        existing_valid_losses[key].extend(valid_losses[key])
        
    save_dir3 = os.path.join(save_dir2,"combined_result")
    os.makedirs(save_dir3, exist_ok=True)
    
    total_epochs = len(existing_train_losses["seg"])
        
    updated_csv_path = os.path.join(save_dir3, f"loss_tracking_updated_{total_epochs}.csv")
    os.makedirs(save_dir3, exist_ok=True)

    # Write combined losses to the updated CSV
    with open(updated_csv_path, "w", newline="") as f:
        writer = csv.writer(f)
        writer.writerow([
            "epoch", "train_seg_loss", "train_depth_loss", "train_combined_loss",
            "train_adv_loss", 
            "valid_seg_loss", "valid_depth_loss", "valid_combined_loss",
            "valid_adv_loss", 
        ])
        for epoch in range(len(existing_train_losses["seg"])):
            writer.writerow([
                epoch + 1,
                existing_train_losses["seg"][epoch],
                existing_train_losses["depth"][epoch],
                existing_train_losses["combined"][epoch],
                existing_train_losses["adv"][epoch],
                existing_valid_losses["seg"][epoch],
                existing_valid_losses["depth"][epoch],
                existing_valid_losses["combined"][epoch],
                existing_valid_losses["adv"][epoch],
            ])

    # Generate graphs
    for key in existing_train_losses.keys():
        plt.figure()
        plt.plot(range(len(existing_train_losses[key])), existing_train_losses[key], label=f"Train {key.capitalize()}")
        plt.plot(range(len(existing_valid_losses[key])), existing_valid_losses[key], label=f"Valid {key.capitalize()}")
        plt.xlabel("Epoch")
        plt.ylabel(f"{key.capitalize()} Loss")
        plt.legend()
        plt.title(f"{key.capitalize()} Loss Over Epochs")
        plt.savefig(os.path.join(save_dir3, f"{key}_loss_graph_epoch_{total_epochs}.png"))
        plt.close()
        
    
    output_path = os.path.join(save_dir3,'combined_results.gif')
    combine_training_gifs(model_dir, save_dir2, output_path)

    return existing_train_losses, existing_valid_losses,save_dir2


# For first instance

In [59]:
# Initialize the model
# Instantiate Models

BATCH_SIZE = 8
EPOCHS = 40
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'


mobilenet_backbone = mobilenet_v3_small(weights=MobileNet_V3_Small_Weights.IMAGENET1K_V1)
# encoder = MobileNetV3Backbone(mobilenet_backbone.features)
model = MultiTaskModel(backbone=mobilenet_backbone.features, num_seg_classes=20, depth_channels=1)
model.to(DEVICE)

# Initialize optimizers and schedulers
opt_sched = initialize_optimizers_and_schedulers(model)

# Access optimizers
optimizers = opt_sched["optimizers"]
schedulers = opt_sched["schedulers"]

# Prepare Data Loaders (Ensure train_loader and valid_loader are ready)
train_losses, valid_losses,save_dir2 = train_model_with_adversarial_loss_tracking(
    model=model,
    train_loader=train_loader,
    valid_loader=valid_loader,
    num_epochs=EPOCHS,
    device=DEVICE,
    opt_sched=opt_sched,
    save_dir="results_test8_final2"
)

Epoch 1/40 - Training:   0%|          | 0/371 [09:37<?, ?batch/s]


Best model saved at epoch 1 with combined loss 2.4004
Epoch 1/40 Results:
  Train Losses - Segmentation: 2.4219, Depth: 0.0579, Combined: 2.4500, Adversarial: -2.9780
  Valid Losses - Segmentation: 2.3837, Depth: 0.0466, Combined: 2.4004, Adversarial: -2.9848


Epoch 2/40 - Training:   0%|          | 0/371 [08:15<?, ?batch/s]


Best model saved at epoch 2 with combined loss 2.1478
Epoch 2/40 Results:
  Train Losses - Segmentation: 2.1883, Depth: 0.0444, Combined: 2.2027, Adversarial: -2.9958
  Valid Losses - Segmentation: 2.1091, Depth: 0.0686, Combined: 2.1478, Adversarial: -2.9893


Epoch 3/40 - Training:   0%|          | 0/371 [08:10<?, ?batch/s]


Epoch 3/40 Results:
  Train Losses - Segmentation: 2.1075, Depth: 0.0420, Combined: 2.1196, Adversarial: -2.9967
  Valid Losses - Segmentation: 2.2348, Depth: 0.0638, Combined: 2.2686, Adversarial: -2.9987


Epoch 4/40 - Training:   0%|          | 0/371 [08:08<?, ?batch/s]


Best model saved at epoch 4 with combined loss 2.0933
Epoch 4/40 Results:
  Train Losses - Segmentation: 2.0822, Depth: 0.0409, Combined: 2.0932, Adversarial: -2.9972
  Valid Losses - Segmentation: 2.0770, Depth: 0.0465, Combined: 2.0933, Adversarial: -3.0239


Epoch 5/40 - Training:   0%|          | 0/371 [08:10<?, ?batch/s]


Best model saved at epoch 5 with combined loss 2.0564
Epoch 5/40 Results:
  Train Losses - Segmentation: 2.0276, Depth: 0.0376, Combined: 2.0352, Adversarial: -2.9978
  Valid Losses - Segmentation: 2.0396, Depth: 0.0468, Combined: 2.0564, Adversarial: -2.9943


Epoch 6/40 - Training:   0%|          | 0/371 [08:10<?, ?batch/s]


Epoch 6/40 Results:
  Train Losses - Segmentation: 1.9941, Depth: 0.0376, Combined: 2.0017, Adversarial: -2.9980
  Valid Losses - Segmentation: 2.1045, Depth: 0.0499, Combined: 2.1242, Adversarial: -3.0172


Epoch 7/40 - Training:   0%|          | 0/371 [08:08<?, ?batch/s]


Epoch 7/40 Results:
  Train Losses - Segmentation: 1.9874, Depth: 0.0376, Combined: 1.9950, Adversarial: -2.9980
  Valid Losses - Segmentation: 2.0622, Depth: 0.0585, Combined: 2.0907, Adversarial: -3.0026


Epoch 8/40 - Training:   0%|          | 0/371 [08:08<?, ?batch/s]


Epoch 8/40 Results:
  Train Losses - Segmentation: 1.9509, Depth: 0.0359, Combined: 1.9569, Adversarial: -2.9984
  Valid Losses - Segmentation: 2.0578, Depth: 0.0521, Combined: 2.0798, Adversarial: -3.0080


Epoch 9/40 - Training:   0%|          | 0/371 [08:10<?, ?batch/s]


Epoch 9/40 Results:
  Train Losses - Segmentation: 1.9244, Depth: 0.0348, Combined: 1.9292, Adversarial: -2.9985
  Valid Losses - Segmentation: 2.1080, Depth: 0.0699, Combined: 2.1480, Adversarial: -2.9941


Epoch 10/40 - Training:   0%|          | 0/371 [08:09<?, ?batch/s]


Best model saved at epoch 10 with combined loss 2.0522
Epoch 10/40 Results:
  Train Losses - Segmentation: 1.9190, Depth: 0.0340, Combined: 1.9230, Adversarial: -2.9985
  Valid Losses - Segmentation: 2.0426, Depth: 0.0397, Combined: 2.0522, Adversarial: -3.0051


Epoch 11/40 - Training:   0%|          | 0/371 [08:08<?, ?batch/s]


Epoch 11/40 Results:
  Train Losses - Segmentation: 1.8911, Depth: 0.0349, Combined: 1.8960, Adversarial: -2.9987
  Valid Losses - Segmentation: 2.0226, Depth: 0.0957, Combined: 2.0882, Adversarial: -3.0053


Epoch 12/40 - Training:   0%|          | 0/371 [08:09<?, ?batch/s]


Best model saved at epoch 12 with combined loss 2.0495
Epoch 12/40 Results:
  Train Losses - Segmentation: 1.8991, Depth: 0.0341, Combined: 1.9032, Adversarial: -2.9987
  Valid Losses - Segmentation: 2.0182, Depth: 0.0612, Combined: 2.0495, Adversarial: -2.9843


Epoch 13/40 - Training:   0%|          | 0/371 [08:10<?, ?batch/s]


Epoch 13/40 Results:
  Train Losses - Segmentation: 1.8789, Depth: 0.0343, Combined: 1.8832, Adversarial: -2.9989
  Valid Losses - Segmentation: 2.1770, Depth: 0.0433, Combined: 2.1905, Adversarial: -2.9839


Epoch 14/40 - Training:   0%|          | 0/371 [08:11<?, ?batch/s]


Best model saved at epoch 14 with combined loss 2.0154
Epoch 14/40 Results:
  Train Losses - Segmentation: 1.8733, Depth: 0.0348, Combined: 1.8781, Adversarial: -2.9989
  Valid Losses - Segmentation: 2.0055, Depth: 0.0399, Combined: 2.0154, Adversarial: -2.9997


Epoch 15/40 - Training:   0%|          | 0/371 [08:09<?, ?batch/s]


Best model saved at epoch 15 with combined loss 1.9091
Epoch 15/40 Results:
  Train Losses - Segmentation: 1.8466, Depth: 0.0328, Combined: 1.8494, Adversarial: -2.9990
  Valid Losses - Segmentation: 1.8968, Depth: 0.0422, Combined: 1.9091, Adversarial: -2.9922


Epoch 16/40 - Training:   0%|          | 0/371 [08:11<?, ?batch/s]


Epoch 16/40 Results:
  Train Losses - Segmentation: 1.8445, Depth: 0.0325, Combined: 1.8470, Adversarial: -2.9991
  Valid Losses - Segmentation: 1.9076, Depth: 0.0414, Combined: 1.9190, Adversarial: -3.0054


Epoch 17/40 - Training:   0%|          | 0/371 [08:09<?, ?batch/s]


Epoch 17/40 Results:
  Train Losses - Segmentation: 1.8260, Depth: 0.0322, Combined: 1.8282, Adversarial: -2.9992
  Valid Losses - Segmentation: 1.9399, Depth: 0.0510, Combined: 1.9608, Adversarial: -3.0046


Epoch 18/40 - Training:   0%|          | 0/371 [08:10<?, ?batch/s]


Epoch 18/40 Results:
  Train Losses - Segmentation: 1.8323, Depth: 0.0320, Combined: 1.8343, Adversarial: -2.9993
  Valid Losses - Segmentation: 1.9881, Depth: 0.0564, Combined: 2.0144, Adversarial: -2.9979


Epoch 19/40 - Training:   0%|          | 0/371 [08:09<?, ?batch/s]


Epoch 19/40 Results:
  Train Losses - Segmentation: 1.8182, Depth: 0.0309, Combined: 1.8191, Adversarial: -2.9992
  Valid Losses - Segmentation: 1.9535, Depth: 0.0825, Combined: 2.0059, Adversarial: -3.0069


Epoch 20/40 - Training:   0%|          | 0/371 [08:10<?, ?batch/s]


Epoch 20/40 Results:
  Train Losses - Segmentation: 1.8197, Depth: 0.0315, Combined: 1.8212, Adversarial: -2.9994
  Valid Losses - Segmentation: 2.0076, Depth: 0.0464, Combined: 2.0240, Adversarial: -2.9984


Epoch 21/40 - Training:   0%|          | 0/371 [08:10<?, ?batch/s]


Best model saved at epoch 21 with combined loss 1.9031
Epoch 21/40 Results:
  Train Losses - Segmentation: 1.8005, Depth: 0.0312, Combined: 1.8017, Adversarial: -2.9993
  Valid Losses - Segmentation: 1.8859, Depth: 0.0473, Combined: 1.9031, Adversarial: -3.0017


Epoch 22/40 - Training:   0%|          | 0/371 [08:11<?, ?batch/s]


Epoch 22/40 Results:
  Train Losses - Segmentation: 1.7869, Depth: 0.0306, Combined: 1.7875, Adversarial: -2.9995
  Valid Losses - Segmentation: 1.9414, Depth: 0.0394, Combined: 1.9506, Adversarial: -3.0133


Epoch 23/40 - Training:   0%|          | 0/371 [08:08<?, ?batch/s]


Epoch 23/40 Results:
  Train Losses - Segmentation: 1.7969, Depth: 0.0302, Combined: 1.7971, Adversarial: -2.9995
  Valid Losses - Segmentation: 1.9068, Depth: 0.0351, Combined: 1.9119, Adversarial: -3.0020


Epoch 24/40 - Training:   0%|          | 0/371 [08:12<?, ?batch/s]


Epoch 24/40 Results:
  Train Losses - Segmentation: 1.7738, Depth: 0.0301, Combined: 1.7740, Adversarial: -2.9995
  Valid Losses - Segmentation: 1.8993, Depth: 0.0468, Combined: 1.9161, Adversarial: -2.9975


Epoch 25/40 - Training:   0%|          | 0/371 [08:09<?, ?batch/s]


Best model saved at epoch 25 with combined loss 1.8873
Epoch 25/40 Results:
  Train Losses - Segmentation: 1.7873, Depth: 0.0302, Combined: 1.7875, Adversarial: -2.9996
  Valid Losses - Segmentation: 1.8715, Depth: 0.0458, Combined: 1.8873, Adversarial: -3.0002


Epoch 26/40 - Training:   0%|          | 0/371 [08:08<?, ?batch/s]


Epoch 26/40 Results:
  Train Losses - Segmentation: 1.7700, Depth: 0.0294, Combined: 1.7694, Adversarial: -2.9996
  Valid Losses - Segmentation: 1.8556, Depth: 0.0677, Combined: 1.8932, Adversarial: -3.0097


Epoch 27/40 - Training:   0%|          | 0/371 [08:11<?, ?batch/s]


Best model saved at epoch 27 with combined loss 1.8724
Epoch 27/40 Results:
  Train Losses - Segmentation: 1.7521, Depth: 0.0286, Combined: 1.7507, Adversarial: -2.9996
  Valid Losses - Segmentation: 1.8617, Depth: 0.0407, Combined: 1.8724, Adversarial: -3.0084


Epoch 28/40 - Training:   0%|          | 0/371 [08:10<?, ?batch/s]


Epoch 28/40 Results:
  Train Losses - Segmentation: 1.7422, Depth: 0.0293, Combined: 1.7414, Adversarial: -2.9997
  Valid Losses - Segmentation: 1.8772, Depth: 0.0406, Combined: 1.8877, Adversarial: -3.0022


Epoch 29/40 - Training:   0%|          | 0/371 [08:09<?, ?batch/s]


Epoch 29/40 Results:
  Train Losses - Segmentation: 1.7655, Depth: 0.0287, Combined: 1.7642, Adversarial: -2.9997
  Valid Losses - Segmentation: 1.8699, Depth: 0.0484, Combined: 1.8884, Adversarial: -2.9909


Epoch 30/40 - Training:   0%|          | 0/371 [08:07<?, ?batch/s]


Best model saved at epoch 30 with combined loss 1.8357
Epoch 30/40 Results:
  Train Losses - Segmentation: 1.7476, Depth: 0.0286, Combined: 1.7461, Adversarial: -2.9997
  Valid Losses - Segmentation: 1.8267, Depth: 0.0390, Combined: 1.8357, Adversarial: -3.0002


Epoch 31/40 - Training:   0%|          | 0/371 [08:13<?, ?batch/s]


Epoch 31/40 Results:
  Train Losses - Segmentation: 1.7423, Depth: 0.0286, Combined: 1.7409, Adversarial: -2.9997
  Valid Losses - Segmentation: 1.8337, Depth: 0.0495, Combined: 1.8531, Adversarial: -3.0079


Epoch 32/40 - Training:   0%|          | 0/371 [08:14<?, ?batch/s]


Epoch 32/40 Results:
  Train Losses - Segmentation: 1.7374, Depth: 0.0284, Combined: 1.7358, Adversarial: -2.9998
  Valid Losses - Segmentation: 1.8474, Depth: 0.0336, Combined: 1.8510, Adversarial: -3.0017


Epoch 33/40 - Training:   0%|          | 0/371 [08:10<?, ?batch/s]


Epoch 33/40 Results:
  Train Losses - Segmentation: 1.7330, Depth: 0.0284, Combined: 1.7314, Adversarial: -2.9998
  Valid Losses - Segmentation: 1.9529, Depth: 0.0483, Combined: 1.9713, Adversarial: -2.9979


Epoch 34/40 - Training:   0%|          | 0/371 [08:10<?, ?batch/s]


Epoch 34/40 Results:
  Train Losses - Segmentation: 1.7091, Depth: 0.0275, Combined: 1.7065, Adversarial: -2.9997
  Valid Losses - Segmentation: 1.8489, Depth: 0.0395, Combined: 1.8584, Adversarial: -3.0020


Epoch 35/40 - Training:   0%|          | 0/371 [08:06<?, ?batch/s]


Best model saved at epoch 35 with combined loss 1.8156
Epoch 35/40 Results:
  Train Losses - Segmentation: 1.7201, Depth: 0.0276, Combined: 1.7177, Adversarial: -2.9998
  Valid Losses - Segmentation: 1.8088, Depth: 0.0367, Combined: 1.8156, Adversarial: -2.9956


Epoch 36/40 - Training:   0%|          | 0/371 [08:07<?, ?batch/s]


Epoch 36/40 Results:
  Train Losses - Segmentation: 1.7099, Depth: 0.0277, Combined: 1.7076, Adversarial: -2.9998
  Valid Losses - Segmentation: 1.8385, Depth: 0.0515, Combined: 1.8601, Adversarial: -2.9964


Epoch 37/40 - Training:   0%|          | 0/371 [08:09<?, ?batch/s]


Epoch 37/40 Results:
  Train Losses - Segmentation: 1.7058, Depth: 0.0276, Combined: 1.7034, Adversarial: -2.9998
  Valid Losses - Segmentation: 1.8251, Depth: 0.0382, Combined: 1.8332, Adversarial: -3.0029


Epoch 38/40 - Training:   0%|          | 0/371 [08:12<?, ?batch/s]


Epoch 38/40 Results:
  Train Losses - Segmentation: 1.6985, Depth: 0.0272, Combined: 1.6957, Adversarial: -2.9998
  Valid Losses - Segmentation: 1.8412, Depth: 0.0463, Combined: 1.8575, Adversarial: -3.0026


Epoch 39/40 - Training:   0%|          | 0/371 [08:09<?, ?batch/s]


Epoch 39/40 Results:
  Train Losses - Segmentation: 1.6937, Depth: 0.0276, Combined: 1.6914, Adversarial: -2.9998
  Valid Losses - Segmentation: 1.8581, Depth: 0.0476, Combined: 1.8757, Adversarial: -3.0022


Epoch 40/40 - Training:   0%|          | 0/371 [08:09<?, ?batch/s]


Epoch 40/40 Results:
  Train Losses - Segmentation: 1.6973, Depth: 0.0270, Combined: 1.6943, Adversarial: -2.9998
  Valid Losses - Segmentation: 1.8487, Depth: 0.0364, Combined: 1.8551, Adversarial: -2.9977
Training visualization saved as GIF at results_test8_final2/20241129_030455/training_visualization_20241129_030455.gif


# For second instance

In [32]:
root_save_dir = os.path.join(os.getcwd(),'results_test8_final2')
model_dir =  os.path.join(root_save_dir,'20241129_145526')
model_dir = os.path.join(model_dir,'combined_result')
model_dir

'/home/rmajumd/2024/ML_in_image_synthesis/Cityscapes/Cityscapes/results_test8_final2/20241129_145526/combined_result'

In [33]:
# root_save_dir = os.path.join(os.getcwd(),'results_test8_final2')
# model_dir =  os.path.join(root_save_dir,'20241129_030455')
num_additional_epochs = 60  # Number of epochs to continue training
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Initialize the model class and loaders
mobilenet_backbone = mobilenet_v3_small(weights=MobileNet_V3_Small_Weights.IMAGENET1K_V1)
model_class = lambda: MultiTaskModel(backbone=mobilenet_backbone.features, num_seg_classes=20, depth_channels=1)

# Initialize optimizers and schedulers
model = model_class()
opt_sched = initialize_optimizers_and_schedulers(model)


# Call the resume training function
resume_training_with_loss_tracking(
    model_class=model_class,
    model_dir=model_dir,
    train_loader=train_loader,
    valid_loader=valid_loader,
    num_additional_epochs=num_additional_epochs,
    device=device,
    opt_sched=opt_sched,
    save_dir="results_test8_final2"
)

root_save_dir,model_dir

  checkpoint = torch.load(checkpoint_path, map_location=device)
Epoch 1/60 - Training:   0%|          | 0/371 [07:20<?, ?batch/s]


Best model saved at epoch 1 with combined loss 1.8251
Epoch 1/60 Results:
  Train Losses - Segmentation: 1.7052, Depth: 0.0282, Combined: 1.7034, Adversarial: -2.9948
  Valid Losses - Segmentation: 1.8172, Depth: 0.0379, Combined: 1.8251, Adversarial: -2.9962


Epoch 2/60 - Training:   0%|          | 0/371 [07:16<?, ?batch/s]


Epoch 2/60 Results:
  Train Losses - Segmentation: 1.6914, Depth: 0.0280, Combined: 1.6895, Adversarial: -2.9949
  Valid Losses - Segmentation: 1.8304, Depth: 0.0365, Combined: 1.8369, Adversarial: -2.9965


Epoch 3/60 - Training:   0%|          | 0/371 [07:18<?, ?batch/s]


Epoch 3/60 Results:
  Train Losses - Segmentation: 1.6896, Depth: 0.0283, Combined: 1.6880, Adversarial: -2.9950
  Valid Losses - Segmentation: 1.8311, Depth: 0.0380, Combined: 1.8392, Adversarial: -2.9964


Epoch 4/60 - Training:   0%|          | 0/371 [07:23<?, ?batch/s]


Epoch 4/60 Results:
  Train Losses - Segmentation: 1.6928, Depth: 0.0285, Combined: 1.6914, Adversarial: -2.9948
  Valid Losses - Segmentation: 1.8314, Depth: 0.0378, Combined: 1.8392, Adversarial: -2.9964


Epoch 5/60 - Training:   0%|          | 0/371 [07:20<?, ?batch/s]


Best model saved at epoch 5 with combined loss 1.8211
Epoch 5/60 Results:
  Train Losses - Segmentation: 1.6947, Depth: 0.0282, Combined: 1.6929, Adversarial: -2.9950
  Valid Losses - Segmentation: 1.8132, Depth: 0.0378, Combined: 1.8211, Adversarial: -2.9961


Epoch 6/60 - Training:   0%|          | 0/371 [07:15<?, ?batch/s]


Epoch 6/60 Results:
  Train Losses - Segmentation: 1.6891, Depth: 0.0281, Combined: 1.6873, Adversarial: -2.9948
  Valid Losses - Segmentation: 1.8373, Depth: 0.0380, Combined: 1.8453, Adversarial: -2.9965


Epoch 7/60 - Training:   0%|          | 0/371 [07:14<?, ?batch/s]


Best model saved at epoch 7 with combined loss 1.8161
Epoch 7/60 Results:
  Train Losses - Segmentation: 1.7043, Depth: 0.0283, Combined: 1.7026, Adversarial: -2.9949
  Valid Losses - Segmentation: 1.8082, Depth: 0.0379, Combined: 1.8161, Adversarial: -2.9961


Epoch 8/60 - Training:   0%|          | 0/371 [07:15<?, ?batch/s]


Epoch 8/60 Results:
  Train Losses - Segmentation: 1.6945, Depth: 0.0281, Combined: 1.6926, Adversarial: -2.9950
  Valid Losses - Segmentation: 1.8299, Depth: 0.0378, Combined: 1.8377, Adversarial: -2.9964


Epoch 9/60 - Training:   0%|          | 0/371 [07:14<?, ?batch/s]


Epoch 9/60 Results:
  Train Losses - Segmentation: 1.6939, Depth: 0.0283, Combined: 1.6923, Adversarial: -2.9950
  Valid Losses - Segmentation: 1.8182, Depth: 0.0389, Combined: 1.8271, Adversarial: -2.9960


Epoch 10/60 - Training:   0%|          | 0/371 [07:16<?, ?batch/s]


Best model saved at epoch 10 with combined loss 1.8158
Epoch 10/60 Results:
  Train Losses - Segmentation: 1.7057, Depth: 0.0279, Combined: 1.7037, Adversarial: -2.9948
  Valid Losses - Segmentation: 1.8084, Depth: 0.0373, Combined: 1.8158, Adversarial: -2.9961


Epoch 11/60 - Training:   0%|          | 0/371 [07:15<?, ?batch/s]


Epoch 11/60 Results:
  Train Losses - Segmentation: 1.6983, Depth: 0.0286, Combined: 1.6970, Adversarial: -2.9948
  Valid Losses - Segmentation: 1.8115, Depth: 0.0372, Combined: 1.8188, Adversarial: -2.9960


Epoch 12/60 - Training:   0%|          | 0/371 [07:15<?, ?batch/s]


Best model saved at epoch 12 with combined loss 1.8099
Epoch 12/60 Results:
  Train Losses - Segmentation: 1.6894, Depth: 0.0281, Combined: 1.6876, Adversarial: -2.9950
  Valid Losses - Segmentation: 1.8027, Depth: 0.0371, Combined: 1.8099, Adversarial: -2.9961


Epoch 13/60 - Training:   0%|          | 0/371 [07:15<?, ?batch/s]


Epoch 13/60 Results:
  Train Losses - Segmentation: 1.6858, Depth: 0.0279, Combined: 1.6838, Adversarial: -2.9948
  Valid Losses - Segmentation: 1.8238, Depth: 0.0383, Combined: 1.8322, Adversarial: -2.9963


Epoch 14/60 - Training:   0%|          | 0/371 [07:18<?, ?batch/s]


Epoch 14/60 Results:
  Train Losses - Segmentation: 1.7090, Depth: 0.0285, Combined: 1.7075, Adversarial: -2.9950
  Valid Losses - Segmentation: 1.8250, Depth: 0.0373, Combined: 1.8324, Adversarial: -2.9964


Epoch 15/60 - Training:   0%|          | 0/371 [07:16<?, ?batch/s]


Epoch 15/60 Results:
  Train Losses - Segmentation: 1.6954, Depth: 0.0283, Combined: 1.6937, Adversarial: -2.9948
  Valid Losses - Segmentation: 1.8210, Depth: 0.0387, Combined: 1.8297, Adversarial: -2.9962


Epoch 16/60 - Training:   0%|          | 0/371 [07:13<?, ?batch/s]


Epoch 16/60 Results:
  Train Losses - Segmentation: 1.6855, Depth: 0.0285, Combined: 1.6840, Adversarial: -2.9949
  Valid Losses - Segmentation: 1.8074, Depth: 0.0381, Combined: 1.8156, Adversarial: -2.9960


Epoch 17/60 - Training:   0%|          | 0/371 [07:13<?, ?batch/s]


Epoch 17/60 Results:
  Train Losses - Segmentation: 1.6916, Depth: 0.0282, Combined: 1.6899, Adversarial: -2.9950
  Valid Losses - Segmentation: 1.8316, Depth: 0.0378, Combined: 1.8394, Adversarial: -2.9965


Epoch 18/60 - Training:   0%|          | 0/371 [07:13<?, ?batch/s]


Epoch 18/60 Results:
  Train Losses - Segmentation: 1.6856, Depth: 0.0281, Combined: 1.6837, Adversarial: -2.9949
  Valid Losses - Segmentation: 1.8327, Depth: 0.0390, Combined: 1.8417, Adversarial: -2.9963


Epoch 19/60 - Training:   0%|          | 0/371 [07:13<?, ?batch/s]


Epoch 19/60 Results:
  Train Losses - Segmentation: 1.6988, Depth: 0.0282, Combined: 1.6971, Adversarial: -2.9949
  Valid Losses - Segmentation: 1.8087, Depth: 0.0383, Combined: 1.8171, Adversarial: -2.9960


Epoch 20/60 - Training:   0%|          | 0/371 [07:07<?, ?batch/s]


Epoch 20/60 Results:
  Train Losses - Segmentation: 1.6874, Depth: 0.0280, Combined: 1.6855, Adversarial: -2.9947
  Valid Losses - Segmentation: 1.8098, Depth: 0.0385, Combined: 1.8184, Adversarial: -2.9960


Epoch 21/60 - Training:   0%|          | 0/371 [07:11<?, ?batch/s]


Epoch 21/60 Results:
  Train Losses - Segmentation: 1.6815, Depth: 0.0283, Combined: 1.6798, Adversarial: -2.9947
  Valid Losses - Segmentation: 1.8188, Depth: 0.0388, Combined: 1.8276, Adversarial: -2.9961


Epoch 22/60 - Training:   0%|          | 0/371 [07:06<?, ?batch/s]


Epoch 22/60 Results:
  Train Losses - Segmentation: 1.6947, Depth: 0.0285, Combined: 1.6933, Adversarial: -2.9952
  Valid Losses - Segmentation: 1.8380, Depth: 0.0380, Combined: 1.8460, Adversarial: -2.9964


Epoch 23/60 - Training:   0%|          | 0/371 [07:07<?, ?batch/s]


Epoch 23/60 Results:
  Train Losses - Segmentation: 1.6873, Depth: 0.0276, Combined: 1.6849, Adversarial: -2.9949
  Valid Losses - Segmentation: 1.8213, Depth: 0.0388, Combined: 1.8301, Adversarial: -2.9961


Epoch 24/60 - Training:   0%|          | 0/371 [07:12<?, ?batch/s]


Epoch 24/60 Results:
  Train Losses - Segmentation: 1.6919, Depth: 0.0282, Combined: 1.6902, Adversarial: -2.9948
  Valid Losses - Segmentation: 1.8160, Depth: 0.0382, Combined: 1.8242, Adversarial: -2.9962


Epoch 25/60 - Training:   0%|          | 0/371 [07:11<?, ?batch/s]


Epoch 25/60 Results:
  Train Losses - Segmentation: 1.7011, Depth: 0.0280, Combined: 1.6991, Adversarial: -2.9951
  Valid Losses - Segmentation: 1.8149, Depth: 0.0374, Combined: 1.8223, Adversarial: -2.9962


Epoch 26/60 - Training:   0%|          | 0/371 [07:12<?, ?batch/s]


Epoch 26/60 Results:
  Train Losses - Segmentation: 1.6856, Depth: 0.0283, Combined: 1.6839, Adversarial: -2.9949
  Valid Losses - Segmentation: 1.8211, Depth: 0.0380, Combined: 1.8292, Adversarial: -2.9962


Epoch 27/60 - Training:   0%|          | 0/371 [07:12<?, ?batch/s]


Epoch 27/60 Results:
  Train Losses - Segmentation: 1.6902, Depth: 0.0280, Combined: 1.6882, Adversarial: -2.9950
  Valid Losses - Segmentation: 1.8087, Depth: 0.0386, Combined: 1.8174, Adversarial: -2.9960


Epoch 28/60 - Training:   0%|          | 0/371 [07:11<?, ?batch/s]


Epoch 28/60 Results:
  Train Losses - Segmentation: 1.6876, Depth: 0.0284, Combined: 1.6860, Adversarial: -2.9949
  Valid Losses - Segmentation: 1.8224, Depth: 0.0389, Combined: 1.8313, Adversarial: -2.9962


Epoch 29/60 - Training:   0%|          | 0/371 [07:18<?, ?batch/s]


Epoch 29/60 Results:
  Train Losses - Segmentation: 1.6923, Depth: 0.0280, Combined: 1.6904, Adversarial: -2.9950
  Valid Losses - Segmentation: 1.8179, Depth: 0.0383, Combined: 1.8262, Adversarial: -2.9962


Epoch 30/60 - Training:   0%|          | 0/371 [07:12<?, ?batch/s]


Epoch 30/60 Results:
  Train Losses - Segmentation: 1.6821, Depth: 0.0277, Combined: 1.6798, Adversarial: -2.9951
  Valid Losses - Segmentation: 1.8176, Depth: 0.0378, Combined: 1.8254, Adversarial: -2.9963


Epoch 31/60 - Training:   0%|          | 0/371 [07:12<?, ?batch/s]


Epoch 31/60 Results:
  Train Losses - Segmentation: 1.6762, Depth: 0.0281, Combined: 1.6744, Adversarial: -2.9948
  Valid Losses - Segmentation: 1.8109, Depth: 0.0374, Combined: 1.8183, Adversarial: -2.9961


Epoch 32/60 - Training:   0%|          | 0/371 [07:08<?, ?batch/s]


Epoch 32/60 Results:
  Train Losses - Segmentation: 1.6981, Depth: 0.0281, Combined: 1.6963, Adversarial: -2.9948
  Valid Losses - Segmentation: 1.8130, Depth: 0.0385, Combined: 1.8215, Adversarial: -2.9961


Epoch 33/60 - Training:   0%|          | 0/371 [07:09<?, ?batch/s]


Epoch 33/60 Results:
  Train Losses - Segmentation: 1.6866, Depth: 0.0278, Combined: 1.6844, Adversarial: -2.9951
  Valid Losses - Segmentation: 1.8139, Depth: 0.0378, Combined: 1.8217, Adversarial: -2.9961


Epoch 34/60 - Training:   0%|          | 0/371 [07:15<?, ?batch/s]


Epoch 34/60 Results:
  Train Losses - Segmentation: 1.7002, Depth: 0.0282, Combined: 1.6985, Adversarial: -2.9946
  Valid Losses - Segmentation: 1.8115, Depth: 0.0377, Combined: 1.8192, Adversarial: -2.9961


Epoch 35/60 - Training:   0%|          | 0/371 [07:16<?, ?batch/s]


Epoch 35/60 Results:
  Train Losses - Segmentation: 1.7054, Depth: 0.0284, Combined: 1.7039, Adversarial: -2.9947
  Valid Losses - Segmentation: 1.8132, Depth: 0.0380, Combined: 1.8212, Adversarial: -2.9961


Epoch 36/60 - Training:   0%|          | 0/371 [07:15<?, ?batch/s]


Epoch 36/60 Results:
  Train Losses - Segmentation: 1.6911, Depth: 0.0280, Combined: 1.6892, Adversarial: -2.9949
  Valid Losses - Segmentation: 1.8169, Depth: 0.0378, Combined: 1.8247, Adversarial: -2.9962


Epoch 37/60 - Training:   0%|          | 0/371 [07:18<?, ?batch/s]


Epoch 37/60 Results:
  Train Losses - Segmentation: 1.6843, Depth: 0.0282, Combined: 1.6826, Adversarial: -2.9950
  Valid Losses - Segmentation: 1.8204, Depth: 0.0374, Combined: 1.8278, Adversarial: -2.9962


Epoch 38/60 - Training:   0%|          | 0/371 [07:15<?, ?batch/s]


Epoch 38/60 Results:
  Train Losses - Segmentation: 1.6901, Depth: 0.0282, Combined: 1.6884, Adversarial: -2.9950
  Valid Losses - Segmentation: 1.8102, Depth: 0.0373, Combined: 1.8175, Adversarial: -2.9961


Epoch 39/60 - Training:   0%|          | 0/371 [07:16<?, ?batch/s]


Epoch 39/60 Results:
  Train Losses - Segmentation: 1.6810, Depth: 0.0278, Combined: 1.6788, Adversarial: -2.9949
  Valid Losses - Segmentation: 1.8132, Depth: 0.0379, Combined: 1.8211, Adversarial: -2.9961


Epoch 40/60 - Training:   0%|          | 0/371 [07:19<?, ?batch/s]


Epoch 40/60 Results:
  Train Losses - Segmentation: 1.7059, Depth: 0.0285, Combined: 1.7044, Adversarial: -2.9948
  Valid Losses - Segmentation: 1.8158, Depth: 0.0374, Combined: 1.8233, Adversarial: -2.9962


Epoch 41/60 - Training:   0%|          | 0/371 [07:17<?, ?batch/s]


Epoch 41/60 Results:
  Train Losses - Segmentation: 1.6956, Depth: 0.0281, Combined: 1.6938, Adversarial: -2.9948
  Valid Losses - Segmentation: 1.8061, Depth: 0.0379, Combined: 1.8140, Adversarial: -2.9960


Epoch 42/60 - Training:   0%|          | 0/371 [07:15<?, ?batch/s]


Epoch 42/60 Results:
  Train Losses - Segmentation: 1.7153, Depth: 0.0282, Combined: 1.7136, Adversarial: -2.9949
  Valid Losses - Segmentation: 1.8158, Depth: 0.0377, Combined: 1.8235, Adversarial: -2.9961


Epoch 43/60 - Training:   0%|          | 0/371 [07:28<?, ?batch/s]


Epoch 43/60 Results:
  Train Losses - Segmentation: 1.6899, Depth: 0.0280, Combined: 1.6879, Adversarial: -2.9948
  Valid Losses - Segmentation: 1.8179, Depth: 0.0380, Combined: 1.8259, Adversarial: -2.9961


Epoch 44/60 - Training:   0%|          | 0/371 [07:19<?, ?batch/s]


Epoch 44/60 Results:
  Train Losses - Segmentation: 1.6826, Depth: 0.0282, Combined: 1.6808, Adversarial: -2.9950
  Valid Losses - Segmentation: 1.8149, Depth: 0.0379, Combined: 1.8229, Adversarial: -2.9961


Epoch 45/60 - Training:   0%|          | 0/371 [07:16<?, ?batch/s]


Epoch 45/60 Results:
  Train Losses - Segmentation: 1.6970, Depth: 0.0279, Combined: 1.6949, Adversarial: -2.9949
  Valid Losses - Segmentation: 1.8094, Depth: 0.0376, Combined: 1.8171, Adversarial: -2.9961


Epoch 46/60 - Training:   0%|          | 0/371 [07:12<?, ?batch/s]


Epoch 46/60 Results:
  Train Losses - Segmentation: 1.6922, Depth: 0.0283, Combined: 1.6905, Adversarial: -2.9948
  Valid Losses - Segmentation: 1.8151, Depth: 0.0379, Combined: 1.8231, Adversarial: -2.9961


Epoch 47/60 - Training:   0%|          | 0/371 [07:13<?, ?batch/s]


Epoch 47/60 Results:
  Train Losses - Segmentation: 1.6871, Depth: 0.0281, Combined: 1.6852, Adversarial: -2.9952
  Valid Losses - Segmentation: 1.8092, Depth: 0.0373, Combined: 1.8165, Adversarial: -2.9961


Epoch 48/60 - Training:   0%|          | 0/371 [07:11<?, ?batch/s]


Epoch 48/60 Results:
  Train Losses - Segmentation: 1.6940, Depth: 0.0280, Combined: 1.6920, Adversarial: -2.9950
  Valid Losses - Segmentation: 1.8122, Depth: 0.0377, Combined: 1.8200, Adversarial: -2.9961


Epoch 49/60 - Training:   0%|          | 0/371 [07:10<?, ?batch/s]


Epoch 49/60 Results:
  Train Losses - Segmentation: 1.6892, Depth: 0.0286, Combined: 1.6878, Adversarial: -2.9951
  Valid Losses - Segmentation: 1.8110, Depth: 0.0376, Combined: 1.8186, Adversarial: -2.9961


Epoch 50/60 - Training:   0%|          | 0/371 [07:09<?, ?batch/s]


Epoch 50/60 Results:
  Train Losses - Segmentation: 1.6966, Depth: 0.0284, Combined: 1.6951, Adversarial: -2.9946
  Valid Losses - Segmentation: 1.8104, Depth: 0.0375, Combined: 1.8180, Adversarial: -2.9961


Epoch 51/60 - Training:   0%|          | 0/371 [07:11<?, ?batch/s]


Epoch 51/60 Results:
  Train Losses - Segmentation: 1.6829, Depth: 0.0285, Combined: 1.6814, Adversarial: -2.9951
  Valid Losses - Segmentation: 1.8170, Depth: 0.0376, Combined: 1.8246, Adversarial: -2.9961


Epoch 52/60 - Training:   0%|          | 0/371 [07:13<?, ?batch/s]


Epoch 52/60 Results:
  Train Losses - Segmentation: 1.6748, Depth: 0.0281, Combined: 1.6730, Adversarial: -2.9948
  Valid Losses - Segmentation: 1.8122, Depth: 0.0379, Combined: 1.8202, Adversarial: -2.9961


Epoch 53/60 - Training:   0%|          | 0/371 [07:12<?, ?batch/s]


Epoch 53/60 Results:
  Train Losses - Segmentation: 1.6889, Depth: 0.0285, Combined: 1.6874, Adversarial: -2.9948
  Valid Losses - Segmentation: 1.8130, Depth: 0.0376, Combined: 1.8206, Adversarial: -2.9962


Epoch 54/60 - Training:   0%|          | 0/371 [07:12<?, ?batch/s]


Epoch 54/60 Results:
  Train Losses - Segmentation: 1.6893, Depth: 0.0280, Combined: 1.6874, Adversarial: -2.9950
  Valid Losses - Segmentation: 1.8115, Depth: 0.0379, Combined: 1.8194, Adversarial: -2.9961


Epoch 55/60 - Training:   0%|          | 0/371 [07:12<?, ?batch/s]


Epoch 55/60 Results:
  Train Losses - Segmentation: 1.6719, Depth: 0.0285, Combined: 1.6704, Adversarial: -2.9950
  Valid Losses - Segmentation: 1.8162, Depth: 0.0380, Combined: 1.8243, Adversarial: -2.9961


Epoch 56/60 - Training:   0%|          | 0/371 [07:15<?, ?batch/s]


Epoch 56/60 Results:
  Train Losses - Segmentation: 1.6844, Depth: 0.0281, Combined: 1.6826, Adversarial: -2.9949
  Valid Losses - Segmentation: 1.8159, Depth: 0.0379, Combined: 1.8238, Adversarial: -2.9961


Epoch 57/60 - Training:   0%|          | 0/371 [07:14<?, ?batch/s]


Epoch 57/60 Results:
  Train Losses - Segmentation: 1.6914, Depth: 0.0281, Combined: 1.6895, Adversarial: -2.9949
  Valid Losses - Segmentation: 1.8150, Depth: 0.0376, Combined: 1.8226, Adversarial: -2.9961


Epoch 58/60 - Training:   0%|          | 0/371 [07:13<?, ?batch/s]


Epoch 58/60 Results:
  Train Losses - Segmentation: 1.6949, Depth: 0.0278, Combined: 1.6927, Adversarial: -2.9950
  Valid Losses - Segmentation: 1.8182, Depth: 0.0379, Combined: 1.8261, Adversarial: -2.9961


Epoch 59/60 - Training:   0%|          | 0/371 [07:15<?, ?batch/s]


Epoch 59/60 Results:
  Train Losses - Segmentation: 1.6877, Depth: 0.0278, Combined: 1.6856, Adversarial: -2.9949
  Valid Losses - Segmentation: 1.8140, Depth: 0.0378, Combined: 1.8218, Adversarial: -2.9961


Epoch 60/60 - Training:   0%|          | 0/371 [07:13<?, ?batch/s]


Epoch 60/60 Results:
  Train Losses - Segmentation: 1.6838, Depth: 0.0280, Combined: 1.6818, Adversarial: -2.9951
  Valid Losses - Segmentation: 1.8186, Depth: 0.0382, Combined: 1.8269, Adversarial: -2.9961
Training visualization saved as GIF at results_test8_final2/20241130_195345/training_visualization_20241130_195345.gif
Combined GIF saved to results_test8_final2/20241130_195345/combined_result/combined_results.gif


('/home/rmajumd/2024/ML_in_image_synthesis/Cityscapes/Cityscapes/results_test8_final2',
 '/home/rmajumd/2024/ML_in_image_synthesis/Cityscapes/Cityscapes/results_test8_final2/20241129_145526/combined_result')

In [None]:
input("Enter timestamp folder value")

In [1]:
import os
import csv
import torch
import matplotlib.pyplot as plt


In [49]:
# root = os.path.join(os.getcwd(), 'results_test8_final' )
# model_dir = os.path.join(root,'20241129_015944')
# save_dir2 = os.path.join(root,'20241129_024547')

# output_path = os.path.join(save_dir2,'combined_results.gif')
# combine_training_gifs(model_dir, save_dir2, output_path)

Combined GIF saved to /home/rmajumd/2024/ML_in_image_synthesis/Cityscapes/Cityscapes/results_test8_final/20241129_024547/combined_results.gif


In [None]:
# check with perpetual loss later

In [None]:
# import os
# import csv
# from datetime import datetime
# import matplotlib.pyplot as plt
# import torch
# from torch.utils.data import DataLoader
# from torchvision.utils import save_image

# def train_model_with_loss_tracking(
#     model, train_loader, valid_loader, num_epochs, device, opt_sched, save_dir="results"
# ):
#     """
#     Trains a multi-task model with Conditional GANs, structural consistency, and perceptual loss.

#     Args:
#         model: The multi-task model to train.
#         train_loader: DataLoader for training data.
#         valid_loader: DataLoader for validation data.
#         num_epochs: Number of epochs to train.
#         device: Device for training ("cuda" or "cpu").
#         opt_sched: Dictionary of optimizers and schedulers.
#         save_dir: Directory to save results.

#     Returns:
#         train_losses, valid_losses: Lists of losses for training and validation.
#     """
#     # Create directories for saving results
#     timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
#     save_dir = os.path.join(save_dir, timestamp)
#     os.makedirs(save_dir, exist_ok=True)

#     # Prepare CSV file
#     csv_path = os.path.join(save_dir, f"loss_tracking_{timestamp}.csv")
#     with open(csv_path, "w", newline="") as f:
#         writer = csv.writer(f)
#         writer.writerow([
#             "epoch", "train_seg_loss", "train_depth_loss", "train_combined_loss",
#             "train_depth_sidl", "train_depth_smooth", "train_seg_iou", "train_seg_perceptual_loss",
#             "train_depth_perceptual_loss", "valid_seg_loss", "valid_depth_loss", "valid_combined_loss",
#             "valid_depth_sidl", "valid_depth_smooth", "valid_seg_iou", "valid_seg_perceptual_loss",
#             "valid_depth_perceptual_loss"
#         ])

#     # Initialize tracking variables
#     train_losses = {
#         "seg": [], "depth": [], "combined": [], "iou": [], "depth_sidl": [], "depth_smooth": [],
#         "seg_perceptual": [], "depth_perceptual": []
#     }
#     valid_losses = {
#         "seg": [], "depth": [], "combined": [], "iou": [], "depth_sidl": [], "depth_smooth": [],
#         "seg_perceptual": [], "depth_perceptual": []
#     }
#     best_combined_loss = float("inf")
#     gif_frames = []

#     # Perceptual Loss (example using VGG features)
#     perceptual_loss_fn = PerceptualLoss(pretrained_model="vgg16").to(device)

#     for epoch in range(num_epochs):
#         model.train()
#         epoch_train = {key: 0.0 for key in train_losses.keys()}
#         num_batches = 0

#         for batch in train_loader:
#             inputs, seg_labels, depth_labels = (
#                 batch["left"].to(device),
#                 batch["mask"].to(device),
#                 batch["depth"].to(device)
#             )
#             latent_noise = torch.randn(inputs.size(0), 3).to(device)

#             # Zero gradients
#             for optimizer in opt_sched["optimizers"].values():
#                 optimizer.zero_grad()

#             # Forward pass
#             outputs = model(inputs, input_size=inputs.size()[-2:])

#             # Loss calculations
#             seg_loss_task = nn.CrossEntropyLoss()(outputs["seg_output"], seg_labels) + dice_loss(outputs["seg_output"], seg_labels)
#             seg_perceptual_loss = perceptual_loss_fn(outputs["seg_output"], seg_labels.unsqueeze(1))
#             seg_loss = seg_loss_task + 0.1 * seg_perceptual_loss

#             depth_loss_sidl = scale_invariant_depth_loss(outputs["depth_output"], depth_labels)
#             depth_loss_huber = inv_huber_loss(outputs["depth_output"], depth_labels)
#             depth_loss_smooth = depth_smoothness_loss(outputs["depth_output"], inputs)
#             depth_perceptual_loss = perceptual_loss_fn(outputs["depth_output"], depth_labels)
#             depth_loss = depth_loss_sidl + depth_loss_huber + depth_loss_smooth + 0.1 * depth_perceptual_loss

#             combined_loss = seg_loss + depth_loss

#             # Backpropagation
#             combined_loss.backward()
#             for optimizer in opt_sched["optimizers"].values():
#                 optimizer.step()

#             # Update training metrics
#             epoch_train["seg"] += seg_loss.item()
#             epoch_train["depth"] += depth_loss.item()
#             epoch_train["combined"] += combined_loss.item()
#             epoch_train["iou"] += mean_iou(outputs["seg_output"], seg_labels, num_classes=20).item()
#             epoch_train["depth_sidl"] += depth_loss_sidl.item()
#             epoch_train["depth_smooth"] += depth_loss_smooth.item()
#             epoch_train["seg_perceptual"] += seg_perceptual_loss.item()
#             epoch_train["depth_perceptual"] += depth_perceptual_loss.item()
#             num_batches += 1

#         # Average training metrics
#         for key in epoch_train.keys():
#             train_losses[key].append(epoch_train[key] / num_batches)

#         # Validation loop
#         model.eval()
#         epoch_valid = {key: 0.0 for key in valid_losses.keys()}
#         num_valid_batches = 0

#         with torch.no_grad():
#             for batch in valid_loader:
#                 inputs, seg_labels, depth_labels = (
#                     batch["left"].to(device),
#                     batch["mask"].to(device),
#                     batch["depth"].to(device)
#                 )
#                 latent_noise = torch.randn(inputs.size(0), 3).to(device)

#                 # Forward pass
#                 outputs = model(inputs, input_size=inputs.size()[-2:])

#                 # Validation loss calculations
#                 seg_loss_task = nn.CrossEntropyLoss()(outputs["seg_output"], seg_labels) + dice_loss(outputs["seg_output"], seg_labels)
#                 seg_perceptual_loss = perceptual_loss_fn(outputs["seg_output"], seg_labels.unsqueeze(1))
#                 seg_loss = seg_loss_task + 0.1 * seg_perceptual_loss

#                 depth_loss_sidl = scale_invariant_depth_loss(outputs["depth_output"], depth_labels)
#                 depth_loss_huber = inv_huber_loss(outputs["depth_output"], depth_labels)
#                 depth_loss_smooth = depth_smoothness_loss(outputs["depth_output"], inputs)
#                 depth_perceptual_loss = perceptual_loss_fn(outputs["depth_output"], depth_labels)
#                 depth_loss = depth_loss_sidl + depth_loss_huber + depth_loss_smooth + 0.1 * depth_perceptual_loss

#                 combined_loss = seg_loss + depth_loss

#                 # Update validation metrics
#                 epoch_valid["seg"] += seg_loss.item()
#                 epoch_valid["depth"] += depth_loss.item()
#                 epoch_valid["combined"] += combined_loss.item()
#                 epoch_valid["iou"] += mean_iou(outputs["seg_output"], seg_labels, num_classes=20).item()
#                 epoch_valid["depth_sidl"] += depth_loss_sidl.item()
#                 epoch_valid["depth_smooth"] += depth_loss_smooth.item()
#                 epoch_valid["seg_perceptual"] += seg_perceptual_loss.item()
#                 epoch_valid["depth_perceptual"] += depth_perceptual_loss.item()
#                 num_valid_batches += 1

#         # Average validation metrics
#         for key in epoch_valid.keys():
#             valid_losses[key].append(epoch_valid[key] / num_valid_batches)

#         # Save best model
#         valid_combined_loss = epoch_valid["combined"] / num_valid_batches
#         if valid_combined_loss < best_combined_loss:
#             best_combined_loss = valid_combined_loss
#             torch.save(model.state_dict(), os.path.join(save_dir, "best_model.pth"))

#         # Append metrics to CSV
#         with open(csv_path, "a", newline="") as f:
#             writer = csv.writer(f)
#             writer.writerow([
#                 epoch + 1,
#                 epoch_train["seg"] / num_batches,
#                 epoch_train["depth"] / num_batches,
#                 epoch_train["combined"] / num_batches,
#                 epoch_train["depth_sidl"] / num_batches,
#                 epoch_train["depth_smooth"] / num_batches,
#                 epoch_train["iou"] / num_batches,
#                 epoch_train["seg_perceptual"] / num_batches,
#                 epoch_train["depth_perceptual"] / num_batches,
#                 epoch_valid["seg"] / num_valid_batches,
#                 epoch_valid["depth"] / num_valid_batches,
#                 epoch_valid["combined"] / num_valid_batches,
#                 epoch_valid["depth_sidl"] / num_valid_batches,
#                 epoch_valid["depth_smooth"] / num_valid_batches,
#                 epoch_valid["iou"] / num_valid_batches,
#                 epoch_valid["seg_perceptual"] / num_valid_batches,
#                 epoch_valid["depth_perceptual"] / num_valid_batches,
#             ])

#         # Update schedulers
#         for scheduler in opt_sched["schedulers"].values():
#             scheduler.step()
            
#     plot_all_losses(train_losses,valid_losses)

    

#     return train_losses, valid_losses


In [None]:
# import os
# import csv
# import torch
# import torch.nn as nn
# from datetime import datetime
# from torchvision.utils import save_image
# from PIL import Image
# import matplotlib.pyplot as plt

# # Updated Train Function
# def train_model_with_loss_tracking_and_gif(
#     model, train_loader, valid_loader, num_epochs, device, save_dir="training_output_bicycle_and_pix2pix"
# ):
#     # Create directories for saving models and outputs
#     os.makedirs(save_dir, exist_ok=True)
#     timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
#     save_dir = os.path.join(save_dir, timestamp)
#     os.makedirs(save_dir, exist_ok=True)

#     csv_path = os.path.join(save_dir, f"loss_tracking_{timestamp}.csv")
#     gif_path = os.path.join(save_dir, f"training_visualization_{timestamp}.gif")

#     # Initialize CSV for saving loss data
#     with open(csv_path, "w", newline="") as f:
#         writer = csv.writer(f)
#         writer.writerow([
#             "epoch", "train_seg_loss", "train_depth_loss", "train_combined_loss",
#             "train_depth_sidl", "train_depth_smooth", "train_seg_iou",
#             "valid_seg_loss", "valid_depth_loss", "valid_combined_loss",
#             "valid_depth_sidl", "valid_depth_smooth", "valid_seg_iou"
#         ])

#     best_combined_loss = float("inf")  # Initialize best combined loss for saving the best model
#     train_losses = {"seg": [], "depth": [], "combined": [], "iou": [], "depth_sidl": [], "depth_smooth": []}
#     valid_losses = {"seg": [], "depth": [], "combined": [], "iou": [], "depth_sidl": [], "depth_smooth": []}
#     gif_frames = []

#     for epoch in range(num_epochs):
#         torch.cuda.empty_cache()
#         model.train()
#         epoch_train = {key: 0.0 for key in train_losses.keys()}
#         num_batches = 0

#         # Training Loop
#         for batch in train_loader:
#             inputs, seg_labels, depth_labels = batch["left"].to(device), batch["mask"].to(device), batch["depth"].to(device)
#             latent_noise = torch.randn(inputs.size(0), 3).to(device)

#             model.optimizer_stage1.zero_grad()
#             model.optimizer_stage2.zero_grad()

#             # Forward Pass
#             outputs = model(inputs, seg_labels, depth_labels, latent_noise)
#             seg_output = outputs["seg_output"]
#             depth_output = outputs["depth_output"]

#             # Loss Calculations
#             seg_loss = nn.CrossEntropyLoss()(seg_output, seg_labels.squeeze(1))
#             seg_dice = dice_loss(seg_output, seg_labels)
#             seg_iou = mean_iou(seg_output, seg_labels, num_classes=20)
#             seg_loss_total = 0.6 * seg_loss + 0.4 * seg_dice

#             depth_sidl = scale_invariant_depth_loss(depth_output, depth_labels)
#             depth_smooth = depth_smoothness_loss(depth_output, inputs)
#             depth_loss_total = depth_sidl + depth_smooth

#             # Combined Loss
#             total_loss = seg_loss_total + depth_loss_total
#             total_loss.backward()

#             # Optimizers Step
#             model.optimizer_stage1.step()
#             model.optimizer_stage2.step()

#             # Accumulate Training Metrics
#             epoch_train["seg"] += seg_loss.item()
#             epoch_train["depth"] += depth_loss_total.item()
#             epoch_train["combined"] += total_loss.item()
#             epoch_train["iou"] += seg_iou.item()
#             epoch_train["depth_sidl"] += depth_sidl.item()
#             epoch_train["depth_smooth"] += depth_smooth.item()
#             num_batches += 1

#             # Save training images for visualization
#             if num_batches % 10 == 0:
#                 img_grid = torch.cat([inputs[0], seg_output[0].argmax(0, keepdim=True), depth_output[0]], dim=2)
#                 save_image(img_grid, os.path.join(save_dir, f"train_{epoch}_{num_batches}.png"))
#                 gif_frames.append(Image.open(os.path.join(save_dir, f"train_{epoch}_{num_batches}.png")))

#         model.scheduler_stage1.step()
#         model.scheduler_stage2.step(epoch_train["combined"] / num_batches)

#         # Average Training Losses
#         for key in epoch_train.keys():
#             train_losses[key].append(epoch_train[key] / num_batches)

#         # Validation Loop
#         model.eval()
#         epoch_valid = {key: 0.0 for key in valid_losses.keys()}
#         num_valid_batches = 0
#         with torch.no_grad():
#             for batch in valid_loader:
#                 inputs, seg_labels, depth_labels = batch["left"].to(device), batch["mask"].to(device), batch["depth"].to(device)
#                 latent_noise = torch.randn(inputs.size(0), 3).to(device)

#                 outputs = model(inputs, seg_labels, depth_labels, latent_noise)
#                 seg_output = outputs["seg_output"]
#                 depth_output = outputs["depth_output"]

#                 # Validation Loss Calculations
#                 seg_loss = nn.CrossEntropyLoss()(seg_output, seg_labels.squeeze(1))
#                 seg_dice = dice_loss(seg_output, seg_labels)
#                 seg_iou = mean_iou(seg_output, seg_labels, num_classes=20)
#                 seg_loss_total = 0.6 * seg_loss + 0.4 * seg_dice

#                 depth_sidl = scale_invariant_depth_loss(depth_output, depth_labels)
#                 depth_smooth = depth_smoothness_loss(depth_output, inputs)
#                 depth_loss_total = depth_sidl + depth_smooth

#                 # Accumulate Validation Metrics
#                 epoch_valid["seg"] += seg_loss.item()
#                 epoch_valid["depth"] += depth_loss_total.item()
#                 epoch_valid["combined"] += (seg_loss_total + depth_loss_total).item()
#                 epoch_valid["iou"] += seg_iou.item()
#                 epoch_valid["depth_sidl"] += depth_sidl.item()
#                 epoch_valid["depth_smooth"] += depth_smooth.item()
#                 num_valid_batches += 1
            
#             frame = save_training_visualization_as_gif2(epoch, inputs, seg_output, depth_output, seg_labels, depth_labels)
#             gif_frames.append(frame)
              

#         for key in epoch_valid.keys():
#             valid_losses[key].append(epoch_valid[key] / num_valid_batches)

#         # Save Model if Validation Loss Improves
#         valid_combined_loss = epoch_valid["combined"] / num_valid_batches
#         if valid_combined_loss < best_combined_loss:
#             best_combined_loss = valid_combined_loss
#             torch.save(model.state_dict(), os.path.join(save_dir, "best_model.pth"))
#             print(f"Best model saved at epoch {epoch+1} with combined loss {best_combined_loss:.4f}")
            
#         if epoch%10==0:
#             gif_path2 =os.path.join(save_dir,f"viz_epoch_{epoch}.gif")
#             gif_frames[0].save(gif_path2, save_all=True, append_images=gif_frames[1:], duration=500, loop=0)

            

#         # Append Validation Metrics to CSV
#         with open(csv_path, "a", newline="") as f:
#             writer = csv.writer(f)
#             writer.writerow([
#                 epoch + 1,
#                 epoch_train["seg"] / num_batches,
#                 epoch_train["depth"] / num_batches,
#                 epoch_train["combined"] / num_batches,
#                 epoch_train["depth_sidl"] / num_batches,
#                 epoch_train["depth_smooth"] / num_batches,
#                 epoch_train["iou"] / num_batches,
#                 epoch_valid["seg"] / num_valid_batches,
#                 epoch_valid["depth"] / num_valid_batches,
#                 epoch_valid["combined"] / num_valid_batches,
#                 epoch_valid["depth_sidl"] / num_valid_batches,
#                 epoch_valid["depth_smooth"] / num_valid_batches,
#                 epoch_valid["iou"] / num_valid_batches,
#             ])

#     # Save GIF
#     gif_frames[0].save(
#         gif_path, save_all=True, append_images=gif_frames[1:], duration=200, loop=0
#     )

#     return train_losses, valid_losses


In [28]:


# def train_model_with_loss_tracking_and_gif(
#     model, train_loader, valid_loader, num_epochs, device, save_dir="training_output_bicycle_and_pix2pix"):
#     # Create directory for saving models and outputs
#     os.makedirs(save_dir, exist_ok=True)
#     timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
#     save_dir = os.path.join(save_dir, timestamp)
#     os.makedirs(save_dir, exist_ok=True)

#     csv_path = os.path.join(save_dir, f"loss_tracking_{timestamp}.csv")
#     gif_path = os.path.join(save_dir, f"training_visualization_{timestamp}.gif")

#     # Initialize CSV for saving loss data
#     with open(csv_path, "w", newline="") as f:
#         writer = csv.writer(f)
#         writer.writerow([
#             "epoch", "train_seg_loss", "train_depth_loss", "train_combined_loss",
#             "train_depth_sidl", "train_depth_inv_huber", "train_depth_contrastive", "train_depth_smooth",
#             "valid_seg_loss", "valid_depth_loss", "valid_combined_loss",
#             "valid_depth_sidl", "valid_depth_inv_huber", "valid_depth_contrastive", "valid_depth_smooth"
#         ])

#     best_combined_loss = float("inf")  # Initialize best combined loss for saving the best model

#     train_losses = {"seg": [], "depth": [], "combined": [], "iou": [], "depth_sidl": [], "depth_smooth": []}
#     valid_losses = {"seg": [], "depth": [], "combined": [], "iou": [], "depth_sidl": [], "depth_smooth": []}
#     # , "depth_inv_huber": [], "depth_contrastive": []

#     gif_frames = []
#     num_classes = 20
#     # Optimizer for latent noise
    

#     for epoch in range(num_epochs):
#         torch.cuda.empty_cache()
#         torch.autograd.set_detect_anomaly(True)

#         model.train()
#         # epoch_train_seg_loss = 0
#         # epoch_train_depth_loss = 0
#         # epoch_train_iou = 0
#         # epoch_train_combined_loss = 0
#         # epoch_train_depth_sidl = 0
#         # epoch_train_depth_inv_huber = 0
#         # epoch_train_depth_contrastive = 0
#         # epoch_train_depth_smooth = 0
#         epoch_train = {key: 0.0 for key in train_losses.keys()}
#         num_batches = 0

#         reconstruction_layer = nn.Conv2d(256, 3, kernel_size=1).to(device)
        
#         # scaler = torch.cuda.amp.GradScaler()

#         # Training Loop
#         for batch in train_loader:
#             inputs, seg_labels, depth_labels = batch["left"].to(device), batch["mask"].to(device), batch["depth"].to(device)
#             print(inputs.shape,seg_labels.shape,depth_labels.shape) # torch.Size([8, 3, 200, 512]) torch.Size([8, 1, 200, 512]) torch.Size([8, 1, 1, 200, 512])
#             return
        
                       


#             # Forward pass
#             seg_output, depth_output, backbone_features = model(...)
            


#             seg_loss = nn.CrossEntropyLoss()(seg_output, seg_labels)
#             seg_dice = dice_loss(seg_output, seg_labels)
#             seg_iou = mean_iou(seg_output, seg_labels, num_classes)
#             seg_loss_total = 0.6 * seg_loss  + 0.4 * seg_dice
            
#             depth_sidl = scale_invariant_depth_loss(depth_output, depth_labels)
#             depth_inv_huber = inv_huber_loss(depth_output, depth_labels)
#             depth_smooth = depth_smoothness_loss(depth_output, inputs)
#             depth_loss_total = depth_sidl + depth_inv_huber + depth_smooth
            
           
           
#             # Combined Loss
#             total_loss = bicycle_loss + pix2pix_total_loss

#             # Single backward pass
#             total_loss.backward()

#             # Update both optimizers
#             model.optimizer_stage1.step()
#             model.optimizer_stage2.step()
#             latent_optimizer.step()

#             # Accumulate Training Metrics
#             epoch_train["seg"] += seg_loss.item()
#             epoch_train["depth"] += (depth_sidl + depth_smooth).item()
#             epoch_train["combined"] += total_loss.item()
#             epoch_train["iou"] += seg_iou.item()
#             epoch_train["depth_sidl"] += depth_sidl.item()
#             epoch_train["depth_smooth"] += depth_smooth.item()
#             num_batches += 1
            
#         model.scheduler_stage1.step()
#         model.scheduler_stage2.step(epoch_train["combined"]/num_batches)


#         # Average Training Losses
#         for key in epoch_train.keys():
#             train_losses[key].append(epoch_train[key] / num_batches)

#         print(
#             f"Epoch {epoch+1}/{num_epochs} - Train Seg Loss: {epoch_train['seg']:.4f}, "
#             f"Train Depth Loss: {epoch_train['depth']:.4f}, Train Combined Loss: {epoch_train['combined']:.4f}, "
#             f"Train mIOU: {epoch_train['iou']:.4f}, Train sidl Loss: {epoch_train['depth_sidl']:.4f}, "
#             f"Train depth smooth: {epoch_train['depth_smooth']:.4f}"
#     )       

#         # Validation Loop
#         model.eval()
#         epoch_valid = {key: 0.0 for key in valid_losses.keys()}
#         num_valid_batches = 0

#         with torch.no_grad():
#             for batch in valid_loader:
#                 # print("inside valid")
#                 inputs, seg_labels, depth_labels = batch["left"].to(device), batch["mask"].to(device), batch["depth"].to(device)

#                 # Ensure depth_labels and segmentation labels have correct dimensions
                

               


#                 seg_output_old =seg_output
#                 # Resize seg_output to match the spatial dimensions of seg_labels
#                 seg_output_resized = F.interpolate(seg_output, size=seg_labels.shape[1:], mode='bilinear', align_corners=False)
#                 seg_output = seg_output_resized

#                 depth_output_old = depth_output
#                 depth_output_resized = F.interpolate(depth_output, size=depth_labels.shape[-2:], mode='bilinear', align_corners=False)
#                 depth_output =depth_output_resized


#                 # Segmentation Loss
#                 seg_loss = nn.CrossEntropyLoss()(seg_output, seg_labels)
#                 seg_dice = dice_loss(seg_output, seg_labels)
#                 seg_iou = mean_iou(seg_output, seg_labels, num_classes)
#                 seg_loss_total = 0.6 * seg_loss  + 0.4 * seg_dice
                
#                 depth_sidl = scale_invariant_depth_loss(depth_output, depth_labels)
#                 depth_inv_huber = inv_huber_loss(depth_output, depth_labels)
#                 depth_smooth = depth_smoothness_loss(depth_output, inputs)
#                 depth_loss_total = depth_sidl + depth_inv_huber + depth_smooth

#                 pix2pix_loss = seg_loss_total + depth_loss_total

#                 # Combined Validation Loss
#                 combined_loss = pix2pix_loss

#                 # Accumulate Validation Metrics
#                 epoch_valid["seg"] += seg_loss.item()
#                 epoch_valid["depth"] += (depth_sidl + depth_smooth).item()
#                 epoch_valid["combined"] += combined_loss.item()
#                 epoch_valid["iou"] += seg_iou.item()
#                 epoch_valid["depth_sidl"] += depth_sidl.item()
#                 epoch_valid["depth_smooth"] += depth_smooth.item()
                
#                 num_valid_batches += 1
                
#             frame = save_training_visualization_as_gif2(epoch, inputs, seg_output, depth_output, seg_labels, depth_labels)
#             gif_frames.append(frame)
                
                
#         # Calculate epoch averages
#         # Average Validation Losses
#         for key in epoch_valid.keys():
#             valid_losses[key].append(epoch_valid[key] / num_valid_batches)

#         print(
#             f"Epoch {epoch+1}/{num_epochs} - Valid Seg Loss: {epoch_valid['seg']:.4f}, "
#             f"Valid Depth Loss: {epoch_valid['depth']:.4f}, Valid Combined Loss: {epoch_valid['combined']:.4f}, "
#             f"Valid mIOU: {epoch_valid['iou']:.4f}, Valid sidl Loss: {epoch_valid['depth_sidl']:.4f}, "
#             f"Valid depth smooth: {epoch_valid['depth_smooth']:.4f}"
#         )

#         # Write the losses to CSV
#         with open(csv_path, "a", newline="") as f:
#             writer = csv.writer(f)
#             writer.writerow([
#                 epoch + 1,
#                 train_losses["seg"], train_losses["depth"], train_losses["combined"],
#                 train_losses["depth_sidl"], 0,0,
#                 # avg_train_depth_inv_huber, avg_train_depth_contrastive,
#                 train_losses["depth_smooth"],
#                 valid_losses["seg"], valid_losses["depth"], valid_losses['combined'],
#                 valid_losses["depth_sidl"],0,0,
#                 # avg_valid_depth_inv_huber, avg_valid_depth_contrastive, 
#                 valid_losses["depth_smooth"]
#             ])

       
#         # Save best model
#         if valid_losses["combined"][-1] < best_combined_loss:
#             best_combined_loss = valid_losses["combined"][-1]
#             torch.save(model, os.path.join(save_dir, "best_model_resnetBackbone.pth"))
#             print(f"Best model saved at epoch {epoch+1} with combined loss {best_combined_loss:.4f}")
            
#         if epoch%10==0:
#             gif_path2 =os.path.join(save_dir,f"viz_epoch_{epoch}.gif")
#             gif_frames[0].save(gif_path2, save_all=True, append_images=gif_frames[1:], duration=500, loop=0)

    
    
#     plot_loss(train_losses, valid_losses, save_dir)
#     gif_frames[0].save(gif_path, save_all=True, append_images=gif_frames[1:], duration=500, loop=0)

    
    
    
#     return train_losses,valid_losses


In [29]:

# # Create your model instance
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# # device = 'cpu'
# model = MultiTaskModel(num_seg_classes=20, feature_channels=256).to(device)

In [30]:
# # Set the number of epochs
# num_epochs = 10

# # Call the training function
# train_losses, valid_losses = train_model_with_loss_tracking_and_gif(
#     model=model,
#     train_loader=train_loader,
#     valid_loader=valid_loader,
#     num_epochs=num_epochs,
#     device=device,
#     save_dir="test7_res"
# )

In [31]:


# def train_model_with_loss_tracking_and_gif(
#     model, train_loader, valid_loader, num_epochs, device, save_dir="training_output_bicycle_and_pix2pix"):
#     # Create directory for saving models and outputs
#     os.makedirs(save_dir, exist_ok=True)
#     timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
#     save_dir = os.path.join(save_dir, timestamp)
#     os.makedirs(save_dir, exist_ok=True)

#     csv_path = os.path.join(save_dir, f"loss_tracking_{timestamp}.csv")
#     gif_path = os.path.join(save_dir, f"training_visualization_{timestamp}.gif")

#     # Initialize CSV for saving loss data
#     with open(csv_path, "w", newline="") as f:
#         writer = csv.writer(f)
#         writer.writerow([
#             "epoch", "train_seg_loss", "train_depth_loss", "train_combined_loss",
#             "train_depth_sidl", "train_depth_inv_huber", "train_depth_contrastive", "train_depth_smooth",
#             "valid_seg_loss", "valid_depth_loss", "valid_combined_loss",
#             "valid_depth_sidl", "valid_depth_inv_huber", "valid_depth_contrastive", "valid_depth_smooth"
#         ])

#     best_combined_loss = float("inf")  # Initialize best combined loss for saving the best model

#     train_losses = {"seg": [], "depth": [], "combined": [], "iou": [], "depth_sidl": [], "depth_smooth": []}
#     valid_losses = {"seg": [], "depth": [], "combined": [], "iou": [], "depth_sidl": [], "depth_smooth": []}
#     # , "depth_inv_huber": [], "depth_contrastive": []

#     gif_frames = []
#     num_classes = 20
#     # Optimizer for latent noise
    

#     for epoch in range(num_epochs):
#         torch.cuda.empty_cache()
#         torch.autograd.set_detect_anomaly(True)

#         model.train()
#         # epoch_train_seg_loss = 0
#         # epoch_train_depth_loss = 0
#         # epoch_train_iou = 0
#         # epoch_train_combined_loss = 0
#         # epoch_train_depth_sidl = 0
#         # epoch_train_depth_inv_huber = 0
#         # epoch_train_depth_contrastive = 0
#         # epoch_train_depth_smooth = 0
#         epoch_train = {key: 0.0 for key in train_losses.keys()}
#         num_batches = 0

#         reconstruction_layer = nn.Conv2d(256, 3, kernel_size=1).to(device)
        
#         # scaler = torch.cuda.amp.GradScaler()

#         # Training Loop
#         for batch in train_loader:
#             inputs, seg_labels, depth_labels = batch["left"].to(device), batch["mask"].to(device), batch["depth"].to(device)

#             # Ensure depth_labels and segmentation labels have correct dimensions
#             if depth_labels.dim() == 5:
#                 depth_labels = depth_labels.squeeze(2)
#             if seg_labels.dim() == 4 and seg_labels.shape[1] == 1:
#                 seg_labels = seg_labels.squeeze(1)

#             # Transform depth labels
#             # depth_labels = torch.log(depth_labels.flatten(start_dim=1)) / 5
#             # depth_labels = depth_labels.view_as(depth_labels)  # Restore shape
#             # depth_labels = torch.clamp(depth_labels, min=1e-5) 
#             # depth_labels = torch.log(depth_labels + 1e-5) / 5  # Avoid log(0)

#             # print(f'seg_labels shape : {seg_labels.shape}')
#             # print(f'depth_labels shape: {depth_labels.shape}')

#             # Start with random noise as latent condition
#             if epoch == 0:
#                 latent_noise = torch.randn_like(inputs).to(device)
#                 # print(f"latent_noise: {latent_noise.shape}")
#                 latent_noise.requires_grad = True  # Make it trainable
#                 latent_optimizer = torch.optim.Adam([latent_noise], lr=1e-3)
            
            


#             # Stage 1: Train BicycleGAN (Backbone Features)
            

#             # Reset gradients for both optimizers
#             model.optimizer_stage1.zero_grad()
#             model.optimizer_stage2.zero_grad()
#             latent_optimizer.zero_grad()

#             # Forward pass
#             seg_output, depth_output, backbone_features = model(inputs, latent_noise)
#             # print(f'seg_ouput shape : {seg_output.shape}')
#             # print(f'depth_output shape: {depth_output.shape}')
#             # print(backbone_features.shape)
#             # return


#             seg_output_old =seg_output
#             # Resize seg_output to match the spatial dimensions of seg_labels
#             seg_output_resized = F.interpolate(seg_output, size=seg_labels.shape[1:], mode='bilinear', align_corners=False)
#             seg_output = seg_output_resized

#             # print(f"depth_output shape before resize: {depth_output.shape}")
#             # print(f"depth_labels shape: {depth_labels.shape}")
#             # return

#             depth_output_old = depth_output
#             depth_output_resized = F.interpolate(depth_output, size=depth_labels.shape[-2:], mode='bilinear', align_corners=False)
#             depth_output = depth_output_resized


#             # Pix2Pix Losses
#             seg_loss = nn.CrossEntropyLoss()(seg_output, seg_labels)
#             seg_dice = dice_loss(seg_output, seg_labels)
#             seg_iou = mean_iou(seg_output, seg_labels, num_classes)
#             seg_loss_total = 0.6 * seg_loss  + 0.4 * seg_dice
            
#             depth_sidl = scale_invariant_depth_loss(depth_output, depth_labels)
#             depth_inv_huber = inv_huber_loss(depth_output, depth_labels)
#             depth_smooth = depth_smoothness_loss(depth_output, inputs)
#             depth_loss_total = depth_sidl + depth_inv_huber + depth_smooth
            
#             pix2pix_loss = seg_loss_total + depth_loss_total

#             # Reconstruction loss
#             # inputs_resized = F.interpolate(inputs, size=(backbone_features.size(2), backbone_features.size(3)))
#             # reconstructed_image = reconstruction_layer(backbone_features)
#             # recon_loss = nn.L1Loss()(reconstructed_image, inputs_resized)
#             # adaptive_weight = 1 / (1 + torch.exp(-recon_loss))
#             # adaptive_weight_value = adaptive_weight.item() 

#             # loss_stage1 = nn.MSELoss()(real_validity, torch.ones_like(real_validity).to(device)) + recon_loss
#             # loss_stage1.backward(retain_graph=True)
#             # model.optimizer_stage1.step()

#             # Pix2Pix Adversarial Losses
#             seg_validity = model.segmentation_discriminator(seg_output)
#             depth_validity = model.depth_discriminator(depth_output)
#             adv_seg_loss = nn.MSELoss()(seg_validity, torch.ones_like(seg_validity))
#             adv_depth_loss = nn.MSELoss()(depth_validity, torch.ones_like(depth_validity))
#             pix2pix_total_loss = pix2pix_loss + adv_seg_loss + adv_depth_loss


#             # BicycleGAN Loss with Pix2Pix Condition
#             # real_validity = model.bicycle_discriminator(backbone_features)
#             # recon_loss = nn.L1Loss()(backbone_features, inputs)
#             # bicycle_loss = nn.MSELoss()(real_validity, torch.ones_like(real_validity)) + recon_loss
#             # conditional_bicycle_loss = bicycle_loss + pix2pix_loss
#             # conditional_bicycle_loss.backward(retain_graph=True)
#             # model.optimizer_stage1.step()

#             # BicycleGAN Loss with Pix2Pix Condition
#             real_validity = model.bicycle_discriminator(backbone_features,latent_noise)

#             # Resize inputs to match backbone_features
#             inputs_resized = F.interpolate(inputs, size=backbone_features.shape[-2:], mode='bilinear', align_corners=False)

#             # print(f"backbone_features shape: {backbone_features.shape}, inputs shape: {inputs.shape}")
#             # print(f"inputs_resized shape: {inputs_resized.shape}")
#             # recon_loss = nn.L1Loss()(backbone_features, inputs_resized)
#             bicycle_loss = adv_seg_loss + adv_depth_loss
#             # + recon_loss

#             # Combined Loss
#             total_loss = bicycle_loss + pix2pix_total_loss

#             # Single backward pass
#             total_loss.backward()

#             # Update both optimizers
#             model.optimizer_stage1.step()
#             model.optimizer_stage2.step()
#             latent_optimizer.step()

#             # Accumulate Training Metrics
#             epoch_train["seg"] += seg_loss.item()
#             epoch_train["depth"] += (depth_sidl + depth_smooth).item()
#             epoch_train["combined"] += total_loss.item()
#             epoch_train["iou"] += seg_iou.item()
#             epoch_train["depth_sidl"] += depth_sidl.item()
#             epoch_train["depth_smooth"] += depth_smooth.item()
#             num_batches += 1
            
#         model.scheduler_stage1.step()
#         model.scheduler_stage2.step(epoch_train["combined"]/num_batches)


#         # Average Training Losses
#         for key in epoch_train.keys():
#             train_losses[key].append(epoch_train[key] / num_batches)

#         print(
#             f"Epoch {epoch+1}/{num_epochs} - Train Seg Loss: {epoch_train['seg']:.4f}, "
#             f"Train Depth Loss: {epoch_train['depth']:.4f}, Train Combined Loss: {epoch_train['combined']:.4f}, "
#             f"Train mIOU: {epoch_train['iou']:.4f}, Train sidl Loss: {epoch_train['depth_sidl']:.4f}, "
#             f"Train depth smooth: {epoch_train['depth_smooth']:.4f}"
#     )       

#         # Validation Loop
#         model.eval()
#         # epoch_valid_seg_loss = 0
#         # epoch_valid_depth_loss = 0
#         # epoch_valid_iou =0
#         # epoch_valid_combined_loss = 0
#         # epoch_valid_depth_sidl = 0
#         # epoch_valid_depth_inv_huber = 0
#         # epoch_valid_depth_contrastive = 0
#         # epoch_valid_depth_smooth = 0
#         epoch_valid = {key: 0.0 for key in valid_losses.keys()}
#         num_valid_batches = 0

#         with torch.no_grad():
#             for batch in valid_loader:
#                 # print("inside valid")
#                 inputs, seg_labels, depth_labels = batch["left"].to(device), batch["mask"].to(device), batch["depth"].to(device)

#                 # Ensure depth_labels and segmentation labels have correct dimensions
#                 if depth_labels.dim() == 5:
#                     depth_labels = depth_labels.squeeze(2)
#                 if seg_labels.dim() == 4 and seg_labels.shape[1] == 1:
#                     seg_labels = seg_labels.squeeze(1)

#                 # Transform depth labels
#                 # depth_labels = torch.log(depth_labels.flatten(start_dim=1)) / 5
#                 # depth_labels = depth_labels.view_as(depth_labels)  # Restore shape
#                 # depth_labels = torch.clamp(depth_labels, min=1e-5) 
#                 # depth_labels = torch.log(depth_labels + 1e-5) / 5  # Avoid log(0)

#                 # Latent noise for validation
#                 latent_noise = torch.randn_like(inputs).to(device)
#                 seg_output, depth_output, backbone_features = model(inputs, latent_noise)

                
                

#                 seg_output_old =seg_output
#                 # Resize seg_output to match the spatial dimensions of seg_labels
#                 seg_output_resized = F.interpolate(seg_output, size=seg_labels.shape[1:], mode='bilinear', align_corners=False)
#                 seg_output = seg_output_resized

#                 depth_output_old = depth_output
#                 depth_output_resized = F.interpolate(depth_output, size=depth_labels.shape[-2:], mode='bilinear', align_corners=False)
#                 depth_output =depth_output_resized


#                 # Segmentation Loss
#                 seg_loss = nn.CrossEntropyLoss()(seg_output, seg_labels)
#                 seg_dice = dice_loss(seg_output, seg_labels)
#                 seg_iou = mean_iou(seg_output, seg_labels, num_classes)
#                 seg_loss_total = 0.6 * seg_loss  + 0.4 * seg_dice
                
#                 depth_sidl = scale_invariant_depth_loss(depth_output, depth_labels)
#                 depth_inv_huber = inv_huber_loss(depth_output, depth_labels)
#                 depth_smooth = depth_smoothness_loss(depth_output, inputs)
#                 depth_loss_total = depth_sidl + depth_inv_huber + depth_smooth

#                 pix2pix_loss = seg_loss_total + depth_loss_total

#                 # Combined Validation Loss
#                 combined_loss = pix2pix_loss

#                 # Accumulate Validation Metrics
#                 epoch_valid["seg"] += seg_loss.item()
#                 epoch_valid["depth"] += (depth_sidl + depth_smooth).item()
#                 epoch_valid["combined"] += combined_loss.item()
#                 epoch_valid["iou"] += seg_iou.item()
#                 epoch_valid["depth_sidl"] += depth_sidl.item()
#                 epoch_valid["depth_smooth"] += depth_smooth.item()
                
#                 num_valid_batches += 1
                
#                 # epoch, inputs, seg_output, depth_output, seg_labels, depth_labels, gif_frames
#             frame = save_training_visualization_as_gif2(epoch, inputs, seg_output, depth_output, seg_labels, depth_labels)
#             gif_frames.append(frame)
                
                
#         # Calculate epoch averages
#         # Average Validation Losses
#         for key in epoch_valid.keys():
#             valid_losses[key].append(epoch_valid[key] / num_valid_batches)

        
        
# # train_losses = { "depth_sidl": [], "depth_inv_huber": [], "depth_contrastive": [], "depth_smooth": []}
#         print(
#             f"Epoch {epoch+1}/{num_epochs} - Valid Seg Loss: {epoch_valid['seg']:.4f}, "
#             f"Valid Depth Loss: {epoch_valid['depth']:.4f}, Valid Combined Loss: {epoch_valid['combined']:.4f}, "
#             f"Valid mIOU: {epoch_valid['iou']:.4f}, Valid sidl Loss: {epoch_valid['depth_sidl']:.4f}, "
#             f"Valid depth smooth: {epoch_valid['depth_smooth']:.4f}"
#         )

#         # Write the losses to CSV
#         with open(csv_path, "a", newline="") as f:
#             writer = csv.writer(f)
#             writer.writerow([
#                 epoch + 1,
#                 train_losses["seg"], train_losses["depth"], train_losses["combined"],
#                 train_losses["depth_sidl"], 0,0,
#                 # avg_train_depth_inv_huber, avg_train_depth_contrastive,
#                 train_losses["depth_smooth"],
#                 valid_losses["seg"], valid_losses["depth"], valid_losses['combined'],
#                 valid_losses["depth_sidl"],0,0,
#                 # avg_valid_depth_inv_huber, avg_valid_depth_contrastive, 
#                 valid_losses["depth_smooth"]
#             ])

#         # Save GIF visualization frames
#         # save_training_visualization_as_gif(epoch, inputs, seg_output, depth_output, seg_labels, depth_labels, gif_frames)

#         # Save best model
#         if valid_losses["combined"][-1] < best_combined_loss:
#             best_combined_loss = valid_losses["combined"][-1]
#             torch.save(model, os.path.join(save_dir, "best_model_resnetBackbone.pth"))
#             print(f"Best model saved at epoch {epoch+1} with combined loss {best_combined_loss:.4f}")
            
#         if epoch%10==0:
#             gif_path2 =os.path.join(save_dir,f"viz_epoch_{epoch}.gif")
#             gif_frames[0].save(gif_path2, save_all=True, append_images=gif_frames[1:], duration=500, loop=0)

#     # Save GIF
#     # gif_frames[0].save(gif_path, save_all=True, append_images=gif_frames[1:], duration=500, loop=0)
    
#     plot_loss(train_losses, valid_losses, save_dir)
#     gif_frames[0].save(gif_path, save_all=True, append_images=gif_frames[1:], duration=500, loop=0)

    
    
    
#     return train_losses,valid_losses
