In [None]:
import os
import random 

import pandas as pd
import numpy as np
import cv2
import torch
from torch.utils.data import Dataset
from tqdm import tqdm
import segmentation_models_pytorch as smp
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import albumentations as A
from albumentations.pytorch import ToTensorV2

data_df = pd.read_csv("./roads_dataset/metadata.csv")

train_df = data_df.loc[data_df['split'] == 'train']

val_df = data_df.loc[data_df['split'] == 'val']

test_df = data_df.loc[data_df['split'] == 'test']

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
class RoadsDataset(Dataset):
    def __init__(self, df, is_train=False):
        self.df = df
        self.is_train = is_train
        if not is_train:
            self.patches = []
            for idx in range(len(self.df)):
                for y in range(0, 1536-512+1, 512):
                    for x in range(0, 1536-512+1, 512):
                        self.patches.append((idx,x,y))
    def __len__(self):
        if self.is_train:
            return len(self.df)
        else:
            return len(self.patches)
    
    def get_transforms(self):
        if self.is_train:
            trans = A.Compose([
                A.RandomCrop(512,512),
                A.SquareSymmetry(p=0.5),
                
                A.OneOf([
                    A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=0.5),
                    A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1, p=0.6),
                    A.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, p=0.5),
                    A.RandomGamma(gamma_limit=(80, 120), p=0.6),
                    A.Sharpen(p=0.6)
                ], p=0.7),
                                 
                A.OneOf([
                    A.GaussNoise(std_range=(0.1, 0.2), p=0.5),
                    A.ISONoise(color_shift=(0.01, 0.05), intensity=(0.1, 0.5), p=0.5),
                    A.MultiplicativeNoise(multiplier=(0.9, 1.1), per_channel=True, p=0.5),
                    A.SaltAndPepper(p=0.5)
                ], p=0.3),
                
                A.Normalize(mean=(0.485, 0.456, 0.406),
                    std=(0.229, 0.224, 0.225)),
                
                ToTensorV2()
            ])
            return trans
        else:
            trans = A.Compose([
                A.Normalize(mean=(0.485, 0.456, 0.406),
                    std=(0.229, 0.224, 0.225)),
                
                ToTensorV2()
            ])
            return trans
        
    def __getitem__(self, index):
        if self.is_train:
            img_path = self.df.iloc[index]['tiff_image_path']
            mask_path = self.df.iloc[index]['tif_label_path']       
            data_dir = './roads_dataset'
            img_path = os.path.join(data_dir, img_path)
            mask_path = os.path.join(data_dir, mask_path)

            orig_image = cv2.imread(img_path)
            orig_image = cv2.cvtColor(orig_image, cv2.COLOR_BGR2RGB)
            orig_mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
            
            image = cv2.resize(orig_image, (1536,1536), interpolation=cv2.INTER_CUBIC)
            mask = cv2.resize(orig_mask, (1536,1536), interpolation=cv2.INTER_NEAREST)
            mask = (mask == 255).astype('float32')

            trans = self.get_transforms()
        else:
            img_idx,x,y = self.patches[index]
            
            img_path = self.df.iloc[img_idx]['tiff_image_path']
            mask_path = self.df.iloc[img_idx]['tif_label_path']       
            data_dir = './roads_dataset'
            img_path = os.path.join(data_dir, img_path)
            mask_path = os.path.join(data_dir, mask_path)

            orig_image = cv2.imread(img_path)
            orig_image = cv2.cvtColor(orig_image, cv2.COLOR_BGR2RGB)
            orig_mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
            
            image = cv2.resize(orig_image, (1536,1536), interpolation=cv2.INTER_CUBIC)
            mask = cv2.resize(orig_mask, (1536,1536), interpolation=cv2.INTER_NEAREST)
            mask = (mask == 255).astype('float32')

            image = image[x:x+512, y:y+512, :]
            mask = mask[x:x+512, y:y+512]
            
            trans = self.get_transforms()
            
        augmented = trans(image=image,mask=mask)
        image = augmented['image']
        mask = augmented['mask'].unsqueeze(0)
        
        return image,mask


In [None]:
def train(model, train_ds, val_ds, device, num_epochs=200, batch_size=16, lr=1e-3, save_path='roads_seg.pth'):
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True, persistent_workers=True)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=8,pin_memory=True, persistent_workers=True)
    
    model = model.to(device)
    optimizer = torch.optim.Adam([
        {"params": model.encoder.layer1.parameters(), "lr": lr * 0.1},
        {"params": model.encoder.layer2.parameters(), "lr": lr * 0.1},
        {"params": model.encoder.layer3.parameters(), "lr": lr * 0.3},
        {"params": model.encoder.layer4.parameters(), "lr": lr * 0.5},
        {"params": model.decoder.parameters(), "lr": lr},
    ], lr=lr)     

    criterion = torch.nn.BCEWithLogitsLoss()
    
    train_losses, val_losses, val_ious = [], [], []

    best_val_loss = float('inf')
    for epoch in range(1, num_epochs + 1):
        model.train()
        total_loss = 0

        train_bar = tqdm(train_loader, desc=f"[Epoch {epoch}] Training")
        for imgs, labels in train_bar:
            imgs, labels = imgs.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            
            total_loss += loss.item()

            train_bar.set_postfix(loss=loss.item())


        avg_train_loss = total_loss / len(train_loader)

        model.eval()
        val_loss, val_iou, count = 0, 0, 0
        val_bar = tqdm(val_loader, desc=f"[Epoch {epoch}] Validation")
        with torch.no_grad():
            for imgs, labels in val_bar:
                imgs, labels = imgs.to(device), labels.to(device)
                outputs = model(imgs)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                val_bar.set_postfix(loss=loss.item())

                probs = torch.sigmoid(outputs)
                pred_masks = (probs > 0.5).float()
                for pred, target in zip(pred_masks, labels):
                    intersection = torch.logical_and(pred, target).sum().item()
                    union = torch.logical_or(pred, target).sum().item()
                    if union > 0:
                        val_iou += intersection / union
                        count += 1

        avg_val_loss = val_loss / len(val_loader)
        avg_val_iou = val_iou / count if count > 0 else 0

        train_losses.append(avg_train_loss)
        val_losses.append(avg_val_loss)
        val_ious.append(avg_val_iou)

        print(f"Epoch {epoch}/{num_epochs} — Train Loss: {avg_train_loss:.4f} — Val Loss: {avg_val_loss:.4f} — Val IoU: {avg_val_iou:.2f}")

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), save_path)
            print(f"Saved new best model at epoch {epoch}")

    return train_losses, val_losses, val_ious

In [None]:
model = smp.Unet("resnet34", encoder_weights="imagenet", classes=1)
train_dataset = RoadsDataset(train_df, is_train=True)
val_dataset = RoadsDataset(val_df, is_train=False)

train_losses, val_losses, val_ious = train(model, train_dataset, val_dataset, device)

In [None]:
def inference_one_image(model, dataset, img_idx):
    model.eval()
    with torch.no_grad():
        full_pred = torch.zeros((1536, 1536), dtype=torch.float32).to(device)

        for idx, (patch_idx, y, x) in enumerate(dataset.patches):
            if patch_idx != img_idx:
                continue  

            image, _ = dataset[idx] 
            image = image.unsqueeze(0).to(device) 

            pred = model(image)  
            pred = torch.sigmoid(pred)
            pred = (pred>0.5).float()
            pred = pred.squeeze(0).squeeze(0) 

            full_pred[y:y+512, x:x+512] += pred

        return full_pred.cpu().numpy()

In [None]:
test_dataset = RoadsDataset(df=test_df, is_train=False)
model = smp.Unet("resnet34", encoder_weights="imagenet", classes=1)
model.to(device)
model.load_state_dict(torch.load('roads_seg.pth'))

model.eval()

ious = []
for i in range(len(test_df)):
    full_pred = inference_one_image(model,test_dataset,i)
    
    mask_path = test_df.iloc[i]['tif_label_path']       
    data_dir = './roads_dataset'
    mask_path = os.path.join(data_dir, mask_path)
    true_mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
    true_mask = (true_mask == 255).astype('float32')
    true_mask = cv2.resize(true_mask, (1536,1536), interpolation=cv2.INTER_NEAREST)
    
    intersection = np.logical_and(full_pred, true_mask).sum()
    union = np.logical_or(full_pred, true_mask).sum()
    
    iou = intersection/union
    ious.append(iou)

print(f"Test IoU: {np.mean(ious)}")