In [1]:
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.datasets import Cityscapes

In [2]:
# Load device (use CUDA if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
# Load pre-trained DeepLabv3 model
model = torchvision.models.segmentation.deeplabv3_resnet101(pretrained=True)

Downloading: "https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth" to /Users/rajkrishnanv/.cache/torch/hub/checkpoints/deeplabv3_resnet101_coco-586e9e4e.pth


  0%|          | 0.00/233M [00:00<?, ?B/s]

In [4]:
# Freeze all the layers in the model
for param in model.parameters():
    param.requires_grad = False

In [5]:
model.classifier

DeepLabHead(
  (0): ASPP(
    (convs): ModuleList(
      (0): Sequential(
        (0): Conv2d(2048, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (1): ASPPConv(
        (0): Conv2d(2048, 256, kernel_size=(3, 3), stride=(1, 1), padding=(12, 12), dilation=(12, 12), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (2): ASPPConv(
        (0): Conv2d(2048, 256, kernel_size=(3, 3), stride=(1, 1), padding=(24, 24), dilation=(24, 24), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (3): ASPPConv(
        (0): Conv2d(2048, 256, kernel_size=(3, 3), stride=(1, 1), padding=(36, 36), dilation=(36, 36), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_

In [6]:
# Replace the last classification layer with a new one
num_classes = 10  # Example number of classes
model.classifier = nn.Sequential(
    nn.Conv2d(2048, num_classes, kernel_size=1),
    nn.ReLU(inplace=True),
    nn.BatchNorm2d(num_classes),
    nn.Conv2d(num_classes, num_classes, kernel_size=3, stride=1, padding=1),
    nn.ReLU(inplace=True),
    nn.BatchNorm2d(num_classes),
    nn.Conv2d(num_classes, num_classes, kernel_size=3, stride=1, padding=1),
    nn.ReLU(inplace=True),
    nn.BatchNorm2d(num_classes),
    nn.Conv2d(num_classes, num_classes, kernel_size=1)
)

In [8]:
# Load dataset for finetuning (example)
# follow . The images have to be loaded in to a range of [0, 1] and then normalized using 
# mean = [0.485, 0.456, 0.406] and std = [0.229, 0.224, 0.225]. They have been trained on 
# images resized such that their minimum size is 520.



In [None]:
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)

In [None]:
# Set loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [None]:
# Train the model
num_epochs = 10  # Example number of epochs
model.to(device)
for epoch in range(num_epochs):
    for i, batch in enumerate(train_loader):
        # Get batch data and move to device
        inputs, labels = batch
        inputs, labels = inputs.to(device), labels.to(device)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass and calculate loss
        outputs = model(inputs)['out']
        loss = criterion(outputs, labels)

        # Backward pass and update parameters
        loss.backward()
        optimizer.step()

        # Print loss and accuracy every 10 batches
        if i % 10 == 0:
            acc = (outputs.argmax(1) == labels).float().mean()
            print(f"Epoch {epoch}, Batch {i}: Loss={loss.item():.4f}, Accuracy={acc:.4f}")


In [None]:
# Save the model
torch.save(model.state_dict(), 'fine_tuned_unet.pt')