In [None]:
import torch
import os
import torchvision.models as models
import torch.nn as nn
from torch.optim import AdamW
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader
from torchvision.datasets import Cityscapes, wrap_dataset_for_transforms_v2
from torchvision.utils import make_grid
from torchvision.transforms.v2 import (
    Compose,
    Normalize,
    Resize,
    ToImage,
    ToDtype,
)


model = deeplabv3
def count_parameters(model):
    return sum(p.numel() for p in model.parameters())

for param in deeplabv3.backbone.parameters():
    param.requires_grad = False  # Freeze the early layers

for param in deeplabv3.backbone.layer4.parameters():  # Unfreeze only the last ResNet layer
    param.requires_grad = True

print(count_parameters(deeplabv3.backbone))
print(count_parameters(deeplabv3.classifier))

# Mapping class IDs to train IDs
id_to_trainid = {cls.id: cls.train_id for cls in Cityscapes.classes}
def convert_to_train_id(label_img: torch.Tensor) -> torch.Tensor:
    return label_img.apply_(lambda x: id_to_trainid[x])

# Mapping train IDs to color
train_id_to_color = {cls.train_id: cls.color for cls in Cityscapes.classes if cls.train_id != 255}
train_id_to_color[255] = (0, 0, 0)  # Assign black to ignored labels

def convert_train_id_to_color(prediction: torch.Tensor) -> torch.Tensor:
    batch, _, height, width = prediction.shape
    color_image = torch.zeros((batch, 3, height, width), dtype=torch.uint8)

    for train_id, color in train_id_to_color.items():
        mask = prediction[:, 0] == train_id

        for i in range(3):
            color_image[:, i][mask] = color[i]

    return color_image



23508032
16130323


In [None]:
#Import deeplabv3 and change last layers to 19 classes instead of 21
deeplabv3 = models.segmentation.deeplabv3_resnet50() #Use resnet50 because it is smaller than resnet101
deeplabv3.classifier[4] = nn.Conv2d(256, 19, kernel_size=(1, 1))
nn.init.xavier_normal_(deeplabv3.classifier[4].weight) #Initialize weights
deeplab.backbone.layer4[0].conv2.dilation = (2, 2) #change  to stride 16
deeplab.backbone.layer4[0].conv2.padding = (2, 2)
deeplab.backbone.layer4[0].downsample[0].stride = (1, 1)  # Prevents downsampling

In [3]:
print(Cityscapes.classes);

[CityscapesClass(name='unlabeled', id=0, train_id=255, category='void', category_id=0, has_instances=False, ignore_in_eval=True, color=(0, 0, 0)), CityscapesClass(name='ego vehicle', id=1, train_id=255, category='void', category_id=0, has_instances=False, ignore_in_eval=True, color=(0, 0, 0)), CityscapesClass(name='rectification border', id=2, train_id=255, category='void', category_id=0, has_instances=False, ignore_in_eval=True, color=(0, 0, 0)), CityscapesClass(name='out of roi', id=3, train_id=255, category='void', category_id=0, has_instances=False, ignore_in_eval=True, color=(0, 0, 0)), CityscapesClass(name='static', id=4, train_id=255, category='void', category_id=0, has_instances=False, ignore_in_eval=True, color=(0, 0, 0)), CityscapesClass(name='dynamic', id=5, train_id=255, category='void', category_id=0, has_instances=False, ignore_in_eval=True, color=(111, 74, 0)), CityscapesClass(name='ground', id=6, train_id=255, category='void', category_id=0, has_instances=False, ignore_

[CityscapesClass(name='unlabeled', id=0, train_id=255, category='void', category_id=0, has_instances=False, ignore_in_eval=True, color=(0, 0, 0)),                  Ignored 255
CityscapesClass(name='ego vehicle', id=1, train_id=255, category='void', category_id=0, has_instances=False, ignore_in_eval=True, color=(0, 0, 0)),                 Ignored 255
CityscapesClass(name='rectification border', id=2, train_id=255, category='void', category_id=0, has_instances=False, ignore_in_eval=True, color=(0, 0, 0)),        Ignored 255
CityscapesClass(name='out of roi', id=3, train_id=255, category='void', category_id=0, has_instances=False, ignore_in_eval=True, color=(0, 0, 0)),                  Ignored 255
CityscapesClass(name='static', id=4, train_id=255, category='void', category_id=0, has_instances=False, ignore_in_eval=True, color=(0, 0, 0)),                      Ignored 255
CityscapesClass(name='dynamic', id=5, train_id=255, category='void', category_id=0, has_instances=False, ignore_in_eval=True, color=(111, 74, 0)),                  Ignored 255
CityscapesClass(name='ground', id=6, train_id=255, category='void', category_id=0, has_instances=False, ignore_in_eval=True, color=(81, 0, 81)),                    Ignored 255
CityscapesClass(name='road', id=7, train_id=0, category='flat', category_id=1, has_instances=False, ignore_in_eval=False, color=(128, 64, 128)),                    0
CityscapesClass(name='sidewalk', id=8, train_id=1, category='flat', category_id=1, has_instances=False, ignore_in_eval=False, color=(244, 35, 232)),                1
CityscapesClass(name='parking', id=9, train_id=255, category='flat', category_id=1, has_instances=False, ignore_in_eval=True, color=(250, 170, 160)),               Ignored 255
CityscapesClass(name='rail track', id=10, train_id=255, category='flat', category_id=1, has_instances=False, ignore_in_eval=True, color=(230, 150, 140)),           Ignored 255
CityscapesClass(name='building', id=11, train_id=2, category='construction', category_id=2, has_instances=False, ignore_in_eval=False, color=(70, 70, 70)),         2
CityscapesClass(name='wall', id=12, train_id=3, category='construction', category_id=2, has_instances=False, ignore_in_eval=False, color=(102, 102, 156)),          3
CityscapesClass(name='fence', id=13, train_id=4, category='construction', category_id=2, has_instances=False, ignore_in_eval=False, color=(190, 153, 153)),         4
CityscapesClass(name='guard rail', id=14, train_id=255, category='construction', category_id=2, has_instances=False, ignore_in_eval=True, color=(180, 165, 180)),   Ignored 255
CityscapesClass(name='bridge', id=15, train_id=255, category='construction', category_id=2, has_instances=False, ignore_in_eval=True, color=(150, 100, 100)),       Ignored 255
CityscapesClass(name='tunnel', id=16, train_id=255, category='construction', category_id=2, has_instances=False, ignore_in_eval=True, color=(150, 120, 90)),        5
CityscapesClass(name='pole', id=17, train_id=5, category='object', category_id=3, has_instances=False, ignore_in_eval=False, color=(153, 153, 153)),                Ignored 255
CityscapesClass(name='polegroup', id=18, train_id=255, category='object', category_id=3, has_instances=False, ignore_in_eval=True, color=(153, 153, 153)),          Ignored 255
CityscapesClass(name='traffic light', id=19, train_id=6, category='object', category_id=3, has_instances=False, ignore_in_eval=False, color=(250, 170, 30)),        6
CityscapesClass(name='traffic sign', id=20, train_id=7, category='object', category_id=3, has_instances=False, ignore_in_eval=False, color=(220, 220, 0)),          7
CityscapesClass(name='vegetation', id=21, train_id=8, category='nature', category_id=4, has_instances=False, ignore_in_eval=False, color=(107, 142, 35)),           8
CityscapesClass(name='terrain', id=22, train_id=9, category='nature', category_id=4, has_instances=False, ignore_in_eval=False, color=(152, 251, 152)),             9
CityscapesClass(name='sky', id=23, train_id=10, category='sky', category_id=5, has_instances=False, ignore_in_eval=False, color=(70, 130, 180)),                    10
CityscapesClass(name='person', id=24, train_id=11, category='human', category_id=6, has_instances=True, ignore_in_eval=False, color=(220, 20, 60)),                 11
CityscapesClass(name='rider', id=25, train_id=12, category='human', category_id=6, has_instances=True, ignore_in_eval=False, color=(255, 0, 0)),                    12
CityscapesClass(name='car', id=26, train_id=13, category='vehicle', category_id=7, has_instances=True, ignore_in_eval=False, color=(0, 0, 142)),                    13
CityscapesClass(name='truck', id=27, train_id=14, category='vehicle', category_id=7, has_instances=True, ignore_in_eval=False, color=(0, 0, 70)),                   14
CityscapesClass(name='bus', id=28, train_id=15, category='vehicle', category_id=7, has_instances=True, ignore_in_eval=False, color=(0, 60, 100)),                   15
CityscapesClass(name='caravan', id=29, train_id=255, category='vehicle', category_id=7, has_instances=True, ignore_in_eval=True, color=(0, 0, 90)),                 Ignored 255
CityscapesClass(name='trailer', id=30, train_id=255, category='vehicle', category_id=7, has_instances=True, ignore_in_eval=True, color=(0, 0, 110)),                Ignored 255
CityscapesClass(name='train', id=31, train_id=16, category='vehicle', category_id=7, has_instances=True, ignore_in_eval=False, color=(0, 80, 100)),                 16
CityscapesClass(name='motorcycle', id=32, train_id=17, category='vehicle', category_id=7, has_instances=True, ignore_in_eval=False, color=(0, 0, 230)),             17
CityscapesClass(name='bicycle', id=33, train_id=18, category='vehicle', category_id=7, has_instances=True, ignore_in_eval=False, color=(119, 11, 32)),              18
CityscapesClass(name='license plate', id=-1, train_id=-1, category='vehicle', category_id=7, has_instances=False, ignore_in_eval=True, color=(0, 0, 142))]          -1

Classes: Road, sidewalk, building, wall, fence, tunnel, traffic light, traffic sign, vegetation terrain, sky, person, rider, car, truck, train, motorcycle, bicycle, license plate
Excluded: Unlabeled, ego vehicle, rectification border, out or roi, static, dynamic, ground, parking, rail track, guard rail, bridge, pole, polegroup, caravan, trailer

In [4]:
# Set seed for reproducability
# If you add other sources of randomness (NumPy, Random), 
# make sure to set their seeds as well
torch.manual_seed(42)
torch.backends.cudnn.deterministic = True


In [17]:
randomgetal = torch.rand(100)*0.29+0.7
randomgetal


tensor([0.7015, 0.7890, 0.7338, 0.9640, 0.8868, 0.9051, 0.8909, 0.8425, 0.9585,
        0.7420, 0.8541, 0.7460, 0.8897, 0.7951, 0.8894, 0.8148, 0.9653, 0.7591,
        0.7585, 0.7585, 0.9754, 0.8933, 0.9845, 0.7253, 0.7012, 0.7316, 0.7475,
        0.9037, 0.8969, 0.9655, 0.7701, 0.7462, 0.9219, 0.7864, 0.9330, 0.8106,
        0.9279, 0.7323, 0.7718, 0.8892, 0.8757, 0.8080, 0.9314, 0.9436, 0.7398,
        0.7676, 0.9778, 0.7961, 0.7936, 0.7047, 0.7620, 0.8812, 0.8259, 0.7397,
        0.8484, 0.7460, 0.7220, 0.7652, 0.7181, 0.7527, 0.9899, 0.8724, 0.8897,
        0.7098, 0.7498, 0.7967, 0.8677, 0.7174, 0.7825, 0.7582, 0.8454, 0.7910,
        0.8350, 0.7467, 0.7455, 0.7604, 0.7954, 0.7306, 0.9666, 0.8162, 0.9698,
        0.8902, 0.7222, 0.9453, 0.8051, 0.7894, 0.7246, 0.7008, 0.8865, 0.8133,
        0.9015, 0.7260, 0.9527, 0.7386, 0.8200, 0.8753, 0.9199, 0.9621, 0.9771,
        0.7300])

In [6]:
# Define the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

In [None]:
# Define the transforms to apply to the data

transform = Compose([
    ToImage(),
    Resize((256, 256)),
    ToDtype(torch.float32, scale=True),
    Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

# Load the dataset and make a split for training and validation
train_dataset = Cityscapes(
    "data/cityscapes", 
    split="train", 
    mode="fine", 
    target_type="semantic", 
    transforms=transform
)
valid_dataset = Cityscapes(
    "data/cityscapes", 
    split="val", 
    mode="fine", 
    target_type="semantic", 
    transforms=transform
)

train_dataset = wrap_dataset_for_transforms_v2(train_dataset)
valid_dataset = wrap_dataset_for_transforms_v2(valid_dataset)

 train_dataloader = DataLoader(
    train_dataset, 
    batch_size=64, 
    shuffle=True,
    num_workers=9
)
valid_dataloader = DataLoader(
    valid_dataset, 
    batch_size=64, 
    shuffle=False,
    num_workers=9
)

RuntimeError: Dataset not found or incomplete. Please make sure all required folders for the specified "split" and "mode" are inside the "root" directory

In [12]:
from torch.optim import lr_scheduler
# Define the model
model = deeplabv3.to(device)

# Define the loss function
criterion = nn.CrossEntropyLoss(ignore_index=255)  # Ignore the void class

# Define the optimizer
optimizer = AdamW(model.classifier.parameters(), lr=0.001)

scheduler = lr_scheduler.MultiplicativeLR(optimizer, 0.7)



In [23]:
import os

# Training loop
best_valid_loss = float('inf')
current_best_model_path = None
for epoch in range(10):
    print(f"Epoch {epoch+1:04}/{10:04}")

    # Training
    model.train()
    for i, (images, labels) in enumerate(train_dataloader):

        labels = convert_to_train_id(labels)  # Convert class IDs to train IDs
        images, labels = images.to(device), labels.to(device)

        labels = labels.long().squeeze(1)  # Remove channel dimension

        optimizer.zero_grad()
        outputs = model(images)['out']
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
    # Validation
    model.eval()
    with torch.no_grad():
        losses = []
        for i, (images, labels) in enumerate(valid_dataloader):

            labels = convert_to_train_id(labels)  # Convert class IDs to train IDs
            images, labels = images.to(device), labels.to(device)

            labels = labels.long().squeeze(1)  # Remove channel dimension

            outputs = model(images)['out']
            loss = criterion(outputs, labels)
            losses.append(loss.item())
        
            if i == 0:
                predictions = outputs.softmax(1).argmax(1)

                predictions = predictions.unsqueeze(1)
                labels = labels.unsqueeze(1)

                predictions = convert_train_id_to_color(predictions)
                labels = convert_train_id_to_color(labels)

                predictions_img = make_grid(predictions.cpu(), nrow=8)
                labels_img = make_grid(labels.cpu(), nrow=8)

                predictions_img = predictions_img.permute(1, 2, 0).numpy()
                labels_img = labels_img.permute(1, 2, 0).numpy()

        
        valid_loss = sum(losses) / len(losses)
       
        if valid_loss < best_valid_loss:
            best_valid_loss = valid_loss
            if current_best_model_path:
                os.remove(current_best_model_path)
            current_best_model_path = os.path.join(
                output_dir, 
                f"best_model-epoch={epoch:04}-val_loss={valid_loss:04}.pth"
            )
            torch.save(model.state_dict(), current_best_model_path)
    
print("Training complete!")

# Save the model
torch.save(
    model.state_dict(),
    os.path.join(
        output_dir,
        f"final_model-epoch={epoch:04}-val_loss={valid_loss:04}.pth"
    )
)


Epoch 0001/0010


NameError: name 'wandb' is not defined