In [1]:
import torch
import os
import torchvision.models as models
import torch.nn as nn
from torch.optim import AdamW
from torch.optim import lr_scheduler
from model import Model
from torch.utils.data import DataLoader
from torchvision.datasets import Cityscapes, wrap_dataset_for_transforms_v2
from torchvision.utils import make_grid
from torchvision.transforms.v2 import (
    Compose,
    Normalize,
    Resize,
    ToImage,
    ToDtype,
    RandomHorizontalFlip,
    RandomVerticalFlip,
)

# Mapping class IDs to train IDs
id_to_trainid = {cls.id: cls.train_id for cls in Cityscapes.classes}
def convert_to_train_id(label_img: torch.Tensor) -> torch.Tensor:
    return label_img.apply_(lambda x: id_to_trainid[x])

# Mapping train IDs to color
train_id_to_color = {cls.train_id: cls.color for cls in Cityscapes.classes if cls.train_id != 255}
train_id_to_color[255] = (0, 0, 0)  # Assign black to ignored labels

def convert_train_id_to_color(prediction: torch.Tensor) -> torch.Tensor:
    batch, _, height, width = prediction.shape
    color_image = torch.zeros((batch, 3, height, width), dtype=torch.uint8)

    for train_id, color in train_id_to_color.items():
        mask = prediction[:, 0] == train_id

        for i in range(3):
            color_image[:, i][mask] = color[i]

    return color_image

In [2]:
# Set seed for reproducability
# If you add other sources of randomness (NumPy, Random), 
# make sure to set their seeds as well
torch.manual_seed(42)
torch.backends.cudnn.deterministic = True
# Define the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
# Define the transforms to apply to the data
class PaintingByNumbersTransform:
      def __init__(self, id_to_color=None):
          self.id_to_color = id_to_color  # Dictionary mapping class IDs to colors
  
      def random_recolor(self, label_img):
          """Assigns random colors to segmentation labels."""
          h, w = label_img.shape[1:]
          recolored = torch.zeros((3, h, w), dtype=torch.uint8)  # Create an empty RGB image
                  
          unique_labels = label_img.unique()
          color_map = {label.item(): torch.randint(0, 256, (3,), dtype=torch.uint8) for label in unique_labels}

          for label, color in color_map.items():
              mask = (label_img[0] == label)  # label_img shape is [1, h, w]
              for c in range(3):
                  recolored[c][mask] = color[c]
             
          return recolored
  
      def __call__(self, img, target):
          if torch.rand(1).item() > 0.5:
              # Load the actual ground truth color image
              gt_color = self.random_recolor(target)
  
              # Blend image and color segmentation map
              alpha = torch.rand(1).item() * 0.29 + 0.7  # Random alpha between 0.7 and 0.99
              blended_img = alpha * img + (1 - alpha) * gt_color.float() / 255.0
              return blended_img, target
          
          return img, target  # If not applying transformation, return original


transform = Compose([
    ToImage(),
    Resize((256, 256)),
    ToDtype(torch.float32, scale=True),
    Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    PaintingByNumbersTransform(),
    RandomHorizontalFlip(p=0.5),
    RandomVerticalFlip(p=0.5),
])

# Load the dataset and make a split for training and validation
train_dataset = Cityscapes(
    "data/cityscapes", 
    split="train", 
    mode="fine", 
    target_type="semantic", 
    transforms=transform
)
valid_dataset = Cityscapes(
    "data/cityscapes", 
    split="val", 
    mode="fine", 
    target_type="semantic", 
    transforms=transform
)

train_dataset = wrap_dataset_for_transforms_v2(train_dataset)
valid_dataset = wrap_dataset_for_transforms_v2(valid_dataset)

train_dataloader = DataLoader(
    train_dataset, 
    batch_size=64, 
    shuffle=True,
    num_workers=10
)
valid_dataloader = DataLoader(
    valid_dataset, 
    batch_size=64, 
    shuffle=False,
    num_workers=10
)

In [8]:
# Define the loss function
criterion = nn.CrossEntropyLoss(ignore_index=255)  # Ignore the void class

# Define the optimizer
lr1 = 0.01
lr2 = 0.001

model = Model().to(device)

optimizer1 = AdamW([
{"params": model.model.classifier.parameters(), "lr": lr1}  # Higher LR for classifier
])

optimizer2 = AdamW([
{"params": model.model.backbone.parameters(), "lr": lr2}  # Lower LR for backbone
])

scheduler = lr_scheduler.StepLR(optimizer1, 2, gamma=0.5, last_epoch=-1)



for param in model.model.backbone.parameters():
    param.requires_grad = True  # Unfreeze the backbone
    
for param in model.model.classifier.parameters():
    param.requires_grad = True



In [10]:
# Training loop
best_valid_loss = float('inf')
current_best_model_path = None
for epoch in range(10):
    print(f"Epoch {epoch+1:04}/{10:04}")

    # Training
    model.train()
    for i, (images, labels) in enumerate(train_dataloader):

        labels = convert_to_train_id(labels)  # Convert class IDs to train IDs
        images, labels = images.to(device), labels.to(device)

        labels = labels.long().squeeze(1)  # Remove channel dimension

        optimizer1.zero_grad()
        outputs = model.model(images)['out']
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer1.step()
        
    # Validation
    model.eval()
    with torch.no_grad():
        losses = []
        for i, (images, labels) in enumerate(valid_dataloader):
            labels = convert_to_train_id(labels)  # Convert class IDs to train IDs
            images, labels = images.to(device), labels.to(device)

            labels = labels.long().squeeze(1)  # Remove channel dimension

            outputs = model.model(images)['out']
            loss = criterion(outputs, labels)
            losses.append(loss.item())
        
            if i == 0:
                predictions = outputs.softmax(1).argmax(1)

                predictions = predictions.unsqueeze(1)
                labels = labels.unsqueeze(1)

                predictions = convert_train_id_to_color(predictions)
                labels = convert_train_id_to_color(labels)

                predictions_img = make_grid(predictions.cpu(), nrow=8)
                labels_img = make_grid(labels.cpu(), nrow=8)

                predictions_img = predictions_img.permute(1, 2, 0).numpy()
                labels_img = labels_img.permute(1, 2, 0).numpy()

        
        valid_loss = sum(losses) / len(losses)
       
        # if valid_loss < best_valid_loss:
        #     best_valid_loss = valid_loss
        #     if current_best_model_path:
        #         os.remove(current_best_model_path)
        #     current_best_model_path = os.path.join(
        #         output_dir, 
        #         f"best_model-epoch={epoch:04}-val_loss={valid_loss:04}.pth"
        #     )
        #     torch.save(model.state_dict(), current_best_model_path)
    
print("Training complete!")

Epoch 0001/0010
Epoch 0002/0010


KeyboardInterrupt: 

In [None]:
torch.save(model.state_dict(), f"model_from_notebook.pth")

In [18]:
from diceloss import DiceLoss
# Get one batch from the dataloader
model.eval()  # No dropout or batchnorm updates
with torch.no_grad():
    for images, labels in train_dataloader:
        labels = convert_to_train_id(labels)
        images, labels = images.to(device), labels.to(device)
        labels = labels.long().squeeze(1)
        
        outputs = model.model(images)['out']
        break

In [19]:
torch.save({
    'images': images.cpu(),
    'labels': labels.cpu(),
    'outputs': outputs.cpu()
}, 'batch_dump.pt')