In [None]:
import os
cwd = os.getcwd().replace("\\", "/")
print(cwd)

In [None]:
import pickle
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from PIL import Image
import requests
import wandb
from sklearn.metrics import f1_score
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import torch.nn.functional as F
from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
from glob import glob
from albumentations import (
    HorizontalFlip,
    VerticalFlip,
    RandomRotate90,
    ShiftScaleRotate,
    RandomBrightnessContrast,
    CLAHE,
    HueSaturationValue,
    GaussNoise,
    GridDistortion,
    Compose,
    RandomCrop,
    Resize
)
import cv2
from torch.utils.data import ConcatDataset, DataLoader

BATCH_SIZE = 4

kaggle = True if cwd == "/kaggle/working" else False
data_path = "/kaggle/input/" if kaggle else cwd + "/../../data/"

# takes path of x and returns x and y as images
def get_label(x_path):
    x_path = x_path.replace("\\","/")
    if x_path.__contains__("massachusetts"):
        y_path = x_path.replace("tiff/train/", "tiff/train_labels/").replace(".tiff", ".tif")
    
    if x_path.__contains__("ethz") or x_path.__contains__("googlemaps"):
        y_path = x_path.replace("images/", "groundtruth/")
    
    if x_path.__contains__("deepglobe"):
        y_path = x_path.replace("sat.jpg", "mask.png")
    
    return Image.open(x_path), Image.open(y_path)

def save(model,optim,name):
    path = ("/kaggle/working/" if kaggle else "") + name + ".pth"
    torch.save({
        'model_state_dict' : model.state_dict(),
        'optimizer_state_dict' : optim.state_dict(),
    },path)

def load(model,optim, name):
    path = ("/kaggle/working/" if kaggle else "") + name + ".pth"
    checkpoint = torch.load(path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optim.load_state_dict(checkpoint['optimizer_state_dict'])

In [None]:
ENCODER = 'resnet50'
WEIGHTS = 'imagenet'
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {DEVICE}")

In [None]:
def get_geometric_transforms_mass():
    geometric_transforms = [
        HorizontalFlip(p=0.5),
        VerticalFlip(p=0.5),
        RandomRotate90(p=0.5),
        ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.2, rotate_limit=15, p=0.9, border_mode=cv2.BORDER_REFLECT),
        GridDistortion(p=0.5),
        # RandomCrop(height=400, width=400, p=1),
        Resize(height=416, width=416, p=1),
    ]
    return Compose(geometric_transforms, additional_targets={'mask':'image'})

def get_geometric_transforms_deepglobe():
    geometric_transforms = [
        HorizontalFlip(p=0.5),
        VerticalFlip(p=0.5),
        RandomRotate90(p=0.5),
        # ShiftScaleRotate(shift_limit=0.0625, scale_limit=0, rotate_limit=0, p=0.9, 
        #                 border_mode=cv2.BORDER_REFLECT),
        # GridDistortion(p=0.5),
        # RandomCrop(height=400, width=400, p=1),
        Resize(height=416, width=416, p=1),
    ]
    return Compose(geometric_transforms, additional_targets={'mask':'image'})
def get_geometric_transforms_official():
    geometric_transforms = [
        HorizontalFlip(p=0.5),
        VerticalFlip(p=0.5),
        RandomRotate90(p=0.5),
        Resize(height=416, width=416, p=1),
    ]
    return Compose(geometric_transforms, additional_targets={'mask':'image'})

# Do not use
def get_photometric_transforms():
    return None
    photometric_transforms = [
        RandomBrightnessContrast(p=0.5),
        CLAHE(p=0.5),
        HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, p=0.5),
        GaussNoise(p=0.5)
    ]
    return Compose(photometric_transforms)

In [None]:
feature_extractor:SegformerImageProcessor = SegformerImageProcessor.from_pretrained("nvidia/segformer-b5-finetuned-ade-640-640", size=416)

In [None]:
class CustomDataset(Dataset):
    def __init__(self, image_files, geometric_transform=None, photometric_transform=None):
        self.image_files = image_files
        self.geometric_transform = geometric_transform
        self.photometric_transform = photometric_transform

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        x_orig, y_orig = get_label(self.image_files[idx])

        x_orig:Image = x_orig.convert("RGB")
        y_orig:Image = y_orig.convert("RGB")

        if x_orig.size[0] != 416 or y_orig.size[0] != 416:
            x_orig = x_orig.resize((416, 416))
            y_orig = y_orig.resize((416, 416))
        
        x_orig_np = np.array(x_orig, dtype=np.uint8)
        y_orig_np = np.array(y_orig, dtype=np.uint8)

        # Apply geometric transforms
        x_augmented, y_augmented = x_orig_np.copy(), y_orig_np.copy()
        if self.geometric_transform:
            augmented = self.geometric_transform(image=x_augmented.copy(), mask=y_augmented.copy())
            x_augmented, y_augmented = augmented['image'], augmented['mask']

        # Apply photometric transforms
        # if self.photometric_transform:
        #    augmented = self.photometric_transform(image=x_augmented.copy())
        #    x_augmented = augmented['image']

        x = feature_extractor(images=x_augmented.astype(np.float32), return_tensors="pt").pixel_values.squeeze(0).cuda()
        y = torch.tensor((y_augmented.astype(np.float32)/255)[:, :, 0], dtype=torch.float32)

        # Convert the images to float32
        x_orig_np = x_orig_np.astype(np.float32) / 255
        y_orig_np = y_orig_np.astype(np.float32) / 255
        x_augmented = x_augmented.astype(np.float32) / 255
        y_augmented = y_augmented.astype(np.float32) / 255

        return x, y, self.image_files[idx], x_orig_np, y_orig_np, x_augmented, y_augmented

In [None]:
mass_files_temp = glob(data_path + "massachusetts-roads-dataset/tiff/train/*.tiff")
#ignore files where over 10% of the pixels are white
mass_files = []
for file in mass_files_temp:
    img = Image.open(file)
    img = np.array(img)
    frac = np.sum(img == 255) / (img.shape[0] * img.shape[1] * img.shape[2])
    # print(file + ": " + str(frac))
    if frac < 0.1:
        mass_files.append(file)

print(len(mass_files))

In [None]:
massachusetts_dataset = CustomDataset(mass_files, get_geometric_transforms_mass(), get_photometric_transforms())
massachusetts_loader = DataLoader(massachusetts_dataset, batch_size=BATCH_SIZE, shuffle=True)

deepglobe_dataset = CustomDataset(glob(data_path + "deepglobe-road-extraction-dataset/train/*.jpg"), get_geometric_transforms_deepglobe(), get_photometric_transforms())
deepglobe_loader = DataLoader(deepglobe_dataset, batch_size=BATCH_SIZE, shuffle=True)

# Combine both dataset mass and deepglobe

combined_dataset = ConcatDataset([massachusetts_dataset, deepglobe_dataset])
combined_loader = DataLoader(combined_dataset, batch_size=BATCH_SIZE, shuffle=True)

# Google Maps Dataset

googlemaps_dataset = CustomDataset(glob(data_path + "googlemaps-boston-losangeles-suburbs/images/*.png"), get_geometric_transforms_official())
googlemaps_loader = DataLoader(googlemaps_dataset, batch_size=BATCH_SIZE, shuffle=True)

In [None]:
main_dataset_len = len(glob(data_path + "ethz-cil-road-segmentation-2023/training/images/*.png"))

val_size = int(main_dataset_len * 0.2)
train_size = main_dataset_len - val_size
torch.manual_seed(0)
indices = torch.randperm(main_dataset_len).tolist()
train_indices = indices[:train_size]
val_indices = indices[train_size:]

# Apply transformations only on training set
train_dataset = CustomDataset(np.array(glob(data_path + "ethz-cil-road-segmentation-2023/training/images/*.png"))[train_indices], None, None)
train_dataset_augmented = CustomDataset(np.array(glob(data_path + "ethz-cil-road-segmentation-2023/training/images/*.png"))[train_indices], get_geometric_transforms_official(), None)
val_dataset = CustomDataset(np.array(glob(data_path + "ethz-cil-road-segmentation-2023/training/images/*.png"))[val_indices], None, None)
val_dataset_augmented = CustomDataset(np.array(glob(data_path + "ethz-cil-road-segmentation-2023/training/images/*.png"))[val_indices], get_geometric_transforms_official(), None)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
train_loader_augmented = DataLoader(train_dataset_augmented, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
val_loader_augmented = DataLoader(val_dataset_augmented, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)


In [None]:
!pip install segmentation_models_pytorch
import segmentation_models_pytorch as smp
from segmentation_models_pytorch.losses import DiceLoss

from torch import nn

class SegmentationModel(nn.Module):
  def __init__(self):
    super(SegmentationModel,self).__init__()

    self.backbone = smp.Unet(
        encoder_name = ENCODER,
        encoder_weights = WEIGHTS,
        in_channels = 3,
        classes = 1,
        activation = None
    )

  def forward(self, images, masks = None):
    return self.backbone(images)

    
model = SegmentationModel()

In [None]:
model.to(DEVICE)
print("model parameters:", sum(p.numel() for p in model.parameters() if p.requires_grad))

In [None]:
def visualize_sample(model, loader):
    with torch.no_grad():
        rows = 4
        fig, ax = plt.subplots(rows, 5, figsize=(40, 40))
        for i, (x, y, name, x_orig, y_orig, x_augmented, y_augmented) in enumerate(loader):
            x = x[0]
            y = y[0]
            name = name[0]
            x_orig = x_orig[0]
            y_orig = y_orig[0]
            x_augmented = x_augmented[0]
            y_augmented = y_augmented[0]

            pred = model(x.unsqueeze(0)).squeeze(0)

            pred = F.sigmoid(pred).permute(1, 2, 0).cpu().numpy()
            if len(y.shape) == 2:
                y = y.unsqueeze(0)
            y = y.permute(1, 2, 0).cpu().numpy()
            x = x.permute(1, 2, 0).cpu().numpy()

            ax[i][0].imshow(x_orig)
            ax[i][1].imshow(y_orig)
            ax[i][2].imshow(x_augmented)
            ax[i][3].imshow(y_augmented)
            ax[i][4].imshow(pred, cmap='gray')
            

            if i == rows - 1:
                break

In [None]:
# Set up training
optimizer = optim.AdamW(model.parameters(), lr=1e-4)
criterion = nn.BCEWithLogitsLoss()

In [None]:
def train(model, dataset, optimizer):
    model.to(DEVICE)
    model.train()
    total_loss = 0
    steps = 0
    for x, y, _a, _b, _c, _d, _e in tqdm(dataset):
        x, y = x.cuda(), y.unsqueeze(1).cuda()     
        optimizer.zero_grad()
        y_pred = model(x)
        # y_pred = torch.repeat_interleave(torch.repeat_interleave(y_pred, 2, dim=2), 2, dim=3)
        # y_pred = torch.repeat_interleave(torch.repeat_interleave(y_pred, 2, dim=2), 2, dim=3)
        loss = criterion(y_pred, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        steps += 1
        if steps % 100 == 0:
            print("Training Loss:", total_loss / steps)
            # if use_wandb: wandb.log({"Train Loss": total_loss / steps})

    print("Training Loss:", total_loss / len(dataset))
    # if use_wandb: wandb.log({"Train Loss": total_loss / len(dataset)})

In [None]:
def validate(model, dataset):
    model.eval()
    y_preds = np.array([], dtype=np.float32)
    y_gt = np.array([], dtype=np.float32)
    with torch.no_grad():
        for x, y, _a, _b, _c, _d, _e in dataset:
            x = x.cuda()
            y = y.unsqueeze(1).cuda()  # add extra dimension to match model's output
            y = F.interpolate(y, size=(416, 416), mode='bilinear', align_corners=False)
            y_pred = model(x)
            y_pred = torch.sigmoid(y_pred)
            
            # apply pooling to reduce the prediction from 400x400 to 25x25
            y_pred = F.avg_pool2d(y_pred, 16, stride=16)
            # apply pooling to reduce the label from 400x400 to 25x25
            y = F.avg_pool2d(y, 16, stride=16)
            
            y_preds = np.concatenate((y_preds, y_pred.cpu().numpy().flatten()))
            y_gt = np.concatenate((y_gt, y.cpu().numpy().flatten()))
            
    y_preds = np.array(y_preds)
    y_gt = np.array(y_gt)
    for tresh in np.arange(0.15,0.40,0.05):        
        score = f1_score(y_gt>0.25, y_preds > tresh)
        print("Validation F1 Score for tresh",tresh,":", score)
        # if use_wandb: wandb.log("Validation F1 Score for tresh "+str(tresh) +": " + str(score))


In [None]:
combined_official_dataset = ConcatDataset([train_dataset, val_dataset])
combined_official_loader = DataLoader(combined_official_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)

combined_official_google_dataset = ConcatDataset([train_dataset_augmented, googlemaps_dataset])
combined_official_google_loader = DataLoader(combined_official_google_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)

actual_pretrain_loader = googlemaps_loader
actual_train_loader = combined_official_google_loader
actual_val_loader = val_loader

In [None]:
# See how the modules are structured
# for name, module in model.backbone.named_modules():
#    print(name, module)

    

In [None]:
# Freeze all parameters
for param in model.parameters():
    param.requires_grad = False

# Unfreeze the last encoder layer, decoder and classification head
for param in model.backbone.encoder.layer4.parameters():
    param.requires_grad = True
for param in model.backbone.decoder.parameters():
    param.requires_grad = True
for param in model.backbone.segmentation_head.parameters():
    param.requires_grad = True

print("trainable parameters:", sum(p.numel() for p in model.parameters() if p.requires_grad))
optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4)

for epoch in range(2):
    train(model, actual_pretrain_loader, optimizer)
    validate(model, actual_val_loader)

In [None]:
module_encoder_first = nn.ModuleList([model.backbone.encoder.conv1, model.backbone.encoder.bn1, model.backbone.encoder.relu, model.backbone.encoder.maxpool])

In [None]:
#train everything except for the start of the encoder
for param in model.parameters():
    param.requires_grad = True
for param in module_encoder_first.parameters():
    param.requires_grad = False

print("trainable parameters:", sum(p.numel() for p in model.parameters() if p.requires_grad))

optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4)

In [None]:
#do a warmup epoch for the new optimizer
optimizer.param_groups[0]['lr'] = 1e-9
for epoch in range(1):
    train(model, actual_pretrain_loader, optimizer)
    validate(model, actual_val_loader)

save(model, optimizer, "unet_post_warmup")

optimizer.param_groups[0]['lr'] = 1e-4

In [None]:
for epoch in range(6):
    train(model, actual_train_loader, optimizer)
    validate(model, actual_val_loader)
    save(model, optimizer, "unet_e-4")

In [None]:
optimizer.param_groups[0]['lr'] = 1e-5

for epoch in range(2):
    train(model, actual_train_loader, optimizer)
    validate(model, actual_val_loader)

save(model, optimizer, "unet_e-5")

In [None]:
for epoch in range(5):
    train(model, combined_official_loader, optimizer)

In [None]:
visualize_sample(model, actual_val_loader)

In [None]:
with torch.no_grad():
    test_path = data_path + "ethz-cil-road-segmentation-2023/" + "test/images/"    

    files = os.listdir(test_path)
    for file in tqdm(files):
        # print(test_path)
        # print(file)
        x_orig:Image = Image.open(test_path + file).convert("RGB")
        x_orig = np.array(x_orig, dtype=np.float32)
        x = feature_extractor(images=x_orig, return_tensors="pt").pixel_values.squeeze(0).cuda()
        pred = model(x.unsqueeze(0))
        #pred = torch.repeat_interleave(torch.repeat_interleave(pred, 2, dim=2), 2, dim=3)
        pred = pred.squeeze(0)
        pred = torch.sigmoid(pred).permute(1, 2, 0).cpu().numpy()
        # print(pred.shape) # (400, 400, 1)
        pred = pred.squeeze(-1)
        # pred = np.resize(pred, (400, 400))
        # print(pred.shape) # (400, 400)
        pred = Image.fromarray((pred*255).astype(np.uint8))
        pred = pred.resize((400, 400))
        print(pred.size)
       

        output_dir = "/kaggle/working/pred/"
        #make the folder
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        
        pred.save(output_dir + file)

In [None]:
# zip the folder
import shutil
shutil.make_archive("/kaggle/working/pred", 'zip', output_dir)