In [None]:
!pip install segmentation-models-pytorch

In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [66]:
import os
import sys

GOOGLE_DRIVE_PATH = "path/to/your/google/drive/folder/"
sys.path.append(GOOGLE_DRIVE_PATH)

In [67]:
import os
import glob
import numpy as np
from PIL import Image
import torch
from torch.utils.data import DataLoader, Dataset, Subset
from sklearn.model_selection import KFold
import segmentation_models_pytorch as smp

In [68]:
class CityscapesDataset(Dataset):
    def __init__(self, images, targets, preprocessing=None, transform=None, target_transform=None):
        self.images = images
        self.targets = targets
        self.preprocessing = preprocessing
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        img_path = self.images[index]
        label_path = self.targets[index]

        image = Image.open(img_path).convert('RGB')
        label = Image.open(label_path).convert('L')

        if self.transform:
            image = self.transform(image)

        if self.preprocessing:
            image = self.preprocessing(np.array(image))
            image = torch.tensor(image).permute(2, 0, 1).float()

        if self.target_transform:
            label = self.target_transform(label)
        else:
            label = torch.tensor(np.array(label), dtype=torch.long)

        return image, label

In [69]:
preprocessing_fn = smp.encoders.get_preprocessing_fn("resnet34", "imagenet")

In [70]:
def get_data_filepaths(root, split):
    images = sorted(glob.glob(os.path.join(root, 'leftImg8bit', split, '**', '*_leftImg8bit.png'), recursive=True))
    targets = sorted(glob.glob(os.path.join(root, 'gtFine', split, '**', '*_gtFine_labelTrainIds.png'), recursive=True))
    return images, targets

In [83]:
dataset_path = os.path.join(GOOGLE_DRIVE_PATH, 'cityscapes')

train_images, train_targets = get_data_filepaths(dataset_path, 'train')
val_images, val_targets = get_data_filepaths(dataset_path, 'val')

images = train_images + val_images
targets = train_targets + val_targets

In [85]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

In [86]:
train_datasets = []
val_datasets = []
kfold = KFold(n_splits=5, shuffle=True, random_state=42)

for fold, (train_idx, val_idx) in enumerate(kfold.split(images)):
    train_images = [images[i] for i in train_idx]
    train_targets = [targets[i] for i in train_idx]
    val_images = [images[i] for i in val_idx]
    val_targets = [targets[i] for i in val_idx]

    train_datasets.append(CityscapesDataset(train_images, train_targets, preprocessing=preprocessing_fn))
    val_datasets.append(CityscapesDataset(val_images, val_targets, preprocessing=preprocessing_fn))



Fold 1
Fold 2
Fold 3
Fold 4
Fold 5


In [None]:
for fold in range(5):
    print(f'Fold {fold+1}')

    train_dataset = train_datasets[fold]
    val_dataset = val_datasets[fold]

    train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=8)
    val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=4)

    model = smp.Unet(
        encoder_name="resnet34",
        encoder_weights="imagenet",
        classes=19,
        activation=None
    )
    model = model.to(device)

    criterion = torch.nn.CrossEntropyLoss(ignore_index=255)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    best_val_loss = float('inf')
    model_file_path = os.path.join(GOOGLE_DRIVE_PATH, f'segmentation_fold{fold+1}.pth')

    num_epochs = 10
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for images, labels in train_loader:
            images = images.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {running_loss/len(train_loader)}")

        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for images, labels in val_loader:
                images = images.to(device)
                labels = labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item()

        average_val_loss = val_loss / len(val_loader)
        print(f'Validation Loss: {average_val_loss}')

        if average_val_loss < best_val_loss:
            best_val_loss = average_val_loss
            torch.save(model, model_file_path)
            print(f'Current best model saved with validation loss: {best_val_loss}')
