In [1]:
import os
from PIL import Image
import torch
import torch.nn as nn
from torchvision import transforms
from torchvision import models
from torchvision.datasets import Cityscapes
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

# pyton 3.9.11 / cuda / windows
# pip install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu121 
# ipython kernel install --name "cuda-train" --user
# pip install ipython matplotlib
# print("Cuda available: ", torch.cuda.is_available())
# print("Device name:", torch.cuda.get_device_name())

Cuda available:  True
Device name: NVIDIA GeForce RTX 2080 SUPER


In [2]:
data_path = "/Users/severin/Documents/GitHub/u-net-segmentation-of-streets-and-cars/train/cityscapes"

In [3]:
def transform_image(img):
    return torch.from_numpy(np.array(img)).long()

In [4]:
# Define your data transformation (you might need to customize these)
data_transforms = transforms.Compose([
    transforms.Resize((256, 256)),  # Resize images
    transforms.ToTensor(),  # Convert to tensor
])

def transform_image(img):
    return torch.from_numpy(np.array(img)).long()

# Define your target transformation
target_transforms = transforms.Compose([
 transforms.Resize((256, 256)), # Resize target
 transforms.Lambda(lambda img: torch.Tensor(np.array(img))), 
])

# Create Cityscapes dataset instance
dataset = Cityscapes(root=data_path, split='train', mode='fine', 
                     target_type='semantic', transform=data_transforms, 
                     target_transform=target_transforms)

In [None]:
# Visualize the first image and its segmentation mask
img, smnt = dataset[0]

# Convert tensors to numpy arrays
img_np = img.permute(1, 2, 0).numpy()
smnt_np = np.array(smnt)

# Plotting
plt.figure(figsize=(10, 5))

# Display the image
plt.subplot(1, 2, 1)
plt.title('Image')
plt.imshow(img_np)
plt.axis('off')

# Display the segmentation mask
plt.subplot(1, 2, 2)
plt.title('Segmentation Mask')
plt.imshow(smnt_np)
plt.axis('off')

plt.tight_layout()
plt.show()

In [6]:
# Create DataLoader for easy iteration
batch_size = 32 
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)
                  
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # Use GPU if available

# Load a pretrained model
model = models.segmentation.fcn_resnet101(weights='DEFAULT') # Load a pretrained FCN-ResNet101 

# Replace the classifier part for your task
num_classes = len(Cityscapes.classes)
model.classifier[-1] = nn.Conv2d(512, num_classes, kernel_size=(1, 1), stride=(1, 1)) 

# Load previously saved model weights (for example, epoch 1)
checkpoint = torch.load('model_weights.pth')  # Change the file path to your desired checkpoint
model.load_state_dict(checkpoint)

# Ensure the model is in training mode
model.train()

# Move the model to the GPU
model = model.to(device)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss() 
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

In [7]:
num_epochs = None
epoch = 0

while True:
    epoch += 1
    print('Epoch {}/{}'.format(epoch, num_epochs if num_epochs else "∞"))
    print('-' * 10)
    
    total_batches = len(dataloader)
    batch_count = 0
    total_loss = 0

    for inputs, labels in dataloader:
        batch_count += 1
        inputs = inputs.to(device)
        labels = labels.to(device).long()  # Convert the type of labels to Long

        optimizer.zero_grad()
        outputs = model(inputs)
        outputs = outputs['out']

        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        batch_loss = loss.item()
        total_loss += batch_loss

        print('Batch {}/{} - Loss: {:.4f}'.format(batch_count, total_batches, batch_loss))

        del loss, outputs, labels  # Manually delete tensors to free memory

    average_loss = total_loss / total_batches

    print('Epoch completed. Average Loss: {:.4f}'.format(average_loss))

    # Save model weights after each epoch
    torch.save(model.state_dict(), f'model_weights_epoch_{epoch}.pth')

    if num_epochs and epoch >= num_epochs:
        break

print('Training completed.')

Epoch 1/∞
----------
