In [None]:
import os
import random
from PIL import Image
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T

class TamperDataset(Dataset):
    def __init__(self, pairs, img_size=512, augment=False):
        self.pairs = pairs
        self.augment = augment
        self.tf_img = T.Compose([T.Resize((img_size,img_size)), T.ToTensor()])
        self.tf_mask = T.Compose([T.Resize((img_size,img_size), interpolation=Image.NEAREST), T.ToTensor()])
    def __len__(self):
        return len(self.pairs)
    def __getitem__(self, idx):
        p, t, m = self.pairs[idx]
        img = Image.open(t).convert("RGB")
        mask = Image.open(m).convert("L")
        if self.augment and random.random() > 0.5:
            if random.random() > 0.5:
                img = T.functional.hflip(img); mask = T.functional.hflip(mask)
            if random.random() > 0.5:
                img = T.functional.vflip(img); mask = T.functional.vflip(mask)
        img_t = self.tf_img(img)
        mask_t = (self.tf_mask(mask) > 0.5).float()
        return img_t, mask_t

class ResidualBlock(nn.Module):
    def __init__(self, in_ch, out_ch, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_ch)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_ch)
        if in_ch != out_ch or stride != 1:
            self.down = nn.Sequential(nn.Conv2d(in_ch, out_ch, 1, stride=stride, bias=False), nn.BatchNorm2d(out_ch))
        else:
            self.down = nn.Identity()
    def forward(self, x):
        res = self.down(x)
        out = self.conv1(x); out = self.bn1(out); out = self.relu(out)
        out = self.conv2(out); out = self.bn2(out)
        out += res
        out = self.relu(out)
        return out

class Down(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.block = ResidualBlock(in_ch, out_ch, stride=2)
    def forward(self, x):
        return self.block(x)

class Up(nn.Module):
    def __init__(self, in_ch, skip_ch, out_ch):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_ch, out_ch, 2, stride=2)
        self.block = ResidualBlock(out_ch + skip_ch, out_ch)
    def forward(self, x, skip):
        x = self.up(x)
        diffY = skip.size(2) - x.size(2)
        diffX = skip.size(3) - x.size(3)
        x = nn.functional.pad(x, [diffX//2, diffX-diffX//2, diffY//2, diffY-diffY//2])
        x = torch.cat([skip, x], dim=1)
        x = self.block(x)
        return x

class ResUNet(nn.Module):
    def __init__(self, in_ch=3, out_ch=1, base_filters=32):
        super().__init__()
        C = base_filters

        self.inc = ResidualBlock(in_ch, C)
        self.down1 = Down(C, C*2)     # 512 → 256
        self.down2 = Down(C*2, C*4)   # 256 → 128
        self.down3 = Down(C*4, C*8)   # 128 → 64
        self.down4 = Down(C*8, C*8)   # 64 → 32

        self.bridge = ResidualBlock(C*8, C*8)

        self.up4 = Up(C*8,   C*8, C*8)   # 32 → 64
        self.up3 = Up(C*8,   C*8, C*4)   # 64 → 128
        self.up2 = Up(C*4,   C*4, C*2)   # 128 → 256
        self.up1 = Up(C*2,   C*2, C)     # 256 → 512

        # ⭐ FIX: final upsample so output matches mask resolution ⭐
        self.final_up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False)

        self.outc = nn.Conv2d(C, out_ch, 1)

    def forward(self, x):
        e0 = self.inc(x)
        e1 = self.down1(e0)
        e2 = self.down2(e1)
        e3 = self.down3(e2)
        e4 = self.down4(e3)

        b = self.bridge(e4)

        d4 = self.up4(b, e4)
        d3 = self.up3(d4, e3)
        d2 = self.up2(d3, e2)
        d1 = self.up1(d2, e1)

        out = self.outc(d1)

        # ⭐ Output now becomes 512×512 ⭐
        out = self.final_up(out)

        return out


def dice_loss(pred, target, eps=1e-7):
    pred = torch.sigmoid(pred)
    inter = (pred * target).sum(dim=[1,2,3])
    union = pred.sum(dim=[1,2,3]) + target.sum(dim=[1,2,3])
    loss = 1 - ((2*inter + eps) / (union + eps))
    return loss.mean()

def iou_metric(pred, target, thr=0.5, eps=1e-7):
    pred = torch.sigmoid(pred)
    pred = (pred>thr).float()
    inter = (pred * target).sum(dim=[1,2,3])
    union = pred.sum(dim=[1,2,3]) + target.sum(dim=[1,2,3]) - inter
    iou = (inter + eps) / (union + eps)
    return iou.mean().item()

def build_pairs(root):
    camera_models = ["Canon_60D", "Nikon_D7000", "Nikon_D90", "Sony_A57"]
    pairs = []
    for cam in camera_models:
        base = os.path.join(root, cam)
        p_dir = os.path.join(base, "pristine")
        t_dir = os.path.join(base, "tampered-realistic")
        m_dir = os.path.join(base, "ground-truth")
        for f in os.listdir(p_dir):
            if f.lower().endswith(".tif") or f.lower().endswith(".tiff"):
                n = f.rsplit(".",1)[0]
                p = os.path.join(p_dir, f)
                t = os.path.join(t_dir, f)
                m = os.path.join(m_dir, n + ".PNG")
                if os.path.exists(t) and os.path.exists(m):
                    pairs.append((p, t, m))
    random.shuffle(pairs)
    return pairs

def train_loop(model, optimizer, criterion_bce, train_loader, device):
    model.train()
    running_loss = 0.0
    pbar = tqdm(train_loader, desc="Training")
    for imgs, masks in pbar:
        imgs = imgs.to(device)
        masks = masks.to(device)
        preds = model(imgs)
        loss = criterion_bce(preds, masks) + dice_loss(preds, masks)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        pbar.set_postfix({"loss": running_loss/(pbar.n+1)})
    return running_loss/len(train_loader)

def valid_loop(model, criterion_bce, val_loader, device):
    model.eval()
    val_loss = 0.0
    ious = []
    with torch.no_grad():
        for imgs, masks in tqdm(val_loader, desc="Validating"):
            imgs = imgs.to(device)
            masks = masks.to(device)
            preds = model(imgs)
            loss = criterion_bce(preds, masks) + dice_loss(preds, masks)
            val_loss += loss.item()
            ious.append(iou_metric(preds, masks))
    mean_iou = float(np.mean(ious)) if len(ious)>0 else 0.0
    return val_loss/len(val_loader), mean_iou

if __name__ == "__main__":
    root = r"C:\Users\shrey\Desktop\Tamper Localization Project\realistic-tampering-dataset\data-images"
    pairs = build_pairs(root)
    split = int(0.85 * len(pairs))
    train_pairs = pairs[:split]
    val_pairs = pairs[split:]
    train_ds = TamperDataset(train_pairs, img_size=512, augment=True)
    val_ds = TamperDataset(val_pairs, img_size=512, augment=False)
    train_loader = DataLoader(train_ds, batch_size=4, shuffle=True, num_workers=0, pin_memory=False)
    val_loader = DataLoader(val_ds, batch_size=4, shuffle=False, num_workers=0, pin_memory=False)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = ResUNet().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    bce = nn.BCEWithLogitsLoss()
    best_iou = 0.0
    epochs = 30
    for epoch in range(1, epochs+1):
        train_loss = train_loop(model, optimizer, bce, train_loader, device)
        val_loss, mean_iou = valid_loop(model, bce, val_loader, device)
        if mean_iou > best_iou:
            best_iou = mean_iou
            torch.save(model.state_dict(), "best_resunet.pth")
        torch.save(model.state_dict(), "last_resunet.pth")
        print(f"Epoch {epoch} finished. Train Loss: {train_loss:.4f} Val Loss: {val_loss:.4f} IOU: {mean_iou:.4f} Best IOU: {best_iou:.4f}")
    torch.save(model.state_dict(), "final_resunet.pth")


Training: 100%|█████████████████████████████████████████████████████████████| 47/47 [07:18<00:00,  9.32s/it, loss=1.38]
Validating: 100%|████████████████████████████████████████████████████████████████████████| 9/9 [00:25<00:00,  2.79s/it]


Epoch 1 finished. Train Loss: 1.3834 Val Loss: 1.2911 IOU: 0.0009 Best IOU: 0.0009


Training: 100%|█████████████████████████████████████████████████████████████| 47/47 [07:20<00:00,  9.37s/it, loss=1.27]
Validating: 100%|████████████████████████████████████████████████████████████████████████| 9/9 [00:24<00:00,  2.76s/it]


Epoch 2 finished. Train Loss: 1.2654 Val Loss: 1.2570 IOU: 0.0229 Best IOU: 0.0229


Training: 100%|█████████████████████████████████████████████████████████████| 47/47 [07:16<00:00,  9.29s/it, loss=1.23]
Validating: 100%|████████████████████████████████████████████████████████████████████████| 9/9 [00:25<00:00,  2.78s/it]


Epoch 3 finished. Train Loss: 1.2300 Val Loss: 1.2377 IOU: 0.0378 Best IOU: 0.0378


Training: 100%|█████████████████████████████████████████████████████████████| 47/47 [07:17<00:00,  9.30s/it, loss=1.22]
Validating: 100%|████████████████████████████████████████████████████████████████████████| 9/9 [00:24<00:00,  2.78s/it]


Epoch 4 finished. Train Loss: 1.2156 Val Loss: 1.2182 IOU: 0.0155 Best IOU: 0.0378


Training: 100%|██████████████████████████████████████████████████████████████| 47/47 [07:16<00:00,  9.28s/it, loss=1.2]
Validating: 100%|████████████████████████████████████████████████████████████████████████| 9/9 [00:25<00:00,  2.78s/it]


Epoch 5 finished. Train Loss: 1.1971 Val Loss: 1.2006 IOU: 0.0413 Best IOU: 0.0413


Training: 100%|█████████████████████████████████████████████████████████████| 47/47 [07:15<00:00,  9.27s/it, loss=1.19]
Validating: 100%|████████████████████████████████████████████████████████████████████████| 9/9 [00:25<00:00,  2.78s/it]


Epoch 6 finished. Train Loss: 1.1883 Val Loss: 1.1971 IOU: 0.0334 Best IOU: 0.0413


Training: 100%|█████████████████████████████████████████████████████████████| 47/47 [07:16<00:00,  9.29s/it, loss=1.18]
Validating: 100%|████████████████████████████████████████████████████████████████████████| 9/9 [00:25<00:00,  2.79s/it]


Epoch 7 finished. Train Loss: 1.1779 Val Loss: 1.1920 IOU: 0.0076 Best IOU: 0.0413


Training: 100%|█████████████████████████████████████████████████████████████| 47/47 [07:17<00:00,  9.31s/it, loss=1.17]
Validating: 100%|████████████████████████████████████████████████████████████████████████| 9/9 [00:25<00:00,  2.78s/it]


Epoch 8 finished. Train Loss: 1.1697 Val Loss: 1.1725 IOU: 0.0392 Best IOU: 0.0413


Training: 100%|█████████████████████████████████████████████████████████████| 47/47 [07:16<00:00,  9.29s/it, loss=1.17]
Validating: 100%|████████████████████████████████████████████████████████████████████████| 9/9 [00:24<00:00,  2.78s/it]


Epoch 9 finished. Train Loss: 1.1675 Val Loss: 1.1785 IOU: 0.0184 Best IOU: 0.0413


Training: 100%|█████████████████████████████████████████████████████████████| 47/47 [07:16<00:00,  9.28s/it, loss=1.16]
Validating: 100%|████████████████████████████████████████████████████████████████████████| 9/9 [00:25<00:00,  2.79s/it]


Epoch 10 finished. Train Loss: 1.1558 Val Loss: 1.1871 IOU: 0.0065 Best IOU: 0.0413


Training: 100%|█████████████████████████████████████████████████████████████| 47/47 [07:15<00:00,  9.27s/it, loss=1.15]
Validating: 100%|████████████████████████████████████████████████████████████████████████| 9/9 [00:25<00:00,  2.79s/it]


Epoch 11 finished. Train Loss: 1.1524 Val Loss: 1.1755 IOU: 0.0465 Best IOU: 0.0465


Training: 100%|█████████████████████████████████████████████████████████████| 47/47 [07:15<00:00,  9.26s/it, loss=1.15]
Validating: 100%|████████████████████████████████████████████████████████████████████████| 9/9 [00:24<00:00,  2.77s/it]


Epoch 12 finished. Train Loss: 1.1497 Val Loss: 1.1586 IOU: 0.0504 Best IOU: 0.0504


Training: 100%|█████████████████████████████████████████████████████████████| 47/47 [07:16<00:00,  9.28s/it, loss=1.14]
Validating: 100%|████████████████████████████████████████████████████████████████████████| 9/9 [00:24<00:00,  2.77s/it]


Epoch 13 finished. Train Loss: 1.1427 Val Loss: 1.1676 IOU: 0.0443 Best IOU: 0.0504


Training: 100%|█████████████████████████████████████████████████████████████| 47/47 [07:13<00:00,  9.22s/it, loss=1.14]
Validating: 100%|████████████████████████████████████████████████████████████████████████| 9/9 [00:24<00:00,  2.76s/it]


Epoch 14 finished. Train Loss: 1.1382 Val Loss: 1.1515 IOU: 0.0910 Best IOU: 0.0910


Training:   4%|██▋                                                           | 2/47 [00:32<12:32, 16.73s/it, loss=1.11]