In [None]:
import torch
import os
import torchvision.models as models
import torch.nn as nn
from torch.optim import AdamW
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader
from torchvision.datasets import Cityscapes, wrap_dataset_for_transforms_v2
from torchvision.utils import make_grid
from torchvision.transforms.v2 import (
    Compose,
    Normalize,
    Resize,
    ToImage,
    ToDtype,
)

#Import deeplabv3 and change last layers to 19 classes instead of 21
deeplabv3 = models.segmentation.deeplabv3_resnet50() #Use resnet50 because it is smaller than resnet101
deeplabv3.classifier[4] = nn.Conv2d(256, 19, kernel_size=(1, 1))
nn.init.xavier_normal_(deeplabv3.classifier[4].weight) #Initialize weights

model = deeplabv3
def count_parameters(model):
    return sum(p.numel() for p in model.parameters())

for param in deeplabv3.backbone.parameters():
    param.requires_grad = False  # Freeze the early layers

for param in deeplabv3.backbone.layer4.parameters():  # Unfreeze only the last ResNet layer
    param.requires_grad = True

print(count_parameters(deeplabv3.backbone))
print(count_parameters(deeplabv3.classifier))

# Mapping class IDs to train IDs
id_to_trainid = {cls.id: cls.train_id for cls in Cityscapes.classes}
def convert_to_train_id(label_img: torch.Tensor) -> torch.Tensor:
    return label_img.apply_(lambda x: id_to_trainid[x])

# Mapping train IDs to color
train_id_to_color = {cls.train_id: cls.color for cls in Cityscapes.classes if cls.train_id != 255}
train_id_to_color[255] = (0, 0, 0)  # Assign black to ignored labels

def convert_train_id_to_color(prediction: torch.Tensor) -> torch.Tensor:
    batch, _, height, width = prediction.shape
    color_image = torch.zeros((batch, 3, height, width), dtype=torch.uint8)

    for train_id, color in train_id_to_color.items():
        mask = prediction[:, 0] == train_id

        for i in range(3):
            color_image[:, i][mask] = color[i]

    return color_image



In [5]:
# Set seed for reproducability
# If you add other sources of randomness (NumPy, Random), 
# make sure to set their seeds as well
torch.manual_seed(42)
torch.backends.cudnn.deterministic = True


In [6]:
# Define the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

In [None]:
# Define the transforms to apply to the data

transform = Compose([
    ToImage(),
    Resize((256, 256)),
    ToDtype(torch.float32, scale=True),
    Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

# Load the dataset and make a split for training and validation
train_dataset = Cityscapes(
    "data/cityscapes", 
    split="train", 
    mode="fine", 
    target_type="semantic", 
    transforms=transform
)
valid_dataset = Cityscapes(
    "data/cityscapes", 
    split="val", 
    mode="fine", 
    target_type="semantic", 
    transforms=transform
)

train_dataset = wrap_dataset_for_transforms_v2(train_dataset)
valid_dataset = wrap_dataset_for_transforms_v2(valid_dataset)

 train_dataloader = DataLoader(
    train_dataset, 
    batch_size=64, 
    shuffle=True,
    num_workers=9
)
valid_dataloader = DataLoader(
    valid_dataset, 
    batch_size=64, 
    shuffle=False,
    num_workers=9
)

RuntimeError: Dataset not found or incomplete. Please make sure all required folders for the specified "split" and "mode" are inside the "root" directory

In [12]:
from torch.optim import lr_scheduler
# Define the model
model = deeplabv3.to(device)

# Define the loss function
criterion = nn.CrossEntropyLoss(ignore_index=255)  # Ignore the void class

# Define the optimizer
optimizer = AdamW(model.classifier.parameters(), lr=0.001)

scheduler = lr_scheduler.MultiplicativeLR(optimizer, 0.7)



In [23]:
import os

# Training loop
best_valid_loss = float('inf')
current_best_model_path = None
for epoch in range(10):
    print(f"Epoch {epoch+1:04}/{10:04}")

    # Training
    model.train()
    for i, (images, labels) in enumerate(train_dataloader):

        labels = convert_to_train_id(labels)  # Convert class IDs to train IDs
        images, labels = images.to(device), labels.to(device)

        labels = labels.long().squeeze(1)  # Remove channel dimension

        optimizer.zero_grad()
        outputs = model(images)['out']
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
    # Validation
    model.eval()
    with torch.no_grad():
        losses = []
        for i, (images, labels) in enumerate(valid_dataloader):

            labels = convert_to_train_id(labels)  # Convert class IDs to train IDs
            images, labels = images.to(device), labels.to(device)

            labels = labels.long().squeeze(1)  # Remove channel dimension

            outputs = model(images)['out']
            loss = criterion(outputs, labels)
            losses.append(loss.item())
        
            if i == 0:
                predictions = outputs.softmax(1).argmax(1)

                predictions = predictions.unsqueeze(1)
                labels = labels.unsqueeze(1)

                predictions = convert_train_id_to_color(predictions)
                labels = convert_train_id_to_color(labels)

                predictions_img = make_grid(predictions.cpu(), nrow=8)
                labels_img = make_grid(labels.cpu(), nrow=8)

                predictions_img = predictions_img.permute(1, 2, 0).numpy()
                labels_img = labels_img.permute(1, 2, 0).numpy()

        
        valid_loss = sum(losses) / len(losses)
       
        if valid_loss < best_valid_loss:
            best_valid_loss = valid_loss
            if current_best_model_path:
                os.remove(current_best_model_path)
            current_best_model_path = os.path.join(
                output_dir, 
                f"best_model-epoch={epoch:04}-val_loss={valid_loss:04}.pth"
            )
            torch.save(model.state_dict(), current_best_model_path)
    
print("Training complete!")

# Save the model
torch.save(
    model.state_dict(),
    os.path.join(
        output_dir,
        f"final_model-epoch={epoch:04}-val_loss={valid_loss:04}.pth"
    )
)


Epoch 0001/0010


NameError: name 'wandb' is not defined