In [None]:
# Install required packages (if not already available)
# Suppress dependency warnings for pre-installed Kaggle packages
!pip install torch torchvision tqdm Pillow numpy --quiet --no-warn-conflicts 2>/dev/null || \
 pip install torch torchvision tqdm Pillow numpy -q

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import os

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## Utility Functions

In [None]:
# ImageNet statistics
MEAN = [0.485, 0.456, 0.406]
STD = [0.229, 0.224, 0.225]

def normalize(tensor):
    """Applies ImageNet normalization. Expects tensor in range [0, 1]."""
    mean = torch.tensor(MEAN).view(1, 3, 1, 1).to(tensor.device)
    std = torch.tensor(STD).view(1, 3, 1, 1).to(tensor.device)
    return (tensor - mean) / std

def create_border_mask(h, w, border_width, device):
    """Creates a mask that selects only the border pixels."""
    mask = torch.zeros((1, 3, h, w), device=device)
    mask[:, :, :border_width, :] = 1  # top
    mask[:, :, h-border_width:, :] = 1  # bottom
    mask[:, :, :, :border_width] = 1  # left
    mask[:, :, :, w-border_width:] = 1  # right
    return mask

def stack_borders(image_tensor, border_width):
    """Extracts borders and stacks them for the fidelity model."""
    h, w = image_tensor.shape[2], image_tensor.shape[3]
    top = image_tensor[:, :, :border_width, :]
    bottom = image_tensor[:, :, h-border_width:, :]
    left = image_tensor[:, :, border_width:h-border_width, :border_width]
    right = image_tensor[:, :, border_width:h-border_width, w-border_width:]
    
    # Permute side borders to be horizontal strips
    left_permuted = left.permute(0, 1, 3, 2)
    right_permuted = right.permute(0, 1, 3, 2)
    
    target_width = image_tensor.shape[3]
    
    def resize_strip(strip):
        return torch.nn.functional.interpolate(
            strip, size=(border_width, target_width), 
            mode='bilinear', align_corners=False
        )
    
    stacked = torch.cat([
        top, bottom, 
        resize_strip(left_permuted), 
        resize_strip(right_permuted)
    ], dim=2)
    
    # Repeat to ensure sufficient height for VGG
    while stacked.shape[2] < 64:
        stacked = torch.cat([stacked, stacked], dim=2)
    
    return stacked

def truncation_loss(adv_image):
    """Encourages pixel values to be close to integer values (0-255)."""
    scaled = adv_image * 255.0
    target_int = torch.round(scaled).detach()
    return torch.abs(target_int/255.0 - adv_image).mean()

def load_image(path, size=224):
    """Load and preprocess an image."""
    transform = transforms.Compose([
        transforms.Resize((size, size)),
        transforms.ToTensor(),
    ])
    img = Image.open(path).convert('RGB')
    return transform(img).unsqueeze(0)

def save_image(tensor, path):
    """Saves a [0, 1] tensor as an image."""
    tensor = tensor.detach().cpu().squeeze(0)
    tensor = torch.clamp(tensor, 0, 1)
    to_pil = transforms.ToPILImage()
    img = to_pil(tensor)
    img.save(path)
    return img

## Model Definitions

In [None]:
class TargetModel(nn.Module):
    """ResNet50 classifier model."""
    def __init__(self, device):
        super(TargetModel, self).__init__()
        self.model = models.resnet50(pretrained=True).to(device)
        self.model.eval()
        for param in self.model.parameters():
            param.requires_grad = False
        self.device = device

    def forward(self, x):
        x_norm = normalize(x)
        return self.model(x_norm)

class Vgg19Fidelity(nn.Module):
    """VGG19 feature extractor for perceptual loss."""
    def __init__(self, device):
        super(Vgg19Fidelity, self).__init__()
        vgg = models.vgg19(pretrained=True).features.to(device)
        self.model = vgg.eval()
        for param in self.model.parameters():
            param.requires_grad = False
        
        self.content_layers = ['conv_4']
        self.style_layers = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']
        
        self.layer_map = {
            '0': 'conv_1',
            '5': 'conv_2',
            '10': 'conv_3',
            '19': 'conv_4',
            '28': 'conv_5',
        }

    def forward(self, x):
        x = normalize(x)
        features = {}
        for name, layer in self.model._modules.items():
            x = layer(x)
            if name in self.layer_map:
                features[self.layer_map[name]] = x
        return features

    def compute_loss(self, adv_stacked, clean_stacked):
        adv_feats = self(adv_stacked)
        clean_feats = self(clean_stacked)
        
        loss_c = 0.0
        loss_s = 0.0
        
        # Content Loss
        for layer in self.content_layers:
            loss_c += torch.mean((adv_feats[layer] - clean_feats[layer]) ** 2)
        
        # Style Loss (Gram Matrix)
        for layer in self.style_layers:
            a = adv_feats[layer]
            c = clean_feats[layer]
            
            b, ch, h, w = a.shape
            a = a.view(b, ch, h * w)
            c = c.view(b, ch, h * w)
            
            gram_a = torch.bmm(a, a.transpose(1, 2)) / (ch * h * w)
            gram_c = torch.bmm(c, c.transpose(1, 2)) / (ch * h * w)
            
            loss_s += torch.mean((gram_a - gram_c) ** 2)
        
        return loss_c, loss_s

## Untargeted Attack Implementation

In [None]:
class UntargetedBorderAttack:
    """Untargeted adversarial attack that modifies only image borders."""
    
    def __init__(self, target_model, fidelity_model, device, 
                 border_width=4, lambda_a=1.5, lambda_f=1000.0, max_iter=50):
        self.target_model = target_model
        self.fidelity_model = fidelity_model
        self.device = device
        self.border_width = border_width
        self.lambda_a = lambda_a
        self.lambda_f = lambda_f
        self.max_iter = max_iter

    def _untargeted_loss(self, logits, original_class):
        """
        Margin loss for untargeted attack.
        L_margin = Z(x')_y - max(Z(x')_i) for i != y
        We minimize this to reduce the score of true class and increase others.
        """
        z_y = logits[0, original_class]
        
        # Get max of other classes
        logits_others = logits.clone()
        logits_others[0, original_class] = -float('inf')
        z_max_other, _ = torch.max(logits_others, dim=1)
        
        l_margin = z_y - z_max_other
        
        # Leaky ReLU: max(L_margin, 0.1 * L_margin)
        l_a = torch.max(l_margin, 0.1 * l_margin)
        return l_a

    def run(self, image, original_class):
        """
        Performs untargeted attack.
        
        Args:
            image: Input image tensor (1, 3, H, W) in range [0, 1]
            original_class: Ground truth class to avoid
            
        Returns:
            adv_image: Adversarial image
            success: Whether attack succeeded (misclassified)
        """
        b, c, h, w = image.shape
        mask = create_border_mask(h, w, self.border_width, self.device)
        inverse_mask = 1 - mask
        
        # Initialize adversarial image
        adv_image = image.clone().detach().requires_grad_(True)
        optimizer = optim.LBFGS([adv_image], max_iter=20, history_size=10)
        
        clean_stacked_borders = stack_borders(image.detach(), self.border_width)
        
        iteration = 0
        success = False
        
        print(f"Running untargeted attack (original class: {original_class})...")
        
        while iteration < self.max_iter:
            def closure():
                optimizer.zero_grad()
                
                # Enforce constraints
                with torch.no_grad():
                    adv_image.clamp_(0, 1)
                    adv_image.data = adv_image.data * mask + image.data * inverse_mask
                
                logits = self.target_model(adv_image)
                
                # Attack Loss (untargeted)
                l_attack = self._untargeted_loss(logits, original_class)
                
                # Fidelity Loss
                adv_stacked = stack_borders(adv_image, self.border_width)
                l_c, l_s = self.fidelity_model.compute_loss(adv_stacked, clean_stacked_borders)
                l_fidelity = l_c + l_s
                
                # Truncation Loss
                l_trunc = truncation_loss(adv_image * mask)
                
                # Total Loss
                total_loss = (self.lambda_a * l_attack) + (self.lambda_f * l_fidelity) + l_trunc
                
                if adv_image.grad is not None:
                    adv_image.grad.data.mul_(mask)
                
                total_loss.backward()
                adv_image.grad.data.mul_(mask)
                
                return total_loss
            
            optimizer.step(closure)
            
            # Check success
            with torch.no_grad():
                logits = self.target_model(adv_image)
                pred = torch.argmax(logits, dim=1).item()
                
                if pred != original_class:  # Misclassified = success
                    success = True
                    print(f"✓ Attack succeeded at iteration {iteration}! Misclassified to class {pred}")
                    break
            
            iteration += 1
            if iteration % 10 == 0:
                print(f"Iter {iteration}: Prediction {pred}")
        
        # Final cleanup
        with torch.no_grad():
            adv_image.clamp_(0, 1)
            adv_image.data = adv_image.data * mask + image.data * inverse_mask
        
        if not success:
            print(f"✗ Attack failed after {self.max_iter} iterations")
        
        return adv_image, success

## Load Models

In [None]:
print("Loading models...")
target_model = TargetModel(device)
fidelity_model = Vgg19Fidelity(device)
print("Models loaded successfully!")

## Configuration & Setup

Upload your test images to Kaggle and update the path below.

In [None]:
# Configuration
IMAGE_DIR = '/kaggle/input/imagenet-validation-set/processed/processed'  # UPDATE THIS PATH
BORDER_WIDTH = 4
MAX_ITERATIONS = 50
OUTPUT_DIR = '/kaggle/working/results'

# Create output directory
os.makedirs(OUTPUT_DIR, exist_ok=True)

print(f"Configuration loaded.")
print(f"Image directory: {IMAGE_DIR}")
print(f"Output directory: {OUTPUT_DIR}")

## Run Attack

In [None]:
# Batch Attack on All Images
import glob

print(f"Scanning directory: {IMAGE_DIR}")
# Support multiple image formats including .pt tensor files
image_paths = []
for ext in ['*.pt', '*.jpg', '*.jpeg', '*.png', '*.JPEG', '*.JPG', '*.PNG']:
    image_paths.extend(glob.glob(os.path.join(IMAGE_DIR, ext)))

print(f"Found {len(image_paths)} images to attack")

# Initialize attacker
attacker = UntargetedBorderAttack(
    target_model, 
    fidelity_model, 
    device,
    border_width=BORDER_WIDTH,
    max_iter=MAX_ITERATIONS,
    lambda_a=1.5,  # Higher for untargeted attacks
    lambda_f=1000.0
)

results = []
success_count = 0

for i, img_path in enumerate(tqdm(image_paths, desc="Attacking images")):
    try:
        # Load image (handle both .pt files and regular images)
        if img_path.endswith('.pt'):
            image = torch.load(img_path, map_location=device)
            # Ensure correct shape (1, 3, H, W)
            if image.dim() == 3:
                image = image.unsqueeze(0)
        else:
            image = load_image(img_path).to(device)
        
        # Get original prediction
        with torch.no_grad():
            orig_logits = target_model(image)
            orig_pred = torch.argmax(orig_logits, dim=1).item()
            orig_conf = torch.softmax(orig_logits, dim=1)[0, orig_pred].item()
        
        # Run attack
        adv_image, success = attacker.run(image, orig_pred)
        
        # Get adversarial prediction
        with torch.no_grad():
            adv_logits = target_model(adv_image)
            adv_pred = torch.argmax(adv_logits, dim=1).item()
            adv_conf = torch.softmax(adv_logits, dim=1)[0, adv_pred].item()
        
        if success:
            success_count += 1
        
        results.append({
            'filename': os.path.basename(img_path),
            'original_pred': orig_pred,
            'original_conf': orig_conf,
            'adversarial_pred': adv_pred,
            'adversarial_conf': adv_conf,
            'success': success
        })
        
        # Save adversarial image
        base_name = os.path.splitext(os.path.basename(img_path))[0]
        save_path = os.path.join(OUTPUT_DIR, f"adv_{base_name}.png")
        save_image(adv_image, save_path)
        
        # Print progress every 100 images
        if (i + 1) % 100 == 0:
            current_asr = (success_count / (i + 1)) * 100
            print(f"\nProgress: {i+1}/{len(image_paths)} | Current ASR: {current_asr:.2f}%")
        
    except Exception as e:
        print(f"\nError processing {img_path}: {e}")
        results.append({
            'filename': os.path.basename(img_path),
            'original_pred': -1,
            'original_conf': 0.0,
            'adversarial_pred': -1,
            'adversarial_conf': 0.0,
            'success': False,
            'error': str(e)
        })

# Calculate final statistics
total_attacked = len([r for r in results if 'error' not in r])
final_asr = (success_count / total_attacked * 100) if total_attacked > 0 else 0

print(f"\n{'='*60}")
print(f"ATTACK SUMMARY")
print(f"{'='*60}")
print(f"Total images found: {len(image_paths)}")
print(f"Successfully attacked: {total_attacked}")
print(f"Successful attacks: {success_count}")
print(f"Attack Success Rate (ASR): {final_asr:.2f}%")
print(f"Results saved to: {OUTPUT_DIR}")
print(f"{'='*60}")

## Visualize Results

In [None]:
# Visualize Sample Results (first successful attack)
successful_results = [r for r in results if r.get('success', False)]

if successful_results:
    # Get first successful attack
    sample = successful_results[0]
    sample_path = os.path.join(IMAGE_DIR, sample['filename'])
    
    # Load original image (handle .pt files)
    if sample_path.endswith('.pt'):
        orig_image = torch.load(sample_path, map_location='cpu')
        if orig_image.dim() == 3:
            orig_image = orig_image.unsqueeze(0)
    else:
        orig_image = load_image(sample_path)
    
    adv_path = os.path.join(OUTPUT_DIR, f"adv_{os.path.splitext(sample['filename'])[0]}.png")
    
    orig_img = save_image(orig_image, os.path.join(OUTPUT_DIR, 'sample_original.png'))
    adv_img = Image.open(adv_path)
    
    # Compute difference
    adv_tensor = load_image(adv_path)
    diff = torch.abs(adv_tensor - orig_image)
    diff_amplified = diff * 10
    diff_img = save_image(diff_amplified, os.path.join(OUTPUT_DIR, 'sample_difference.png'))
    
    # Display
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    axes[0].imshow(orig_img)
    axes[0].set_title(f'Original\nPred: {sample["original_pred"]} ({sample["original_conf"]:.2f})')
    axes[0].axis('off')
    
    axes[1].imshow(adv_img)
    axes[1].set_title(f'Adversarial\nPred: {sample["adversarial_pred"]} ({sample["adversarial_conf"]:.2f})')
    axes[1].axis('off')
    
    axes[2].imshow(diff_img)
    axes[2].set_title('Difference (10x amplified)')
    axes[2].axis('off')
    
    plt.tight_layout()
    plt.savefig(os.path.join(OUTPUT_DIR, 'sample_comparison.png'), dpi=150, bbox_inches='tight')
    plt.show()
    
    print(f"Sample visualization from: {sample['filename']}")
else:
    print("No successful attacks to visualize.")

## Batch Processing (Optional)

Attack multiple images at once

In [None]:
# Save Detailed Results to CSV
import csv

csv_path = os.path.join(OUTPUT_DIR, 'attack_results.csv')

with open(csv_path, 'w', newline='') as f:
    if results:
        fieldnames = list(results[0].keys())
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        writer.writeheader()
        writer.writerows(results)

print(f"Detailed results saved to: {csv_path}")

# Show summary statistics
print("\n" + "="*60)
print("DETAILED STATISTICS")
print("="*60)

successful_results = [r for r in results if r.get('success', False)]
failed_results = [r for r in results if not r.get('success', False) and 'error' not in r]

if successful_results:
    avg_orig_conf = sum(r['original_conf'] for r in successful_results) / len(successful_results)
    avg_adv_conf = sum(r['adversarial_conf'] for r in successful_results) / len(successful_results)
    print(f"Successful Attacks: {len(successful_results)}")
    print(f"  - Avg Original Confidence: {avg_orig_conf:.4f}")
    print(f"  - Avg Adversarial Confidence: {avg_adv_conf:.4f}")

if failed_results:
    print(f"\nFailed Attacks: {len(failed_results)}")

errors = [r for r in results if 'error' in r]
if errors:
    print(f"Errors: {len(errors)}")
    print("\nFirst 5 errors:")
    for r in errors[:5]:
        print(f"  - {r['filename']}: {r.get('error', 'Unknown')}")

print("="*60)