In [226]:
import os
from glob import glob
import time

import numpy as np
import h5py
import cv2

import torch 
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset,Subset
from torchvision import transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
import torch.nn.functional as F

import matplotlib as mpl
import matplotlib.pyplot as plt

from torchvision.models import mobilenet_v3_small
from sklearn.metrics import jaccard_score
from torchvision.models.mobilenetv3 import MobileNet_V3_Small_Weights
from torchvision.models import vgg16, VGG16_Weights



import csv

from datetime import datetime

from PIL import Image
import torch.nn.functional as F


from labels import labels

%matplotlib inline

curr_dir=os.getcwd()
root= os.path.join(curr_dir,"cityscapes_dataset")
curr_dir,root

('/home/rmajumd/2024/ML_in_image_synthesis/Cityscapes/Cityscapes',
 '/home/rmajumd/2024/ML_in_image_synthesis/Cityscapes/Cityscapes/cityscapes_dataset')

In [227]:

import torch
from torchvision import transforms
import torchvision.transforms.functional as TF
# from albumentations.augmentations.transforms import RandomShadow

class Normalize(object):
    """ Normalizes RGB image to  0-mean 1-std_dev """ 
    def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], depth_norm=5, max_depth=250):
        self.mean = mean
        self.std = std
        self.depth_norm = depth_norm
        self.max_depth = max_depth

    def __call__(self, sample):
        left, mask, depth = sample['left'], sample['mask'], sample['depth']
            
        return {'left': TF.normalize(left, self.mean, self.std), 
                'mask': mask, 
                'depth' : torch.clip( # saftey clip :)
                            torch.log(torch.clip(depth, 0, self.max_depth))/self.depth_norm, 
                            0, 
                            self.max_depth)}


class AddColorJitter(object):
    """Convert a color image to grayscale and normalize the color range to [0,1].""" 
    def __init__(self, brightness, contrast, saturation, hue):
        ''' Applies brightness, constrast, saturation, and hue jitter to image ''' 
        self.color_jitter = transforms.ColorJitter(brightness, contrast, saturation, hue)

    def __call__(self, sample):
        left, mask, depth = sample['left'], sample['mask'], sample['depth']

        return {'left': self.color_jitter(left), 
                'mask': mask, 
                'depth' : depth}


class Rescale(object):
    """ Rescales images with bilinear interpolation and masks with nearest interpolation """

    def __init__(self, h, w):
        self.h, self.w = h, w

    def __call__(self, sample):
        left, mask, depth = sample['left'], sample['mask'], sample['depth']
# mask interpolation Nearest is import to have smoothness
        return {'left': TF.resize(left, (self.h, self.w)), 
                'mask': TF.resize(mask.unsqueeze(0), (self.h, self.w), transforms.InterpolationMode.NEAREST), 
                'depth' : TF.resize(depth.unsqueeze(0), (self.h, self.w))}


class RandomCrop(object):
    def __init__(self, h, w, scale=(0.08, 1.0), ratio=(3.0 / 4.0, 4.0 / 3.0)):
        self.h = h
        self.w = w
        self.scale = scale
        self.ratio = ratio

    def __call__(self, sample):
        left, mask, depth = sample['left'], sample['mask'], sample['depth']
        i, j, h, w = transforms.RandomResizedCrop.get_params(left, scale=self.scale, ratio=self.ratio)

        return {'left': TF.resized_crop(left, i, j, h, w, (self.h, self.w)), 
                'mask': TF.resized_crop(mask.unsqueeze(0), i, j, h, w, (self.h, self.w), interpolation=TF.InterpolationMode.NEAREST),
                'depth' : TF.resized_crop(depth.unsqueeze(0), i, j, h, w, (self.h, self.w))}


class ToTensor(object):
    """Convert ndarrays in sample to Tensors."""
    def __call__(self, sample):
         
        left, mask, depth = sample['left'], sample['mask'], sample['depth']

        return {'left': transforms.ToTensor()(left), 
                'mask': torch.as_tensor(mask, dtype=torch.int64),
                'depth' : transforms.ToTensor()(depth).type(torch.float32)}
    

class ElasticTransform(object):
    def __init__(self, alpha=25.0, sigma=5.0, prob=0.5):
        self.alpha = [1.0, alpha]
        self.sigma = [1, sigma]
        self.prob = prob

    def __call__(self, sample):
        
        if torch.rand(1) < self.prob:

            left, mask, depth = sample['left'], sample['mask'], sample['depth']
            _, H, W = mask.shape
            displacement = transforms.ElasticTransform.get_params(self.alpha, self.sigma, [H, W])

            # # TEMP
            # print(TF.elastic_transform(left, displacement).shape)
            # print(TF.elastic_transform(mask.unsqueeze(0), displacement, interpolation=TF.InterpolationMode.NEAREST).shape)
            # print(torch.clip(TF.elastic_transform(depth, displacement), 0, depth.max()).shape)

            return {'left': TF.elastic_transform(left, displacement), 
                    'mask': TF.elastic_transform(mask.unsqueeze(0), displacement, interpolation=TF.InterpolationMode.NEAREST), 
                    'depth' : torch.clip(TF.elastic_transform(depth, displacement), 0, depth.max())} 
        
        else:
            return sample

        
    

# new transform to rotate the images
class RandomRotate(object):
    def __init__(self, angle):
        if not isinstance(angle, (list, tuple)):
            self.angle = (-abs(angle), abs(angle))
        else:
            self.angle = angle

    def __call__(self, sample):
        left, mask, depth = sample['left'], sample['mask'], sample['depth']

        angle = transforms.RandomRotation.get_params(self.angle)

        return {'left': TF.rotate(left, angle), 
                'mask': TF.rotate(mask.unsqueeze(0), angle), 
                'depth' : TF.rotate(depth, angle)}
    
    
class RandomHorizontalFlip(object):
    def __init__(self, prob=0.5):
        self.prob = prob

    def __call__(self, sample):
        
        if torch.rand(1) < self.prob:
            left, mask, depth = sample['left'], sample['mask'], sample['depth']
            return {'left': TF.hflip(left), 
                    'mask': TF.hflip(mask), 
                    'depth' : TF.hflip(depth)}
        else:
            return sample
        

class RandomVerticalFlip(object):
    def __init__(self, prob=0.5):
        self.prob = prob

    def __call__(self, sample):
        if torch.rand(1) < self.prob:
            left, mask, depth = sample['left'], sample['mask'], sample['depth']
            return {'left': TF.vflip(left), 
                    'mask': TF.vflip(mask), 
                    'depth' : TF.vflip(depth)}
        else:
            return sample
        

In [228]:

def convert_to_numpy(image):
    if not isinstance(image, np.ndarray):
        if len(image.shape) == 2:
            image = image.detach().cpu().numpy()
        else:
            image = image.detach().cpu().numpy().transpose(1, 2, 0)

    return image

def get_color_mask(mask, labels, id_type='id'):
    try:
        h, w = mask.shape
    except ValueError:
        mask = mask.squeeze(-1)
        h, w = mask.shape

    color_mask = np.zeros((h, w, 3), dtype=np.uint8)

    if id_type == 'id':
        for lbl in labels:
            color_mask[mask == lbl.id] = lbl.color
    elif id_type == 'trainId':
        for lbl in labels:
            if (lbl.trainId != 255) | (lbl.trainId != -1):
                color_mask[mask == lbl.trainId] = lbl.color

    return color_mask


def plot_items(left, mask, depth, labels=None, num_seg_labels=34, id_type='id'):
    left = convert_to_numpy(left)
    mask = convert_to_numpy(mask)
    depth = convert_to_numpy(depth)

    # unnormalize left image
    left = (left*np.array([0.229, 0.224, 0.225])) + np.array([0.485, 0.456, 0.406])

    # cmaps: 'prism', 'terrain', 'turbo', 'gist_rainbow_r', 'nipy_spectral_r'
    
    
    _, ax = plt.subplots(1, 3, figsize=(15,10))
    ax[0].imshow(left)
    ax[0].set_title("Left Image")

    if labels:
        color_mask = get_color_mask(mask, labels, id_type)
        ax[1].imshow(color_mask)
    else:
        cmap = mpl.colormaps.get_cmap('nipy_spectral_r').resampled(num_seg_labels)
        ax[1].imshow(mask, cmap=cmap)

    ax[1].set_title("Seg Mask")
    ax[2].imshow(depth, cmap='plasma')
    ax[2].set_title("Depth")

In [229]:
def scale_invariant_depth_loss(pred, target, lambda_weight=0.1):
    if pred.shape != target.shape:
        pred = F.interpolate(pred, size=target.shape[1:], mode='bilinear', align_corners=False)
    
    diff = pred - target
    n = diff.numel()
    mse = torch.sum(diff**2) / n
    scale_invariant = mse - (lambda_weight / (n**2)) * (torch.sum(diff))**2
    return scale_invariant

def depth_smoothness_loss(pred, img, alpha=1.0):
    depth_grad_x = torch.abs(pred[:, :, :, :-1] - pred[:, :, :, 1:])
    depth_grad_y = torch.abs(pred[:, :, :-1, :] - pred[:, :, 1:, :])
    img_grad_x = torch.mean(torch.abs(img[:, :, :, :-1] - img[:, :, :, 1:]), dim=1, keepdim=True)
    img_grad_y = torch.mean(torch.abs(img[:, :, :-1, :] - img[:, :, 1:, :]), dim=1, keepdim=True)
    smoothness_x = depth_grad_x * torch.exp(-alpha * img_grad_x)
    smoothness_y = depth_grad_y * torch.exp(-alpha * img_grad_y)
    return smoothness_x.mean() + smoothness_y.mean()


def inv_huber_loss(pred, target, delta=0.1):
    """
    Inverse Huber loss for depth prediction.
    Args:
        pred (Tensor): Predicted depth map.
        target (Tensor): Ground truth depth map.
        delta (float): Threshold for switching between quadratic and linear terms.
    Returns:
        Tensor: Inverse Huber loss.
    """
    abs_diff = torch.abs(pred - target)
    delta_tensor = torch.tensor(delta, dtype=abs_diff.dtype, device=abs_diff.device)  # Convert delta to tensor
    quadratic = torch.minimum(abs_diff, delta_tensor)
    linear = abs_diff - quadratic
    return (0.5 * quadratic**2 + delta_tensor * linear).mean()


def mean_iou(pred, target, num_classes):
    pred = torch.argmax(pred, dim=1)
    intersection = torch.logical_and(pred == target, target != 255).float()  # Ignore class 255
    union = torch.logical_or(pred == target, target != 255).float()
    iou = torch.sum(intersection) / torch.sum(union)
    return iou



def contrastive_loss(pred, target, margin=1.0):
    """
    Contrastive loss to ensure the depth map predictions are closer to the target.
    """
    # Flatten the tensors for element-wise operations
    pred_flat = pred.view(pred.size(0), -1)  # Flatten except for the batch dimension
    target_flat = target.view(target.size(0), -1)  # Flatten except for the batch dimension

    # Compute the pairwise distances
    distances = torch.sqrt(torch.sum((pred_flat - target_flat) ** 2, dim=1))  # Batch-wise distances

    # Create labels for contrastive loss
    labels = (torch.abs(pred_flat - target_flat).mean(dim=1) < margin).float()  # Batch-wise labels

    # Calculate contrastive loss
    similar_loss = labels * distances**2
    dissimilar_loss = (1 - labels) * torch.clamp(margin - distances, min=0)**2
    loss = (similar_loss + dissimilar_loss).mean()

    return loss


def dice_loss(predictions, targets, smooth=1e-6):
    """
    Calculate Dice Loss for segmentation.
    Args:
        predictions (torch.Tensor): The predicted segmentation map (logits or probabilities).
                                    Shape: [batch_size, num_classes, height, width]
        targets (torch.Tensor): The ground truth segmentation map (one-hot encoded or integer labels).
                                Shape: [batch_size, height, width]
        smooth (float): Smoothing factor to avoid division by zero.
    Returns:
        torch.Tensor: Dice Loss (scalar).
    """
    # Convert integer labels to one-hot if needed
    if predictions.shape != targets.shape:
        targets = F.one_hot(targets, num_classes=predictions.shape[1]).permute(0, 3, 1, 2).float()
    
    # Apply softmax to predictions for multi-class segmentation
    predictions = torch.softmax(predictions, dim=1)
    
    # Flatten tensors to calculate intersection and union
    predictions_flat = predictions.view(predictions.shape[0], predictions.shape[1], -1)
    targets_flat = targets.view(targets.shape[0], targets.shape[1], -1)
    
    # Calculate intersection and union
    intersection = (predictions_flat * targets_flat).sum(dim=-1)
    union = predictions_flat.sum(dim=-1) + targets_flat.sum(dim=-1)
    
    # Calculate Dice Coefficient
    dice_coeff = (2 * intersection + smooth) / (union + smooth)
    
    # Dice Loss
    return 1 - dice_coeff.mean()

In [230]:
def plot_loss(train_losses, valid_losses, save_dir):
    epochs = range(1, len(train_losses["seg"]) + 1)

    # Plot Segmentation Loss
    plt.figure(figsize=(10, 6))
    plt.plot(epochs, train_losses["seg"], label="Train Seg Loss")
    plt.plot(epochs, valid_losses["seg"], label="Valid Seg Loss")
    plt.xlabel("Epochs")
    plt.ylabel("Segmentation Loss")
    plt.legend()
    plt.title("Segmentation Loss Over Epochs")
    plt.savefig(os.path.join(save_dir, "segmentation_loss.png"))
    plt.close()

    # Plot Depth Loss
    plt.figure(figsize=(10, 6))
    plt.plot(epochs, train_losses["depth"], label="Train Depth Loss")
    plt.plot(epochs, valid_losses["depth"], label="Valid Depth Loss")
    plt.xlabel("Epochs")
    plt.ylabel("Depth Loss")
    plt.legend()
    plt.title("Depth Loss Over Epochs")
    plt.savefig(os.path.join(save_dir, "depth_loss.png"))
    plt.close()

    # Plot Combined Loss
    plt.figure(figsize=(10, 6))
    plt.plot(epochs, train_losses["combined"], label="Train Combined Loss")
    plt.plot(epochs, valid_losses["combined"], label="Valid Combined Loss")
    plt.xlabel("Epochs")
    plt.ylabel("Combined Loss")
    plt.legend()
    plt.title("Combined Loss Over Epochs")
    plt.savefig(os.path.join(save_dir, "combined_loss.png"))
    plt.close()


In [231]:
def save_training_visualization_as_gif2(epoch, inputs, seg_output, depth_output, seg_labels, depth_labels):
    inputs = inputs.detach().cpu()
    seg_output = torch.argmax(seg_output, dim=1).detach().cpu()
    depth_output = depth_output.detach().cpu()
    seg_labels = seg_labels.detach().cpu()
    depth_labels = depth_labels.detach().cpu()
    
#     inputs_rgb = (inputs - inputs.min()) / (inputs.max() - inputs.min() + 1e-5)  # Normalize inputs to [0, 1]
    
#     # Normalize depth maps for visualization
#     depth_labels_vis = (depth_labels - depth_labels.min()) / (depth_labels.max() - depth_labels.min() + 1e-5)
#     depth_preds_vis = (depth_output - depth_output.min()) / (depth_output.max() - depth_output.min() + 1e-5)



    batch_size = min(4, inputs.size(0))  # Limit to 4 samples for visualization
    fig, axes = plt.subplots(batch_size, 5, figsize=(15, 4 * batch_size))

    for i in range(batch_size):
        
        inputs_temp = inputs[i]
        # print(f"inputs_temp: {inputs_temp.shape}")
        inputs_rgb = (inputs_temp - inputs_temp.min()) / (inputs_temp.max() - inputs_temp.min() + 1e-5)  # Normalize inputs to [0, 1]
        
        depth_labels_vis = (depth_labels[i] - depth_labels[i].min()) / (depth_labels[i].max() - depth_labels[i].min() + 1e-5)
        depth_preds = depth_output[i]
        depth_preds_vis = (depth_preds - depth_preds.min()) / (depth_preds.max() - depth_preds.min() + 1e-5)
        # print(f"depth_labels_vis: {depth_labels_vis.shape}")
        # print(f"depth_preds_vis: {depth_preds_vis.shape}")

    
        
        # Row 1: Ground truth
        axes[i, 0].imshow(inputs_rgb.permute(1, 2, 0))
        axes[i, 0].set_title("RGB Image")
        axes[i, 0].axis("off")

        axes[i, 1].imshow(seg_labels[i], cmap="tab20")
        axes[i, 1].set_title("GT Segmentation")
        axes[i, 1].axis("off")

        axes[i, 2].imshow(depth_labels_vis.squeeze(), cmap="inferno")
        axes[i, 2].set_title("GT Depth")
        axes[i, 2].axis("off")

        # Row 2: Predictions
        axes[i, 3].imshow(seg_output[i], cmap="tab20")
        axes[i, 3].set_title("Generated Segmentation")
        axes[i, 3].axis("off")

        axes[i, 4].imshow(depth_preds_vis.squeeze(), cmap="inferno")
        axes[i, 4].set_title("Generated Depth")
        axes[i, 4].axis("off")
        
    # Remove axes for cleaner visualization
    for ax in axes.flat:
        ax.axis("off")


    # plt.tight_layout()
    fig.tight_layout()
    fig.canvas.draw()
    
    # # Save current epoch as an image for GIF
    # epoch_img_path = os.path.join(gif_path, f"epoch_{epoch}.png")
    # os.makedirs(gif_path, exist_ok=True)
    # plt.savefig(epoch_img_path)
    # plt.close()
    
    
    # return epoch_img_path
    frame = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8)  # Updated to buffer_rgba
    frame = frame.reshape(fig.canvas.get_width_height()[::-1] + (4,))  # RGBA has 4 channels
    plt.close(fig)

    # Convert to PIL.Image for GIF
    frame_rgb = frame[:, :, :3] 

    # Return as PIL.Image for GIF creation
    # return Image.fromarray(frame)
    return Image.fromarray(frame_rgb)




In [232]:
import torchvision.models as models

# Define the ResBlock
class ResBlock(nn.Module):
    def __init__(self, channels):
        super(ResBlock, self).__init__()
        self.conv_block = nn.Sequential(
            nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False),
            nn.InstanceNorm2d(channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False),
            nn.InstanceNorm2d(channels),
        )

    def forward(self, x):
        return x + self.conv_block(x)

# Define the CRPBlock
class CRPBlock(nn.Module):
    def __init__(self, in_chans, out_chans, n_stages=4, groups=False):
        super(CRPBlock, self).__init__()
        self.n_stages = n_stages
        groups = in_chans if groups else 1
        self.mini_blocks = nn.ModuleList()
        for _ in range(n_stages):
            self.mini_blocks.append(nn.MaxPool2d(kernel_size=5, stride=1, padding=2))
            self.mini_blocks.append(nn.Conv2d(in_chans, out_chans, kernel_size=1, bias=False, groups=groups))
    
    def forward(self, x):
        out = x
        for block in self.mini_blocks:
            out = block(out)
            x = x + out
        return x

class ResNetBackbone(nn.Module):
    def __init__(self, pretrained=True, feature_dim=256):
        super(ResNetBackbone, self).__init__()
        base_model = models.resnet18(pretrained=pretrained)

        # Freeze pre-trained layers
        for param in base_model.parameters():
            param.requires_grad = False

        # Extract ResNet layers and modify strides/pooling to preserve spatial dimensions
        layers = list(base_model.children())[:-2]  # Remove FC and AvgPool layers
        for layer in layers:
            if isinstance(layer, nn.Conv2d):
                layer.stride = (1, 1)  # Set stride to 1
            elif isinstance(layer, nn.MaxPool2d) or isinstance(layer, nn.AvgPool2d):
                layer.stride = (1, 1)  # Avoid reducing dimensions with pooling layers

        self.features = nn.Sequential(*layers)

        # Adjust final feature dimension using a 1x1 convolution
        self.feature_dim = feature_dim
        self.adjust_channels = nn.Conv2d(base_model.fc.in_features, feature_dim, kernel_size=1, bias=False)

    def forward(self, x):
        x = self.features(x)  # Extract features without changing spatial dimensions
        x = self.adjust_channels(x)  # Adjust feature channels
        return x



In [233]:

class CityScapesDataset(Dataset):
    def __init__(self, root, transform=None, split='train', label_map='id', crop=True):
        """
        
        """
        self.root = root
        self.transform = transform
        self.label_map = label_map
        self.crop = crop

        self.left_paths = glob(os.path.join(root, 'leftImg8bit', split, '**/*.png'))
        self.mask_paths = glob(os.path.join(root, 'gtFine', split, '**/*gtFine_labelIds.png'))
        self.depth_paths = glob(os.path.join(root, 'crestereo_depth2', split, '**/*.npy'))

        sorted(self.left_paths)
        sorted(self.mask_paths)
        sorted(self.depth_paths)

        # get label mappings
        self.id_2_train = {}
        self.id_2_cat = {}
        self.train_2_id = {}
        self.id_2_name = {-1 : 'unlabeled'}
        self.trainid_2_name = {19 : 'unlabeled'} # {255 : 'unlabeled', -1 : 'unlabeled'}

        for lbl in labels:
            self.id_2_train.update({lbl.id : lbl.trainId})
            self.id_2_cat.update({lbl.id : lbl.categoryId})
            if lbl.trainId != 19: # (lbl.trainId > 0) and (lbl.trainId != 255):
                self.trainid_2_name.update({lbl.trainId : lbl.name})
                self.train_2_id.update({lbl.trainId : lbl.id})
            if lbl.id > 0:
                self.id_2_name.update({lbl.id : lbl.name})


    def __getitem__(self, idx):
        left = cv2.cvtColor(cv2.imread(self.left_paths[idx]), cv2.COLOR_BGR2RGB)
        mask = cv2.imread(self.mask_paths[idx], cv2.IMREAD_UNCHANGED).astype(np.uint8)
        depth = np.load(self.depth_paths[idx]) # data is type float16

        if self.crop:
            left = left[:800, :, :]
            mask = mask[:800, :]
            depth = depth[:800, :]

        # get label id
        if self.label_map == 'id':
            mask[mask==-1] == 0
        elif self.label_map == 'trainId':
            for _id, train_id in self.id_2_train.items():
                mask[mask==_id] = train_id
        elif self.label_map == 'categoryId':
            for _id, train_id in self.id_2_cat.items():
                mask[mask==_id] = train_id

        sample = {'left' : left, 'mask' : mask, 'depth' : depth}

        if self.transform:
            sample = self.transform(sample)

        # ensure that no depth values are less than 0
        depth[depth < 0] = 0

        return sample
    

    def __len__(self):
        print(f"Number of RGB images: {len(self.left_paths)}")
        print(f"Number of Mask images: {len(self.mask_paths)}")
        print(f"Number of depth images: {len(self.depth_paths)}")
        return len(self.left_paths)
    
    

In [236]:
OG_W, OG_H = 2048, 800 # OG width and height after crop
W, H = OG_W//4, OG_H//4 # resize w,h for training

transform = transforms.Compose([
    ToTensor(),
    RandomCrop(H, W),
    # ElasticTransform(alpha=100.0, sigma=25.0, prob=0.5),
    AddColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),
    RandomHorizontalFlip(0.5),
    RandomVerticalFlip(0.2),
    # RandomRotate((-30, 30)),
    Normalize()
])

valid_transform = transforms.Compose([
    ToTensor(),
    Rescale(H, W),
    Normalize()
])

test_transform = transforms.Compose([
    ToTensor(),
    Normalize()
])


BATCH_SIZE = 8
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'



train_dataset = CityScapesDataset(root, transform=transform, split='train', label_map='trainId') # 'trainId')
train_subset = Subset(train_dataset, indices=range(2968)) #2968
# train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, pin_memory=True, shuffle=True)
train_loader = DataLoader(train_subset, batch_size=BATCH_SIZE, pin_memory=True, shuffle=True)


valid_dataset = CityScapesDataset(root, transform=valid_transform, split='val', label_map='trainId')
val_subset = Subset(valid_dataset, indices=range(496)) #496 
# valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, pin_memory=True, shuffle=False)
valid_loader = DataLoader(val_subset, batch_size=BATCH_SIZE, pin_memory=True, shuffle=False)


# shared Generator

In [237]:
# class SharedGenerator(nn.Module):
#     def __init__(self):
#         """
#         Shared Generator for both tasks.
#         Contains shared layers for skip connection processing and refinement.
#         """
#         super(SharedGenerator, self).__init__()
#         # Shared convolution layers to process each skip connection
#         self.shared_conv1 = nn.Conv2d(576, 256, kernel_size=1, bias=False)  # Process l11_out (1/32)
#         self.shared_conv2 = nn.Conv2d(576, 256, kernel_size=1, bias=False)  # Process l7_out (1/16)
#         self.shared_conv3 = nn.Conv2d(576, 256, kernel_size=1, bias=False)  # Process l3_out (1/8)
#         self.shared_conv4 = nn.Conv2d(576, 256, kernel_size=1, bias=False)  # Process l1_out (1/4)

#         # Shared CRP blocks for refinement
#         self.shared_crp1 = CRPBlock(256, 256, n_stages=4)  # CRP for 1/32
#         self.shared_crp2 = CRPBlock(256, 256, n_stages=4)  # CRP for 1/16
#         self.shared_crp3 = CRPBlock(256, 256, n_stages=4)  # CRP for 1/8
#         self.shared_crp4 = CRPBlock(256, 256, n_stages=4)  # CRP for 1/4

#     def forward(self, skips):
#         """
#         Process skips with shared layers for task-specific generation.
#         Args:
#             skips (dict): Skip connections from the encoder.
#         Returns:
#             dict: Processed skip connections.
#         """
#         x1 = self.shared_crp1(self.shared_conv1(skips["l11_out"]))
#         x2 = self.shared_crp2(self.shared_conv2(skips["l7_out"]))
#         x3 = self.shared_crp3(self.shared_conv3(skips["l3_out"]))
#         x4 = self.shared_crp4(self.shared_conv4(skips["l1_out"]))

#         return {"x1": x1, "x2": x2, "x3": x3, "x4": x4}


In [238]:
# class SharedPix2PixGenerator(nn.Module):
#     def __init__(self, seg_output_channels=20, depth_output_channels=1):
#         """
#         Shared Pix2Pix Generator for Segmentation and Depth tasks.
#         Args:
#             seg_output_channels (int): Number of output channels for segmentation.
#             depth_output_channels (int): Number of output channels for depth estimation.
#         """
#         super(SharedPix2PixGenerator, self).__init__()
#         self.shared_generator = SharedGenerator()
#         self.seg_output_layer = TaskOutputLayer(output_channels=seg_output_channels)
#         self.depth_output_layer = TaskOutputLayer(output_channels=depth_output_channels)

#     def forward(self, skips, input_size):
#         """
#         Forward pass for both tasks.
#         Args:
#             skips (dict): Skip connections from the encoder.
#             input_size (tuple): Original input size (H, W).
#         Returns:
#             dict: Outputs for segmentation and depth tasks.
#         """
#         shared_features = self.shared_generator(skips)

#         # Task-specific outputs
#         seg_output = self.seg_output_layer(shared_features["x4"], input_size)
#         depth_output = self.depth_output_layer(shared_features["x4"], input_size)

#         return {
#             "seg_output": seg_output,
#             "depth_output": depth_output
#         }


# Saving batch gif code And function to plot al losses


In [239]:
def save_training_visualization_as_gif2(epoch, inputs, seg_output, depth_output, seg_labels, depth_labels):
    inputs = inputs.detach().cpu()
    seg_output = torch.argmax(seg_output, dim=1).detach().cpu()
    depth_output = depth_output.detach().cpu()
    seg_labels = seg_labels.detach().cpu()
    depth_labels = depth_labels.detach().cpu()
    
#     inputs_rgb = (inputs - inputs.min()) / (inputs.max() - inputs.min() + 1e-5)  # Normalize inputs to [0, 1]
    
#     # Normalize depth maps for visualization
#     depth_labels_vis = (depth_labels - depth_labels.min()) / (depth_labels.max() - depth_labels.min() + 1e-5)
#     depth_preds_vis = (depth_output - depth_output.min()) / (depth_output.max() - depth_output.min() + 1e-5)



    batch_size = min(4, inputs.size(0))  # Limit to 4 samples for visualization
    fig, axes = plt.subplots(batch_size, 5, figsize=(15, 4 * batch_size))

    for i in range(batch_size):
        
        inputs_temp = inputs[i]
        # print(f"inputs_temp: {inputs_temp.shape}")
        inputs_rgb = (inputs_temp - inputs_temp.min()) / (inputs_temp.max() - inputs_temp.min() + 1e-5)  # Normalize inputs to [0, 1]
        
        depth_labels_vis = (depth_labels[i] - depth_labels[i].min()) / (depth_labels[i].max() - depth_labels[i].min() + 1e-5)
        depth_preds = depth_output[i]
        depth_preds_vis = (depth_preds - depth_preds.min()) / (depth_preds.max() - depth_preds.min() + 1e-5)
        # print(f"depth_labels_vis: {depth_labels_vis.shape}")
        # print(f"depth_preds_vis: {depth_preds_vis.shape}")

    
        
        # Row 1: Ground truth
        axes[i, 0].imshow(inputs_rgb.permute(1, 2, 0))
        axes[i, 0].set_title("RGB Image")
        axes[i, 0].axis("off")

        axes[i, 1].imshow(seg_labels[i], cmap="tab20")
        axes[i, 1].set_title("GT Segmentation")
        axes[i, 1].axis("off")

        axes[i, 2].imshow(depth_labels_vis.squeeze(), cmap="inferno")
        axes[i, 2].set_title("GT Depth")
        axes[i, 2].axis("off")

        # Row 2: Predictions
        axes[i, 3].imshow(seg_output[i], cmap="tab20")
        axes[i, 3].set_title("Generated Segmentation")
        axes[i, 3].axis("off")

        axes[i, 4].imshow(depth_preds_vis.squeeze(), cmap="inferno")
        axes[i, 4].set_title("Generated Depth")
        axes[i, 4].axis("off")
        
    # Remove axes for cleaner visualization
    for ax in axes.flat:
        ax.axis("off")


    # plt.tight_layout()
    fig.tight_layout()
    fig.canvas.draw()
    
    # # Save current epoch as an image for GIF
    # epoch_img_path = os.path.join(gif_path, f"epoch_{epoch}.png")
    # os.makedirs(gif_path, exist_ok=True)
    # plt.savefig(epoch_img_path)
    # plt.close()
    
    
    # return epoch_img_path
    frame = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8)  # Updated to buffer_rgba
    frame = frame.reshape(fig.canvas.get_width_height()[::-1] + (4,))  # RGBA has 4 channels
    plt.close(fig)

    # Convert to PIL.Image for GIF
    frame_rgb = frame[:, :, :3] 

    # Return as PIL.Image for GIF creation
    # return Image.fromarray(frame)
    return Image.fromarray(frame_rgb)

def plot_all_losses(train_losses,valid_losses,save_dir):
    # Plot training and validation losses
    for key in train_losses.keys():
        plt.figure()
        plt.plot(train_losses[key], label=f"Train {key}")
        plt.plot(valid_losses[key], label=f"Valid {key}")
        plt.xlabel("Epoch")
        plt.ylabel(key.replace("_", " ").title())
        plt.legend()
        plt.savefig(os.path.join(save_dir, f"{key}_loss.png"))
        plt.close()


# Loss Function

In [240]:
def compute_adversarial_losses(
    discriminator_model,  # Pass the actual model
    discriminator_real, discriminator_fake, hinge_loss=True, gradient_penalty=True, real_inputs=None
):
    """
    Computes adversarial losses for WGAN-GP and Hinge Loss.
    """
    # Generator Loss
    generator_loss = -torch.mean(discriminator_fake)

    # Discriminator Loss
    if hinge_loss:
        discriminator_real_loss = torch.mean(torch.relu(1.0 - discriminator_real))
        discriminator_fake_loss = torch.mean(torch.relu(1.0 + discriminator_fake))
        discriminator_loss = discriminator_real_loss + discriminator_fake_loss
    else:
        discriminator_loss = torch.mean(discriminator_fake) - torch.mean(discriminator_real)

    # Gradient Penalty (WGAN-GP)
    if gradient_penalty and not hinge_loss:
        alpha = torch.rand(real_inputs.size(0), 1, 1, 1, device=real_inputs.device)
        interpolates = alpha * real_inputs + (1 - alpha) * real_inputs.detach()
        interpolates.requires_grad_(True)
        # interpolates_output = discriminator_real(interpolates)
        
        # Use the actual discriminator model for gradient penalty computation
        interpolates_output = discriminator_model(interpolates)

        gradients = torch.autograd.grad(
            outputs=interpolates_output,
            inputs=interpolates,
            grad_outputs=torch.ones_like(interpolates_output),
            create_graph=True,
            retain_graph=True,
            only_inputs=True,
        )[0]
        # gradient_penalty_term = torch.mean((gradients.view(gradients.size(0), -1).norm(2, dim=1) - 1) ** 2)
        gradient_penalty_term = torch.mean((gradients.flatten(start_dim=1).norm(2, dim=1) - 1) ** 2)

        discriminator_loss += 10.0 * gradient_penalty_term

    return generator_loss, discriminator_loss


In [241]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models

class PerceptualLoss(nn.Module):
    def __init__(self, pretrained_model="vgg16", layers=["relu3_3"], device="cuda"):
        """
        Perceptual loss class.

        Args:
            pretrained_model (str): Pretrained model to use (e.g., "vgg16").
            layers (list of str): Layers to extract features from.
            device (str): Device to load the pretrained model on ("cuda" or "cpu").
        """
        super().__init__()

        # Load pretrained model
        if pretrained_model == "vgg16":
            vgg = models.vgg16(weights=VGG16_Weights.IMAGENET1K_V1).features.to(device).eval()
        else:
            raise ValueError(f"Unsupported pretrained model: {pretrained_model}")

        # Freeze the parameters
        for param in vgg.parameters():
            param.requires_grad = False

        # Select layers
        self.layers = layers
        self.feature_extractor = nn.ModuleDict({
            layer: vgg[:i] for i, layer in enumerate(vgg._modules.keys()) if layer in self.layers
        })

    def forward(self, generated, target):
        """
        Compute perceptual loss between generated and target images.

        Args:
            generated (torch.Tensor): Generated image batch.
            target (torch.Tensor): Target image batch.

        Returns:
            torch.Tensor: MSE loss between extracted features.
        """
        loss = 0.0
        for layer_name, extractor in self.feature_extractor.items():
            gen_features = extractor(generated)
            target_features = extractor(target)
            loss += F.mse_loss(gen_features, target_features)
        return loss

def scale_invariant_depth_loss(pred, target, lambda_weight=0.1):
    if pred.shape != target.shape:
        pred = F.interpolate(pred, size=target.shape[1:], mode='bilinear', align_corners=False)
    
    diff = pred - target
    n = diff.numel()
    mse = torch.sum(diff**2) / n
    scale_invariant = mse - (lambda_weight / (n**2)) * (torch.sum(diff))**2
    return scale_invariant

def depth_smoothness_loss(pred, img, alpha=1.0):
    depth_grad_x = torch.abs(pred[:, :, :, :-1] - pred[:, :, :, 1:])
    depth_grad_y = torch.abs(pred[:, :, :-1, :] - pred[:, :, 1:, :])
    img_grad_x = torch.mean(torch.abs(img[:, :, :, :-1] - img[:, :, :, 1:]), dim=1, keepdim=True)
    img_grad_y = torch.mean(torch.abs(img[:, :, :-1, :] - img[:, :, 1:, :]), dim=1, keepdim=True)
    smoothness_x = depth_grad_x * torch.exp(-alpha * img_grad_x)
    smoothness_y = depth_grad_y * torch.exp(-alpha * img_grad_y)
    return smoothness_x.mean() + smoothness_y.mean()


def inv_huber_loss(pred, target, delta=0.1):
    """
    Inverse Huber loss for depth prediction.
    Args:
        pred (Tensor): Predicted depth map.
        target (Tensor): Ground truth depth map.
        delta (float): Threshold for switching between quadratic and linear terms.
    Returns:
        Tensor: Inverse Huber loss.
    """
    abs_diff = torch.abs(pred - target)
    delta_tensor = torch.tensor(delta, dtype=abs_diff.dtype, device=abs_diff.device)  # Convert delta to tensor
    quadratic = torch.minimum(abs_diff, delta_tensor)
    linear = abs_diff - quadratic
    return (0.5 * quadratic**2 + delta_tensor * linear).mean()


def mean_iou(pred, target, num_classes):
    pred = torch.argmax(pred, dim=1)
    intersection = torch.logical_and(pred == target, target != 255).float()  # Ignore class 255
    union = torch.logical_or(pred == target, target != 255).float()
    iou = torch.sum(intersection) / torch.sum(union)
    return iou



def contrastive_loss(pred, target, margin=1.0):
    """
    Contrastive loss to ensure the depth map predictions are closer to the target.
    """
    # Flatten the tensors for element-wise operations
    pred_flat = pred.view(pred.size(0), -1)  # Flatten except for the batch dimension
    target_flat = target.view(target.size(0), -1)  # Flatten except for the batch dimension

    # Compute the pairwise distances
    distances = torch.sqrt(torch.sum((pred_flat - target_flat) ** 2, dim=1))  # Batch-wise distances

    # Create labels for contrastive loss
    labels = (torch.abs(pred_flat - target_flat).mean(dim=1) < margin).float()  # Batch-wise labels

    # Calculate contrastive loss
    similar_loss = labels * distances**2
    dissimilar_loss = (1 - labels) * torch.clamp(margin - distances, min=0)**2
    loss = (similar_loss + dissimilar_loss).mean()

    return loss


def dice_loss(predictions, targets, smooth=1e-6):
    """
    Calculate Dice Loss for segmentation.
    Args:
        predictions (torch.Tensor): The predicted segmentation map (logits or probabilities).
                                    Shape: [batch_size, num_classes, height, width]
        targets (torch.Tensor): The ground truth segmentation map (one-hot encoded or integer labels).
                                Shape: [batch_size, height, width]
        smooth (float): Smoothing factor to avoid division by zero.
    Returns:
        torch.Tensor: Dice Loss (scalar).
    """
    # Convert integer labels to one-hot if needed
    if predictions.shape != targets.shape:
        targets = F.one_hot(targets, num_classes=predictions.shape[1]).permute(0, 3, 1, 2).float()
    
    # Apply softmax to predictions for multi-class segmentation
    predictions = torch.softmax(predictions, dim=1)
    
    # Flatten tensors to calculate intersection and union
    predictions_flat = predictions.view(predictions.shape[0], predictions.shape[1], -1)
    targets_flat = targets.view(targets.shape[0], targets.shape[1], -1)
    
    # Calculate intersection and union
    intersection = (predictions_flat * targets_flat).sum(dim=-1)
    union = predictions_flat.sum(dim=-1) + targets_flat.sum(dim=-1)
    
    # Calculate Dice Coefficient
    dice_coeff = (2 * intersection + smooth) / (union + smooth)
    
    # Dice Loss
    return 1 - dice_coeff.mean()


def initialize_optimizers_and_schedulers(model, lr_gen=1e-4, lr_disc=1e-4, weight_decay=1e-4):
    """
    Initialize optimizers and schedulers for all generators and discriminators.
    
    Args:
        model (nn.Module): MultiTaskModel instance.
        lr_gen (float): Learning rate for generators.
        lr_disc (float): Learning rate for discriminators.
        weight_decay (float): Weight decay for optimizers.
    
    Returns:
        dict: Optimizers and schedulers for generators and discriminators.
    """
    # Optimizers for shared generator
    optimizer_shared_gen = torch.optim.AdamW(
        model.feature_generator.parameters(),
        lr=lr_gen,
        weight_decay=weight_decay
    )
    scheduler_shared_gen = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer_shared_gen, T_max=50, eta_min=1e-6
    )

    # Optimizer and scheduler for the shared generator's refinement layer
    optimizer_shared_refine = torch.optim.AdamW(
        model.shared_generator.parameters(),
        lr=lr_gen,
        weight_decay=weight_decay
    )
    scheduler_shared_refine = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer_shared_refine, T_max=50, eta_min=1e-6
    )

    # Optimizers and schedulers for task-specific generators
    optimizer_seg_gen = torch.optim.AdamW(
        model.seg_output_layer.parameters(),
        lr=lr_gen,
        weight_decay=weight_decay
    )
    scheduler_seg_gen = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer_seg_gen, mode='min', factor=0.5, patience=5
    )

    optimizer_depth_gen = torch.optim.AdamW(
        model.depth_output_layer.parameters(),
        lr=lr_gen,
        weight_decay=weight_decay
    )
    scheduler_depth_gen = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer_depth_gen, mode='min', factor=0.5, patience=5
    )

    # Optimizers and schedulers for task-specific discriminators
    optimizer_seg_disc = torch.optim.AdamW(
        model.seg_discriminator.parameters(),
        lr=lr_disc,
        weight_decay=weight_decay
    )
    scheduler_seg_disc = torch.optim.lr_scheduler.StepLR(
        optimizer_seg_disc, step_size=20, gamma=0.1
    )

    optimizer_depth_disc = torch.optim.AdamW(
        model.depth_discriminator.parameters(),
        lr=lr_disc,
        weight_decay=weight_decay
    )
    scheduler_depth_disc = torch.optim.lr_scheduler.StepLR(
        optimizer_depth_disc, step_size=20, gamma=0.1
    )

    # Optimizer and scheduler for the multi-task discriminator
    optimizer_multi_task_disc = torch.optim.AdamW(
        model.multi_task_discriminator.parameters(),
        lr=lr_disc,
        weight_decay=weight_decay
    )
    scheduler_multi_task_disc = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer_multi_task_disc, T_max=50, eta_min=1e-6
    )

    return {
        "optimizers": {
            "shared_gen": optimizer_shared_gen,
            "shared_refine": optimizer_shared_refine,
            "seg_gen": optimizer_seg_gen,
            "depth_gen": optimizer_depth_gen,
            "seg_disc": optimizer_seg_disc,
            "depth_disc": optimizer_depth_disc,
            "multi_task_disc": optimizer_multi_task_disc
        },
        "schedulers": {
            "shared_gen": scheduler_shared_gen,
            "shared_refine": scheduler_shared_refine,
            "seg_gen": scheduler_seg_gen,
            "depth_gen": scheduler_depth_gen,
            "seg_disc": scheduler_seg_disc,
            "depth_disc": scheduler_depth_disc,
            "multi_task_disc": scheduler_multi_task_disc
        }
    }


## MultiTaskModel

In [242]:
class MobileNetV3Backbone(nn.Module):
    def __init__(self, backbone):
        super().__init__()

        self.backbone = backbone
        self.proj_l1 = nn.Conv2d(16, 576, kernel_size=1, bias=False)   # For l1_out (1/4 resolution)
        self.proj_l3 = nn.Conv2d(24, 576, kernel_size=1, bias=False)  # For l3_out (1/8 resolution)
        self.proj_l7 = nn.Conv2d(48, 576, kernel_size=1, bias=False)  # For l7_out (1/16 resolution)
        self.proj_l11 = nn.Conv2d(96, 576, kernel_size=1, bias=False) # For l11_out (1/32 resolution)

    
    def forward(self, x):
        """ Passes input theough MobileNetV3 backbone feature extraction layers
            layers to add connections to (0 indexed)
                - 1:  1/4 res
                - 3:  1/8 res
                - 7, 8:  1/16 res
                - 10, 11: 1/32 res
           """
        # skips = nn.ParameterDict()
        # for i in range(len(self.backbone) - 1):
        #     x = self.backbone[i](x)
        #     # add skip connection outputs
        #     if i in [1, 3, 7, 11]:
        #         skips.update({f"l{i}_out" : x})

        # return skips
        skips = {}  # Dictionary to store skip connections

        for i, layer in enumerate(self.backbone):
            x = layer(x)
            # Add skip connections for specific layers
            if i == 1:
                skips["l1_out"] = self.proj_l1(x)  # Project l1_out
            elif i == 3:
                skips["l3_out"] = self.proj_l3(x)  # Project l3_out
            elif i == 7:
                skips["l7_out"] = self.proj_l7(x)  # Project l7_out
            elif i == 11:
                skips["l11_out"] = self.proj_l11(x)  # Project l11_out

        return skips

In [243]:
class EnhancedSharedGenerator(nn.Module):
    def __init__(self):
        """
        Enhanced Shared Generator for Segmentation and Depth tasks.
        Includes additional refinement layers for better generalization.
        """
        super(EnhancedSharedGenerator, self).__init__()
        # Shared convolution layers to process each skip connection
        self.shared_conv1 = nn.Conv2d(576, 256, kernel_size=1, bias=False)  # Process l11_out (1/32)
        self.shared_conv2 = nn.Conv2d(576, 256, kernel_size=1, bias=False)  # Process l7_out (1/16)
        self.shared_conv3 = nn.Conv2d(576, 256, kernel_size=1, bias=False)  # Process l3_out (1/8)
        self.shared_conv4 = nn.Conv2d(576, 256, kernel_size=1, bias=False)  # Process l1_out (1/4)

        # CRP blocks for refinement
        self.shared_crp1 = CRPBlock(256, 256, n_stages=4)
        self.shared_crp2 = CRPBlock(256, 256, n_stages=4)
        self.shared_crp3 = CRPBlock(256, 256, n_stages=4)
        self.shared_crp4 = CRPBlock(256, 256, n_stages=4)

        # Additional refinement layers
        self.refine = nn.Sequential(
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1)
        )

    def forward(self, skips):
        """
        Process skips with shared layers for task-specific generation.
        Args:
            skips (dict): Skip connections from the encoder.
        Returns:
            dict: Processed skip connections.
        """
        x1 = self.shared_crp1(self.shared_conv1(skips["l11_out"]))
        x1 = self.refine(x1)  # Extra refinement

        x2 = self.shared_crp2(self.shared_conv2(skips["l7_out"]))
        x2 = self.refine(x2)

        x3 = self.shared_crp3(self.shared_conv3(skips["l3_out"]))
        x3 = self.refine(x3)

        x4 = self.shared_crp4(self.shared_conv4(skips["l1_out"]))
        x4 = self.refine(x4)

        return {"x1": x1, "x2": x2, "x3": x3, "x4": x4}


In [244]:
class TaskOutputLayer(nn.Module):
    def __init__(self, output_channels):
        """
        Task-specific output layers for generating final predictions.
        Args:
            output_channels (int): Number of output channels (e.g., 20 for segmentation, 1 for depth).
        """
        super(TaskOutputLayer, self).__init__()
        self.final_conv = nn.Conv2d(256, output_channels, kernel_size=3, padding=1)

    def forward(self, x, input_size):
        """
        Generate task-specific output.
        Args:
            x (Tensor): Input feature map.
            input_size (tuple): Original input size (H, W).
        Returns:
            Tensor: Task-specific output.
        """
        x = self.final_conv(x)
        return nn.functional.interpolate(x, size=input_size, mode="bilinear", align_corners=False)



In [245]:
class TaskSpecificDiscriminator(nn.Module):
    def __init__(self, input_channels):
        super(TaskSpecificDiscriminator, self).__init__()
        self.adapt_conv = nn.Conv2d(input_channels+input_channels, input_channels, kernel_size=1, bias=False)
        self.model = nn.Sequential(
            nn.Conv2d(input_channels, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 1, kernel_size=4, stride=1, padding=1)
        )

    def forward(self, task_output, labels=None):
        """
        Forward pass through the discriminator.

        Args:
            task_output (Tensor): Output from the generator (e.g., seg_output or depth_output).
            labels (Tensor, optional): Ground truth labels. If provided, aligns channels with task_output.

        Returns:
            Tensor: Discriminator's prediction.
        """
        if labels is not None:
            # Ensure labels match the shape of task_output
            if labels.dim() < task_output.dim():
                labels = labels.unsqueeze(1)  # Add channel dimension if needed
            if labels.size(1) != task_output.size(1):
                labels = torch.nn.functional.one_hot(labels.squeeze(1), num_classes=task_output.size(1))
                labels = labels.permute(0, 3, 1, 2).float().to(task_output.device)
            combined = torch.cat([task_output, labels], dim=1)
            combined = self.adapt_conv(combined)
        else:
            combined = task_output

        return self.model(combined)


In [246]:

class MultiTaskDiscriminator(nn.Module):
    def __init__(self, input_channels):
        """
        Multi-Task Discriminator for evaluating all task-specific outputs.
        Args:
            input_channels (int): Number of input channels for concatenated features and outputs.
        """
        super(MultiTaskDiscriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(input_channels, 128, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1)
        )

    def forward(self, inputs):
        """
        Evaluate input image and task-specific outputs.
        Args:
            inputs (Tensor): Input image.
            outputs (list[Tensor]): List of task-specific outputs.
        Returns:
            Tensor: Discriminator output.
        """
        # combined = torch.cat([inputs] + outputs, dim=1)
        return self.model(inputs)


In [247]:
class MultiTaskModel(nn.Module):
    def __init__(self, backbone, num_seg_classes=20, depth_channels=1):
        """
        Multi-task model with shared Pix2Pix Generator, task-specific discriminators,
        and a multi-task discriminator.
        Args:
            backbone (nn.Module): Encoder backbone for feature extraction.
            num_seg_classes (int): Number of segmentation classes.
            depth_channels (int): Number of output channels for depth.
        """
        super(MultiTaskModel, self).__init__()
        self.feature_generator = MobileNetV3Backbone(backbone)
        self.shared_generator = EnhancedSharedGenerator()
        self.seg_output_layer = TaskOutputLayer(output_channels=num_seg_classes)
        self.depth_output_layer = TaskOutputLayer(output_channels=depth_channels)

        # Task-specific discriminators
        self.seg_discriminator = TaskSpecificDiscriminator(input_channels=num_seg_classes)
        self.depth_discriminator = TaskSpecificDiscriminator(input_channels=depth_channels)

        # Multi-task discriminator
        self.multi_task_discriminator = MultiTaskDiscriminator(input_channels= num_seg_classes + depth_channels)
        
    def forward(self, inputs, input_size, seg_labels=None, depth_labels=None, return_discriminator_outputs=False):
        # Extract features from the encoder
        skips = self.feature_generator(inputs)
        shared_features = self.shared_generator(skips)

        # Task-specific outputs
        seg_output = self.seg_output_layer(shared_features["x4"], input_size)
        depth_output = self.depth_output_layer(shared_features["x4"], input_size)

        output_dict = {
            "seg_output": seg_output,
            "depth_output": depth_output,
        }

        if return_discriminator_outputs:
            
            # Detach outputs to prevent discriminator backward from interfering with the generator
            seg_output_detached = seg_output.detach()
            depth_output_detached = depth_output.detach()
            
            # Adversarial feedback from task-specific discriminators
            seg_real_disc = self.seg_discriminator(seg_output_detached, seg_labels) if seg_labels is not None else None
            seg_fake_disc = self.seg_discriminator(seg_output_detached, None)

            depth_real_disc = self.depth_discriminator(depth_output_detached, depth_labels) if depth_labels is not None else None
            depth_fake_disc = self.depth_discriminator(depth_output_detached, None)

            # Multi-task discriminator feedback
            combined_real_input = torch.cat([ seg_labels, depth_labels], dim=1) if seg_labels is not None and depth_labels is not None else None
            combined_fake_input = torch.cat([ seg_output, depth_output], dim=1)

            combined_real_disc = self.multi_task_discriminator(combined_real_input) if combined_real_input is not None else None
            combined_fake_disc = self.multi_task_discriminator(combined_fake_input.detach())

            output_dict.update({
                "seg_real_disc": seg_real_disc,
                "seg_fake_disc": seg_fake_disc,
                "depth_real_disc": depth_real_disc,
                "depth_fake_disc": depth_fake_disc,
                "combined_real_disc": combined_real_disc,
                "combined_fake_disc": combined_fake_disc,
            })

        return output_dict


#     def forward(self, inputs, input_size, seg_labels=None, depth_labels=None, return_discriminator_outputs=False):
#         """
#         Forward pass for multi-task model.
#         Args:
#             inputs (Tensor): Input images.
#             input_size (tuple): Original input size.
#             seg_labels (Tensor, optional): Ground truth segmentation labels. Required for discriminator feedback.
#             depth_labels (Tensor, optional): Ground truth depth labels. Required for discriminator feedback.
#             return_discriminator_outputs (bool): If True, returns discriminator outputs for adversarial loss.
#         Returns:
#             dict: Outputs for segmentation and depth tasks, and optionally discriminator outputs.
#         """
#         skips = self.feature_generator(inputs)
#         shared_features = self.shared_generator(skips)

#         # Task-specific outputs
#         seg_output = self.seg_output_layer(shared_features["x4"], input_size)
#         depth_output = self.depth_output_layer(shared_features["x4"], input_size)

#         if return_discriminator_outputs:
#             # Adversarial feedback from task-specific discriminators
#             seg_real_disc = self.seg_discriminator(seg_output, seg_labels) if seg_labels is not None else None
#             depth_real_disc = self.depth_discriminator(depth_output, depth_labels) if depth_labels is not None else None

#             # Multi-task discriminator feedback
#             combined_real_disc = self.multi_task_discriminator(inputs, [seg_output, depth_output])

#             return {
#                 "seg_output": seg_output,
#                 "depth_output": depth_output,
#                 "seg_real_disc": seg_real_disc,
#                 "depth_real_disc": depth_real_disc,
#                 "combined_real_disc": combined_real_disc
#             }

#         return {
#             "seg_output": seg_output,
#             "depth_output": depth_output
#         }


In [248]:
from PIL import Image, ImageSequence
import os

def combine_training_gifs(model_dir, save_dir2, output_path):
    """
    Combine two training visualization GIFs into one.
    
    Args:
        model_dir: Directory containing the first training GIF.
        save_dir2: Directory containing the second training GIF.
        output_path: Path to save the combined GIF.
    """
    # Find the GIF files
    model_dir_gif = [file for file in os.listdir(model_dir) if file.startswith("training_visualization") and file.endswith(".gif")]
    save_dir2_gif = [file for file in os.listdir(save_dir2) if file.startswith("training_visualization") and file.endswith(".gif")]
    
    if not model_dir_gif or not save_dir2_gif:
        raise FileNotFoundError("Could not find training_visualization_*.gif in one of the directories.")
    
    model_dir_gif_path = os.path.join(model_dir, model_dir_gif[0])
    save_dir2_gif_path = os.path.join(save_dir2, save_dir2_gif[0])

    # Open the GIFs
    gif1 = Image.open(model_dir_gif_path)
    gif2 = Image.open(save_dir2_gif_path)

    # Collect all frames from both GIFs
    combined_frames = []
    for frame in ImageSequence.Iterator(gif1):
        combined_frames.append(frame.copy())
    for frame in ImageSequence.Iterator(gif2):
        combined_frames.append(frame.copy())

    # Save the combined GIF
    combined_frames[0].save(
        output_path,
        save_all=True,
        append_images=combined_frames[1:],
        duration=gif1.info.get("duration", 500),  # Use duration from the first GIF
        loop=0
    )

    print(f"Combined GIF saved to {output_path}")


# Saving loss charts

In [249]:
def plot_all_losses(epoch, train_losses,valid_losses,save_dir):
    # Plot training and validation losses
    for key in train_losses.keys():
        plt.figure()
        plt.plot(train_losses[key], label=f"Train {key}")
        plt.plot(valid_losses[key], label=f"Valid {key}")
        plt.xlabel("Epoch")
        plt.ylabel(key.replace("_", " ").title())
        plt.legend()
        plt.savefig(os.path.join(save_dir, f"{key}_loss_after_epoch_{epoch}.png"))
        plt.close()

# saving checkpoints

In [250]:
# Save checkpoint including model, optimizer, and scheduler states
def save_checkpoint(model, opt_sched, save_path, epoch, best_loss):
    checkpoint = {
        "model_state_dict": model.state_dict(),
        "optimizer_states": {name: opt.state_dict() for name, opt in opt_sched["optimizers"].items()},
        "scheduler_states": {name: sched.state_dict() for name, sched in opt_sched["schedulers"].items()},
        "epoch": epoch,
        "best_loss": best_loss
    }
    torch.save(checkpoint, save_path, _use_new_zipfile_serialization=True)
    
def load_checkpoint(model, opt_sched, checkpoint_path, device):
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint["model_state_dict"])
    
    for name, opt in opt_sched["optimizers"].items():
        if name in checkpoint["optimizer_states"]:
            opt.load_state_dict(checkpoint["optimizer_states"][name])
            
    for name, sched in opt_sched["schedulers"].items():
        if name in checkpoint["scheduler_states"]:
            sched.load_state_dict(checkpoint["scheduler_states"][name])
            
    # return checkpoint["epoch"], checkpoint["best_loss"]
    return checkpoint.get("epoch", 0), checkpoint.get("best_loss", float("inf"))



In [251]:
# import pandas as pd

def combine_and_plot_loss_data(model_dir, save_dir2, combined_save_dir="all_data_from_prev_curr_epoch"):
    """
    Combines loss data from previous and current training sessions and plots combined graphs.
    
    Args:
        model_dir: Path to the directory containing the previous loss-tracking CSV.
        save_dir2: Path to the directory containing the current loss-tracking CSV.
        combined_save_dir: Path to save the combined data and plots.

    Returns:
        combined_df: A pandas DataFrame containing the combined loss data.
    """
    # Ensure the save directory exists
    combined_save_dir = os.path.join(save_dir2, combined_save_dir)
    os.makedirs(combined_save_dir, exist_ok=True)

    # Locate CSV files
    previous_csv = os.path.join(model_dir, [file for file in os.listdir(model_dir) if file.endswith(".csv")][0])
    current_csv = os.path.join(save_dir2, [file for file in os.listdir(save_dir2) if file.endswith(".csv")][0])

    # Load data into pandas DataFrames
    previous_df = pd.read_csv(previous_csv)
    current_df = pd.read_csv(current_csv)

    # Update epoch numbers in the current DataFrame
    max_prev_epoch = previous_df["epoch"].max()
    current_df["epoch"] += max_prev_epoch

    # Combine the DataFrames
    combined_df = pd.concat([previous_df, current_df], ignore_index=True)

    # Save the combined DataFrame
    combined_csv_path = os.path.join(combined_save_dir, "combined_loss_tracking.csv")
    combined_df.to_csv(combined_csv_path, index=False)
    print(f"Combined loss data saved to {combined_csv_path}")

    # # Generate plots for each loss type
    # loss_columns = ["train_seg_loss", "train_depth_loss", "train_combined_loss", "train_adv_loss",
    #                 "valid_seg_loss", "valid_depth_loss", "valid_combined_loss", "valid_adv_loss"]
    # for col in loss_columns:
    #     plt.figure()
    #     plt.plot(combined_df["epoch"], combined_df[col], label=col)
    #     plt.xlabel("Epoch")
    #     plt.ylabel("Loss")
    #     plt.title(f"{col.replace('_', ' ').title()} Over Epochs")
    #     plt.legend()
    #     plot_path = os.path.join(combined_save_dir, f"{col}_plot.png")
    #     plt.savefig(plot_path)
    #     plt.close()
    #     print(f"Plot saved to {plot_path}")
    # Generate combined plots for train and valid losses
    
    loss_types = ["seg_loss", "depth_loss", "combined_loss", "adv_loss"]
    for loss_type in loss_types:
        train_loss_col = f"train_{loss_type}"
        valid_loss_col = f"valid_{loss_type}"

        plt.figure()
        plt.plot(combined_df["epoch"], combined_df[train_loss_col], label=f"Train {loss_type.capitalize()}")
        plt.plot(combined_df["epoch"], combined_df[valid_loss_col], label=f"Valid {loss_type.capitalize()}")
        plt.xlabel("Epoch")
        plt.ylabel("Loss")
        plt.title(f"{loss_type.replace('_', ' ').capitalize()} Over Epochs")
        plt.legend()
        plot_path = os.path.join(combined_save_dir, f"{loss_type}_train_valid_plot.png")
        plt.savefig(plot_path)
        plt.close()
        print(f"Combined plot saved to {plot_path}")


    return combined_df


In [252]:
# testing combine

In [253]:
os.getcwd(),'results_test8'

('/home/rmajumd/2024/ML_in_image_synthesis/Cityscapes/Cityscapes',
 'results_test8')

# training

In [258]:
def train_model_with_adversarial_loss_tracking(
    model, train_loader, valid_loader, num_epochs, device, opt_sched, save_dir="results"
):
    """
    Trains a multi-task model with adversarial feedback and tracks losses.
    
    Args:
        model: Multi-task model with integrated generators and discriminators.
        train_loader: DataLoader for training data.
        valid_loader: DataLoader for validation data.
        num_epochs: Number of epochs to train.
        device: Device for training ("cuda" or "cpu").
        opt_sched: Dictionary of optimizers and schedulers for generators and discriminators.
        save_dir: Directory to save results.
    
    Returns:
        train_losses, valid_losses: Lists of losses for training and validation.
    """
    # Create directories for saving results
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    save_dir = os.path.join(save_dir, timestamp)
    os.makedirs(save_dir, exist_ok=True)

    # Prepare CSV file for loss tracking
    csv_path = os.path.join(save_dir, f"loss_tracking_{timestamp}.csv")
    gif_path = os.path.join(save_dir, f"training_visualization_{timestamp}.gif")
    
    with open(csv_path, "w", newline="") as f:
        writer = csv.writer(f)
        writer.writerow([
            "epoch", "train_seg_loss", "train_depth_loss", "train_combined_loss",
            "train_adv_loss", 
            # "train_seg_iou",
            "valid_seg_loss", "valid_depth_loss", "valid_combined_loss",
            "valid_adv_loss", 
            # "valid_seg_iou"
        ])

    # Initialize tracking variables
    train_losses = {"seg": [], "depth": [], "combined": [], "adv": []}
    valid_losses = {"seg": [], "depth": [], "combined": [], "adv": []}
    best_combined_loss = float("inf")
    gif_frames =[]
    perceptual_loss_fn = PerceptualLoss(pretrained_model="vgg16").to(device)

    # Start training loop
    for epoch in range(num_epochs):
        model.train()
        epoch_train = {key: 0.0 for key in train_losses.keys()}
        num_batches = 0
        
        
        
        with tqdm(total=len(train_loader), desc=f"Epoch {epoch+1}/{num_epochs} - Training", unit="batch") as pbar:
            for batch in train_loader:
                inputs, seg_labels, depth_labels = (
                    batch["left"].to(device),
                    batch["mask"].to(device),
                    batch["depth"].to(device),
                )
                input_size = inputs.size()[-2:]

                # Preprocess seg_labels to one-hot encoding
                if seg_labels.size(1) == 1:  # If class indices are given
                    seg_labels = torch.nn.functional.one_hot(seg_labels.squeeze(1), num_classes=20)
                    seg_labels = seg_labels.permute(0, 3, 1, 2).float().to(device)  # Convert to [B, C, H, W]

                # Ensure depth_labels has correct dimensions
                if depth_labels.dim() == 5:  # If depth_labels has extra dimensions
                    depth_labels = depth_labels.squeeze(2)

                # Zero gradients
                for optimizer in opt_sched["optimizers"].values():
                    optimizer.zero_grad()

                # Forward pass with discriminator outputs
                outputs = model(
                    inputs,
                    input_size=input_size,
                    seg_labels=seg_labels,
                    depth_labels=depth_labels,
                    return_discriminator_outputs=True,
                )

                # Generator losses
                seg_loss = nn.CrossEntropyLoss()(outputs["seg_output"], seg_labels) + \
                           dice_loss(outputs["seg_output"], seg_labels)
                depth_loss = scale_invariant_depth_loss(outputs["depth_output"], depth_labels) + \
                             inv_huber_loss(outputs["depth_output"], depth_labels) 
                # + \
                #              depth_smoothness_loss(outputs["depth_output"], inputs)
                
                seg_perceptual_loss = perceptual_loss_fn(outputs["seg_output"], seg_labels.unsqueeze(1))
                depth_perceptual_loss = perceptual_loss_fn(outputs["depth_output"], depth_labels)
                
                seg_loss = seg_loss + 0.1 * seg_perceptual_loss
                depth_loss = depth_loss + 0.1 * depth_perceptual_loss
                
                
                # Adversarial Losses
                gen_adv_loss_seg, disc_loss_seg = compute_adversarial_losses(
                    model.seg_discriminator,
                    outputs["seg_real_disc"],
                    outputs["seg_fake_disc"].detach(),
                    hinge_loss=False,  # Switch to False for WGAN-GP
                    gradient_penalty=True,
                    real_inputs=seg_labels,  # Use only for WGAN-GP
                )

                gen_adv_loss_depth, disc_loss_depth = compute_adversarial_losses(
                    model.depth_discriminator,
                    outputs["depth_real_disc"],
                    outputs["depth_fake_disc"].detach(),
                    hinge_loss=False,
                    gradient_penalty=True,
                    real_inputs=depth_labels,
                )

                gen_adv_loss_combined, disc_loss_combined = compute_adversarial_losses(
                    model.multi_task_discriminator,
                    outputs["combined_real_disc"],
                    outputs["combined_fake_disc"].detach(),
                    hinge_loss=False,
                    gradient_penalty=True,
                    real_inputs=torch.cat([seg_labels, depth_labels], dim=1) if seg_labels is not None else None,
                )
                

                # adv_loss = -(
                #     torch.mean(outputs["seg_real_disc"]) +
                #     torch.mean(outputs["depth_real_disc"]) +
                #     torch.mean(outputs["combined_real_disc"])
                # )
                # Total Generator Loss
                gen_adv_loss =1.5 * gen_adv_loss_seg + gen_adv_loss_depth + gen_adv_loss_combined
                combined_loss = 2 * seg_loss + depth_loss +  0.1 * gen_adv_loss

                # combined_loss = seg_loss + depth_loss + 0.01 * adv_loss

                # Backpropagation for generators
                combined_loss.backward(retain_graph=True)
                # opt_sched["optimizers"]["generator"].step()
                opt_sched["optimizers"]["shared_gen"].step()
                opt_sched["optimizers"]["shared_refine"].step()
                opt_sched["optimizers"]["seg_gen"].step()
                opt_sched["optimizers"]["depth_gen"].step()


                # # Update task-specific discriminators
                # for task, disc_optimizer in [
                #     ("seg", "seg_disc"),
                #     ("depth", "depth_disc"),
                # ]:
                #     opt_sched["optimizers"][disc_optimizer].zero_grad()
                #     real_disc_loss = torch.mean(
                #         (outputs[f"{task}_real_disc"] - 1) ** 2
                #     )
                #     fake_disc_loss = torch.mean(
                #         (outputs[f"{task}_fake_disc"].detach()) ** 2
                #     )
                #     disc_loss = (real_disc_loss + fake_disc_loss) / 2
                #     disc_loss.backward()
                #     opt_sched["optimizers"][disc_optimizer].step()
                
                # Backpropagation for discriminators
                for task, disc_loss in [("seg", disc_loss_seg), ("depth", disc_loss_depth)]:
                    opt_sched["optimizers"][f"{task}_disc"].zero_grad()
                    disc_loss.backward()
                    opt_sched["optimizers"][f"{task}_disc"].step()

                # Update multi-task discriminator
                opt_sched["optimizers"]["multi_task_disc"].zero_grad()
                disc_loss_combined.backward()
#                 real_combined_loss = torch.mean(
#                     (outputs["combined_real_disc"] - 1) ** 2
#                 )
#                 fake_combined_loss = torch.mean(
#                     (outputs["combined_fake_disc"].detach()) ** 2
#                 )
                
#                 combined_disc_loss = (real_combined_loss + fake_combined_loss) / 2
#                 combined_disc_loss.backward()
                opt_sched["optimizers"]["multi_task_disc"].step()

                # Update training metrics
                epoch_train["seg"] += seg_loss.item()
                epoch_train["depth"] += depth_loss.item()
                epoch_train["combined"] += combined_loss.item()
                epoch_train["adv"] += gen_adv_loss.item()
                # epoch_train["iou"] += mean_iou(outputs["seg_output"], seg_labels, num_classes=20).item()
                num_batches += 1
                pbar.update(1)

            # Average training metrics
            for key in epoch_train.keys():
                train_losses[key].append(epoch_train[key] / num_batches)

            # Validation loop
            model.eval()
            epoch_valid = {key: 0.0 for key in valid_losses.keys()}
            num_valid_batches = 0

            with torch.no_grad():
                for batch in valid_loader:
                    inputs, seg_labels, depth_labels = (
                        batch["left"].to(device),
                        batch["mask"].to(device),
                        batch["depth"].to(device),
                    )
                    input_size = inputs.size()[-2:]

                    # Preprocess seg_labels to one-hot encoding
                    if seg_labels.size(1) == 1:
                        seg_labels = torch.nn.functional.one_hot(seg_labels.squeeze(1), num_classes=20)
                        seg_labels = seg_labels.permute(0, 3, 1, 2).float().to(device)

                    # Ensure depth_labels has correct dimensions
                    if depth_labels.dim() == 5:
                        depth_labels = depth_labels.squeeze(2)


                    # Forward pass
                    outputs = model(
                        inputs,
                        input_size=input_size,
                        seg_labels=seg_labels,
                        depth_labels=depth_labels,
                        return_discriminator_outputs=True,
                    )

                    # Validation loss calculations
                    seg_loss = nn.CrossEntropyLoss()(outputs["seg_output"], seg_labels) + \
                               dice_loss(outputs["seg_output"], seg_labels)
                    depth_loss = scale_invariant_depth_loss(outputs["depth_output"], depth_labels) + \
                                 inv_huber_loss(outputs["depth_output"], depth_labels) 
                    # + \
                    #              depth_smoothness_loss(outputs["depth_output"], inputs)
                    
                    seg_perceptual_loss = perceptual_loss_fn(outputs["seg_output"], seg_labels.unsqueeze(1))
                    depth_perceptual_loss = perceptual_loss_fn(outputs["depth_output"], depth_labels)

                    seg_loss = seg_loss + 0.1 * seg_perceptual_loss
                    depth_loss = depth_loss + 0.1 * depth_perceptual_loss
                    
                    # Adversarial Losses (No Backpropagation)
                    gen_adv_loss_seg, _ = compute_adversarial_losses(
                        model.seg_discriminator,
                        outputs["seg_real_disc"], outputs["seg_fake_disc"], hinge_loss=False,gradient_penalty=False,
                    )
                    gen_adv_loss_depth, _ = compute_adversarial_losses(
                        model.depth_discriminator,
                        outputs["depth_real_disc"], outputs["depth_fake_disc"], hinge_loss=False,gradient_penalty=False,
                    )
                    gen_adv_loss_combined, _ = compute_adversarial_losses(
                        model.multi_task_discriminator,
                        outputs["combined_real_disc"], outputs["combined_fake_disc"], hinge_loss=False,gradient_penalty=False,
                        real_inputs=torch.cat([seg_labels, depth_labels], dim=1) if seg_labels is not None else None,
                    )

                    gen_adv_loss = 1.5 * gen_adv_loss_seg + gen_adv_loss_depth + gen_adv_loss_combined
                    combined_loss = 2 * seg_loss + depth_loss + 0.1 * gen_adv_loss


#                     adv_loss = -(
#                         torch.mean(outputs["seg_real_disc"]) +
#                         torch.mean(outputs["depth_real_disc"]) +
#                         torch.mean(outputs["combined_real_disc"])
#                     )

#                     combined_loss = seg_loss + depth_loss + 0.01 * adv_loss

                    # Update validation metrics
                    epoch_valid["seg"] += seg_loss.item()
                    epoch_valid["depth"] += depth_loss.item()
                    epoch_valid["combined"] += combined_loss.item()
                    epoch_valid["adv"] += gen_adv_loss.item()
                    # epoch_valid["iou"] += mean_iou(outputs["seg_output"], seg_labels, num_classes=20).item()
                    num_valid_batches += 1

        # Average validation metrics
        for key in epoch_valid.keys():
            valid_losses[key].append(epoch_valid[key] / num_valid_batches)

        # Save best model
        valid_combined_loss = (epoch_valid["combined"] / num_valid_batches) + 0.01 * (epoch_valid["adv"] / num_valid_batches)
        if valid_combined_loss < best_combined_loss:
            best_combined_loss = valid_combined_loss
            # torch.save(model.state_dict(), os.path.join(save_dir, "best_model.pth"))
            checkpoint_path = os.path.join(save_dir, "best_model_checkpoint.pth")
            save_checkpoint(model, opt_sched, checkpoint_path, epoch + 1, best_combined_loss)
            print(f"Best model saved at epoch {epoch+1} with combined loss {best_combined_loss:.4f}")
            
        frame = save_training_visualization_as_gif2(epoch, inputs, outputs["seg_output"], outputs["depth_output"], torch.argmax(seg_labels, dim=1), depth_labels)
        gif_frames.append(frame)
        
        

        # Append metrics to CSV
        with open(csv_path, "a", newline="") as f:
            writer = csv.writer(f)
            writer.writerow([
                epoch + 1,
                epoch_train["seg"] / num_batches,
                epoch_train["depth"] / num_batches,
                epoch_train["combined"] / num_batches,
                epoch_train["adv"] / num_batches,
                # epoch_train["iou"] / num_batches,
                epoch_valid["seg"] / num_valid_batches,
                epoch_valid["depth"] / num_valid_batches,
                epoch_valid["combined"] / num_valid_batches,
                epoch_valid["adv"] / num_valid_batches,
                # epoch_valid["iou"] / num_valid_batches,
            ])
            
        # Print epoch results
        print(f"Epoch {epoch + 1}/{num_epochs} Results:")

        # Print training losses
        print(f"  Train Losses - Segmentation: {epoch_train['seg']/num_batches:.4f}, Depth: {epoch_train['depth']/num_batches:.4f}, "
              f"Combined: {epoch_train['combined']/num_batches:.4f}, Adversarial: {epoch_train['adv']/num_batches:.4f}")

        # Print validation losses
        print(f"  Valid Losses - Segmentation: {epoch_valid['seg']/ num_valid_batches:.4f}, Depth: {epoch_valid['depth']/ num_valid_batches:.4f}, "
              f"Combined: {epoch_valid['combined']/ num_valid_batches:.4f}, Adversarial: {epoch_valid['adv']/ num_valid_batches:.4f}")


        # Update schedulers
        for name, scheduler in opt_sched["schedulers"].items():
            if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
                # Pass the appropriate metric to ReduceLROnPlateau
                scheduler.step(valid_losses["combined"][-1])  # Use the most recent validation combined loss
            else:
                scheduler.step()
            
        if epoch %10 == 0:
            gif_path2 =os.path.join(save_dir,f"viz_epoch_{epoch}.gif")
            gif_frames[0].save(gif_path2, save_all=True, append_images=gif_frames[1:], duration=500, loop=0)
            plot_all_losses(epoch, train_losses,valid_losses,save_dir)
            
    
    gif_frames[0].save(gif_path, save_all=True, append_images=gif_frames[1:], duration=500, loop=0)
    print(f"Training visualization saved as GIF at {gif_path}")
    plot_all_losses(epoch, train_losses,valid_losses,save_dir)

    
    return train_losses, valid_losses, save_dir


In [259]:

def resume_training_with_loss_tracking(
    model_class,
    model_dir,
    train_loader,
    valid_loader,
    num_additional_epochs,
    device,
    opt_sched,
    save_dir,
):
    """
    Resumes training a multi-task model, appends loss data to the existing CSV file,
    and generates graphs for the combined training history.

    Args:
        model_class: The model class to instantiate.
        model_dir: Path to the directory containing the saved model and loss CSV file.
        train_loader: DataLoader for training data.
        valid_loader: DataLoader for validation data.
        num_additional_epochs: Number of additional epochs to train.
        device: Device for training ("cuda" or "cpu").
        opt_sched: Dictionary of optimizers and schedulers.
        save_dir: Directory to save the updated results.

    Returns:
        Updated train and validation losses.
    """
    # # Load the best model
    # best_model_path = os.path.join(model_dir, "best_model.pth")
    # if not os.path.exists(best_model_path):
    #     raise FileNotFoundError(f"Best model not found at {best_model_path}")
        
    # Load the checkpoint
    checkpoint_path = os.path.join(model_dir, "best_model_checkpoint.pth")
    if not os.path.exists(checkpoint_path):
        raise FileNotFoundError(f"Checkpoint not found at {checkpoint_path}")
    
        
    model = model_class().to(device)
    
    start_epoch, best_loss = load_checkpoint(model, opt_sched, checkpoint_path, device)

    # model.load_state_dict(torch.load(best_model_path, map_location=device))

    # Locate the existing loss CSV
    csv_path = os.path.join(model_dir, [file for file in os.listdir(model_dir) if file.endswith(".csv")][0])

    # Parse existing CSV data
    existing_train_losses = {"seg": [], "depth": [], "combined": [], "adv": []}
    existing_valid_losses = {"seg": [], "depth": [], "combined": [], "adv": []}
    current_epoch = 0
    with open(csv_path, "r") as f:
        reader = csv.reader(f)
        next(reader)  # Skip header
        for row in reader:
            current_epoch = int(row[0])
            existing_train_losses["seg"].append(float(row[1]))
            existing_train_losses["depth"].append(float(row[2]))
            existing_train_losses["combined"].append(float(row[3]))
            existing_train_losses["adv"].append(float(row[4]))
            existing_valid_losses["seg"].append(float(row[5]))
            existing_valid_losses["depth"].append(float(row[6]))
            existing_valid_losses["combined"].append(float(row[7]))
            existing_valid_losses["adv"].append(float(row[8]))

    # Train for additional epochs
    
    train_losses, valid_losses,save_dir2 = train_model_with_adversarial_loss_tracking(
        model=model,
        train_loader=train_loader,
        valid_loader=valid_loader,
        num_epochs=num_additional_epochs,
        device=device,
        opt_sched=opt_sched,
        save_dir=save_dir,
    )

    # Combine the new losses with the existing ones
    for key in existing_train_losses.keys():
        existing_train_losses[key].extend(train_losses[key])
        existing_valid_losses[key].extend(valid_losses[key])
        
    save_dir3 = os.path.join(save_dir2,"combined_result")
    os.makedirs(save_dir3, exist_ok=True)
    
    total_epochs = len(existing_train_losses["seg"])
        
    updated_csv_path = os.path.join(save_dir3, f"loss_tracking_updated_{total_epochs}.csv")
    os.makedirs(save_dir3, exist_ok=True)

    # Write combined losses to the updated CSV
    with open(updated_csv_path, "w", newline="") as f:
        writer = csv.writer(f)
        writer.writerow([
            "epoch", "train_seg_loss", "train_depth_loss", "train_combined_loss",
            "train_adv_loss", 
            "valid_seg_loss", "valid_depth_loss", "valid_combined_loss",
            "valid_adv_loss", 
        ])
        for epoch in range(len(existing_train_losses["seg"])):
            writer.writerow([
                epoch + 1,
                existing_train_losses["seg"][epoch],
                existing_train_losses["depth"][epoch],
                existing_train_losses["combined"][epoch],
                existing_train_losses["adv"][epoch],
                existing_valid_losses["seg"][epoch],
                existing_valid_losses["depth"][epoch],
                existing_valid_losses["combined"][epoch],
                existing_valid_losses["adv"][epoch],
            ])

    # Generate graphs
    for key in existing_train_losses.keys():
        plt.figure()
        plt.plot(range(len(existing_train_losses[key])), existing_train_losses[key], label=f"Train {key.capitalize()}")
        plt.plot(range(len(existing_valid_losses[key])), existing_valid_losses[key], label=f"Valid {key.capitalize()}")
        plt.xlabel("Epoch")
        plt.ylabel(f"{key.capitalize()} Loss")
        plt.legend()
        plt.title(f"{key.capitalize()} Loss Over Epochs")
        plt.savefig(os.path.join(save_dir3, f"{key}_loss_graph_epoch_{total_epochs}.png"))
        plt.close()
        
    
    output_path = os.path.join(save_dir3,'combined_results.gif')
    combine_training_gifs(model_dir, save_dir2, output_path)

    return existing_train_losses, existing_valid_losses,save_dir2


# for visualization

In [256]:
def list_pth_files(directory):
    """
    Lists all .pth files in a specified directory.

    Args:
        directory (str): Path to the directory to search.

    Returns:
        List[str]: List of .pth file paths.
    """
    pth_files = [os.path.join(directory, file) for file in os.listdir(directory) if file.endswith('.pth')]
    return pth_files

In [257]:
best_model_dir = os.path.join(os.getcwd(),'Best_models')
best_model_dir
pth_files = list_pth_files(best_model_dir)
pth_files

['/home/rmajumd/2024/ML_in_image_synthesis/Cityscapes/Cityscapes/Best_models/test11.pth',
 '/home/rmajumd/2024/ML_in_image_synthesis/Cityscapes/Cityscapes/Best_models/test13.pth',
 '/home/rmajumd/2024/ML_in_image_synthesis/Cityscapes/Cityscapes/Best_models/test8.pth',
 '/home/rmajumd/2024/ML_in_image_synthesis/Cityscapes/Cityscapes/Best_models/test9.pth']

In [210]:
print("List of .pth files:")
model_map ={}
for file in pth_files:
    filename = file.split('/')[-1] 
    name_without_extension = filename.split('.')[0]  # Removes '.pth'
    model_map[name_without_extension] = file
    
model_map

List of .pth files:


{'test11': '/home/rmajumd/2024/ML_in_image_synthesis/Cityscapes/Cityscapes/Best_models/test11.pth',
 'test13': '/home/rmajumd/2024/ML_in_image_synthesis/Cityscapes/Cityscapes/Best_models/test13.pth',
 'test8': '/home/rmajumd/2024/ML_in_image_synthesis/Cityscapes/Cityscapes/Best_models/test8.pth',
 'test9': '/home/rmajumd/2024/ML_in_image_synthesis/Cityscapes/Cityscapes/Best_models/test9.pth'}

In [211]:
class MultiTaskModel_old(nn.Module):
    def __init__(self, backbone, num_seg_classes=20, depth_channels=1):
        """
        Multi-task model with shared Pix2Pix Generator, task-specific discriminators,
        and a multi-task discriminator.
        Args:
            backbone (nn.Module): Encoder backbone for feature extraction.
            num_seg_classes (int): Number of segmentation classes.
            depth_channels (int): Number of output channels for depth.
        """
        super(MultiTaskModel_old, self).__init__()
        self.feature_generator = MobileNetV3Backbone(backbone)
        self.shared_generator = EnhancedSharedGenerator()
        self.seg_output_layer = TaskOutputLayer(output_channels=num_seg_classes)
        self.depth_output_layer = TaskOutputLayer(output_channels=depth_channels)

        # Task-specific discriminators
        self.seg_discriminator = TaskSpecificDiscriminator(input_channels=num_seg_classes)
        self.depth_discriminator = TaskSpecificDiscriminator(input_channels=depth_channels)

        # Multi-task discriminator
        self.multi_task_discriminator = MultiTaskDiscriminator(input_channels=3 + num_seg_classes + depth_channels)
        
    def forward(self, inputs, input_size, seg_labels=None, depth_labels=None, return_discriminator_outputs=False):
        # Extract features from the encoder
        skips = self.feature_generator(inputs)
        shared_features = self.shared_generator(skips)

        # Task-specific outputs
        seg_output = self.seg_output_layer(shared_features["x4"], input_size)
        depth_output = self.depth_output_layer(shared_features["x4"], input_size)

        output_dict = {
            "seg_output": seg_output,
            "depth_output": depth_output,
        }

        if return_discriminator_outputs:
            
            # Detach outputs to prevent discriminator backward from interfering with the generator
            seg_output_detached = seg_output.detach()
            depth_output_detached = depth_output.detach()
            
            # Adversarial feedback from task-specific discriminators
            seg_real_disc = self.seg_discriminator(seg_output_detached, seg_labels) if seg_labels is not None else None
            seg_fake_disc = self.seg_discriminator(seg_output_detached, None)

            depth_real_disc = self.depth_discriminator(depth_output_detached, depth_labels) if depth_labels is not None else None
            depth_fake_disc = self.depth_discriminator(depth_output_detached, None)

            # Multi-task discriminator feedback
            combined_real_input = torch.cat([inputs, seg_labels, depth_labels], dim=1) if seg_labels is not None and depth_labels is not None else None
            combined_fake_input = torch.cat([inputs, seg_output, depth_output], dim=1)

            combined_real_disc = self.multi_task_discriminator(combined_real_input) if combined_real_input is not None else None
            combined_fake_disc = self.multi_task_discriminator(combined_fake_input.detach())

            output_dict.update({
                "seg_real_disc": seg_real_disc,
                "seg_fake_disc": seg_fake_disc,
                "depth_real_disc": depth_real_disc,
                "depth_fake_disc": depth_fake_disc,
                "combined_real_disc": combined_real_disc,
                "combined_fake_disc": combined_fake_disc,
            })

        return output_dict


In [212]:
# def save_visualizations(models, val_loader, device, output_dir):
#     """
#     Save visualizations of segmentation and depth outputs as PNG files.

#     Args:
#         models (list): List of PyTorch models.
#         val_loader: Validation DataLoader.
#         device: Device for inference ("cuda" or "cpu").
#         output_dir: Directory to save the visualizations.
#     """
#     # Ensure output directory exists
#     os.makedirs(output_dir, exist_ok=True)

#     # Get the 4th element from val_loader
#     sample_index = 3
#     inputs, seg_labels, depth_labels = None, None, None
#     for idx, batch in enumerate(val_loader):
#         if idx == sample_index:
#             inputs = batch["left"].to(device)
#             seg_labels = batch["mask"].to(device)
#             depth_labels = batch["depth"].to(device)
#             break

#     # Ensure the sample was found
#     if inputs is None or seg_labels is None or depth_labels is None:
#         raise ValueError(f"Sample with index {sample_index} not found in val_loader.")

#     # Get predictions from all models
#     seg_outputs = []
#     depth_outputs = []

#     for model in models:
#         model.eval()
#         with torch.no_grad():
#             outputs = model(inputs)
#             seg_outputs.append(outputs["seg_output"])  # Segmentation output
#             depth_outputs.append(outputs["depth_output"])  # Depth output

#     # Convert segmentation labels to one-hot if necessary
#     if seg_labels.size(1) == 1:  # If labels are class indices
#         seg_labels = torch.nn.functional.one_hot(seg_labels.squeeze(1), num_classes=20)
#         seg_labels = seg_labels.permute(0, 3, 1, 2).float().to(device)

#     # Save segmentation visualization
#     plt.figure(figsize=(15, 5))

#     # Original Image
#     plt.subplot(1, len(models) + 2, 1)
#     plt.imshow(inputs[0].permute(1, 2, 0).cpu().numpy())
#     plt.title("Original Image")
#     plt.axis("off")

#     # Original Segmentation Labels
#     plt.subplot(1, len(models) + 2, 2)
#     plt.imshow(torch.argmax(seg_labels[0], dim=0).cpu().numpy(), cmap="jet")
#     plt.title("Segmentation Labels")
#     plt.axis("off")

#     # Model Outputs for Segmentation
#     for i, seg_output in enumerate(seg_outputs):
#         plt.subplot(1, len(models) + 2, i + 3)
#         plt.imshow(torch.argmax(seg_output[0], dim=0).cpu().numpy(), cmap="jet")
#         plt.title(f"Model {i + 1}")
#         plt.axis("off")

#     seg_output_path = os.path.join(output_dir, "segmentation_visualization.png")
#     plt.savefig(seg_output_path)
#     plt.close()
#     print(f"Segmentation visualization saved at: {seg_output_path}")

#     # Save depth visualization
#     plt.figure(figsize=(15, 5))

#     # Original Image
#     plt.subplot(1, len(models) + 2, 1)
#     plt.imshow(inputs[0].permute(1, 2, 0).cpu().numpy())
#     plt.title("Original Image")
#     plt.axis("off")

#     # Depth Labels
#     plt.subplot(1, len(models) + 2, 2)
#     plt.imshow(depth_labels[0].squeeze(0).cpu().numpy(), cmap="plasma")
#     plt.title("Depth Labels")
#     plt.axis("off")

#     # Model Outputs for Depth
#     for i, depth_output in enumerate(depth_outputs):
#         plt.subplot(1, len(models) + 2, i + 3)
#         plt.imshow(depth_output[0].squeeze(0).cpu().numpy(), cmap="plasma")
#         plt.title(f"Model {i + 1}")
#         plt.axis("off")

#     depth_output_path = os.path.join(output_dir, "depth_visualization.png")
#     plt.savefig(depth_output_path)
#     plt.close()
#     print(f"Depth visualization saved at: {depth_output_path}")

# # def save_visualizations(models, val_loader, device, output_dir):
# #     """
# #     Save visualizations of segmentation and depth outputs as PNG files.

# #     Args:
# #         models (list): List of PyTorch models.
# #         val_loader: Validation DataLoader.
# #         device: Device for inference ("cuda" or "cpu").
# #         output_dir: Directory to save the visualizations.
# #     """
# #     # Ensure output directory exists
# #     os.makedirs(output_dir, exist_ok=True)

# #     # Get the 4th element from val_loader
# #     sample_index = 3
# #     for idx, batch in enumerate(val_loader):
# #         if idx == sample_index:
# #             inputs = batch["left"].to(device)
# #             seg_labels = batch["mask"].to(device)
# #             depth_labels = batch["depth"].to(device)
# #             break

# #     # Get predictions from all models
# #     seg_outputs = []
# #     depth_outputs = []

# #     for model in models:
# #         model.eval()
# #         with torch.no_grad():
# #             outputs = model(inputs)
# #             seg_outputs.append(outputs["seg_output"])  # Segmentation output
# #             depth_outputs.append(outputs["depth_output"])  # Depth output

# #     # Convert segmentation labels to one-hot if necessary
# #     if seg_labels.size(1) == 1:  # If labels are class indices
# #         seg_labels = torch.nn.functional.one_hot(seg_labels.squeeze(1), num_classes=20)
# #         seg_labels = seg_labels.permute(0, 3, 1, 2).float().to(device)

# #     # Save segmentation visualization
# #     plt.figure(figsize=(15, 5))

# #     # Original Image
# #     plt.subplot(1, len(models) + 2, 1)
# #     plt.imshow(inputs[0].permute(1, 2, 0).cpu().numpy())
# #     plt.title("Original Image")
# #     plt.axis("off")

# #     # Original Segmentation Labels
# #     plt.subplot(1, len(models) + 2, 2)
# #     plt.imshow(torch.argmax(seg_labels[0], dim=0).cpu().numpy(), cmap="jet")
# #     plt.title("Segmentation Labels")
# #     plt.axis("off")

# #     # Model Outputs for Segmentation
# #     for i, seg_output in enumerate(seg_outputs):
# #         plt.subplot(1, len(models) + 2, i + 3)
# #         plt.imshow(torch.argmax(seg_output[0], dim=0).cpu().numpy(), cmap="jet")
# #         plt.title(f"Model {i + 1}")
# #         plt.axis("off")

# #     seg_output_path = os.path.join(output_dir, "segmentation_visualization.png")
# #     plt.savefig(seg_output_path)
# #     plt.close()
# #     print(f"Segmentation visualization saved at: {seg_output_path}")

# #     # Save depth visualization
# #     plt.figure(figsize=(15, 5))

# #     # Original Image
# #     plt.subplot(1, len(models) + 2, 1)
# #     plt.imshow(inputs[0].permute(1, 2, 0).cpu().numpy())
# #     plt.title("Original Image")
# #     plt.axis("off")

# #     # Depth Labels
# #     plt.subplot(1, len(models) + 2, 2)
# #     plt.imshow(depth_labels[0].squeeze(0).cpu().numpy(), cmap="plasma")
# #     plt.title("Depth Labels")
# #     plt.axis("off")

# #     # Model Outputs for Depth
# #     for i, depth_output in enumerate(depth_outputs):
# #         plt.subplot(1, len(models) + 2, i + 3)
# #         plt.imshow(depth_output[0].squeeze(0).cpu().numpy(), cmap="plasma")
# #         plt.title(f"Model {i + 1}")
# #         plt.axis("off")

# #     depth_output_path = os.path.join(output_dir, "depth_visualization.png")
# #     plt.savefig(depth_output_path)
# #     plt.close()
# #     print(f"Depth visualization saved at: {depth_output_path}")


In [213]:
# def save_batch_visualizations(models, val_loader, device, output_dir, batch_index=0):
#     """
#     Save visualizations of segmentation and depth outputs for an entire batch as PNG files.

#     Args:
#         models (list): List of PyTorch models.
#         val_loader: Validation DataLoader.
#         device: Device for inference ("cuda" or "cpu").
#         output_dir: Directory to save the visualizations.
#         batch_index: Index of the batch to visualize.
#     """
#     # Ensure output directory exists
#     os.makedirs(output_dir, exist_ok=True)

#     # Get the specified batch
#     for idx, batch in enumerate(val_loader):
#         if idx == batch_index:
#             inputs = batch["left"].to(device)
#             seg_labels = batch["mask"].to(device)
#             depth_labels = batch["depth"].to(device)
#             input_size = inputs.size()[-2:]
#             break

#     # Ensure the batch was found
#     if inputs is None or seg_labels is None or depth_labels is None:
#         raise ValueError(f"Batch with index {batch_index} not found in val_loader.")
        
#     # Move models to the correct device
#     for model in models:
#         model.to(device)

#     # Get predictions from all models
#     seg_outputs = []
#     depth_outputs = []

#     for model in models:
#         model.eval()
#         with torch.no_grad():
#             outputs = model(inputs,input_size=input_size)
#             seg_outputs.append(outputs["seg_output"])  # Segmentation output
#             depth_outputs.append(outputs["depth_output"])  # Depth output

#     # Iterate over the batch
#     batch_size = inputs.size(0)

#     for i in range(batch_size):
#         # Convert segmentation labels to one-hot if necessary
#         if seg_labels.size(1) == 1:  # If labels are class indices
#             single_seg_label = torch.nn.functional.one_hot(seg_labels[i].squeeze(0), num_classes=20)
#             single_seg_label = single_seg_label.permute(2, 0, 1).float().to(device)
#         else:
#             single_seg_label = seg_labels[i]

#         single_depth_label = depth_labels[i]

#         # Save segmentation visualization for this image
#         plt.figure(figsize=(15, 5))

#         # Original Image
#         plt.subplot(1, len(models) + 2, 1)
#         plt.imshow(inputs[i].permute(1, 2, 0).cpu().numpy())
#         plt.title("Original Image")
#         plt.axis("off")

#         # Original Segmentation Labels
#         plt.subplot(1, len(models) + 2, 2)
#         plt.imshow(torch.argmax(single_seg_label, dim=0).cpu().numpy(), cmap="jet")
#         plt.title("Segmentation Labels")
#         plt.axis("off")

#         # Model Outputs for Segmentation
#         for j, seg_output in enumerate(seg_outputs):
#             plt.subplot(1, len(models) + 2, j + 3)
#             plt.imshow(torch.argmax(seg_output[i], dim=0).cpu().numpy(), cmap="jet")
#             plt.title(f"Model {j + 1}")
#             plt.axis("off")

#         seg_output_path = os.path.join(output_dir, f"segmentation_visualization_image_{i}.png")
#         plt.savefig(seg_output_path)
#         plt.close()
#         print(f"Segmentation visualization for image {i} saved at: {seg_output_path}")

#         # Save depth visualization for this image
#         plt.figure(figsize=(15, 5))

#         # Original Image
#         plt.subplot(1, len(models) + 2, 1)
#         plt.imshow(inputs[i].permute(1, 2, 0).cpu().numpy())
#         plt.title("Original Image")
#         plt.axis("off")

#         # Depth Labels
#         plt.subplot(1, len(models) + 2, 2)
#         plt.imshow(single_depth_label.squeeze(0).cpu().numpy(), cmap="plasma")
#         plt.title("Depth Labels")
#         plt.axis("off")

#         # Model Outputs for Depth
#         for j, depth_output in enumerate(depth_outputs):
#             plt.subplot(1, len(models) + 2, j + 3)
#             plt.imshow(depth_output[i].squeeze(0).cpu().numpy(), cmap="plasma")
#             plt.title(f"Model {j + 1}")
#             plt.axis("off")

#         depth_output_path = os.path.join(output_dir, f"depth_visualization_image_{i}.png")
#         plt.savefig(depth_output_path)
#         plt.close()
#         print(f"Depth visualization for image {i} saved at: {depth_output_path}")


In [214]:
def save_batch_visualizations(models, val_loader, device, output_dir, batch_index=0):
    """
    Save visualizations of segmentation and depth outputs for an entire batch as separate PNG files.

    Args:
        models (list): List of PyTorch models.
        val_loader: Validation DataLoader.
        device: Device for inference ("cuda" or "cpu").
        output_dir: Directory to save the visualizations.
        batch_index: Index of the batch to visualize.
    """
    # Ensure output directory exists
    os.makedirs(output_dir, exist_ok=True)

    # Get the specified batch
    inputs, seg_labels, depth_labels = None, None, None
    for idx, batch in enumerate(val_loader):
        if idx == batch_index:
            inputs = batch["left"].detach().cpu()
            seg_labels = batch["mask"].detach().cpu()
            depth_labels = batch["depth"].detach().cpu()
            
            # Preprocess seg_labels to one-hot encoding
            if seg_labels.size(1) == 1:
                seg_labels = torch.nn.functional.one_hot(seg_labels.squeeze(1), num_classes=20)
                seg_labels = seg_labels.permute(0, 3, 1, 2).float().to(device)

            # Ensure depth_labels has correct dimensions
            if depth_labels.dim() == 5:
                depth_labels = depth_labels.squeeze(2)
                
            seg_labels = torch.argmax(seg_labels, dim=1)
            break

    # Ensure the batch was found
    if inputs is None or seg_labels is None or depth_labels is None:
        raise ValueError(f"Batch with index {batch_index} not found in val_loader.")

    # Move models to the correct device
    for model in models:
        model.to(device)

    # Get predictions from all models
    seg_outputs = []
    depth_outputs = []

    for model in models:
        model.eval()
        with torch.no_grad():
            outputs = model(inputs, input_size=inputs.shape[-2:])
            seg_outputs.append(outputs["seg_output"].detach().cpu())  # Segmentation output
            depth_outputs.append(outputs["depth_output"].detach().cpu())  # Depth output
    
    batch_size = min(4, inputs.size(0))
    
    # Visualize segmentation
    for i in range(inputs.size(0)):
        inputs_temp = inputs[i].detach().cpu()
        inputs_rgb = (inputs_temp - inputs_temp.min()) / (inputs_temp.max() - inputs_temp.min() + 1e-5)

        seg_label = seg_labels[i].detach().cpu()
        seg_outputs_vis = [torch.argmax(output[i], dim=0).detach().cpu() for output in seg_outputs]

        # fig, axes = plt.subplots(1, len(models) + 2, figsize=(15, 5))
        fig, axes = plt.subplots(batch_size, len(models) + 2, figsize=(15, 4 * batch_size))
        
        for i in range(batch_size):
        
            inputs_temp = inputs[i]
            # print(f"inputs_temp: {inputs_temp.shape}")
            inputs_rgb = (inputs_temp - inputs_temp.min()) / (inputs_temp.max() - inputs_temp.min() + 1e-5)  # Normalize inputs to [0, 1]

            # depth_labels_vis = (depth_labels[i] - depth_labels[i].min()) / (depth_labels[i].max() - depth_labels[i].min() + 1e-5)
            # depth_preds = depth_output[i]
            # depth_preds_vis = (depth_preds - depth_preds.min()) / (depth_preds.max() - depth_preds.min() + 1e-5) 

            # Original Image
            axes[i,0].imshow(inputs_rgb.permute(1, 2, 0))
            axes[i,0].set_title("Original Image (RGB)")
            axes[i,0].axis("off")

            # Segmentation Labels
            axes[i,1].imshow(seg_label, cmap="tab20")
            axes[i,1].set_title("Original Segmentation Label")
            axes[i,1].axis("off")

            # Model Segmentation Outputs
            for j, seg_output in enumerate(seg_outputs_vis):
                axes[i,j + 2].imshow(seg_output, cmap="tab20")
                axes[i,j + 2].set_title(f"Model {j + 1} Segmentation Output")
                axes[i,j + 2].axis("off")

            seg_output_path = os.path.join(output_dir, f"segmentation_visualization_image_{i}.png")
            plt.savefig(seg_output_path)
            plt.close()
            print(f"Segmentation visualization for image {i} saved at: {seg_output_path}")

        for i in range(batch_size):
        
            inputs_temp = inputs[i]
            # print(f"inputs_temp: {inputs_temp.shape}")
            inputs_rgb = (inputs_temp - inputs_temp.min()) / (inputs_temp.max() - inputs_temp.min() + 1e-5)  # Normalize inputs to [0, 1]

            depth_labels_vis = (depth_labels[i] - depth_labels[i].min()) / (depth_labels[i].max() - depth_labels[i].min() + 1e-5)
            depth_preds = depth_outputs[i]
            depth_preds_vis = (depth_preds - depth_preds.min()) / (depth_preds.max() - depth_preds.min() + 1e-5) 

            # Original Image
            axes[i,0].imshow(inputs_rgb.permute(1, 2, 0))
            axes[i,0].set_title("Original Image (RGB)")
            axes[i,0].axis("off")

            # Segmentation Labels
            axes[i,1].imshow(depth_labels_vis.squeeze(), cmap="tab20")
            axes[i,1].set_title("Original Depth Map")
            axes[i,1].axis("off")

            # Model Depth Outputs
            for j, depth_output in enumerate(depth_preds_vis):
                axes[i,j + 2].imshow(depth_output.squeeze(), cmap="inferno")
                axes[i,j + 2].set_title(f"Model {j + 1} Depth Output")
                axes[i,j + 2].axis("off")

            depth_output_path = os.path.join(output_dir, f"depth_visualization_image_{i}.png")
            plt.savefig(depth_output_path)
            plt.close()
            print(f"Depth visualization for image {i} saved at: {depth_output_path}")

            

In [215]:
# device =  'cpu'

# # Initialize the model class and loaders
# mobilenet_backbone = mobilenet_v3_small(weights=MobileNet_V3_Small_Weights.IMAGENET1K_V1)
# model_class = lambda: MultiTaskModel(backbone=mobilenet_backbone.features, num_seg_classes=20, depth_channels=1)

# # # mobilenet_backbone = mobilenet_v3_small(weights=MobileNet_V3_Small_Weights.IMAGENET1K_V1)
# model_class2 = lambda: MultiTaskModel_old(backbone=mobilenet_backbone.features, num_seg_classes=20, depth_channels=1)

# # opt_sched = initialize_optimizers_and_schedulers(model)

# # Initialize optimizers and schedulers
# model1 = model_class2()
# opt_sched1 = initialize_optimizers_and_schedulers(model1)
# start_epoch, best_loss = load_checkpoint(model1, opt_sched1, model_map['test8'], device)

# model2 = model_class()
# opt_sched2 = initialize_optimizers_and_schedulers(model2)
# start_epoch, best_loss = load_checkpoint(model2, opt_sched2, model_map['test9'], device)

# model3 = model_class()
# opt_sched3 = initialize_optimizers_and_schedulers(model3)
# start_epoch, best_loss = load_checkpoint(model3, opt_sched3, model_map['test11'], device)

# model4 = model_class()
# opt_sched4 = initialize_optimizers_and_schedulers(model4)
# start_epoch, best_loss = load_checkpoint(model4, opt_sched4, model_map['test13'], device)




In [216]:
def extract_batch_data(valid_loader, batch_index):
    # Get the specified batch
    inputs, seg_labels, depth_labels = None, None, None
    for idx, batch in enumerate(valid_loader):
        print(idx)
        if idx == batch_index:
            inputs = batch["left"]
            seg_labels = batch["mask"]
            depth_labels = batch["depth"]
            
            input_size = inputs.size()[-2:]
            
            # Preprocess seg_labels to one-hot encoding
            if seg_labels.size(1) == 1:
                seg_labels = torch.nn.functional.one_hot(seg_labels.squeeze(1), num_classes=20)
                seg_labels = seg_labels.permute(0, 3, 1, 2).float().to(device)

            # Ensure depth_labels has correct dimensions
            if depth_labels.dim() == 5:
                depth_labels = depth_labels.squeeze(2)
                
            # seg_labels = torch.argmax(seg_labels, dim=1)
            break
    print("shape of data in extract_batch_data")
    print(inputs.shape,seg_labels.shape,depth_labels.shape)
    return inputs.detach().cpu(),input_size,seg_labels.detach().cpu(),depth_labels.detach().cpu()

In [217]:
def get_seg_depth_output(model,inputs,input_size,seg_labels,depth_labels):
    outputs = model(
                    inputs,
                    input_size=input_size,
                    seg_labels=seg_labels,
                    depth_labels=depth_labels,
                    return_discriminator_outputs=False,
                    )
    print("inside get_seg_depth_output")
    print(outputs["seg_output"].shape,outputs["depth_output"].shape)
    return outputs["seg_output"],outputs["depth_output"]

In [219]:
def save_batch_visualization3(inputs,input_size,seg_outputs,depth_outputs,seg_labels, depth_labels,best_model_dir):
    inputs = inputs.detach().cpu()
    seg_labels = seg_labels.detach().cpu()
    depth_labels = depth_labels.detach().cpu()
    
    batch_size = min(4, inputs.size(0))  # Limit to 4 samples for visualization
    fig, axes = plt.subplots(batch_size, 5, figsize=(15, 4 * batch_size))
    
    for i in range(batch_size):
        inputs_temp = inputs[i]
        # print(f"inputs_temp: {inputs_temp.shape}")
        inputs_rgb = (inputs_temp - inputs_temp.min()) / (inputs_temp.max() - inputs_temp.min() + 1e-5)  # Normalize inputs to [0, 1]
        
        # Row 1: Ground truth
        axes[i, 0].imshow(inputs_rgb.permute(1, 2, 0))
        axes[i, 0].set_title("RGB Image")
        axes[i, 0].axis("off")

        axes[i, 1].imshow(seg_labels[i], cmap="tab20")
        axes[i, 1].set_title("GT Segmentation")
        axes[i, 1].axis("off")
        
        j=2
        for index,seg_out in enumerate(seg_outputs):
            seg_output = torch.argmax(seg_output, dim=1).detach().cpu()
            axes[i, j].imshow(seg_output[i], cmap="tab20")
            axes[i, j].set_title(f"Model_{index}")
            axes[i, j].axis("off")
            j+=1
            
    for ax in axes.flat:
        ax.axis("off")
        
    fig.tight_layout() 
    
    seg_output_path = os.path.join(best_model_dir, "segmentation_visualization_image.png")
    plt.savefig(seg_output_path)
    plt.close()
    print(f"Segmentation visualization for image {i} saved at: {seg_output_path}")
        
        
        
    

In [220]:
def save_batch_visualizations2(models, valid_loader, device, best_model_dir, batch_index=0):
    seg_outputs = []
    depth_outputs = []
    
    inputs,input_size,seg_labels,depth_labels = extract_batch_data(valid_loader,batch_index)
    
    for model in models:
        model.eval()
        seg_output, depth_output= get_seg_depth_output(model,inputs,input_size,seg_labels,depth_labels)
        seg_outputs.append(seg_output)
        depth_outputs.append(depth_output)
        
    print("save_batch_visualizations2")
    print(type(seg_output),type(seg_outputs),len(seg_outputs))
    
    save_batch_visualization3(inputs,input_size,seg_outputs,depth_outputs,torch.argmax(seg_labels, dim=1), depth_labels,best_model_dir)
    

In [221]:
device =  'cpu'

# Initialize the model class and loaders
mobilenet_backbone = mobilenet_v3_small(weights=MobileNet_V3_Small_Weights.IMAGENET1K_V1)
model_class = lambda: MultiTaskModel(backbone=mobilenet_backbone.features, num_seg_classes=20, depth_channels=1)

# # mobilenet_backbone = mobilenet_v3_small(weights=MobileNet_V3_Small_Weights.IMAGENET1K_V1)
model_class2 = lambda: MultiTaskModel_old(backbone=mobilenet_backbone.features, num_seg_classes=20, depth_channels=1)

# opt_sched = initialize_optimizers_and_schedulers(model)

# Initialize optimizers and schedulers
model1 = model_class2()
opt_sched1 = initialize_optimizers_and_schedulers(model1)
start_epoch, best_loss = load_checkpoint(model1, opt_sched1, model_map['test8'], device)

model2 = model_class()
opt_sched2 = initialize_optimizers_and_schedulers(model2)
start_epoch, best_loss = load_checkpoint(model2, opt_sched2, model_map['test9'], device)

model3 = model_class()
opt_sched3 = initialize_optimizers_and_schedulers(model3)
start_epoch, best_loss = load_checkpoint(model3, opt_sched3, model_map['test11'], device)

model4 = model_class()
opt_sched4 = initialize_optimizers_and_schedulers(model4)
start_epoch, best_loss = load_checkpoint(model4, opt_sched4, model_map['test13'], device)


models = [model1,model2,model3,model4]
# Visualize outputs
# visualize_model_outputs(models, val_loader, device)
# Save visualizations
# save_visualizations(models, valid_loader, device, best_model_dir)
# save_batch_visualizations2(models, valid_loader, device, best_model_dir, batch_index=0)

seg_outputs = []
depth_outputs = []
batch_index = 12 
inputs,input_size,seg_labels,depth_labels = extract_batch_data(valid_loader,batch_index)

for model in models:
    model.eval()
    seg_output, depth_output= get_seg_depth_output(model,inputs,input_size,seg_labels,depth_labels)
    seg_outputs.append(seg_output.detach().cpu())
    depth_outputs.append(depth_output.detach().cpu())

print("save_batch_visualizations2")
print(type(seg_output),type(seg_outputs),len(seg_outputs))

  checkpoint = torch.load(checkpoint_path, map_location=device)


0
1
2
3
4
5
6
7
8
9
10
11
12
shape of data in extract_batch_data
torch.Size([8, 3, 200, 512]) torch.Size([8, 20, 200, 512]) torch.Size([8, 1, 200, 512])
inside get_seg_depth_output
torch.Size([8, 20, 200, 512]) torch.Size([8, 1, 200, 512])
inside get_seg_depth_output
torch.Size([8, 20, 200, 512]) torch.Size([8, 1, 200, 512])
inside get_seg_depth_output
torch.Size([8, 20, 200, 512]) torch.Size([8, 1, 200, 512])
inside get_seg_depth_output
torch.Size([8, 20, 200, 512]) torch.Size([8, 1, 200, 512])
save_batch_visualizations2
<class 'torch.Tensor'> <class 'list'> 4


In [222]:
inputs = inputs.detach().cpu()
seg_labels = seg_labels.detach().cpu()
seg_labels =torch.argmax(seg_labels, dim=1)
depth_labels = depth_labels.detach().cpu()

In [223]:
print(seg_labels.shape)

torch.Size([8, 200, 512])


In [224]:


batch_size = min(4, inputs.size(0))  # Limit to 4 samples for visualization
fig, axes = plt.subplots(batch_size, 6, figsize=(15, 6 ))

for i in range(batch_size):
    inputs_temp = inputs[i]
    # print(f"inputs_temp: {inputs_temp.shape}")
    inputs_rgb = (inputs_temp - inputs_temp.min()) / (inputs_temp.max() - inputs_temp.min() + 1e-5)  # Normalize inputs to [0, 1]

    # Row 1: Ground truth
    axes[i, 0].imshow(inputs_rgb.permute(1, 2, 0))
    axes[i, 0].set_title("RGB Image")
    axes[i, 0].axis("off")

    axes[i, 1].imshow(seg_labels[i], cmap="tab20")
    axes[i, 1].set_title("GT Segmentation")
    axes[i, 1].axis("off")

    j=2
    for index,seg_output in enumerate(seg_outputs):
        seg_output = torch.argmax(seg_output, dim=0).detach().cpu()
        print(seg_output.shape)
        axes[i, j].imshow(seg_output[i], cmap="tab20")
        axes[i, j].set_title(f"Model_{index}")
        axes[i, j].axis("off")
        j+=1

for ax in axes.flat:
    ax.axis("off")

fig.tight_layout() 

seg_output_path = os.path.join(best_model_dir, f"segmentation_visualization_image_{batch_index}.png")
plt.savefig(seg_output_path)
plt.close()
print(f"Segmentation visualization for images at: {seg_output_path}")

        

torch.Size([20, 200, 512])
torch.Size([20, 200, 512])
torch.Size([20, 200, 512])
torch.Size([20, 200, 512])
torch.Size([20, 200, 512])
torch.Size([20, 200, 512])
torch.Size([20, 200, 512])
torch.Size([20, 200, 512])
torch.Size([20, 200, 512])
torch.Size([20, 200, 512])
torch.Size([20, 200, 512])
torch.Size([20, 200, 512])
torch.Size([20, 200, 512])
torch.Size([20, 200, 512])
torch.Size([20, 200, 512])
torch.Size([20, 200, 512])
Segmentation visualization for images at: /home/rmajumd/2024/ML_in_image_synthesis/Cityscapes/Cityscapes/Best_models/segmentation_visualization_image_12.png


In [225]:

batch_size = min(8, inputs.size(0))  # Limit to 4 samples for visualization
fig, axes = plt.subplots(batch_size, 6, figsize=(15, 8 ))

# depth_labels_vis = (depth_labels - depth_labels.min()) / (depth_labels.max() - depth_labels.min() + 1e-5)

for i in range(batch_size):
    inputs_temp = inputs[i]
    # print(f"inputs_temp: {inputs_temp.shape}")
    inputs_rgb = (inputs_temp - inputs_temp.min()) / (inputs_temp.max() - inputs_temp.min() + 1e-5)  # Normalize inputs to [0, 1]
    
    depth_labels_vis = (depth_labels[i] - depth_labels[i].min()) / (depth_labels[i].max() - depth_labels[i].min() + 1e-5)
    # depth_preds = depth_output[i]
    # depth_preds_vis = (depth_preds - depth_preds.min()) / (depth_preds.max() - depth_preds.min() + 1e-5)

    # Row 1: Ground truth
    axes[i, 0].imshow(inputs_rgb.permute(1, 2, 0))
    axes[i, 0].set_title("RGB Image")
    axes[i, 0].axis("off")

    axes[i, 1].imshow(depth_labels_vis.squeeze(), cmap="inferno")
    axes[i, 1].set_title("GT Depth")
    axes[i, 1].axis("off")

    j=2
    for index,depth_output in enumerate(depth_outputs):
        depth_preds = depth_output[i]
        depth_preds_vis = (depth_preds - depth_preds.min()) / (depth_preds.max() - depth_preds.min() + 1e-5)
        print(depth_preds_vis.shape,depth_preds_vis.squeeze(0).shape)
        axes[i, j].imshow(depth_preds_vis.squeeze(0), cmap="inferno")
        axes[i, j].set_title(f"Model_{index}")
        axes[i, j].axis("off")
        j+=1

for ax in axes.flat:
    ax.axis("off")

fig.tight_layout() 

depth_output_path = os.path.join(best_model_dir, f"depth_visualization_image_{batch_index}.png")
plt.savefig(depth_output_path)
plt.close()
print(f"Depth visualization for images saved at: {depth_output_path}")


torch.Size([1, 200, 512]) torch.Size([200, 512])
torch.Size([1, 200, 512]) torch.Size([200, 512])
torch.Size([1, 200, 512]) torch.Size([200, 512])
torch.Size([1, 200, 512]) torch.Size([200, 512])
torch.Size([1, 200, 512]) torch.Size([200, 512])
torch.Size([1, 200, 512]) torch.Size([200, 512])
torch.Size([1, 200, 512]) torch.Size([200, 512])
torch.Size([1, 200, 512]) torch.Size([200, 512])
torch.Size([1, 200, 512]) torch.Size([200, 512])
torch.Size([1, 200, 512]) torch.Size([200, 512])
torch.Size([1, 200, 512]) torch.Size([200, 512])
torch.Size([1, 200, 512]) torch.Size([200, 512])
torch.Size([1, 200, 512]) torch.Size([200, 512])
torch.Size([1, 200, 512]) torch.Size([200, 512])
torch.Size([1, 200, 512]) torch.Size([200, 512])
torch.Size([1, 200, 512]) torch.Size([200, 512])
torch.Size([1, 200, 512]) torch.Size([200, 512])
torch.Size([1, 200, 512]) torch.Size([200, 512])
torch.Size([1, 200, 512]) torch.Size([200, 512])
torch.Size([1, 200, 512]) torch.Size([200, 512])
torch.Size([1, 200, 

In [52]:
# test12:Train FFF, Valid-FFF

# For first instance

In [53]:
# Initialize the model
# Instantiate Models

BATCH_SIZE = 8
EPOCHS = 100
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'


mobilenet_backbone = mobilenet_v3_small(weights=MobileNet_V3_Small_Weights.IMAGENET1K_V1)
# encoder = MobileNetV3Backbone(mobilenet_backbone.features)
model = MultiTaskModel(backbone=mobilenet_backbone.features, num_seg_classes=20, depth_channels=1)
model.to(DEVICE)

# Initialize optimizers and schedulers
opt_sched = initialize_optimizers_and_schedulers(model)

# Access optimizers
optimizers = opt_sched["optimizers"]
schedulers = opt_sched["schedulers"]

# Prepare Data Loaders (Ensure train_loader and valid_loader are ready)
train_losses, valid_losses,save_dir2 = train_model_with_adversarial_loss_tracking(
    model=model,
    train_loader=train_loader,
    valid_loader=valid_loader,
    num_epochs=EPOCHS,
    device=DEVICE,
    opt_sched=opt_sched,
    save_dir="results_test12_final3_MTL_disc_channel_21"
)

Epoch 1/100 - Training:   0%|          | 0/125 [02:56<?, ?batch/s]


Best model saved at epoch 1 with combined loss 2.7408
Epoch 1/100 Results:
  Train Losses - Segmentation: 2.6150, Depth: 0.0706, Combined: 2.6842, Adversarial: -0.1380
  Valid Losses - Segmentation: 2.6908, Depth: 0.0518, Combined: 2.7417, Adversarial: -0.0856


Epoch 2/100 - Training:   0%|          | 0/125 [02:50<?, ?batch/s]


Best model saved at epoch 2 with combined loss 2.4960
Epoch 2/100 Results:
  Train Losses - Segmentation: 2.2913, Depth: 0.0360, Combined: 2.3263, Adversarial: -0.0984
  Valid Losses - Segmentation: 2.4602, Depth: 0.0351, Combined: 2.4957, Adversarial: 0.0338


Epoch 3/100 - Training:   0%|          | 0/125 [02:51<?, ?batch/s]


Best model saved at epoch 3 with combined loss 2.3005
Epoch 3/100 Results:
  Train Losses - Segmentation: 2.2053, Depth: 0.0360, Combined: 2.2420, Adversarial: 0.0671
  Valid Losses - Segmentation: 2.2584, Depth: 0.0439, Combined: 2.3014, Adversarial: -0.0897


Epoch 4/100 - Training:   0%|          | 0/125 [02:52<?, ?batch/s]


Epoch 4/100 Results:
  Train Losses - Segmentation: 2.1675, Depth: 0.0340, Combined: 2.2000, Adversarial: -0.1528
  Valid Losses - Segmentation: 2.4102, Depth: 0.0448, Combined: 2.4532, Adversarial: -0.1734


Epoch 5/100 - Training:   0%|          | 0/125 [02:54<?, ?batch/s]


Epoch 5/100 Results:
  Train Losses - Segmentation: 2.1272, Depth: 0.0345, Combined: 2.1597, Adversarial: -0.1904
  Valid Losses - Segmentation: 2.4627, Depth: 0.0722, Combined: 2.5328, Adversarial: -0.2148


Epoch 6/100 - Training:   0%|          | 0/125 [02:49<?, ?batch/s]


Best model saved at epoch 6 with combined loss 2.2204
Epoch 6/100 Results:
  Train Losses - Segmentation: 2.0647, Depth: 0.0359, Combined: 2.0983, Adversarial: -0.2359
  Valid Losses - Segmentation: 2.1868, Depth: 0.0388, Combined: 2.2230, Adversarial: -0.2614


Epoch 7/100 - Training:   0%|          | 0/125 [02:49<?, ?batch/s]


Epoch 7/100 Results:
  Train Losses - Segmentation: 2.0296, Depth: 0.0314, Combined: 2.0578, Adversarial: -0.3109
  Valid Losses - Segmentation: 2.6012, Depth: 0.0542, Combined: 2.6515, Adversarial: -0.3802


Epoch 8/100 - Training:   0%|          | 0/125 [02:53<?, ?batch/s]


Epoch 8/100 Results:
  Train Losses - Segmentation: 2.0570, Depth: 0.0317, Combined: 2.0857, Adversarial: -0.3009
  Valid Losses - Segmentation: 2.3027, Depth: 0.0483, Combined: 2.3477, Adversarial: -0.3320


Epoch 9/100 - Training:   0%|          | 0/125 [02:52<?, ?batch/s]


Epoch 9/100 Results:
  Train Losses - Segmentation: 2.0521, Depth: 0.0332, Combined: 2.0818, Adversarial: -0.3465
  Valid Losses - Segmentation: 2.2114, Depth: 0.1791, Combined: 2.3872, Adversarial: -0.3262


Epoch 10/100 - Training:   0%|          | 0/125 [02:54<?, ?batch/s]


Best model saved at epoch 10 with combined loss 2.2176
Epoch 10/100 Results:
  Train Losses - Segmentation: 1.9860, Depth: 0.0331, Combined: 2.0153, Adversarial: -0.3863
  Valid Losses - Segmentation: 2.1724, Depth: 0.0534, Combined: 2.2217, Adversarial: -0.4115


Epoch 11/100 - Training:   0%|          | 0/125 [02:52<?, ?batch/s]


Best model saved at epoch 11 with combined loss 2.2058
Epoch 11/100 Results:
  Train Losses - Segmentation: 2.0062, Depth: 0.0294, Combined: 2.0312, Adversarial: -0.4377
  Valid Losses - Segmentation: 2.1751, Depth: 0.0398, Combined: 2.2103, Adversarial: -0.4591


Epoch 12/100 - Training:   0%|          | 0/125 [02:51<?, ?batch/s]


Epoch 12/100 Results:
  Train Losses - Segmentation: 1.9893, Depth: 0.0323, Combined: 2.0168, Adversarial: -0.4823
  Valid Losses - Segmentation: 2.2646, Depth: 0.0618, Combined: 2.3214, Adversarial: -0.4944


Epoch 13/100 - Training:   0%|          | 0/125 [02:52<?, ?batch/s]


Epoch 13/100 Results:
  Train Losses - Segmentation: 1.9949, Depth: 0.0314, Combined: 2.0213, Adversarial: -0.5002
  Valid Losses - Segmentation: 2.1810, Depth: 0.0433, Combined: 2.2194, Adversarial: -0.4907


Epoch 14/100 - Training:   0%|          | 0/125 [02:49<?, ?batch/s]


Epoch 14/100 Results:
  Train Losses - Segmentation: 1.9433, Depth: 0.0287, Combined: 1.9667, Adversarial: -0.5385
  Valid Losses - Segmentation: 2.2179, Depth: 0.0742, Combined: 2.2866, Adversarial: -0.5486


Epoch 15/100 - Training:   0%|          | 0/125 [02:55<?, ?batch/s]


Epoch 15/100 Results:
  Train Losses - Segmentation: 1.9483, Depth: 0.0297, Combined: 1.9721, Adversarial: -0.5870
  Valid Losses - Segmentation: 2.2561, Depth: 0.0633, Combined: 2.3133, Adversarial: -0.6085


Epoch 16/100 - Training:   0%|          | 0/125 [02:53<?, ?batch/s]


Epoch 16/100 Results:
  Train Losses - Segmentation: 1.9182, Depth: 0.0276, Combined: 1.9394, Adversarial: -0.6385
  Valid Losses - Segmentation: 2.1791, Depth: 0.0834, Combined: 2.2563, Adversarial: -0.6160


Epoch 17/100 - Training:   0%|          | 0/125 [02:53<?, ?batch/s]


Best model saved at epoch 17 with combined loss 2.1961
Epoch 17/100 Results:
  Train Losses - Segmentation: 1.8744, Depth: 0.0269, Combined: 1.8949, Adversarial: -0.6390
  Valid Losses - Segmentation: 2.1659, Depth: 0.0439, Combined: 2.2029, Adversarial: -0.6859


Epoch 18/100 - Training:   0%|          | 0/125 [02:50<?, ?batch/s]


Best model saved at epoch 18 with combined loss 2.1160
Epoch 18/100 Results:
  Train Losses - Segmentation: 1.9095, Depth: 0.0269, Combined: 1.9293, Adversarial: -0.7126
  Valid Losses - Segmentation: 2.0674, Depth: 0.0624, Combined: 2.1229, Adversarial: -0.6921


Epoch 19/100 - Training:   0%|          | 0/125 [02:54<?, ?batch/s]


Epoch 19/100 Results:
  Train Losses - Segmentation: 1.8914, Depth: 0.0292, Combined: 1.9132, Adversarial: -0.7366
  Valid Losses - Segmentation: 2.2280, Depth: 0.0439, Combined: 2.2645, Adversarial: -0.7346


Epoch 20/100 - Training:   0%|          | 0/125 [02:55<?, ?batch/s]


Epoch 20/100 Results:
  Train Losses - Segmentation: 1.8988, Depth: 0.0281, Combined: 1.9192, Adversarial: -0.7750
  Valid Losses - Segmentation: 2.1467, Depth: 0.0525, Combined: 2.1915, Adversarial: -0.7707


Epoch 21/100 - Training:   0%|          | 0/125 [02:50<?, ?batch/s]


Epoch 21/100 Results:
  Train Losses - Segmentation: 1.8841, Depth: 0.0263, Combined: 1.9023, Adversarial: -0.8129
  Valid Losses - Segmentation: 2.2782, Depth: 0.0588, Combined: 2.3294, Adversarial: -0.7634


Epoch 22/100 - Training:   0%|          | 0/125 [02:56<?, ?batch/s]


Epoch 22/100 Results:
  Train Losses - Segmentation: 1.8712, Depth: 0.0259, Combined: 1.8894, Adversarial: -0.7703
  Valid Losses - Segmentation: 2.1423, Depth: 0.0572, Combined: 2.1912, Adversarial: -0.8244


Epoch 23/100 - Training:   0%|          | 0/125 [02:56<?, ?batch/s]


Epoch 23/100 Results:
  Train Losses - Segmentation: 1.8614, Depth: 0.0260, Combined: 1.8789, Adversarial: -0.8418
  Valid Losses - Segmentation: 2.2147, Depth: 0.0675, Combined: 2.2736, Adversarial: -0.8560


Epoch 24/100 - Training:   0%|          | 0/125 [02:55<?, ?batch/s]


Epoch 24/100 Results:
  Train Losses - Segmentation: 1.8697, Depth: 0.0256, Combined: 1.8867, Adversarial: -0.8536
  Valid Losses - Segmentation: 2.1454, Depth: 0.0429, Combined: 2.1796, Adversarial: -0.8691


Epoch 25/100 - Training:   0%|          | 0/125 [02:55<?, ?batch/s]


Epoch 25/100 Results:
  Train Losses - Segmentation: 1.8666, Depth: 0.0267, Combined: 1.8848, Adversarial: -0.8465
  Valid Losses - Segmentation: 2.1969, Depth: 0.0433, Combined: 2.2316, Adversarial: -0.8555


Epoch 26/100 - Training:   0%|          | 0/125 [02:57<?, ?batch/s]


Epoch 26/100 Results:
  Train Losses - Segmentation: 1.8471, Depth: 0.0251, Combined: 1.8636, Adversarial: -0.8647
  Valid Losses - Segmentation: 2.1772, Depth: 0.0612, Combined: 2.2298, Adversarial: -0.8579


Epoch 27/100 - Training:   0%|          | 0/125 [02:56<?, ?batch/s]


Epoch 27/100 Results:
  Train Losses - Segmentation: 1.8551, Depth: 0.0248, Combined: 1.8712, Adversarial: -0.8698
  Valid Losses - Segmentation: 2.1224, Depth: 0.0455, Combined: 2.1594, Adversarial: -0.8522


Epoch 28/100 - Training:   0%|          | 0/125 [02:55<?, ?batch/s]


Epoch 28/100 Results:
  Train Losses - Segmentation: 1.8327, Depth: 0.0256, Combined: 1.8497, Adversarial: -0.8563
  Valid Losses - Segmentation: 2.1522, Depth: 0.0400, Combined: 2.1836, Adversarial: -0.8703


Epoch 29/100 - Training:   0%|          | 0/125 [02:52<?, ?batch/s]


Best model saved at epoch 29 with combined loss 2.0468
Epoch 29/100 Results:
  Train Losses - Segmentation: 1.8214, Depth: 0.0260, Combined: 1.8386, Adversarial: -0.8842
  Valid Losses - Segmentation: 2.0364, Depth: 0.0337, Combined: 2.0584, Adversarial: -1.1626


Epoch 30/100 - Training:   0%|          | 0/125 [02:54<?, ?batch/s]


Epoch 30/100 Results:
  Train Losses - Segmentation: 1.8219, Depth: 0.0241, Combined: 1.8369, Adversarial: -0.9122
  Valid Losses - Segmentation: 2.1400, Depth: 0.0812, Combined: 2.2123, Adversarial: -0.8973


Epoch 31/100 - Training:   0%|          | 0/125 [02:54<?, ?batch/s]


Epoch 31/100 Results:
  Train Losses - Segmentation: 1.8102, Depth: 0.0250, Combined: 1.8262, Adversarial: -0.9082
  Valid Losses - Segmentation: 2.0956, Depth: 0.0750, Combined: 2.1614, Adversarial: -0.9197


Epoch 32/100 - Training:   0%|          | 0/125 [02:53<?, ?batch/s]


Epoch 32/100 Results:
  Train Losses - Segmentation: 1.7997, Depth: 0.0250, Combined: 1.8157, Adversarial: -0.9006
  Valid Losses - Segmentation: 2.1175, Depth: 0.0517, Combined: 2.1603, Adversarial: -0.8903


Epoch 33/100 - Training:   0%|          | 0/125 [02:55<?, ?batch/s]


Epoch 33/100 Results:
  Train Losses - Segmentation: 1.8017, Depth: 0.0227, Combined: 1.8154, Adversarial: -0.9036
  Valid Losses - Segmentation: 2.1378, Depth: 0.0404, Combined: 2.1692, Adversarial: -0.9073


Epoch 34/100 - Training:   0%|          | 0/125 [02:51<?, ?batch/s]


Epoch 34/100 Results:
  Train Losses - Segmentation: 1.7755, Depth: 0.0237, Combined: 1.7901, Adversarial: -0.9088
  Valid Losses - Segmentation: 2.0855, Depth: 0.0562, Combined: 2.1328, Adversarial: -0.8882


Epoch 35/100 - Training:   0%|          | 0/125 [02:47<?, ?batch/s]


Epoch 35/100 Results:
  Train Losses - Segmentation: 1.7619, Depth: 0.0240, Combined: 1.7767, Adversarial: -0.9189
  Valid Losses - Segmentation: 2.0954, Depth: 0.0464, Combined: 2.1325, Adversarial: -0.9294


Epoch 36/100 - Training:   0%|          | 0/125 [02:50<?, ?batch/s]


Epoch 36/100 Results:
  Train Losses - Segmentation: 1.7920, Depth: 0.0226, Combined: 1.8054, Adversarial: -0.9186
  Valid Losses - Segmentation: 2.1290, Depth: 0.0389, Combined: 2.1587, Adversarial: -0.9179


Epoch 37/100 - Training:   0%|          | 0/125 [02:54<?, ?batch/s]


Epoch 37/100 Results:
  Train Losses - Segmentation: 1.7925, Depth: 0.0228, Combined: 1.8062, Adversarial: -0.9056
  Valid Losses - Segmentation: 2.0595, Depth: 0.0324, Combined: 2.0828, Adversarial: -0.9034


Epoch 38/100 - Training:   0%|          | 0/125 [02:50<?, ?batch/s]


Epoch 38/100 Results:
  Train Losses - Segmentation: 1.7756, Depth: 0.0226, Combined: 1.7893, Adversarial: -0.8990
  Valid Losses - Segmentation: 2.1205, Depth: 0.0345, Combined: 2.1459, Adversarial: -0.9106


Epoch 39/100 - Training:   0%|          | 0/125 [02:53<?, ?batch/s]


Epoch 39/100 Results:
  Train Losses - Segmentation: 1.7627, Depth: 0.0224, Combined: 1.7758, Adversarial: -0.9256
  Valid Losses - Segmentation: 2.1022, Depth: 0.0381, Combined: 2.1309, Adversarial: -0.9379


Epoch 40/100 - Training:   0%|          | 0/125 [02:55<?, ?batch/s]


Epoch 40/100 Results:
  Train Losses - Segmentation: 1.7724, Depth: 0.0224, Combined: 1.7854, Adversarial: -0.9423
  Valid Losses - Segmentation: 2.0764, Depth: 0.0398, Combined: 2.1068, Adversarial: -0.9339


Epoch 41/100 - Training:   0%|          | 0/125 [02:53<?, ?batch/s]


Epoch 41/100 Results:
  Train Losses - Segmentation: 1.7514, Depth: 0.0220, Combined: 1.7639, Adversarial: -0.9427
  Valid Losses - Segmentation: 2.0222, Depth: 0.0508, Combined: 2.0636, Adversarial: -0.9438


Epoch 42/100 - Training:   0%|          | 0/125 [02:55<?, ?batch/s]


Epoch 42/100 Results:
  Train Losses - Segmentation: 1.7491, Depth: 0.0222, Combined: 1.7618, Adversarial: -0.9419
  Valid Losses - Segmentation: 2.0888, Depth: 0.0439, Combined: 2.1232, Adversarial: -0.9394


Epoch 43/100 - Training:   0%|          | 0/125 [02:59<?, ?batch/s]


Epoch 43/100 Results:
  Train Losses - Segmentation: 1.7708, Depth: 0.0221, Combined: 1.7834, Adversarial: -0.9420
  Valid Losses - Segmentation: 2.0865, Depth: 0.0458, Combined: 2.1228, Adversarial: -0.9448


Epoch 44/100 - Training:   0%|          | 0/125 [02:54<?, ?batch/s]


Epoch 44/100 Results:
  Train Losses - Segmentation: 1.7467, Depth: 0.0210, Combined: 1.7582, Adversarial: -0.9515
  Valid Losses - Segmentation: 2.0731, Depth: 0.0421, Combined: 2.1057, Adversarial: -0.9572


Epoch 45/100 - Training:   0%|          | 0/125 [02:54<?, ?batch/s]


Epoch 45/100 Results:
  Train Losses - Segmentation: 1.7504, Depth: 0.0216, Combined: 1.7624, Adversarial: -0.9552
  Valid Losses - Segmentation: 2.0851, Depth: 0.0472, Combined: 2.1228, Adversarial: -0.9521


Epoch 46/100 - Training:   0%|          | 0/125 [02:50<?, ?batch/s]


Epoch 46/100 Results:
  Train Losses - Segmentation: 1.7406, Depth: 0.0213, Combined: 1.7524, Adversarial: -0.9553
  Valid Losses - Segmentation: 2.0568, Depth: 0.0418, Combined: 2.0890, Adversarial: -0.9559


Epoch 47/100 - Training:   0%|          | 0/125 [02:51<?, ?batch/s]


Epoch 47/100 Results:
  Train Losses - Segmentation: 1.7545, Depth: 0.0219, Combined: 1.7668, Adversarial: -0.9534
  Valid Losses - Segmentation: 2.0638, Depth: 0.0428, Combined: 2.0970, Adversarial: -0.9560


Epoch 48/100 - Training:   0%|          | 0/125 [02:51<?, ?batch/s]


Epoch 48/100 Results:
  Train Losses - Segmentation: 1.7108, Depth: 0.0213, Combined: 1.7225, Adversarial: -0.9570
  Valid Losses - Segmentation: 2.0568, Depth: 0.0443, Combined: 2.0915, Adversarial: -0.9595


Epoch 49/100 - Training:   0%|          | 0/125 [02:51<?, ?batch/s]


Epoch 49/100 Results:
  Train Losses - Segmentation: 1.7526, Depth: 0.0217, Combined: 1.7647, Adversarial: -0.9616
  Valid Losses - Segmentation: 2.0744, Depth: 0.0456, Combined: 2.1104, Adversarial: -0.9620


Epoch 50/100 - Training:   0%|          | 0/125 [02:54<?, ?batch/s]


Epoch 50/100 Results:
  Train Losses - Segmentation: 1.7629, Depth: 0.0218, Combined: 1.7751, Adversarial: -0.9631
  Valid Losses - Segmentation: 2.0692, Depth: 0.0440, Combined: 2.1036, Adversarial: -0.9604


Epoch 51/100 - Training:   0%|          | 0/125 [02:50<?, ?batch/s]


Epoch 51/100 Results:
  Train Losses - Segmentation: 1.7550, Depth: 0.0214, Combined: 1.7668, Adversarial: -0.9629
  Valid Losses - Segmentation: 2.0772, Depth: 0.0402, Combined: 2.1078, Adversarial: -0.9640


Epoch 52/100 - Training:   0%|          | 0/125 [02:53<?, ?batch/s]


Epoch 52/100 Results:
  Train Losses - Segmentation: 1.7402, Depth: 0.0212, Combined: 1.7517, Adversarial: -0.9621
  Valid Losses - Segmentation: 2.0711, Depth: 0.0433, Combined: 2.1048, Adversarial: -0.9634


Epoch 53/100 - Training:   0%|          | 0/125 [02:54<?, ?batch/s]


Epoch 53/100 Results:
  Train Losses - Segmentation: 1.7235, Depth: 0.0218, Combined: 1.7357, Adversarial: -0.9624
  Valid Losses - Segmentation: 2.0881, Depth: 0.0444, Combined: 2.1229, Adversarial: -0.9612


Epoch 54/100 - Training:   0%|          | 0/125 [02:55<?, ?batch/s]


Epoch 54/100 Results:
  Train Losses - Segmentation: 1.7415, Depth: 0.0209, Combined: 1.7528, Adversarial: -0.9645
  Valid Losses - Segmentation: 2.0673, Depth: 0.0418, Combined: 2.0995, Adversarial: -0.9655


Epoch 55/100 - Training:   0%|          | 0/125 [02:52<?, ?batch/s]


Epoch 55/100 Results:
  Train Losses - Segmentation: 1.7696, Depth: 0.0216, Combined: 1.7816, Adversarial: -0.9625
  Valid Losses - Segmentation: 2.0558, Depth: 0.0448, Combined: 2.0911, Adversarial: -0.9610


Epoch 56/100 - Training:   0%|          | 0/125 [02:52<?, ?batch/s]


Epoch 56/100 Results:
  Train Losses - Segmentation: 1.7609, Depth: 0.0216, Combined: 1.7728, Adversarial: -0.9627
  Valid Losses - Segmentation: 2.0739, Depth: 0.0483, Combined: 2.1125, Adversarial: -0.9678


Epoch 57/100 - Training:   0%|          | 0/125 [02:52<?, ?batch/s]


Epoch 57/100 Results:
  Train Losses - Segmentation: 1.7404, Depth: 0.0215, Combined: 1.7522, Adversarial: -0.9670
  Valid Losses - Segmentation: 2.0578, Depth: 0.0497, Combined: 2.0978, Adversarial: -0.9676


Epoch 58/100 - Training:   0%|          | 0/125 [02:54<?, ?batch/s]


Epoch 58/100 Results:
  Train Losses - Segmentation: 1.7229, Depth: 0.0215, Combined: 1.7347, Adversarial: -0.9681
  Valid Losses - Segmentation: 2.1173, Depth: 0.0529, Combined: 2.1605, Adversarial: -0.9669


Epoch 59/100 - Training:   0%|          | 0/125 [02:53<?, ?batch/s]


Epoch 59/100 Results:
  Train Losses - Segmentation: 1.7357, Depth: 0.0213, Combined: 1.7474, Adversarial: -0.9575
  Valid Losses - Segmentation: 2.1280, Depth: 0.0508, Combined: 2.1688, Adversarial: -0.9912


Epoch 60/100 - Training:   0%|          | 0/125 [02:50<?, ?batch/s]


Epoch 60/100 Results:
  Train Losses - Segmentation: 1.7565, Depth: 0.0227, Combined: 1.7698, Adversarial: -0.9441
  Valid Losses - Segmentation: 2.0624, Depth: 0.0403, Combined: 2.0933, Adversarial: -0.9437


Epoch 61/100 - Training:   0%|          | 0/125 [02:50<?, ?batch/s]


Epoch 61/100 Results:
  Train Losses - Segmentation: 1.7715, Depth: 0.0217, Combined: 1.7835, Adversarial: -0.9737
  Valid Losses - Segmentation: 2.0956, Depth: 0.0417, Combined: 2.1277, Adversarial: -0.9670


Epoch 62/100 - Training:   0%|          | 0/125 [02:52<?, ?batch/s]


Epoch 62/100 Results:
  Train Losses - Segmentation: 1.7410, Depth: 0.0219, Combined: 1.7532, Adversarial: -0.9744
  Valid Losses - Segmentation: 2.0988, Depth: 0.0395, Combined: 2.1284, Adversarial: -0.9898


Epoch 63/100 - Training:   0%|          | 0/125 [02:54<?, ?batch/s]


Epoch 63/100 Results:
  Train Losses - Segmentation: 1.7375, Depth: 0.0216, Combined: 1.7491, Adversarial: -1.0019
  Valid Losses - Segmentation: 2.0806, Depth: 0.0524, Combined: 2.1228, Adversarial: -1.0188


Epoch 64/100 - Training:   0%|          | 0/125 [02:54<?, ?batch/s]


Epoch 64/100 Results:
  Train Losses - Segmentation: 1.7486, Depth: 0.0222, Combined: 1.7612, Adversarial: -0.9562
  Valid Losses - Segmentation: 2.0663, Depth: 0.0568, Combined: 2.1129, Adversarial: -1.0121


Epoch 65/100 - Training:   0%|          | 0/125 [02:54<?, ?batch/s]


Epoch 65/100 Results:
  Train Losses - Segmentation: 1.7643, Depth: 0.0222, Combined: 1.7762, Adversarial: -1.0374
  Valid Losses - Segmentation: 2.0642, Depth: 0.0332, Combined: 2.0881, Adversarial: -0.9372


Epoch 66/100 - Training:   0%|          | 0/125 [02:52<?, ?batch/s]


Epoch 66/100 Results:
  Train Losses - Segmentation: 1.7519, Depth: 0.0222, Combined: 1.7641, Adversarial: -1.0042
  Valid Losses - Segmentation: 2.1443, Depth: 0.0402, Combined: 2.1703, Adversarial: -1.4194


Epoch 67/100 - Training:   0%|          | 0/125 [02:49<?, ?batch/s]


Epoch 67/100 Results:
  Train Losses - Segmentation: 1.7454, Depth: 0.0226, Combined: 1.7579, Adversarial: -1.0085
  Valid Losses - Segmentation: 2.0982, Depth: 0.0375, Combined: 2.1254, Adversarial: -1.0345


Epoch 68/100 - Training:   0%|          | 0/125 [02:52<?, ?batch/s]


Epoch 68/100 Results:
  Train Losses - Segmentation: 1.7637, Depth: 0.0231, Combined: 1.7768, Adversarial: -1.0016
  Valid Losses - Segmentation: 2.0261, Depth: 0.0491, Combined: 2.0658, Adversarial: -0.9470


Epoch 69/100 - Training:   0%|          | 0/125 [02:51<?, ?batch/s]


Epoch 69/100 Results:
  Train Losses - Segmentation: 1.7657, Depth: 0.0234, Combined: 1.7796, Adversarial: -0.9505
  Valid Losses - Segmentation: 2.0412, Depth: 0.0307, Combined: 2.0631, Adversarial: -0.8812


Epoch 70/100 - Training:   0%|          | 0/125 [02:53<?, ?batch/s]


Epoch 70/100 Results:
  Train Losses - Segmentation: 1.7841, Depth: 0.0239, Combined: 1.7983, Adversarial: -0.9748
  Valid Losses - Segmentation: 2.0410, Depth: 0.0382, Combined: 2.0689, Adversarial: -1.0323


Epoch 71/100 - Training:   0%|          | 0/125 [02:54<?, ?batch/s]


Epoch 71/100 Results:
  Train Losses - Segmentation: 1.7391, Depth: 0.0220, Combined: 1.7513, Adversarial: -0.9859
  Valid Losses - Segmentation: 2.0632, Depth: 0.0443, Combined: 2.0976, Adversarial: -0.9948


Epoch 72/100 - Training:   0%|          | 0/125 [02:53<?, ?batch/s]


Epoch 72/100 Results:
  Train Losses - Segmentation: 1.7797, Depth: 0.0238, Combined: 1.7932, Adversarial: -1.0305
  Valid Losses - Segmentation: 2.0909, Depth: 0.0424, Combined: 2.1242, Adversarial: -0.9133


Epoch 73/100 - Training:   0%|          | 0/125 [02:49<?, ?batch/s]


Epoch 73/100 Results:
  Train Losses - Segmentation: 1.7831, Depth: 0.0240, Combined: 1.7974, Adversarial: -0.9797
  Valid Losses - Segmentation: 2.0376, Depth: 0.0305, Combined: 2.0577, Adversarial: -1.0420


Epoch 74/100 - Training:   0%|          | 0/125 [02:49<?, ?batch/s]


Epoch 74/100 Results:
  Train Losses - Segmentation: 1.7762, Depth: 0.0235, Combined: 1.7892, Adversarial: -1.0451
  Valid Losses - Segmentation: 2.2511, Depth: 0.0891, Combined: 2.3294, Adversarial: -1.0860


Epoch 75/100 - Training:   0%|          | 0/125 [02:52<?, ?batch/s]


Epoch 75/100 Results:
  Train Losses - Segmentation: 1.7778, Depth: 0.0230, Combined: 1.7901, Adversarial: -1.0709
  Valid Losses - Segmentation: 2.0475, Depth: 0.0448, Combined: 2.0816, Adversarial: -1.0660


Epoch 76/100 - Training:   0%|          | 0/125 [02:51<?, ?batch/s]


Epoch 76/100 Results:
  Train Losses - Segmentation: 1.7810, Depth: 0.0241, Combined: 1.7946, Adversarial: -1.0513
  Valid Losses - Segmentation: 2.0962, Depth: 0.0346, Combined: 2.1207, Adversarial: -1.0011


Epoch 77/100 - Training:   0%|          | 0/125 [02:51<?, ?batch/s]


Epoch 77/100 Results:
  Train Losses - Segmentation: 1.7599, Depth: 0.0232, Combined: 1.7721, Adversarial: -1.1022
  Valid Losses - Segmentation: 2.0749, Depth: 0.0574, Combined: 2.1221, Adversarial: -1.0185


Epoch 78/100 - Training:   0%|          | 0/125 [02:55<?, ?batch/s]


Epoch 78/100 Results:
  Train Losses - Segmentation: 1.8006, Depth: 0.0237, Combined: 1.8143, Adversarial: -0.9912
  Valid Losses - Segmentation: 2.0648, Depth: 0.0604, Combined: 2.1160, Adversarial: -0.9213


Epoch 79/100 - Training:   0%|          | 0/125 [02:55<?, ?batch/s]


Epoch 79/100 Results:
  Train Losses - Segmentation: 1.7682, Depth: 0.0245, Combined: 1.7824, Adversarial: -1.0326
  Valid Losses - Segmentation: 2.0481, Depth: 0.0830, Combined: 2.1239, Adversarial: -0.7139


Epoch 80/100 - Training:   0%|          | 0/125 [02:52<?, ?batch/s]


Epoch 80/100 Results:
  Train Losses - Segmentation: 1.7780, Depth: 0.0268, Combined: 1.7941, Adversarial: -1.0657
  Valid Losses - Segmentation: 2.1529, Depth: 0.0514, Combined: 2.1934, Adversarial: -1.0952


Epoch 81/100 - Training:   0%|          | 0/125 [02:53<?, ?batch/s]


Epoch 81/100 Results:
  Train Losses - Segmentation: 1.7838, Depth: 0.0241, Combined: 1.7971, Adversarial: -1.0747
  Valid Losses - Segmentation: 2.1563, Depth: 0.0361, Combined: 2.1823, Adversarial: -1.0062


Epoch 82/100 - Training:   0%|          | 0/125 [02:53<?, ?batch/s]


Epoch 82/100 Results:
  Train Losses - Segmentation: 1.7473, Depth: 0.0252, Combined: 1.7618, Adversarial: -1.0705
  Valid Losses - Segmentation: 2.1386, Depth: 0.0430, Combined: 2.1714, Adversarial: -1.0167


Epoch 83/100 - Training:   0%|          | 0/125 [02:53<?, ?batch/s]


Best model saved at epoch 83 with combined loss 2.0152
Epoch 83/100 Results:
  Train Losses - Segmentation: 1.7534, Depth: 0.0236, Combined: 1.7659, Adversarial: -1.1074
  Valid Losses - Segmentation: 2.0024, Depth: 0.0354, Combined: 2.0265, Adversarial: -1.1271


Epoch 84/100 - Training:   0%|          | 0/125 [02:53<?, ?batch/s]


Epoch 84/100 Results:
  Train Losses - Segmentation: 1.7822, Depth: 0.0238, Combined: 1.7945, Adversarial: -1.1516
  Valid Losses - Segmentation: 2.2594, Depth: 0.0393, Combined: 2.2871, Adversarial: -1.1568


Epoch 85/100 - Training:   0%|          | 0/125 [02:51<?, ?batch/s]


Epoch 85/100 Results:
  Train Losses - Segmentation: 1.7967, Depth: 0.0246, Combined: 1.8097, Adversarial: -1.1545
  Valid Losses - Segmentation: 2.1953, Depth: 0.0344, Combined: 2.2177, Adversarial: -1.1991


Epoch 86/100 - Training:   0%|          | 0/125 [02:55<?, ?batch/s]


Best model saved at epoch 86 with combined loss 2.0068
Epoch 86/100 Results:
  Train Losses - Segmentation: 1.8043, Depth: 0.0249, Combined: 1.8181, Adversarial: -1.1117
  Valid Losses - Segmentation: 1.9948, Depth: 0.0288, Combined: 2.0152, Adversarial: -0.8425


Epoch 87/100 - Training:   0%|          | 0/125 [02:53<?, ?batch/s]


Epoch 87/100 Results:
  Train Losses - Segmentation: 1.7761, Depth: 0.0251, Combined: 1.7892, Adversarial: -1.2052
  Valid Losses - Segmentation: 2.0116, Depth: 0.0397, Combined: 2.0391, Adversarial: -1.2246


Epoch 88/100 - Training:   0%|          | 0/125 [02:50<?, ?batch/s]


Epoch 88/100 Results:
  Train Losses - Segmentation: 1.7745, Depth: 0.0244, Combined: 1.7862, Adversarial: -1.2660
  Valid Losses - Segmentation: 1.9988, Depth: 0.0384, Combined: 2.0254, Adversarial: -1.1842


Epoch 89/100 - Training:   0%|          | 0/125 [02:50<?, ?batch/s]


Epoch 89/100 Results:
  Train Losses - Segmentation: 1.7858, Depth: 0.0238, Combined: 1.7972, Adversarial: -1.2402
  Valid Losses - Segmentation: 2.1381, Depth: 0.0451, Combined: 2.1706, Adversarial: -1.2621


Epoch 90/100 - Training:   0%|          | 0/125 [02:54<?, ?batch/s]


Epoch 90/100 Results:
  Train Losses - Segmentation: 1.7677, Depth: 0.0247, Combined: 1.7803, Adversarial: -1.1998
  Valid Losses - Segmentation: 2.0033, Depth: 0.0736, Combined: 2.0643, Adversarial: -1.2626


Epoch 91/100 - Training:   0%|          | 0/125 [02:50<?, ?batch/s]


Epoch 91/100 Results:
  Train Losses - Segmentation: 1.7997, Depth: 0.0271, Combined: 1.8144, Adversarial: -1.2458
  Valid Losses - Segmentation: 2.1587, Depth: 0.0360, Combined: 2.1830, Adversarial: -1.1754


Epoch 92/100 - Training:   0%|          | 0/125 [02:51<?, ?batch/s]


Epoch 92/100 Results:
  Train Losses - Segmentation: 1.7589, Depth: 0.0244, Combined: 1.7706, Adversarial: -1.2675
  Valid Losses - Segmentation: 2.0181, Depth: 0.0544, Combined: 2.0583, Adversarial: -1.4157


Epoch 93/100 - Training:   0%|          | 0/125 [02:52<?, ?batch/s]


Epoch 93/100 Results:
  Train Losses - Segmentation: 1.7702, Depth: 0.0250, Combined: 1.7825, Adversarial: -1.2707
  Valid Losses - Segmentation: 2.0902, Depth: 0.0403, Combined: 2.1181, Adversarial: -1.2421


Epoch 94/100 - Training:   0%|          | 0/125 [02:54<?, ?batch/s]


Epoch 94/100 Results:
  Train Losses - Segmentation: 1.8032, Depth: 0.0252, Combined: 1.8162, Adversarial: -1.2214
  Valid Losses - Segmentation: 2.4374, Depth: 0.0338, Combined: 2.4580, Adversarial: -1.3257


Epoch 95/100 - Training:   0%|          | 0/125 [02:56<?, ?batch/s]


Epoch 95/100 Results:
  Train Losses - Segmentation: 1.7755, Depth: 0.0258, Combined: 1.7882, Adversarial: -1.3132
  Valid Losses - Segmentation: 2.0545, Depth: 0.0360, Combined: 2.0781, Adversarial: -1.2402


Epoch 96/100 - Training:   0%|          | 0/125 [02:58<?, ?batch/s]


Epoch 96/100 Results:
  Train Losses - Segmentation: 1.7972, Depth: 0.0244, Combined: 1.8091, Adversarial: -1.2513
  Valid Losses - Segmentation: 2.3038, Depth: 0.0606, Combined: 2.3509, Adversarial: -1.3549


Epoch 97/100 - Training:   0%|          | 0/125 [02:56<?, ?batch/s]


Epoch 97/100 Results:
  Train Losses - Segmentation: 1.7732, Depth: 0.0257, Combined: 1.7836, Adversarial: -1.5392
  Valid Losses - Segmentation: 2.1059, Depth: 0.0473, Combined: 2.1325, Adversarial: -2.0689


Epoch 98/100 - Training:   0%|          | 0/125 [02:54<?, ?batch/s]


Epoch 98/100 Results:
  Train Losses - Segmentation: 1.7531, Depth: 0.0249, Combined: 1.7651, Adversarial: -1.2852
  Valid Losses - Segmentation: 2.2048, Depth: 0.0976, Combined: 2.2886, Adversarial: -1.3782


Epoch 99/100 - Training:   0%|          | 0/125 [02:53<?, ?batch/s]


Epoch 99/100 Results:
  Train Losses - Segmentation: 1.7714, Depth: 0.0248, Combined: 1.7829, Adversarial: -1.3315
  Valid Losses - Segmentation: 2.0458, Depth: 0.0572, Combined: 2.0898, Adversarial: -1.3195


Epoch 100/100 - Training:   0%|          | 0/125 [02:54<?, ?batch/s]


Epoch 100/100 Results:
  Train Losses - Segmentation: 1.7919, Depth: 0.0262, Combined: 1.8045, Adversarial: -1.3633
  Valid Losses - Segmentation: 2.0056, Depth: 0.0539, Combined: 2.0454, Adversarial: -1.4092
Training visualization saved as GIF at results_test12_final3_MTL_disc_channel_21/20241203_042841/training_visualization_20241203_042841.gif


# For second instance

In [260]:
root_save_dir = os.path.join(os.getcwd(),'results_test12_final3_MTL_disc_channel_21')
model_dir =  os.path.join(root_save_dir,'20241203_042841')
# model_dir = os.path.join(model_dir,'combined_result')
model_dir

'/home/rmajumd/2024/ML_in_image_synthesis/Cityscapes/Cityscapes/results_test12_final3_MTL_disc_channel_21/20241203_042841'

In [261]:
# root_save_dir = os.path.join(os.getcwd(),'results_test8_final2')
# model_dir =  os.path.join(root_save_dir,'20241129_030455')
num_additional_epochs = 70  # Number of epochs to continue training
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Initialize the model class and loaders
mobilenet_backbone = mobilenet_v3_small(weights=MobileNet_V3_Small_Weights.IMAGENET1K_V1)
model_class = lambda: MultiTaskModel(backbone=mobilenet_backbone.features, num_seg_classes=20, depth_channels=1)

# Initialize optimizers and schedulers
model = model_class()
opt_sched = initialize_optimizers_and_schedulers(model)


# Call the resume training function
resume_training_with_loss_tracking(
    model_class=model_class,
    model_dir=model_dir,
    train_loader=train_loader,
    valid_loader=valid_loader,
    num_additional_epochs=num_additional_epochs,
    device=device,
    opt_sched=opt_sched,
    save_dir="results_test8_final3_MTL_disc_channel_21"
)

root_save_dir,model_dir

  checkpoint = torch.load(checkpoint_path, map_location=device)
Epoch 1/70 - Training:   0%|          | 0/371 [00:01<?, ?batch/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 66.00 MiB. GPU 0 has a total capacity of 15.77 GiB of which 22.62 MiB is free. Including non-PyTorch memory, this process has 3.94 GiB memory in use. Process 1124940 has 11.49 GiB memory in use. Of the allocated memory 3.51 GiB is allocated by PyTorch, and 65.91 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
input("Enter timestamp folder value")

In [1]:
import os
import csv
import torch
import matplotlib.pyplot as plt


In [49]:
# root = os.path.join(os.getcwd(), 'results_test8_final' )
# model_dir = os.path.join(root,'20241129_015944')
# save_dir2 = os.path.join(root,'20241129_024547')

# output_path = os.path.join(save_dir2,'combined_results.gif')
# combine_training_gifs(model_dir, save_dir2, output_path)

Combined GIF saved to /home/rmajumd/2024/ML_in_image_synthesis/Cityscapes/Cityscapes/results_test8_final/20241129_024547/combined_results.gif


In [None]:
# check with perpetual loss later

In [None]:
# import os
# import csv
# from datetime import datetime
# import matplotlib.pyplot as plt
# import torch
# from torch.utils.data import DataLoader
# from torchvision.utils import save_image

# def train_model_with_loss_tracking(
#     model, train_loader, valid_loader, num_epochs, device, opt_sched, save_dir="results"
# ):
#     """
#     Trains a multi-task model with Conditional GANs, structural consistency, and perceptual loss.

#     Args:
#         model: The multi-task model to train.
#         train_loader: DataLoader for training data.
#         valid_loader: DataLoader for validation data.
#         num_epochs: Number of epochs to train.
#         device: Device for training ("cuda" or "cpu").
#         opt_sched: Dictionary of optimizers and schedulers.
#         save_dir: Directory to save results.

#     Returns:
#         train_losses, valid_losses: Lists of losses for training and validation.
#     """
#     # Create directories for saving results
#     timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
#     save_dir = os.path.join(save_dir, timestamp)
#     os.makedirs(save_dir, exist_ok=True)

#     # Prepare CSV file
#     csv_path = os.path.join(save_dir, f"loss_tracking_{timestamp}.csv")
#     with open(csv_path, "w", newline="") as f:
#         writer = csv.writer(f)
#         writer.writerow([
#             "epoch", "train_seg_loss", "train_depth_loss", "train_combined_loss",
#             "train_depth_sidl", "train_depth_smooth", "train_seg_iou", "train_seg_perceptual_loss",
#             "train_depth_perceptual_loss", "valid_seg_loss", "valid_depth_loss", "valid_combined_loss",
#             "valid_depth_sidl", "valid_depth_smooth", "valid_seg_iou", "valid_seg_perceptual_loss",
#             "valid_depth_perceptual_loss"
#         ])

#     # Initialize tracking variables
#     train_losses = {
#         "seg": [], "depth": [], "combined": [], "iou": [], "depth_sidl": [], "depth_smooth": [],
#         "seg_perceptual": [], "depth_perceptual": []
#     }
#     valid_losses = {
#         "seg": [], "depth": [], "combined": [], "iou": [], "depth_sidl": [], "depth_smooth": [],
#         "seg_perceptual": [], "depth_perceptual": []
#     }
#     best_combined_loss = float("inf")
#     gif_frames = []

#     # Perceptual Loss (example using VGG features)
#     perceptual_loss_fn = PerceptualLoss(pretrained_model="vgg16").to(device)

#     for epoch in range(num_epochs):
#         model.train()
#         epoch_train = {key: 0.0 for key in train_losses.keys()}
#         num_batches = 0

#         for batch in train_loader:
#             inputs, seg_labels, depth_labels = (
#                 batch["left"].to(device),
#                 batch["mask"].to(device),
#                 batch["depth"].to(device)
#             )
#             latent_noise = torch.randn(inputs.size(0), 3).to(device)

#             # Zero gradients
#             for optimizer in opt_sched["optimizers"].values():
#                 optimizer.zero_grad()

#             # Forward pass
#             outputs = model(inputs, input_size=inputs.size()[-2:])

#             # Loss calculations
#             seg_loss_task = nn.CrossEntropyLoss()(outputs["seg_output"], seg_labels) + dice_loss(outputs["seg_output"], seg_labels)
#             seg_perceptual_loss = perceptual_loss_fn(outputs["seg_output"], seg_labels.unsqueeze(1))
#             seg_loss = seg_loss_task + 0.1 * seg_perceptual_loss

#             depth_loss_sidl = scale_invariant_depth_loss(outputs["depth_output"], depth_labels)
#             depth_loss_huber = inv_huber_loss(outputs["depth_output"], depth_labels)
#             depth_loss_smooth = depth_smoothness_loss(outputs["depth_output"], inputs)
#             depth_perceptual_loss = perceptual_loss_fn(outputs["depth_output"], depth_labels)
#             depth_loss = depth_loss_sidl + depth_loss_huber + depth_loss_smooth + 0.1 * depth_perceptual_loss

#             combined_loss = seg_loss + depth_loss

#             # Backpropagation
#             combined_loss.backward()
#             for optimizer in opt_sched["optimizers"].values():
#                 optimizer.step()

#             # Update training metrics
#             epoch_train["seg"] += seg_loss.item()
#             epoch_train["depth"] += depth_loss.item()
#             epoch_train["combined"] += combined_loss.item()
#             epoch_train["iou"] += mean_iou(outputs["seg_output"], seg_labels, num_classes=20).item()
#             epoch_train["depth_sidl"] += depth_loss_sidl.item()
#             epoch_train["depth_smooth"] += depth_loss_smooth.item()
#             epoch_train["seg_perceptual"] += seg_perceptual_loss.item()
#             epoch_train["depth_perceptual"] += depth_perceptual_loss.item()
#             num_batches += 1

#         # Average training metrics
#         for key in epoch_train.keys():
#             train_losses[key].append(epoch_train[key] / num_batches)

#         # Validation loop
#         model.eval()
#         epoch_valid = {key: 0.0 for key in valid_losses.keys()}
#         num_valid_batches = 0

#         with torch.no_grad():
#             for batch in valid_loader:
#                 inputs, seg_labels, depth_labels = (
#                     batch["left"].to(device),
#                     batch["mask"].to(device),
#                     batch["depth"].to(device)
#                 )
#                 latent_noise = torch.randn(inputs.size(0), 3).to(device)

#                 # Forward pass
#                 outputs = model(inputs, input_size=inputs.size()[-2:])

#                 # Validation loss calculations
#                 seg_loss_task = nn.CrossEntropyLoss()(outputs["seg_output"], seg_labels) + dice_loss(outputs["seg_output"], seg_labels)
#                 seg_perceptual_loss = perceptual_loss_fn(outputs["seg_output"], seg_labels.unsqueeze(1))
#                 seg_loss = seg_loss_task + 0.1 * seg_perceptual_loss

#                 depth_loss_sidl = scale_invariant_depth_loss(outputs["depth_output"], depth_labels)
#                 depth_loss_huber = inv_huber_loss(outputs["depth_output"], depth_labels)
#                 depth_loss_smooth = depth_smoothness_loss(outputs["depth_output"], inputs)
#                 depth_perceptual_loss = perceptual_loss_fn(outputs["depth_output"], depth_labels)
#                 depth_loss = depth_loss_sidl + depth_loss_huber + depth_loss_smooth + 0.1 * depth_perceptual_loss

#                 combined_loss = seg_loss + depth_loss

#                 # Update validation metrics
#                 epoch_valid["seg"] += seg_loss.item()
#                 epoch_valid["depth"] += depth_loss.item()
#                 epoch_valid["combined"] += combined_loss.item()
#                 epoch_valid["iou"] += mean_iou(outputs["seg_output"], seg_labels, num_classes=20).item()
#                 epoch_valid["depth_sidl"] += depth_loss_sidl.item()
#                 epoch_valid["depth_smooth"] += depth_loss_smooth.item()
#                 epoch_valid["seg_perceptual"] += seg_perceptual_loss.item()
#                 epoch_valid["depth_perceptual"] += depth_perceptual_loss.item()
#                 num_valid_batches += 1

#         # Average validation metrics
#         for key in epoch_valid.keys():
#             valid_losses[key].append(epoch_valid[key] / num_valid_batches)

#         # Save best model
#         valid_combined_loss = epoch_valid["combined"] / num_valid_batches
#         if valid_combined_loss < best_combined_loss:
#             best_combined_loss = valid_combined_loss
#             torch.save(model.state_dict(), os.path.join(save_dir, "best_model.pth"))

#         # Append metrics to CSV
#         with open(csv_path, "a", newline="") as f:
#             writer = csv.writer(f)
#             writer.writerow([
#                 epoch + 1,
#                 epoch_train["seg"] / num_batches,
#                 epoch_train["depth"] / num_batches,
#                 epoch_train["combined"] / num_batches,
#                 epoch_train["depth_sidl"] / num_batches,
#                 epoch_train["depth_smooth"] / num_batches,
#                 epoch_train["iou"] / num_batches,
#                 epoch_train["seg_perceptual"] / num_batches,
#                 epoch_train["depth_perceptual"] / num_batches,
#                 epoch_valid["seg"] / num_valid_batches,
#                 epoch_valid["depth"] / num_valid_batches,
#                 epoch_valid["combined"] / num_valid_batches,
#                 epoch_valid["depth_sidl"] / num_valid_batches,
#                 epoch_valid["depth_smooth"] / num_valid_batches,
#                 epoch_valid["iou"] / num_valid_batches,
#                 epoch_valid["seg_perceptual"] / num_valid_batches,
#                 epoch_valid["depth_perceptual"] / num_valid_batches,
#             ])

#         # Update schedulers
#         for scheduler in opt_sched["schedulers"].values():
#             scheduler.step()
            
#     plot_all_losses(train_losses,valid_losses)

    

#     return train_losses, valid_losses


In [None]:
# import os
# import csv
# import torch
# import torch.nn as nn
# from datetime import datetime
# from torchvision.utils import save_image
# from PIL import Image
# import matplotlib.pyplot as plt

# # Updated Train Function
# def train_model_with_loss_tracking_and_gif(
#     model, train_loader, valid_loader, num_epochs, device, save_dir="training_output_bicycle_and_pix2pix"
# ):
#     # Create directories for saving models and outputs
#     os.makedirs(save_dir, exist_ok=True)
#     timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
#     save_dir = os.path.join(save_dir, timestamp)
#     os.makedirs(save_dir, exist_ok=True)

#     csv_path = os.path.join(save_dir, f"loss_tracking_{timestamp}.csv")
#     gif_path = os.path.join(save_dir, f"training_visualization_{timestamp}.gif")

#     # Initialize CSV for saving loss data
#     with open(csv_path, "w", newline="") as f:
#         writer = csv.writer(f)
#         writer.writerow([
#             "epoch", "train_seg_loss", "train_depth_loss", "train_combined_loss",
#             "train_depth_sidl", "train_depth_smooth", "train_seg_iou",
#             "valid_seg_loss", "valid_depth_loss", "valid_combined_loss",
#             "valid_depth_sidl", "valid_depth_smooth", "valid_seg_iou"
#         ])

#     best_combined_loss = float("inf")  # Initialize best combined loss for saving the best model
#     train_losses = {"seg": [], "depth": [], "combined": [], "iou": [], "depth_sidl": [], "depth_smooth": []}
#     valid_losses = {"seg": [], "depth": [], "combined": [], "iou": [], "depth_sidl": [], "depth_smooth": []}
#     gif_frames = []

#     for epoch in range(num_epochs):
#         torch.cuda.empty_cache()
#         model.train()
#         epoch_train = {key: 0.0 for key in train_losses.keys()}
#         num_batches = 0

#         # Training Loop
#         for batch in train_loader:
#             inputs, seg_labels, depth_labels = batch["left"].to(device), batch["mask"].to(device), batch["depth"].to(device)
#             latent_noise = torch.randn(inputs.size(0), 3).to(device)

#             model.optimizer_stage1.zero_grad()
#             model.optimizer_stage2.zero_grad()

#             # Forward Pass
#             outputs = model(inputs, seg_labels, depth_labels, latent_noise)
#             seg_output = outputs["seg_output"]
#             depth_output = outputs["depth_output"]

#             # Loss Calculations
#             seg_loss = nn.CrossEntropyLoss()(seg_output, seg_labels.squeeze(1))
#             seg_dice = dice_loss(seg_output, seg_labels)
#             seg_iou = mean_iou(seg_output, seg_labels, num_classes=20)
#             seg_loss_total = 0.6 * seg_loss + 0.4 * seg_dice

#             depth_sidl = scale_invariant_depth_loss(depth_output, depth_labels)
#             depth_smooth = depth_smoothness_loss(depth_output, inputs)
#             depth_loss_total = depth_sidl + depth_smooth

#             # Combined Loss
#             total_loss = seg_loss_total + depth_loss_total
#             total_loss.backward()

#             # Optimizers Step
#             model.optimizer_stage1.step()
#             model.optimizer_stage2.step()

#             # Accumulate Training Metrics
#             epoch_train["seg"] += seg_loss.item()
#             epoch_train["depth"] += depth_loss_total.item()
#             epoch_train["combined"] += total_loss.item()
#             epoch_train["iou"] += seg_iou.item()
#             epoch_train["depth_sidl"] += depth_sidl.item()
#             epoch_train["depth_smooth"] += depth_smooth.item()
#             num_batches += 1

#             # Save training images for visualization
#             if num_batches % 10 == 0:
#                 img_grid = torch.cat([inputs[0], seg_output[0].argmax(0, keepdim=True), depth_output[0]], dim=2)
#                 save_image(img_grid, os.path.join(save_dir, f"train_{epoch}_{num_batches}.png"))
#                 gif_frames.append(Image.open(os.path.join(save_dir, f"train_{epoch}_{num_batches}.png")))

#         model.scheduler_stage1.step()
#         model.scheduler_stage2.step(epoch_train["combined"] / num_batches)

#         # Average Training Losses
#         for key in epoch_train.keys():
#             train_losses[key].append(epoch_train[key] / num_batches)

#         # Validation Loop
#         model.eval()
#         epoch_valid = {key: 0.0 for key in valid_losses.keys()}
#         num_valid_batches = 0
#         with torch.no_grad():
#             for batch in valid_loader:
#                 inputs, seg_labels, depth_labels = batch["left"].to(device), batch["mask"].to(device), batch["depth"].to(device)
#                 latent_noise = torch.randn(inputs.size(0), 3).to(device)

#                 outputs = model(inputs, seg_labels, depth_labels, latent_noise)
#                 seg_output = outputs["seg_output"]
#                 depth_output = outputs["depth_output"]

#                 # Validation Loss Calculations
#                 seg_loss = nn.CrossEntropyLoss()(seg_output, seg_labels.squeeze(1))
#                 seg_dice = dice_loss(seg_output, seg_labels)
#                 seg_iou = mean_iou(seg_output, seg_labels, num_classes=20)
#                 seg_loss_total = 0.6 * seg_loss + 0.4 * seg_dice

#                 depth_sidl = scale_invariant_depth_loss(depth_output, depth_labels)
#                 depth_smooth = depth_smoothness_loss(depth_output, inputs)
#                 depth_loss_total = depth_sidl + depth_smooth

#                 # Accumulate Validation Metrics
#                 epoch_valid["seg"] += seg_loss.item()
#                 epoch_valid["depth"] += depth_loss_total.item()
#                 epoch_valid["combined"] += (seg_loss_total + depth_loss_total).item()
#                 epoch_valid["iou"] += seg_iou.item()
#                 epoch_valid["depth_sidl"] += depth_sidl.item()
#                 epoch_valid["depth_smooth"] += depth_smooth.item()
#                 num_valid_batches += 1
            
#             frame = save_training_visualization_as_gif2(epoch, inputs, seg_output, depth_output, seg_labels, depth_labels)
#             gif_frames.append(frame)
              

#         for key in epoch_valid.keys():
#             valid_losses[key].append(epoch_valid[key] / num_valid_batches)

#         # Save Model if Validation Loss Improves
#         valid_combined_loss = epoch_valid["combined"] / num_valid_batches
#         if valid_combined_loss < best_combined_loss:
#             best_combined_loss = valid_combined_loss
#             torch.save(model.state_dict(), os.path.join(save_dir, "best_model.pth"))
#             print(f"Best model saved at epoch {epoch+1} with combined loss {best_combined_loss:.4f}")
            
#         if epoch%10==0:
#             gif_path2 =os.path.join(save_dir,f"viz_epoch_{epoch}.gif")
#             gif_frames[0].save(gif_path2, save_all=True, append_images=gif_frames[1:], duration=500, loop=0)

            

#         # Append Validation Metrics to CSV
#         with open(csv_path, "a", newline="") as f:
#             writer = csv.writer(f)
#             writer.writerow([
#                 epoch + 1,
#                 epoch_train["seg"] / num_batches,
#                 epoch_train["depth"] / num_batches,
#                 epoch_train["combined"] / num_batches,
#                 epoch_train["depth_sidl"] / num_batches,
#                 epoch_train["depth_smooth"] / num_batches,
#                 epoch_train["iou"] / num_batches,
#                 epoch_valid["seg"] / num_valid_batches,
#                 epoch_valid["depth"] / num_valid_batches,
#                 epoch_valid["combined"] / num_valid_batches,
#                 epoch_valid["depth_sidl"] / num_valid_batches,
#                 epoch_valid["depth_smooth"] / num_valid_batches,
#                 epoch_valid["iou"] / num_valid_batches,
#             ])

#     # Save GIF
#     gif_frames[0].save(
#         gif_path, save_all=True, append_images=gif_frames[1:], duration=200, loop=0
#     )

#     return train_losses, valid_losses


In [28]:


# def train_model_with_loss_tracking_and_gif(
#     model, train_loader, valid_loader, num_epochs, device, save_dir="training_output_bicycle_and_pix2pix"):
#     # Create directory for saving models and outputs
#     os.makedirs(save_dir, exist_ok=True)
#     timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
#     save_dir = os.path.join(save_dir, timestamp)
#     os.makedirs(save_dir, exist_ok=True)

#     csv_path = os.path.join(save_dir, f"loss_tracking_{timestamp}.csv")
#     gif_path = os.path.join(save_dir, f"training_visualization_{timestamp}.gif")

#     # Initialize CSV for saving loss data
#     with open(csv_path, "w", newline="") as f:
#         writer = csv.writer(f)
#         writer.writerow([
#             "epoch", "train_seg_loss", "train_depth_loss", "train_combined_loss",
#             "train_depth_sidl", "train_depth_inv_huber", "train_depth_contrastive", "train_depth_smooth",
#             "valid_seg_loss", "valid_depth_loss", "valid_combined_loss",
#             "valid_depth_sidl", "valid_depth_inv_huber", "valid_depth_contrastive", "valid_depth_smooth"
#         ])

#     best_combined_loss = float("inf")  # Initialize best combined loss for saving the best model

#     train_losses = {"seg": [], "depth": [], "combined": [], "iou": [], "depth_sidl": [], "depth_smooth": []}
#     valid_losses = {"seg": [], "depth": [], "combined": [], "iou": [], "depth_sidl": [], "depth_smooth": []}
#     # , "depth_inv_huber": [], "depth_contrastive": []

#     gif_frames = []
#     num_classes = 20
#     # Optimizer for latent noise
    

#     for epoch in range(num_epochs):
#         torch.cuda.empty_cache()
#         torch.autograd.set_detect_anomaly(True)

#         model.train()
#         # epoch_train_seg_loss = 0
#         # epoch_train_depth_loss = 0
#         # epoch_train_iou = 0
#         # epoch_train_combined_loss = 0
#         # epoch_train_depth_sidl = 0
#         # epoch_train_depth_inv_huber = 0
#         # epoch_train_depth_contrastive = 0
#         # epoch_train_depth_smooth = 0
#         epoch_train = {key: 0.0 for key in train_losses.keys()}
#         num_batches = 0

#         reconstruction_layer = nn.Conv2d(256, 3, kernel_size=1).to(device)
        
#         # scaler = torch.cuda.amp.GradScaler()

#         # Training Loop
#         for batch in train_loader:
#             inputs, seg_labels, depth_labels = batch["left"].to(device), batch["mask"].to(device), batch["depth"].to(device)
#             print(inputs.shape,seg_labels.shape,depth_labels.shape) # torch.Size([8, 3, 200, 512]) torch.Size([8, 1, 200, 512]) torch.Size([8, 1, 1, 200, 512])
#             return
        
                       


#             # Forward pass
#             seg_output, depth_output, backbone_features = model(...)
            


#             seg_loss = nn.CrossEntropyLoss()(seg_output, seg_labels)
#             seg_dice = dice_loss(seg_output, seg_labels)
#             seg_iou = mean_iou(seg_output, seg_labels, num_classes)
#             seg_loss_total = 0.6 * seg_loss  + 0.4 * seg_dice
            
#             depth_sidl = scale_invariant_depth_loss(depth_output, depth_labels)
#             depth_inv_huber = inv_huber_loss(depth_output, depth_labels)
#             depth_smooth = depth_smoothness_loss(depth_output, inputs)
#             depth_loss_total = depth_sidl + depth_inv_huber + depth_smooth
            
           
           
#             # Combined Loss
#             total_loss = bicycle_loss + pix2pix_total_loss

#             # Single backward pass
#             total_loss.backward()

#             # Update both optimizers
#             model.optimizer_stage1.step()
#             model.optimizer_stage2.step()
#             latent_optimizer.step()

#             # Accumulate Training Metrics
#             epoch_train["seg"] += seg_loss.item()
#             epoch_train["depth"] += (depth_sidl + depth_smooth).item()
#             epoch_train["combined"] += total_loss.item()
#             epoch_train["iou"] += seg_iou.item()
#             epoch_train["depth_sidl"] += depth_sidl.item()
#             epoch_train["depth_smooth"] += depth_smooth.item()
#             num_batches += 1
            
#         model.scheduler_stage1.step()
#         model.scheduler_stage2.step(epoch_train["combined"]/num_batches)


#         # Average Training Losses
#         for key in epoch_train.keys():
#             train_losses[key].append(epoch_train[key] / num_batches)

#         print(
#             f"Epoch {epoch+1}/{num_epochs} - Train Seg Loss: {epoch_train['seg']:.4f}, "
#             f"Train Depth Loss: {epoch_train['depth']:.4f}, Train Combined Loss: {epoch_train['combined']:.4f}, "
#             f"Train mIOU: {epoch_train['iou']:.4f}, Train sidl Loss: {epoch_train['depth_sidl']:.4f}, "
#             f"Train depth smooth: {epoch_train['depth_smooth']:.4f}"
#     )       

#         # Validation Loop
#         model.eval()
#         epoch_valid = {key: 0.0 for key in valid_losses.keys()}
#         num_valid_batches = 0

#         with torch.no_grad():
#             for batch in valid_loader:
#                 # print("inside valid")
#                 inputs, seg_labels, depth_labels = batch["left"].to(device), batch["mask"].to(device), batch["depth"].to(device)

#                 # Ensure depth_labels and segmentation labels have correct dimensions
                

               


#                 seg_output_old =seg_output
#                 # Resize seg_output to match the spatial dimensions of seg_labels
#                 seg_output_resized = F.interpolate(seg_output, size=seg_labels.shape[1:], mode='bilinear', align_corners=False)
#                 seg_output = seg_output_resized

#                 depth_output_old = depth_output
#                 depth_output_resized = F.interpolate(depth_output, size=depth_labels.shape[-2:], mode='bilinear', align_corners=False)
#                 depth_output =depth_output_resized


#                 # Segmentation Loss
#                 seg_loss = nn.CrossEntropyLoss()(seg_output, seg_labels)
#                 seg_dice = dice_loss(seg_output, seg_labels)
#                 seg_iou = mean_iou(seg_output, seg_labels, num_classes)
#                 seg_loss_total = 0.6 * seg_loss  + 0.4 * seg_dice
                
#                 depth_sidl = scale_invariant_depth_loss(depth_output, depth_labels)
#                 depth_inv_huber = inv_huber_loss(depth_output, depth_labels)
#                 depth_smooth = depth_smoothness_loss(depth_output, inputs)
#                 depth_loss_total = depth_sidl + depth_inv_huber + depth_smooth

#                 pix2pix_loss = seg_loss_total + depth_loss_total

#                 # Combined Validation Loss
#                 combined_loss = pix2pix_loss

#                 # Accumulate Validation Metrics
#                 epoch_valid["seg"] += seg_loss.item()
#                 epoch_valid["depth"] += (depth_sidl + depth_smooth).item()
#                 epoch_valid["combined"] += combined_loss.item()
#                 epoch_valid["iou"] += seg_iou.item()
#                 epoch_valid["depth_sidl"] += depth_sidl.item()
#                 epoch_valid["depth_smooth"] += depth_smooth.item()
                
#                 num_valid_batches += 1
                
#             frame = save_training_visualization_as_gif2(epoch, inputs, seg_output, depth_output, seg_labels, depth_labels)
#             gif_frames.append(frame)
                
                
#         # Calculate epoch averages
#         # Average Validation Losses
#         for key in epoch_valid.keys():
#             valid_losses[key].append(epoch_valid[key] / num_valid_batches)

#         print(
#             f"Epoch {epoch+1}/{num_epochs} - Valid Seg Loss: {epoch_valid['seg']:.4f}, "
#             f"Valid Depth Loss: {epoch_valid['depth']:.4f}, Valid Combined Loss: {epoch_valid['combined']:.4f}, "
#             f"Valid mIOU: {epoch_valid['iou']:.4f}, Valid sidl Loss: {epoch_valid['depth_sidl']:.4f}, "
#             f"Valid depth smooth: {epoch_valid['depth_smooth']:.4f}"
#         )

#         # Write the losses to CSV
#         with open(csv_path, "a", newline="") as f:
#             writer = csv.writer(f)
#             writer.writerow([
#                 epoch + 1,
#                 train_losses["seg"], train_losses["depth"], train_losses["combined"],
#                 train_losses["depth_sidl"], 0,0,
#                 # avg_train_depth_inv_huber, avg_train_depth_contrastive,
#                 train_losses["depth_smooth"],
#                 valid_losses["seg"], valid_losses["depth"], valid_losses['combined'],
#                 valid_losses["depth_sidl"],0,0,
#                 # avg_valid_depth_inv_huber, avg_valid_depth_contrastive, 
#                 valid_losses["depth_smooth"]
#             ])

       
#         # Save best model
#         if valid_losses["combined"][-1] < best_combined_loss:
#             best_combined_loss = valid_losses["combined"][-1]
#             torch.save(model, os.path.join(save_dir, "best_model_resnetBackbone.pth"))
#             print(f"Best model saved at epoch {epoch+1} with combined loss {best_combined_loss:.4f}")
            
#         if epoch%10==0:
#             gif_path2 =os.path.join(save_dir,f"viz_epoch_{epoch}.gif")
#             gif_frames[0].save(gif_path2, save_all=True, append_images=gif_frames[1:], duration=500, loop=0)

    
    
#     plot_loss(train_losses, valid_losses, save_dir)
#     gif_frames[0].save(gif_path, save_all=True, append_images=gif_frames[1:], duration=500, loop=0)

    
    
    
#     return train_losses,valid_losses


In [29]:

# # Create your model instance
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# # device = 'cpu'
# model = MultiTaskModel(num_seg_classes=20, feature_channels=256).to(device)

In [30]:
# # Set the number of epochs
# num_epochs = 10

# # Call the training function
# train_losses, valid_losses = train_model_with_loss_tracking_and_gif(
#     model=model,
#     train_loader=train_loader,
#     valid_loader=valid_loader,
#     num_epochs=num_epochs,
#     device=device,
#     save_dir="test7_res"
# )

In [31]:


# def train_model_with_loss_tracking_and_gif(
#     model, train_loader, valid_loader, num_epochs, device, save_dir="training_output_bicycle_and_pix2pix"):
#     # Create directory for saving models and outputs
#     os.makedirs(save_dir, exist_ok=True)
#     timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
#     save_dir = os.path.join(save_dir, timestamp)
#     os.makedirs(save_dir, exist_ok=True)

#     csv_path = os.path.join(save_dir, f"loss_tracking_{timestamp}.csv")
#     gif_path = os.path.join(save_dir, f"training_visualization_{timestamp}.gif")

#     # Initialize CSV for saving loss data
#     with open(csv_path, "w", newline="") as f:
#         writer = csv.writer(f)
#         writer.writerow([
#             "epoch", "train_seg_loss", "train_depth_loss", "train_combined_loss",
#             "train_depth_sidl", "train_depth_inv_huber", "train_depth_contrastive", "train_depth_smooth",
#             "valid_seg_loss", "valid_depth_loss", "valid_combined_loss",
#             "valid_depth_sidl", "valid_depth_inv_huber", "valid_depth_contrastive", "valid_depth_smooth"
#         ])

#     best_combined_loss = float("inf")  # Initialize best combined loss for saving the best model

#     train_losses = {"seg": [], "depth": [], "combined": [], "iou": [], "depth_sidl": [], "depth_smooth": []}
#     valid_losses = {"seg": [], "depth": [], "combined": [], "iou": [], "depth_sidl": [], "depth_smooth": []}
#     # , "depth_inv_huber": [], "depth_contrastive": []

#     gif_frames = []
#     num_classes = 20
#     # Optimizer for latent noise
    

#     for epoch in range(num_epochs):
#         torch.cuda.empty_cache()
#         torch.autograd.set_detect_anomaly(True)

#         model.train()
#         # epoch_train_seg_loss = 0
#         # epoch_train_depth_loss = 0
#         # epoch_train_iou = 0
#         # epoch_train_combined_loss = 0
#         # epoch_train_depth_sidl = 0
#         # epoch_train_depth_inv_huber = 0
#         # epoch_train_depth_contrastive = 0
#         # epoch_train_depth_smooth = 0
#         epoch_train = {key: 0.0 for key in train_losses.keys()}
#         num_batches = 0

#         reconstruction_layer = nn.Conv2d(256, 3, kernel_size=1).to(device)
        
#         # scaler = torch.cuda.amp.GradScaler()

#         # Training Loop
#         for batch in train_loader:
#             inputs, seg_labels, depth_labels = batch["left"].to(device), batch["mask"].to(device), batch["depth"].to(device)

#             # Ensure depth_labels and segmentation labels have correct dimensions
#             if depth_labels.dim() == 5:
#                 depth_labels = depth_labels.squeeze(2)
#             if seg_labels.dim() == 4 and seg_labels.shape[1] == 1:
#                 seg_labels = seg_labels.squeeze(1)

#             # Transform depth labels
#             # depth_labels = torch.log(depth_labels.flatten(start_dim=1)) / 5
#             # depth_labels = depth_labels.view_as(depth_labels)  # Restore shape
#             # depth_labels = torch.clamp(depth_labels, min=1e-5) 
#             # depth_labels = torch.log(depth_labels + 1e-5) / 5  # Avoid log(0)

#             # print(f'seg_labels shape : {seg_labels.shape}')
#             # print(f'depth_labels shape: {depth_labels.shape}')

#             # Start with random noise as latent condition
#             if epoch == 0:
#                 latent_noise = torch.randn_like(inputs).to(device)
#                 # print(f"latent_noise: {latent_noise.shape}")
#                 latent_noise.requires_grad = True  # Make it trainable
#                 latent_optimizer = torch.optim.Adam([latent_noise], lr=1e-3)
            
            


#             # Stage 1: Train BicycleGAN (Backbone Features)
            

#             # Reset gradients for both optimizers
#             model.optimizer_stage1.zero_grad()
#             model.optimizer_stage2.zero_grad()
#             latent_optimizer.zero_grad()

#             # Forward pass
#             seg_output, depth_output, backbone_features = model(inputs, latent_noise)
#             # print(f'seg_ouput shape : {seg_output.shape}')
#             # print(f'depth_output shape: {depth_output.shape}')
#             # print(backbone_features.shape)
#             # return


#             seg_output_old =seg_output
#             # Resize seg_output to match the spatial dimensions of seg_labels
#             seg_output_resized = F.interpolate(seg_output, size=seg_labels.shape[1:], mode='bilinear', align_corners=False)
#             seg_output = seg_output_resized

#             # print(f"depth_output shape before resize: {depth_output.shape}")
#             # print(f"depth_labels shape: {depth_labels.shape}")
#             # return

#             depth_output_old = depth_output
#             depth_output_resized = F.interpolate(depth_output, size=depth_labels.shape[-2:], mode='bilinear', align_corners=False)
#             depth_output = depth_output_resized


#             # Pix2Pix Losses
#             seg_loss = nn.CrossEntropyLoss()(seg_output, seg_labels)
#             seg_dice = dice_loss(seg_output, seg_labels)
#             seg_iou = mean_iou(seg_output, seg_labels, num_classes)
#             seg_loss_total = 0.6 * seg_loss  + 0.4 * seg_dice
            
#             depth_sidl = scale_invariant_depth_loss(depth_output, depth_labels)
#             depth_inv_huber = inv_huber_loss(depth_output, depth_labels)
#             depth_smooth = depth_smoothness_loss(depth_output, inputs)
#             depth_loss_total = depth_sidl + depth_inv_huber + depth_smooth
            
#             pix2pix_loss = seg_loss_total + depth_loss_total

#             # Reconstruction loss
#             # inputs_resized = F.interpolate(inputs, size=(backbone_features.size(2), backbone_features.size(3)))
#             # reconstructed_image = reconstruction_layer(backbone_features)
#             # recon_loss = nn.L1Loss()(reconstructed_image, inputs_resized)
#             # adaptive_weight = 1 / (1 + torch.exp(-recon_loss))
#             # adaptive_weight_value = adaptive_weight.item() 

#             # loss_stage1 = nn.MSELoss()(real_validity, torch.ones_like(real_validity).to(device)) + recon_loss
#             # loss_stage1.backward(retain_graph=True)
#             # model.optimizer_stage1.step()

#             # Pix2Pix Adversarial Losses
#             seg_validity = model.segmentation_discriminator(seg_output)
#             depth_validity = model.depth_discriminator(depth_output)
#             adv_seg_loss = nn.MSELoss()(seg_validity, torch.ones_like(seg_validity))
#             adv_depth_loss = nn.MSELoss()(depth_validity, torch.ones_like(depth_validity))
#             pix2pix_total_loss = pix2pix_loss + adv_seg_loss + adv_depth_loss


#             # BicycleGAN Loss with Pix2Pix Condition
#             # real_validity = model.bicycle_discriminator(backbone_features)
#             # recon_loss = nn.L1Loss()(backbone_features, inputs)
#             # bicycle_loss = nn.MSELoss()(real_validity, torch.ones_like(real_validity)) + recon_loss
#             # conditional_bicycle_loss = bicycle_loss + pix2pix_loss
#             # conditional_bicycle_loss.backward(retain_graph=True)
#             # model.optimizer_stage1.step()

#             # BicycleGAN Loss with Pix2Pix Condition
#             real_validity = model.bicycle_discriminator(backbone_features,latent_noise)

#             # Resize inputs to match backbone_features
#             inputs_resized = F.interpolate(inputs, size=backbone_features.shape[-2:], mode='bilinear', align_corners=False)

#             # print(f"backbone_features shape: {backbone_features.shape}, inputs shape: {inputs.shape}")
#             # print(f"inputs_resized shape: {inputs_resized.shape}")
#             # recon_loss = nn.L1Loss()(backbone_features, inputs_resized)
#             bicycle_loss = adv_seg_loss + adv_depth_loss
#             # + recon_loss

#             # Combined Loss
#             total_loss = bicycle_loss + pix2pix_total_loss

#             # Single backward pass
#             total_loss.backward()

#             # Update both optimizers
#             model.optimizer_stage1.step()
#             model.optimizer_stage2.step()
#             latent_optimizer.step()

#             # Accumulate Training Metrics
#             epoch_train["seg"] += seg_loss.item()
#             epoch_train["depth"] += (depth_sidl + depth_smooth).item()
#             epoch_train["combined"] += total_loss.item()
#             epoch_train["iou"] += seg_iou.item()
#             epoch_train["depth_sidl"] += depth_sidl.item()
#             epoch_train["depth_smooth"] += depth_smooth.item()
#             num_batches += 1
            
#         model.scheduler_stage1.step()
#         model.scheduler_stage2.step(epoch_train["combined"]/num_batches)


#         # Average Training Losses
#         for key in epoch_train.keys():
#             train_losses[key].append(epoch_train[key] / num_batches)

#         print(
#             f"Epoch {epoch+1}/{num_epochs} - Train Seg Loss: {epoch_train['seg']:.4f}, "
#             f"Train Depth Loss: {epoch_train['depth']:.4f}, Train Combined Loss: {epoch_train['combined']:.4f}, "
#             f"Train mIOU: {epoch_train['iou']:.4f}, Train sidl Loss: {epoch_train['depth_sidl']:.4f}, "
#             f"Train depth smooth: {epoch_train['depth_smooth']:.4f}"
#     )       

#         # Validation Loop
#         model.eval()
#         # epoch_valid_seg_loss = 0
#         # epoch_valid_depth_loss = 0
#         # epoch_valid_iou =0
#         # epoch_valid_combined_loss = 0
#         # epoch_valid_depth_sidl = 0
#         # epoch_valid_depth_inv_huber = 0
#         # epoch_valid_depth_contrastive = 0
#         # epoch_valid_depth_smooth = 0
#         epoch_valid = {key: 0.0 for key in valid_losses.keys()}
#         num_valid_batches = 0

#         with torch.no_grad():
#             for batch in valid_loader:
#                 # print("inside valid")
#                 inputs, seg_labels, depth_labels = batch["left"].to(device), batch["mask"].to(device), batch["depth"].to(device)

#                 # Ensure depth_labels and segmentation labels have correct dimensions
#                 if depth_labels.dim() == 5:
#                     depth_labels = depth_labels.squeeze(2)
#                 if seg_labels.dim() == 4 and seg_labels.shape[1] == 1:
#                     seg_labels = seg_labels.squeeze(1)

#                 # Transform depth labels
#                 # depth_labels = torch.log(depth_labels.flatten(start_dim=1)) / 5
#                 # depth_labels = depth_labels.view_as(depth_labels)  # Restore shape
#                 # depth_labels = torch.clamp(depth_labels, min=1e-5) 
#                 # depth_labels = torch.log(depth_labels + 1e-5) / 5  # Avoid log(0)

#                 # Latent noise for validation
#                 latent_noise = torch.randn_like(inputs).to(device)
#                 seg_output, depth_output, backbone_features = model(inputs, latent_noise)

                
                

#                 seg_output_old =seg_output
#                 # Resize seg_output to match the spatial dimensions of seg_labels
#                 seg_output_resized = F.interpolate(seg_output, size=seg_labels.shape[1:], mode='bilinear', align_corners=False)
#                 seg_output = seg_output_resized

#                 depth_output_old = depth_output
#                 depth_output_resized = F.interpolate(depth_output, size=depth_labels.shape[-2:], mode='bilinear', align_corners=False)
#                 depth_output =depth_output_resized


#                 # Segmentation Loss
#                 seg_loss = nn.CrossEntropyLoss()(seg_output, seg_labels)
#                 seg_dice = dice_loss(seg_output, seg_labels)
#                 seg_iou = mean_iou(seg_output, seg_labels, num_classes)
#                 seg_loss_total = 0.6 * seg_loss  + 0.4 * seg_dice
                
#                 depth_sidl = scale_invariant_depth_loss(depth_output, depth_labels)
#                 depth_inv_huber = inv_huber_loss(depth_output, depth_labels)
#                 depth_smooth = depth_smoothness_loss(depth_output, inputs)
#                 depth_loss_total = depth_sidl + depth_inv_huber + depth_smooth

#                 pix2pix_loss = seg_loss_total + depth_loss_total

#                 # Combined Validation Loss
#                 combined_loss = pix2pix_loss

#                 # Accumulate Validation Metrics
#                 epoch_valid["seg"] += seg_loss.item()
#                 epoch_valid["depth"] += (depth_sidl + depth_smooth).item()
#                 epoch_valid["combined"] += combined_loss.item()
#                 epoch_valid["iou"] += seg_iou.item()
#                 epoch_valid["depth_sidl"] += depth_sidl.item()
#                 epoch_valid["depth_smooth"] += depth_smooth.item()
                
#                 num_valid_batches += 1
                
#                 # epoch, inputs, seg_output, depth_output, seg_labels, depth_labels, gif_frames
#             frame = save_training_visualization_as_gif2(epoch, inputs, seg_output, depth_output, seg_labels, depth_labels)
#             gif_frames.append(frame)
                
                
#         # Calculate epoch averages
#         # Average Validation Losses
#         for key in epoch_valid.keys():
#             valid_losses[key].append(epoch_valid[key] / num_valid_batches)

        
        
# # train_losses = { "depth_sidl": [], "depth_inv_huber": [], "depth_contrastive": [], "depth_smooth": []}
#         print(
#             f"Epoch {epoch+1}/{num_epochs} - Valid Seg Loss: {epoch_valid['seg']:.4f}, "
#             f"Valid Depth Loss: {epoch_valid['depth']:.4f}, Valid Combined Loss: {epoch_valid['combined']:.4f}, "
#             f"Valid mIOU: {epoch_valid['iou']:.4f}, Valid sidl Loss: {epoch_valid['depth_sidl']:.4f}, "
#             f"Valid depth smooth: {epoch_valid['depth_smooth']:.4f}"
#         )

#         # Write the losses to CSV
#         with open(csv_path, "a", newline="") as f:
#             writer = csv.writer(f)
#             writer.writerow([
#                 epoch + 1,
#                 train_losses["seg"], train_losses["depth"], train_losses["combined"],
#                 train_losses["depth_sidl"], 0,0,
#                 # avg_train_depth_inv_huber, avg_train_depth_contrastive,
#                 train_losses["depth_smooth"],
#                 valid_losses["seg"], valid_losses["depth"], valid_losses['combined'],
#                 valid_losses["depth_sidl"],0,0,
#                 # avg_valid_depth_inv_huber, avg_valid_depth_contrastive, 
#                 valid_losses["depth_smooth"]
#             ])

#         # Save GIF visualization frames
#         # save_training_visualization_as_gif(epoch, inputs, seg_output, depth_output, seg_labels, depth_labels, gif_frames)

#         # Save best model
#         if valid_losses["combined"][-1] < best_combined_loss:
#             best_combined_loss = valid_losses["combined"][-1]
#             torch.save(model, os.path.join(save_dir, "best_model_resnetBackbone.pth"))
#             print(f"Best model saved at epoch {epoch+1} with combined loss {best_combined_loss:.4f}")
            
#         if epoch%10==0:
#             gif_path2 =os.path.join(save_dir,f"viz_epoch_{epoch}.gif")
#             gif_frames[0].save(gif_path2, save_all=True, append_images=gif_frames[1:], duration=500, loop=0)

#     # Save GIF
#     # gif_frames[0].save(gif_path, save_all=True, append_images=gif_frames[1:], duration=500, loop=0)
    
#     plot_loss(train_losses, valid_losses, save_dir)
#     gif_frames[0].save(gif_path, save_all=True, append_images=gif_frames[1:], duration=500, loop=0)

    
    
    
#     return train_losses,valid_losses
