In [None]:
import os
import cv2
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import timm
import segmentation_models_pytorch as smp
from pathlib import Path
from tqdm import tqdm

# Configuration
CLASSIFIER_THRESHOLD = 0.25
SEG_THRESHOLD = 0.35
MIN_AREA = 300
CLASSIFIER_SIZE = 384
SEG_SIZE = 512
USE_TTA = True
USE_ADAPTIVE = True

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

In [None]:
# Model Definitions

class ForgeryClassifier(nn.Module):
    """Binary classifier: forged (1) vs authentic (0)."""
    def __init__(self, backbone='efficientnet_b2'):
        super().__init__()
        self.backbone = timm.create_model(backbone, pretrained=False, num_classes=0)
        self.classifier = nn.Sequential(
            nn.Linear(self.backbone.num_features, 256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(256, 1)
        )
    
    def forward(self, x):
        features = self.backbone(x)
        return self.classifier(features).squeeze(-1)


class AttentionGate(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        hidden1 = max(8, in_channels // 2)
        hidden2 = max(4, in_channels // 4)
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, hidden1, kernel_size=1),
            nn.BatchNorm2d(hidden1),
            nn.ReLU(inplace=True),
            nn.Conv2d(hidden1, hidden2, kernel_size=3, padding=1),
            nn.BatchNorm2d(hidden2),
            nn.ReLU(inplace=True),
            nn.Conv2d(hidden2, 1, kernel_size=1),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        return x * self.conv(x)


class AttentionFPN(nn.Module):
    def __init__(self, base_model):
        super().__init__()
        self.base = base_model
        self.attention = AttentionGate(1)
    
    def forward(self, x):
        return self.attention(self.base(x))

In [None]:
# Load Models
# NOTE: Update these paths to your Kaggle dataset paths

MODEL_DIR = Path('/kaggle/input/your-models-dataset')  # Update this!

# Load binary classifier
classifier = ForgeryClassifier('efficientnet_b2').to(device)
classifier_ckpt = torch.load(MODEL_DIR / 'binary_classifier_best.pth', map_location=device)
classifier.load_state_dict(classifier_ckpt['model_state_dict'])
classifier.eval()
print("✓ Binary classifier loaded")

# Load 4-model ensemble
model_names = [
    'highres_no_ela_v4_best.pth',
    'hard_negative_v4_best.pth',
    'high_recall_v4_best.pth',
    'enhanced_aug_v4_best.pth'
]

seg_models = []
for name in model_names:
    base = smp.FPN(
        encoder_name="timm-efficientnet-b2",
        encoder_weights=None,
        in_channels=3,
        classes=1,
    )
    model = AttentionFPN(base).to(device)
    ckpt = torch.load(MODEL_DIR / name, map_location=device)
    if 'model_state_dict' in ckpt:
        model.load_state_dict(ckpt['model_state_dict'])
    else:
        model.load_state_dict(ckpt)
    model.eval()
    seg_models.append(model)
    print(f"  ✓ {name}")

print(f"✓ Loaded {len(seg_models)} ensemble models")

In [None]:
# Preprocessing functions

def preprocess_for_classifier(img):
    img_resized = cv2.resize(img, (CLASSIFIER_SIZE, CLASSIFIER_SIZE))
    img_rgb = cv2.cvtColor(img_resized, cv2.COLOR_BGR2RGB)
    img_norm = img_rgb.astype(np.float32) / 255.0
    img_norm = (img_norm - np.array([0.485, 0.456, 0.406])) / np.array([0.229, 0.224, 0.225])
    tensor = torch.from_numpy(img_norm.transpose(2, 0, 1)).unsqueeze(0).float().to(device)
    return tensor


def preprocess_for_segmentation(img):
    img_resized = cv2.resize(img, (SEG_SIZE, SEG_SIZE))
    img_rgb = cv2.cvtColor(img_resized, cv2.COLOR_BGR2RGB)
    img_norm = img_rgb.astype(np.float32) / 255.0
    img_norm = (img_norm - np.array([0.485, 0.456, 0.406])) / np.array([0.229, 0.224, 0.225])
    tensor = torch.from_numpy(img_norm.transpose(2, 0, 1)).unsqueeze(0).float().to(device)
    return tensor


def apply_tta_ensemble(models, img_tensor):
    """Apply 4x TTA with mean aggregation across ensemble."""
    all_preds = []
    
    for model in models:
        preds = []
        
        # Original
        with torch.no_grad():
            pred = torch.sigmoid(model(img_tensor))
        preds.append(pred)
        
        # Horizontal flip
        with torch.no_grad():
            pred = torch.sigmoid(model(torch.flip(img_tensor, dims=[3])))
            pred = torch.flip(pred, dims=[3])
        preds.append(pred)
        
        # Vertical flip
        with torch.no_grad():
            pred = torch.sigmoid(model(torch.flip(img_tensor, dims=[2])))
            pred = torch.flip(pred, dims=[2])
        preds.append(pred)
        
        # Both flips
        with torch.no_grad():
            pred = torch.sigmoid(model(torch.flip(img_tensor, dims=[2, 3])))
            pred = torch.flip(pred, dims=[2, 3])
        preds.append(pred)
        
        # Mean for this model's TTA
        model_pred = torch.stack(preds).mean(dim=0)
        all_preds.append(model_pred)
    
    # Mean across ensemble
    return torch.stack(all_preds).mean(dim=0)


def get_adaptive_threshold(img, base_threshold=0.35):
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    brightness = np.mean(gray) / 255.0
    if brightness < 0.3:
        return base_threshold * 0.85
    elif brightness > 0.7:
        return base_threshold * 1.15
    return base_threshold

In [None]:
def mask_to_rle(mask):
    """Convert binary mask to RLE string."""
    pixels = mask.flatten()
    runs = []
    prev = 0
    start = 0
    
    for i, p in enumerate(pixels):
        if p != prev:
            runs.append(i - start)
            start = i
            prev = p
    runs.append(len(pixels) - start)
    
    if pixels[0] == 1:
        runs = [0] + runs
    
    rle_pairs = []
    for i in range(0, len(runs) - 1, 2):
        rle_pairs.append(f"{runs[i]} {runs[i+1]}")
    
    return ' '.join(rle_pairs) if rle_pairs else ''


def process_image(img_path, classifier, seg_models):
    """Two-stage processing of a single image."""
    img = cv2.imread(str(img_path))
    if img is None:
        return 'authentic'
    
    orig_h, orig_w = img.shape[:2]
    
    # Stage 1: Binary classification
    cls_tensor = preprocess_for_classifier(img)
    with torch.no_grad():
        cls_prob = torch.sigmoid(classifier(cls_tensor)).item()
    
    if cls_prob < CLASSIFIER_THRESHOLD:
        return 'authentic'
    
    # Stage 2: Segmentation with ensemble + TTA
    seg_tensor = preprocess_for_segmentation(img)
    
    if USE_TTA:
        pred = apply_tta_ensemble(seg_models, seg_tensor)
    else:
        preds = []
        for model in seg_models:
            with torch.no_grad():
                preds.append(torch.sigmoid(model(seg_tensor)))
        pred = torch.stack(preds).mean(dim=0)
    
    pred_np = pred.squeeze().cpu().numpy()
    
    # Apply threshold
    if USE_ADAPTIVE:
        threshold = get_adaptive_threshold(img, SEG_THRESHOLD)
    else:
        threshold = SEG_THRESHOLD
    
    mask = (pred_np > threshold).astype(np.uint8)
    
    # Filter small regions
    if MIN_AREA > 0:
        num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(mask, connectivity=8)
        mask_filtered = np.zeros_like(mask)
        for i in range(1, num_labels):
            if stats[i, cv2.CC_STAT_AREA] >= MIN_AREA:
                mask_filtered[labels == i] = 1
        mask = mask_filtered
    
    # Resize to original size
    mask_orig = cv2.resize(mask, (orig_w, orig_h), interpolation=cv2.INTER_NEAREST)
    
    if mask_orig.sum() == 0:
        return 'authentic'
    
    return mask_to_rle(mask_orig)

In [None]:
# Process test images and generate submission

TEST_DIR = Path('/kaggle/input/recodai-luc-scientific-image-forgery-detection/validation_images')

image_files = sorted(TEST_DIR.glob('*'))
print(f"Found {len(image_files)} test images")

results = []
for img_path in tqdm(image_files, desc="Processing"):
    case_id = img_path.stem
    annotation = process_image(img_path, classifier, seg_models)
    results.append({'case_id': case_id, 'annotation': annotation})

# Create submission
submission = pd.DataFrame(results)
submission.to_csv('submission.csv', index=False)

print(f"\nSubmission saved!")
print(f"Total: {len(submission)}")
print(f"Forged: {(submission['annotation'] != 'authentic').sum()}")
print(f"Authentic: {(submission['annotation'] == 'authentic').sum()}")
print(submission.head(10))