In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from torch.autograd import Variable
from torchvision import transforms
import torchvision

import os
import numpy as np
import random
from skimage import io
from scipy.ndimage import zoom
import matplotlib.pyplot as plt
from tqdm import tqdm as tqdm
from pandas import read_csv
from math import floor, ceil, sqrt, exp
import time
import warnings
from PIL import Image
import torchvision.transforms.functional as TF

import ssl
ssl._create_default_https_context = ssl._create_unverified_context

import sys
CURRENT_DIR = os.path.dirname(os.path.abspath('__file__'))
parent_dir = os.path.abspath(os.path.join(CURRENT_DIR, os.pardir))
sys.path.append(parent_dir)

from utils.helpers import crop_image
from models.change_vit import Trainer, Encoder, Decoder, DinoVisionTransformer, PatchEmbed, Block, MemEffAttention, Mlp, BasicBlock, FeatureInjector, BlockInjector, CrossAttention, MlpDecoder, ResNet
from models.efficientunet import CDUnet, UpSamplingBlock, ConvBlock
from models.siamconc import SiamUnet_conc
from models.siamdiff import SiamUnet_diff
from models.stackunet import Unet
from models.efficientunet_respath_attn import CDUnetResPath, Respath, BasicConv, GridAttentionBlock2D
from models.model import SupervisedModel, SemiSupervisedModel

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
DATA_PATH = f"{CURRENT_DIR}/../datasets/concrete_cd_labeled-v2-2/concrete_cd_labeled-v2-2"
VALIDATION_PART = 0.125

In [10]:
class ChangeDetectionTestDataset(Dataset):
    def __init__(self, data_path, transforms=None, mask_pattern=None):
        self.data_path = data_path
        self.transforms = transforms
        self.test_regex = None if mask_pattern is None else re.compile(mask_pattern) 
        self.weights = [1.0, 1.0]
        self.image_num = 0
        
        self.fetched_data = []
        self._fetch_paths(data_path)

    def _fetch_paths(self, data_path):
        total_pixels = 0.0
        positive_pixels = 0.0
        
        masks_path = os.path.join(data_path, "masks")
        matched_test_files = os.listdir(masks_path) if self.test_regex is None else [f for f in os.listdir(masks_path) if self.test_regex.match(f)]
        for mask_name in matched_test_files:
            if len(mask_name) >= 4 and mask_name[-4:] == ".PNG":
                video_name, snapshot_name = mask_name.split("_")
                snapshot_name = snapshot_name.split('.')[0]
                snapshots_dir_path = os.path.join(DATA_PATH, "data", video_name)
                before_patches = crop_image(np.array(Image.open(f"{snapshots_dir_path}/before_{snapshot_name}.png")))
                after_patches = crop_image(np.array(Image.open(f"{snapshots_dir_path}/after_{snapshot_name}.png")))
                
                uncropped_mask = np.array(Image.open(f"{masks_path}/{mask_name}"))[..., :1]
                
                total_pixels += np.prod(uncropped_mask.shape)
                positive_pixels += uncropped_mask.sum()
                mask_patches = crop_image(uncropped_mask)
                
                for before_sample, after_sample, mask_sample in zip(before_patches, after_patches, mask_patches):
                    self.fetched_data.append({"before":before_sample, "after":after_sample, "mask":mask_sample})
        self.image_num = len(matched_test_files)
        
    def augment_image_mask(
        self,
        patch_1: torch.Tensor,
        patch_2: torch.Tensor,
        mask: torch.Tensor, 
        probability: float = 0.8,
    ) -> tuple[torch.Tensor, torch.Tensor]:
       
        if random.random() <= probability:
            angle = random.randint(-50, 50)
            patch_1 = TF.rotate(patch_1, angle)
            patch_2 = TF.rotate(patch_2, angle)
            mask = TF.rotate(mask, angle)
        if random.random() <= probability:
            patch_1 = TF.vflip(patch_1)
            patch_2 = TF.vflip(patch_2)
            mask = TF.vflip(mask)  
        if random.random() <= probability:
            patch_1 = TF.hflip(patch_1)
            patch_2 = TF.hflip(patch_2)
            mask = TF.hflip(mask)

        return patch_1, patch_2, mask

    def __len__(self):
        return len(self.fetched_data)

    def __getitem__(self, idx):
        data_sample = self.fetched_data[idx]
        before_patch = data_sample["before"]
        after_patch = data_sample["after"]
        mask = data_sample["mask"]

        before_patch, after_patch, mask = TF.to_tensor(before_patch), TF.to_tensor(after_patch), TF.to_tensor(mask)
        before_patch, after_patch, mask = self.augment_image_mask(before_patch, after_patch, mask)
        before_patch = TF.normalize(before_patch, mean=(0.485), std=(0.229))
        after_patch = TF.normalize(after_patch, mean=(0.485), std=(0.229))

        return {"before":before_patch, "after":after_patch, "mask":mask}

In [11]:
import re
dataset = ChangeDetectionTestDataset(DATA_PATH, mask_pattern=r"^(?!4A_|5A_).*")
train_dataset, val_dataset, _ = random_split(dataset, [1 - VALIDATION_PART, VALIDATION_PART, 0])

In [None]:
class DiceLoss(nn.Module):
    def __init__(self, smooth=1e-6):
        super(DiceLoss, self).__init__()
        self.smooth = smooth

    def forward(self, probs, targets):
        probs = probs.view(-1)
        targets = targets.view(-1)

        intersection = (probs * targets).sum()
        dice_coeff = (2. * intersection + self.smooth) / (probs.sum() + targets.sum() + self.smooth)
        dice_loss = 1 - dice_coeff
        
        return dice_loss

loss_bce = nn.BCELoss()
loss_dice = DiceLoss()

In [None]:
def create_model_object(model_name):
    if model_name == "change_vit": return Trainer("tiny").float()
    if model_name == "siam_diff": return SiamUnet_diff(3*1, 1)
    if model_name == "siam_conc": return SiamUnet_conc(3*1, 1)
    if model_name == "stackunet": return Unet(3*2, 1)
    if model_name == "stackunet": return Unet(3*2, 1)
    if model_name == "efficientunet": return CDUnet(out_channels=1, pretrained=True)
    if model_name == "efficientunet_respath_attn":
        output_model = CDUnetResPath(out_channels=1, pretrained=True)
        output_model.activate_attention_gates()
        return output_model
    if model_name == "efficientunet_respath":
        output_model = CDUnetResPath(out_channels=1, pretrained=True)
        output_model.deactivate_attention_gates()
        return output_model
    return None

models_training_configs = [
    {
        "model_name": "efficientunet_respath",
        "save_dir": f"{CURRENT_DIR}/../weights/change_detection",
        "epochs": [100, 100], # [epochs_freeze, epochs_unfreeze]
        "input_size": [512, 512], # [height, width]
        "lr": [1e-2, 1e-3], # [lr_freeze, lr_unfreeze]
        "batch_size": [16, 4],  # [batch_size_freeze, batch_size_unfreeze]
        "eval_epoch": 5, # evaluate  the model every "eval_epoch" epochs,
        "device": device,
        "freeze_function": lambda x: x.freeze_backbone(),
        "unfreeze_function": lambda x: x.unfreeze_backbone()
    },
    {
        "model_name": "efficientunet_respath_attn",
        "save_dir": f"{CURRENT_DIR}/../weights/change_detection",
        "epochs": [100, 100], # [epochs_freeze, epochs_unfreeze]
        "input_size": [512, 512], # [height, width]
        "lr": [1e-2, 1e-3], # [lr_freeze, lr_unfreeze]
        "batch_size": [16, 4],  # [batch_size_freeze, batch_size_unfreeze]
        "eval_epoch": 5, # evaluate  the model every "eval_epoch" epochs,
        "device": device,
        "freeze_function": lambda x: x.freeze_backbone(),
        "unfreeze_function": lambda x: x.unfreeze_backbone()
    },
    {
        "model_name": "efficientunet",
        "save_dir": f"{CURRENT_DIR}/../weights/change_detection",
        "epochs": [100, 100], # [epochs_freeze, epochs_unfreeze]
        "input_size": [512, 512], # [height, width]
        "lr": [1e-2, 1e-3], # [lr_freeze, lr_unfreeze]
        "batch_size": [16, 4],  # [batch_size_freeze, batch_size_unfreeze]
        "eval_epoch": 5, # evaluate  the model every "eval_epoch" epochs,
        "device": device,
        "freeze_function": lambda x: x.freeze_backbone(),
        "unfreeze_function": lambda x: x.unfreeze_backbone()
    },
    {
        "model_name": "change_vit",
        "save_dir": f"{CURRENT_DIR}/../weights/change_detection",
        "epochs": [0, 200], # [epochs_freeze, epochs_unfreeze]
        "input_size": [256, 256], # [height, width]
        "lr": [0, 1e-3], # [lr_freeze, lr_unfreeze]
        "batch_size": [0, 8],  # [batch_size_freeze, batch_size_unfreeze]
        "eval_epoch": 5, # evaluate  the model every "eval_epoch" epochs,
        "device": device,
        "freeze_function": lambda x: x,
        "unfreeze_function": lambda x: x
    },
    {
        "model_name": "siam_diff",
        "save_dir": f"{CURRENT_DIR}/../weights/change_detection",
        "epochs": [0, 200], # [epochs_freeze, epochs_unfreeze]
        "input_size": [500, 500], # [height, width]
        "lr": [0, 1e-2], # [lr_freeze, lr_unfreeze]
        "batch_size": [0, 16],  # [batch_size_freeze, batch_size_unfreeze]
        "eval_epoch": 5, # evaluate  the model every "eval_epoch" epochs,
        "device": device,
        "freeze_function": lambda x: x,
        "unfreeze_function": lambda x: x
    },
    {
        "model_name": "siam_conc",
        "save_dir": f"{CURRENT_DIR}/../weights/change_detection",
        "epochs": [0, 200], # [epochs_freeze, epochs_unfreeze]
        "input_size": [500, 500], # [height, width]
        "lr": [0, 1e-2], # [lr_freeze, lr_unfreeze]
        "batch_size": [0, 16],  # [batch_size_freeze, batch_size_unfreeze]
        "eval_epoch": 5, # evaluate  the model every "eval_epoch" epochs,
        "device": device,
        "freeze_function": lambda x: x,
        "unfreeze_function": lambda x: x
    },
    {
        "model_name": "stackunet",
        "save_dir": f"{CURRENT_DIR}/../weights/change_detection",
        "epochs": [0, 200], # [epochs_freeze, epochs_unfreeze]
        "input_size": [500, 500], # [height, width]
        "lr": [0, 1e-2], # [lr_freeze, lr_unfreeze]
        "batch_size": [0, 16],  # [batch_size_freeze, batch_size_unfreeze]
        "eval_epoch": 5, # evaluate  the model every "eval_epoch" epochs,
        "device": device,
        "freeze_function": lambda x: x,
        "unfreeze_function": lambda x: x
    }
]

In [None]:
def compute_eval_loss(model, val_loader, resizer, device):
    model.eval()
    eval_loss = 0.0
    with torch.no_grad():
        for batch in val_loader:
            befores = resizer(batch['before']).float().to(device)
            afters = resizer(batch['after']).float().to(device)
            masks = resizer(batch['mask']).float().to(device)
            masks = torch.clamp(masks, 0.0, 1.0)

            output = model(befores, afters)
            loss = loss_bce(output, masks) + loss_dice(output, masks)
            eval_loss += loss.item()
    model.train()
    return eval_loss

In [None]:
def train(model, train_loader, val_loader, epochs, optimizer, model_training_configs, stage_name=""):
    eval_epoch = model_training_configs["eval_epoch"]
    best_eval_loss = None

    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.95)
    resizer = torchvision.transforms.Resize(model_training_configs["input_size"])
    for epoch_index in range(epochs):
        print(f'[{stage_name}] Epoch: ' + str(epoch_index + 1) + ' of ' + str(epochs))
        epoch_loss = 0.0
        for batch in train_loader:
            befores = resizer(batch['before']).float().to(model_training_configs["device"])
            afters = resizer(batch['after']).float().to(model_training_configs["device"])
            masks = resizer(batch['mask']).float().to(model_training_configs["device"])
            masks = torch.clamp(masks, 0.0, 1.0)

            optimizer.zero_grad()
            output = model(befores, afters)
            loss = loss_bce(output, masks)
            epoch_loss += loss.item()
            loss.backward()
            optimizer.step()

        if (epoch_index + 1) % eval_epoch == 0:
            current_eval_loss = compute_eval_loss(model, val_loader, resizer, model_training_configs["device"])
            if best_eval_loss is None or best_eval_loss > current_eval_loss:
                best_eval_loss = current_eval_loss
                torch.save(model, f'{model_training_configs["save_dir"]}/{model_training_configs["model_name"]}_{stage_name}_loss_{int(current_eval_loss)}.pth')
            print(f"Eval Loss: {current_eval_loss}")

        print(epoch_loss)
        scheduler.step()
        
    torch.save(model, f'{model_training_configs["save_dir"]}/{model_training_configs["model_name"]}_{stage_name}.pth')

In [None]:
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=True)

for model_training_configs in models_training_configs:
    model = create_model_object(model_training_configs["model_name"]).to(model_training_configs["device"])
    model.train()

    freeze_epochs = model_training_configs["epochs"][0]
    if freeze_epochs != 0:
        model_training_configs["freeze_function"](model)
        train_loader = DataLoader(train_dataset, batch_size=model_training_configs["batch_size"][0], shuffle=True)
        optimizer = torch.optim.NAdam(model.parameters(), lr=model_training_configs["lr"][0], weight_decay=1e-4)
        train(model, train_loader, val_loader, freeze_epochs, optimizer, model_training_configs, "freeze")
        model_training_configs["unfreeze_function"](model)

    unfreeze_epochs = model_training_configs["epochs"][1]
    if unfreeze_epochs != 0:
        train_loader = DataLoader(train_dataset, batch_size=model_training_configs["batch_size"][1], shuffle=True)
        optimizer = torch.optim.NAdam(model.parameters(), lr=model_training_configs["lr"][1], weight_decay=1e-4)
        train(model, train_loader, val_loader, unfreeze_epochs, optimizer, model_training_configs, "unfreeze")


[freeze] Epoch: 1 of 100
