In [1]:
!pip install -U fvcore

[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.1.1[0m[39;49m -> [0m[32;49m24.1.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [2]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable
from torchvision import transforms
from fvcore.nn import FlopCountAnalysis

import os
import numpy as np
import random
from skimage import io
from scipy.ndimage import zoom
import matplotlib.pyplot as plt
from tqdm import tqdm as tqdm
from pandas import read_csv
from math import floor, ceil, sqrt, exp
from IPython import display
import time
from itertools import chain
import time
import warnings
import pandas as pd
from pprint import pprint
from PIL import Image
import torchvision.transforms.functional as TF

import torchvision
import torch.nn.functional as F
from torchvision.models.feature_extraction import create_feature_extractor
from typing import Optional, Callable

import sys
CURRENT_DIR = os.path.dirname(os.path.abspath('__file__'))
parent_dir = os.path.abspath(os.path.join(CURRENT_DIR, os.pardir))
sys.path.append(parent_dir)

from utils.helpers import crop_image
from models.change_vit import Trainer, Encoder, Decoder, DinoVisionTransformer, PatchEmbed, Block, MemEffAttention, Mlp, BasicBlock, FeatureInjector, BlockInjector, CrossAttention, MlpDecoder, ResNet
from models.efficientunet import CDUnet, UpSamplingBlock, ConvBlock
from models.siamconc import SiamUnet_conc
from models.siamdiff import SiamUnet_diff
from models.stackunet import Unet
from models.efficientunet_respath_attn import CDUnetResPath, Respath, BasicConv, GridAttentionBlock2D

from models.model import SupervisedModel, SemiSupervisedModel

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
DATA_PATH = f"{CURRENT_DIR}/../datasets/concrete_cd_labeled-v2-2/concrete_cd_labeled-v2-2"

In [4]:
class ChangeDetectionTestDataset(Dataset):
    def __init__(self, data_path, transforms=None, mask_pattern=None):
        self.data_path = data_path
        self.transforms = transforms
        self.test_regex = None if mask_pattern is None else re.compile(mask_pattern) 
        self.weights = [1.0, 1.0]
        self.image_num = 0
        
        self.fetched_data = []
        self._fetch_paths(data_path)

    def _fetch_paths(self, data_path):
        total_pixels = 0.0
        positive_pixels = 0.0
        
        masks_path = os.path.join(data_path, "masks")
        matched_test_files = os.listdir(masks_path) if self.test_regex is None else [f for f in os.listdir(masks_path) if self.test_regex.match(f)]
        for mask_name in matched_test_files:
            if len(mask_name) >= 4 and mask_name[-4:] == ".PNG":
                video_name, snapshot_name = mask_name.split("_")
                snapshot_name = snapshot_name.split('.')[0]
                snapshots_dir_path = os.path.join(DATA_PATH, "data", video_name)
                before_patches = crop_image(np.array(Image.open(f"{snapshots_dir_path}/before_{snapshot_name}.png")))
                after_patches = crop_image(np.array(Image.open(f"{snapshots_dir_path}/after_{snapshot_name}.png")))
                
                uncropped_mask = np.array(Image.open(f"{masks_path}/{mask_name}"))[..., :1]
                
                total_pixels += np.prod(uncropped_mask.shape)
                positive_pixels += uncropped_mask.sum()
                mask_patches = crop_image(uncropped_mask)
                
                for before_sample, after_sample, mask_sample in zip(before_patches, after_patches, mask_patches):
                    self.fetched_data.append({"before":before_sample, "after":after_sample, "mask":mask_sample})
        self.image_num = len(matched_test_files)
        
    def __len__(self):
        return len(self.fetched_data)

    def __getitem__(self, idx):
        data_sample = self.fetched_data[idx]
        before_patch = data_sample["before"]
        after_patch = data_sample["after"]
        mask = data_sample["mask"]

        before_patch, after_patch, mask = TF.to_tensor(before_patch), TF.to_tensor(after_patch), TF.to_tensor(mask)
        before_patch = TF.normalize(before_patch, mean=(0.485), std=(0.229))
        after_patch = TF.normalize(after_patch, mean=(0.485), std=(0.229))

        return {"before":before_patch, "after":after_patch, "mask":mask}

In [5]:
import re
from torch.utils.data import DataLoader, random_split
dataset = ChangeDetectionTestDataset(DATA_PATH, mask_pattern=r"5A_*|4A_*")
test_loader = DataLoader(dataset, batch_size=1)

In [6]:
def get_model_preparation_function(model_name):
    model_name = model_name.lower()
    if "respathattn" in model_name: return lambda x: x.activate_attention_gates()
    if "respath" in model_name: return lambda x: x.deactivate_attention_gates()
    return lambda x: None

def get_model_input_size(model_name):
    model_name = model_name.lower()
    if "changevit" in model_name: return [256, 256]
    if "eff" in model_name: return [512, 512]
    if "pred_diff" in model_name: return [512, 512]
    if "siamthresh" in model_name: return [512, 512]
    if "siamthreshevolv" in model_name: return [512, 512]
    return [500, 500]

In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_root_path = f"{CURRENT_DIR}/../weights/change_detection"
models = [
    # {
    #     "model_name": model_name.split('.')[0],
    #     "model_prediction_wrapper": SupervisedModel(os.path.join(model_root_path, model_name), input_size=get_model_input_size(model_name.split('.')[0])),
    #     "device": device
    # }
    # for model_name in os.listdir(model_root_path)
]

models += [
    {
        "model_name": model_name,
        "model_prediction_wrapper": SemiSupervisedModel(model_name),
        "device": device,
    }
    for model_name in ["pred_diff", "siamthresh", "siamthreshevolv"]
]

In [8]:
class DiceLoss(nn.Module):
    def __init__(self, smooth=1e-6):
        super(DiceLoss, self).__init__()
        self.smooth = smooth

    def forward(self, probs, targets):
        probs = probs.view(-1)
        targets = targets.view(-1)

        intersection = (probs * targets).sum()
        dice_coeff = (2. * intersection + self.smooth) / (probs.sum() + targets.sum() + self.smooth)
        dice_loss = 1 - dice_coeff
        
        return dice_loss

In [9]:
def calculate_metrics(pred_prob_map, ground_truth_mask, threshold=0.5):
    pred_binary_map = (pred_prob_map > threshold).float()
    ground_truth_mask = (ground_truth_mask > 0).float()
    
    pred_binary_map = pred_binary_map.view(-1)
    ground_truth_mask = ground_truth_mask.view(-1)
    
    TP = (pred_binary_map * ground_truth_mask).sum().item()
    FP = (pred_binary_map * (1 - ground_truth_mask)).sum().item()
    TN = ((1 - pred_binary_map) * (1 - ground_truth_mask)).sum().item()
    FN = ((1 - pred_binary_map) * ground_truth_mask).sum().item()
    return TP, FP, TN, FN

In [10]:
import torch.nn as nn
class WrapperModel(nn.Module):
    def __init__(self, original_model):
        super(WrapperModel, self).__init__()
        self.original_model = original_model

    def forward(self, combined_input):
        input1, input2 = combined_input
        return self.original_model(input1, input2)

def get_flops(model, input_size, device):
    input1 = torch.rand(1, 3, input_size, input_size).to(device)
    input2 = torch.rand(1, 3, input_size, input_size).to(device)
    wrapper = WrapperModel(model)
    flops = FlopCountAnalysis(wrapper, [input1, input2])
    return flops.total()

In [11]:
loss_bce = torch.nn.BCELoss()
loss_dice = DiceLoss()

In [13]:
result_dict = {}
for model_params in models:
    
    model = model_params["model_prediction_wrapper"].to(model_params["device"])
    model_name = model_params["model_name"]
    model.eval()
    
    print(f"Current Model: {model_name}")

    bce_loss_total = 0.0
    dice_loss_total = 0.0
    iou_total = 0.0
    TP_overall , FP_overall, TN_overall, FN_overall = 0.0, 0.0, 0.0, 0.0
    
    with torch.no_grad():
        for batch in tqdm(test_loader):
            befores = batch['before'].float().to(model_params["device"])
            afters = batch['after'].float().to(model_params["device"])
            masks = batch['mask'].float().to(model_params["device"])
            masks = torch.clamp(masks, 0.0, 1.0)

            output = model(befores, afters)
            output = torch.clamp(output, 0.0, 1.0)
            bce_loss_total += loss_bce(output, masks).item()
            dice_loss_total += loss_dice(output, masks).item()
            
            TP, FP, TN, FN = calculate_metrics(output[0][0], masks[0][0], 0.5)
            TP_overall += TP
            FP_overall += FP
            TN_overall += TN
            FN_overall += FN

    iou_total = TP_overall / (TP_overall + FP_overall + FN_overall)
    precision = TP_overall / (TP_overall + FP_overall) if (TP_overall + FP_overall) > 0 else 0
    recall = TP_overall / (TP_overall + FN_overall) if (TP_overall + FN_overall) > 0 else 0
    f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
    accuracy = (TP_overall + TN_overall) / (TP_overall + TN_overall + FP_overall + FN_overall)
    
    print(f"[{model_name}] bce: {bce_loss_total / len(test_loader)}, dice: {dice_loss_total / len(test_loader)}, iou: {iou_total}, precision: {precision}, recall: {recall}, f1_score: {f1_score}, accuracy: {accuracy}")
    current_result = {"BCE": bce_loss_total / len(test_loader), "Dice": dice_loss_total / len(test_loader), "IoU": iou_total, "Precision": precision, "Recall": recall, "F1": f1_score, "Accuracy": accuracy}
        
    num_flops = get_flops(model, get_model_input_size(model_name)[0], model_params["device"])
    current_result["GFlops"] = num_flops / 1_000_000_000
    current_result["Params"] = sum(p.numel() for p in model.parameters())
    
    result_dict[model_name] = current_result
    
results_df = pd.DataFrame(result_dict).T

Current Model: pred_diff


  5%|â–Œ         | 23/448 [00:26<08:14,  1.16s/it]

In [None]:
results_df.to_excel("test_results.xlsx")