# Semantic Segmentation Demo

This is a notebook for running the benchmark semantic segmentation network from the the [ADE20K MIT Scene Parsing Benchchmark](http://sceneparsing.csail.mit.edu/).

The code for this notebook is available here
https://github.com/CSAILVision/semantic-segmentation-pytorch/tree/master/notebooks

It can be run on Colab at this URL https://colab.research.google.com/github/CSAILVision/semantic-segmentation-pytorch/blob/master/notebooks/DemoSegmenter.ipynb

### Environment Setup

First, download the code and pretrained models if we are on colab.

In [None]:
# %%bash
# # Colab-specific setup
# !(stat -t /usr/local/lib/*/dist-packages/google/colab > /dev/null 2>&1) && exit
# pip install yacs 2>&1 >> install.log
# git init 2>&1 >> install.log
# git remote add origin https://github.com/CSAILVision/semantic-segmentation-pytorch.git 2>> install.log
# git pull origin master 2>&1 >> install.log
# DOWNLOAD_ONLY=1 ./demo_test.sh 2>> install.log

In [None]:
# import torch
# print(torch.__version__)  # Should show the installed PyTorch version
# print(torch.cuda.is_available())  # Should return True if CUDA is working
# print(torch.version.cuda)

In [None]:
#!git clone https://github.com/CSAILVision/semantic-segmentation-pytorch.git

In [None]:
# ls

In [None]:
# !cd semantic-segmentation-pytorch

In [None]:
import sys
sys.path.append('./semantic-segmentation-pytorch')

In [None]:
ls semantic-segmentation-pytorch/

## Imports and utility functions

We need pytorch, numpy, and the code for the segmentation model.  And some utilities for visualizing the data.

In [None]:
# System libs
import os, csv, torch, numpy, scipy.io, PIL.Image, torchvision.transforms
# Our libs
from mit_semseg.models import ModelBuilder, SegmentationModule
from mit_semseg.utils import colorEncode

# colors = scipy.io.loadmat('data/color150.mat')['colors']
# names = {}

# with open('data/object150_info.csv') as f:
#     reader = csv.reader(f)
#     next(reader)
#     for row in reader:
#         names[int(row[0])] = row[5].split(";")[0]

# def visualize_result(img, pred, index=None):
#     # filter prediction class if requested
#     if index is not None:
#         pred = pred.copy()
#         pred[pred != index] = -1
#         print(f'{names[index+1]}:')

#     # colorize prediction
#     pred_color = colorEncode(pred, colors).astype(numpy.uint8)

#     # aggregate images and save
#     im_vis = numpy.concatenate((img, pred_color), axis=1)
#     display(PIL.Image.fromarray(im_vis))

In [None]:
#mkdir -p ckpt/ade20k-resnet50dilated-ppm_deepsup

In [None]:
# import os
# import urllib.request

# # Create directory if it doesn't exist
# ckpt_dir = 'ckpt/ade20k-resnet50dilated-ppm_deepsup'
# os.makedirs(ckpt_dir, exist_ok=True)

# # Download encoder weights
# encoder_url = "http://sceneparsing.csail.mit.edu/model/pytorch/ade20k-resnet50dilated-ppm_deepsup/encoder_epoch_20.pth"
# encoder_path = f"{ckpt_dir}/encoder_epoch_20.pth"
# urllib.request.urlretrieve(encoder_url, encoder_path)

# # Download decoder weights
# decoder_url = "http://sceneparsing.csail.mit.edu/model/pytorch/ade20k-resnet50dilated-ppm_deepsup/decoder_epoch_20.pth"
# decoder_path = f"{ckpt_dir}/decoder_epoch_20.pth"
# urllib.request.urlretrieve(decoder_url, decoder_path)

## Loading the segmentation model

Here we load a pretrained segmentation model.  Like any pytorch model, we can call it like a function, or examine the parameters in all the layers.

After loading, we put it on the GPU.  And since we are doing inference, not training, we put the model in eval mode.

In [None]:
# Network Builders
net_encoder = ModelBuilder.build_encoder(
    arch='resnet50dilated',
    fc_dim=2048,
    weights='ckpt/ade20k-resnet50dilated-ppm_deepsup/encoder_epoch_20.pth')
net_decoder = ModelBuilder.build_decoder(
    arch='ppm_deepsup',
    fc_dim=2048,
    num_class=150,
    weights='ckpt/ade20k-resnet50dilated-ppm_deepsup/decoder_epoch_20.pth',
    use_softmax=True)

crit = torch.nn.NLLLoss(ignore_index=-1)
segmentation_module = SegmentationModule(net_encoder, net_decoder, crit)
segmentation_module.eval()
segmentation_module.cuda()

## Segmentation Mask

In [None]:
# Load and normalize one image as a singleton tensor batch
pil_to_tensor = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(
        mean=[0.485, 0.456, 0.406], # These are RGB mean+std values
        std=[0.229, 0.224, 0.225])  # across a large photo dataset.
])



## Run the Model

Finally we just pass the test image to the segmentation model.

The segmentation model is coded as a function that takes a dictionary as input, because it wants to know both the input batch image data as well as the desired output segmentation resolution.  We ask for full resolution output.

Then we use the previously-defined visualize_result function to render the segmentation map.

## Showing classes individually

To see which colors are which, here we visualize individual classes, one at a time.

In [None]:
# # Top classes in answer
# predicted_classes = numpy.bincount(pred.flatten()).argsort()[::-1]
# for c in predicted_classes[:15]:
#     visualize_result(img_original, pred, c)

In [None]:
# pred.shape

# Autoencoder

In [None]:
import torch
import torch.nn as nn

class MaskAutoencoder(nn.Module):
    def __init__(self, height, width, k):
        super(MaskAutoencoder, self).__init__()
        # self.encoder = nn.Sequential(
        #     nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1),
        #     nn.ReLU(),
        #     nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
        #     nn.ReLU(),
        # )
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2)
        )
        self.flatten = nn.Flatten()
        # Calculating dimensions after downsampling
        def conv_output_size(input_size, kernel_size=3, stride=2, padding=1):
            return (input_size + 2*padding - kernel_size) // stride + 1

        h_after_conv1 = conv_output_size(height)
        w_after_conv1 = conv_output_size(width)
        h_after_conv2 = conv_output_size(h_after_conv1)
        w_after_conv2 = conv_output_size(w_after_conv1)

        self.encoded_dim = h_after_conv2 * w_after_conv2 * 64
        #self.encoded_dim = (height // 4) * (width // 4) * 64
        self.projection = nn.Linear(self.encoded_dim, k)

        # Decoder components
        self.unprojection = nn.Linear(k, self.encoded_dim)
        self.unflatten = nn.Unflatten(1, (64, h_after_conv2, w_after_conv2))
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 1, kernel_size=3, stride=2, padding=1, output_padding=1),
            #nn.PixelShuffle(1), # Not fix input and reconstructed size mismtach
            nn.Sigmoid()  # Sigmoid activation for pixel values in [0, 1]
        )

    def forward(self, x):
        # Encoding
        features = self.encoder(x)

        flattened = self.flatten(features)

        projected = self.projection(flattened)

        # Decoding
        unprojected = self.unprojection(projected)

        unflattened = self.unflatten(unprojected)

        reconstructed = self.decoder(unflattened)
        reconstructed = reconstructed[:, :, :x.shape[2], :x.shape[3]] # align output size with input size

        return projected, reconstructed

# Example reconstruction loss calculation
def reconstruction_loss(original, reconstructed):
    loss_fn = nn.MSELoss()  # Mean Squared Error for pixel-wise comparison
    return loss_fn(reconstructed, original)


In [None]:
# pil_to_tensor = torchvision.transforms.Compose([
#     torchvision.transforms.ToTensor(),
#     torchvision.transforms.Normalize(
#         mean=[0.485, 0.456, 0.406], # These are RGB mean+std values
#         std=[0.229, 0.224, 0.225])  # across a large photo dataset.
# ])

# def extract_and_save_segmentations(image_dir, segmentation_module, output_dir, batch_size=4):
#     """
#     Extract segmentations for all images in a directory and save them.

#     Args:
#         image_dir: Directory containing input images
#         segmentation_module: Pretrained segmentation model
#         output_dir: Directory to save segmentation masks
#         batch_size: Batch size for processing
#     """
#     os.makedirs(output_dir, exist_ok=True)

#     # Image transformation pipeline
#     transform = torchvision.transforms.Compose([
#         torchvision.transforms.ToTensor(),
#         torchvision.transforms.Normalize(
#             mean=[0.485, 0.456, 0.406],
#             std=[0.229, 0.224, 0.225]
#         )
#     ])

#     # Get list of all images
#     image_files = [f for f in os.listdir(image_dir)]

#     for img_file in tqdm(image_files, desc="Processing images"):
#         try:
#             # Load and process image
#             img_path = os.path.join(image_dir, img_file)
#             pil_image = PIL.Image.open(img_path).convert('RGB')
#             img_original = numpy.array(pil_image)
#             img_data = pil_to_tensor(pil_image)
#             singleton_batch = {'img_data': img_data[None].cuda()}
#             output_size = img_data.shape[1:]

#             # Generate segmentation
#             with torch.no_grad():
#                 scores = segmentation_module(singleton_batch, segSize=output_size)

#             # Get prediction
#             _, pred = torch.max(scores, dim=1)
#             pred = pred.cpu()[0].numpy()
#             pred_tensor = torch.tensor(pred, dtype=torch.float32).unsqueeze(0).unsqueeze(0)

#             # Save segmentation mask
#             output_path = os.path.join(output_dir, f"{os.path.splitext(img_file)[0]}_seg.npy")
#             np.save(output_path, pred)

#         except Exception as e:
#             print(f"Error processing {img_file}: {str(e)}")
#             continue

In [None]:
# # Colab
# image_dir = '/content/drive/My Drive/Capstone - AI Guide Dog 2024/output_frames'
# segmentation_output_dir = '/content/drive/My Drive/Capstone - AI Guide Dog 2024/segmentation_masks'



In [None]:
# segmentation_output_dir = 'segmentation_masks/'

In [None]:
# #### NO RERUN #######
# inp = 'output_frames'
# out = 'output2_masks/'
# extract_and_save_segmentations(
#     image_dir=inp,
#     segmentation_module=segmentation_module,  # Your pretrained segmentation module
#     output_dir=out
# )

In [None]:
import torch
import torch.nn as nn
import torchvision
import PIL.Image
import numpy as np
from torch.utils.data import Dataset, DataLoader
import os
from tqdm import tqdm
import matplotlib.pyplot as plt
import random
from torch.utils.data import random_split

class SegmentationDataset(Dataset):
    def __init__(self, segmentation_dir, file_list):
        self.segmentation_dir = segmentation_dir
        self.file_list = file_list
    
    def __len__(self):
        return len(self.file_list)
    
    def __getitem__(self, idx):
        seg_file = self.file_list[idx]
        seg_path = os.path.join(self.segmentation_dir, seg_file)
        seg_mask = np.load(seg_path)
        return torch.tensor(seg_mask, dtype=torch.float32).unsqueeze(0)

In [None]:
# with train-test split
def train_autoencoder(autoencoder, segmentation_dir, num_epochs=100, learning_rate=0.001, batch_size=16, eval_frequency=10):
    # Create dataset
    seg_files = [f for f in os.listdir(segmentation_dir)]
    full_dataset = SegmentationDataset(segmentation_dir, seg_files)
    
    # Split dataset
    train_size = int(0.8 * len(full_dataset))
    test_size = len(full_dataset) - train_size
    train_dataset, test_dataset = random_split(full_dataset, [train_size, test_size])
    
    # Create DataLoaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    autoencoder = autoencoder.to(device)
    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(autoencoder.parameters(), lr=learning_rate)
    
    # Training loop with test evaluation
    for epoch in range(num_epochs):
        # Training phase
        autoencoder.train()
        train_loss = 0
        for batch in tqdm(train_loader, desc=f"Training Epoch {epoch+1}"):
            seg_tensor = batch.to(device)
            seg_tensor = seg_tensor.float() / 255.0
            
            optimizer.zero_grad()
            encoded, reconstructed = autoencoder(seg_tensor)
            loss = criterion(reconstructed, seg_tensor)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
        
        # Print epoch statistics
        print(f"Epoch {epoch+1}")
        print(f"Train Loss: {train_loss/len(train_loader):.6f}")
        
        # Validation phase
        if (epoch + 1) % eval_frequency == 0:
          autoencoder.eval()
          val_loss = 0
          with torch.no_grad():
              for batch in test_loader:
                  seg_tensor = batch.to(device)
                  encoded, reconstructed = autoencoder(seg_tensor)
                  loss = criterion(reconstructed, seg_tensor)
                  val_loss += loss.item()
              print(f"Validation Loss: {val_loss/len(test_loader):.6f}")
        
        if (epoch + 1) % 10 == 0:
            checkpoint_path = f'autoencoder_checkpoint_epoch_{epoch+1}.pth'
            torch.save({
                'epoch': epoch,
                'model_state_dict': autoencoder.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': avg_loss,
            }, checkpoint_path)
            
    return autoencoder

In [None]:
k = 648  # Encoded dimension size
autoencoder = MaskAutoencoder(height=270, width=480, k=k)


In [None]:
from torch.optim.lr_scheduler import CyclicLR


In [None]:
from datetime import datetime

segmentation_dir = 'segmentation_input/extracted_masks'
seg_files = [f for f in os.listdir(segmentation_dir)]
full_dataset = SegmentationDataset(segmentation_dir, seg_files)
num_epochs = 100
learning_rate = 0.0005
base_lr = 0.0001
max_lr = 0.001
batch_size = 16
eval_frequency = 10 # number of epochs per eval

# Split dataset
train_size = int(0.8 * len(full_dataset))
test_size = len(full_dataset) - train_size
train_dataset, test_dataset = random_split(full_dataset, [train_size, test_size])

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
autoencoder = autoencoder.to(device)
criterion = nn.MSELoss()
# optimizer = torch.optim.Adam(autoencoder.parameters(), lr=learning_rate)
optimizer = torch.optim.Adam(autoencoder.parameters(), lr=base_lr, weight_decay=1e-5)
scheduler = CyclicLR(
    optimizer,
    base_lr=base_lr,
    max_lr=max_lr,
    step_size_up=4,
    mode='triangular',
    cycle_momentum=False  # Disable cycle_momentum
)


train_losses = []
test_losses = []

# Training loop with test evaluation
for epoch in range(num_epochs):
    # Training phase
    autoencoder.train()
    train_loss = 0
    for batch in tqdm(train_loader, desc=f"Training Epoch {epoch+1}"):
        seg_tensor = batch.to(device)
        seg_tensor = seg_tensor.float() / 255.0
        
        optimizer.zero_grad()
        encoded, reconstructed = autoencoder(seg_tensor)
        loss = criterion(reconstructed, seg_tensor)
        loss.backward()
        optimizer.step()

        scheduler.step()
        
        train_loss += loss.item()
    train_losses.append(train_loss/len(train_loader))
    # Print epoch statistics
    avg_loss = train_loss / len(train_loader)
    print(f"Epoch {epoch+1}, Train Loss: {avg_loss:.6f}")
    
    # Validation phase
    if (epoch + 1) % eval_frequency == 0:
      autoencoder.eval()
      val_loss = 0
      with torch.no_grad():
          for batch in test_loader:
              seg_tensor = batch.to(device)
              encoded, reconstructed = autoencoder(seg_tensor)
              loss = criterion(reconstructed, seg_tensor)
              val_loss += loss.item()
      #         plt.figure(figsize=(10, 5))
      #         for i in range(4):
      #            plt.subplot(2, 4, i + 1)
      #            plt.imshow(seg_tensor[i, 0].cpu().numpy(), cmap='gray')
      #            plt.title('Original')
      #            plt.axis('off')
      #            plt.subplot(2, 4, i + 5)
      #            plt.imshow(reconstructed[i, 0].cpu().numpy(), cmap='gray')
      #            plt.title('Reconstructed')
      #            plt.axis('off')
      # plt.show()
      test_losses.append(val_loss/len(test_loader))
      print(f"Validation Loss: {val_loss/len(test_loader):.6f}")
    
    if (epoch + 1) % 20 == 0:
        timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
        checkpoint_path = f'autoencoder_checkpoint_epoch_{epoch+1}_{timestamp}.pth'
        torch.save({
            'epoch': epoch,
            'model_state_dict': autoencoder.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': avg_loss,
        }, checkpoint_path)

In [None]:
autoencoder.cuda()

In [None]:
# train_losses - normalize tensor, dynamic learning rate
# 0.010598669550978603,
#  0.009087359782127912,
#  0.004185134204014329,
#  0.0036549259360582485,
#  0.0034981117865065395,
#  0.0033929046035141943,
#  0.003267997186604431,
#  0.0031347709271441243,
#  0.0031136135601218355,
#  0.0030799311230176704,
#  0.0029925421501199403,
#  0.0029478789870348624,
#  0.002950479022089254,
#  0.002869488670095334,
#  0.002872124812472877,
#  0.0027697143573420084,
#  0.0027105508848462266,
#  0.0026728716898481346,
#  0.0026305040820207223,
#  0.002590992351983305,
#  0.0025186698235030104,
#  0.0024867024474334175,
#  0.0025063614902269635,
#  0.002402767481645861,
#  0.0023577577461676934,
#  0.002360801220449627,
#  0.0022824687652417226,
#  0.0023271376243858806,
#  0.0022554369334879722,
#  0.0022152050453346974,
#  0.00219290150943719,
#  0.0022095415984649223,
#  0.0021607008453543453,
#  0.0022139052241547112,
#  0.0021012918201097396,
#  0.0021487404303784426,
#  0.0021049527768470196,
#  0.002043176349186129,
#  0.0021780415015571096,
#  0.0020342982106144273,
#  0.0020528297162909284,
#  0.002011137314345104,
#  0.0020772715877248053,
#  0.00196996894137711,
#  0.001978731192807611,
#  0.002019311662264529,
#  0.0021013033349449053,
#  0.0019783886884832676,
#  0.0019796134206198387,
#  0.001969414494170189,
#  0.001913465900768518,
#  0.001958244603241335,
#  0.001977357328497935,
#  0.0018932805781350856,
#  0.001982925513141027,
#  0.0019326269619320746,
#  0.0018998307940096427,
#  0.0019092814302989986,
#  0.001899283706672235,
#  0.0018762577307858952,
#  0.0019800624560745183,
#  0.0018605668813473097,
#  0.0018876566902530031,
#  0.0018942271691927436,
#  0.0018847197260405135,
#  0.0018821798951200546,
#  0.001874544504868063,
#  0.0018958703324893898,
#  0.001916885769002649,
#  0.0018252110996938966,
#  0.0018747157100163548,
#  0.0018779334510700442,
#  0.0018953005942460308,
#  0.0018827289892121768,
#  0.0018879924637154064,
#  0.001826623885194503,
#  0.0018644464750140056,
#  0.001829887837326775,
#  0.001967043435450463,
#  0.0018982560956914327,
#  0.0018301724260235126,
#  0.001860489436137381,
#  0.0018488050468645437,
#  0.001918041868336987,
#  0.0018233922399723758,
#  0.0018673976959575741,
#  0.0018486381648373698,
#  0.001847924745709309,
#  0.0017870324852868405,
#  0.0019125953823144939,
#  0.0018307854515298714,
#  0.0018192705097329668,
#  0.0018259260439953404,
#  0.001847274343422123,
#  0.0018665836162154066,
#  0.0018549766855850349,
#  0.0018079144353819136,
#  0.0018595705990231148,
#  0.0018406131396340763,
#  0.0018173889850698773]

In [None]:
# test_losses
# [674.6184588345615,
#  669.6901964707808,
#  669.3669828935103,
#  669.2363950555974,
#  669.2811737060547,
#  668.5389347076416,
#  668.6799944097346,
#  668.8191734660755,
#  670.0550876964222,
#  668.7530928525058]

In [None]:
# # norm tensor, lr = 0.0005
##### train loss
# [0.012919229262925026,
#  0.010495140084577591,
#  0.010502334075855182,
#  0.010507502728420446,
#  0.010485618315541591,
#  0.010513703299001751,
#  0.010506574628328071,
#  0.010474914830162102,
#  0.010476896819921259,
#  0.010497446879717997,
#  0.010483136958445282,
#  0.010488326502751633,
#  0.010519665383343767,
#  0.010484851373391401,
#  0.010485852921270153,
#  0.010498622436331123,
#  0.01049230278514729,
#  0.010421480214026918,
#  0.010359131299925071,
#  0.010377319968780475,
#  0.010368558291632395,
#  0.010366258009116662,
#  0.010376635081853791,
#  0.010358538960857566,
#  0.010361543597860469,
#  0.010378243408279966,
#  0.010363919505891933,
#  0.010391937849374536,
#  0.01036856722078708,
#  0.010374640451579691,
#  0.010385230904646705,
#  0.01036395997995049,
#  0.010369975140392567,
#  0.010380298144016892,
#  0.010357583891547303,
#  0.010378568361584958,
#  0.010395116816166035,
#  0.010365820313376664,
#  0.01037341430372004,
#  0.010382045142649754,
#  0.010369892828458353,
#  0.01036810908968059,
#  0.010363976306618958,
#  0.010389761303410868,
#  0.01036483648119022,
#  0.010375067760154308,
#  0.010382358115458293,
#  0.01036999827155318,
#  0.01035941466145175,
#  0.010365696513144422,
#  0.010379880772866414,
#  0.010377261351178429,
#  0.010380936196164103,
#  0.010368056516388585,
#  0.01036732699297517,
#  0.010383301340919146,
#  0.010361180566463546,
#  0.010379377396331511,
#  0.010371307840443447,
#  0.010383967946991961,
#  0.010379769261581477,
#  0.010364630549450927,
#  0.010393445259968397,
#  0.010372201000682564,
#  0.010361896019211395,
#  0.010377790731778554,
#  0.01038492830374684,
#  0.010386664957667772,
#  0.010391447969056942,
#  0.010383025772254011,
#  0.010374613995783223,
#  0.010389668922048815,
#  0.010372335946272284,
#  0.010381413854531243,
#  0.010374731563318234,
#  0.010384616603215154,
#  0.01036797573609527,
#  0.010367941154443004,
#  0.010362121386007749,
#  0.010378099776316232,
#  0.01037243392419314,
#  0.010412328782361372,
#  0.010378307600550757,
#  0.010391010809955093,
#  0.010383097464855537,
#  0.010361636474070556,
#  0.01039646823553533,
#  0.01037824592630026,
#  0.01038795979653732,
#  0.01035960474346056,
#  0.010373186517076997,
#  0.010373493545739102,
#  0.010364184195860328,
#  0.010390874065400402,
#  0.010368071978067996,
#  0.010388104152199719,
#  0.010358121579093726,
#  0.010385539870910834,
#  0.010366724525442999,
#  0.010368037610009652]

In [None]:
# test_losses
# [689.2446632385254,
#  689.2478457364169,
#  689.2478457364169,
#  689.2478457364169,
#  689.2478457364169,
#  689.2478457364169,
#  689.2478457364169,
#  689.2478457364169,
#  689.2478457364169,
#  689.2478457364169]

In [None]:
# #### lr = 0.0005, normalize in autoencoder layers
# Epoch 1, Train Loss: 652.069466
# Training Epoch 2: 100%|█████████████████████████████████████| 351/351 [01:39<00:00,  3.53it/s]
# Epoch 2, Train Loss: 652.692586
# Training Epoch 3: 100%|█████████████████████████████████████| 351/351 [01:39<00:00,  3.54it/s]
# Epoch 3, Train Loss: 652.600141
# Training Epoch 4: 100%|█████████████████████████████████████| 351/351 [01:39<00:00,  3.54it/s]
# Epoch 4, Train Loss: 651.894051
# Training Epoch 5: 100%|█████████████████████████████████████| 351/351 [01:39<00:00,  3.55it/s]
# Epoch 5, Train Loss: 651.884299
# Training Epoch 6: 100%|█████████████████████████████████████| 351/351 [01:39<00:00,  3.54it/s]
# Epoch 6, Train Loss: 651.157909
# Training Epoch 7: 100%|█████████████████████████████████████| 351/351 [01:39<00:00,  3.54it/s]
# Epoch 7, Train Loss: 652.808403
# Training Epoch 8: 100%|█████████████████████████████████████| 351/351 [01:39<00:00,  3.54it/s]
# Epoch 8, Train Loss: 651.930558
# Training Epoch 9: 100%|█████████████████████████████████████| 351/351 [01:39<00:00,  3.54it/s]
# Epoch 9, Train Loss: 652.140237
# Training Epoch 10: 100%|████████████████████████████████████| 351/351 [01:39<00:00,  3.53it/s]
# Epoch 10, Train Loss: 651.177875
# Validation Loss: 659.643295
# Training Epoch 11: 100%|████████████████████████████████████| 351/351 [01:38<00:00,  3.55it/s]
# Epoch 11, Train Loss: 652.893720
# Training Epoch 12: 100%|████████████████████████████████████| 351/351 [01:38<00:00,  3.55it/s]
# Epoch 12, Train Loss: 651.975086
# Training Epoch 13: 100%|████████████████████████████████████| 351/351 [01:38<00:00,  3.55it/s]
# Epoch 13, Train Loss: 652.931841
# Training Epoch 14: 100%|████████████████████████████████████| 351/351 [01:39<00:00,  3.54it/s]
# Epoch 14, Train Loss: 652.330461
# Training Epoch 15: 100%|████████████████████████████████████| 351/351 [01:39<00:00,  3.54it/s]
# Epoch 15, Train Loss: 653.820897
# Training Epoch 16: 100%|████████████████████████████████████| 351/351 [01:39<00:00,  3.54it/s]
# Epoch 16, Train Loss: 651.770907
# Training Epoch 17: 100%|████████████████████████████████████| 351/351 [01:39<00:00,  3.53it/s]
# Epoch 17, Train Loss: 651.725226
# Training Epoch 18: 100%|████████████████████████████████████| 351/351 [01:39<00:00,  3.54it/s]
# Epoch 18, Train Loss: 652.241266
# Training Epoch 19: 100%|████████████████████████████████████| 351/351 [01:39<00:00,  3.54it/s]
# Epoch 19, Train Loss: 651.556436
# Training Epoch 20: 100%|████████████████████████████████████| 351/351 [01:38<00:00,  3.55it/s]
# Epoch 20, Train Loss: 651.854807
# Validation Loss: 659.645702
# Training Epoch 21: 100%|████████████████████████████████████| 351/351 [01:39<00:00,  3.54it/s]
# Epoch 21, Train Loss: 651.743324
# Training Epoch 22: 100%|████████████████████████████████████| 351/351 [01:39<00:00,  3.55it/s]
# Epoch 22, Train Loss: 652.704887
# Training Epoch 23: 100%|████████████████████████████████████| 351/351 [01:39<00:00,  3.53it/s]
# Epoch 23, Train Loss: 652.056867
# Training Epoch 24: 100%|████████████████████████████████████| 351/351 [01:39<00:00,  3.53it/s]
# Epoch 24, Train Loss: 652.186908
# Training Epoch 25: 100%|████████████████████████████████████| 351/351 [01:39<00:00,  3.54it/s]
# Epoch 25, Train Loss: 652.562346
# Training Epoch 26: 100%|████████████████████████████████████| 351/351 [01:39<00:00,  3.54it/s]
# Epoch 26, Train Loss: 652.013083
# Training Epoch 27: 100%|████████████████████████████████████| 351/351 [01:39<00:00,  3.54it/s]
# Epoch 27, Train Loss: 652.361789
# Training Epoch 28: 100%|████████████████████████████████████| 351/351 [01:39<00:00,  3.54it/s]
# Epoch 28, Train Loss: 651.564174
# Training Epoch 29: 100%|████████████████████████████████████| 351/351 [01:39<00:00,  3.54it/s]
# Epoch 29, Train Loss: 651.654328
# Training Epoch 30: 100%|████████████████████████████████████| 351/351 [01:39<00:00,  3.54it/s]
# Epoch 30, Train Loss: 651.677792
# Validation Loss: 659.642267
# Training Epoch 31: 100%|████████████████████████████████████| 351/351 [01:39<00:00,  3.54it/s]
# Epoch 31, Train Loss: 652.280122
# Training Epoch 32: 100%|████████████████████████████████████| 351/351 [01:39<00:00,  3.54it/s]
# Epoch 32, Train Loss: 652.841016
# Training Epoch 33: 100%|████████████████████████████████████| 351/351 [01:39<00:00,  3.53it/s]
# Epoch 33, Train Loss: 652.137129
# Training Epoch 34: 100%|████████████████████████████████████| 351/351 [01:39<00:00,  3.54it/s]
# Epoch 34, Train Loss: 652.149949
# Training Epoch 35: 100%|████████████████████████████████████| 351/351 [01:39<00:00,  3.53it/s]
# Epoch 35, Train Loss: 651.757967
# Training Epoch 36: 100%|████████████████████████████████████| 351/351 [01:39<00:00,  3.54it/s]
# Epoch 36, Train Loss: 653.631322
# Training Epoch 37: 100%|████████████████████████████████████| 351/351 [01:39<00:00,  3.54it/s]
# Epoch 37, Train Loss: 652.113598
# Training Epoch 38: 100%|████████████████████████████████████| 351/351 [01:39<00:00,  3.54it/s]
# Epoch 38, Train Loss: 651.450653
# Training Epoch 39: 100%|████████████████████████████████████| 351/351 [01:39<00:00,  3.54it/s]
# Epoch 39, Train Loss: 652.112651
# Training Epoch 40: 100%|████████████████████████████████████| 351/351 [01:39<00:00,  3.54it/s]
# Epoch 40, Train Loss: 653.152022
# Validation Loss: 659.641706
# Training Epoch 41: 100%|████████████████████████████████████| 351/351 [01:39<00:00,  3.54it/s]
# Epoch 41, Train Loss: 651.838199
# Training Epoch 42: 100%|████████████████████████████████████| 351/351 [01:39<00:00,  3.54it/s]
# Epoch 42, Train Loss: 653.216685
# Training Epoch 43: 100%|████████████████████████████████████| 351/351 [01:39<00:00,  3.54it/s]
# Epoch 43, Train Loss: 651.788312
# Training Epoch 44: 100%|████████████████████████████████████| 351/351 [01:39<00:00,  3.55it/s]
# Epoch 44, Train Loss: 652.100813
# Training Epoch 45: 100%|████████████████████████████████████| 351/351 [01:39<00:00,  3.54it/s]
# Epoch 45, Train Loss: 651.938989
# Training Epoch 46: 100%|████████████████████████████████████| 351/351 [01:39<00:00,  3.53it/s]
# Epoch 46, Train Loss: 652.252912
# Training Epoch 47: 100%|████████████████████████████████████| 351/351 [01:38<00:00,  3.55it/s]
# Epoch 47, Train Loss: 652.736892
# Training Epoch 48: 100%|████████████████████████████████████| 351/351 [01:39<00:00,  3.54it/s]
# Epoch 48, Train Loss: 652.466110


In [None]:
#### lr = 0.005, no normalization
# Training Epoch 1: 100%|█████████████████████████████████████| 351/351 [02:08<00:00,  2.74it/s]
# Epoch 1, Train Loss: 656.253363
# Training Epoch 2: 100%|█████████████████████████████████████| 351/351 [01:37<00:00,  3.59it/s]
# Epoch 2, Train Loss: 657.944037
# Training Epoch 3: 100%|█████████████████████████████████████| 351/351 [01:37<00:00,  3.59it/s]
# Epoch 3, Train Loss: 656.778741
# Training Epoch 4: 100%|█████████████████████████████████████| 351/351 [01:37<00:00,  3.58it/s]
# Epoch 4, Train Loss: 656.744963
# Training Epoch 5: 100%|█████████████████████████████████████| 351/351 [01:37<00:00,  3.58it/s]
# Epoch 5, Train Loss: 656.823661
# Training Epoch 6: 100%|█████████████████████████████████████| 351/351 [01:37<00:00,  3.59it/s]
# Epoch 6, Train Loss: 660.347832
# Training Epoch 7: 100%|█████████████████████████████████████| 351/351 [01:37<00:00,  3.60it/s]
# Epoch 7, Train Loss: 656.921275
# Training Epoch 8: 100%|█████████████████████████████████████| 351/351 [01:37<00:00,  3.59it/s]
# Epoch 8, Train Loss: 656.122115
# Training Epoch 9: 100%|█████████████████████████████████████| 351/351 [01:37<00:00,  3.59it/s]
# Epoch 9, Train Loss: 656.077839
# Training Epoch 10: 100%|████████████████████████████████████| 351/351 [01:37<00:00,  3.59it/s]
# Epoch 10, Train Loss: 656.742585
# Validation Loss: 639.907631
# Training Epoch 11: 100%|████████████████████████████████████| 351/351 [01:37<00:00,  3.59it/s]
# Epoch 11, Train Loss: 656.353283
# Training Epoch 12: 100%|████████████████████████████████████| 351/351 [01:37<00:00,  3.59it/s]
# Epoch 12, Train Loss: 656.320721
# Training Epoch 13: 100%|████████████████████████████████████| 351/351 [01:37<00:00,  3.59it/s]
# Epoch 13, Train Loss: 657.266287
# Training Epoch 14: 100%|████████████████████████████████████| 351/351 [01:37<00:00,  3.59it/s]
# Epoch 14, Train Loss: 656.678734
# Training Epoch 15: 100%|████████████████████████████████████| 351/351 [01:37<00:00,  3.59it/s]
# Epoch 15, Train Loss: 657.124233
# Training Epoch 16: 100%|████████████████████████████████████| 351/351 [01:37<00:00,  3.58it/s]
# Epoch 16, Train Loss: 656.573825
# Training Epoch 17: 100%|████████████████████████████████████| 351/351 [01:37<00:00,  3.60it/s]
# Epoch 17, Train Loss: 657.191130
# Training Epoch 18: 100%|████████████████████████████████████| 351/351 [01:37<00:00,  3.59it/s]
# Epoch 18, Train Loss: 656.645940
# Training Epoch 19: 100%|████████████████████████████████████| 351/351 [01:37<00:00,  3.59it/s]
# Epoch 19, Train Loss: 657.434529
# Training Epoch 20: 100%|████████████████████████████████████| 351/351 [01:37<00:00,  3.59it/s]
# Epoch 20, Train Loss: 657.011962
# Validation Loss: 639.907631
# Training Epoch 21: 100%|████████████████████████████████████| 351/351 [01:37<00:00,  3.60it/s]
# Epoch 21, Train Loss: 658.020220
# Training Epoch 22: 100%|████████████████████████████████████| 351/351 [01:37<00:00,  3.59it/s]
# Epoch 22, Train Loss: 657.597698
# Training Epoch 23: 100%|████████████████████████████████████| 351/351 [01:37<00:00,  3.59it/s]
# Epoch 23, Train Loss: 656.368929
# Training Epoch 24: 100%|████████████████████████████████████| 351/351 [01:37<00:00,  3.59it/s]
# Epoch 24, Train Loss: 656.939704
# Training Epoch 25: 100%|████████████████████████████████████| 351/351 [01:37<00:00,  3.59it/s]
# Epoch 25, Train Loss: 656.370394
# Training Epoch 26: 100%|████████████████████████████████████| 351/351 [01:37<00:00,  3.59it/s]
# Epoch 26, Train Loss: 656.970797
# Training Epoch 27: 100%|████████████████████████████████████| 351/351 [01:37<00:00,  3.59it/s]
# Epoch 27, Train Loss: 655.919122
# Training Epoch 28: 100%|████████████████████████████████████| 351/351 [01:37<00:00,  3.59it/s]
# Epoch 28, Train Loss: 657.775658
# Training Epoch 29: 100%|████████████████████████████████████| 351/351 [01:37<00:00,  3.59it/s]
# Epoch 29, Train Loss: 656.966953
# Training Epoch 30: 100%|████████████████████████████████████| 351/351 [01:37<00:00,  3.59it/s]
# Epoch 30, Train Loss: 657.083070
# Validation Loss: 639.907631
# Training Epoch 31: 100%|████████████████████████████████████| 351/351 [01:37<00:00,  3.60it/s]
# Epoch 31, Train Loss: 656.580937
# Training Epoch 32:  17%|██████▍                              | 61/351 [00:17<01:22,  3.54it/s]

In [None]:
torch.cuda.empty_cache()

In [None]:
torch.save(trained_autoencoder.state_dict(), 'trained_autoencoder_final.pth')

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MaskAutoencoder2(nn.Module):
    def __init__(self, height, width, k):
        super(MaskAutoencoder2, self).__init__()
        
        # Improved encoder with residual connections and proper normalization
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.1),
            
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.1),
            
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.1)
        )
        
        # Calculate encoded dimensions
        def conv_output_size(input_size, kernel_size=3, stride=2, padding=1):
            return (input_size + 2*padding - kernel_size) // stride + 1
            
        h_after_conv1 = conv_output_size(height)
        w_after_conv1 = conv_output_size(width)
        h_after_conv2 = conv_output_size(h_after_conv1)
        w_after_conv2 = conv_output_size(w_after_conv1)
        
        self.encoded_dim = h_after_conv2 * w_after_conv2 * 128
        
        # Improved bottleneck
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(self.encoded_dim, k * 2)
        self.fc2 = nn.Linear(k * 2, k)
        self.fc3 = nn.Linear(k, k * 2)
        self.fc4 = nn.Linear(k * 2, self.encoded_dim)
        
        # Improved decoder with skip connections
        self.unflatten = nn.Unflatten(1, (128, h_after_conv2, w_after_conv2))
        
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.1),
            
            nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.1),
            
            nn.Conv2d(32, 1, kernel_size=3, stride=1, padding=1),
            nn.Sigmoid()
        )
        
    def encode(self, x):
        x = self.encoder(x)
        x = self.flatten(x)
        x = F.leaky_relu(self.fc1(x), 0.2)
        return self.fc2(x)
        
    def decode(self, z):
        x = F.leaky_relu(self.fc3(z), 0.2)
        x = F.leaky_relu(self.fc4(x), 0.2)
        x = self.unflatten(x)
        return self.decoder(x)
    
    def forward(self, x):
        z = self.encode(x)
        reconstructed = self.decode(z)
        reconstructed = F.interpolate(reconstructed, size=(x.shape[2], x.shape[3]), mode='bilinear', align_corners=False)
        return z, reconstructed





In [None]:
k = 648  # Encoded dimension size
autoencoder2 = MaskAutoencoder2(height=270, width=480, k=k)


In [None]:

from datetime import datetime

segmentation_dir = 'segmentation_input/extracted_masks'
seg_files = [f for f in os.listdir(segmentation_dir)]
full_dataset = SegmentationDataset(segmentation_dir, seg_files)
num_epochs = 100
learning_rate = 0.0005
base_lr = 0.0001
max_lr = 0.001
batch_size = 16
eval_frequency = 10 # number of epochs per eval

# Split dataset
train_size = int(0.8 * len(full_dataset))
test_size = len(full_dataset) - train_size
train_dataset, test_dataset = random_split(full_dataset, [train_size, test_size])

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
autoencoder2 = autoencoder2.to(device)
criterion = nn.MSELoss()
# optimizer = torch.optim.Adam(autoencoder2.parameters(), lr=learning_rate)
optimizer = torch.optim.Adam(autoencoder2.parameters(), lr=base_lr, weight_decay=1e-5)
scheduler = CyclicLR(
    optimizer,
    base_lr=base_lr,
    max_lr=max_lr,
    step_size_up=4,
    mode='triangular',
    cycle_momentum=False  # Disable cycle_momentum
)


train_losses = []
test_losses = []

# Training loop with test evaluation
for epoch in range(num_epochs):
    # Training phase
    autoencoder2.train()
    train_loss = 0
    for batch in tqdm(train_loader, desc=f"Training Epoch {epoch+1}"):
        seg_tensor = batch.to(device)
        seg_tensor = seg_tensor.float() / 255.0
        
        optimizer.zero_grad()
        encoded, reconstructed = autoencoder2(seg_tensor)
        loss = criterion(reconstructed, seg_tensor)
        loss.backward()
        optimizer.step()

        scheduler.step()
        
        train_loss += loss.item()
    train_losses.append(train_loss/len(train_loader))
    # Print epoch statistics
    avg_loss = train_loss / len(train_loader)
    print(f"Epoch {epoch+1}, Train Loss: {avg_loss:.6f}")
    
    # Validation phase
    if (epoch + 1) % eval_frequency == 0:
      autoencoder2.eval()
      val_loss = 0
      with torch.no_grad():
          for batch in test_loader:
              seg_tensor = batch.to(device)
              seg_tensor = seg_tensor.float() / 255.0
              encoded, reconstructed = autoencoder2(seg_tensor)
              loss = criterion(reconstructed, seg_tensor)
              val_loss += loss.item()
      #         plt.figure(figsize=(10, 5))
      #         for i in range(4):
      #            plt.subplot(2, 4, i + 1)
      #            plt.imshow(seg_tensor[i, 0].cpu().numpy(), cmap='gray')
      #            plt.title('Original')
      #            plt.axis('off')
      #            plt.subplot(2, 4, i + 5)
      #            plt.imshow(reconstructed[i, 0].cpu().numpy(), cmap='gray')
      #            plt.title('Reconstructed')
      #            plt.axis('off')
      # plt.show()
      test_losses.append(val_loss/len(test_loader))
      print(f"Validation Loss: {val_loss/len(test_loader):.6f}")
    
    if (epoch + 1) % 20 == 0:
        timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
        checkpoint_path = f'autoencoder2_checkpoint_epoch_{epoch+1}_{timestamp}.pth'
        torch.save({
            'epoch': epoch,
            'model_state_dict': autoencoder2.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': avg_loss,
        }, checkpoint_path)