# DRAEM Training on Kaggle - Stage 2: Defect Localization

This notebook trains the DRAEM model for bike defect localization.

**No annotation required!** Trains only on intact images.

---

## Setup

### 1. Upload Your Data

Upload your data folder to Kaggle with this structure:
```
/kaggle/input/bike-data/
└── intact/
    ├── bike001.jpg
    ├── bike002.jpg
    └── ...
```

### 2. Run All Cells

Just run all cells in order!

---

## 1. Install Dependencies

In [None]:
# Install required packages
!pip install -q opencv-python
!pip install -q scikit-image
!pip install -q albumentations

print("✅ Dependencies installed!")

---

## 2. Import Libraries

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
import cv2
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm import tqdm
import random
import os

# Set seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

---

## 3. Configuration

In [None]:
# ============================================================================
# CONFIGURATION - Adjust these settings
# ============================================================================

CONFIG = {
    # Data paths
    'intact_dir': '/kaggle/input/bike-data/intact',  # ⚠️ Change this to your data path!
    'anomaly_source_path': None,  # Optional: path to texture images (DTD dataset)
    
    # Training parameters
    'epochs': 100,
    'batch_size': 8,  # Reduce to 4 if GPU memory issues
    'learning_rate': 0.0001,
    'image_size': 256,
    
    # Device
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    
    # Data loading
    'num_workers': 2,
    'pin_memory': True,
    
    # Perlin noise parameters
    'min_perlin_scale': 0,
    'max_perlin_scale': 6,
    
    # Output
    'output_dir': '/kaggle/working',
    'save_interval': 10,  # Save checkpoint every N epochs
}

print("Configuration:")
for key, value in CONFIG.items():
    print(f"  {key}: {value}")

# Create output directory
os.makedirs(CONFIG['output_dir'], exist_ok=True)
print(f"\n✅ Output directory: {CONFIG['output_dir']}")

---

## 4. Verify Data

In [None]:
# Check if data exists
intact_dir = Path(CONFIG['intact_dir'])

if not intact_dir.exists():
    print(f"❌ ERROR: Data directory not found: {intact_dir}")
    print("\nPlease:")
    print("1. Upload your data to Kaggle")
    print("2. Update CONFIG['intact_dir'] in the cell above")
    print("3. Re-run this cell")
else:
    # Count images
    image_files = []
    for ext in ['*.jpg', '*.png', '*.jpeg']:
        image_files.extend(list(intact_dir.glob(f'**/{ext}')))
    
    print(f"✅ Data directory found: {intact_dir}")
    print(f"✅ Found {len(image_files)} intact images")
    
    # Show sample images
    if len(image_files) > 0:
        fig, axes = plt.subplots(1, min(5, len(image_files)), figsize=(15, 3))
        if len(image_files) == 1:
            axes = [axes]
        
        for i, img_path in enumerate(image_files[:5]):
            img = Image.open(img_path)
            axes[i].imshow(img)
            axes[i].set_title(f'{img_path.name}\n{img.size}')
            axes[i].axis('off')
        
        plt.tight_layout()
        plt.show()
        
        print(f"\n✅ Ready to train!")
    else:
        print("\n❌ No images found! Please check your data directory.")

---

## 5. Perlin Noise Generator

In [None]:
# ============================================================================
# Perlin Noise for Synthetic Anomaly Generation
# ============================================================================

def rand_perlin_2d(shape, res, fade=lambda t: 6*t**5 - 15*t**4 + 10*t**3):
    """Generate 2D Perlin noise"""
    delta = (res[0] / shape[0], res[1] / shape[1])
    d = (shape[0] // res[0], shape[1] // res[1])
    
    grid = np.mgrid[0:res[0]:delta[0], 0:res[1]:delta[1]].transpose(1, 2, 0) % 1
    
    # Gradients
    angles = 2 * np.pi * np.random.rand(res[0] + 1, res[1] + 1)
    gradients = np.dstack((np.cos(angles), np.sin(angles)))
    
    # Get grid coordinates
    g00 = gradients[0:-1, 0:-1].repeat(d[0], 0).repeat(d[1], 1)
    g10 = gradients[1:, 0:-1].repeat(d[0], 0).repeat(d[1], 1)
    g01 = gradients[0:-1, 1:].repeat(d[0], 0).repeat(d[1], 1)
    g11 = gradients[1:, 1:].repeat(d[0], 0).repeat(d[1], 1)
    
    # Ramps
    n00 = np.sum(np.dstack((grid[:, :, 0], grid[:, :, 1])) * g00, 2)
    n10 = np.sum(np.dstack((grid[:, :, 0] - 1, grid[:, :, 1])) * g10, 2)
    n01 = np.sum(np.dstack((grid[:, :, 0], grid[:, :, 1] - 1)) * g01, 2)
    n11 = np.sum(np.dstack((grid[:, :, 0] - 1, grid[:, :, 1] - 1)) * g11, 2)
    
    # Interpolation
    t = fade(grid)
    n0 = n00 * (1 - t[:, :, 0]) + t[:, :, 0] * n10
    n1 = n01 * (1 - t[:, :, 0]) + t[:, :, 0] * n11
    
    return np.sqrt(2) * ((1 - t[:, :, 1]) * n0 + t[:, :, 1] * n1)


def rand_perlin_2d_octaves(shape, res, octaves=1, persistence=0.5):
    """Generate multi-octave Perlin noise"""
    noise = np.zeros(shape)
    frequency = 1
    amplitude = 1
    
    for _ in range(octaves):
        noise += amplitude * rand_perlin_2d(
            shape,
            (frequency * res[0], frequency * res[1])
        )
        frequency *= 2
        amplitude *= persistence
    
    return noise


def generate_perlin_noise_mask(img_size, min_perlin_scale=0, max_perlin_scale=6):
    """Generate binary anomaly mask using Perlin noise"""
    perlin_scale = 2 ** np.random.randint(min_perlin_scale, max_perlin_scale)
    
    perlin_noise = rand_perlin_2d_octaves(
        img_size,
        (perlin_scale, perlin_scale),
        octaves=3
    )
    
    threshold = np.random.uniform(0.0, 0.5)
    mask = np.where(perlin_noise > threshold, 1.0, 0.0)
    
    return mask


def generate_smooth_anomaly(img_size, min_perlin_scale=0, max_perlin_scale=6):
    """Generate smooth anomaly pattern"""
    perlin_scale = 2 ** np.random.randint(min_perlin_scale, max_perlin_scale)
    
    perlin_noise = rand_perlin_2d_octaves(
        img_size,
        (perlin_scale, perlin_scale),
        octaves=4,
        persistence=0.6
    )
    
    # Normalize to [0, 1]
    perlin_noise = (perlin_noise - perlin_noise.min()) / (perlin_noise.max() - perlin_noise.min() + 1e-8)
    
    return perlin_noise


print("✅ Perlin noise functions defined")

# Test Perlin noise
test_noise = rand_perlin_2d_octaves((128, 128), (4, 4), octaves=3)
test_mask = generate_perlin_noise_mask((128, 128))

fig, axes = plt.subplots(1, 2, figsize=(10, 4))
axes[0].imshow(test_noise, cmap='gray')
axes[0].set_title('Perlin Noise')
axes[0].axis('off')

axes[1].imshow(test_mask, cmap='gray')
axes[1].set_title('Binary Mask')
axes[1].axis('off')

plt.tight_layout()
plt.show()

print("✅ Perlin noise working correctly!")

---

## 6. Anomaly Generator

In [None]:
# ============================================================================
# Synthetic Anomaly Generator
# ============================================================================

class AnomalyGenerator:
    """Generate synthetic anomalies for training"""
    
    def __init__(self, anomaly_source_path=None, resize_shape=(256, 256)):
        self.resize_shape = resize_shape
        self.anomaly_source_path = anomaly_source_path
        self.anomaly_source_images = []
        
        if anomaly_source_path and Path(anomaly_source_path).exists():
            self.anomaly_source_images = self._load_anomaly_sources()
            print(f"[OK] Loaded {len(self.anomaly_source_images)} anomaly source images")
        else:
            print("[INFO] No anomaly source images, using only Perlin noise")
    
    def _load_anomaly_sources(self):
        source_path = Path(self.anomaly_source_path)
        image_files = []
        for ext in ['*.jpg', '*.png', '*.jpeg']:
            image_files.extend(list(source_path.glob(f'**/{ext}')))
        return image_files
    
    def generate_anomaly(self, img):
        """Generate synthetic anomaly on normal image"""
        if isinstance(img, Image.Image):
            img = np.array(img)
        
        h, w = img.shape[:2]
        
        # Generate Perlin noise mask
        perlin_mask = generate_perlin_noise_mask(
            (h, w),
            min_perlin_scale=CONFIG['min_perlin_scale'],
            max_perlin_scale=CONFIG['max_perlin_scale']
        )
        
        # Augment mask
        perlin_mask = self._augment_mask(perlin_mask)
        
        # Choose anomaly type
        if len(self.anomaly_source_images) > 0:
            anomaly_type = random.choice(['texture', 'noise', 'brightness'])
        else:
            anomaly_type = random.choice(['noise', 'brightness'])
        
        if anomaly_type == 'texture' and len(self.anomaly_source_images) > 0:
            augmented_img = self._texture_anomaly(img, perlin_mask)
        elif anomaly_type == 'noise':
            augmented_img = self._noise_anomaly(img, perlin_mask)
        else:
            augmented_img = self._brightness_anomaly(img, perlin_mask)
        
        augmented_img = np.clip(augmented_img, 0, 255).astype(np.uint8)
        anomaly_mask = (perlin_mask > 0).astype(np.float32)
        
        return augmented_img, anomaly_mask
    
    def _texture_anomaly(self, img, mask):
        texture_path = random.choice(self.anomaly_source_images)
        texture = Image.open(texture_path).convert('RGB')
        texture = texture.resize((img.shape[1], img.shape[0]))
        texture = np.array(texture)
        
        smooth_weight = generate_smooth_anomaly(
            (img.shape[0], img.shape[1]),
            min_perlin_scale=0,
            max_perlin_scale=4
        )
        
        mask_3ch = np.stack([mask] * 3, axis=2)
        weight_3ch = np.stack([smooth_weight] * 3, axis=2)
        
        augmented = img * (1 - mask_3ch) + texture * mask_3ch * weight_3ch + img * mask_3ch * (1 - weight_3ch)
        return augmented
    
    def _noise_anomaly(self, img, mask):
        noise = np.random.randint(0, 255, img.shape, dtype=np.uint8)
        noise = cv2.GaussianBlur(noise, (5, 5), 0)
        
        smooth_weight = generate_smooth_anomaly(
            (img.shape[0], img.shape[1]),
            min_perlin_scale=1,
            max_perlin_scale=5
        )
        
        mask_3ch = np.stack([mask] * 3, axis=2)
        weight_3ch = np.stack([smooth_weight] * 3, axis=2)
        
        augmented = img * (1 - mask_3ch * weight_3ch) + noise * mask_3ch * weight_3ch
        return augmented
    
    def _brightness_anomaly(self, img, mask):
        brightness_factor = random.uniform(0.3, 2.0)
        mask_3ch = np.stack([mask] * 3, axis=2)
        
        augmented = img.copy().astype(np.float32)
        augmented = augmented * (1 - mask_3ch) + augmented * brightness_factor * mask_3ch
        return augmented
    
    def _augment_mask(self, mask):
        kernel_size = random.choice([3, 5, 7])
        kernel = np.ones((kernel_size, kernel_size), np.uint8)
        operation = random.choice(['erode', 'dilate', 'open', 'close', 'none'])
        
        # Ensure mask is numpy array and convert to uint8
        mask = np.asarray(mask, dtype=np.float32)
        mask_uint8 = (mask * 255).astype(np.uint8)
        
        if operation == 'erode':
            mask = cv2.erode(mask_uint8, kernel, iterations=1) / 255.0
        elif operation == 'dilate':
            mask = cv2.dilate(mask_uint8, kernel, iterations=1) / 255.0
        elif operation == 'open':
            mask = cv2.morphologyEx(mask_uint8, cv2.MORPH_OPEN, kernel) / 255.0
        elif operation == 'close':
            mask = cv2.morphologyEx(mask_uint8, cv2.MORPH_CLOSE, kernel) / 255.0
        else:
            mask = mask_uint8 / 255.0
        
        return mask


print("✅ Anomaly generator defined")

# Test anomaly generation
test_img = np.random.randint(100, 200, (256, 256, 3), dtype=np.uint8)
generator = AnomalyGenerator()

fig, axes = plt.subplots(2, 3, figsize=(12, 8))
for i in range(3):
    aug_img, mask = generator.generate_anomaly(test_img)
    axes[0, i].imshow(aug_img)
    axes[0, i].set_title(f'Augmented {i+1}')
    axes[0, i].axis('off')
    axes[1, i].imshow(mask, cmap='gray')
    axes[1, i].set_title(f'Mask {i+1}')
    axes[1, i].axis('off')

plt.tight_layout()
plt.show()

print("✅ Anomaly generator working correctly!")

---

## 7. DRAEM Model Architecture

In [None]:
# ============================================================================
# DRAEM Model Architecture
# ============================================================================

class EncoderReconstructive(nn.Module):
    def __init__(self, in_channels, base_width):
        super().__init__()
        
        self.block1 = nn.Sequential(
            nn.Conv2d(in_channels, base_width, 3, padding=1),
            nn.BatchNorm2d(base_width),
            nn.ReLU(True),
            nn.Conv2d(base_width, base_width, 3, padding=1),
            nn.BatchNorm2d(base_width),
            nn.ReLU(True)
        )
        self.mp1 = nn.MaxPool2d(2)
        
        self.block2 = nn.Sequential(
            nn.Conv2d(base_width, base_width * 2, 3, padding=1),
            nn.BatchNorm2d(base_width * 2),
            nn.ReLU(True),
            nn.Conv2d(base_width * 2, base_width * 2, 3, padding=1),
            nn.BatchNorm2d(base_width * 2),
            nn.ReLU(True)
        )
        self.mp2 = nn.MaxPool2d(2)
        
        self.block3 = nn.Sequential(
            nn.Conv2d(base_width * 2, base_width * 4, 3, padding=1),
            nn.BatchNorm2d(base_width * 4),
            nn.ReLU(True),
            nn.Conv2d(base_width * 4, base_width * 4, 3, padding=1),
            nn.BatchNorm2d(base_width * 4),
            nn.ReLU(True)
        )
        self.mp3 = nn.MaxPool2d(2)
        
        self.block4 = nn.Sequential(
            nn.Conv2d(base_width * 4, base_width * 8, 3, padding=1),
            nn.BatchNorm2d(base_width * 8),
            nn.ReLU(True),
            nn.Conv2d(base_width * 8, base_width * 8, 3, padding=1),
            nn.BatchNorm2d(base_width * 8),
            nn.ReLU(True)
        )
        self.mp4 = nn.MaxPool2d(2)
        
        self.block5 = nn.Sequential(
            nn.Conv2d(base_width * 8, base_width * 8, 3, padding=1),
            nn.BatchNorm2d(base_width * 8),
            nn.ReLU(True),
            nn.Conv2d(base_width * 8, base_width * 8, 3, padding=1),
            nn.BatchNorm2d(base_width * 8),
            nn.ReLU(True)
        )
    
    def forward(self, x):
        b1 = self.block1(x)
        mp1 = self.mp1(b1)
        b2 = self.block2(mp1)
        mp2 = self.mp2(b2)
        b3 = self.block3(mp2)
        mp3 = self.mp3(b3)
        b4 = self.block4(mp3)
        mp4 = self.mp4(b4)
        b5 = self.block5(mp4)
        return b5


class DecoderReconstructive(nn.Module):
    def __init__(self, base_width, out_channels):
        super().__init__()
        
        self.up1 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(base_width * 8, base_width * 8, 3, padding=1),
            nn.BatchNorm2d(base_width * 8),
            nn.ReLU(True)
        )
        
        self.up2 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(base_width * 8, base_width * 4, 3, padding=1),
            nn.BatchNorm2d(base_width * 4),
            nn.ReLU(True)
        )
        
        self.up3 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(base_width * 4, base_width * 2, 3, padding=1),
            nn.BatchNorm2d(base_width * 2),
            nn.ReLU(True)
        )
        
        self.up4 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(base_width * 2, base_width, 3, padding=1),
            nn.BatchNorm2d(base_width),
            nn.ReLU(True)
        )
        
        self.final = nn.Conv2d(base_width, out_channels, 1)
    
    def forward(self, b5):
        up1 = self.up1(b5)
        up2 = self.up2(up1)
        up3 = self.up3(up2)
        up4 = self.up4(up3)
        output = self.final(up4)
        return output


class ReconstructiveSubNetwork(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, base_width=128):
        super().__init__()
        self.encoder = EncoderReconstructive(in_channels, base_width)
        self.decoder = DecoderReconstructive(base_width, out_channels)
    
    def forward(self, x):
        b5 = self.encoder(x)
        output = self.decoder(b5)
        return output


# Similar implementation for Discriminative network...
# (Code continues in next cell due to length)

print("✅ Reconstructive network defined")

In [None]:
# ============================================================================
# DRAEM Model - Part 2: Discriminative Network
# ============================================================================

class EncoderDiscriminative(nn.Module):
    def __init__(self, in_channels, base_width):
        super().__init__()
        
        self.block1 = nn.Sequential(
            nn.Conv2d(in_channels, base_width, 3, padding=1),
            nn.BatchNorm2d(base_width),
            nn.ReLU(True),
            nn.Conv2d(base_width, base_width, 3, padding=1),
            nn.BatchNorm2d(base_width),
            nn.ReLU(True)
        )
        self.mp1 = nn.MaxPool2d(2)
        
        self.block2 = nn.Sequential(
            nn.Conv2d(base_width, base_width * 2, 3, padding=1),
            nn.BatchNorm2d(base_width * 2),
            nn.ReLU(True),
            nn.Conv2d(base_width * 2, base_width * 2, 3, padding=1),
            nn.BatchNorm2d(base_width * 2),
            nn.ReLU(True)
        )
        self.mp2 = nn.MaxPool2d(2)
        
        self.block3 = nn.Sequential(
            nn.Conv2d(base_width * 2, base_width * 4, 3, padding=1),
            nn.BatchNorm2d(base_width * 4),
            nn.ReLU(True),
            nn.Conv2d(base_width * 4, base_width * 4, 3, padding=1),
            nn.BatchNorm2d(base_width * 4),
            nn.ReLU(True)
        )
        self.mp3 = nn.MaxPool2d(2)
        
        self.block4 = nn.Sequential(
            nn.Conv2d(base_width * 4, base_width * 8, 3, padding=1),
            nn.BatchNorm2d(base_width * 8),
            nn.ReLU(True),
            nn.Conv2d(base_width * 8, base_width * 8, 3, padding=1),
            nn.BatchNorm2d(base_width * 8),
            nn.ReLU(True)
        )
        self.mp4 = nn.MaxPool2d(2)
        
        self.block5 = nn.Sequential(
            nn.Conv2d(base_width * 8, base_width * 8, 3, padding=1),
            nn.BatchNorm2d(base_width * 8),
            nn.ReLU(True),
            nn.Conv2d(base_width * 8, base_width * 8, 3, padding=1),
            nn.BatchNorm2d(base_width * 8),
            nn.ReLU(True)
        )
        self.mp5 = nn.MaxPool2d(2)
        
        self.block6 = nn.Sequential(
            nn.Conv2d(base_width * 8, base_width * 8, 3, padding=1),
            nn.BatchNorm2d(base_width * 8),
            nn.ReLU(True),
            nn.Conv2d(base_width * 8, base_width * 8, 3, padding=1),
            nn.BatchNorm2d(base_width * 8),
            nn.ReLU(True)
        )
    
    def forward(self, x):
        b1 = self.block1(x)
        mp1 = self.mp1(b1)
        b2 = self.block2(mp1)
        mp2 = self.mp2(b2)
        b3 = self.block3(mp2)
        mp3 = self.mp3(b3)
        b4 = self.block4(mp3)
        mp4 = self.mp4(b4)
        b5 = self.block5(mp4)
        mp5 = self.mp5(b5)
        b6 = self.block6(mp5)
        return b1, b2, b3, b4, b5, b6


class DecoderDiscriminative(nn.Module):
    def __init__(self, base_width, out_channels):
        super().__init__()
        
        self.up1 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(base_width * 8, base_width * 8, 3, padding=1),
            nn.BatchNorm2d(base_width * 8),
            nn.ReLU(True)
        )
        
        self.db1 = nn.Sequential(
            nn.Conv2d(base_width * 16, base_width * 8, 3, padding=1),
            nn.BatchNorm2d(base_width * 8),
            nn.ReLU(True),
            nn.Conv2d(base_width * 8, base_width * 8, 3, padding=1),
            nn.BatchNorm2d(base_width * 8),
            nn.ReLU(True)
        )
        
        self.up2 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(base_width * 8, base_width * 8, 3, padding=1),
            nn.BatchNorm2d(base_width * 8),
            nn.ReLU(True)
        )
        
        self.db2 = nn.Sequential(
            nn.Conv2d(base_width * 16, base_width * 8, 3, padding=1),
            nn.BatchNorm2d(base_width * 8),
            nn.ReLU(True),
            nn.Conv2d(base_width * 8, base_width * 8, 3, padding=1),
            nn.BatchNorm2d(base_width * 8),
            nn.ReLU(True)
        )
        
        self.up3 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(base_width * 8, base_width * 4, 3, padding=1),
            nn.BatchNorm2d(base_width * 4),
            nn.ReLU(True)
        )
        
        self.db3 = nn.Sequential(
            nn.Conv2d(base_width * 8, base_width * 4, 3, padding=1),
            nn.BatchNorm2d(base_width * 4),
            nn.ReLU(True),
            nn.Conv2d(base_width * 4, base_width * 4, 3, padding=1),
            nn.BatchNorm2d(base_width * 4),
            nn.ReLU(True)
        )
        
        self.up4 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(base_width * 4, base_width * 2, 3, padding=1),
            nn.BatchNorm2d(base_width * 2),
            nn.ReLU(True)
        )
        
        self.db4 = nn.Sequential(
            nn.Conv2d(base_width * 4, base_width * 2, 3, padding=1),
            nn.BatchNorm2d(base_width * 2),
            nn.ReLU(True),
            nn.Conv2d(base_width * 2, base_width * 2, 3, padding=1),
            nn.BatchNorm2d(base_width * 2),
            nn.ReLU(True)
        )
        
        self.up5 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(base_width * 2, base_width, 3, padding=1),
            nn.BatchNorm2d(base_width),
            nn.ReLU(True)
        )
        
        self.db5 = nn.Sequential(
            nn.Conv2d(base_width * 2, base_width, 3, padding=1),
            nn.BatchNorm2d(base_width),
            nn.ReLU(True),
            nn.Conv2d(base_width, base_width, 3, padding=1),
            nn.BatchNorm2d(base_width),
            nn.ReLU(True)
        )
        
        self.final = nn.Conv2d(base_width, out_channels, 1)
    
    def forward(self, b1, b2, b3, b4, b5, b6):
        up1 = self.up1(b6)
        cat1 = torch.cat((up1, b5), dim=1)
        db1 = self.db1(cat1)
        
        up2 = self.up2(db1)
        cat2 = torch.cat((up2, b4), dim=1)
        db2 = self.db2(cat2)
        
        up3 = self.up3(db2)
        cat3 = torch.cat((up3, b3), dim=1)
        db3 = self.db3(cat3)
        
        up4 = self.up4(db3)
        cat4 = torch.cat((up4, b2), dim=1)
        db4 = self.db4(cat4)
        
        up5 = self.up5(db4)
        cat5 = torch.cat((up5, b1), dim=1)
        db5 = self.db5(cat5)
        
        output = self.final(db5)
        return output


class DiscriminativeSubNetwork(nn.Module):
    def __init__(self, in_channels=6, out_channels=2, base_width=64):
        super().__init__()
        self.encoder = EncoderDiscriminative(in_channels, base_width)
        self.decoder = DecoderDiscriminative(base_width, out_channels)
    
    def forward(self, x):
        b1, b2, b3, b4, b5, b6 = self.encoder(x)
        output = self.decoder(b1, b2, b3, b4, b5, b6)
        return output


class DRAEM(nn.Module):
    """Complete DRAEM model"""
    def __init__(self, in_channels=3, out_channels_seg=2):
        super().__init__()
        self.reconstructive = ReconstructiveSubNetwork(
            in_channels=in_channels,
            out_channels=in_channels,
            base_width=128
        )
        self.discriminative = DiscriminativeSubNetwork(
            in_channels=in_channels * 2,
            out_channels=out_channels_seg,
            base_width=64
        )
    
    def forward(self, x):
        reconstruction = self.reconstructive(x)
        combined = torch.cat([x, reconstruction], dim=1)
        segmentation = self.discriminative(combined)
        return reconstruction, segmentation


print("✅ Complete DRAEM model defined")

# Test model
test_model = DRAEM()
test_input = torch.randn(1, 3, 256, 256)
test_recon, test_seg = test_model(test_input)
print(f"  Input: {test_input.shape}")
print(f"  Reconstruction: {test_recon.shape}")
print(f"  Segmentation: {test_seg.shape}")
print("✅ Model architecture working correctly!")

---

## 8. Dataset and DataLoader

In [None]:
# ============================================================================
# Dataset
# ============================================================================

class DRAEMDataset(Dataset):
    """Dataset for DRAEM training with synthetic anomalies"""
    
    def __init__(self, intact_dir, anomaly_source_path=None, transform=None, image_size=256):
        self.intact_dir = Path(intact_dir)
        self.image_size = image_size
        
        # Get all intact images
        self.image_paths = []
        for ext in ['*.jpg', '*.png', '*.jpeg']:
            self.image_paths.extend(list(self.intact_dir.glob(f'**/{ext}')))
        
        print(f"[OK] Found {len(self.image_paths)} intact images")
        
        # Anomaly generator
        self.anomaly_generator = AnomalyGenerator(
            anomaly_source_path=anomaly_source_path,
            resize_shape=(image_size, image_size)
        )
        
        # Transforms
        if transform is None:
            self.transform = transforms.Compose([
                transforms.Resize((image_size, image_size)),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225]
                )
            ])
        else:
            self.transform = transform
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        # Load intact image
        img_path = self.image_paths[idx]
        img = Image.open(img_path).convert('RGB')
        img = img.resize((self.image_size, self.image_size))
        img_np = np.array(img)
        
        # Generate synthetic anomaly
        aug_img, anomaly_mask = self.anomaly_generator.generate_anomaly(img_np)
        
        # Convert to PIL for transforms
        aug_img_pil = Image.fromarray(aug_img)
        
        # Apply transforms
        intact_tensor = self.transform(img)
        augmented_tensor = self.transform(aug_img_pil)
        
        # Mask to tensor
        mask_tensor = torch.from_numpy(anomaly_mask).unsqueeze(0).float()
        
        return intact_tensor, augmented_tensor, mask_tensor


# Create dataset
train_dataset = DRAEMDataset(
    intact_dir=CONFIG['intact_dir'],
    anomaly_source_path=CONFIG['anomaly_source_path'],
    image_size=CONFIG['image_size']
)

train_loader = DataLoader(
    train_dataset,
    batch_size=CONFIG['batch_size'],
    shuffle=True,
    num_workers=CONFIG['num_workers'],
    pin_memory=CONFIG['pin_memory']
)

print(f"✅ Dataset created: {len(train_dataset)} samples")
print(f"✅ DataLoader created: {len(train_loader)} batches")

# Visualize sample batch
intact, augmented, mask = next(iter(train_loader))

# Denormalize for visualization
def denormalize(tensor):
    mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
    return tensor * std + mean

fig, axes = plt.subplots(3, 4, figsize=(12, 9))
for i in range(min(4, CONFIG['batch_size'])):
    intact_img = denormalize(intact[i]).permute(1, 2, 0).numpy()
    augmented_img = denormalize(augmented[i]).permute(1, 2, 0).numpy()
    mask_img = mask[i].squeeze().numpy()
    
    axes[0, i].imshow(np.clip(intact_img, 0, 1))
    axes[0, i].set_title('Intact')
    axes[0, i].axis('off')
    
    axes[1, i].imshow(np.clip(augmented_img, 0, 1))
    axes[1, i].set_title('Augmented')
    axes[1, i].axis('off')
    
    axes[2, i].imshow(mask_img, cmap='gray')
    axes[2, i].set_title('Mask')
    axes[2, i].axis('off')

plt.tight_layout()
plt.show()

print("✅ Data pipeline working correctly!")

---

## 9. Loss Functions

In [None]:
# ============================================================================
# Loss Functions
# ============================================================================

def ssim(img1, img2, window_size=11):
    """Structural Similarity Index"""
    C1 = 0.01 ** 2
    C2 = 0.03 ** 2
    
    mu1 = F.avg_pool2d(img1, window_size, stride=1, padding=window_size // 2)
    mu2 = F.avg_pool2d(img2, window_size, stride=1, padding=window_size // 2)
    
    mu1_sq = mu1.pow(2)
    mu2_sq = mu2.pow(2)
    mu1_mu2 = mu1 * mu2
    
    sigma1_sq = F.avg_pool2d(img1 * img1, window_size, stride=1, padding=window_size // 2) - mu1_sq
    sigma2_sq = F.avg_pool2d(img2 * img2, window_size, stride=1, padding=window_size // 2) - mu2_sq
    sigma12 = F.avg_pool2d(img1 * img2, window_size, stride=1, padding=window_size // 2) - mu1_mu2
    
    ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / \
               ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
    
    return ssim_map.mean()


class FocalLoss(nn.Module):
    """Focal Loss for handling class imbalance"""
    def __init__(self, alpha=0.25, gamma=2.0):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
    
    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
        return focal_loss.mean()


class DRAEMLoss(nn.Module):
    """Combined loss for DRAEM training"""
    def __init__(self):
        super().__init__()
        self.l2_loss = nn.MSELoss()
        self.focal_loss = FocalLoss()
    
    def forward(self, reconstruction, segmentation, target_img, target_mask):
        # Reconstruction loss
        recon_loss = self.l2_loss(reconstruction, target_img)
        
        # SSIM loss
        ssim_loss = 1 - ssim(reconstruction, target_img)
        
        # Segmentation loss
        target_mask_long = target_mask.squeeze(1).long()
        seg_loss = self.focal_loss(segmentation, target_mask_long)
        
        # Combined loss
        total_loss = recon_loss + ssim_loss + seg_loss
        
        return total_loss, recon_loss, ssim_loss, seg_loss


print("✅ Loss functions defined")

---

## 10. Training Loop

In [None]:
# ============================================================================
# TRAINING
# ============================================================================

# Create model
device = torch.device(CONFIG['device'])
model = DRAEM().to(device)

print(f"Model on device: {device}")
print(f"Model parameters: {sum(p.numel() for p in model.parameters()) / 1e6:.2f}M")

# Loss and optimizer
criterion = DRAEMLoss()
optimizer = torch.optim.Adam([
    {'params': model.reconstructive.parameters(), 'lr': CONFIG['learning_rate']},
    {'params': model.discriminative.parameters(), 'lr': CONFIG['learning_rate']}
])

# Learning rate scheduler
scheduler = torch.optim.lr_scheduler.MultiStepLR(
    optimizer,
    milestones=[
        int(CONFIG['epochs'] * 0.8),
        int(CONFIG['epochs'] * 0.9)
    ],
    gamma=0.2
)

print("\n" + "="*60)
print("STARTING TRAINING")
print("="*60)
print(f"Epochs: {CONFIG['epochs']}")
print(f"Batch size: {CONFIG['batch_size']}")
print(f"Learning rate: {CONFIG['learning_rate']}")
print(f"Training samples: {len(train_dataset)}")
print(f"Batches per epoch: {len(train_loader)}")
print("="*60)

# Training history
history = {
    'total_loss': [],
    'recon_loss': [],
    'ssim_loss': [],
    'seg_loss': []
}

# Training loop
for epoch in range(CONFIG['epochs']):
    model.train()
    
    epoch_loss = 0
    epoch_recon_loss = 0
    epoch_ssim_loss = 0
    epoch_seg_loss = 0
    
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{CONFIG['epochs']}")
    
    for batch_idx, (intact, augmented, mask) in enumerate(pbar):
        intact = intact.to(device)
        augmented = augmented.to(device)
        mask = mask.to(device)
        
        # Forward pass
        reconstruction, segmentation = model(augmented)
        
        # Compute loss
        total_loss, recon_loss, ssim_loss, seg_loss = criterion(
            reconstruction, segmentation, intact, mask
        )
        
        # Backward pass
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()
        
        # Track losses
        epoch_loss += total_loss.item()
        epoch_recon_loss += recon_loss.item()
        epoch_ssim_loss += ssim_loss.item()
        epoch_seg_loss += seg_loss.item()
        
        # Update progress bar
        pbar.set_postfix({
            'loss': f'{total_loss.item():.4f}',
            'recon': f'{recon_loss.item():.4f}',
            'seg': f'{seg_loss.item():.4f}'
        })
    
    # Epoch statistics
    n_batches = len(train_loader)
    avg_total = epoch_loss / n_batches
    avg_recon = epoch_recon_loss / n_batches
    avg_ssim = epoch_ssim_loss / n_batches
    avg_seg = epoch_seg_loss / n_batches
    
    history['total_loss'].append(avg_total)
    history['recon_loss'].append(avg_recon)
    history['ssim_loss'].append(avg_ssim)
    history['seg_loss'].append(avg_seg)
    
    print(f"\nEpoch {epoch+1} Summary:")
    print(f"  Total Loss: {avg_total:.4f}")
    print(f"  Reconstruction Loss: {avg_recon:.4f}")
    print(f"  SSIM Loss: {avg_ssim:.4f}")
    print(f"  Segmentation Loss: {avg_seg:.4f}")
    print(f"  LR: {optimizer.param_groups[0]['lr']:.6f}")
    
    # Step scheduler
    scheduler.step()
    
    # Save checkpoint
    if (epoch + 1) % CONFIG['save_interval'] == 0:
        checkpoint_path = f"{CONFIG['output_dir']}/draem_epoch_{epoch+1}.pth"
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'loss': avg_total,
            'config': CONFIG
        }, checkpoint_path)
        print(f"  [OK] Checkpoint saved: {checkpoint_path}")

print("\n" + "="*60)
print("TRAINING COMPLETE!")
print("="*60)

---

## 11. Save Final Model

In [None]:
# Save final model
final_model_path = f"{CONFIG['output_dir']}/draem_final.pth"

torch.save({
    'model_state_dict': model.state_dict(),
    'config': CONFIG,
    'history': history
}, final_model_path)

print(f"✅ Final model saved: {final_model_path}")
print(f"✅ Model size: {os.path.getsize(final_model_path) / 1e6:.2f} MB")

---

## 12. Plot Training History

In [None]:
# Plot training curves
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

axes[0, 0].plot(history['total_loss'])
axes[0, 0].set_title('Total Loss')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].grid(True)

axes[0, 1].plot(history['recon_loss'])
axes[0, 1].set_title('Reconstruction Loss')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Loss')
axes[0, 1].grid(True)

axes[1, 0].plot(history['ssim_loss'])
axes[1, 0].set_title('SSIM Loss')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('Loss')
axes[1, 0].grid(True)

axes[1, 1].plot(history['seg_loss'])
axes[1, 1].set_title('Segmentation Loss')
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('Loss')
axes[1, 1].grid(True)

plt.tight_layout()
plt.savefig(f"{CONFIG['output_dir']}/training_history.png", dpi=150)
plt.show()

print(f"✅ Training history saved")

---

## 13. Download Model

Download the trained model from Kaggle to use locally!

In [None]:
# List all saved files
print("Saved files:")
for file in Path(CONFIG['output_dir']).glob('*.pth'):
    print(f"  {file.name} ({file.stat().st_size / 1e6:.2f} MB)")

print("\n✅ DRAEM training complete!")
print("\nNext steps:")
print("1. Download draem_final.pth from Kaggle output")
print("2. Use for inference on your local machine")
print("3. Integrate with Stage 1 classifier for complete pipeline")

---

## Done! 🎉

Your DRAEM model is trained!

**Download `draem_final.pth` and use it for inference.**

**No annotation was needed!** 🚀