In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.padding import ReplicationPad2d
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable
import torchvision.transforms as tr
from torchinfo import summary

# Other
import os
import numpy as np
import random
from skimage import io
from scipy.ndimage import zoom
import matplotlib.pyplot as plt
%matplotlib inline
from tqdm import tqdm as tqdm
from pandas import read_csv
from math import floor, ceil, sqrt, exp
from IPython import display
import time
from itertools import chain
import time
import warnings
from pprint import pprint

from numpy.random import RandomState

SEED=2342

torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)
print('IMPORTS OK')

# Global Variables' Definitions

PATH_TO_DATASET = '../onera/OSCD/'
IS_PROTOTYPE = False

FP_MODIFIER = 10 # Tuning parameter, use 1 if unsure
#BATCH_SIZE = 32
BATCH_SIZE = 8
#BATCH_SIZE = 1

PATCH_SIDE = 128
#PATCH_SIDE = 64

N_EPOCHS = 10
NORMALISE_IMGS = True
TRAIN_STRIDE = int(PATCH_SIDE/2) - 1
TYPE = 3 # 0-RGB | 1-RGBIr | 2-All bands s.t. resulution <= 20m | 3-All bands
LOAD_TRAINED = False
DATA_AUG = True
#device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 
device = torch.device("cpu") 

print('DEFINITIONS OK')

# Dataset Functions

def adjust_shape(I, s):
    """Adjust shape of grayscale image I to s."""
    
    # crop if necesary
    I = I[:s[0],:s[1]]
    si = I.shape
    
    # pad if necessary 
    p0 = max(0,s[0] - si[0])
    p1 = max(0,s[1] - si[1])
    
    return np.pad(I,((0,p0),(0,p1)),'edge')
    

def read_sentinel_img(path):
    """Read cropped Sentinel-2 image: RGB bands."""
    im_name = os.listdir(path)[0][:-7]
    r = io.imread(path + im_name + "B04.tif")
    g = io.imread(path + im_name + "B03.tif")
    b = io.imread(path + im_name + "B02.tif")
    
    I = np.stack((r,g,b),axis=2).astype('float')
    
    if NORMALISE_IMGS:
        I = (I - I.mean()) / I.std()

    return I

def read_sentinel_img_4(path):
    """Read cropped Sentinel-2 image: RGB and NIR bands."""
    im_name = os.listdir(path)[0][:-7]
    r = io.imread(path + im_name + "B04.tif")
    g = io.imread(path + im_name + "B03.tif")
    b = io.imread(path + im_name + "B02.tif")
    nir = io.imread(path + im_name + "B08.tif")
    
    I = np.stack((r,g,b,nir),axis=2).astype('float')
    
    if NORMALISE_IMGS:
        I = (I - I.mean()) / I.std()

    return I

def read_sentinel_img_leq20(path):
    """Read cropped Sentinel-2 image: bands with resolution less than or equals to 20m."""
    im_name = os.listdir(path)[0][:-7]
    
    r = io.imread(path + im_name + "B04.tif")
    s = r.shape
    g = io.imread(path + im_name + "B03.tif")
    b = io.imread(path + im_name + "B02.tif")
    nir = io.imread(path + im_name + "B08.tif")
    
    ir1 = adjust_shape(zoom(io.imread(path + im_name + "B05.tif"),2),s)
    ir2 = adjust_shape(zoom(io.imread(path + im_name + "B06.tif"),2),s)
    ir3 = adjust_shape(zoom(io.imread(path + im_name + "B07.tif"),2),s)
    nir2 = adjust_shape(zoom(io.imread(path + im_name + "B8A.tif"),2),s)
    swir2 = adjust_shape(zoom(io.imread(path + im_name + "B11.tif"),2),s)
    swir3 = adjust_shape(zoom(io.imread(path + im_name + "B12.tif"),2),s)
    
    I = np.stack((r,g,b,nir,ir1,ir2,ir3,nir2,swir2,swir3),axis=2).astype('float')
    
    if NORMALISE_IMGS:
        I = (I - I.mean()) / I.std()

    return I

def read_sentinel_img_leq60(path):
    """Read cropped Sentinel-2 image: all bands."""
    im_name = os.listdir(path)[0][:-7]
    
    r = io.imread(path + im_name + "B04.tif")
    s = r.shape
    g = io.imread(path + im_name + "B03.tif")
    b = io.imread(path + im_name + "B02.tif")
    nir = io.imread(path + im_name + "B08.tif")
    
    ir1 = adjust_shape(zoom(io.imread(path + im_name + "B05.tif"),2),s)
    ir2 = adjust_shape(zoom(io.imread(path + im_name + "B06.tif"),2),s)
    ir3 = adjust_shape(zoom(io.imread(path + im_name + "B07.tif"),2),s)
    nir2 = adjust_shape(zoom(io.imread(path + im_name + "B8A.tif"),2),s)
    swir2 = adjust_shape(zoom(io.imread(path + im_name + "B11.tif"),2),s)
    swir3 = adjust_shape(zoom(io.imread(path + im_name + "B12.tif"),2),s)
    
    uv = adjust_shape(zoom(io.imread(path + im_name + "B01.tif"),6),s)
    wv = adjust_shape(zoom(io.imread(path + im_name + "B09.tif"),6),s)
    swirc = adjust_shape(zoom(io.imread(path + im_name + "B10.tif"),6),s)
    
    I = np.stack((r,g,b,nir,ir1,ir2,ir3,nir2,swir2,swir3,uv,wv,swirc),axis=2).astype('float')
    
    if NORMALISE_IMGS:
        I = (I - I.mean()) / I.std()

    return I

def read_sentinel_img_trio(path):
    """Read cropped Sentinel-2 image pair and change map."""
#     read images
    if TYPE == 0:
        I1 = read_sentinel_img(path + '/imgs_1/')
        I2 = read_sentinel_img(path + '/imgs_2/')
    elif TYPE == 1:
        I1 = read_sentinel_img_4(path + '/imgs_1/')
        I2 = read_sentinel_img_4(path + '/imgs_2/')
    elif TYPE == 2:
        I1 = read_sentinel_img_leq20(path + '/imgs_1/')
        I2 = read_sentinel_img_leq20(path + '/imgs_2/')
    elif TYPE == 3:
        I1 = read_sentinel_img_leq60(path + '/imgs_1/')
        I2 = read_sentinel_img_leq60(path + '/imgs_2/')
        
    cm = io.imread(path + '/cm/cm.png', as_gray=True) != 0
    
    # crop if necessary
    s1 = I1.shape
    s2 = I2.shape
    I2 = np.pad(I2,((0, s1[0] - s2[0]), (0, s1[1] - s2[1]), (0,0)),'edge')
    
    
    return I1, I2, cm



def reshape_for_torch(I):
    """Transpose image for PyTorch coordinates."""
#     out = np.swapaxes(I,1,2)
#     out = np.swapaxes(out,0,1)
#     out = out[np.newaxis,:]
    out = I.transpose((2, 0, 1))
    return torch.from_numpy(out)



class ChangeDetectionDataset(Dataset):
    """Change Detection dataset class, used for both training and test data."""

    def __init__(self, path, train = True, patch_side = 128, stride = None, use_all_bands = False, transform=None):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        
        # basics
        self.transform = transform
        self.path = path
        self.patch_side = patch_side
        if not stride:
            self.stride = 1
        else:
            self.stride = stride
        
        if train:
            fname = 'all.txt'
        else:
            fname = 'test.txt'
        
#         print(path + fname)
        self.names = read_csv(path + fname).columns
        self.n_imgs = self.names.shape[0]
        
        n_pix = 0
        true_pix = 0
        
        
        # load images
        self.imgs_1 = {}
        self.imgs_2 = {}
        self.change_maps = {}
        self.n_patches_per_image = {}
        self.n_patches = 0
        self.patch_coords = []
        for im_name in tqdm(self.names):
            # load and store each image
            I1, I2, cm = read_sentinel_img_trio(self.path + im_name)
            self.imgs_1[im_name] = reshape_for_torch(I1)
            self.imgs_2[im_name] = reshape_for_torch(I2)
            self.change_maps[im_name] = cm
            
            s = cm.shape
            n_pix += np.prod(s)
            true_pix += cm.sum()
            
            # calculate the number of patches
            s = self.imgs_1[im_name].shape
            n1 = ceil((s[1] - self.patch_side + 1) / self.stride)
            n2 = ceil((s[2] - self.patch_side + 1) / self.stride)
            n_patches_i = n1 * n2
            self.n_patches_per_image[im_name] = n_patches_i
            self.n_patches += n_patches_i
            
            # generate path coordinates
            for i in range(n1):
                for j in range(n2):
                    # coordinates in (x1, x2, y1, y2)
                    current_patch_coords = (im_name, 
                                    [self.stride*i, self.stride*i + self.patch_side, self.stride*j, self.stride*j + self.patch_side],
                                    [self.stride*(i + 1), self.stride*(j + 1)])
                    self.patch_coords.append(current_patch_coords)
                    
        self.weights = [ FP_MODIFIER * 2 * true_pix / n_pix, 2 * (n_pix - true_pix) / n_pix]
        
    def get_img(self, im_name):
        return self.imgs_1[im_name], self.imgs_2[im_name], self.change_maps[im_name]        

    def get_ood_img(self, im_name):
        return self.transform(self.imgs_1[im_name]).squeeze(0), self.transform(self.imgs_2[im_name]).squeeze(0), self.change_maps[im_name]
        
    def __len__(self):
        return self.n_patches

    def __getitem__(self, idx):
        current_patch_coords = self.patch_coords[idx]
        im_name = current_patch_coords[0]
        limits = current_patch_coords[1]
        centre = current_patch_coords[2]
        
        I1 = self.imgs_1[im_name][:, limits[0]:limits[1], limits[2]:limits[3]]
        I2 = self.imgs_2[im_name][:, limits[0]:limits[1], limits[2]:limits[3]]
        
        label = self.change_maps[im_name][limits[0]:limits[1], limits[2]:limits[3]]
        label = torch.from_numpy(1*np.array(label)).float()
        
        sample = {'I1': I1, 'I2': I2, 'label': label}
        
        if self.transform:
            sample = self.transform(sample)

        return sample



class RandomFlip(object):
    """Flip randomly the images in a sample."""

#     def __init__(self):
#         return

    def __call__(self, sample):
        I1, I2, label = sample['I1'], sample['I2'], sample['label']
        
        if random.random() > 0.5:
            I1 =  I1.numpy()[:,:,::-1].copy()
            I1 = torch.from_numpy(I1)
            I2 =  I2.numpy()[:,:,::-1].copy()
            I2 = torch.from_numpy(I2)
            label =  label.numpy()[:,::-1].copy()
            label = torch.from_numpy(label)

        return {'I1': I1, 'I2': I2, 'label': label}



class RandomRot(object):
    """Rotate randomly the images in a sample."""

#     def __init__(self):
#         return

    def __call__(self, sample):
        I1, I2, label = sample['I1'], sample['I2'], sample['label']
        
        n = random.randint(0, 3)
        if n:
            I1 =  sample['I1'].numpy()
            I1 = np.rot90(I1, n, axes=(1, 2)).copy()
            I1 = torch.from_numpy(I1)
            I2 =  sample['I2'].numpy()
            I2 = np.rot90(I2, n, axes=(1, 2)).copy()
            I2 = torch.from_numpy(I2)
            label =  sample['label'].numpy()
            label = np.rot90(label, n, axes=(0, 1)).copy()
            label = torch.from_numpy(label)

        return {'I1': I1, 'I2': I2, 'label': label}

print('UTILS OK')

# Training functions
def train(model, net_name, criterion, optimizer, scheduler, train_loader, train_dataset, test_dataset, feat=False, n_epochs = N_EPOCHS, save = True):
    L = 1024
    N = 2
    t = np.linspace(1, n_epochs, n_epochs)
    
    epoch_train_loss = 0 * t
    epoch_train_accuracy = 0 * t
    epoch_train_change_accuracy = 0 * t
    epoch_train_nochange_accuracy = 0 * t
    epoch_train_precision = 0 * t
    epoch_train_recall = 0 * t
    epoch_train_Fmeasure = 0 * t
    epoch_test_loss = 0 * t
    epoch_test_accuracy = 0 * t
    epoch_test_change_accuracy = 0 * t
    epoch_test_nochange_accuracy = 0 * t
    epoch_test_precision = 0 * t
    epoch_test_recall = 0 * t
    epoch_test_Fmeasure = 0 * t
    
    
    fm = 0
    best_fm = 0
    
    lss = 1000
    best_lss = 1000
    
    plt.figure(num=1)
    plt.figure(num=2)
    plt.figure(num=3)
    
    for epoch_index in tqdm(range(n_epochs)):
        model.train()
        print('Epoch: ' + str(epoch_index + 1) + ' of ' + str(N_EPOCHS))

        tot_count = 0
        tot_loss = 0
        tot_accurate = 0
        class_correct = list(0. for i in range(2))
        class_total = list(0. for i in range(2))
#         for batch_index, batch in enumerate(tqdm(data_loader)):
        for batch in train_loader:
            I1 = Variable(batch['I1'].float().to(device))
            #I1 = Variable(batch['I1'].float())
            #I2 = Variable(batch['I2'].float())
            I2 = Variable(batch['I2'].float().to(device))
            #I2 = Variable(batch['I2'].float())
            #label = torch.squeeze(Variable(batch['label'].cuda()))
            label = torch.squeeze(Variable(batch['label'].to(device)))

            optimizer.zero_grad()
            if feat:
                output, _ = model(I1, I2)
            else:
                output = model(I1, I2)
            loss = criterion(output, label.long())
            loss.backward()
            optimizer.step()
            
        scheduler.step()


        epoch_train_loss[epoch_index], epoch_train_accuracy[epoch_index], cl_acc, pr_rec = test_patch_batch(train_dataset, criterion, model, feat=feat)
        epoch_train_nochange_accuracy[epoch_index] = cl_acc[0]
        epoch_train_change_accuracy[epoch_index] = cl_acc[1]
        epoch_train_precision[epoch_index] = pr_rec[0]
        epoch_train_recall[epoch_index] = pr_rec[1]
        epoch_train_Fmeasure[epoch_index] = pr_rec[2]
        
        epoch_test_loss[epoch_index], epoch_test_accuracy[epoch_index], cl_acc, pr_rec = test_patch_batch(test_dataset, criterion, model, feat=feat)
        epoch_test_nochange_accuracy[epoch_index] = cl_acc[0]
        epoch_test_change_accuracy[epoch_index] = cl_acc[1]
        epoch_test_precision[epoch_index] = pr_rec[0]
        epoch_test_recall[epoch_index] = pr_rec[1]
        epoch_test_Fmeasure[epoch_index] = pr_rec[2]

        plt.figure(num=1)
        plt.clf()
        l1_1, = plt.plot(t[:epoch_index + 1], epoch_train_loss[:epoch_index + 1], label='Train loss')
        l1_2, = plt.plot(t[:epoch_index + 1], epoch_test_loss[:epoch_index + 1], label='Test loss')
        plt.legend(handles=[l1_1, l1_2])
        plt.grid()
#         plt.gcf().gca().set_ylim(bottom = 0)
        plt.gcf().gca().set_xlim(left = 0)
        plt.title('Loss')
        display.clear_output(wait=True)
        display.display(plt.gcf())

        plt.figure(num=2)
        plt.clf()
        l2_1, = plt.plot(t[:epoch_index + 1], epoch_train_accuracy[:epoch_index + 1], label='Train accuracy')
        l2_2, = plt.plot(t[:epoch_index + 1], epoch_test_accuracy[:epoch_index + 1], label='Test accuracy')
        plt.legend(handles=[l2_1, l2_2])
        plt.grid()
        plt.gcf().gca().set_ylim(0, 100)
#         plt.gcf().gca().set_ylim(bottom = 0)
#         plt.gcf().gca().set_xlim(left = 0)
        plt.title('Accuracy')
        display.clear_output(wait=True)
        display.display(plt.gcf())

        plt.figure(num=3)
        plt.clf()
        l3_1, = plt.plot(t[:epoch_index + 1], epoch_train_nochange_accuracy[:epoch_index + 1], label='Train accuracy: no change')
        l3_2, = plt.plot(t[:epoch_index + 1], epoch_train_change_accuracy[:epoch_index + 1], label='Train accuracy: change')
        l3_3, = plt.plot(t[:epoch_index + 1], epoch_test_nochange_accuracy[:epoch_index + 1], label='Test accuracy: no change')
        l3_4, = plt.plot(t[:epoch_index + 1], epoch_test_change_accuracy[:epoch_index + 1], label='Test accuracy: change')
        plt.legend(handles=[l3_1, l3_2, l3_3, l3_4])
        plt.grid()
        plt.gcf().gca().set_ylim(0, 100)
#         plt.gcf().gca().set_ylim(bottom = 0)
#         plt.gcf().gca().set_xlim(left = 0)
        plt.title('Accuracy per class')
        display.clear_output(wait=True)
        display.display(plt.gcf())

        plt.figure(num=4)
        plt.clf()
        l4_1, = plt.plot(t[:epoch_index + 1], epoch_train_precision[:epoch_index + 1], linewidth = '1', label='Train precision')
        l4_2, = plt.plot(t[:epoch_index + 1], epoch_train_recall[:epoch_index + 1],  linewidth = '1', label='Train recall')
        l4_3, = plt.plot(t[:epoch_index + 1], epoch_train_Fmeasure[:epoch_index + 1],  linewidth = '1', label='Train Dice/F1')
        l4_4, = plt.plot(t[:epoch_index + 1], epoch_test_precision[:epoch_index + 1],  linestyle='dashed', linewidth = '2', label='Test precision')
        l4_5, = plt.plot(t[:epoch_index + 1], epoch_test_recall[:epoch_index + 1], linestyle='dashed', linewidth = '2', label='Test recall')
        l4_6, = plt.plot(t[:epoch_index + 1], epoch_test_Fmeasure[:epoch_index + 1], linestyle='dashed', linewidth = '2', label='Test Dice/F1')
        plt.legend(handles=[l4_1, l4_2, l4_3, l4_4, l4_5, l4_6])
        plt.grid()
        plt.gcf().gca().set_ylim(0, 1)
#         plt.gcf().gca().set_ylim(bottom = 0)
#         plt.gcf().gca().set_xlim(left = 0)
        plt.title('Precision, Recall and F-measure')
        display.clear_output(wait=True)
        display.display(plt.gcf())
        
        
        
        
#         fm = pr_rec[2]
        fm = epoch_train_Fmeasure[epoch_index]
        if fm > best_fm:
            best_fm = fm
            save_str = 'net-best_epoch-' + str(epoch_index + 1) + '_fm-' + str(fm) + '.pth.tar'
            torch.save(model.state_dict(), save_str)
        
        lss = epoch_train_loss[epoch_index]
        if lss < best_lss:
            best_lss = lss
            save_str = 'net-best_epoch-' + str(epoch_index + 1) + '_loss-' + str(lss) + '.pth.tar'
            torch.save(model.state_dict(), save_str)
            
            
#         print('Epoch loss: ' + str(tot_loss/tot_count))
        if save:
            im_format = 'png'
    #         im_format = 'eps'

            plt.figure(num=1)
            plt.savefig(net_name + '-01-loss.' + im_format)

            plt.figure(num=2)
            plt.savefig(net_name + '-02-accuracy.' + im_format)

            plt.figure(num=3)
            plt.savefig(net_name + '-03-accuracy-per-class.' + im_format)

            plt.figure(num=4)
            plt.savefig(net_name + '-04-prec-rec-fmeas.' + im_format)
        
    out = {'train_loss': epoch_train_loss[-1],
           'train_accuracy': epoch_train_accuracy[-1],
           'train_nochange_accuracy': epoch_train_nochange_accuracy[-1],
           'train_change_accuracy': epoch_train_change_accuracy[-1],
           'test_loss': epoch_test_loss[-1],
           'test_accuracy': epoch_test_accuracy[-1],
           'test_nochange_accuracy': epoch_test_nochange_accuracy[-1],
           'test_change_accuracy': epoch_test_change_accuracy[-1]}
    
    print('pr_c, rec_c, f_meas, pr_nc, rec_nc')
    print(pr_rec)
    
    return out


def kappa(tp, tn, fp, fn):
    N = tp + tn + fp + fn
    p0 = (tp + tn) / N
    pe = ((tp+fp)*(tp+fn) + (tn+fp)*(tn+fn)) / (N * N)
    
    return (p0 - pe) / (1 - pe)

def test_patch_batch(dset, criterion, model, feat=False, batch_size=BATCH_SIZE, patch_size=PATCH_SIDE):
    L = 1024
    N = 2
    model.eval()
    tot_loss = 0
    tot_count = 0
    tot_accurate = 0
    model = model.to(device)
    
    n = 2
    class_correct = list(0. for i in range(n))
    class_total = list(0. for i in range(n))
    class_accuracy = list(0. for i in range(n))
    
    tp = 0
    tn = 0
    fp = 0
    fn = 0
    
    for img_index in dset.names:
        I1_full, I2_full, cm_full = dset.get_img(img_index)
        full_x, full_y = cm_full.shape
        cm = Variable(torch.unsqueeze(torch.from_numpy(1.0*cm_full),0).float().to(device))
        cm_patch, cm_x, cm_y, cm_px, cm_py = get_mask_patch_batch(cm_full)
        I1_patch, i_x, i_y, i_px, i_py = get_image_patch_batch(I1_full)
        I2_patch, _, _, _, _ = get_image_patch_batch(I2_full)
        
        # first perform inference
        I1 = Variable(I1_patch.float().to(device))
        I2 = Variable(I2_patch.float().to(device))
        if feat:
            output, _ = model(I1, I2)
        else:
            output = model(I1, I2)
        #print(output.shape)
        output_size = (cm_px,cm_py)
        b_s =  1
        cm_recon = reconstruct_image_batch(output, b_s, output_size, patch_size)
        #print(cm_recon.shape)
        output = cm_recon[:, :, :full_x, :full_y]
        #print(output.shape)
        loss = criterion(output, cm.long())
        tot_loss += loss.data * np.prod(cm.size())
        tot_count += np.prod(cm.size())

        _, predicted = torch.max(output.data, 1)

        c = (predicted.int() == cm.data.int())
        for i in range(c.size(1)):
            for j in range(c.size(2)):
                l = int(cm.data[0, i, j])
                class_correct[l] += c[0, i, j]
                class_total[l] += 1
                
        pr = (predicted.int() > 0).cpu().numpy()
        gt = (cm.data.int() > 0).cpu().numpy()
        
        tp += np.logical_and(pr, gt).sum()
        tn += np.logical_and(np.logical_not(pr), np.logical_not(gt)).sum()
        fp += np.logical_and(pr, np.logical_not(gt)).sum()
        fn += np.logical_and(np.logical_not(pr), gt).sum()
        
    net_loss = tot_loss/tot_count        
    net_loss = float(net_loss.cpu().numpy())
    
    net_accuracy = 100 * (tp + tn)/tot_count
    
    for i in range(n):
        class_accuracy[i] = 100 * class_correct[i] / max(class_total[i],0.00001)
        class_accuracy[i] =  float(class_accuracy[i].cpu().numpy())

    prec = tp / (tp + fp)
    rec = tp / (tp + fn)
    f_meas = 2 * prec * rec / (prec + rec)
    dice = 2 * prec * rec / (prec + rec)
    prec_nc = tn / (tn + fn)


    rec_nc = tn / (tn + fp)
    
    pr_rec = [prec, rec, f_meas, prec_nc, rec_nc]
        
    return net_loss, net_accuracy, class_accuracy, pr_rec




# Training functions
def train_ood_batch(model, net_name, criterion, optimizer, scheduler, train_loader, train_dataset, test_dataset, feat=False, n_epochs = N_EPOCHS, save = True):
    L = 1024
    N = 2
    t = np.linspace(1, n_epochs, n_epochs)
    
    epoch_train_loss = 0 * t
    epoch_train_accuracy = 0 * t
    epoch_train_change_accuracy = 0 * t
    epoch_train_nochange_accuracy = 0 * t
    epoch_train_precision = 0 * t
    epoch_train_recall = 0 * t
    epoch_train_Fmeasure = 0 * t
    epoch_test_loss = 0 * t
    epoch_test_accuracy = 0 * t
    epoch_test_change_accuracy = 0 * t
    epoch_test_nochange_accuracy = 0 * t
    epoch_test_precision = 0 * t
    epoch_test_recall = 0 * t
    epoch_test_Fmeasure = 0 * t
    
    
    fm = 0
    best_fm = 0
    
    lss = 1000
    best_lss = 1000
    
    plt.figure(num=1)
    plt.figure(num=2)
    plt.figure(num=3)
    
    for epoch_index in tqdm(range(n_epochs)):
        model.train()
        print('Epoch: ' + str(epoch_index + 1) + ' of ' + str(N_EPOCHS))

        tot_count = 0
        tot_loss = 0
        tot_accurate = 0
        class_correct = list(0. for i in range(2))
        class_total = list(0. for i in range(2))
#         for batch_index, batch in enumerate(tqdm(data_loader)):
        for batch in train_loader:
            I1 = Variable(batch['I1'].float().to(device))
            #I1 = Variable(batch['I1'].float())
            #I2 = Variable(batch['I2'].float())
            I2 = Variable(batch['I2'].float().to(device))
            #I2 = Variable(batch['I2'].float())
            #label = torch.squeeze(Variable(batch['label'].cuda()))
            label = torch.squeeze(Variable(batch['label'].to(device)))

            optimizer.zero_grad()
            if feat:
                output, _ = model(I1, I2)
            else:
                output = model(I1, I2)
            loss = criterion(output, label.long())
            loss.backward()
            optimizer.step()
            
        scheduler.step()


        epoch_train_loss[epoch_index], epoch_train_accuracy[epoch_index], cl_acc, pr_rec = test_patch_batch_ood(train_dataset, criterion, model, feat=feat)
        epoch_train_nochange_accuracy[epoch_index] = cl_acc[0]
        epoch_train_change_accuracy[epoch_index] = cl_acc[1]
        epoch_train_precision[epoch_index] = pr_rec[0]
        epoch_train_recall[epoch_index] = pr_rec[1]
        epoch_train_Fmeasure[epoch_index] = pr_rec[2]
        
        epoch_test_loss[epoch_index], epoch_test_accuracy[epoch_index], cl_acc, pr_rec = test_patch_batch_ood(test_dataset, criterion, model, feat=feat)
        epoch_test_nochange_accuracy[epoch_index] = cl_acc[0]
        epoch_test_change_accuracy[epoch_index] = cl_acc[1]
        epoch_test_precision[epoch_index] = pr_rec[0]
        epoch_test_recall[epoch_index] = pr_rec[1]
        epoch_test_Fmeasure[epoch_index] = pr_rec[2]        
        
#         fm = pr_rec[2]
        fm = epoch_train_Fmeasure[epoch_index]
        if fm > best_fm:
            best_fm = fm
            save_str = 'net-best_epoch-' + str(epoch_index + 1) + '_fm-' + str(fm) + '.pth.tar'
            torch.save(model.state_dict(), save_str)
        
        lss = epoch_train_loss[epoch_index]
        if lss < best_lss:
            best_lss = lss
            save_str = 'net-best_epoch-' + str(epoch_index + 1) + '_loss-' + str(lss) + '.pth.tar'
            torch.save(model.state_dict(), save_str)
            
        
    out = {'train_loss': epoch_train_loss[-1],
           'train_accuracy': epoch_train_accuracy[-1],
           'train_nochange_accuracy': epoch_train_nochange_accuracy[-1],
           'train_change_accuracy': epoch_train_change_accuracy[-1],
           'test_loss': epoch_test_loss[-1],
           'test_accuracy': epoch_test_accuracy[-1],
           'test_nochange_accuracy': epoch_test_nochange_accuracy[-1],
           'test_change_accuracy': epoch_test_change_accuracy[-1]}
    
    print('pr_c, rec_c, f_meas, pr_nc, rec_nc')
    print(pr_rec)
    
    return out


def kappa(tp, tn, fp, fn):
    N = tp + tn + fp + fn
    p0 = (tp + tn) / N
    pe = ((tp+fp)*(tp+fn) + (tn+fp)*(tn+fn)) / (N * N)
    
    return (p0 - pe) / (1 - pe)

def test_patch_batch_ood(dset, criterion, model, feat=False, batch_size=BATCH_SIZE, patch_size=PATCH_SIDE):
    L = 1024
    N = 2
    model.eval()
    tot_loss = 0
    tot_count = 0
    tot_accurate = 0
    model = model.to(device)
    
    n = 2
    class_correct = list(0. for i in range(n))
    class_total = list(0. for i in range(n))
    class_accuracy = list(0. for i in range(n))
    
    tp = 0
    tn = 0
    fp = 0
    fn = 0
    
    for img_index in dset.dataset.names:
        I1_full, I2_full, cm_full = dset.dataset.get_img(img_index)
        full_x, full_y = cm_full.shape
        cm = Variable(torch.unsqueeze(torch.from_numpy(1.0*cm_full),0).float().to(device))
        cm_patch, cm_x, cm_y, cm_px, cm_py = get_mask_patch_batch(cm_full)
        I1_patch, i_x, i_y, i_px, i_py = get_image_patch_batch(I1_full)
        I2_patch, _, _, _, _ = get_image_patch_batch(I2_full)
        
        # first perform inference
        I1 = Variable(I1_patch.float().to(device))
        I2 = Variable(I2_patch.float().to(device))
        if feat:
            output, _ = model(I1, I2)
        else:
            output = model(I1, I2)
        #print(output.shape)
        output_size = (cm_px,cm_py)
        b_s =  1
        cm_recon = reconstruct_image_batch(output, b_s, output_size, patch_size)
        #print(cm_recon.shape)
        output = cm_recon[:, :, :full_x, :full_y]
        #print(output.shape)
        loss = criterion(output, cm.long())
        tot_loss += loss.data * np.prod(cm.size())
        tot_count += np.prod(cm.size())

        _, predicted = torch.max(output.data, 1)

        c = (predicted.int() == cm.data.int())
        for i in range(c.size(1)):
            for j in range(c.size(2)):
                l = int(cm.data[0, i, j])
                class_correct[l] += c[0, i, j]
                class_total[l] += 1
                
        pr = (predicted.int() > 0).cpu().numpy()
        gt = (cm.data.int() > 0).cpu().numpy()
        
        tp += np.logical_and(pr, gt).sum()
        tn += np.logical_and(np.logical_not(pr), np.logical_not(gt)).sum()
        fp += np.logical_and(pr, np.logical_not(gt)).sum()
        fn += np.logical_and(np.logical_not(pr), gt).sum()
        
    net_loss = tot_loss/tot_count        
    net_loss = float(net_loss.cpu().numpy())
    
    net_accuracy = 100 * (tp + tn)/tot_count
    
    for i in range(n):
        class_accuracy[i] = 100 * class_correct[i] / max(class_total[i],0.00001)
        class_accuracy[i] =  float(class_accuracy[i].cpu().numpy())

    prec = tp / (tp + fp)
    rec = tp / (tp + fn)
    f_meas = 2 * prec * rec / (prec + rec)
    dice = 2 * prec * rec / (prec + rec)
    prec_nc = tn / (tn + fn)


    rec_nc = tn / (tn + fp)
    
    pr_rec = [prec, rec, f_meas, prec_nc, rec_nc]
        
    return net_loss, net_accuracy, class_accuracy, pr_rec

def test(dset, criterion, model, feat=False):
    L = 1024
    N = 2
    model.eval()
    tot_loss = 0
    tot_count = 0
    tot_accurate = 0
    
    n = 2
    class_correct = list(0. for i in range(n))
    class_total = list(0. for i in range(n))
    class_accuracy = list(0. for i in range(n))
    
    tp = 0
    tn = 0
    fp = 0
    fn = 0
    
    for img_index in dset.names:
        I1_full, I2_full, cm_full = dset.get_img(img_index)
        
        s = cm_full.shape
        

        steps0 = np.arange(0,s[0],ceil(s[0]/N))
        steps1 = np.arange(0,s[1],ceil(s[1]/N))
        for ii in range(N):
            for jj in range(N):
                xmin = steps0[ii]
                if ii == N-1:
                    xmax = s[0]
                else:
                    xmax = steps0[ii+1]
                ymin = jj
                if jj == N-1:
                    ymax = s[1]
                else:
                    ymax = steps1[jj+1]
                I1 = I1_full[:, xmin:xmax, ymin:ymax]
                I2 = I2_full[:, xmin:xmax, ymin:ymax]
                cm = cm_full[xmin:xmax, ymin:ymax]

                I1 = Variable(torch.unsqueeze(I1, 0).float().to(device))
                I2 = Variable(torch.unsqueeze(I2, 0).float().to(device))
                cm = Variable(torch.unsqueeze(torch.from_numpy(1.0*cm),0).float().to(device))
                #cm = Variable(torch.from_numpy(1.0*cm).float())
                #print(I1.shape, I2.shape, cm.shape)


                if feat:
                    output, _ = model(I1, I2)
                else:
                    output = model(I1, I2)
                loss = criterion(output, cm.long())
        #         print(loss)
                tot_loss += loss.data * np.prod(cm.size())
                tot_count += np.prod(cm.size())

                _, predicted = torch.max(output.data, 1)

                c = (predicted.int() == cm.data.int())
                for i in range(c.size(1)):
                    for j in range(c.size(2)):
                        l = int(cm.data[0, i, j])
                        class_correct[l] += c[0, i, j]
                        class_total[l] += 1
                        
                pr = (predicted.int() > 0).cpu().numpy()
                gt = (cm.data.int() > 0).cpu().numpy()
                
                tp += np.logical_and(pr, gt).sum()
                tn += np.logical_and(np.logical_not(pr), np.logical_not(gt)).sum()
                fp += np.logical_and(pr, np.logical_not(gt)).sum()
                fn += np.logical_and(np.logical_not(pr), gt).sum()
        
    net_loss = tot_loss/tot_count
    net_accuracy = 100 * (tp + tn)/tot_count
    
    for i in range(n):
        class_accuracy[i] = 100 * class_correct[i] / max(class_total[i],0.00001)

    prec = tp / (tp + fp)
    rec = tp / (tp + fn)
    f_meas = 2 * prec * rec / (prec + rec)
    prec_nc = tn / (tn + fn)
    rec_nc = tn / (tn + fp)
    
    pr_rec = [prec, rec, f_meas, prec_nc, rec_nc]
        
    return net_loss, net_accuracy, class_accuracy, pr_rec

def unseen_test(dset, criterion, model, feat=False):
    L = 1024
    N = 2
    model.eval()
    tot_loss = 0
    tot_count = 0
    tot_accurate = 0
    
    n = 2
    class_correct = list(0. for i in range(n))
    class_total = list(0. for i in range(n))
    class_accuracy = list(0. for i in range(n))
    
    tp = 0
    tn = 0
    fp = 0
    fn = 0
    
    for img_index in tqdm(dset.names):
        I1_full, I2_full, cm_full = dset.get_img(img_index)
        
        s = cm_full.shape
        
        for ii in range(ceil(s[0]/L)):
            for jj in range(ceil(s[1]/L)):
                xmin = L*ii
                xmax = min(L*(ii+1),s[1])
                ymin = L*jj
                ymax = min(L*(jj+1),s[1])
                I1 = I1_full[:, xmin:xmax, ymin:ymax]
                I2 = I2_full[:, xmin:xmax, ymin:ymax]
                cm = cm_full[xmin:xmax, ymin:ymax]

                I1 = Variable(torch.unsqueeze(I1, 0).float())
                I2 = Variable(torch.unsqueeze(I2, 0).float())
                cm = Variable(torch.unsqueeze(torch.from_numpy(1.0*cm),0).float())

                if feat:
                    output, _ = model(I1, I2)
                else:
                    output = model(I1, I2)
                    
                loss = criterion(output, cm.long())
                tot_loss += loss.data * np.prod(cm.size())
                tot_count += np.prod(cm.size())

                _, predicted = torch.max(output.data, 1)

                c = (predicted.int() == cm.data.int())
                for i in range(c.size(1)):
                    for j in range(c.size(2)):
                        l = int(cm.data[0, i, j])
                        class_correct[l] += c[0, i, j]
                        class_total[l] += 1
                        
                pr = (predicted.int() > 0).cpu().numpy()
                gt = (cm.data.int() > 0).cpu().numpy()
                
                tp += np.logical_and(pr, gt).sum()
                tn += np.logical_and(np.logical_not(pr), np.logical_not(gt)).sum()
                fp += np.logical_and(pr, np.logical_not(gt)).sum()
                fn += np.logical_and(np.logical_not(pr), gt).sum()
        
    net_loss = tot_loss/tot_count        
    net_loss = float(net_loss.cpu().numpy())
    
    net_accuracy = 100 * (tp + tn)/tot_count
    
    for i in range(n):
        class_accuracy[i] = 100 * class_correct[i] / max(class_total[i],0.00001)
        class_accuracy[i] =  float(class_accuracy[i].cpu().numpy())

    prec = tp / (tp + fp)
    rec = tp / (tp + fn)
    f_meas = 2 * prec * rec / (prec + rec)
    dice = 2 * prec * rec / (prec + rec)
    prec_nc = tn / (tn + fn)


    rec_nc = tn / (tn + fp)
    
    pr_rec = [prec, rec, dice, prec_nc, rec_nc]
    
    k = kappa(tp, tn, fp, fn)
    
    return {'net_loss': net_loss, 
            'net_accuracy': net_accuracy, 
            'class_accuracy': class_accuracy, 
            'precision': prec, 
            'recall': rec, 
            'dice': dice, 
            'kappa': k,
            'f_meas':f_meas}
# test patch
def extract_patches_batch(images, patch_size):
    """
    Extracts patches from a batch of images.
    
    Args:
        images (torch.Tensor): Input images of shape (B, C, H, W).
        patch_size (int): Size of each patch.

    Returns:
        torch.Tensor: Extracted patches of shape (B * num_patches, C, patch_size, patch_size)
    """
    B, C, H, W = images.shape
    #print(images.shape)

    # Ensure the image size is divisible by patch size
    assert H % patch_size == 0 and W % patch_size == 0, "Image size must be divisible by patch size"

    # Unfold the images to extract patches
    patches = images.unfold(2, patch_size, patch_size).unfold(3, patch_size, patch_size)
    # Shape after unfold: (B, C, num_patches_H, num_patches_W, patch_size, patch_size)

    # Reshape into (B * num_patches, C, patch_size, patch_size)
    num_patches_H = H // patch_size
    num_patches_W = W // patch_size
    patches = patches.permute(0, 2, 3, 1, 4, 5).contiguous()  # (B, num_patches_H, num_patches_W, C, patch_size, patch_size)
    patches = patches.view(B * num_patches_H * num_patches_W, C, patch_size, patch_size)

    #print(f"Extracted patches shape: {patches.shape} (Expected {B * num_patches_H * num_patches_W})")
    return patches


def reconstruct_image_batch(patches, batch_size, output_size=(32, 32), patch_size=2):
    if isinstance(output_size, int):
        output_size = (output_size, output_size)  # Ensure it's a tuple

    if isinstance(batch_size, tuple):
        batch_size = batch_size[0]  # Extract integer batch size

    # Ensure output size is valid
    if output_size[0] < patch_size or output_size[1] < patch_size:
        raise ValueError(f"Output size {output_size} must be larger than patch size {patch_size}")

    num_patches_per_row = output_size[0] // patch_size
    num_patches_per_col = output_size[1] // patch_size
    num_patches = num_patches_per_row * num_patches_per_col

    #print(f"Num patches: {num_patches}")
    #print(f"Fixed batch_size: {batch_size}")
    #print(f"num_patches_per_row: {num_patches_per_row}, num_patches_per_col: {num_patches_per_col}")
    #print(f"Expected {batch_size * num_patches} patches, but got {patches.shape[0]}")

    # Fix batch_size issue
    if patches.shape[0] == batch_size * num_patches:
        patches = patches.view(batch_size, num_patches_per_row, num_patches_per_col, -1, patch_size, patch_size)
    elif patches.shape[0] == num_patches:
        # If only patches for one image exist, assume batch_size=1
        patches = patches.view(1, num_patches_per_row, num_patches_per_col, -1, patch_size, patch_size)
    else:
        raise ValueError(f"Unexpected number of patches: {patches.shape[0]}")

    patches = patches.permute(0, 3, 1, 4, 2, 5).contiguous()
    reconstructed = patches.view(batch_size, -1, output_size[0], output_size[1])

    return reconstructed

def pad_trick(i, p):
    """
    Pads a multi-channel image to be divisible by p.
    
    Args:
        i (torch.Tensor or np.ndarray): Input image tensor of shape (C, H, W).
        p (int): Patch size to pad to.
    
    Returns:
        torch.Tensor: Padded image tensor of shape (C, H+p_h, W+p_w).
    """
    if isinstance(i, np.ndarray):
        i = torch.from_numpy(i)

    assert i.ndim == 3, f"Expected input shape (C, H, W), but got {i.shape}"

    C, H, W = i.shape  # Now considering channels
    p_h = (p - (H % p)) % p
    p_w = (p - (W % p)) % p

    # Padding format: (left, right, top, bottom)
    pi = F.pad(i, (0, p_w, 0, p_h), mode="constant", value=0)
    return pi


def get_mask_patch_batch(cmt):
    b_s = 1
    patch_size = PATCH_SIDE
    cmt = torch.from_numpy(cmt)
    orig_x = cmt.shape[0]
    orig_y = cmt.shape[1]
    cmt = torch.unsqueeze(cmt, 0)
    cmt = pad_trick(cmt, patch_size)
    cm_batch = torch.unsqueeze(cmt, 0)
    cm_patch = extract_patches_batch(cm_batch, patch_size)
    padded_x = cm_batch.shape[2]
    padded_y = cm_batch.shape[3]
    #cm_recon = reconstruct_image_batch(cm_patch, b_s, output_size, patch_size)
    return cm_patch, orig_x, orig_y, padded_x, padded_y 

def get_image_patch_batch(cmt):
    b_s = 1
    patch_size = PATCH_SIDE
    #cmt = torch.from_numpy(cmt)
    orig_x = cmt.shape[1]
    orig_y = cmt.shape[2]
    cmt = pad_trick(cmt, patch_size)
    cm_batch = torch.unsqueeze(cmt, 0)
    padded_x = cm_batch.shape[2]
    padded_y = cm_batch.shape[3]
    cm_patch = extract_patches_batch(cm_batch, patch_size)
    #cm_recon = reconstruct_image_batch(cm_patch, b_s, output_size, patch_size)
    return cm_patch, orig_x, orig_y, padded_x, padded_y
    
def unseen_test_patch(dset, criterion, model, feat=False):
    L = 1024
    N = 2
    model.eval()
    tot_loss = 0
    tot_count = 0
    tot_accurate = 0
    model = model.to(device)
    
    n = 2
    class_correct = list(0. for i in range(n))
    class_total = list(0. for i in range(n))
    class_accuracy = list(0. for i in range(n))
    
    tp = 0
    tn = 0
    fp = 0
    fn = 0
    
    for img_index in tqdm(dset.names):
        I1_full, I2_full, cm_full = dset.get_img(img_index)
        cm_patch, cm_x, cm_y, cm_px, cm_py = get_mask_patch_batch(cm_full)
        I1_patch, i_x, i_y, i_px, i_py = get_image_patch_batch(I1_full)
        I2_patch, _, _, _, _ = get_image_patch_batch(I2_full)
        #print(I1_patch.shape)

        for i in range(cm_patch.shape[0]):
            I1 = I1_patch[i]
            I2 = I2_patch[i]
            cm = cm_patch[i]
            #print(I1.shape)

            I1 = Variable(torch.unsqueeze(I1, 0).float().to(device))
            I2 = Variable(torch.unsqueeze(I2, 0).float().to(device))
            cm = Variable((1.0*cm).float().to(device))

            if feat:
                output, _ = model(I1, I2)
            else:
                output = model(I1, I2)
                
            loss = criterion(output, cm.long())
            tot_loss += loss.data * np.prod(cm.size())
            tot_count += np.prod(cm.size())

            _, predicted = torch.max(output.data, 1)

            c = (predicted.int() == cm.data.int())
            for i in range(c.size(1)):
                for j in range(c.size(2)):
                    l = int(cm.data[0, i, j])
                    class_correct[l] += c[0, i, j]
                    class_total[l] += 1
                    
            pr = (predicted.int() > 0).cpu().numpy()
            gt = (cm.data.int() > 0).cpu().numpy()
            
            tp += np.logical_and(pr, gt).sum()
            tn += np.logical_and(np.logical_not(pr), np.logical_not(gt)).sum()
            fp += np.logical_and(pr, np.logical_not(gt)).sum()
            fn += np.logical_and(np.logical_not(pr), gt).sum()
        
    net_loss = tot_loss/tot_count        
    net_loss = float(net_loss.cpu().numpy())
    
    net_accuracy = 100 * (tp + tn)/tot_count
    
    for i in range(n):
        class_accuracy[i] = 100 * class_correct[i] / max(class_total[i],0.00001)
        class_accuracy[i] =  float(class_accuracy[i].cpu().numpy())

    prec = tp / (tp + fp)
    rec = tp / (tp + fn)
    f_meas = 2 * prec * rec / (prec + rec)
    dice = 2 * prec * rec / (prec + rec)
    prec_nc = tn / (tn + fn)


    rec_nc = tn / (tn + fp)
    
    pr_rec = [prec, rec, dice, prec_nc, rec_nc]
    
    k = kappa(tp, tn, fp, fn)
    
    return {'net_loss': net_loss, 
            'net_accuracy': net_accuracy, 
            'class_accuracy': class_accuracy, 
            'precision': prec, 
            'recall': rec, 
            'dice': dice, 
            'kappa': k,
            'f_meas':f_meas}

def unseen_test_patch_batch(dset, criterion, model, feat=False, batch_size=BATCH_SIZE, patch_size=PATCH_SIDE, ood=False):
    L = 1024
    N = 2
    model.eval()
    tot_loss = 0
    tot_count = 0
    tot_accurate = 0
    model = model.to(device)
    
    n = 2
    class_correct = list(0. for i in range(n))
    class_total = list(0. for i in range(n))
    class_accuracy = list(0. for i in range(n))
    
    tp = 0
    tn = 0
    fp = 0
    fn = 0
    
    for img_index in tqdm(dset.names):
        if ood:
            I1_full, I2_full, cm_full = dset.get_ood_img(img_index)
        else:
            I1_full, I2_full, cm_full = dset.get_img(img_index)
        full_x, full_y = cm_full.shape
        cm = Variable(torch.unsqueeze(torch.from_numpy(1.0*cm_full),0).float().to(device))
        cm_patch, cm_x, cm_y, cm_px, cm_py = get_mask_patch_batch(cm_full)
        I1_patch, i_x, i_y, i_px, i_py = get_image_patch_batch(I1_full)
        I2_patch, _, _, _, _ = get_image_patch_batch(I2_full)
        
        # first perform inference
        I1 = Variable(I1_patch.float().to(device))
        I2 = Variable(I2_patch.float().to(device))
        if feat:
            output, _ = model(I1, I2)
        else:
            output = model(I1, I2)
        #print(output.shape)
        output_size = (cm_px,cm_py)
        b_s =  1
        cm_recon = reconstruct_image_batch(output, b_s, output_size, patch_size)
        #print(cm_recon.shape)
        output = cm_recon[:, :, :full_x, :full_y]
        #print(output.shape)
        loss = criterion(output, cm.long())
        tot_loss += loss.data * np.prod(cm.size())
        tot_count += np.prod(cm.size())

        _, predicted = torch.max(output.data, 1)

        c = (predicted.int() == cm.data.int())
        for i in range(c.size(1)):
            for j in range(c.size(2)):
                l = int(cm.data[0, i, j])
                class_correct[l] += c[0, i, j]
                class_total[l] += 1
                
        pr = (predicted.int() > 0).cpu().numpy()
        gt = (cm.data.int() > 0).cpu().numpy()
        
        tp += np.logical_and(pr, gt).sum()
        tn += np.logical_and(np.logical_not(pr), np.logical_not(gt)).sum()
        fp += np.logical_and(pr, np.logical_not(gt)).sum()
        fn += np.logical_and(np.logical_not(pr), gt).sum()
        
    net_loss = tot_loss/tot_count        
    net_loss = float(net_loss.cpu().numpy())
    
    net_accuracy = 100 * (tp + tn)/tot_count
    
    for i in range(n):
        class_accuracy[i] = 100 * class_correct[i] / max(class_total[i],0.00001)
        class_accuracy[i] =  float(class_accuracy[i].cpu().numpy())

    prec = tp / (tp + fp)
    rec = tp / (tp + fn)
    f_meas = 2 * prec * rec / (prec + rec)
    dice = 2 * prec * rec / (prec + rec)
    prec_nc = tn / (tn + fn)


    rec_nc = tn / (tn + fp)
    
    pr_rec = [prec, rec, dice, prec_nc, rec_nc]
    
    k = kappa(tp, tn, fp, fn)
    
    return {'net_loss': net_loss, 
            'net_accuracy': net_accuracy, 
            'class_accuracy': class_accuracy, 
            'precision': prec, 
            'recall': rec, 
            'dice': dice, 
            'kappa': k,
            'f_meas':f_meas}



# --- helper: per-image, per-channel [0,1] scaling (optional) ---
class PerImageMinMax01(object):
    def __call__(self, sample):
        for k in ('I1','I2'):
            x = sample[k].float()
            xmin = x.amin(dim=(1,2), keepdim=True)
            xmax = x.amax(dim=(1,2), keepdim=True)
            sample[k] = (x - xmin) / (xmax - xmin + 1e-6)
        return sample

# --- DROP-IN REPLACEMENT (minimal changes) ---
def unseen_test_patch_batch(
        dset, criterion, model, feat=False,
        batch_size=BATCH_SIZE, patch_size=PATCH_SIDE,
        ood=False,                      # keep your flag
        pair_transform=None,            # NEW: callable(sample)->sample for OOD
        apply_minmax=False              # NEW: normalize to [0,1] before OOD
    ):
    """
    If ood=True and pair_transform is provided, we apply it to full images
    (I1_full, I2_full) BEFORE patching. Labels are left unchanged.
    Metrics are NaN-safe.
    """
    model.eval()
    model = model.to(device)

    n = 2
    class_correct = [0.0 for _ in range(n)]
    class_total   = [0.0 for _ in range(n)]

    tp = tn = fp = fn = 0
    tot_loss = 0.0
    tot_count = 0.0
    eps = 1e-9  # for NaN-safe divisions

    mm = PerImageMinMax01() if apply_minmax else None

    for img_index in tqdm(dset.names):
        # --- load full images (your original API) ---
        I1_full, I2_full, cm_full = dset.get_img(img_index)

        # --- optional minmax + pair-wise OOD transform (full-frame) ---
        if ood and pair_transform is not None:
            sample = {'I1': I1_full.clone(), 'I2': I2_full.clone(),
                      'label': torch.from_numpy(1.0*cm_full).float()}
            if mm is not None:
                sample = mm(sample)
            sample = pair_transform(sample)
            I1_full, I2_full = sample['I1'], sample['I2']  # keep cm_full unchanged

        # --- patching (unchanged) ---
        full_x, full_y = cm_full.shape
        cm = Variable(torch.unsqueeze(torch.from_numpy(1.0*cm_full),0).float().to(device))
        cm_patch, cm_x, cm_y, cm_px, cm_py = get_mask_patch_batch(cm_full)
        I1_patch, i_x, i_y, i_px, i_py = get_image_patch_batch(I1_full)
        I2_patch, _, _, _, _ = get_image_patch_batch(I2_full)

        I1 = Variable(I1_patch.float().to(device))
        I2 = Variable(I2_patch.float().to(device))

        # --- model forward (unchanged) ---
        if feat:
            output, _ = model(I1, I2)
        else:
            output = model(I1, I2)

        output_size = (cm_px, cm_py)
        b_s = 1
        cm_recon = reconstruct_image_batch(output, b_s, output_size, patch_size)
        output = cm_recon[:, :, :full_x, :full_y]

        # --- loss accounting (NaN-safe count) ---
        loss = criterion(output, cm.long())
        num_pix = float(np.prod(cm.size()))
        tot_loss  += loss.data * num_pix
        tot_count += num_pix

        # --- predictions & per-class accuracy (unchanged logic, but safe) ---
        _, predicted = torch.max(output.data, 1)
        c = (predicted.int() == cm.data.int())
        for i in range(c.size(1)):
            for j in range(c.size(2)):
                l = int(cm.data[0, i, j])
                class_correct[l] += float(c[0, i, j])
                class_total[l]   += 1.0

        # --- confusion counts (NaN-safe metrics later) ---
        pr = (predicted.int() > 0).cpu().numpy()
        gt = (cm.data.int() > 0).cpu().numpy()
        tp += np.logical_and(pr, gt).sum()
        tn += np.logical_and(~pr, ~gt).sum()
        fp += np.logical_and(pr, ~gt).sum()
        fn += np.logical_and(~pr, gt).sum()

    # --- aggregate metrics (NaN-safe) ---
    net_loss = float((tot_loss / max(tot_count, 1.0)).cpu().numpy())
    net_accuracy = 100.0 * (tp + tn) / max(tot_count, 1.0)

    class_accuracy = [0.0, 0.0]
    for i in range(n):
        class_accuracy[i] = 100.0 * class_correct[i] / max(class_total[i], 1e-9)

    prec   = tp / max(tp + fp, 1e-9)
    rec    = tp / max(tp + fn, 1e-9)
    f_meas = 2 * prec * rec / max(prec + rec, 1e-9)
    dice   = f_meas  # same formula here
    k      = kappa(tp, tn, fp, fn)  # your function

    return {
        'net_loss': net_loss,
        'net_accuracy': net_accuracy,
        'class_accuracy': class_accuracy,
        'precision': prec,
        'recall': rec,
        'dice': dice,
        'kappa': k,
        'f_meas': f_meas
    }





def unseen_test_patch_batch_ood(dset, criterion, model, feat=False, batch_size=BATCH_SIZE, patch_size=PATCH_SIDE, ood=False):
    L = 1024
    N = 2
    model.eval()
    tot_loss = 0
    tot_count = 0
    tot_accurate = 0
    model = model.to(device)
    
    n = 2
    class_correct = list(0. for i in range(n))
    class_total = list(0. for i in range(n))
    class_accuracy = list(0. for i in range(n))
    
    tp = 0
    tn = 0
    fp = 0
    fn = 0
    
    for img_index in tqdm(dset.dataset.names):
        if ood:
            I1_full, I2_full, cm_full = dset.dataset.get_ood_img(img_index)
        else:
            I1_full, I2_full, cm_full = dset.dataset.get_img(img_index)
        full_x, full_y = cm_full.shape
        cm = Variable(torch.unsqueeze(torch.from_numpy(1.0*cm_full),0).float().to(device))
        cm_patch, cm_x, cm_y, cm_px, cm_py = get_mask_patch_batch(cm_full)
        I1_patch, i_x, i_y, i_px, i_py = get_image_patch_batch(I1_full)
        I2_patch, _, _, _, _ = get_image_patch_batch(I2_full)
        
        # first perform inference
        I1 = Variable(I1_patch.float().to(device))
        I2 = Variable(I2_patch.float().to(device))
        if feat:
            output, _ = model(I1, I2)
        else:
            output = model(I1, I2)
        #print(output.shape)
        output_size = (cm_px,cm_py)
        b_s =  1
        cm_recon = reconstruct_image_batch(output, b_s, output_size, patch_size)
        #print(cm_recon.shape)
        output = cm_recon[:, :, :full_x, :full_y]
        #print(output.shape)
        loss = criterion(output, cm.long())
        tot_loss += loss.data * np.prod(cm.size())
        tot_count += np.prod(cm.size())

        _, predicted = torch.max(output.data, 1)

        c = (predicted.int() == cm.data.int())
        for i in range(c.size(1)):
            for j in range(c.size(2)):
                l = int(cm.data[0, i, j])
                class_correct[l] += c[0, i, j]
                class_total[l] += 1
                
        pr = (predicted.int() > 0).cpu().numpy()
        gt = (cm.data.int() > 0).cpu().numpy()
        
        tp += np.logical_and(pr, gt).sum()
        tn += np.logical_and(np.logical_not(pr), np.logical_not(gt)).sum()
        fp += np.logical_and(pr, np.logical_not(gt)).sum()
        fn += np.logical_and(np.logical_not(pr), gt).sum()
        
    net_loss = tot_loss/tot_count        
    net_loss = float(net_loss.cpu().numpy())
    
    net_accuracy = 100 * (tp + tn)/tot_count
    
    for i in range(n):
        class_accuracy[i] = 100 * class_correct[i] / max(class_total[i],0.00001)
        class_accuracy[i] =  float(class_accuracy[i].cpu().numpy())

    prec = tp / (tp + fp)
    rec = tp / (tp + fn)
    f_meas = 2 * prec * rec / (prec + rec)
    dice = 2 * prec * rec / (prec + rec)
    prec_nc = tn / (tn + fn)


    rec_nc = tn / (tn + fp)
    
    pr_rec = [prec, rec, dice, prec_nc, rec_nc]
    
    k = kappa(tp, tn, fp, fn)
    
    return {'net_loss': net_loss, 
            'net_accuracy': net_accuracy, 
            'class_accuracy': class_accuracy, 
            'precision': prec, 
            'recall': rec, 
            'dice': dice, 
            'kappa': k,
            'f_meas':f_meas}



    
def unseen_test_patch_batch_robust(dset, criterion, model, feat=False, batch_size=BATCH_SIZE, patch_size=PATCH_SIDE):
    L = 1024
    N = 2
    model.eval()
    tot_loss = 0
    tot_count = 0
    tot_accurate = 0
    model = model.to(device)
    
    n = 2
    class_correct = list(0. for i in range(n))
    class_total = list(0. for i in range(n))
    class_accuracy = list(0. for i in range(n))
    
    tp = 0
    tn = 0
    fp = 0
    fn = 0
    
    for img_index in tqdm(dset.names):
        I1_full, I2_full, cm_full = dset.get_img(img_index)
        full_x, full_y = cm_full.shape
        cm = Variable(torch.unsqueeze(torch.from_numpy(1.0*cm_full),0).float().to(device))
        cm_patch, cm_x, cm_y, cm_px, cm_py = get_mask_patch_batch(cm_full)
        I1_patch, i_x, i_y, i_px, i_py = get_image_patch_batch(I1_full)
        I2_patch, _, _, _, _ = get_image_patch_batch(I2_full)
        
        # first perform inference
        I1 = Variable(I1_patch.float().to(device))
        I2 = Variable(I2_patch.float().to(device))
        if feat:
            output, _ = model(I1, I2)
        else:
            output = model(I1, I2)
        #print(output.shape)
        output_size = (cm_px,cm_py)
        b_s =  1
        cm_recon = reconstruct_image_batch(output, b_s, output_size, patch_size)
        #print(cm_recon.shape)
        output = cm_recon[:, :, :full_x, :full_y]
        #print(output.shape)
        loss = criterion(output, cm.long())
        tot_loss += loss.data * np.prod(cm.size())
        tot_count += np.prod(cm.size())

        _, predicted = torch.max(output.data, 1)

        c = (predicted.int() == cm.data.int())
        for i in range(c.size(1)):
            for j in range(c.size(2)):
                l = int(cm.data[0, i, j])
                class_correct[l] += c[0, i, j]
                class_total[l] += 1
                
        pr = (predicted.int() > 0).cpu().numpy()
        gt = (cm.data.int() > 0).cpu().numpy()
        
        tp += np.logical_and(pr, gt).sum()
        tn += np.logical_and(np.logical_not(pr), np.logical_not(gt)).sum()
        fp += np.logical_and(pr, np.logical_not(gt)).sum()
        fn += np.logical_and(np.logical_not(pr), gt).sum()
        
    net_loss = tot_loss/tot_count        
    net_loss = float(net_loss.cpu().numpy())
    
    net_accuracy = 100 * (tp + tn)/tot_count
    
    for i in range(n):
        class_accuracy[i] = 100 * class_correct[i] / max(class_total[i],0.00001)
        class_accuracy[i] =  float(class_accuracy[i].cpu().numpy())

    prec = tp / (tp + fp)
    rec = tp / (tp + fn)
    f_meas = 2 * prec * rec / (prec + rec)
    dice = 2 * prec * rec / (prec + rec)
    prec_nc = tn / (tn + fn)


    rec_nc = tn / (tn + fp)
    
    pr_rec = [prec, rec, dice, prec_nc, rec_nc]
    
    k = kappa(tp, tn, fp, fn)
    
    return {'net_loss': net_loss, 
            'net_accuracy': net_accuracy, 
            'class_accuracy': class_accuracy, 
            'precision': prec, 
            'recall': rec, 
            'dice': dice, 
            'kappa': k,
            'f_meas':f_meas}

    
def save_test_results(model, dset, net_name, feat=False, batch_size=BATCH_SIZE, patch_size=PATCH_SIDE):
    for name in tqdm(dset.names):
        #print(name)
        with warnings.catch_warnings():
            I1_full, I2_full, cm_full = dset.get_img(name)
            full_x, full_y = cm_full.shape
            cm = Variable(torch.unsqueeze(torch.from_numpy(1.0*cm_full),0).float().to(device))
            cm_patch, cm_x, cm_y, cm_px, cm_py = get_mask_patch_batch(cm_full)
            I1_patch, i_x, i_y, i_px, i_py = get_image_patch_batch(I1_full)
            I2_patch, _, _, _, _ = get_image_patch_batch(I2_full)
            # first perform inference
            I1 = Variable(I1_patch.float().to(device))
            I2 = Variable(I2_patch.float().to(device))
            #output, _ = model(I1, I2)
            if feat:
                output, _ = model(I1, I2)
            else:
                output = model(I1, I2)
            output_size = (cm_px,cm_py)
            b_s =  1
            cm_recon = reconstruct_image_batch(output, b_s, output_size, patch_size)
            #print(cm_recon.shape)
            output = cm_recon[:, :, :full_x, :full_y]
            _, predicted = torch.max(output.data, 1)
            I = np.stack((np.squeeze(cm.cpu().numpy()),np.squeeze(predicted.cpu().numpy()),np.squeeze(cm.cpu().numpy())),2)
            #im = Image.fromarray((I * 255).astype(np.uint8))
            im = (I * 255).astype(np.uint8)
            io.imsave(f'{net_name}-{name}.png',im)
print("Training & testing Functions Defined")

IMPORTS OK
DEFINITIONS OK
UTILS OK
Training & testing Functions Defined


In [4]:
# ================== bounded_ood_transforms.py ==================
import math, random
import numpy as np
import torch
import torch.nn.functional as F
from typing import Optional

# -------- basic helpers --------
def _clamp01(x: torch.Tensor) -> torch.Tensor:
    return x.clamp_(0.0, 1.0)

def _depthwise_box_blur(x: torch.Tensor, k: int):
    """x: (C,H,W). Depthwise box blur with odd kernel size."""
    C, H, W = x.shape
    k = int(k) | 1
    pad = k // 2
    filt = torch.ones(C, 1, k, k, device=x.device, dtype=x.dtype) / (k * k)
    y = F.pad(x.unsqueeze(0), (pad, pad, pad, pad), mode='reflect')
    y = F.conv2d(y, filt, groups=C).squeeze(0)
    return y

def _project_linf(delta: torch.Tensor, eps: float) -> torch.Tensor:
    """Scale delta so that ‖delta‖∞ ≤ eps (per-sample, per-channel)."""
    if delta.dim() == 3:  # CHW
        scale = eps / (delta.abs().amax(dim=(1,2), keepdim=True) + 1e-8)
    else:                 # NCHW
        scale = eps / (delta.abs().amax(dim=(2,3), keepdim=True) + 1e-8)
    scale = torch.clamp(scale, max=1.0)
    return delta * scale

class PerImageMinMax01:
    """Bring I1/I2 to [0,1] per image, per channel; keep label untouched."""
    def __call__(self, sample):
        for k in ('I1','I2'):
            x = sample[k].float()
            xmin = x.amin(dim=(1,2), keepdim=True)
            xmax = x.amax(dim=(1,2), keepdim=True)
            sample[k] = (x - xmin) / (xmax - xmin + 1e-6)
        return sample

# -------- single-image proposals (CHW, in [0,1]) --------
def _prop_lowfreq(x: torch.Tensor, eps: float, seed: int, kernel: int) -> torch.Tensor:
    # blurred noise → additive delta → L∞ projection → clamp
    with torch.random.fork_rng():
        torch.manual_seed(seed)
        n = torch.randn_like(x)
    n = _depthwise_box_blur(n, kernel)
    d = _project_linf(n, eps)
    y = x + d
    return _clamp01(y)

def _prop_per_band(x: torch.Tensor, eps: float, seed: int) -> torch.Tensor:
    # per-channel DC bias in [-eps, eps]
    with torch.random.fork_rng():
        torch.manual_seed(seed)
        C = x.shape[0]
        bias = (2*torch.rand(C,1,1, device=x.device, dtype=x.dtype) - 1.0) * eps
    y = x + bias
    return _clamp01(y)

def _prop_blur(x: torch.Tensor, eps: float, kernel: int) -> torch.Tensor:
    xb = _depthwise_box_blur(x, kernel)
    d  = xb - x
    d  = _project_linf(d, eps)
    y  = x + d
    return _clamp01(y)

def _prop_shadow(x: torch.Tensor, eps: float, seed: int, kernel: int, alpha: float=0.2) -> torch.Tensor:
    # soft multiplicative darkening with low-freq mask, then L∞-project as additive
    with torch.random.fork_rng():
        torch.manual_seed(seed)
        mask = torch.rand(1, x.shape[1], x.shape[2], device=x.device, dtype=x.dtype)
    mask = _depthwise_box_blur(mask, kernel)
    y_raw = x * (1 - alpha*mask)
    d = _project_linf(y_raw - x, eps)
    y = x + d
    return _clamp01(y)

# -------- pair-aware ε-bounded OOD transform (operates on sample dict) --------
class BoundedPairOOD:
    """
    Apply a structured, ε-bounded OOD perturbation to (I1, I2) with CHW layout.
    kind: 'lowfreq' | 'per_band' | 'blur' | 'shadow'
    Use either:
      - mode in {'shared','indep','anti'}, or
      - correlated=True/False (alias: True→'shared', False→'indep').
    Outputs are clamped to [0,1] if assume01=True (recommended).
    """
    def __init__(self, *,
                 kind: str="lowfreq",
                 eps: float=1/255,
                 correlated: Optional[bool]=None,
                 mode: Optional[str]=None,
                 lowfreq_kernel: int=9,
                 blur_kernel: int=5,
                 seed: int=123,
                 assume01: bool=True):
        self.kind = str(kind)
        self.eps  = float(eps)
        self.lowfreq_kernel = int(lowfreq_kernel)
        self.blur_kernel    = int(blur_kernel)
        self.assume01       = bool(assume01)
        self._np = np.random.RandomState(int(seed))

        # alias handling
        if mode is not None and correlated is not None:
            raise ValueError("Specify either 'mode' or 'correlated', not both.")
        if mode is None and correlated is None:
            mode = "indep"
        if correlated is not None:
            mode = "shared" if correlated else "indep"
        if mode not in ("shared","indep","anti"):
            raise ValueError("mode must be one of {'shared','indep','anti'}")
        self.mode = mode

    def _prop(self, x: torch.Tensor, eps: float, seed: int):
        if self.kind == "lowfreq":
            return _prop_lowfreq(x, eps, seed, self.lowfreq_kernel)
        if self.kind == "per_band":
            return _prop_per_band(x, eps, seed)
        if self.kind == "blur":
            return _prop_blur(x, eps, self.blur_kernel)
        if self.kind == "shadow":
            return _prop_shadow(x, eps, seed, self.lowfreq_kernel)
        raise ValueError(f"Unknown kind '{self.kind}'")

    def __call__(self, sample):
        I1, I2, y = sample['I1'].float(), sample['I2'].float(), sample['label']
        # seeds per sample (stable across get_img/__getitem__)
        s = int(self._np.randint(0, 2**31-1))
        s1 = s
        s2 = s if self.mode in ("shared","anti") else (s + 1337)

        # propose
        Y1 = self._prop(I1, self.eps, s1)
        Y2 = self._prop(I2, self.eps, s2)

        if self.mode == "anti":
            # invert date-2 residual sign
            d1 = Y1 - I1
            d2 = Y2 - I2
            Y2 = _clamp01(I2 - _project_linf(d1, self.eps))  # opposite-signed residual magnitude

        if self.assume01:
            Y1 = _clamp01(Y1)
            Y2 = _clamp01(Y2)

        return {'I1': Y1, 'I2': Y2, 'label': y}

# -------- wrapper that makes get_img() path see OOD too --------
class OODDatasetView(torch.utils.data.Dataset):
    """
    Wraps a base ChangeDetectionDataset and applies a pair-aware transform
    in both __getitem__ AND get_img, so code paths using either will see OOD.
    """
    def __init__(self, base_ds, pair_transform, apply_minmax=True):
        self.base = base_ds
        self.transform = pair_transform
        self.apply_minmax = apply_minmax
        # mirror the public fields the test code expects
        self.names = list(base_ds.names)
        self.n_patches_per_image = dict(base_ds.n_patches_per_image)
        self.n_patches = base_ds.n_patches
        self.patch_coords = list(base_ds.patch_coords)

    def __len__(self): return len(self.base)

    def _minmax01(self, x: torch.Tensor) -> torch.Tensor:
        xmin = x.amin(dim=(1,2), keepdim=True)
        xmax = x.amax(dim=(1,2), keepdim=True)
        return (x - xmin) / (xmax - xmin + 1e-6)

    def _apply_pair(self, I1, I2, cm):
        sample = {
            'I1': self._minmax01(I1) if self.apply_minmax else I1,
            'I2': self._minmax01(I2) if self.apply_minmax else I2,
            'label': torch.from_numpy(cm.astype('float32')) if not torch.is_tensor(cm) else cm
        }
        sample = self.transform(sample)
        return sample['I1'], sample['I2'], cm

    # used by your DataLoader path (kept for completeness)
    def __getitem__(self, idx):
        s = self.base.__getitem__(idx)
        sample = {'I1': s['I1'].float(), 'I2': s['I2'].float(), 'label': s['label']}
        # IMPORTANT: apply minmax BEFORE the pair transform if you want ε in [0,1]
        if self.apply_minmax:
            sample = PerImageMinMax01()(sample)
        if self.transform is not None:
            sample = self.transform(sample)
        return sample

    # CRITICAL: your unseen_test_patch_batch uses this path
    def get_img(self, im_name):
        I1 = self.base.imgs_1[im_name]
        I2 = self.base.imgs_2[im_name]
        cm = self.base.change_maps[im_name]
        return self._apply_pair(I1, I2, cm)
# ================== end of file ==================


In [5]:
#Train dataset
# ==== TRAIN DATASET + LOADER (clean) ======================================
# Assumes you already defined:
# - ChangeDetectionDataset
# - RandomFlip, RandomRot
# - seed_worker, g (Generator), device, PATH_TO_DATASET, TRAIN_STRIDE, BATCH_SIZE
# - DATA_AUG (bool)
def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    numpy.random.seed(worker_seed)
    random.seed(worker_seed)

g = torch.Generator()
g.manual_seed(SEED)
# Compose your training transform exactly like before (you can also add PerImageMinMax01 if you want)
if DATA_AUG:
    train_tf = tr.Compose([RandomFlip(), RandomRot()])
else:
    train_tf = None

train_dataset = ChangeDetectionDataset(
    PATH_TO_DATASET, train=True, stride=TRAIN_STRIDE, transform=train_tf
)

# your imbalance weights (unchanged)
dataset_weights = torch.FloatTensor(train_dataset.weights).to(device)
print("Class weights:", dataset_weights)

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    worker_init_fn=seed_worker,
    generator=g
)

print("Train dataloader ready")

# ==== OPTIONAL: OOD-robust training loader (ε-bounded perturbations) ======
# If/when you want to try robust training, flip this on:
USE_OOD_TRAIN = True  # set True to enable
if USE_OOD_TRAIN:
    robust_train_tf = tr.Compose([
        RandomFlip(), RandomRot(),
        PerImageMinMax01(),  # ensure [0,1]
        BoundedPairOOD(kind="lowfreq", eps=1/255,
                       correlated=True, lowfreq_kernel=9,
                       seed=2025, assume01=True)
    ])
    train_dataset_ood = ChangeDetectionDataset(
        PATH_TO_DATASET, train=True, stride=TRAIN_STRIDE, transform=robust_train_tf
    )
    train_loader_ood = DataLoader(
        train_dataset_ood,
        batch_size=BATCH_SIZE,
        shuffle=True,
        worker_init_fn=seed_worker,
        generator=g
    )
    print("OOD-robust train dataloader ready")

# ==== LOSS (example) ======================================================
# Now your earlier line works because dataset_weights is defined:
loss_fn = nn.CrossEntropyLoss(weight=dataset_weights)

100%|███████████████████████████████████████████████████████████████████████████████████| 24/24 [00:22<00:00,  1.05it/s]


Class weights: tensor([0.6437, 1.9356])
Train dataloader ready


100%|███████████████████████████████████████████████████████████████████████████████████| 24/24 [00:22<00:00,  1.07it/s]

OOD-robust train dataloader ready





In [6]:
# 1) Identity transform that just passes (I1,I2,label) through.
class IdentityPair:
    def __call__(self, sample):
        return sample  # no changes

# 2) Build the base clean dataset exactly as your test API expects (no transform here)
test_dataset = ChangeDetectionDataset(PATH_TO_DATASET, train=False, stride=TRAIN_STRIDE, transform=None)

# 3) Wrap it in views so BOTH clean and OOD go through the same normalization path
#    (set apply_minmax=True to consistently bring inputs to [0,1] per image).
clean_view = OODDatasetView(
    base_ds=test_dataset,
    pair_transform=IdentityPair(),
    apply_minmax=True        # <— SAME normalization as OOD views
)

ood_views = {
    "lf_1": OODDatasetView(test_dataset, BoundedPairOOD(kind="lowfreq", eps=1/255, correlated=True,  lowfreq_kernel=9, seed=123), True),
    "lf_2": OODDatasetView(test_dataset, BoundedPairOOD(kind="lowfreq", eps=2/255, correlated=True,  lowfreq_kernel=9, seed=123), True),
    "shadow": OODDatasetView(test_dataset, BoundedPairOOD(kind="shadow",  eps=1/255, correlated=False, lowfreq_kernel=9, seed=7),  True),
    "pband":  OODDatasetView(test_dataset, BoundedPairOOD(kind="per_band", eps=1/255, correlated=True,                 seed=9),  True),
    "blur_1": OODDatasetView(test_dataset, BoundedPairOOD(kind="blur",    eps=1/255, correlated=False, blur_kernel=5, seed=11), True),
}

100%|███████████████████████████████████████████████████████████████████████████████████| 10/10 [00:06<00:00,  1.44it/s]


In [15]:
# # 4) Evaluate using your existing API (IMPORTANT: pass the *dataset views*, not DataLoaders)
# results = []

# specs = [
#     ("clean",   None),
#     ("lf_1",    BoundedPairOOD(kind="lowfreq", eps=1/255, correlated=True,  seed=42)),
#     ("lf_2",    BoundedPairOOD(kind="lowfreq", eps=2/255, correlated=True,  seed=42)),
#     ("shadow",  BoundedPairOOD(kind="shadow",  eps=1/255, correlated=False, seed=7)),
#     ("pband",   BoundedPairOOD(kind="per_band",eps=1/255, correlated=True,  seed=9)),
#     ("blur_1",  BoundedPairOOD(kind="blur",   eps=1/255, correlated=False,  seed=11)),
# ]
# for name, tf in specs:
#     m = unseen_test_patch_batch(test_dataset, loss_fn, netsota, feat=False, 
#                                 ood=(tf is not None),
#                                 pair_transform=tf,
#                                 apply_minmax=True)
#     m['setting'] = name
#     results.append(m)




# # 5) Turn into a DataFrame + save
# import pandas as pd
# df = pd.DataFrame(results)
# print(df)
# df.to_csv("./ood_eval_summary.csv", index=False)

# # Optional: LaTeX table
# def df_to_latex_table(df, out_path="table_ood_eval_summary.tex"):
#     cols = ["setting","net_loss","net_accuracy","acc_bg","acc_fg","precision","recall","dice","kappa","f_meas"]
#     use = [c for c in cols if c in df.columns]
#     with open(out_path, "w") as f:
#         f.write("\\begin{table}[t]\n\\centering\n")
#         f.write("\\caption{\\textbf{Empirical OOD evaluation on OSCD test set.}}\n")
#         f.write("\\label{tab:ood_eval_summary}\n")
#         f.write("\\begin{tabular}{%s}\n\\toprule\n" % ("l" + "r"*(len(use)-1)))
#         f.write(" & ".join([("\\textbf{"+c.replace('_','\\_')+"}") for c in use]) + " \\\\\n\\midrule\n")
#         for _, r in df[use].iterrows():
#             f.write(" & ".join([str(r[c]) if not isinstance(r[c], float) else f"{r[c]:.4f}" for c in use]) + " \\\\\n")
#         f.write("\\bottomrule\n\\end{tabular}\n\\end{table}\n")
#     print(f"[saved] {out_path}")

# df_to_latex_table(df)

100%|███████████████████████████████████████████████████████████████████████████████████| 10/10 [00:39<00:00,  3.94s/it]
100%|███████████████████████████████████████████████████████████████████████████████████| 10/10 [00:42<00:00,  4.29s/it]
100%|███████████████████████████████████████████████████████████████████████████████████| 10/10 [00:43<00:00,  4.38s/it]
100%|███████████████████████████████████████████████████████████████████████████████████| 10/10 [00:39<00:00,  3.93s/it]
100%|███████████████████████████████████████████████████████████████████████████████████| 10/10 [00:41<00:00,  4.15s/it]
100%|███████████████████████████████████████████████████████████████████████████████████| 10/10 [00:42<00:00,  4.21s/it]

   net_loss  net_accuracy                          class_accuracy  precision  \
0  0.218059     94.437734  [96.43497681799634, 57.79088114560873]   0.469065   
1  0.599110     94.831699                            [100.0, 0.0]   0.000000   
2  0.599109     94.831699                            [100.0, 0.0]   0.000000   
3  0.599081     94.831699                            [100.0, 0.0]   0.000000   
4  0.599216     94.831699                            [100.0, 0.0]   0.000000   
5  0.599111     94.831699                            [100.0, 0.0]   0.000000   

     recall      dice     kappa    f_meas setting  
0  0.577909  0.517829  0.488654  0.517829   clean  
1  0.000000  0.000000  0.000000  0.000000    lf_1  
2  0.000000  0.000000  0.000000  0.000000    lf_2  
3  0.000000  0.000000  0.000000  0.000000  shadow  
4  0.000000  0.000000  0.000000  0.000000   pband  
5  0.000000  0.000000  0.000000  0.000000  blur_1  
[saved] table_ood_eval_summary.tex





In [5]:
# Lipra friendly encoder-decoder
class DoubleConv(nn.Module): 
    def __init__(self, in_channels, out_channels): 
        super().__init__() 
        self.double_conv = nn.Sequential( 
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), 
            nn.BatchNorm2d(out_channels), 
            nn.ReLU(inplace=True), 
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), 
            nn.BatchNorm2d(out_channels), 
            nn.ReLU(inplace=True),
        )
 
    def forward(self, x):
        return self.double_conv(x)
 
class Down(nn.Module): 
    def __init__(self, in_channels, out_channels):
        super().__init__() 
        self.conv = nn.Sequential( 
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1), 
            nn.BatchNorm2d(out_channels), 
            nn.ReLU(inplace=True), 
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), 
            nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), 
        )
 
    def forward(self, x):
        return self.conv(x)
 
class Up(nn.Module): 
    def __init__(self, in_channels, out_channels): 
        super().__init__() 
        self.up = nn.Sequential( 
            nn.Upsample(scale_factor=2, mode='nearest'), 
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) 
        ) 
        self.conv = DoubleConv(in_channels, out_channels)
 
    def forward(self, x1, x2):
        x1 = self.up(x1)
        # ensure spatial alignment by cropping or center crop if needed
        if x1.shape[-2:] != x2.shape[-2:]:
            diffY = x2.size()[2] - x1.size()[2]
            diffX = x2.size()[3] - x1.size()[3]
            x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                            diffY // 2, diffY - diffY // 2])
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)
 
class OutConv(nn.Module): 
    def __init__(self, in_channels, out_channels): 
        super().__init__() 
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
 
    def forward(self, x):
        return self.conv(x)
 
class EncDec_LiRPA(nn.Module): 
    def __init__(self, in_channels=26, out_channels=2): 
        super().__init__() 
        self.inc = DoubleConv(in_channels, 8) 
        self.down1 = Down(8, 16) 
        self.down2 = Down(16, 32) 
        self.down3 = Down(32, 64) 
        self.down4 = Down(64, 128) 
        self.up1 = Up(128, 64) 
        self.up2 = Up(64, 32) 
        self.up3 = Up(32, 16) 
        self.up4 = Up(16, 8) 
        self.outc = OutConv(8, out_channels)
 
    def forward(self, x1, x2):

        x = torch.cat((x1, x2), 1)
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits

# Example usage
model, model_name = EncDec_LiRPA(), 'EcnDec_LiRPA'
#x = torch.randn(1, 13, 128, 128) 
x = torch.randn(1, 13, 256, 256)
out = model(x, x)
print(summary(model, input_size=(x.shape, x.shape)))

Layer (type:depth-idx)                   Output Shape              Param #
EncDec_LiRPA                             [1, 2, 256, 256]          --
├─DoubleConv: 1-1                        [1, 8, 256, 256]          --
│    └─Sequential: 2-1                   [1, 8, 256, 256]          --
│    │    └─Conv2d: 3-1                  [1, 8, 256, 256]          1,880
│    │    └─BatchNorm2d: 3-2             [1, 8, 256, 256]          16
│    │    └─ReLU: 3-3                    [1, 8, 256, 256]          --
│    │    └─Conv2d: 3-4                  [1, 8, 256, 256]          584
│    │    └─BatchNorm2d: 3-5             [1, 8, 256, 256]          16
│    │    └─ReLU: 3-6                    [1, 8, 256, 256]          --
├─Down: 1-2                              [1, 16, 128, 128]         --
│    └─Sequential: 2-2                   [1, 16, 128, 128]         --
│    │    └─Conv2d: 3-7                  [1, 16, 128, 128]         1,168
│    │    └─BatchNorm2d: 3-8             [1, 16, 128, 128]         32
│    │  

In [None]:
# Lipra friendly encoder-decoder network loading
netsota, netsota_name = EncDec_LiRPA(), 'EcnDec_LiRPA'
netsota.load_state_dict(torch.load('../oscd_ood/verificationfriendly_EncDec_LiRPA.pth.tar'))
netsota = netsota.to(device)
loss_fn = nn.CrossEntropyLoss(weight=dataset_weights)
print('LOAD OK')

In [6]:
# Lipra friendly Local Attention Encoder-decoder Falconet Architecture   
class AttentionCondenser(nn.Module):
    """Lightweight attention using convolutions instead of full self-attention."""
    def __init__(self, d, reduction=8, dropout=0.1):
        super().__init__()
        self.conv1 = nn.Conv2d(d, d // reduction, kernel_size=1)
        self.conv2 = nn.Conv2d(d // reduction, d, kernel_size=1)
        self.sigmoid = nn.Sigmoid()
        self.dropout = nn.Dropout2d(p=dropout)  # Dropout in attention

    def forward(self, x):
        attn = self.conv1(x)
        attn = self.conv2(attn)
        attn = self.dropout(attn)  # Dropout after attention layers
        return x * self.sigmoid(attn)


class ConvAttention(nn.Module):
    def __init__(self, hidden_dim, kernel_size=3, reduction_ratio=8):
        super().__init__()
        reduced_dim = hidden_dim // reduction_ratio

        # Downsample with depthwise conv (local feature extraction)
        self.depthwise_conv = nn.Conv1d(hidden_dim, hidden_dim, kernel_size, padding=kernel_size//2, groups=hidden_dim, bias=False)

        # Pointwise conv (lightweight attention projection)
        self.q_proj = nn.Conv1d(hidden_dim, reduced_dim, 1, bias=False)
        self.k_proj = nn.Conv1d(hidden_dim, reduced_dim, 1, bias=False)
        self.v_proj = nn.Conv1d(hidden_dim, hidden_dim, 1, bias=False)

        # Attention Scoring
        self.score_proj = nn.Conv1d(reduced_dim, hidden_dim, 1, bias=False)

        # Final pointwise conv
        self.out_proj = nn.Conv1d(hidden_dim, hidden_dim, 1, bias=False)

    def forward(self, x):
        """
        x: (batch, seq_len, hidden_dim)
        Output: (batch, seq_len, hidden_dim)
        """
        x = x.transpose(1, 2)  # (batch, hidden_dim, seq_len) for convs

        # Local feature extraction
        x_dw = self.depthwise_conv(x)

        # Query, Key, Value projection
        q = self.q_proj(x_dw)
        k = self.k_proj(x_dw)
        v = self.v_proj(x_dw)

        # Lightweight attention (sigmoid gating)
        attn_scores = self.score_proj(q * k).sigmoid()

        # Apply attention to values
        attn_output = attn_scores * v

        # Final projection
        out = self.out_proj(attn_output)

        # Residual connection
        out = out + x
        return out.transpose(1, 2)  # Back to (batch, seq_len, hidden_dim)


class MultiHeadConvAttention(nn.Module):
    def __init__(self, hidden_dim, num_heads=2, kernel_size=3):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = hidden_dim // num_heads  # Ensure correct head dim

        assert hidden_dim % num_heads == 0, \
            f"hidden_dim ({hidden_dim}) must be divisible by num_heads ({num_heads})"

        self.heads = nn.ModuleList([
            ConvAttention(self.head_dim, kernel_size) for _ in range(num_heads)
        ])
        self.out_proj = nn.Linear(hidden_dim, hidden_dim, bias=False)

    def forward(self, x):
        batch_size, seq_len, embed_dim = x.shape

        #print(f"embed_dim = {embed_dim}, num_heads * head_dim = {self.num_heads * self.head_dim}")  # Debug

        assert embed_dim == self.num_heads * self.head_dim, \
            f"Expected embedding dim {embed_dim} to match {self.num_heads * self.head_dim} (={self.num_heads * self.head_dim})"

        x = x.view(batch_size, seq_len, self.num_heads, self.head_dim)  
        x = x.permute(2, 0, 1, 3)  # (num_heads, batch, seq_len, head_dim)

        x = torch.stack([head(x[i]) for i, head in enumerate(self.heads)], dim=0)

        x = x.permute(1, 2, 0, 3).reshape(batch_size, seq_len, embed_dim)  # Merge heads
        return self.out_proj(x)

class FALCONetMHA_LiRPA(nn.Module):
    """FALCONetMHA_LiRPA segmentation network."""

 
class FALCONetMHA_LiRPA(nn.Module): 
    def __init__(self, in_channels=26, out_channels=2, dropout=0.1, reduction=8, attention=True, num_heads=4):
        """Init FALCONetMHA_LiRPA fields."""
        super(FALCONetMHA_LiRPA, self).__init__()
        self.dropout = dropout
        self.reduction = reduction
        self.attention = attention
        cur_depth = 8
        self.inc = DoubleConv(in_channels, cur_depth) 
        self.down1 = Down(cur_depth, cur_depth*2) 
        cur_depth *= 2

        self.down2 = Down(cur_depth, cur_depth*2) 
        cur_depth *= 2
        if self.attention:
            self.token_mixer_2 = MultiHeadConvAttention(cur_depth, num_heads=num_heads)
        self.down3 = Down(cur_depth, cur_depth*2) 
        cur_depth *= 2
        if self.attention:
            self.token_mixer_3 = MultiHeadConvAttention(cur_depth, num_heads=num_heads)
        self.down4 = Down(cur_depth, cur_depth*2) 
        cur_depth *= 2
        if self.attention:
            self.token_mixer_4 = MultiHeadConvAttention(cur_depth, num_heads=num_heads)
        self.up1 = Up(cur_depth, int(cur_depth/2)) 
        cur_depth = int(cur_depth/2)
        self.up2 = Up(cur_depth, int(cur_depth/2))
        cur_depth = int(cur_depth/2)
        self.up3 = Up(cur_depth, int(cur_depth/2))
        cur_depth = int(cur_depth/2)
        self.up4 = Up(cur_depth, int(cur_depth/2)) 
        cur_depth = int(cur_depth/2)
        self.outc = OutConv(cur_depth, out_channels)
 
    def forward(self, x1, x2):

        x = torch.cat((x1, x2), 1)
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        if self.attention:
            # Flatten spatial dimensions
            B, C, H, W = x3.shape
            N = H * W  # Number of tokens
            x3 = x3.view(B, C, N).transpose(1, 2)  # (B, N, C)
            
            # Apply Lightweight Convolutional Multi-Head Attention
            x3 = self.token_mixer_2(x3)  
            
            # Restore spatial shape
            x3 = x3.transpose(1, 2).view(B, C, H, W)  # Restore spatial shape
        x4 = self.down3(x3)
        if self.attention:
            # Flatten spatial dimensions
            B, C, H, W = x4.shape
            N = H * W  # Number of tokens
            x4 = x4.view(B, C, N).transpose(1, 2)  # (B, N, C)
            
            # Apply Lightweight Convolutional Multi-Head Attention
            x4 = self.token_mixer_3(x4)  
            
            # Restore spatial shape
            x4 = x4.transpose(1, 2).view(B, C, H, W)  # Restore spatial shape
        x5 = self.down4(x4)
        if self.attention:
            # Flatten spatial dimensions
            B, C, H, W = x5.shape
            N = H * W  # Number of tokens
            x5 = x5.view(B, C, N).transpose(1, 2)  # (B, N, C)
            
            # Apply Lightweight Convolutional Multi-Head Attention
            x5 = self.token_mixer_4(x5)  
            
            # Restore spatial shape
            x5 = x5.transpose(1, 2).view(B, C, H, W)  # Restore spatial shape
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits

# Example usage
model, model_name = FALCONetMHA_LiRPA(2*13, 2 , dropout=0.1, reduction=8, attention=True, num_heads=4), 'FALCONetMHA_LiRPA'
#x = torch.randn(1, 13, 128, 128) 
x = torch.randn(1, 13, 256, 256)
out = model(x, x)
print(summary(model, input_size=(x.shape, x.shape)))

Layer (type:depth-idx)                   Output Shape              Param #
FALCONetMHA_LiRPA                        [1, 2, 256, 256]          --
├─DoubleConv: 1-1                        [1, 8, 256, 256]          --
│    └─Sequential: 2-1                   [1, 8, 256, 256]          --
│    │    └─Conv2d: 3-1                  [1, 8, 256, 256]          1,880
│    │    └─BatchNorm2d: 3-2             [1, 8, 256, 256]          16
│    │    └─ReLU: 3-3                    [1, 8, 256, 256]          --
│    │    └─Conv2d: 3-4                  [1, 8, 256, 256]          584
│    │    └─BatchNorm2d: 3-5             [1, 8, 256, 256]          16
│    │    └─ReLU: 3-6                    [1, 8, 256, 256]          --
├─Down: 1-2                              [1, 16, 128, 128]         --
│    └─Sequential: 2-2                   [1, 16, 128, 128]         --
│    │    └─Conv2d: 3-7                  [1, 16, 128, 128]         1,168
│    │    └─BatchNorm2d: 3-8             [1, 16, 128, 128]         32
│    │  

In [None]:
# FALCONetMHA Finetuning
falconet_retrained, falconet_retrained_name = FALCONetMHA_LiRPA(2*13, 2 , dropout=0.1, reduction=8, attention=True, num_heads=4), 'FALCONetMHA_LiRPA'

falconet_retrained.load_state_dict(torch.load('../oscd_ood/FALCONetMHA_LiRPA.pth.tar', weights_only=True), strict=False)
falconet_retrained = falconet_retrained.to(device)
loss_fn = nn.CrossEntropyLoss(weight=dataset_weights)
optimizer_falconet_retrained = torch.optim.Adam(falconet_retrained.parameters(), lr=5e-4, weight_decay=1e-4)
scheduler_falconet_retrained = torch.optim.lr_scheduler.ExponentialLR(optimizer_falconet_retrained, 0.95)
print('LOAD OK')

In [25]:
# 4) Evaluate using your existing API (IMPORTANT: pass the *dataset views*, not DataLoaders)
results = []

specs = [
    ("clean",   None),
    ("lf_1",    BoundedPairOOD(kind="lowfreq", eps=1/255, correlated=True,  seed=42)),
    ("lf_2",    BoundedPairOOD(kind="lowfreq", eps=2/255, correlated=True,  seed=42)),
    ("shadow",  BoundedPairOOD(kind="shadow",  eps=1/255, correlated=False, seed=7)),
    ("pband",   BoundedPairOOD(kind="per_band",eps=1/255, correlated=True,  seed=9)),
    ("blur_1",  BoundedPairOOD(kind="blur",   eps=1/255, correlated=False,  seed=11)),
]
for name, tf in specs:
    m = unseen_test_patch_batch(test_dataset, loss_fn, falconet_retrained, feat=False, 
                                ood=(tf is not None),
                                pair_transform=tf,
                                apply_minmax=True)
    m['setting'] = name
    results.append(m)




# 5) Turn into a DataFrame + save
import pandas as pd
df = pd.DataFrame(results)
print(df)
df.to_csv("./ood_eval_summary_falco.csv", index=False)

# Optional: LaTeX table
def df_to_latex_table(df, out_path="table_ood_eval_summary_falco.tex"):
    cols = ["setting","net_loss","net_accuracy","acc_bg","acc_fg","precision","recall","dice","kappa","f_meas"]
    use = [c for c in cols if c in df.columns]
    with open(out_path, "w") as f:
        f.write("\\begin{table}[t]\n\\centering\n")
        f.write("\\caption{\\textbf{Empirical OOD evaluation on OSCD test set.}}\n")
        f.write("\\label{tab:ood_eval_summary}\n")
        f.write("\\begin{tabular}{%s}\n\\toprule\n" % ("l" + "r"*(len(use)-1)))
        f.write(" & ".join([("\\textbf{"+c.replace('_','\\_')+"}") for c in use]) + " \\\\\n\\midrule\n")
        for _, r in df[use].iterrows():
            f.write(" & ".join([str(r[c]) if not isinstance(r[c], float) else f"{r[c]:.4f}" for c in use]) + " \\\\\n")
        f.write("\\bottomrule\n\\end{tabular}\n\\end{table}\n")
    print(f"[saved] {out_path}")

df_to_latex_table(df)

100%|███████████████████████████████████████████████████████████████████████████████████| 10/10 [00:42<00:00,  4.29s/it]
100%|███████████████████████████████████████████████████████████████████████████████████| 10/10 [00:43<00:00,  4.31s/it]
100%|███████████████████████████████████████████████████████████████████████████████████| 10/10 [00:43<00:00,  4.34s/it]
100%|███████████████████████████████████████████████████████████████████████████████████| 10/10 [00:40<00:00,  4.06s/it]
100%|███████████████████████████████████████████████████████████████████████████████████| 10/10 [00:40<00:00,  4.02s/it]
100%|███████████████████████████████████████████████████████████████████████████████████| 10/10 [00:41<00:00,  4.18s/it]

   net_loss  net_accuracy                          class_accuracy  precision  \
0  0.194746     95.454259  [97.22432635492156, 62.97579159777969]   0.552876   
1  0.726732     94.831699                            [100.0, 0.0]   0.000000   
2  0.726746     94.831699                            [100.0, 0.0]   0.000000   
3  0.726882     94.831699                            [100.0, 0.0]   0.000000   
4  0.727310     94.831699                            [100.0, 0.0]   0.000000   
5  0.726733     94.831699                            [100.0, 0.0]   0.000000   

     recall      dice     kappa    f_meas setting  
0  0.629758  0.588818  0.564867  0.588818   clean  
1  0.000000  0.000000  0.000000  0.000000    lf_1  
2  0.000000  0.000000  0.000000  0.000000    lf_2  
3  0.000000  0.000000  0.000000  0.000000  shadow  
4  0.000000  0.000000  0.000000  0.000000   pband  
5  0.000000  0.000000  0.000000  0.000000  blur_1  
[saved] table_ood_eval_summary_falco.tex





In [7]:
# FALCONetMHA Finetuning
falconet_retrained, falconet_retrained_name = FALCONetMHA_LiRPA(2*13, 2 , dropout=0.1, reduction=8, attention=True, num_heads=4), 'FALCONetMHA_LiRPA'

falconet_retrained.load_state_dict(torch.load('./FALCONetMHA_HCT_vfree_v3.pth.tar', weights_only=True), strict=False)
falconet_retrained = falconet_retrained.to(device)
loss_fn = nn.CrossEntropyLoss(weight=dataset_weights)
results = []

specs = [
    ("clean",   None),
    ("lf_1",    BoundedPairOOD(kind="lowfreq", eps=1/255, correlated=True,  seed=42)),
    ("lf_2",    BoundedPairOOD(kind="lowfreq", eps=2/255, correlated=True,  seed=42)),
    ("shadow",  BoundedPairOOD(kind="shadow",  eps=1/255, correlated=False, seed=7)),
    ("pband",   BoundedPairOOD(kind="per_band",eps=1/255, correlated=True,  seed=9)),
    ("blur_1",  BoundedPairOOD(kind="blur",   eps=1/255, correlated=False,  seed=11)),
]
for name, tf in specs:
    m = unseen_test_patch_batch(test_dataset, loss_fn, falconet_retrained, feat=False, 
                                ood=(tf is not None),
                                pair_transform=tf,
                                apply_minmax=True)
    m['setting'] = name
    results.append(m)




# 5) Turn into a DataFrame + save
import pandas as pd
df = pd.DataFrame(results)
print(df)
df.to_csv("./ood_eval_summary_falco_hct.csv", index=False)

# Optional: LaTeX table
def df_to_latex_table(df, out_path="table_ood_eval_summary_falco_hct.tex"):
    cols = ["setting","net_loss","net_accuracy","acc_bg","acc_fg","precision","recall","dice","kappa","f_meas"]
    use = [c for c in cols if c in df.columns]
    with open(out_path, "w") as f:
        f.write("\\begin{table}[t]\n\\centering\n")
        f.write("\\caption{\\textbf{Empirical OOD evaluation on OSCD test set.}}\n")
        f.write("\\label{tab:ood_eval_summary}\n")
        f.write("\\begin{tabular}{%s}\n\\toprule\n" % ("l" + "r"*(len(use)-1)))
        f.write(" & ".join([("\\textbf{"+c.replace('_','\\_')+"}") for c in use]) + " \\\\\n\\midrule\n")
        for _, r in df[use].iterrows():
            f.write(" & ".join([str(r[c]) if not isinstance(r[c], float) else f"{r[c]:.4f}" for c in use]) + " \\\\\n")
        f.write("\\bottomrule\n\\end{tabular}\n\\end{table}\n")
    print(f"[saved] {out_path}")

df_to_latex_table(df)

100%|███████████████████████████████████████████████████████████████████████████████████| 10/10 [00:44<00:00,  4.40s/it]
100%|███████████████████████████████████████████████████████████████████████████████████| 10/10 [00:51<00:00,  5.14s/it]
100%|███████████████████████████████████████████████████████████████████████████████████| 10/10 [00:47<00:00,  4.78s/it]
100%|███████████████████████████████████████████████████████████████████████████████████| 10/10 [00:40<00:00,  4.05s/it]
100%|███████████████████████████████████████████████████████████████████████████████████| 10/10 [00:42<00:00,  4.21s/it]
100%|███████████████████████████████████████████████████████████████████████████████████| 10/10 [00:48<00:00,  4.88s/it]

   net_loss  net_accuracy                           class_accuracy  precision  \
0  0.147368     96.080913   [97.38767785631303, 72.10344675848803]   0.600681   
1  0.801486     94.858665   [99.99804718213521, 0.557591606580461]   0.939619   
2  0.801540     94.858308  [99.99801292217268, 0.5513053426956757]   0.937968   
3  0.802180     94.858113  [99.99821848194792, 0.5437618260339333]   0.943293   
4  0.801611     94.858145  [99.99808144209776, 0.5469049579763259]   0.939525   
5  0.801707     94.858698  [99.99804718213521, 0.5582202329689395]   0.939683   

     recall      dice     kappa    f_meas setting  
0  0.721034  0.655378  0.634784  0.655378   clean  
1  0.005576  0.011086  0.010483  0.011086    lf_1  
2  0.005513  0.010962  0.010364  0.010962    lf_2  
3  0.005438  0.010813  0.010227  0.010813  shadow  
4  0.005469  0.010875  0.010283  0.010875   pband  
5  0.005582  0.011098  0.010494  0.011098  blur_1  
[saved] table_ood_eval_summary_falco_hct.tex





In [22]:
# Attunet definition
class conv_block(nn.Module):
    def __init__(self,ch_in,ch_out):
        super(conv_block,self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(ch_in, ch_out, kernel_size=3,stride=1,padding=1,bias=True),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True),
            nn.Conv2d(ch_out, ch_out, kernel_size=3,stride=1,padding=1,bias=True),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True)
        )


    def forward(self,x):
        x = self.conv(x)
        return x

class up_conv(nn.Module):
    def __init__(self,ch_in,ch_out):
        super(up_conv,self).__init__()
        self.up = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=1,padding=1,bias=True),
		    nn.BatchNorm2d(ch_out),
			nn.ReLU(inplace=True)
        )

    def forward(self,x):
        x = self.up(x)
        return x

class Recurrent_block(nn.Module):
    def __init__(self,ch_out,t=2):
        super(Recurrent_block,self).__init__()
        self.t = t
        self.ch_out = ch_out
        self.conv = nn.Sequential(
            nn.Conv2d(ch_out,ch_out,kernel_size=3,stride=1,padding=1,bias=True),
		    nn.BatchNorm2d(ch_out),
			nn.ReLU(inplace=True)
        )

    def forward(self,x):
        for i in range(self.t):

            if i==0:
                x1 = self.conv(x)
            
            x1 = self.conv(x+x1)
        return x1
        
class RRCNN_block(nn.Module):
    def __init__(self,ch_in,ch_out,t=2):
        super(RRCNN_block,self).__init__()
        self.RCNN = nn.Sequential(
            Recurrent_block(ch_out,t=t),
            Recurrent_block(ch_out,t=t)
        )
        self.Conv_1x1 = nn.Conv2d(ch_in,ch_out,kernel_size=1,stride=1,padding=0)

    def forward(self,x):
        x = self.Conv_1x1(x)
        x1 = self.RCNN(x)
        return x+x1


class single_conv(nn.Module):
    def __init__(self,ch_in,ch_out):
        super(single_conv,self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(ch_in, ch_out, kernel_size=3,stride=1,padding=1,bias=True),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True)
        )

    def forward(self,x):
        x = self.conv(x)
        return x

class Attention_block(nn.Module):
    def __init__(self,F_g,F_l,F_int):
        super(Attention_block,self).__init__()
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1,stride=1,padding=0,bias=True),
            nn.BatchNorm2d(F_int)
            )
        
        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1,stride=1,padding=0,bias=True),
            nn.BatchNorm2d(F_int)
        )

        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1,stride=1,padding=0,bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )
        
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self,g,x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1+x1)
        psi = self.psi(psi)

        return x*psi



class AttU_Net(nn.Module):
    def __init__(self,img_ch=26,output_ch=1):
        super(AttU_Net,self).__init__()
        
        self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2)

        self.Conv1 = conv_block(ch_in=img_ch,ch_out=64)
        self.Conv2 = conv_block(ch_in=64,ch_out=128)
        self.Conv3 = conv_block(ch_in=128,ch_out=256)
        self.Conv4 = conv_block(ch_in=256,ch_out=512)
        self.Conv5 = conv_block(ch_in=512,ch_out=1024)

        self.Up5 = up_conv(ch_in=1024,ch_out=512)
        self.Att5 = Attention_block(F_g=512,F_l=512,F_int=256)
        self.Up_conv5 = conv_block(ch_in=1024, ch_out=512)

        self.Up4 = up_conv(ch_in=512,ch_out=256)
        self.Att4 = Attention_block(F_g=256,F_l=256,F_int=128)
        self.Up_conv4 = conv_block(ch_in=512, ch_out=256)
        
        self.Up3 = up_conv(ch_in=256,ch_out=128)
        self.Att3 = Attention_block(F_g=128,F_l=128,F_int=64)
        self.Up_conv3 = conv_block(ch_in=256, ch_out=128)
        
        self.Up2 = up_conv(ch_in=128,ch_out=64)
        self.Att2 = Attention_block(F_g=64,F_l=64,F_int=32)
        self.Up_conv2 = conv_block(ch_in=128, ch_out=64)

        self.Conv_1x1 = nn.Conv2d(64,output_ch,kernel_size=1,stride=1,padding=0)
        self.sm = nn.Identity()


    def forward(self, x1, x2):

        x = torch.cat((x1, x2), 1)
        
        # encoding path
        x1 = self.Conv1(x)

        x2 = self.Maxpool(x1)
        x2 = self.Conv2(x2)
        
        x3 = self.Maxpool(x2)
        x3 = self.Conv3(x3)

        x4 = self.Maxpool(x3)
        x4 = self.Conv4(x4)

        x5 = self.Maxpool(x4)
        x5 = self.Conv5(x5)

        # decoding + concat path
        d5 = self.Up5(x5)
        x4 = self.Att5(g=d5,x=x4)
        d5 = torch.cat((x4,d5),dim=1)        
        d5 = self.Up_conv5(d5)
        
        d4 = self.Up4(d5)
        x3 = self.Att4(g=d4,x=x3)
        d4 = torch.cat((x3,d4),dim=1)
        d4 = self.Up_conv4(d4)

        d3 = self.Up3(d4)
        x2 = self.Att3(g=d3,x=x2)
        d3 = torch.cat((x2,d3),dim=1)
        d3 = self.Up_conv3(d3)

        d2 = self.Up2(d3)
        x1 = self.Att2(g=d2,x=x1)
        d2 = torch.cat((x1,d2),dim=1)
        d2 = self.Up_conv2(d2)

        d1 = self.Conv_1x1(d2)

        return self.sm(d1)

In [24]:
saved_aunet, saved_aunet_name = AttU_Net(2*13, 2), 'AttU_Net'
lossatt_fn = nn.CrossEntropyLoss(weight=dataset_weights)
saved_aunet.load_state_dict(torch.load('../oscd_ood/verificationfriendly_attunet_finetune.pth.tar', weights_only=True), strict=False)

<All keys matched successfully>

In [26]:
# 4) Evaluate using your existing API (IMPORTANT: pass the *dataset views*, not DataLoaders)
results = []

specs = [
    ("clean",   None),
    ("lf_1",    BoundedPairOOD(kind="lowfreq", eps=1/255, correlated=True,  seed=42)),
    ("lf_2",    BoundedPairOOD(kind="lowfreq", eps=2/255, correlated=True,  seed=42)),
    ("shadow",  BoundedPairOOD(kind="shadow",  eps=1/255, correlated=False, seed=7)),
    ("pband",   BoundedPairOOD(kind="per_band",eps=1/255, correlated=True,  seed=9)),
    ("blur_1",  BoundedPairOOD(kind="blur",   eps=1/255, correlated=False,  seed=11)),
]
for name, tf in specs:
    m = unseen_test_patch_batch(test_dataset, lossatt_fn, saved_aunet, feat=False, 
                                ood=(tf is not None),
                                pair_transform=tf,
                                apply_minmax=True)
    m['setting'] = name
    results.append(m)




# 5) Turn into a DataFrame + save
import pandas as pd
df = pd.DataFrame(results)
print(df)
df.to_csv("./ood_eval_summary_att.csv", index=False)

# Optional: LaTeX table
def df_to_latex_table(df, out_path="table_ood_eval_summary_att.tex"):
    cols = ["setting","net_loss","net_accuracy","acc_bg","acc_fg","precision","recall","dice","kappa","f_meas"]
    use = [c for c in cols if c in df.columns]
    with open(out_path, "w") as f:
        f.write("\\begin{table}[t]\n\\centering\n")
        f.write("\\caption{\\textbf{Empirical OOD evaluation on OSCD test set.}}\n")
        f.write("\\label{tab:ood_eval_summary}\n")
        f.write("\\begin{tabular}{%s}\n\\toprule\n" % ("l" + "r"*(len(use)-1)))
        f.write(" & ".join([("\\textbf{"+c.replace('_','\\_')+"}") for c in use]) + " \\\\\n\\midrule\n")
        for _, r in df[use].iterrows():
            f.write(" & ".join([str(r[c]) if not isinstance(r[c], float) else f"{r[c]:.4f}" for c in use]) + " \\\\\n")
        f.write("\\bottomrule\n\\end{tabular}\n\\end{table}\n")
    print(f"[saved] {out_path}")

df_to_latex_table(df)

100%|███████████████████████████████████████████████████████████████████████████████████| 10/10 [03:21<00:00, 20.12s/it]
100%|███████████████████████████████████████████████████████████████████████████████████| 10/10 [02:35<00:00, 15.55s/it]
100%|███████████████████████████████████████████████████████████████████████████████████| 10/10 [02:48<00:00, 16.86s/it]
100%|███████████████████████████████████████████████████████████████████████████████████| 10/10 [02:32<00:00, 15.25s/it]
100%|███████████████████████████████████████████████████████████████████████████████████| 10/10 [02:28<00:00, 14.86s/it]
100%|███████████████████████████████████████████████████████████████████████████████████| 10/10 [02:28<00:00, 14.88s/it]


   net_loss  net_accuracy                           class_accuracy  precision  \
0  0.216362     95.729411  [97.70030001449196, 59.566122066672115]   0.585344   
1  1.402886     27.777803   [26.40004193419415, 53.05795306675384]   0.037803   
2  1.402800     27.782969  [26.405900387788517, 53.05040955009209]   0.037801   
3  1.403117     27.736412   [26.35468174379098, 53.08938438617776]   0.037803   
4  1.407761     27.144002  [25.702063717363533, 53.60171489278777]   0.037831   
5  1.403131     27.750512  [26.36999594704643, 53.081212243127545]   0.037805   

     recall      dice     kappa    f_meas setting  
0  0.595661  0.590457  0.567932  0.590457   clean  
1  0.530580  0.070578 -0.028680  0.070578    lf_1  
2  0.530504  0.070574 -0.028685  0.070574    lf_2  
3  0.530894  0.070579 -0.028683  0.070579  shadow  
4  0.536017  0.070674 -0.028643  0.070674   pband  
5  0.530812  0.070582 -0.028679  0.070582  blur_1  
[saved] table_ood_eval_summary_att.tex


In [13]:
# ================================
# Head-Consistency Training (no LiRPA)
# Drop-in replacement with same signature & plotting
# ================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import random

# ---- Physically-plausible OOD perturbations (Sentinel-2-ish) ----
def _gaussian_kernel(k=3, sigma=1.0, device="cpu", dtype=torch.float32):
    ax = torch.arange(k, device=device, dtype=dtype) - (k - 1) / 2.0
    w = torch.exp(-(ax**2) / (2 * sigma**2))
    w = w / w.sum()
    k2d = (w[:, None] * w[None, :]).unsqueeze(0).unsqueeze(0)  # 1x1xk xk
    return k2d

def gaussian_blur_depthwise(x, k=3, sigma=1.0):
    B, C, H, W = x.shape
    k2d = _gaussian_kernel(k, sigma, device=x.device, dtype=x.dtype)
    weight = k2d.repeat(C, 1, 1, 1)  # [C,1,k,k]
    return F.conv2d(x, weight, padding=k//2, groups=C)

def lowfreq_drift(x, amp=0.02, grid_shrink=16):
    B, C, H, W = x.shape
    h = max(1, H // grid_shrink); w = max(1, W // grid_shrink)
    coarse = torch.rand(B, C, h, w, device=x.device, dtype=x.dtype) * 2 - 1  # [-1,1]
    drift = F.interpolate(coarse, size=(H, W), mode='bilinear', align_corners=False)
    return torch.clamp(x + amp * drift, 0.0, 1.0)

def shadow_patch(x, alpha=(0.5, 0.85)):
    # Darken a random box; randomness is non-differentiable but fine for training
    B, C, H, W = x.shape
    out = x.clone()
    for b in range(B):
        ah = random.randint(max(1, H//6), max(2, H//2))
        aw = random.randint(max(1, W//6), max(2, W//2))
        y0 = random.randint(0, max(0, H - ah))
        x0 = random.randint(0, max(0, W - aw))
        a = random.uniform(alpha[0], alpha[1])
        out[b, :, y0:y0+ah, x0:x0+aw] *= a
    return torch.clamp(out, 0.0, 1.0)

def passband_jitter(x, scale=0.03):
    B, C, H, W = x.shape
    jitter = 1.0 + (torch.rand(B, C, 1, 1, device=x.device, dtype=x.dtype) * 2 - 1) * scale
    return torch.clamp(x * jitter, 0.0, 1.0)

def uniform_linf(x, eps=1/255):
    return torch.clamp(x + (torch.rand_like(x) * 2 - 1) * eps, 0.0, 1.0)

def phys_ood_perturb(x1, x2, eps_linf=1/255,
                     p_lf=0.7, p_shadow=0.5, p_blur=0.7, p_pband=0.7):
    x1p, x2p = uniform_linf(x1, eps_linf), uniform_linf(x2, eps_linf)
    if random.random() < p_lf:
        x1p = lowfreq_drift(x1p, amp=2.5*eps_linf, grid_shrink=16)
        x2p = lowfreq_drift(x2p, amp=2.5*eps_linf, grid_shrink=16)
    if random.random() < p_pband:
        x1p = passband_jitter(x1p, scale=3*eps_linf)
        x2p = passband_jitter(x2p, scale=3*eps_linf)
    if random.random() < p_shadow:
        x1p = shadow_patch(x1p, alpha=(0.5, 0.85))
        x2p = shadow_patch(x2p, alpha=(0.5, 0.85))
    if random.random() < p_blur:
        x1p = gaussian_blur_depthwise(x1p, k=3, sigma=0.8)
        x2p = gaussian_blur_depthwise(x2p, k=3, sigma=0.8)
    return x1p, x2p

# ---- Helper losses for masked pixel-wise CE and margin hinge ----
def _pixelwise_ce_masked(logits, target, weight=None, mask=None):
    # logits [B,2,H,W]; target [B,H,W] long
    loss = F.cross_entropy(logits, target.long(), weight=weight, reduction='none')
    if mask is None:
        return loss.mean()
    loss = loss * mask
    denom = mask.sum().clamp_min(1.0)
    return loss.sum() / denom

def _pseudo_labels_and_conf(logits):
    probs = F.softmax(logits, dim=1)
    conf, pseudo = probs.max(dim=1)  # [B,H,W], [B,H,W]
    return pseudo, conf

# ---------------------------------------------------------------
# Same signature & outputs as your 'train' function
# Adds HCT losses internally; keeps plotting & eval intact
# ---------------------------------------------------------------
def train_HCT_noLiRPA(model, net_name, criterion, optimizer, scheduler,
                      train_loader, train_dataset, test_dataset,
                      feat=False, n_epochs=N_EPOCHS, save=True,
                      # HCT hyperparams (tweak as needed)
                      eps_linf=1.0/255.0, K_noisy=2, tau_conf=0.7,
                      lambda_ce=1.0, lambda_cons=1.0,
                      lambda_margin=0.5, margin_target=1.0):
    L = 1024
    N = 2
    t = np.linspace(1, n_epochs, n_epochs)

    epoch_train_loss = 0 * t
    epoch_train_accuracy = 0 * t
    epoch_train_change_accuracy = 0 * t
    epoch_train_nochange_accuracy = 0 * t
    epoch_train_precision = 0 * t
    epoch_train_recall = 0 * t
    epoch_train_Fmeasure = 0 * t
    epoch_test_loss = 0 * t
    epoch_test_accuracy = 0 * t
    epoch_test_change_accuracy = 0 * t
    epoch_test_nochange_accuracy = 0 * t
    epoch_test_precision = 0 * t
    epoch_test_recall = 0 * t
    epoch_test_Fmeasure = 0 * t

    fm = 0
    best_fm = 0
    lss = 1000
    best_lss = 1000

    plt.figure(num=1); plt.figure(num=2); plt.figure(num=3); plt.figure(num=4)

    # try to reuse class weights if provided in your CE
    class_weights = getattr(criterion, "weight", None)

    for epoch_index in tqdm(range(n_epochs)):
        model.train()
        print('Epoch: ' + str(epoch_index + 1) + ' of ' + str(n_epochs))

        for batch in train_loader:
            I1 = Variable(batch['I1'].float().to(device))
            I2 = Variable(batch['I2'].float().to(device))
            label = torch.squeeze(Variable(batch['label'].to(device)))

            optimizer.zero_grad()

            # --- Clean forward ---
            if feat:
                output_clean, _ = model(I1, I2)
            else:
                output_clean = model(I1, I2)

            # Supervised CE on clean
            ce_clean = criterion(output_clean, label.long())

            # Pseudo labels + confidence gating for PP-consistency
            with torch.no_grad():
                pseudo, conf = _pseudo_labels_and_conf(output_clean)
            mask = (conf >= tau_conf).float()  # [B,H,W]

            # --- Margin hinge on clean logits (pred vs alt) under mask ---
            pred = pseudo.long()          # [B,H,W]
            alt = (1 - pred).long()
            pred_logits = output_clean.gather(1, pred.unsqueeze(1)).squeeze(1)  # [B,H,W]
            alt_logits  = output_clean.gather(1, alt.unsqueeze(1)).squeeze(1)   # [B,H,W]
            m_clean = pred_logits - alt_logits
            margin_penalty = F.relu(margin_target - m_clean) * mask
            L_margin = margin_penalty.sum() / mask.sum().clamp_min(1.0)

            # --- Consistency to pseudo labels for K noisy draws ---
            cons_terms = []
            for _ in range(K_noisy):
                I1p, I2p = phys_ood_perturb(I1, I2, eps_linf=eps_linf)
                if feat:
                    out_p, _ = model(I1p, I2p)
                else:
                    out_p = model(I1p, I2p)
                # CE to pseudo under the confidence mask; reuse class weights if any
                cons_terms.append(_pixelwise_ce_masked(out_p, pseudo, weight=class_weights, mask=mask))

            L_cons = torch.stack(cons_terms).mean() if cons_terms else 0.0 * ce_clean

            # --- Total loss ---
            loss = lambda_ce * ce_clean + lambda_cons * L_cons + lambda_margin * L_margin
            loss.backward()
            optimizer.step()

        scheduler.step()

        # ---- Epoch-end: reuse your evaluation exactly as-is ----
        epoch_train_loss[epoch_index], epoch_train_accuracy[epoch_index], cl_acc, pr_rec = \
            test_patch_batch(train_dataset, criterion, model, feat=feat)
        epoch_train_nochange_accuracy[epoch_index] = cl_acc[0]
        epoch_train_change_accuracy[epoch_index] = cl_acc[1]
        epoch_train_precision[epoch_index] = pr_rec[0]
        epoch_train_recall[epoch_index] = pr_rec[1]
        epoch_train_Fmeasure[epoch_index] = pr_rec[2]

        epoch_test_loss[epoch_index], epoch_test_accuracy[epoch_index], cl_acc, pr_rec = \
            test_patch_batch(test_dataset, criterion, model, feat=feat)
        epoch_test_nochange_accuracy[epoch_index] = cl_acc[0]
        epoch_test_change_accuracy[epoch_index] = cl_acc[1]
        epoch_test_precision[epoch_index] = pr_rec[0]
        epoch_test_recall[epoch_index] = pr_rec[1]
        epoch_test_Fmeasure[epoch_index] = pr_rec[2]

        # ---- Same plots as your trainer ----
        plt.figure(num=1); plt.clf()
        l1_1, = plt.plot(t[:epoch_index + 1], epoch_train_loss[:epoch_index + 1], label='Train loss')
        l1_2, = plt.plot(t[:epoch_index + 1], epoch_test_loss[:epoch_index + 1], label='Test loss')
        plt.legend(handles=[l1_1, l1_2]); plt.grid(); plt.gcf().gca().set_xlim(left=0); plt.title('Loss')
        display.clear_output(wait=True); display.display(plt.gcf())

        plt.figure(num=2); plt.clf()
        l2_1, = plt.plot(t[:epoch_index + 1], epoch_train_accuracy[:epoch_index + 1], label='Train accuracy')
        l2_2, = plt.plot(t[:epoch_index + 1], epoch_test_accuracy[:epoch_index + 1], label='Test accuracy')
        plt.legend(handles=[l2_1, l2_2]); plt.grid(); plt.gcf().gca().set_ylim(0, 100); plt.title('Accuracy')
        display.clear_output(wait=True); display.display(plt.gcf())

        plt.figure(num=3); plt.clf()
        l3_1, = plt.plot(t[:epoch_index + 1], epoch_train_nochange_accuracy[:epoch_index + 1], label='Train accuracy: no change')
        l3_2, = plt.plot(t[:epoch_index + 1], epoch_train_change_accuracy[:epoch_index + 1], label='Train accuracy: change')
        l3_3, = plt.plot(t[:epoch_index + 1], epoch_test_nochange_accuracy[:epoch_index + 1], label='Test accuracy: no change')
        l3_4, = plt.plot(t[:epoch_index + 1], epoch_test_change_accuracy[:epoch_index + 1], label='Test accuracy: change')
        plt.legend(handles=[l3_1, l3_2, l3_3, l3_4]); plt.grid(); plt.gcf().gca().set_ylim(0, 100); plt.title('Accuracy per class')
        display.clear_output(wait=True); display.display(plt.gcf())

        plt.figure(num=4); plt.clf()
        l4_1, = plt.plot(t[:epoch_index + 1], epoch_train_precision[:epoch_index + 1], linewidth=1, label='Train precision')
        l4_2, = plt.plot(t[:epoch_index + 1], epoch_train_recall[:epoch_index + 1], linewidth=1, label='Train recall')
        l4_3, = plt.plot(t[:epoch_index + 1], epoch_train_Fmeasure[:epoch_index + 1], linewidth=1, label='Train Dice/F1')
        l4_4, = plt.plot(t[:epoch_index + 1], epoch_test_precision[:epoch_index + 1], linestyle='dashed', linewidth=2, label='Test precision')
        l4_5, = plt.plot(t[:epoch_index + 1], epoch_test_recall[:epoch_index + 1], linestyle='dashed', linewidth=2, label='Test recall')
        l4_6, = plt.plot(t[:epoch_index + 1], epoch_test_Fmeasure[:epoch_index + 1], linestyle='dashed', linewidth=2, label='Test Dice/F1')
        plt.legend(handles=[l4_1, l4_2, l4_3, l4_4, l4_5, l4_6]); plt.grid(); plt.gcf().gca().set_ylim(0, 1); plt.title('Precision, Recall and F-measure')
        display.clear_output(wait=True); display.display(plt.gcf())

        # ---- Save best by train F1 and by train loss (same as your code) ----
        fm = epoch_train_Fmeasure[epoch_index]
        if fm > best_fm:
            best_fm = fm
            save_str = f'net-best_epoch-{epoch_index + 1}_fm-{fm}.pth.tar'
            torch.save(model.state_dict(), save_str)

        lss = epoch_train_loss[epoch_index]
        if lss < best_lss:
            best_lss = lss
            save_str = f'net-best_epoch-{epoch_index + 1}_loss-{lss}.pth.tar'
            torch.save(model.state_dict(), save_str)

        if save:
            im_format = 'png'
            plt.figure(num=1); plt.savefig(net_name + '-01-loss.' + im_format)
            plt.figure(num=2); plt.savefig(net_name + '-02-accuracy.' + im_format)
            plt.figure(num=3); plt.savefig(net_name + '-03-accuracy-per-class.' + im_format)
            plt.figure(num=4); plt.savefig(net_name + '-04-prec-rec-fmeas.' + im_format)

    out = {'train_loss': epoch_train_loss[-1],
           'train_accuracy': epoch_train_accuracy[-1],
           'train_nochange_accuracy': epoch_train_nochange_accuracy[-1],
           'train_change_accuracy': epoch_train_change_accuracy[-1],
           'test_loss': epoch_test_loss[-1],
           'test_accuracy': epoch_test_accuracy[-1],
           'test_nochange_accuracy': epoch_test_nochange_accuracy[-1],
           'test_change_accuracy': epoch_test_change_accuracy[-1]}

    print('pr_c, rec_c, f_meas, pr_nc, rec_nc')
    print(pr_rec)
    return out


In [16]:
# Model (no LiRPA wrapper)
falconet_hct = FALCONetMHA_LiRPA(2*13, 2 , dropout=0.1, reduction=8, attention=True, num_heads=4)
name = 'FALCONetMHA_HCT_vfree'

# Loss/opt/sched as before
loss_fn = nn.CrossEntropyLoss(weight=dataset_weights)
optimizer = torch.optim.Adam(falconet_hct.parameters(), lr=5e-4, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.95)

print('LOAD OK')
t0 = time.time()
out_dic = train_HCT_noLiRPA(
    falconet_hct, 'FALCONetMHA_HCT_vfree',
    loss_fn, optimizer, scheduler,
    train_loader, train_dataset, test_dataset,
    feat=False, n_epochs=N_EPOCHS, save=True,
    # Optional HCT knobs:
    eps_linf=4/255, K_noisy=4, tau_conf=0.8,
    lambda_ce=1.0, lambda_cons=1.0, lambda_margin=0.5, margin_target=1.0
)

t1 = time.time()
print("Elapsed time for training:", t1 - t0)
torch.save(falconet_hct.state_dict(), 'FALCONetMHA_HCT_vfree.pth.tar')
print('SAVE OK')


LOAD OK


  0%|                                                                                            | 0/10 [00:00<?, ?it/s]

Epoch: 1 of 10


  0%|                                                                                            | 0/10 [03:02<?, ?it/s]


KeyboardInterrupt: 

<Figure size 640x480 with 0 Axes>

<Figure size 640x480 with 0 Axes>

<Figure size 640x480 with 0 Axes>

<Figure size 640x480 with 0 Axes>

In [None]:
# FALCONetMHA Finetuning
falconet_retrained, falconet_retrained_name = FALCONetMHA_LiRPA(2*13, 2 , dropout=0.1, reduction=8, attention=True, num_heads=4), 'FALCONetMHA_LiRPA'

falconet_retrained.load_state_dict(torch.load('./FALCONetMHA_HCT_vfree.pth.tar', weights_only=True), strict=False)
falconet_retrained = falconet_retrained.to(device)
loss_fn = nn.CrossEntropyLoss(weight=dataset_weights)
results = []

specs = [
    ("clean",   None),
    ("lf_1",    BoundedPairOOD(kind="lowfreq", eps=1/255, correlated=True,  seed=42)),
    ("lf_2",    BoundedPairOOD(kind="lowfreq", eps=2/255, correlated=True,  seed=42)),
    ("shadow",  BoundedPairOOD(kind="shadow",  eps=1/255, correlated=False, seed=7)),
    ("pband",   BoundedPairOOD(kind="per_band",eps=1/255, correlated=True,  seed=9)),
    ("blur_1",  BoundedPairOOD(kind="blur",   eps=1/255, correlated=False,  seed=11)),
]
for name, tf in specs:
    m = unseen_test_patch_batch(test_dataset, loss_fn, falconet_retrained, feat=False, 
                                ood=(tf is not None),
                                pair_transform=tf,
                                apply_minmax=True)
    m['setting'] = name
    results.append(m)




# 5) Turn into a DataFrame + save
import pandas as pd
df = pd.DataFrame(results)
print(df)
df.to_csv("./ood_eval_summary_falco_hct.csv", index=False)

# Optional: LaTeX table
def df_to_latex_table(df, out_path="table_ood_eval_summary_falco_hct.tex"):
    cols = ["setting","net_loss","net_accuracy","acc_bg","acc_fg","precision","recall","dice","kappa","f_meas"]
    use = [c for c in cols if c in df.columns]
    with open(out_path, "w") as f:
        f.write("\\begin{table}[t]\n\\centering\n")
        f.write("\\caption{\\textbf{Empirical OOD evaluation on OSCD test set.}}\n")
        f.write("\\label{tab:ood_eval_summary}\n")
        f.write("\\begin{tabular}{%s}\n\\toprule\n" % ("l" + "r"*(len(use)-1)))
        f.write(" & ".join([("\\textbf{"+c.replace('_','\\_')+"}") for c in use]) + " \\\\\n\\midrule\n")
        for _, r in df[use].iterrows():
            f.write(" & ".join([str(r[c]) if not isinstance(r[c], float) else f"{r[c]:.4f}" for c in use]) + " \\\\\n")
        f.write("\\bottomrule\n\\end{tabular}\n\\end{table}\n")
    print(f"[saved] {out_path}")

df_to_latex_table(df)