## The HCT - main novel training

In [2]:
# Legacy imports for OSCD
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.padding import ReplicationPad2d
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable
import torchvision.transforms as tr
from torchinfo import summary

# Other
import os
import numpy as np
import random
from skimage import io
from scipy.ndimage import zoom
import matplotlib.pyplot as plt
%matplotlib inline
from tqdm import tqdm as tqdm
from pandas import read_csv
from math import floor, ceil, sqrt, exp
from IPython import display
import time
from itertools import chain
import time
import warnings
from pprint import pprint

from numpy.random import RandomState

SEED=2342

torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)
print('IMPORTS OK')

# Global Variables' Definitions

PATH_TO_DATASET = '../onera/OSCD/'
IS_PROTOTYPE = False

FP_MODIFIER = 10 # Tuning parameter, use 1 if unsure
#BATCH_SIZE = 32
BATCH_SIZE = 8
#BATCH_SIZE = 1

PATCH_SIDE = 128
#PATCH_SIDE = 64

N_EPOCHS = 10
NORMALISE_IMGS = True
TRAIN_STRIDE = int(PATCH_SIDE/2) - 1
TYPE = 3 # 0-RGB | 1-RGBIr | 2-All bands s.t. resulution <= 20m | 3-All bands
LOAD_TRAINED = False
DATA_AUG = True
#device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 
device = torch.device("cpu") 

print('DEFINITIONS OK')

# Dataset Functions

def adjust_shape(I, s):
    """Adjust shape of grayscale image I to s."""
    
    # crop if necesary
    I = I[:s[0],:s[1]]
    si = I.shape
    
    # pad if necessary 
    p0 = max(0,s[0] - si[0])
    p1 = max(0,s[1] - si[1])
    
    return np.pad(I,((0,p0),(0,p1)),'edge')
    

def read_sentinel_img(path):
    """Read cropped Sentinel-2 image: RGB bands."""
    im_name = os.listdir(path)[0][:-7]
    r = io.imread(path + im_name + "B04.tif")
    g = io.imread(path + im_name + "B03.tif")
    b = io.imread(path + im_name + "B02.tif")
    
    I = np.stack((r,g,b),axis=2).astype('float')
    
    if NORMALISE_IMGS:
        I = (I - I.mean()) / I.std()

    return I

def read_sentinel_img_4(path):
    """Read cropped Sentinel-2 image: RGB and NIR bands."""
    im_name = os.listdir(path)[0][:-7]
    r = io.imread(path + im_name + "B04.tif")
    g = io.imread(path + im_name + "B03.tif")
    b = io.imread(path + im_name + "B02.tif")
    nir = io.imread(path + im_name + "B08.tif")
    
    I = np.stack((r,g,b,nir),axis=2).astype('float')
    
    if NORMALISE_IMGS:
        I = (I - I.mean()) / I.std()

    return I

def read_sentinel_img_leq20(path):
    """Read cropped Sentinel-2 image: bands with resolution less than or equals to 20m."""
    im_name = os.listdir(path)[0][:-7]
    
    r = io.imread(path + im_name + "B04.tif")
    s = r.shape
    g = io.imread(path + im_name + "B03.tif")
    b = io.imread(path + im_name + "B02.tif")
    nir = io.imread(path + im_name + "B08.tif")
    
    ir1 = adjust_shape(zoom(io.imread(path + im_name + "B05.tif"),2),s)
    ir2 = adjust_shape(zoom(io.imread(path + im_name + "B06.tif"),2),s)
    ir3 = adjust_shape(zoom(io.imread(path + im_name + "B07.tif"),2),s)
    nir2 = adjust_shape(zoom(io.imread(path + im_name + "B8A.tif"),2),s)
    swir2 = adjust_shape(zoom(io.imread(path + im_name + "B11.tif"),2),s)
    swir3 = adjust_shape(zoom(io.imread(path + im_name + "B12.tif"),2),s)
    
    I = np.stack((r,g,b,nir,ir1,ir2,ir3,nir2,swir2,swir3),axis=2).astype('float')
    
    if NORMALISE_IMGS:
        I = (I - I.mean()) / I.std()

    return I

def read_sentinel_img_leq60(path):
    """Read cropped Sentinel-2 image: all bands."""
    im_name = os.listdir(path)[0][:-7]
    
    r = io.imread(path + im_name + "B04.tif")
    s = r.shape
    g = io.imread(path + im_name + "B03.tif")
    b = io.imread(path + im_name + "B02.tif")
    nir = io.imread(path + im_name + "B08.tif")
    
    ir1 = adjust_shape(zoom(io.imread(path + im_name + "B05.tif"),2),s)
    ir2 = adjust_shape(zoom(io.imread(path + im_name + "B06.tif"),2),s)
    ir3 = adjust_shape(zoom(io.imread(path + im_name + "B07.tif"),2),s)
    nir2 = adjust_shape(zoom(io.imread(path + im_name + "B8A.tif"),2),s)
    swir2 = adjust_shape(zoom(io.imread(path + im_name + "B11.tif"),2),s)
    swir3 = adjust_shape(zoom(io.imread(path + im_name + "B12.tif"),2),s)
    
    uv = adjust_shape(zoom(io.imread(path + im_name + "B01.tif"),6),s)
    wv = adjust_shape(zoom(io.imread(path + im_name + "B09.tif"),6),s)
    swirc = adjust_shape(zoom(io.imread(path + im_name + "B10.tif"),6),s)
    
    I = np.stack((r,g,b,nir,ir1,ir2,ir3,nir2,swir2,swir3,uv,wv,swirc),axis=2).astype('float')
    
    if NORMALISE_IMGS:
        I = (I - I.mean()) / I.std()

    return I

def read_sentinel_img_trio(path):
    """Read cropped Sentinel-2 image pair and change map."""
#     read images
    if TYPE == 0:
        I1 = read_sentinel_img(path + '/imgs_1/')
        I2 = read_sentinel_img(path + '/imgs_2/')
    elif TYPE == 1:
        I1 = read_sentinel_img_4(path + '/imgs_1/')
        I2 = read_sentinel_img_4(path + '/imgs_2/')
    elif TYPE == 2:
        I1 = read_sentinel_img_leq20(path + '/imgs_1/')
        I2 = read_sentinel_img_leq20(path + '/imgs_2/')
    elif TYPE == 3:
        I1 = read_sentinel_img_leq60(path + '/imgs_1/')
        I2 = read_sentinel_img_leq60(path + '/imgs_2/')
        
    cm = io.imread(path + '/cm/cm.png', as_gray=True) != 0
    
    # crop if necessary
    s1 = I1.shape
    s2 = I2.shape
    I2 = np.pad(I2,((0, s1[0] - s2[0]), (0, s1[1] - s2[1]), (0,0)),'edge')
    
    
    return I1, I2, cm



def reshape_for_torch(I):
    """Transpose image for PyTorch coordinates."""
#     out = np.swapaxes(I,1,2)
#     out = np.swapaxes(out,0,1)
#     out = out[np.newaxis,:]
    out = I.transpose((2, 0, 1))
    return torch.from_numpy(out)



class ChangeDetectionDataset(Dataset):
    """Change Detection dataset class, used for both training and test data."""

    def __init__(self, path, train = True, patch_side = 128, stride = None, use_all_bands = False, transform=None):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        
        # basics
        self.transform = transform
        self.path = path
        self.patch_side = patch_side
        if not stride:
            self.stride = 1
        else:
            self.stride = stride
        
        if train:
            fname = 'all.txt'
        else:
            fname = 'test.txt'
        
#         print(path + fname)
        self.names = read_csv(path + fname).columns
        self.n_imgs = self.names.shape[0]
        
        n_pix = 0
        true_pix = 0
        
        
        # load images
        self.imgs_1 = {}
        self.imgs_2 = {}
        self.change_maps = {}
        self.n_patches_per_image = {}
        self.n_patches = 0
        self.patch_coords = []
        for im_name in tqdm(self.names):
            # load and store each image
            I1, I2, cm = read_sentinel_img_trio(self.path + im_name)
            self.imgs_1[im_name] = reshape_for_torch(I1)
            self.imgs_2[im_name] = reshape_for_torch(I2)
            self.change_maps[im_name] = cm
            
            s = cm.shape
            n_pix += np.prod(s)
            true_pix += cm.sum()
            
            # calculate the number of patches
            s = self.imgs_1[im_name].shape
            n1 = ceil((s[1] - self.patch_side + 1) / self.stride)
            n2 = ceil((s[2] - self.patch_side + 1) / self.stride)
            n_patches_i = n1 * n2
            self.n_patches_per_image[im_name] = n_patches_i
            self.n_patches += n_patches_i
            
            # generate path coordinates
            for i in range(n1):
                for j in range(n2):
                    # coordinates in (x1, x2, y1, y2)
                    current_patch_coords = (im_name, 
                                    [self.stride*i, self.stride*i + self.patch_side, self.stride*j, self.stride*j + self.patch_side],
                                    [self.stride*(i + 1), self.stride*(j + 1)])
                    self.patch_coords.append(current_patch_coords)
                    
        self.weights = [ FP_MODIFIER * 2 * true_pix / n_pix, 2 * (n_pix - true_pix) / n_pix]
        
    def get_img(self, im_name):
        return self.imgs_1[im_name], self.imgs_2[im_name], self.change_maps[im_name]        

    def get_ood_img(self, im_name):
        return self.transform(self.imgs_1[im_name]).squeeze(0), self.transform(self.imgs_2[im_name]).squeeze(0), self.change_maps[im_name]
        
    def __len__(self):
        return self.n_patches

    def __getitem__(self, idx):
        current_patch_coords = self.patch_coords[idx]
        im_name = current_patch_coords[0]
        limits = current_patch_coords[1]
        centre = current_patch_coords[2]
        
        I1 = self.imgs_1[im_name][:, limits[0]:limits[1], limits[2]:limits[3]]
        I2 = self.imgs_2[im_name][:, limits[0]:limits[1], limits[2]:limits[3]]
        
        label = self.change_maps[im_name][limits[0]:limits[1], limits[2]:limits[3]]
        label = torch.from_numpy(1*np.array(label)).float()
        
        sample = {'I1': I1, 'I2': I2, 'label': label}
        
        if self.transform:
            sample = self.transform(sample)

        return sample



class RandomFlip(object):
    """Flip randomly the images in a sample."""

#     def __init__(self):
#         return

    def __call__(self, sample):
        I1, I2, label = sample['I1'], sample['I2'], sample['label']
        
        if random.random() > 0.5:
            I1 =  I1.numpy()[:,:,::-1].copy()
            I1 = torch.from_numpy(I1)
            I2 =  I2.numpy()[:,:,::-1].copy()
            I2 = torch.from_numpy(I2)
            label =  label.numpy()[:,::-1].copy()
            label = torch.from_numpy(label)

        return {'I1': I1, 'I2': I2, 'label': label}



class RandomRot(object):
    """Rotate randomly the images in a sample."""

#     def __init__(self):
#         return

    def __call__(self, sample):
        I1, I2, label = sample['I1'], sample['I2'], sample['label']
        
        n = random.randint(0, 3)
        if n:
            I1 =  sample['I1'].numpy()
            I1 = np.rot90(I1, n, axes=(1, 2)).copy()
            I1 = torch.from_numpy(I1)
            I2 =  sample['I2'].numpy()
            I2 = np.rot90(I2, n, axes=(1, 2)).copy()
            I2 = torch.from_numpy(I2)
            label =  sample['label'].numpy()
            label = np.rot90(label, n, axes=(0, 1)).copy()
            label = torch.from_numpy(label)

        return {'I1': I1, 'I2': I2, 'label': label}

print('UTILS OK')

# Training functions
def train(model, net_name, criterion, optimizer, scheduler, train_loader, train_dataset, test_dataset, feat=False, n_epochs = N_EPOCHS, save = True):
    L = 1024
    N = 2
    t = np.linspace(1, n_epochs, n_epochs)
    
    epoch_train_loss = 0 * t
    epoch_train_accuracy = 0 * t
    epoch_train_change_accuracy = 0 * t
    epoch_train_nochange_accuracy = 0 * t
    epoch_train_precision = 0 * t
    epoch_train_recall = 0 * t
    epoch_train_Fmeasure = 0 * t
    epoch_test_loss = 0 * t
    epoch_test_accuracy = 0 * t
    epoch_test_change_accuracy = 0 * t
    epoch_test_nochange_accuracy = 0 * t
    epoch_test_precision = 0 * t
    epoch_test_recall = 0 * t
    epoch_test_Fmeasure = 0 * t
    
    
    fm = 0
    best_fm = 0
    
    lss = 1000
    best_lss = 1000
    
    plt.figure(num=1)
    plt.figure(num=2)
    plt.figure(num=3)
    
    for epoch_index in tqdm(range(n_epochs)):
        model.train()
        print('Epoch: ' + str(epoch_index + 1) + ' of ' + str(N_EPOCHS))

        tot_count = 0
        tot_loss = 0
        tot_accurate = 0
        class_correct = list(0. for i in range(2))
        class_total = list(0. for i in range(2))
#         for batch_index, batch in enumerate(tqdm(data_loader)):
        for batch in train_loader:
            I1 = Variable(batch['I1'].float().to(device))
            #I1 = Variable(batch['I1'].float())
            #I2 = Variable(batch['I2'].float())
            I2 = Variable(batch['I2'].float().to(device))
            #I2 = Variable(batch['I2'].float())
            #label = torch.squeeze(Variable(batch['label'].cuda()))
            label = torch.squeeze(Variable(batch['label'].to(device)))

            optimizer.zero_grad()
            if feat:
                output, _ = model(I1, I2)
            else:
                output = model(I1, I2)
            loss = criterion(output, label.long())
            loss.backward()
            optimizer.step()
            
        scheduler.step()


        epoch_train_loss[epoch_index], epoch_train_accuracy[epoch_index], cl_acc, pr_rec = test_patch_batch(train_dataset, criterion, model, feat=feat)
        epoch_train_nochange_accuracy[epoch_index] = cl_acc[0]
        epoch_train_change_accuracy[epoch_index] = cl_acc[1]
        epoch_train_precision[epoch_index] = pr_rec[0]
        epoch_train_recall[epoch_index] = pr_rec[1]
        epoch_train_Fmeasure[epoch_index] = pr_rec[2]
        
        epoch_test_loss[epoch_index], epoch_test_accuracy[epoch_index], cl_acc, pr_rec = test_patch_batch(test_dataset, criterion, model, feat=feat)
        epoch_test_nochange_accuracy[epoch_index] = cl_acc[0]
        epoch_test_change_accuracy[epoch_index] = cl_acc[1]
        epoch_test_precision[epoch_index] = pr_rec[0]
        epoch_test_recall[epoch_index] = pr_rec[1]
        epoch_test_Fmeasure[epoch_index] = pr_rec[2]

        plt.figure(num=1)
        plt.clf()
        l1_1, = plt.plot(t[:epoch_index + 1], epoch_train_loss[:epoch_index + 1], label='Train loss')
        l1_2, = plt.plot(t[:epoch_index + 1], epoch_test_loss[:epoch_index + 1], label='Test loss')
        plt.legend(handles=[l1_1, l1_2])
        plt.grid()
#         plt.gcf().gca().set_ylim(bottom = 0)
        plt.gcf().gca().set_xlim(left = 0)
        plt.title('Loss')
        display.clear_output(wait=True)
        display.display(plt.gcf())

        plt.figure(num=2)
        plt.clf()
        l2_1, = plt.plot(t[:epoch_index + 1], epoch_train_accuracy[:epoch_index + 1], label='Train accuracy')
        l2_2, = plt.plot(t[:epoch_index + 1], epoch_test_accuracy[:epoch_index + 1], label='Test accuracy')
        plt.legend(handles=[l2_1, l2_2])
        plt.grid()
        plt.gcf().gca().set_ylim(0, 100)
#         plt.gcf().gca().set_ylim(bottom = 0)
#         plt.gcf().gca().set_xlim(left = 0)
        plt.title('Accuracy')
        display.clear_output(wait=True)
        display.display(plt.gcf())

        plt.figure(num=3)
        plt.clf()
        l3_1, = plt.plot(t[:epoch_index + 1], epoch_train_nochange_accuracy[:epoch_index + 1], label='Train accuracy: no change')
        l3_2, = plt.plot(t[:epoch_index + 1], epoch_train_change_accuracy[:epoch_index + 1], label='Train accuracy: change')
        l3_3, = plt.plot(t[:epoch_index + 1], epoch_test_nochange_accuracy[:epoch_index + 1], label='Test accuracy: no change')
        l3_4, = plt.plot(t[:epoch_index + 1], epoch_test_change_accuracy[:epoch_index + 1], label='Test accuracy: change')
        plt.legend(handles=[l3_1, l3_2, l3_3, l3_4])
        plt.grid()
        plt.gcf().gca().set_ylim(0, 100)
#         plt.gcf().gca().set_ylim(bottom = 0)
#         plt.gcf().gca().set_xlim(left = 0)
        plt.title('Accuracy per class')
        display.clear_output(wait=True)
        display.display(plt.gcf())

        plt.figure(num=4)
        plt.clf()
        l4_1, = plt.plot(t[:epoch_index + 1], epoch_train_precision[:epoch_index + 1], linewidth = '1', label='Train precision')
        l4_2, = plt.plot(t[:epoch_index + 1], epoch_train_recall[:epoch_index + 1],  linewidth = '1', label='Train recall')
        l4_3, = plt.plot(t[:epoch_index + 1], epoch_train_Fmeasure[:epoch_index + 1],  linewidth = '1', label='Train Dice/F1')
        l4_4, = plt.plot(t[:epoch_index + 1], epoch_test_precision[:epoch_index + 1],  linestyle='dashed', linewidth = '2', label='Test precision')
        l4_5, = plt.plot(t[:epoch_index + 1], epoch_test_recall[:epoch_index + 1], linestyle='dashed', linewidth = '2', label='Test recall')
        l4_6, = plt.plot(t[:epoch_index + 1], epoch_test_Fmeasure[:epoch_index + 1], linestyle='dashed', linewidth = '2', label='Test Dice/F1')
        plt.legend(handles=[l4_1, l4_2, l4_3, l4_4, l4_5, l4_6])
        plt.grid()
        plt.gcf().gca().set_ylim(0, 1)
#         plt.gcf().gca().set_ylim(bottom = 0)
#         plt.gcf().gca().set_xlim(left = 0)
        plt.title('Precision, Recall and F-measure')
        display.clear_output(wait=True)
        display.display(plt.gcf())
        
        
        
        
#         fm = pr_rec[2]
        fm = epoch_train_Fmeasure[epoch_index]
        if fm > best_fm:
            best_fm = fm
            save_str = 'net-best_epoch-' + str(epoch_index + 1) + '_fm-' + str(fm) + '.pth.tar'
            torch.save(model.state_dict(), save_str)
        
        lss = epoch_train_loss[epoch_index]
        if lss < best_lss:
            best_lss = lss
            save_str = 'net-best_epoch-' + str(epoch_index + 1) + '_loss-' + str(lss) + '.pth.tar'
            torch.save(model.state_dict(), save_str)
            
            
#         print('Epoch loss: ' + str(tot_loss/tot_count))
        if save:
            im_format = 'png'
    #         im_format = 'eps'

            plt.figure(num=1)
            plt.savefig(net_name + '-01-loss.' + im_format)

            plt.figure(num=2)
            plt.savefig(net_name + '-02-accuracy.' + im_format)

            plt.figure(num=3)
            plt.savefig(net_name + '-03-accuracy-per-class.' + im_format)

            plt.figure(num=4)
            plt.savefig(net_name + '-04-prec-rec-fmeas.' + im_format)
        
    out = {'train_loss': epoch_train_loss[-1],
           'train_accuracy': epoch_train_accuracy[-1],
           'train_nochange_accuracy': epoch_train_nochange_accuracy[-1],
           'train_change_accuracy': epoch_train_change_accuracy[-1],
           'test_loss': epoch_test_loss[-1],
           'test_accuracy': epoch_test_accuracy[-1],
           'test_nochange_accuracy': epoch_test_nochange_accuracy[-1],
           'test_change_accuracy': epoch_test_change_accuracy[-1]}
    
    print('pr_c, rec_c, f_meas, pr_nc, rec_nc')
    print(pr_rec)
    
    return out


def kappa(tp, tn, fp, fn):
    N = tp + tn + fp + fn
    p0 = (tp + tn) / N
    pe = ((tp+fp)*(tp+fn) + (tn+fp)*(tn+fn)) / (N * N)
    
    return (p0 - pe) / (1 - pe)

def test_patch_batch(dset, criterion, model, feat=False, batch_size=BATCH_SIZE, patch_size=PATCH_SIDE):
    L = 1024
    N = 2
    model.eval()
    tot_loss = 0
    tot_count = 0
    tot_accurate = 0
    model = model.to(device)
    
    n = 2
    class_correct = list(0. for i in range(n))
    class_total = list(0. for i in range(n))
    class_accuracy = list(0. for i in range(n))
    
    tp = 0
    tn = 0
    fp = 0
    fn = 0
    
    for img_index in dset.names:
        I1_full, I2_full, cm_full = dset.get_img(img_index)
        full_x, full_y = cm_full.shape
        cm = Variable(torch.unsqueeze(torch.from_numpy(1.0*cm_full),0).float().to(device))
        cm_patch, cm_x, cm_y, cm_px, cm_py = get_mask_patch_batch(cm_full)
        I1_patch, i_x, i_y, i_px, i_py = get_image_patch_batch(I1_full)
        I2_patch, _, _, _, _ = get_image_patch_batch(I2_full)
        
        # first perform inference
        I1 = Variable(I1_patch.float().to(device))
        I2 = Variable(I2_patch.float().to(device))
        if feat:
            output, _ = model(I1, I2)
        else:
            output = model(I1, I2)
        #print(output.shape)
        output_size = (cm_px,cm_py)
        b_s =  1
        cm_recon = reconstruct_image_batch(output, b_s, output_size, patch_size)
        #print(cm_recon.shape)
        output = cm_recon[:, :, :full_x, :full_y]
        #print(output.shape)
        loss = criterion(output, cm.long())
        tot_loss += loss.data * np.prod(cm.size())
        tot_count += np.prod(cm.size())

        _, predicted = torch.max(output.data, 1)

        c = (predicted.int() == cm.data.int())
        for i in range(c.size(1)):
            for j in range(c.size(2)):
                l = int(cm.data[0, i, j])
                class_correct[l] += c[0, i, j]
                class_total[l] += 1
                
        pr = (predicted.int() > 0).cpu().numpy()
        gt = (cm.data.int() > 0).cpu().numpy()
        
        tp += np.logical_and(pr, gt).sum()
        tn += np.logical_and(np.logical_not(pr), np.logical_not(gt)).sum()
        fp += np.logical_and(pr, np.logical_not(gt)).sum()
        fn += np.logical_and(np.logical_not(pr), gt).sum()
        
    net_loss = tot_loss/tot_count        
    net_loss = float(net_loss.cpu().numpy())
    
    net_accuracy = 100 * (tp + tn)/tot_count
    
    for i in range(n):
        class_accuracy[i] = 100 * class_correct[i] / max(class_total[i],0.00001)
        class_accuracy[i] =  float(class_accuracy[i].cpu().numpy())

    prec = tp / (tp + fp)
    rec = tp / (tp + fn)
    f_meas = 2 * prec * rec / (prec + rec)
    dice = 2 * prec * rec / (prec + rec)
    prec_nc = tn / (tn + fn)


    rec_nc = tn / (tn + fp)
    
    pr_rec = [prec, rec, f_meas, prec_nc, rec_nc]
        
    return net_loss, net_accuracy, class_accuracy, pr_rec




# Training functions
def train_ood_batch(model, net_name, criterion, optimizer, scheduler, train_loader, train_dataset, test_dataset, feat=False, n_epochs = N_EPOCHS, save = True):
    L = 1024
    N = 2
    t = np.linspace(1, n_epochs, n_epochs)
    
    epoch_train_loss = 0 * t
    epoch_train_accuracy = 0 * t
    epoch_train_change_accuracy = 0 * t
    epoch_train_nochange_accuracy = 0 * t
    epoch_train_precision = 0 * t
    epoch_train_recall = 0 * t
    epoch_train_Fmeasure = 0 * t
    epoch_test_loss = 0 * t
    epoch_test_accuracy = 0 * t
    epoch_test_change_accuracy = 0 * t
    epoch_test_nochange_accuracy = 0 * t
    epoch_test_precision = 0 * t
    epoch_test_recall = 0 * t
    epoch_test_Fmeasure = 0 * t
    
    
    fm = 0
    best_fm = 0
    
    lss = 1000
    best_lss = 1000
    
    plt.figure(num=1)
    plt.figure(num=2)
    plt.figure(num=3)
    
    for epoch_index in tqdm(range(n_epochs)):
        model.train()
        print('Epoch: ' + str(epoch_index + 1) + ' of ' + str(N_EPOCHS))

        tot_count = 0
        tot_loss = 0
        tot_accurate = 0
        class_correct = list(0. for i in range(2))
        class_total = list(0. for i in range(2))
#         for batch_index, batch in enumerate(tqdm(data_loader)):
        for batch in train_loader:
            I1 = Variable(batch['I1'].float().to(device))
            #I1 = Variable(batch['I1'].float())
            #I2 = Variable(batch['I2'].float())
            I2 = Variable(batch['I2'].float().to(device))
            #I2 = Variable(batch['I2'].float())
            #label = torch.squeeze(Variable(batch['label'].cuda()))
            label = torch.squeeze(Variable(batch['label'].to(device)))

            optimizer.zero_grad()
            if feat:
                output, _ = model(I1, I2)
            else:
                output = model(I1, I2)
            loss = criterion(output, label.long())
            loss.backward()
            optimizer.step()
            
        scheduler.step()


        epoch_train_loss[epoch_index], epoch_train_accuracy[epoch_index], cl_acc, pr_rec = test_patch_batch_ood(train_dataset, criterion, model, feat=feat)
        epoch_train_nochange_accuracy[epoch_index] = cl_acc[0]
        epoch_train_change_accuracy[epoch_index] = cl_acc[1]
        epoch_train_precision[epoch_index] = pr_rec[0]
        epoch_train_recall[epoch_index] = pr_rec[1]
        epoch_train_Fmeasure[epoch_index] = pr_rec[2]
        
        epoch_test_loss[epoch_index], epoch_test_accuracy[epoch_index], cl_acc, pr_rec = test_patch_batch_ood(test_dataset, criterion, model, feat=feat)
        epoch_test_nochange_accuracy[epoch_index] = cl_acc[0]
        epoch_test_change_accuracy[epoch_index] = cl_acc[1]
        epoch_test_precision[epoch_index] = pr_rec[0]
        epoch_test_recall[epoch_index] = pr_rec[1]
        epoch_test_Fmeasure[epoch_index] = pr_rec[2]        
        
#         fm = pr_rec[2]
        fm = epoch_train_Fmeasure[epoch_index]
        if fm > best_fm:
            best_fm = fm
            save_str = 'net-best_epoch-' + str(epoch_index + 1) + '_fm-' + str(fm) + '.pth.tar'
            torch.save(model.state_dict(), save_str)
        
        lss = epoch_train_loss[epoch_index]
        if lss < best_lss:
            best_lss = lss
            save_str = 'net-best_epoch-' + str(epoch_index + 1) + '_loss-' + str(lss) + '.pth.tar'
            torch.save(model.state_dict(), save_str)
            
        
    out = {'train_loss': epoch_train_loss[-1],
           'train_accuracy': epoch_train_accuracy[-1],
           'train_nochange_accuracy': epoch_train_nochange_accuracy[-1],
           'train_change_accuracy': epoch_train_change_accuracy[-1],
           'test_loss': epoch_test_loss[-1],
           'test_accuracy': epoch_test_accuracy[-1],
           'test_nochange_accuracy': epoch_test_nochange_accuracy[-1],
           'test_change_accuracy': epoch_test_change_accuracy[-1]}
    
    print('pr_c, rec_c, f_meas, pr_nc, rec_nc')
    print(pr_rec)
    
    return out


def kappa(tp, tn, fp, fn):
    N = tp + tn + fp + fn
    p0 = (tp + tn) / N
    pe = ((tp+fp)*(tp+fn) + (tn+fp)*(tn+fn)) / (N * N)
    
    return (p0 - pe) / (1 - pe)

def test_patch_batch_ood(dset, criterion, model, feat=False, batch_size=BATCH_SIZE, patch_size=PATCH_SIDE):
    L = 1024
    N = 2
    model.eval()
    tot_loss = 0
    tot_count = 0
    tot_accurate = 0
    model = model.to(device)
    
    n = 2
    class_correct = list(0. for i in range(n))
    class_total = list(0. for i in range(n))
    class_accuracy = list(0. for i in range(n))
    
    tp = 0
    tn = 0
    fp = 0
    fn = 0
    
    for img_index in dset.dataset.names:
        I1_full, I2_full, cm_full = dset.dataset.get_img(img_index)
        full_x, full_y = cm_full.shape
        cm = Variable(torch.unsqueeze(torch.from_numpy(1.0*cm_full),0).float().to(device))
        cm_patch, cm_x, cm_y, cm_px, cm_py = get_mask_patch_batch(cm_full)
        I1_patch, i_x, i_y, i_px, i_py = get_image_patch_batch(I1_full)
        I2_patch, _, _, _, _ = get_image_patch_batch(I2_full)
        
        # first perform inference
        I1 = Variable(I1_patch.float().to(device))
        I2 = Variable(I2_patch.float().to(device))
        if feat:
            output, _ = model(I1, I2)
        else:
            output = model(I1, I2)
        #print(output.shape)
        output_size = (cm_px,cm_py)
        b_s =  1
        cm_recon = reconstruct_image_batch(output, b_s, output_size, patch_size)
        #print(cm_recon.shape)
        output = cm_recon[:, :, :full_x, :full_y]
        #print(output.shape)
        loss = criterion(output, cm.long())
        tot_loss += loss.data * np.prod(cm.size())
        tot_count += np.prod(cm.size())

        _, predicted = torch.max(output.data, 1)

        c = (predicted.int() == cm.data.int())
        for i in range(c.size(1)):
            for j in range(c.size(2)):
                l = int(cm.data[0, i, j])
                class_correct[l] += c[0, i, j]
                class_total[l] += 1
                
        pr = (predicted.int() > 0).cpu().numpy()
        gt = (cm.data.int() > 0).cpu().numpy()
        
        tp += np.logical_and(pr, gt).sum()
        tn += np.logical_and(np.logical_not(pr), np.logical_not(gt)).sum()
        fp += np.logical_and(pr, np.logical_not(gt)).sum()
        fn += np.logical_and(np.logical_not(pr), gt).sum()
        
    net_loss = tot_loss/tot_count        
    net_loss = float(net_loss.cpu().numpy())
    
    net_accuracy = 100 * (tp + tn)/tot_count
    
    for i in range(n):
        class_accuracy[i] = 100 * class_correct[i] / max(class_total[i],0.00001)
        class_accuracy[i] =  float(class_accuracy[i].cpu().numpy())

    prec = tp / (tp + fp)
    rec = tp / (tp + fn)
    f_meas = 2 * prec * rec / (prec + rec)
    dice = 2 * prec * rec / (prec + rec)
    prec_nc = tn / (tn + fn)


    rec_nc = tn / (tn + fp)
    
    pr_rec = [prec, rec, f_meas, prec_nc, rec_nc]
        
    return net_loss, net_accuracy, class_accuracy, pr_rec

def test(dset, criterion, model, feat=False):
    L = 1024
    N = 2
    model.eval()
    tot_loss = 0
    tot_count = 0
    tot_accurate = 0
    
    n = 2
    class_correct = list(0. for i in range(n))
    class_total = list(0. for i in range(n))
    class_accuracy = list(0. for i in range(n))
    
    tp = 0
    tn = 0
    fp = 0
    fn = 0
    
    for img_index in dset.names:
        I1_full, I2_full, cm_full = dset.get_img(img_index)
        
        s = cm_full.shape
        

        steps0 = np.arange(0,s[0],ceil(s[0]/N))
        steps1 = np.arange(0,s[1],ceil(s[1]/N))
        for ii in range(N):
            for jj in range(N):
                xmin = steps0[ii]
                if ii == N-1:
                    xmax = s[0]
                else:
                    xmax = steps0[ii+1]
                ymin = jj
                if jj == N-1:
                    ymax = s[1]
                else:
                    ymax = steps1[jj+1]
                I1 = I1_full[:, xmin:xmax, ymin:ymax]
                I2 = I2_full[:, xmin:xmax, ymin:ymax]
                cm = cm_full[xmin:xmax, ymin:ymax]

                I1 = Variable(torch.unsqueeze(I1, 0).float().to(device))
                I2 = Variable(torch.unsqueeze(I2, 0).float().to(device))
                cm = Variable(torch.unsqueeze(torch.from_numpy(1.0*cm),0).float().to(device))
                #cm = Variable(torch.from_numpy(1.0*cm).float())
                #print(I1.shape, I2.shape, cm.shape)


                if feat:
                    output, _ = model(I1, I2)
                else:
                    output = model(I1, I2)
                loss = criterion(output, cm.long())
        #         print(loss)
                tot_loss += loss.data * np.prod(cm.size())
                tot_count += np.prod(cm.size())

                _, predicted = torch.max(output.data, 1)

                c = (predicted.int() == cm.data.int())
                for i in range(c.size(1)):
                    for j in range(c.size(2)):
                        l = int(cm.data[0, i, j])
                        class_correct[l] += c[0, i, j]
                        class_total[l] += 1
                        
                pr = (predicted.int() > 0).cpu().numpy()
                gt = (cm.data.int() > 0).cpu().numpy()
                
                tp += np.logical_and(pr, gt).sum()
                tn += np.logical_and(np.logical_not(pr), np.logical_not(gt)).sum()
                fp += np.logical_and(pr, np.logical_not(gt)).sum()
                fn += np.logical_and(np.logical_not(pr), gt).sum()
        
    net_loss = tot_loss/tot_count
    net_accuracy = 100 * (tp + tn)/tot_count
    
    for i in range(n):
        class_accuracy[i] = 100 * class_correct[i] / max(class_total[i],0.00001)

    prec = tp / (tp + fp)
    rec = tp / (tp + fn)
    f_meas = 2 * prec * rec / (prec + rec)
    prec_nc = tn / (tn + fn)
    rec_nc = tn / (tn + fp)
    
    pr_rec = [prec, rec, f_meas, prec_nc, rec_nc]
        
    return net_loss, net_accuracy, class_accuracy, pr_rec

def unseen_test(dset, criterion, model, feat=False):
    L = 1024
    N = 2
    model.eval()
    tot_loss = 0
    tot_count = 0
    tot_accurate = 0
    
    n = 2
    class_correct = list(0. for i in range(n))
    class_total = list(0. for i in range(n))
    class_accuracy = list(0. for i in range(n))
    
    tp = 0
    tn = 0
    fp = 0
    fn = 0
    
    for img_index in tqdm(dset.names):
        I1_full, I2_full, cm_full = dset.get_img(img_index)
        
        s = cm_full.shape
        
        for ii in range(ceil(s[0]/L)):
            for jj in range(ceil(s[1]/L)):
                xmin = L*ii
                xmax = min(L*(ii+1),s[1])
                ymin = L*jj
                ymax = min(L*(jj+1),s[1])
                I1 = I1_full[:, xmin:xmax, ymin:ymax]
                I2 = I2_full[:, xmin:xmax, ymin:ymax]
                cm = cm_full[xmin:xmax, ymin:ymax]

                I1 = Variable(torch.unsqueeze(I1, 0).float())
                I2 = Variable(torch.unsqueeze(I2, 0).float())
                cm = Variable(torch.unsqueeze(torch.from_numpy(1.0*cm),0).float())

                if feat:
                    output, _ = model(I1, I2)
                else:
                    output = model(I1, I2)
                    
                loss = criterion(output, cm.long())
                tot_loss += loss.data * np.prod(cm.size())
                tot_count += np.prod(cm.size())

                _, predicted = torch.max(output.data, 1)

                c = (predicted.int() == cm.data.int())
                for i in range(c.size(1)):
                    for j in range(c.size(2)):
                        l = int(cm.data[0, i, j])
                        class_correct[l] += c[0, i, j]
                        class_total[l] += 1
                        
                pr = (predicted.int() > 0).cpu().numpy()
                gt = (cm.data.int() > 0).cpu().numpy()
                
                tp += np.logical_and(pr, gt).sum()
                tn += np.logical_and(np.logical_not(pr), np.logical_not(gt)).sum()
                fp += np.logical_and(pr, np.logical_not(gt)).sum()
                fn += np.logical_and(np.logical_not(pr), gt).sum()
        
    net_loss = tot_loss/tot_count        
    net_loss = float(net_loss.cpu().numpy())
    
    net_accuracy = 100 * (tp + tn)/tot_count
    
    for i in range(n):
        class_accuracy[i] = 100 * class_correct[i] / max(class_total[i],0.00001)
        class_accuracy[i] =  float(class_accuracy[i].cpu().numpy())

    prec = tp / (tp + fp)
    rec = tp / (tp + fn)
    f_meas = 2 * prec * rec / (prec + rec)
    dice = 2 * prec * rec / (prec + rec)
    prec_nc = tn / (tn + fn)


    rec_nc = tn / (tn + fp)
    
    pr_rec = [prec, rec, dice, prec_nc, rec_nc]
    
    k = kappa(tp, tn, fp, fn)
    
    return {'net_loss': net_loss, 
            'net_accuracy': net_accuracy, 
            'class_accuracy': class_accuracy, 
            'precision': prec, 
            'recall': rec, 
            'dice': dice, 
            'kappa': k,
            'f_meas':f_meas}
# test patch
def extract_patches_batch(images, patch_size):
    """
    Extracts patches from a batch of images.
    
    Args:
        images (torch.Tensor): Input images of shape (B, C, H, W).
        patch_size (int): Size of each patch.

    Returns:
        torch.Tensor: Extracted patches of shape (B * num_patches, C, patch_size, patch_size)
    """
    B, C, H, W = images.shape
    #print(images.shape)

    # Ensure the image size is divisible by patch size
    assert H % patch_size == 0 and W % patch_size == 0, "Image size must be divisible by patch size"

    # Unfold the images to extract patches
    patches = images.unfold(2, patch_size, patch_size).unfold(3, patch_size, patch_size)
    # Shape after unfold: (B, C, num_patches_H, num_patches_W, patch_size, patch_size)

    # Reshape into (B * num_patches, C, patch_size, patch_size)
    num_patches_H = H // patch_size
    num_patches_W = W // patch_size
    patches = patches.permute(0, 2, 3, 1, 4, 5).contiguous()  # (B, num_patches_H, num_patches_W, C, patch_size, patch_size)
    patches = patches.view(B * num_patches_H * num_patches_W, C, patch_size, patch_size)

    #print(f"Extracted patches shape: {patches.shape} (Expected {B * num_patches_H * num_patches_W})")
    return patches


def reconstruct_image_batch(patches, batch_size, output_size=(32, 32), patch_size=2):
    if isinstance(output_size, int):
        output_size = (output_size, output_size)  # Ensure it's a tuple

    if isinstance(batch_size, tuple):
        batch_size = batch_size[0]  # Extract integer batch size

    # Ensure output size is valid
    if output_size[0] < patch_size or output_size[1] < patch_size:
        raise ValueError(f"Output size {output_size} must be larger than patch size {patch_size}")

    num_patches_per_row = output_size[0] // patch_size
    num_patches_per_col = output_size[1] // patch_size
    num_patches = num_patches_per_row * num_patches_per_col

    #print(f"Num patches: {num_patches}")
    #print(f"Fixed batch_size: {batch_size}")
    #print(f"num_patches_per_row: {num_patches_per_row}, num_patches_per_col: {num_patches_per_col}")
    #print(f"Expected {batch_size * num_patches} patches, but got {patches.shape[0]}")

    # Fix batch_size issue
    if patches.shape[0] == batch_size * num_patches:
        patches = patches.view(batch_size, num_patches_per_row, num_patches_per_col, -1, patch_size, patch_size)
    elif patches.shape[0] == num_patches:
        # If only patches for one image exist, assume batch_size=1
        patches = patches.view(1, num_patches_per_row, num_patches_per_col, -1, patch_size, patch_size)
    else:
        raise ValueError(f"Unexpected number of patches: {patches.shape[0]}")

    patches = patches.permute(0, 3, 1, 4, 2, 5).contiguous()
    reconstructed = patches.view(batch_size, -1, output_size[0], output_size[1])

    return reconstructed

def pad_trick(i, p):
    """
    Pads a multi-channel image to be divisible by p.
    
    Args:
        i (torch.Tensor or np.ndarray): Input image tensor of shape (C, H, W).
        p (int): Patch size to pad to.
    
    Returns:
        torch.Tensor: Padded image tensor of shape (C, H+p_h, W+p_w).
    """
    if isinstance(i, np.ndarray):
        i = torch.from_numpy(i)

    assert i.ndim == 3, f"Expected input shape (C, H, W), but got {i.shape}"

    C, H, W = i.shape  # Now considering channels
    p_h = (p - (H % p)) % p
    p_w = (p - (W % p)) % p

    # Padding format: (left, right, top, bottom)
    pi = F.pad(i, (0, p_w, 0, p_h), mode="constant", value=0)
    return pi


def get_mask_patch_batch(cmt):
    b_s = 1
    patch_size = PATCH_SIDE
    cmt = torch.from_numpy(cmt)
    orig_x = cmt.shape[0]
    orig_y = cmt.shape[1]
    cmt = torch.unsqueeze(cmt, 0)
    cmt = pad_trick(cmt, patch_size)
    cm_batch = torch.unsqueeze(cmt, 0)
    cm_patch = extract_patches_batch(cm_batch, patch_size)
    padded_x = cm_batch.shape[2]
    padded_y = cm_batch.shape[3]
    #cm_recon = reconstruct_image_batch(cm_patch, b_s, output_size, patch_size)
    return cm_patch, orig_x, orig_y, padded_x, padded_y 

def get_image_patch_batch(cmt):
    b_s = 1
    patch_size = PATCH_SIDE
    #cmt = torch.from_numpy(cmt)
    orig_x = cmt.shape[1]
    orig_y = cmt.shape[2]
    cmt = pad_trick(cmt, patch_size)
    cm_batch = torch.unsqueeze(cmt, 0)
    padded_x = cm_batch.shape[2]
    padded_y = cm_batch.shape[3]
    cm_patch = extract_patches_batch(cm_batch, patch_size)
    #cm_recon = reconstruct_image_batch(cm_patch, b_s, output_size, patch_size)
    return cm_patch, orig_x, orig_y, padded_x, padded_y
    
def unseen_test_patch(dset, criterion, model, feat=False):
    L = 1024
    N = 2
    model.eval()
    tot_loss = 0
    tot_count = 0
    tot_accurate = 0
    model = model.to(device)
    
    n = 2
    class_correct = list(0. for i in range(n))
    class_total = list(0. for i in range(n))
    class_accuracy = list(0. for i in range(n))
    
    tp = 0
    tn = 0
    fp = 0
    fn = 0
    
    for img_index in tqdm(dset.names):
        I1_full, I2_full, cm_full = dset.get_img(img_index)
        cm_patch, cm_x, cm_y, cm_px, cm_py = get_mask_patch_batch(cm_full)
        I1_patch, i_x, i_y, i_px, i_py = get_image_patch_batch(I1_full)
        I2_patch, _, _, _, _ = get_image_patch_batch(I2_full)
        #print(I1_patch.shape)

        for i in range(cm_patch.shape[0]):
            I1 = I1_patch[i]
            I2 = I2_patch[i]
            cm = cm_patch[i]
            #print(I1.shape)

            I1 = Variable(torch.unsqueeze(I1, 0).float().to(device))
            I2 = Variable(torch.unsqueeze(I2, 0).float().to(device))
            cm = Variable((1.0*cm).float().to(device))

            if feat:
                output, _ = model(I1, I2)
            else:
                output = model(I1, I2)
                
            loss = criterion(output, cm.long())
            tot_loss += loss.data * np.prod(cm.size())
            tot_count += np.prod(cm.size())

            _, predicted = torch.max(output.data, 1)

            c = (predicted.int() == cm.data.int())
            for i in range(c.size(1)):
                for j in range(c.size(2)):
                    l = int(cm.data[0, i, j])
                    class_correct[l] += c[0, i, j]
                    class_total[l] += 1
                    
            pr = (predicted.int() > 0).cpu().numpy()
            gt = (cm.data.int() > 0).cpu().numpy()
            
            tp += np.logical_and(pr, gt).sum()
            tn += np.logical_and(np.logical_not(pr), np.logical_not(gt)).sum()
            fp += np.logical_and(pr, np.logical_not(gt)).sum()
            fn += np.logical_and(np.logical_not(pr), gt).sum()
        
    net_loss = tot_loss/tot_count        
    net_loss = float(net_loss.cpu().numpy())
    
    net_accuracy = 100 * (tp + tn)/tot_count
    
    for i in range(n):
        class_accuracy[i] = 100 * class_correct[i] / max(class_total[i],0.00001)
        class_accuracy[i] =  float(class_accuracy[i].cpu().numpy())

    prec = tp / (tp + fp)
    rec = tp / (tp + fn)
    f_meas = 2 * prec * rec / (prec + rec)
    dice = 2 * prec * rec / (prec + rec)
    prec_nc = tn / (tn + fn)


    rec_nc = tn / (tn + fp)
    
    pr_rec = [prec, rec, dice, prec_nc, rec_nc]
    
    k = kappa(tp, tn, fp, fn)
    
    return {'net_loss': net_loss, 
            'net_accuracy': net_accuracy, 
            'class_accuracy': class_accuracy, 
            'precision': prec, 
            'recall': rec, 
            'dice': dice, 
            'kappa': k,
            'f_meas':f_meas}

def unseen_test_patch_batch(dset, criterion, model, feat=False, batch_size=BATCH_SIZE, patch_size=PATCH_SIDE, ood=False):
    L = 1024
    N = 2
    model.eval()
    tot_loss = 0
    tot_count = 0
    tot_accurate = 0
    model = model.to(device)
    
    n = 2
    class_correct = list(0. for i in range(n))
    class_total = list(0. for i in range(n))
    class_accuracy = list(0. for i in range(n))
    
    tp = 0
    tn = 0
    fp = 0
    fn = 0
    
    for img_index in tqdm(dset.names):
        if ood:
            I1_full, I2_full, cm_full = dset.get_ood_img(img_index)
        else:
            I1_full, I2_full, cm_full = dset.get_img(img_index)
        full_x, full_y = cm_full.shape
        cm = Variable(torch.unsqueeze(torch.from_numpy(1.0*cm_full),0).float().to(device))
        cm_patch, cm_x, cm_y, cm_px, cm_py = get_mask_patch_batch(cm_full)
        I1_patch, i_x, i_y, i_px, i_py = get_image_patch_batch(I1_full)
        I2_patch, _, _, _, _ = get_image_patch_batch(I2_full)
        
        # first perform inference
        I1 = Variable(I1_patch.float().to(device))
        I2 = Variable(I2_patch.float().to(device))
        if feat:
            output, _ = model(I1, I2)
        else:
            output = model(I1, I2)
        #print(output.shape)
        output_size = (cm_px,cm_py)
        b_s =  1
        cm_recon = reconstruct_image_batch(output, b_s, output_size, patch_size)
        #print(cm_recon.shape)
        output = cm_recon[:, :, :full_x, :full_y]
        #print(output.shape)
        loss = criterion(output, cm.long())
        tot_loss += loss.data * np.prod(cm.size())
        tot_count += np.prod(cm.size())

        _, predicted = torch.max(output.data, 1)

        c = (predicted.int() == cm.data.int())
        for i in range(c.size(1)):
            for j in range(c.size(2)):
                l = int(cm.data[0, i, j])
                class_correct[l] += c[0, i, j]
                class_total[l] += 1
                
        pr = (predicted.int() > 0).cpu().numpy()
        gt = (cm.data.int() > 0).cpu().numpy()
        
        tp += np.logical_and(pr, gt).sum()
        tn += np.logical_and(np.logical_not(pr), np.logical_not(gt)).sum()
        fp += np.logical_and(pr, np.logical_not(gt)).sum()
        fn += np.logical_and(np.logical_not(pr), gt).sum()
        
    net_loss = tot_loss/tot_count        
    net_loss = float(net_loss.cpu().numpy())
    
    net_accuracy = 100 * (tp + tn)/tot_count
    
    for i in range(n):
        class_accuracy[i] = 100 * class_correct[i] / max(class_total[i],0.00001)
        class_accuracy[i] =  float(class_accuracy[i].cpu().numpy())

    prec = tp / (tp + fp)
    rec = tp / (tp + fn)
    f_meas = 2 * prec * rec / (prec + rec)
    dice = 2 * prec * rec / (prec + rec)
    prec_nc = tn / (tn + fn)


    rec_nc = tn / (tn + fp)
    
    pr_rec = [prec, rec, dice, prec_nc, rec_nc]
    
    k = kappa(tp, tn, fp, fn)
    
    return {'net_loss': net_loss, 
            'net_accuracy': net_accuracy, 
            'class_accuracy': class_accuracy, 
            'precision': prec, 
            'recall': rec, 
            'dice': dice, 
            'kappa': k,
            'f_meas':f_meas}



# --- helper: per-image, per-channel [0,1] scaling (optional) ---
class PerImageMinMax01(object):
    def __call__(self, sample):
        for k in ('I1','I2'):
            x = sample[k].float()
            xmin = x.amin(dim=(1,2), keepdim=True)
            xmax = x.amax(dim=(1,2), keepdim=True)
            sample[k] = (x - xmin) / (xmax - xmin + 1e-6)
        return sample

# --- DROP-IN REPLACEMENT (minimal changes) ---
def unseen_test_patch_batch(
        dset, criterion, model, feat=False,
        batch_size=BATCH_SIZE, patch_size=PATCH_SIDE,
        ood=False,                      # keep your flag
        pair_transform=None,            # NEW: callable(sample)->sample for OOD
        apply_minmax=False              # NEW: normalize to [0,1] before OOD
    ):
    """
    If ood=True and pair_transform is provided, we apply it to full images
    (I1_full, I2_full) BEFORE patching. Labels are left unchanged.
    Metrics are NaN-safe.
    """
    model.eval()
    model = model.to(device)

    n = 2
    class_correct = [0.0 for _ in range(n)]
    class_total   = [0.0 for _ in range(n)]

    tp = tn = fp = fn = 0
    tot_loss = 0.0
    tot_count = 0.0
    eps = 1e-9  # for NaN-safe divisions

    mm = PerImageMinMax01() if apply_minmax else None

    for img_index in tqdm(dset.names):
        # --- load full images (your original API) ---
        I1_full, I2_full, cm_full = dset.get_img(img_index)

        # --- optional minmax + pair-wise OOD transform (full-frame) ---
        if ood and pair_transform is not None:
            sample = {'I1': I1_full.clone(), 'I2': I2_full.clone(),
                      'label': torch.from_numpy(1.0*cm_full).float()}
            if mm is not None:
                sample = mm(sample)
            sample = pair_transform(sample)
            I1_full, I2_full = sample['I1'], sample['I2']  # keep cm_full unchanged

        # --- patching (unchanged) ---
        full_x, full_y = cm_full.shape
        cm = Variable(torch.unsqueeze(torch.from_numpy(1.0*cm_full),0).float().to(device))
        cm_patch, cm_x, cm_y, cm_px, cm_py = get_mask_patch_batch(cm_full)
        I1_patch, i_x, i_y, i_px, i_py = get_image_patch_batch(I1_full)
        I2_patch, _, _, _, _ = get_image_patch_batch(I2_full)

        I1 = Variable(I1_patch.float().to(device))
        I2 = Variable(I2_patch.float().to(device))

        # --- model forward (unchanged) ---
        if feat:
            output, _ = model(I1, I2)
        else:
            output = model(I1, I2)

        output_size = (cm_px, cm_py)
        b_s = 1
        cm_recon = reconstruct_image_batch(output, b_s, output_size, patch_size)
        output = cm_recon[:, :, :full_x, :full_y]

        # --- loss accounting (NaN-safe count) ---
        loss = criterion(output, cm.long())
        num_pix = float(np.prod(cm.size()))
        tot_loss  += loss.data * num_pix
        tot_count += num_pix

        # --- predictions & per-class accuracy (unchanged logic, but safe) ---
        _, predicted = torch.max(output.data, 1)
        c = (predicted.int() == cm.data.int())
        for i in range(c.size(1)):
            for j in range(c.size(2)):
                l = int(cm.data[0, i, j])
                class_correct[l] += float(c[0, i, j])
                class_total[l]   += 1.0

        # --- confusion counts (NaN-safe metrics later) ---
        pr = (predicted.int() > 0).cpu().numpy()
        gt = (cm.data.int() > 0).cpu().numpy()
        tp += np.logical_and(pr, gt).sum()
        tn += np.logical_and(~pr, ~gt).sum()
        fp += np.logical_and(pr, ~gt).sum()
        fn += np.logical_and(~pr, gt).sum()

    # --- aggregate metrics (NaN-safe) ---
    net_loss = float((tot_loss / max(tot_count, 1.0)).cpu().numpy())
    net_accuracy = 100.0 * (tp + tn) / max(tot_count, 1.0)

    class_accuracy = [0.0, 0.0]
    for i in range(n):
        class_accuracy[i] = 100.0 * class_correct[i] / max(class_total[i], 1e-9)

    prec   = tp / max(tp + fp, 1e-9)
    rec    = tp / max(tp + fn, 1e-9)
    f_meas = 2 * prec * rec / max(prec + rec, 1e-9)
    dice   = f_meas  # same formula here
    k      = kappa(tp, tn, fp, fn)  # your function

    return {
        'net_loss': net_loss,
        'net_accuracy': net_accuracy,
        'class_accuracy': class_accuracy,
        'precision': prec,
        'recall': rec,
        'dice': dice,
        'kappa': k,
        'f_meas': f_meas
    }





def unseen_test_patch_batch_ood(dset, criterion, model, feat=False, batch_size=BATCH_SIZE, patch_size=PATCH_SIDE, ood=False):
    L = 1024
    N = 2
    model.eval()
    tot_loss = 0
    tot_count = 0
    tot_accurate = 0
    model = model.to(device)
    
    n = 2
    class_correct = list(0. for i in range(n))
    class_total = list(0. for i in range(n))
    class_accuracy = list(0. for i in range(n))
    
    tp = 0
    tn = 0
    fp = 0
    fn = 0
    
    for img_index in tqdm(dset.dataset.names):
        if ood:
            I1_full, I2_full, cm_full = dset.dataset.get_ood_img(img_index)
        else:
            I1_full, I2_full, cm_full = dset.dataset.get_img(img_index)
        full_x, full_y = cm_full.shape
        cm = Variable(torch.unsqueeze(torch.from_numpy(1.0*cm_full),0).float().to(device))
        cm_patch, cm_x, cm_y, cm_px, cm_py = get_mask_patch_batch(cm_full)
        I1_patch, i_x, i_y, i_px, i_py = get_image_patch_batch(I1_full)
        I2_patch, _, _, _, _ = get_image_patch_batch(I2_full)
        
        # first perform inference
        I1 = Variable(I1_patch.float().to(device))
        I2 = Variable(I2_patch.float().to(device))
        if feat:
            output, _ = model(I1, I2)
        else:
            output = model(I1, I2)
        #print(output.shape)
        output_size = (cm_px,cm_py)
        b_s =  1
        cm_recon = reconstruct_image_batch(output, b_s, output_size, patch_size)
        #print(cm_recon.shape)
        output = cm_recon[:, :, :full_x, :full_y]
        #print(output.shape)
        loss = criterion(output, cm.long())
        tot_loss += loss.data * np.prod(cm.size())
        tot_count += np.prod(cm.size())

        _, predicted = torch.max(output.data, 1)

        c = (predicted.int() == cm.data.int())
        for i in range(c.size(1)):
            for j in range(c.size(2)):
                l = int(cm.data[0, i, j])
                class_correct[l] += c[0, i, j]
                class_total[l] += 1
                
        pr = (predicted.int() > 0).cpu().numpy()
        gt = (cm.data.int() > 0).cpu().numpy()
        
        tp += np.logical_and(pr, gt).sum()
        tn += np.logical_and(np.logical_not(pr), np.logical_not(gt)).sum()
        fp += np.logical_and(pr, np.logical_not(gt)).sum()
        fn += np.logical_and(np.logical_not(pr), gt).sum()
        
    net_loss = tot_loss/tot_count        
    net_loss = float(net_loss.cpu().numpy())
    
    net_accuracy = 100 * (tp + tn)/tot_count
    
    for i in range(n):
        class_accuracy[i] = 100 * class_correct[i] / max(class_total[i],0.00001)
        class_accuracy[i] =  float(class_accuracy[i].cpu().numpy())

    prec = tp / (tp + fp)
    rec = tp / (tp + fn)
    f_meas = 2 * prec * rec / (prec + rec)
    dice = 2 * prec * rec / (prec + rec)
    prec_nc = tn / (tn + fn)


    rec_nc = tn / (tn + fp)
    
    pr_rec = [prec, rec, dice, prec_nc, rec_nc]
    
    k = kappa(tp, tn, fp, fn)
    
    return {'net_loss': net_loss, 
            'net_accuracy': net_accuracy, 
            'class_accuracy': class_accuracy, 
            'precision': prec, 
            'recall': rec, 
            'dice': dice, 
            'kappa': k,
            'f_meas':f_meas}



    
def unseen_test_patch_batch_robust(dset, criterion, model, feat=False, batch_size=BATCH_SIZE, patch_size=PATCH_SIDE):
    L = 1024
    N = 2
    model.eval()
    tot_loss = 0
    tot_count = 0
    tot_accurate = 0
    model = model.to(device)
    
    n = 2
    class_correct = list(0. for i in range(n))
    class_total = list(0. for i in range(n))
    class_accuracy = list(0. for i in range(n))
    
    tp = 0
    tn = 0
    fp = 0
    fn = 0
    
    for img_index in tqdm(dset.names):
        I1_full, I2_full, cm_full = dset.get_img(img_index)
        full_x, full_y = cm_full.shape
        cm = Variable(torch.unsqueeze(torch.from_numpy(1.0*cm_full),0).float().to(device))
        cm_patch, cm_x, cm_y, cm_px, cm_py = get_mask_patch_batch(cm_full)
        I1_patch, i_x, i_y, i_px, i_py = get_image_patch_batch(I1_full)
        I2_patch, _, _, _, _ = get_image_patch_batch(I2_full)
        
        # first perform inference
        I1 = Variable(I1_patch.float().to(device))
        I2 = Variable(I2_patch.float().to(device))
        if feat:
            output, _ = model(I1, I2)
        else:
            output = model(I1, I2)
        #print(output.shape)
        output_size = (cm_px,cm_py)
        b_s =  1
        cm_recon = reconstruct_image_batch(output, b_s, output_size, patch_size)
        #print(cm_recon.shape)
        output = cm_recon[:, :, :full_x, :full_y]
        #print(output.shape)
        loss = criterion(output, cm.long())
        tot_loss += loss.data * np.prod(cm.size())
        tot_count += np.prod(cm.size())

        _, predicted = torch.max(output.data, 1)

        c = (predicted.int() == cm.data.int())
        for i in range(c.size(1)):
            for j in range(c.size(2)):
                l = int(cm.data[0, i, j])
                class_correct[l] += c[0, i, j]
                class_total[l] += 1
                
        pr = (predicted.int() > 0).cpu().numpy()
        gt = (cm.data.int() > 0).cpu().numpy()
        
        tp += np.logical_and(pr, gt).sum()
        tn += np.logical_and(np.logical_not(pr), np.logical_not(gt)).sum()
        fp += np.logical_and(pr, np.logical_not(gt)).sum()
        fn += np.logical_and(np.logical_not(pr), gt).sum()
        
    net_loss = tot_loss/tot_count        
    net_loss = float(net_loss.cpu().numpy())
    
    net_accuracy = 100 * (tp + tn)/tot_count
    
    for i in range(n):
        class_accuracy[i] = 100 * class_correct[i] / max(class_total[i],0.00001)
        class_accuracy[i] =  float(class_accuracy[i].cpu().numpy())

    prec = tp / (tp + fp)
    rec = tp / (tp + fn)
    f_meas = 2 * prec * rec / (prec + rec)
    dice = 2 * prec * rec / (prec + rec)
    prec_nc = tn / (tn + fn)


    rec_nc = tn / (tn + fp)
    
    pr_rec = [prec, rec, dice, prec_nc, rec_nc]
    
    k = kappa(tp, tn, fp, fn)
    
    return {'net_loss': net_loss, 
            'net_accuracy': net_accuracy, 
            'class_accuracy': class_accuracy, 
            'precision': prec, 
            'recall': rec, 
            'dice': dice, 
            'kappa': k,
            'f_meas':f_meas}

    
def save_test_results(model, dset, net_name, feat=False, batch_size=BATCH_SIZE, patch_size=PATCH_SIDE):
    for name in tqdm(dset.names):
        #print(name)
        with warnings.catch_warnings():
            I1_full, I2_full, cm_full = dset.get_img(name)
            full_x, full_y = cm_full.shape
            cm = Variable(torch.unsqueeze(torch.from_numpy(1.0*cm_full),0).float().to(device))
            cm_patch, cm_x, cm_y, cm_px, cm_py = get_mask_patch_batch(cm_full)
            I1_patch, i_x, i_y, i_px, i_py = get_image_patch_batch(I1_full)
            I2_patch, _, _, _, _ = get_image_patch_batch(I2_full)
            # first perform inference
            I1 = Variable(I1_patch.float().to(device))
            I2 = Variable(I2_patch.float().to(device))
            #output, _ = model(I1, I2)
            if feat:
                output, _ = model(I1, I2)
            else:
                output = model(I1, I2)
            output_size = (cm_px,cm_py)
            b_s =  1
            cm_recon = reconstruct_image_batch(output, b_s, output_size, patch_size)
            #print(cm_recon.shape)
            output = cm_recon[:, :, :full_x, :full_y]
            _, predicted = torch.max(output.data, 1)
            I = np.stack((np.squeeze(cm.cpu().numpy()),np.squeeze(predicted.cpu().numpy()),np.squeeze(cm.cpu().numpy())),2)
            #im = Image.fromarray((I * 255).astype(np.uint8))
            im = (I * 255).astype(np.uint8)
            io.imsave(f'{net_name}-{name}.png',im)
print("Training & testing Functions Defined")


# ================== bounded_ood_transforms.py ==================
import math, random
import numpy as np
import torch
import torch.nn.functional as F
from typing import Optional

# -------- basic helpers --------
def _clamp01(x: torch.Tensor) -> torch.Tensor:
    return x.clamp_(0.0, 1.0)

def _depthwise_box_blur(x: torch.Tensor, k: int):
    """x: (C,H,W). Depthwise box blur with odd kernel size."""
    C, H, W = x.shape
    k = int(k) | 1
    pad = k // 2
    filt = torch.ones(C, 1, k, k, device=x.device, dtype=x.dtype) / (k * k)
    y = F.pad(x.unsqueeze(0), (pad, pad, pad, pad), mode='reflect')
    y = F.conv2d(y, filt, groups=C).squeeze(0)
    return y

def _project_linf(delta: torch.Tensor, eps: float) -> torch.Tensor:
    """Scale delta so that ‖delta‖∞ ≤ eps (per-sample, per-channel)."""
    if delta.dim() == 3:  # CHW
        scale = eps / (delta.abs().amax(dim=(1,2), keepdim=True) + 1e-8)
    else:                 # NCHW
        scale = eps / (delta.abs().amax(dim=(2,3), keepdim=True) + 1e-8)
    scale = torch.clamp(scale, max=1.0)
    return delta * scale

class PerImageMinMax01:
    """Bring I1/I2 to [0,1] per image, per channel; keep label untouched."""
    def __call__(self, sample):
        for k in ('I1','I2'):
            x = sample[k].float()
            xmin = x.amin(dim=(1,2), keepdim=True)
            xmax = x.amax(dim=(1,2), keepdim=True)
            sample[k] = (x - xmin) / (xmax - xmin + 1e-6)
        return sample

# -------- single-image proposals (CHW, in [0,1]) --------
def _prop_lowfreq(x: torch.Tensor, eps: float, seed: int, kernel: int) -> torch.Tensor:
    # blurred noise → additive delta → L∞ projection → clamp
    with torch.random.fork_rng():
        torch.manual_seed(seed)
        n = torch.randn_like(x)
    n = _depthwise_box_blur(n, kernel)
    d = _project_linf(n, eps)
    y = x + d
    return _clamp01(y)

def _prop_per_band(x: torch.Tensor, eps: float, seed: int) -> torch.Tensor:
    # per-channel DC bias in [-eps, eps]
    with torch.random.fork_rng():
        torch.manual_seed(seed)
        C = x.shape[0]
        bias = (2*torch.rand(C,1,1, device=x.device, dtype=x.dtype) - 1.0) * eps
    y = x + bias
    return _clamp01(y)

def _prop_blur(x: torch.Tensor, eps: float, kernel: int) -> torch.Tensor:
    xb = _depthwise_box_blur(x, kernel)
    d  = xb - x
    d  = _project_linf(d, eps)
    y  = x + d
    return _clamp01(y)

def _prop_shadow(x: torch.Tensor, eps: float, seed: int, kernel: int, alpha: float=0.2) -> torch.Tensor:
    # soft multiplicative darkening with low-freq mask, then L∞-project as additive
    with torch.random.fork_rng():
        torch.manual_seed(seed)
        mask = torch.rand(1, x.shape[1], x.shape[2], device=x.device, dtype=x.dtype)
    mask = _depthwise_box_blur(mask, kernel)
    y_raw = x * (1 - alpha*mask)
    d = _project_linf(y_raw - x, eps)
    y = x + d
    return _clamp01(y)

# -------- pair-aware ε-bounded OOD transform (operates on sample dict) --------
class BoundedPairOOD:
    """
    Apply a structured, ε-bounded OOD perturbation to (I1, I2) with CHW layout.
    kind: 'lowfreq' | 'per_band' | 'blur' | 'shadow'
    Use either:
      - mode in {'shared','indep','anti'}, or
      - correlated=True/False (alias: True→'shared', False→'indep').
    Outputs are clamped to [0,1] if assume01=True (recommended).
    """
    def __init__(self, *,
                 kind: str="lowfreq",
                 eps: float=1/255,
                 correlated: Optional[bool]=None,
                 mode: Optional[str]=None,
                 lowfreq_kernel: int=9,
                 blur_kernel: int=5,
                 seed: int=123,
                 assume01: bool=True):
        self.kind = str(kind)
        self.eps  = float(eps)
        self.lowfreq_kernel = int(lowfreq_kernel)
        self.blur_kernel    = int(blur_kernel)
        self.assume01       = bool(assume01)
        self._np = np.random.RandomState(int(seed))

        # alias handling
        if mode is not None and correlated is not None:
            raise ValueError("Specify either 'mode' or 'correlated', not both.")
        if mode is None and correlated is None:
            mode = "indep"
        if correlated is not None:
            mode = "shared" if correlated else "indep"
        if mode not in ("shared","indep","anti"):
            raise ValueError("mode must be one of {'shared','indep','anti'}")
        self.mode = mode

    def _prop(self, x: torch.Tensor, eps: float, seed: int):
        if self.kind == "lowfreq":
            return _prop_lowfreq(x, eps, seed, self.lowfreq_kernel)
        if self.kind == "per_band":
            return _prop_per_band(x, eps, seed)
        if self.kind == "blur":
            return _prop_blur(x, eps, self.blur_kernel)
        if self.kind == "shadow":
            return _prop_shadow(x, eps, seed, self.lowfreq_kernel)
        raise ValueError(f"Unknown kind '{self.kind}'")

    def __call__(self, sample):
        I1, I2, y = sample['I1'].float(), sample['I2'].float(), sample['label']
        # seeds per sample (stable across get_img/__getitem__)
        s = int(self._np.randint(0, 2**31-1))
        s1 = s
        s2 = s if self.mode in ("shared","anti") else (s + 1337)

        # propose
        Y1 = self._prop(I1, self.eps, s1)
        Y2 = self._prop(I2, self.eps, s2)

        if self.mode == "anti":
            # invert date-2 residual sign
            d1 = Y1 - I1
            d2 = Y2 - I2
            Y2 = _clamp01(I2 - _project_linf(d1, self.eps))  # opposite-signed residual magnitude

        if self.assume01:
            Y1 = _clamp01(Y1)
            Y2 = _clamp01(Y2)

        return {'I1': Y1, 'I2': Y2, 'label': y}

# -------- wrapper that makes get_img() path see OOD too --------
class OODDatasetView(torch.utils.data.Dataset):
    """
    Wraps a base ChangeDetectionDataset and applies a pair-aware transform
    in both __getitem__ AND get_img, so code paths using either will see OOD.
    """
    def __init__(self, base_ds, pair_transform, apply_minmax=True):
        self.base = base_ds
        self.transform = pair_transform
        self.apply_minmax = apply_minmax
        # mirror the public fields the test code expects
        self.names = list(base_ds.names)
        self.n_patches_per_image = dict(base_ds.n_patches_per_image)
        self.n_patches = base_ds.n_patches
        self.patch_coords = list(base_ds.patch_coords)

    def __len__(self): return len(self.base)

    def _minmax01(self, x: torch.Tensor) -> torch.Tensor:
        xmin = x.amin(dim=(1,2), keepdim=True)
        xmax = x.amax(dim=(1,2), keepdim=True)
        return (x - xmin) / (xmax - xmin + 1e-6)

    def _apply_pair(self, I1, I2, cm):
        sample = {
            'I1': self._minmax01(I1) if self.apply_minmax else I1,
            'I2': self._minmax01(I2) if self.apply_minmax else I2,
            'label': torch.from_numpy(cm.astype('float32')) if not torch.is_tensor(cm) else cm
        }
        sample = self.transform(sample)
        return sample['I1'], sample['I2'], cm

    # used by your DataLoader path (kept for completeness)
    def __getitem__(self, idx):
        s = self.base.__getitem__(idx)
        sample = {'I1': s['I1'].float(), 'I2': s['I2'].float(), 'label': s['label']}
        # IMPORTANT: apply minmax BEFORE the pair transform if you want ε in [0,1]
        if self.apply_minmax:
            sample = PerImageMinMax01()(sample)
        if self.transform is not None:
            sample = self.transform(sample)
        return sample

    # CRITICAL: your unseen_test_patch_batch uses this path
    def get_img(self, im_name):
        I1 = self.base.imgs_1[im_name]
        I2 = self.base.imgs_2[im_name]
        cm = self.base.change_maps[im_name]
        return self._apply_pair(I1, I2, cm)
# ================== end of file ==================


#Train dataset
# ==== TRAIN DATASET + LOADER (clean) ======================================
# Assumes you already defined:
# - ChangeDetectionDataset
# - RandomFlip, RandomRot
# - seed_worker, g (Generator), device, PATH_TO_DATASET, TRAIN_STRIDE, BATCH_SIZE
# - DATA_AUG (bool)
def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    numpy.random.seed(worker_seed)
    random.seed(worker_seed)

g = torch.Generator()
g.manual_seed(SEED)
# Compose your training transform exactly like before (you can also add PerImageMinMax01 if you want)
if DATA_AUG:
    train_tf = tr.Compose([RandomFlip(), RandomRot()])
else:
    train_tf = None

train_dataset = ChangeDetectionDataset(
    PATH_TO_DATASET, train=True, stride=TRAIN_STRIDE, transform=train_tf
)

# your imbalance weights (unchanged)
dataset_weights = torch.FloatTensor(train_dataset.weights).to(device)
print("Class weights:", dataset_weights)

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    worker_init_fn=seed_worker,
    generator=g
)

print("Train dataloader ready")

# ==== OPTIONAL: OOD-robust training loader (ε-bounded perturbations) ======
# If/when you want to try robust training, flip this on:
USE_OOD_TRAIN = True  # set True to enable
if USE_OOD_TRAIN:
    robust_train_tf = tr.Compose([
        RandomFlip(), RandomRot(),
        PerImageMinMax01(),  # ensure [0,1]
        BoundedPairOOD(kind="lowfreq", eps=1/255,
                       correlated=True, lowfreq_kernel=9,
                       seed=2025, assume01=True)
    ])
    train_dataset_ood = ChangeDetectionDataset(
        PATH_TO_DATASET, train=True, stride=TRAIN_STRIDE, transform=robust_train_tf
    )
    train_loader_ood = DataLoader(
        train_dataset_ood,
        batch_size=BATCH_SIZE,
        shuffle=True,
        worker_init_fn=seed_worker,
        generator=g
    )
    print("OOD-robust train dataloader ready")

# ==== LOSS (example) ======================================================
# Now your earlier line works because dataset_weights is defined:
loss_fn = nn.CrossEntropyLoss(weight=dataset_weights)


# 1) Identity transform that just passes (I1,I2,label) through.
class IdentityPair:
    def __call__(self, sample):
        return sample  # no changes

# 2) Build the base clean dataset exactly as your test API expects (no transform here)
test_dataset = ChangeDetectionDataset(PATH_TO_DATASET, train=False, stride=TRAIN_STRIDE, transform=None)

# 3) Wrap it in views so BOTH clean and OOD go through the same normalization path
#    (set apply_minmax=True to consistently bring inputs to [0,1] per image).
clean_view = OODDatasetView(
    base_ds=test_dataset,
    pair_transform=IdentityPair(),
    apply_minmax=True        # <— SAME normalization as OOD views
)

ood_views = {
    "lf_1": OODDatasetView(test_dataset, BoundedPairOOD(kind="lowfreq", eps=1/255, correlated=True,  lowfreq_kernel=9, seed=123), True),
    "lf_2": OODDatasetView(test_dataset, BoundedPairOOD(kind="lowfreq", eps=2/255, correlated=True,  lowfreq_kernel=9, seed=123), True),
    "shadow": OODDatasetView(test_dataset, BoundedPairOOD(kind="shadow",  eps=1/255, correlated=False, lowfreq_kernel=9, seed=7),  True),
    "pband":  OODDatasetView(test_dataset, BoundedPairOOD(kind="per_band", eps=1/255, correlated=True,                 seed=9),  True),
    "blur_1": OODDatasetView(test_dataset, BoundedPairOOD(kind="blur",    eps=1/255, correlated=False, blur_kernel=5, seed=11), True),
}

# Lipra friendly encoder-decoder
class DoubleConv(nn.Module): 
    def __init__(self, in_channels, out_channels): 
        super().__init__() 
        self.double_conv = nn.Sequential( 
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), 
            nn.BatchNorm2d(out_channels), 
            nn.ReLU(inplace=True), 
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), 
            nn.BatchNorm2d(out_channels), 
            nn.ReLU(inplace=True),
        )
 
    def forward(self, x):
        return self.double_conv(x)
 
class Down(nn.Module): 
    def __init__(self, in_channels, out_channels):
        super().__init__() 
        self.conv = nn.Sequential( 
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1), 
            nn.BatchNorm2d(out_channels), 
            nn.ReLU(inplace=True), 
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), 
            nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), 
        )
 
    def forward(self, x):
        return self.conv(x)
 
class Up(nn.Module): 
    def __init__(self, in_channels, out_channels): 
        super().__init__() 
        self.up = nn.Sequential( 
            nn.Upsample(scale_factor=2, mode='nearest'), 
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) 
        ) 
        self.conv = DoubleConv(in_channels, out_channels)
 
    def forward(self, x1, x2):
        x1 = self.up(x1)
        # ensure spatial alignment by cropping or center crop if needed
        if x1.shape[-2:] != x2.shape[-2:]:
            diffY = x2.size()[2] - x1.size()[2]
            diffX = x2.size()[3] - x1.size()[3]
            x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                            diffY // 2, diffY - diffY // 2])
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)
 
class OutConv(nn.Module): 
    def __init__(self, in_channels, out_channels): 
        super().__init__() 
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
 
    def forward(self, x):
        return self.conv(x)
 
class EncDec_LiRPA(nn.Module): 
    def __init__(self, in_channels=26, out_channels=2): 
        super().__init__() 
        self.inc = DoubleConv(in_channels, 8) 
        self.down1 = Down(8, 16) 
        self.down2 = Down(16, 32) 
        self.down3 = Down(32, 64) 
        self.down4 = Down(64, 128) 
        self.up1 = Up(128, 64) 
        self.up2 = Up(64, 32) 
        self.up3 = Up(32, 16) 
        self.up4 = Up(16, 8) 
        self.outc = OutConv(8, out_channels)
 
    def forward(self, x1, x2):

        x = torch.cat((x1, x2), 1)
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits

# Lipra friendly Local Attention Encoder-decoder Falconet Architecture   
class AttentionCondenser(nn.Module):
    """Lightweight attention using convolutions instead of full self-attention."""
    def __init__(self, d, reduction=8, dropout=0.1):
        super().__init__()
        self.conv1 = nn.Conv2d(d, d // reduction, kernel_size=1)
        self.conv2 = nn.Conv2d(d // reduction, d, kernel_size=1)
        self.sigmoid = nn.Sigmoid()
        self.dropout = nn.Dropout2d(p=dropout)  # Dropout in attention

    def forward(self, x):
        attn = self.conv1(x)
        attn = self.conv2(attn)
        attn = self.dropout(attn)  # Dropout after attention layers
        return x * self.sigmoid(attn)


class ConvAttention(nn.Module):
    def __init__(self, hidden_dim, kernel_size=3, reduction_ratio=8):
        super().__init__()
        reduced_dim = hidden_dim // reduction_ratio

        # Downsample with depthwise conv (local feature extraction)
        self.depthwise_conv = nn.Conv1d(hidden_dim, hidden_dim, kernel_size, padding=kernel_size//2, groups=hidden_dim, bias=False)

        # Pointwise conv (lightweight attention projection)
        self.q_proj = nn.Conv1d(hidden_dim, reduced_dim, 1, bias=False)
        self.k_proj = nn.Conv1d(hidden_dim, reduced_dim, 1, bias=False)
        self.v_proj = nn.Conv1d(hidden_dim, hidden_dim, 1, bias=False)

        # Attention Scoring
        self.score_proj = nn.Conv1d(reduced_dim, hidden_dim, 1, bias=False)

        # Final pointwise conv
        self.out_proj = nn.Conv1d(hidden_dim, hidden_dim, 1, bias=False)

    def forward(self, x):
        """
        x: (batch, seq_len, hidden_dim)
        Output: (batch, seq_len, hidden_dim)
        """
        x = x.transpose(1, 2)  # (batch, hidden_dim, seq_len) for convs

        # Local feature extraction
        x_dw = self.depthwise_conv(x)

        # Query, Key, Value projection
        q = self.q_proj(x_dw)
        k = self.k_proj(x_dw)
        v = self.v_proj(x_dw)

        # Lightweight attention (sigmoid gating)
        attn_scores = self.score_proj(q * k).sigmoid()

        # Apply attention to values
        attn_output = attn_scores * v

        # Final projection
        out = self.out_proj(attn_output)

        # Residual connection
        out = out + x
        return out.transpose(1, 2)  # Back to (batch, seq_len, hidden_dim)


class MultiHeadConvAttention(nn.Module):
    def __init__(self, hidden_dim, num_heads=2, kernel_size=3):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = hidden_dim // num_heads  # Ensure correct head dim

        assert hidden_dim % num_heads == 0, \
            f"hidden_dim ({hidden_dim}) must be divisible by num_heads ({num_heads})"

        self.heads = nn.ModuleList([
            ConvAttention(self.head_dim, kernel_size) for _ in range(num_heads)
        ])
        self.out_proj = nn.Linear(hidden_dim, hidden_dim, bias=False)

    def forward(self, x):
        batch_size, seq_len, embed_dim = x.shape

        #print(f"embed_dim = {embed_dim}, num_heads * head_dim = {self.num_heads * self.head_dim}")  # Debug

        assert embed_dim == self.num_heads * self.head_dim, \
            f"Expected embedding dim {embed_dim} to match {self.num_heads * self.head_dim} (={self.num_heads * self.head_dim})"

        x = x.view(batch_size, seq_len, self.num_heads, self.head_dim)  
        x = x.permute(2, 0, 1, 3)  # (num_heads, batch, seq_len, head_dim)

        x = torch.stack([head(x[i]) for i, head in enumerate(self.heads)], dim=0)

        x = x.permute(1, 2, 0, 3).reshape(batch_size, seq_len, embed_dim)  # Merge heads
        return self.out_proj(x)

class FALCONetMHA_LiRPA(nn.Module):
    """FALCONetMHA_LiRPA segmentation network."""

 
class FALCONetMHA_LiRPA(nn.Module): 
    def __init__(self, in_channels=26, out_channels=2, dropout=0.1, reduction=8, attention=True, num_heads=4):
        """Init FALCONetMHA_LiRPA fields."""
        super(FALCONetMHA_LiRPA, self).__init__()
        self.dropout = dropout
        self.reduction = reduction
        self.attention = attention
        cur_depth = 8
        self.inc = DoubleConv(in_channels, cur_depth) 
        self.down1 = Down(cur_depth, cur_depth*2) 
        cur_depth *= 2

        self.down2 = Down(cur_depth, cur_depth*2) 
        cur_depth *= 2
        if self.attention:
            self.token_mixer_2 = MultiHeadConvAttention(cur_depth, num_heads=num_heads)
        self.down3 = Down(cur_depth, cur_depth*2) 
        cur_depth *= 2
        if self.attention:
            self.token_mixer_3 = MultiHeadConvAttention(cur_depth, num_heads=num_heads)
        self.down4 = Down(cur_depth, cur_depth*2) 
        cur_depth *= 2
        if self.attention:
            self.token_mixer_4 = MultiHeadConvAttention(cur_depth, num_heads=num_heads)
        self.up1 = Up(cur_depth, int(cur_depth/2)) 
        cur_depth = int(cur_depth/2)
        self.up2 = Up(cur_depth, int(cur_depth/2))
        cur_depth = int(cur_depth/2)
        self.up3 = Up(cur_depth, int(cur_depth/2))
        cur_depth = int(cur_depth/2)
        self.up4 = Up(cur_depth, int(cur_depth/2)) 
        cur_depth = int(cur_depth/2)
        self.outc = OutConv(cur_depth, out_channels)
 
    def forward(self, x1, x2):

        x = torch.cat((x1, x2), 1)
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        if self.attention:
            # Flatten spatial dimensions
            B, C, H, W = x3.shape
            N = H * W  # Number of tokens
            x3 = x3.view(B, C, N).transpose(1, 2)  # (B, N, C)
            
            # Apply Lightweight Convolutional Multi-Head Attention
            x3 = self.token_mixer_2(x3)  
            
            # Restore spatial shape
            x3 = x3.transpose(1, 2).view(B, C, H, W)  # Restore spatial shape
        x4 = self.down3(x3)
        if self.attention:
            # Flatten spatial dimensions
            B, C, H, W = x4.shape
            N = H * W  # Number of tokens
            x4 = x4.view(B, C, N).transpose(1, 2)  # (B, N, C)
            
            # Apply Lightweight Convolutional Multi-Head Attention
            x4 = self.token_mixer_3(x4)  
            
            # Restore spatial shape
            x4 = x4.transpose(1, 2).view(B, C, H, W)  # Restore spatial shape
        x5 = self.down4(x4)
        if self.attention:
            # Flatten spatial dimensions
            B, C, H, W = x5.shape
            N = H * W  # Number of tokens
            x5 = x5.view(B, C, N).transpose(1, 2)  # (B, N, C)
            
            # Apply Lightweight Convolutional Multi-Head Attention
            x5 = self.token_mixer_4(x5)  
            
            # Restore spatial shape
            x5 = x5.transpose(1, 2).view(B, C, H, W)  # Restore spatial shape
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits



IMPORTS OK
DEFINITIONS OK
UTILS OK
Training & testing Functions Defined


100%|███████████████████████████████████████████████████████████████████████████████████| 24/24 [00:22<00:00,  1.08it/s]


Class weights: tensor([0.6437, 1.9356])
Train dataloader ready


100%|███████████████████████████████████████████████████████████████████████████████████| 24/24 [00:20<00:00,  1.18it/s]


OOD-robust train dataloader ready


100%|███████████████████████████████████████████████████████████████████████████████████| 10/10 [00:05<00:00,  1.91it/s]


In [2]:
# ## Imports for training

# %%
import os, math, time, random, json
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

import matplotlib.pyplot as plt

# For tables/CSV
import pandas as pd

# Try optional auto_LiRPA
try:
    from auto_LiRPA import BoundedModule, BoundedTensor
    from auto_LiRPA.perturbations import PerturbationLpNorm
    _HAS_LIRPA = True
except Exception:
    _HAS_LIRPA = False


# %% [markdown]
# ## Utilities: device, clipping, simple Gaussian blur

# %%
def get_device():
    return torch.device("cuda" if torch.cuda.is_available() else "cpu")

device = get_device()

def clip_01(x):
    return x.clamp_(0.0, 1.0)

def gaussian_blur2d(x, sigma=1.2, kernel_size=5):
    """Depthwise Gaussian blur on BCHW (no padding change)."""
    if sigma <= 0 or kernel_size < 3:
        return x
    half = kernel_size // 2
    t = torch.arange(-half, half+1, device=x.device, dtype=x.dtype)
    g1 = torch.exp(-(t**2)/(2*sigma**2))
    g1 = g1 / g1.sum()
    g2 = g1[:, None] * g1[None, :]
    g2 = g2[None, None, :, :]  # 1x1xKxK
    c = x.shape[1]
    w = g2.expand(c, 1, kernel_size, kernel_size).contiguous()
    return F.conv2d(x, w, padding=half, groups=c)

# %% [markdown]
# ## Physically grounded OOD transforms (fast, differentiable)


import math
import torch
import torch.nn.functional as F

def shadow(img, eps):
    """Multiplicative soft shadow with ≤ eps darkness (B,C,H,W)->(B,C,H,W)."""
    B, C, H, W = img.shape
    # coordinate grids on the same device/dtype as img
    yy = torch.linspace(-1, 1, H, device=img.device, dtype=img.dtype).view(H, 1).expand(H, W)
    xx = torch.linspace(-1, 1, W, device=img.device, dtype=img.dtype).view(1, W).expand(H, W)
    angle = random.uniform(0.0, math.pi)
    # use math.cos/sin (floats) so float * tensor broadcasting works
    mask = math.cos(angle) * xx + math.sin(angle) * yy
    mask = (mask - mask.min()) / (mask.max() - mask.min() + 1e-8)          # [0,1]
    mask = mask.unsqueeze(0).unsqueeze(0).repeat(B, 1, 1, 1)                # Bx1xHxW
    mask = gaussian_blur2d(mask, sigma=3.0, kernel_size=21)

    dip = float(eps) * random.uniform(0.5, 1.0)                             # ≤ eps
    factor = 1.0 - dip * mask                                               # ≥ 1-eps
    return clip_01(img * factor)

def passband_shift(img, eps):
    """Per-band affine jitter: x' = x*(1+α) + β, with |α|,|β| ≤ eps/2 (B,C,H,W)."""
    B, C, H, W = img.shape
    rng = float(eps) * 0.5
    a_c = (torch.rand(B, C, 1, 1, device=img.device, dtype=img.dtype) - 0.5) * (2 * rng)
    b_c = (torch.rand(B, C, 1, 1, device=img.device, dtype=img.dtype) - 0.5) * (2 * rng)
    out = img * (1 + a_c) + b_c
    return clip_01(out)

def lf_drift(img, eps):
    """Low-frequency additive drift, per-band (bounded by eps in L∞)."""
    B, C, H, W = img.shape
    scale = random.choice([8, 16])
    small = torch.randn(B, C, H//scale, W//scale, device=img.device, dtype=img.dtype) * 0.2
    field = F.interpolate(small, size=(H, W), mode='bilinear', align_corners=False)
    field = gaussian_blur2d(field, sigma=1.0, kernel_size=5)
    # hard bound amplitude to eps
    max_abs = field.abs().amax(dim=(2,3), keepdim=True).clamp_min(1e-8)
    field = field / max_abs * float(eps)
    return clip_01(img + field)


def blur(img, eps):
    """Mild blur; strength proportional to eps."""
    sigma = max(0.5, float(eps) * 64.0)  # tuned for eps in 1/255..2/255
    k = 5 if sigma < 1.5 else 7
    return gaussian_blur2d(img, sigma=sigma, kernel_size=k)

def apply_phys_noise_pair(I1, I2, kind, eps):
    if kind == 'lf':
        return lf_drift(I1, eps), lf_drift(I2, eps)
    if kind == 'shadow':
        return shadow(I1, eps), shadow(I2, eps)
    if kind == 'pband':
        return passband_shift(I1, eps), passband_shift(I2, eps)
    if kind == 'blur':
        return blur(I1, eps), blur(I2, eps)
    raise ValueError(f"Unknown kind: {kind}")

def sample_mixture_kind():
    return random.choice(['lf', 'shadow', 'pband', 'blur'])

# %% [markdown]
# ## HCT v2 (no LiRPA): drop-in training loop
#
# * Same signature flavor as your `train(...)`, so loaders + `test_patch_batch` work unchanged.
# * Adds: ε-ramp, head-margin hinge on OOD views, EMA, grad-clip, optional Dice on clean.

# %%
class DiceLoss(nn.Module):
    def __init__(self, eps=1e-6):
        super().__init__()
        self.eps = eps
    def forward(self, logits, target):
        # binary dice on "change" channel
        num_classes = logits.shape[1]
        if num_classes < 2:  # safety
            return torch.tensor(0.0, device=logits.device)
        probs = torch.softmax(logits, dim=1)[:, 1]  # change prob
        target_bin = (target == 1).float()
        inter = (probs * target_bin).sum()
        union = probs.sum() + target_bin.sum()
        dice = (2*inter + self.eps)/(union + self.eps)
        return 1.0 - dice

def clean_pred_argmax(logits):
    # returns (B,H,W) in {0,1}
    return torch.argmax(logits, dim=1)

def margin_for_clean_class(perturbed_logits, clean_pred):
    # m = z[c*] - z[~c*]
    # shape handling
    z = perturbed_logits
    cstar = clean_pred  # (B,H,W)
    other = 1 - cstar
    B, C, H, W = z.shape
    idx = cstar.unsqueeze(1)  # B x1 xH xW
    idx_other = other.unsqueeze(1)
    z_c = torch.gather(z, 1, idx)[:,0]        # BxHxW
    z_o = torch.gather(z, 1, idx_other)[:,0]  # BxHxW
    return z_c - z_o

@torch.no_grad()
def _evaluate_epoch(train_dataset, test_dataset, criterion_ce, model, feat=False, test_patch_batch=None):
    # Uses your provided test_patch_batch
    if test_patch_batch is None:
        return {}
    train_loss, train_acc, cl_acc, pr_rec = test_patch_batch(train_dataset, criterion_ce, model, feat=feat)
    test_loss, test_acc, cl_acc_t, pr_rec_t = test_patch_batch(test_dataset, criterion_ce, model, feat=feat)
    return dict(
        train_loss=float(train_loss), train_accuracy=float(train_acc),
        train_nochange_accuracy=float(cl_acc[0]), train_change_accuracy=float(cl_acc[1]),
        train_precision=float(pr_rec[0]), train_recall=float(pr_rec[1]), train_f1=float(pr_rec[2]),
        test_loss=float(test_loss), test_accuracy=float(test_acc),
        test_nochange_accuracy=float(cl_acc_t[0]), test_change_accuracy=float(cl_acc_t[1]),
        test_precision=float(pr_rec_t[0]), test_recall=float(pr_rec_t[1]), test_f1=float(pr_rec_t[2]),
    )

def train_HCT_noLiRPA_v2(
    model, net_name, criterion_ce, optimizer, scheduler,
    train_loader, train_dataset, test_dataset,
    feat=False, n_epochs=50, save=True,
    # HCT knobs:
    eps_max=1.0/255.0, ramp_frac=0.4, K=1, lambda_hct=0.35, tau=0.1,
    use_dice=False, lambda_dice=0.1,
    grad_clip=1.0, ema_decay=0.0,
    test_patch_batch=None,
):
    """
    Drop-in trainer that preserves your data pipeline & metrics.

    - Clean CE (and optional Dice) on clean inputs.
    - HCT margin hinge on 1–2 perturbed copies per batch (physically grounded OOD).
    - Linear ε ramp to eps_max over ramp_frac of epochs.

    Returns: dict with final metrics; saves best by F1 and by loss (like your train()).
    """
    model = model.to(device)
    scaler = None  # you can add AMP if you like
    ema_shadow = None
    if ema_decay > 0.0:
        ema_shadow = {k: v.detach().clone() for k, v in model.state_dict().items()}

    dice_loss = DiceLoss() if use_dice else None

    best_f1 = -1e9
    best_loss = 1e9
    history = []

    for epoch in range(1, n_epochs+1):
        model.train()
        # ε schedule
        progress = epoch / max(1, n_epochs)
        eps = eps_max * min(1.0, progress / max(1e-8, ramp_frac))

        for batch in train_loader:
            I1 = batch['I1'].float().to(device)
            I2 = batch['I2'].float().to(device)
            label = torch.squeeze(batch['label'].to(device)).long()

            optimizer.zero_grad()

            # Clean forward
            logits_clean = model(I1, I2) if not feat else model(I1, I2)[0]
            loss = criterion_ce(logits_clean, label)

            if use_dice:
                loss = loss + lambda_dice * dice_loss(logits_clean, label)

            # Head-consistency loss
            with torch.no_grad():
                clean_pred = clean_pred_argmax(logits_clean)  # BxHxW

            hct_losses = []
            for _ in range(int(K)):
                kind = sample_mixture_kind()
                I1_pert, I2_pert = apply_phys_noise_pair(I1, I2, kind=kind, eps=eps)
                logits_pert = model(I1_pert, I2_pert) if not feat else model(I1_pert, I2_pert)[0]
                margin = margin_for_clean_class(logits_pert, clean_pred)  # BxHxW
                # hinge on (tau - margin), i.e., penalize when margin < tau
                h = torch.clamp(tau - margin, min=0.0).mean()
                hct_losses.append(h)
            if len(hct_losses) > 0:
                loss = loss + lambda_hct * torch.stack(hct_losses).mean()

            loss.backward()
            if grad_clip is not None and grad_clip > 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            optimizer.step()

            # EMA
            if ema_shadow is not None:
                with torch.no_grad():
                    for k, v in model.state_dict().items():
                        ema_shadow[k].mul_(ema_decay).add_(v, alpha=1.0 - ema_decay)

        if scheduler is not None:
            scheduler.step()

        # Eval
        eval_out = _evaluate_epoch(train_dataset, test_dataset, criterion_ce, model, feat=feat, test_patch_batch=test_patch_batch)
        history.append(dict(epoch=epoch, eps=eps, **eval_out))
        # Save like your original code
        fm = eval_out.get('train_f1', 0.0)
        if fm > best_f1:
            best_f1 = fm
            torch.save(model.state_dict(), f'{net_name}-best_f1-epoch{epoch}.pth.tar')
        lss = eval_out.get('train_loss', 1e9)
        if lss < best_loss:
            best_loss = lss
            torch.save(model.state_dict(), f'{net_name}-best_loss-epoch{epoch}.pth.tar')

        print(f"[Epoch {epoch}/{n_epochs}] eps={eps:.5f} | "
              f"train F1={eval_out.get('train_f1',0):.3f} test F1={eval_out.get('test_f1',0):.3f}")

    # Optionally restore EMA weights for export
    if ema_shadow is not None:
        model.load_state_dict(ema_shadow, strict=False)

    # Final eval
    final = _evaluate_epoch(train_dataset, test_dataset, criterion_ce, model, feat=feat, test_patch_batch=test_patch_batch)
    final['history'] = history
    return final


# %% [markdown]
# ## Ablations from logs (no re-verification needed)
# Expect CSV with columns like:
#  model, tap, eps, pp_pass_pct, pp_med_coverage, pp_med_spill, pred_strict_pass_pct, ...
# (Use whatever your logger emits; adapt column names below.)

# %%
def ablate_tap_from_logs(csv_path, model_filter=None):
    df = pd.read_csv(csv_path)
    if model_filter:
        df = df[df['model'].isin(model_filter)]
    # Group by (model, tap, eps)
    g = df.groupby(['model','tap','eps'], as_index=False).agg({
        'pp_pass_pct':'mean',
        'pp_med_coverage':'mean'
    }).rename(columns={'pp_pass_pct':'pp_pass_%','pp_med_coverage':'pp_cov'})
    return g

def plot_tap_summary(g, title="Tap placement vs PP pass/coverage"):
    models = g['model'].unique()
    fig, axes = plt.subplots(len(models), 1, figsize=(6, 3*len(models)), sharex=True)
    if len(models)==1:
        axes=[axes]
    for ax, m in zip(axes, models):
        sub = g[g['model']==m]
        for tap_name in sorted(sub['tap'].unique()):
            sub2 = sub[sub['tap']==tap_name]
            ax.plot(sub2['eps'], sub2['pp_pass_%'], marker='o', label=f"{tap_name} (pass%)")
            ax.plot(sub2['eps'], sub2['pp_cov']*100.0, marker='x', linestyle='--', label=f"{tap_name} (cov%×100)")
        ax.set_title(m)
        ax.set_ylabel("Pass% / Cov%")
        ax.grid(True, alpha=0.3)
        ax.legend(fontsize=8, ncol=2)
    axes[-1].set_xlabel("ε")
    fig.suptitle(title)
    fig.tight_layout()
    return fig

def predicate_sweep_from_logs(csv_path, rho_list, gamma_list, smin_list):
    """
    Re-compute pass% under sweeps if your CSV has raw overlap/fp/cc stats per scene.
    Expected per-row columns: 'overlap', 'fp_ratio', 'min_cc_ok' (0/1), plus 'model','eps'
    """
    df = pd.read_csv(csv_path)
    out = []
    for rho in rho_list:
        for gamma in gamma_list:
            for smin in smin_list:
                # if 'min_cc_size' exists, threshold it; else expect 'min_cc_ok'
                if 'min_cc_size' in df.columns:
                    ok_pattern = (df['min_cc_size'] >= smin)
                else:
                    ok_pattern = (df['min_cc_ok'] >= 1)
                ok_overlap = (df['overlap'] >= rho)
                ok_fp = (df['fp_ratio'] <= gamma)
                strict_pass = (ok_overlap & ok_fp & ok_pattern).astype(np.float32)

                agg = df.groupby(['model','eps'], as_index=False).agg(
                    pass_pct = ('scene_id', lambda x: strict_pass.loc[x.index].mean()*100.0)
                )
                agg['rho']=rho; agg['gamma']=gamma; agg['smin']=smin
                out.append(agg)
    out = pd.concat(out, ignore_index=True)
    return out

def plot_predicate_sweep(sweep_df, model, eps, title="Predicate sweep (pass %)"):
    sub = sweep_df[(sweep_df['model']==model) & (sweep_df['eps']==eps)]
    # pivot on rho x gamma (fix smin to first)
    smin_vals = sorted(sub['smin'].unique())
    if len(smin_vals)>1:
        sub = sub[sub['smin']==smin_vals[0]]
    P = sub.pivot(index='rho', columns='gamma', values='pass_pct').sort_index()
    fig, ax = plt.subplots(figsize=(5,4))
    im = ax.imshow(P.values, origin='lower', aspect='auto')
    ax.set_xticks(range(P.shape[1])); ax.set_xticklabels([f"{c:.2f}" for c in P.columns], rotation=45)
    ax.set_yticks(range(P.shape[0])); ax.set_yticklabels([f"{r:.2f}" for r in P.index])
    ax.set_xlabel("γ (FP cap)"); ax.set_ylabel("ρ (overlap)")
    ax.set_title(f"{title}: {model}, ε={eps}")
    fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04, label="Pass %")
    fig.tight_layout()
    return fig

# %% [markdown]
# ## PP calibration vs GT (per-scene, image-level)
# Provide clean pred (H×W), certified mask (H×W, bool), and GT mask (H×W).
# Returns coverage, PP-precision wrt GT, PP-recall wrt GT, PP-F1 wrt GT.

# %%
def pp_calibration_scene(clean_pred_hw, cert_mask_hw, gt_hw):
    """
    clean_pred_hw: np.uint8 or bool; 1=change, 0=no-change
    cert_mask_hw: np.bool_; PP-certified pixels mask
    gt_hw:        np.uint8/bool; 1=change, 0=no-change
    """
    cp = clean_pred_hw.astype(np.uint8)
    cm = cert_mask_hw.astype(bool)
    gt = gt_hw.astype(np.uint8)

    cov = cm.mean().item() if cm.size>0 else 0.0  # fraction of certified pixels
    # Evaluate "PP correctness wrt GT": where certified pixels agree with GT?
    # Option A: How many certified pixels that clean predicted as change are truly change?
    # We report standard precision/recall restricted to cm==True.
    mask = cm
    if mask.sum()==0:
        return dict(coverage=cov, pp_precision=0.0, pp_recall=0.0, pp_f1=0.0)

    # Within certified pixels, treat clean_pred as "prediction", compare against GT.
    tp = np.logical_and(mask, np.logical_and(cp==1, gt==1)).sum()
    fp = np.logical_and(mask, np.logical_and(cp==1, gt==0)).sum()
    fn = np.logical_and(mask, np.logical_and(cp==0, gt==1)).sum()

    prec = tp / (tp+fp+1e-8)
    rec  = tp / (tp+fn+1e-8)
    f1   = 2*prec*rec/(prec+rec+1e-8)
    return dict(coverage=float(cov), pp_precision=float(prec), pp_recall=float(rec), pp_f1=float(f1))

def pp_calibration_batch(clean_pred_list, cert_mask_list, gt_list):
    rows = []
    for cp, cm, gt in zip(clean_pred_list, cert_mask_list, gt_list):
        rows.append(pp_calibration_scene(cp, cm, gt))
    df = pd.DataFrame(rows)
    return df.describe().loc[['mean','50%']].rename(index={'50%':'median'})

# %% [markdown]
# ## Optional: β-CROWN (or cut-plane) on the head only
# If auto_LiRPA is installed, compute α- vs β-CROWN LB on the tail subgraph and summarize.

# %%
def beta_crown_head_lbs(model_tail, z_center, z_radius, margin_matrix_fn, method='alpha-crown'):
    """
    model_tail: nn.Module mapping tail input z -> logits (B,2,H,W)
    z_center, z_radius: tensors defining L∞ box: z in [z_center - z_radius, z_center + z_radius]
    margin_matrix_fn: callable that maps logits to (B,H,W) margin = z_change - z_nochange

    Returns (lb_per_pixel) lower bounds with the chosen method if auto_LiRPA is present.
    """
    if not _HAS_LIRPA:
        raise RuntimeError("auto_LiRPA not available; install to run β-CROWN on the head.")

    model_tail.eval()
    B, C, H, W = z_center.shape
    dummy_in = z_center.clone().detach()
    bounded_net = BoundedModule(model_tail, dummy_in, device=dummy_in.device)

    # Define perturbed input
    ptb = PerturbationLpNorm(norm=np.inf, x_L=(z_center - z_radius), x_U=(z_center + z_radius))
    x = BoundedTensor(dummy_in, ptb)

    # Forward to get logits
    logits = bounded_net(x)
    # We want LB on margin m = z_change - z_nochange
    # Create a linear spec with c* such that C * logits = margin
    # Here: for each pixel, margin = [1,-1] · logits[channel]
    # auto_LiRPA supports batched specifications; we flatten pixels.
    logits_flat = logits.view(B, 2, -1)  # 2 x (H*W)
    num_pos = logits_flat.shape[-1]
    # Build spec: shape (B, num_spec, num_class)
    C = torch.zeros(B, num_pos, 2, device=logits.device)
    C[:, :, 0] = 1.0  # + for 'change'
    C[:, :, 1] = -1.0 # - for 'no-change'

    lb, ub = bounded_net.compute_bounds(x=(x,), C=C, method=method)  # 'alpha-crown' or 'beta-crown'
    lb = lb.view(B, H, W)
    return lb

def summarize_lb_improvement(lb_alpha, lb_beta):
    delta = (lb_beta - lb_alpha).detach().cpu().numpy()
    return dict(
        mean=float(delta.mean()),
        median=float(np.median(delta)),
        p95=float(np.percentile(delta, 95)),
        frac_improved=float((delta>1e-8).mean()),
    )

def latex_table_alpha_vs_beta(stats_list, caption="Head LB gains: β-CROWN vs α-CROWN", label="tab:beta_vs_alpha"):
    """
    stats_list: list of dicts with keys: model, eps, mean, median, p95, frac_improved
    """
    lines = [
        "\\begin{table}[t]",
        "\\centering",
        f"\\caption{{\\textbf{{{caption}}}}}",
        f"\\label{{{label}}}",
        "\\setlength{\\tabcolsep}{6pt}",
        "\\scriptsize",
        "\\begin{tabular}{l c c c c c}",
        "\\toprule",
        "Model & $\\varepsilon$ & Mean $\\Delta$LB & Median $\\Delta$LB & 95p $\\Delta$LB & Frac($\\Delta>0$) \\\\",
        "\\midrule"
    ]
    for s in stats_list:
        lines.append(f"{s['model']} & {s['eps']} & {s['mean']:.4f} & {s['median']:.4f} & {s['p95']:.4f} & {s['frac_improved']*100:.1f}\\% \\\\")
    lines += ["\\bottomrule", "\\end{tabular}", "\\end{table}"]
    return "\n".join(lines)



In [3]:
# %% [markdown]
# ## Minimal usage sketches (adapt names to your codebase)
#
# ```
model = FALCONetMHA_LiRPA(2*13, 2 , dropout=0.1, reduction=8, attention=True, num_heads=4)
loss_fn = nn.CrossEntropyLoss(weight=dataset_weights)
opt = torch.optim.Adam(model.parameters(), lr=5e-4, weight_decay=1e-4)
sch = torch.optim.lr_scheduler.ExponentialLR(opt, 0.95)

out = train_HCT_noLiRPA_v2(
    model, "FALCONet_HCTv3", loss_fn, opt, sch,
    train_loader, train_dataset, test_dataset,
    feat=False, n_epochs=60, eps_max=8/255, K=1,
    lambda_hct=0.45, tau=0.1, use_dice=True, lambda_dice=0.1,
    grad_clip=1.0, ema_decay=0.0,
    test_patch_batch=test_patch_batch
)
print(out)

torch.save(model.state_dict(), 'FALCONetMHA_HCT_vfree_v3.pth.tar')
print('SAVE OK')
# # Tap ablation from logs
# g = ablate_tap_from_logs("predicate_pass_logf8.csv", model_filter=["FALCONet_HCTv3","AttUNet","EncDec"])
# fig = plot_tap_summary(g); fig.savefig("tap_vs_ppf8.png", dpi=200)

# # Predicate sweep
# sw = predicate_sweep_from_logs("per_scene_statsf8.csv",
#                                rho_list=[0.10,0.20,0.30], gamma_list=[0.30,0.50,0.70], smin_list=[8,16])
# fig = plot_predicate_sweep(sw, model="FALCONet_HCTv3", eps="1/255"); fig.savefig("sweep_f8_eps1.png", dpi=200)

# # PP calibration vs GT (per scene):
# #   collect per-scene arrays: clean_pred (H,W), cert_mask (H,W), gt (H,W)
# df_pp = pp_calibration_batch(clean_preds, cert_masks, gts)
# print(df_pp)

# # β-CROWN on head (optional; requires auto_LiRPA):
# if _HAS_LIRPA:
#     lb_a = beta_crown_head_lbs(model_tail, z_center, z_rad, margin_matrix_fn, method='alpha-crown')
#     lb_b = beta_crown_head_lbs(model_tail, z_center, z_rad, margin_matrix_fn, method='beta-crown')
#     stats = summarize_lb_improvement(lb_a, lb_b)
#     print(stats)


[Epoch 1/60] eps=0.00131 | train F1=0.305 test F1=0.362
[Epoch 2/60] eps=0.00261 | train F1=0.392 test F1=0.454
[Epoch 3/60] eps=0.00392 | train F1=0.488 test F1=0.527
[Epoch 4/60] eps=0.00523 | train F1=0.497 test F1=0.564
[Epoch 5/60] eps=0.00654 | train F1=0.403 test F1=0.464
[Epoch 6/60] eps=0.00784 | train F1=0.484 test F1=0.548
[Epoch 7/60] eps=0.00915 | train F1=0.493 test F1=0.544
[Epoch 8/60] eps=0.01046 | train F1=0.570 test F1=0.619
[Epoch 9/60] eps=0.01176 | train F1=0.562 test F1=0.609
[Epoch 10/60] eps=0.01307 | train F1=0.581 test F1=0.619
[Epoch 11/60] eps=0.01438 | train F1=0.536 test F1=0.588
[Epoch 12/60] eps=0.01569 | train F1=0.547 test F1=0.605
[Epoch 13/60] eps=0.01699 | train F1=0.552 test F1=0.615
[Epoch 14/60] eps=0.01830 | train F1=0.510 test F1=0.552
[Epoch 15/60] eps=0.01961 | train F1=0.605 test F1=0.647
[Epoch 16/60] eps=0.02092 | train F1=0.546 test F1=0.592
[Epoch 17/60] eps=0.02222 | train F1=0.582 test F1=0.633
[Epoch 18/60] eps=0.02353 | train F1=0.5