# DRAEM Training on Kaggle - Stage 2: Defect Localization

This notebook trains the DRAEM model for bike defect localization.

**No annotation required!** Trains only on intact images.

---

## Setup

### 1. Upload Your Data

Upload your data folder to Kaggle with this structure:
```
/kaggle/input/bike-data/
‚îî‚îÄ‚îÄ intact/
    ‚îú‚îÄ‚îÄ bike001.jpg
    ‚îú‚îÄ‚îÄ bike002.jpg
    ‚îî‚îÄ‚îÄ ...
```

### 2. Run All Cells

Just run all cells in order!

---

## 1. Install Dependencies

In [None]:
# Cell 1: Installation (FIXED)
!pip uninstall -y numpy -q
!pip install -q numpy==1.24.4
!pip install -q opencv-python-headless
!pip install -q scikit-image

import warnings
warnings.filterwarnings('ignore')

print("‚úÖ Dependencies installed!")

---

## 2. Import Libraries

In [None]:
# Cell 2: Imports
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
import cv2
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm import tqdm
import random
import os

torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

---

## 3. Configuration

In [None]:
# Cell 3: Configuration
CONFIG = {
    'intact_dir': '/kaggle/input/datasetprivet/data/processed/intact',
    'epochs': 150,
    'batch_size': 8,
    'learning_rate': 0.0001,
    'image_size': 256,
    'train_split': 0.85,
    'val_split': 0.15,
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    'num_workers': 0,
    'pin_memory': True,
    'min_perlin_scale': 0,
    'max_perlin_scale': 6,
    'augmentation_prob': 0.8,
    'output_dir': '/kaggle/working',
    'save_interval': 10,
    'early_stopping_patience': 20,
}

os.makedirs(CONFIG['output_dir'], exist_ok=True)
print("‚úÖ Configuration set")

---

## 4. Verify Data

In [None]:
# Check if data exists
intact_dir = Path(CONFIG['intact_dir'])

if not intact_dir.exists():
    print(f"‚ùå ERROR: Data directory not found: {intact_dir}")
    print("\nPlease:")
    print("1. Upload your data to Kaggle")
    print("2. Update CONFIG['intact_dir'] in the cell above")
    print("3. Re-run this cell")
else:
    # Count images
    image_files = []
    for ext in ['*.jpg', '*.png', '*.jpeg']:
        image_files.extend(list(intact_dir.glob(f'**/{ext}')))
    
    print(f"‚úÖ Data directory found: {intact_dir}")
    print(f"‚úÖ Found {len(image_files)} intact images")
    
    # Show sample images
    if len(image_files) > 0:
        fig, axes = plt.subplots(1, min(5, len(image_files)), figsize=(15, 3))
        if len(image_files) == 1:
            axes = [axes]
        
        for i, img_path in enumerate(image_files[:5]):
            img = Image.open(img_path)
            axes[i].imshow(img)
            axes[i].set_title(f'{img_path.name}\n{img.size}')
            axes[i].axis('off')
        
        plt.tight_layout()
        plt.show()
        
        print(f"\n‚úÖ Ready to train!")
    else:
        print("\n‚ùå No images found! Please check your data directory.")

---

## 5. Perlin Noise Generator

In [None]:
# Cell 4: Check Data
intact_dir = Path(CONFIG['intact_dir'])

image_files = []
for ext in ['*.jpg', '*.png', '*.jpeg', '*.JPG', '*.PNG']:
    image_files.extend(list(intact_dir.glob(ext)))

print(f"‚úÖ Found {len(image_files)} intact images")
print(f"   Train: {int(len(image_files) * CONFIG['train_split'])}")
print(f"   Val: {len(image_files) - int(len(image_files) * CONFIG['train_split'])}")

---

## 6. Anomaly Generator

In [None]:
# Cell 5: Perlin Noise (FIXED)
def lerp(a, b, t):
    return a + t * (b - a)

def rand_perlin_2d(shape, res, fade=lambda t: 6*t**5 - 15*t**4 + 10*t**3):
    delta = (res[0] / shape[0], res[1] / shape[1])
    d = (shape[0] // res[0], shape[1] // res[1])
    grid = np.mgrid[0:res[0]:delta[0], 0:res[1]:delta[1]].transpose(1, 2, 0) % 1
    angles = 2 * np.pi * np.random.rand(res[0] + 1, res[1] + 1)
    gradients = np.stack((np.cos(angles), np.sin(angles)), axis=-1)
    tile_grads = lambda slice1, slice2: gradients[slice1[0]:slice1[1], slice2[0]:slice2[1]].repeat(d[0], 0).repeat(d[1], 1)
    dot = lambda grad, shift: (np.stack((grid[:shape[0], :shape[1], 0] + shift[0], grid[:shape[0], :shape[1], 1] + shift[1]), axis=-1) * grad[:shape[0], :shape[1]]).sum(axis=-1)
    n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0])
    n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0])
    n01 = dot(tile_grads([0, -1], [1, None]), [0, -1])
    n11 = dot(tile_grads([1, None], [1, None]), [-1, -1])
    t = fade(grid[:shape[0], :shape[1]])
    return np.sqrt(2) * lerp(lerp(n00, n10, t[..., 0]), lerp(n01, n11, t[..., 0]), t[..., 1])

def generate_perlin_noise_mask(shape, min_scale=0, max_scale=6):
    perlin_scalex = 2 ** np.random.randint(min_scale, max_scale)
    perlin_scaley = 2 ** np.random.randint(min_scale, max_scale)
    perlin_noise = rand_perlin_2d((shape[0], shape[1]), (perlin_scalex, perlin_scaley))
    perlin_noise = (perlin_noise - perlin_noise.min()) / (perlin_noise.max() - perlin_noise.min())
    threshold = np.random.uniform(0.3, 0.7)
    mask = (perlin_noise > threshold).astype(np.float32)
    return mask, perlin_noise

print("‚úÖ Perlin noise generator ready")

---

## 7. DRAEM Model Architecture

In [None]:
# ============================================================================
# DRAEM Model Architecture
# ============================================================================

class EncoderReconstructive(nn.Module):
    def __init__(self, in_channels, base_width):
        super().__init__()
        
        self.block1 = nn.Sequential(
            nn.Conv2d(in_channels, base_width, 3, padding=1),
            nn.BatchNorm2d(base_width),
            nn.ReLU(True),
            nn.Conv2d(base_width, base_width, 3, padding=1),
            nn.BatchNorm2d(base_width),
            nn.ReLU(True)
        )
        self.mp1 = nn.MaxPool2d(2)
        
        self.block2 = nn.Sequential(
            nn.Conv2d(base_width, base_width * 2, 3, padding=1),
            nn.BatchNorm2d(base_width * 2),
            nn.ReLU(True),
            nn.Conv2d(base_width * 2, base_width * 2, 3, padding=1),
            nn.BatchNorm2d(base_width * 2),
            nn.ReLU(True)
        )
        self.mp2 = nn.MaxPool2d(2)
        
        self.block3 = nn.Sequential(
            nn.Conv2d(base_width * 2, base_width * 4, 3, padding=1),
            nn.BatchNorm2d(base_width * 4),
            nn.ReLU(True),
            nn.Conv2d(base_width * 4, base_width * 4, 3, padding=1),
            nn.BatchNorm2d(base_width * 4),
            nn.ReLU(True)
        )
        self.mp3 = nn.MaxPool2d(2)
        
        self.block4 = nn.Sequential(
            nn.Conv2d(base_width * 4, base_width * 8, 3, padding=1),
            nn.BatchNorm2d(base_width * 8),
            nn.ReLU(True),
            nn.Conv2d(base_width * 8, base_width * 8, 3, padding=1),
            nn.BatchNorm2d(base_width * 8),
            nn.ReLU(True)
        )
        self.mp4 = nn.MaxPool2d(2)
        
        self.block5 = nn.Sequential(
            nn.Conv2d(base_width * 8, base_width * 8, 3, padding=1),
            nn.BatchNorm2d(base_width * 8),
            nn.ReLU(True),
            nn.Conv2d(base_width * 8, base_width * 8, 3, padding=1),
            nn.BatchNorm2d(base_width * 8),
            nn.ReLU(True)
        )
    
    def forward(self, x):
        b1 = self.block1(x)
        mp1 = self.mp1(b1)
        b2 = self.block2(mp1)
        mp2 = self.mp2(b2)
        b3 = self.block3(mp2)
        mp3 = self.mp3(b3)
        b4 = self.block4(mp3)
        mp4 = self.mp4(b4)
        b5 = self.block5(mp4)
        return b5


class DecoderReconstructive(nn.Module):
    def __init__(self, base_width, out_channels):
        super().__init__()
        
        self.up1 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(base_width * 8, base_width * 8, 3, padding=1),
            nn.BatchNorm2d(base_width * 8),
            nn.ReLU(True)
        )
        
        self.up2 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(base_width * 8, base_width * 4, 3, padding=1),
            nn.BatchNorm2d(base_width * 4),
            nn.ReLU(True)
        )
        
        self.up3 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(base_width * 4, base_width * 2, 3, padding=1),
            nn.BatchNorm2d(base_width * 2),
            nn.ReLU(True)
        )
        
        self.up4 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(base_width * 2, base_width, 3, padding=1),
            nn.BatchNorm2d(base_width),
            nn.ReLU(True)
        )
        
        self.final = nn.Conv2d(base_width, out_channels, 1)
    
    def forward(self, b5):
        up1 = self.up1(b5)
        up2 = self.up2(up1)
        up3 = self.up3(up2)
        up4 = self.up4(up3)
        output = self.final(up4)
        return output


class ReconstructiveSubNetwork(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, base_width=128):
        super().__init__()
        self.encoder = EncoderReconstructive(in_channels, base_width)
        self.decoder = DecoderReconstructive(base_width, out_channels)
    
    def forward(self, x):
        b5 = self.encoder(x)
        output = self.decoder(b5)
        return output


# Similar implementation for Discriminative network...
# (Code continues in next cell due to length)

print("‚úÖ Reconstructive network defined")

In [None]:
# Cell 6: Dataset (FIXED)
class DRAEMDataset(Dataset):
    def __init__(self, image_paths, image_size=256, augmentation_prob=0.8,
                 min_perlin_scale=0, max_perlin_scale=6, is_train=True):
        self.image_paths = image_paths
        self.image_size = image_size
        self.augmentation_prob = augmentation_prob
        self.min_perlin_scale = min_perlin_scale
        self.max_perlin_scale = max_perlin_scale
        self.is_train = is_train
        self.transform = transforms.Compose([
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    
    def __len__(self):
        return len(self.image_paths)
    
    def augment_with_transforms(self, image):
        if not self.is_train:
            return image
        image = np.array(image, dtype=np.uint8, copy=True)
        if np.random.rand() < 0.5:
            image = np.fliplr(image).copy()
        if np.random.rand() < 0.3:
            image = np.flipud(image).copy()
        if np.random.rand() < 0.5:
            angle = np.random.uniform(-20, 20)
            h, w = image.shape[:2]
            M = cv2.getRotationMatrix2D((w//2, h//2), angle, 1.0)
            image = cv2.warpAffine(np.ascontiguousarray(image, dtype=np.uint8), M, (w, h))
        if np.random.rand() < 0.5:
            image = cv2.convertScaleAbs(image, alpha=np.random.uniform(0.7, 1.3), beta=np.random.uniform(-30, 30))
        return image
    
    def augment_image(self, image):
        image = np.array(image, dtype=np.uint8, copy=True)
        perlin_mask, _ = generate_perlin_noise_mask(image.shape[:2], self.min_perlin_scale, self.max_perlin_scale)
        perlin_mask_3ch = np.repeat(perlin_mask[:, :, np.newaxis], 3, axis=2)
        if np.random.rand() < 0.5:
            anomaly_texture = np.random.randint(0, 255, image.shape, dtype=np.uint8)
        else:
            anomaly_texture = np.clip(image.astype(np.float32) * np.random.uniform(0.3, 1.5), 0, 255).astype(np.uint8)
        augmented_image = (image.astype(np.float32) * (1 - perlin_mask_3ch) + anomaly_texture.astype(np.float32) * perlin_mask_3ch).astype(np.uint8)
        return augmented_image, perlin_mask
    
    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = cv2.imread(str(img_path))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = cv2.resize(image, (self.image_size, self.image_size))
        image = np.array(image, dtype=np.uint8, copy=True)
        image = self.augment_with_transforms(image)
        if np.random.rand() < self.augmentation_prob:
            augmented_image, mask = self.augment_image(image)
        else:
            augmented_image = image.copy()
            mask = np.zeros((self.image_size, self.image_size), dtype=np.float32)
        image = np.array(image, dtype=np.uint8, copy=True)
        augmented_image = np.array(augmented_image, dtype=np.uint8, copy=True)
        mask = np.array(mask, dtype=np.float32, copy=True)
        image_pil = Image.fromarray(image.copy())
        augmented_pil = Image.fromarray(augmented_image.copy())
        image_tensor = self.transform(image_pil)
        augmented_tensor = self.transform(augmented_pil)
        mask_tensor = torch.from_numpy(mask.copy()).unsqueeze(0)
        return image_tensor, augmented_tensor, mask_tensor

# Create datasets
image_paths = []
for ext in ['*.jpg', '*.png', '*.jpeg', '*.JPG', '*.PNG']:
    image_paths.extend(list(Path(CONFIG['intact_dir']).glob(ext)))

random.shuffle(image_paths)
n_train = int(len(image_paths) * CONFIG['train_split'])
train_paths = image_paths[:n_train]
val_paths = image_paths[n_train:]

train_dataset = DRAEMDataset(train_paths, CONFIG['image_size'], CONFIG['augmentation_prob'], 
                              CONFIG['min_perlin_scale'], CONFIG['max_perlin_scale'], True)
val_dataset = DRAEMDataset(val_paths, CONFIG['image_size'], CONFIG['augmentation_prob'],
                            CONFIG['min_perlin_scale'], CONFIG['max_perlin_scale'], False)

train_loader = DataLoader(train_dataset, batch_size=CONFIG['batch_size'], shuffle=True, 
                          num_workers=0, pin_memory=True, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=CONFIG['batch_size'], shuffle=False,
                        num_workers=0, pin_memory=True)

print(f"‚úÖ Datasets: Train={len(train_dataset)}, Val={len(val_dataset)}")

---

## 8. Dataset and DataLoader

In [None]:
# Cell 7: Model Architecture
class ReconstructiveSubNetwork(nn.Module):
    def __init__(self, in_channels=3, out_channels=3):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, 32, 4, 2, 1), nn.BatchNorm2d(32), nn.ReLU(True),
            nn.Conv2d(32, 64, 4, 2, 1), nn.BatchNorm2d(64), nn.ReLU(True),
            nn.Conv2d(64, 128, 4, 2, 1), nn.BatchNorm2d(128), nn.ReLU(True),
            nn.Conv2d(128, 256, 4, 2, 1), nn.BatchNorm2d(256), nn.ReLU(True),
        )
        self.bottleneck = nn.Sequential(
            nn.Conv2d(256, 256, 3, 1, 1), nn.BatchNorm2d(256), nn.ReLU(True),
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, 4, 2, 1), nn.BatchNorm2d(128), nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1), nn.BatchNorm2d(64), nn.ReLU(True),
            nn.ConvTranspose2d(64, 32, 4, 2, 1), nn.BatchNorm2d(32), nn.ReLU(True),
            nn.ConvTranspose2d(32, out_channels, 4, 2, 1), nn.Tanh()
        )
    def forward(self, x):
        return self.decoder(self.bottleneck(self.encoder(x)))

class DiscriminativeSubNetwork(nn.Module):
    def __init__(self, in_channels=6, out_channels=2):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, 64, 4, 2, 1), nn.BatchNorm2d(64), nn.LeakyReLU(0.2, True),
            nn.Conv2d(64, 128, 4, 2, 1), nn.BatchNorm2d(128), nn.LeakyReLU(0.2, True),
            nn.Conv2d(128, 256, 4, 2, 1), nn.BatchNorm2d(256), nn.LeakyReLU(0.2, True),
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, 4, 2, 1), nn.BatchNorm2d(128), nn.LeakyReLU(0.2, True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1), nn.BatchNorm2d(64), nn.LeakyReLU(0.2, True),
            nn.ConvTranspose2d(64, out_channels, 4, 2, 1),
        )
    def forward(self, x):
        return self.decoder(self.encoder(x))

class DRAEM(nn.Module):
    def __init__(self):
        super().__init__()
        self.reconstructive = ReconstructiveSubNetwork()
        self.discriminative = DiscriminativeSubNetwork()
    def forward(self, x):
        reconstruction = self.reconstructive(x)
        segmentation = self.discriminative(torch.cat([x, reconstruction], dim=1))
        return reconstruction, segmentation

print("‚úÖ Model architecture defined")

---

## 9. Loss Functions

In [None]:
# Cell 8: SSIM Loss
def ssim(img1, img2, window_size=11):
    img1 = (img1 + 1) / 2
    img2 = (img2 + 1) / 2
    channel = img1.size(1)
    def gaussian(ws, sigma=1.5):
        gauss = torch.Tensor([np.exp(-(x - ws//2)**2 / (2*sigma**2)) for x in range(ws)])
        return gauss / gauss.sum()
    _1D = gaussian(window_size).unsqueeze(1)
    _2D = _1D.mm(_1D.t()).float().unsqueeze(0).unsqueeze(0)
    window = _2D.expand(channel, 1, window_size, window_size).contiguous().to(img1.device)
    mu1 = F.conv2d(img1, window, padding=window_size//2, groups=channel)
    mu2 = F.conv2d(img2, window, padding=window_size//2, groups=channel)
    mu1_sq, mu2_sq, mu1_mu2 = mu1.pow(2), mu2.pow(2), mu1 * mu2
    sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size//2, groups=channel) - mu1_sq
    sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size//2, groups=channel) - mu2_sq
    sigma12 = F.conv2d(img1 * img2, window, padding=window_size//2, groups=channel) - mu1_mu2
    C1, C2 = 0.01 ** 2, 0.03 ** 2
    ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
    return ssim_map.mean()

print("‚úÖ SSIM ready")

---

## 10. Training Loop

In [None]:
# Cell 9: Loss Functions
class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        return (self.alpha * (1 - pt) ** self.gamma * ce_loss).mean()

class DRAEMLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.l2_loss = nn.MSELoss()
        self.focal_loss = FocalLoss()
    def forward(self, reconstruction, segmentation, target_img, target_mask):
        recon_loss = self.l2_loss(reconstruction, target_img)
        ssim_loss = 1 - ssim(reconstruction, target_img)
        seg_loss = self.focal_loss(segmentation, target_mask.squeeze(1).long())
        return recon_loss + ssim_loss + seg_loss, recon_loss, ssim_loss, seg_loss

print("‚úÖ Losses ready")

---

## 11. Save Final Model

In [None]:
# Cell 10: Training Functions
def train_one_epoch(model, loader, criterion, optimizer, device, epoch):
    model.train()
    total_loss = recon_loss = ssim_loss = seg_loss = 0
    for intact, augmented, mask in tqdm(loader, desc=f"Epoch {epoch+1}"):
        intact, augmented, mask = intact.to(device), augmented.to(device), mask.to(device)
        reconstruction, segmentation = model(augmented)
        loss, r_loss, s_loss, sg_loss = criterion(reconstruction, segmentation, intact, mask)
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        total_loss += loss.item()
        recon_loss += r_loss.item()
        ssim_loss += s_loss.item()
        seg_loss += sg_loss.item()
    n = len(loader)
    return {'total': total_loss/n, 'recon': recon_loss/n, 'ssim': ssim_loss/n, 'seg': seg_loss/n}

def validate(model, loader, criterion, device):
    model.eval()
    total_loss = recon_loss = ssim_loss = seg_loss = 0
    with torch.no_grad():
        for intact, augmented, mask in loader:
            intact, augmented, mask = intact.to(device), augmented.to(device), mask.to(device)
            reconstruction, segmentation = model(augmented)
            loss, r_loss, s_loss, sg_loss = criterion(reconstruction, segmentation, intact, mask)
            total_loss += loss.item()
            recon_loss += r_loss.item()
            ssim_loss += s_loss.item()
            seg_loss += sg_loss.item()
    n = len(loader)
    return {'total': total_loss/n, 'recon': recon_loss/n, 'ssim': ssim_loss/n, 'seg': seg_loss/n}

print("‚úÖ Training functions ready")

---

## 12. Plot Training History

In [None]:
# Cell 11: Main Training Loop (FINAL FIXED VERSION)
device = torch.device(CONFIG['device'])
model = DRAEM().to(device)
criterion = DRAEMLoss().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=CONFIG['learning_rate'])

# FIXED: Remove 'verbose' parameter (doesn't exist in ReduceLROnPlateau)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='min',
    factor=0.5,
    patience=10,
    min_lr=1e-7
)

print(f"Model params: {sum(p.numel() for p in model.parameters())/1e6:.2f}M")
print(f"Starting training for {CONFIG['epochs']} epochs...\n")

history = {
    'train_loss': [], 'val_loss': [], 
    'train_recon': [], 'val_recon': [],
    'train_ssim': [], 'val_ssim': [], 
    'train_seg': [], 'val_seg': [], 
    'learning_rates': []
}
best_val_loss = float('inf')
patience_counter = 0

for epoch in range(CONFIG['epochs']):
    train_metrics = train_one_epoch(model, train_loader, criterion, optimizer, device, epoch)
    val_metrics = validate(model, val_loader, criterion, device)
    current_lr = optimizer.param_groups[0]['lr']
    
    # Store metrics
    history['train_loss'].append(train_metrics['total'])
    history['val_loss'].append(val_metrics['total'])
    history['train_recon'].append(train_metrics['recon'])
    history['val_recon'].append(val_metrics['recon'])
    history['train_ssim'].append(train_metrics['ssim'])
    history['val_ssim'].append(val_metrics['ssim'])
    history['train_seg'].append(train_metrics['seg'])
    history['val_seg'].append(val_metrics['seg'])
    history['learning_rates'].append(current_lr)
    
    print(f"\n{'='*70}")
    print(f"Epoch {epoch+1}/{CONFIG['epochs']}")
    print(f"{'='*70}")
    print(f"TRAIN - Total: {train_metrics['total']:.4f}, Recon: {train_metrics['recon']:.4f}, "
          f"SSIM: {train_metrics['ssim']:.4f}, Seg: {train_metrics['seg']:.4f}")
    print(f"VAL   - Total: {val_metrics['total']:.4f}, Recon: {val_metrics['recon']:.4f}, "
          f"SSIM: {val_metrics['ssim']:.4f}, Seg: {val_metrics['seg']:.4f}")
    print(f"LR: {current_lr:.6f}")
    
    # Step scheduler
    old_lr = current_lr
    scheduler.step(val_metrics['total'])
    new_lr = optimizer.param_groups[0]['lr']
    if new_lr < old_lr:
        print(f"üìâ Learning rate reduced: {old_lr:.6f} ‚Üí {new_lr:.6f}")
    
    if val_metrics['total'] < best_val_loss:
        best_val_loss = val_metrics['total']
        patience_counter = 0
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'val_loss': best_val_loss,
            'config': CONFIG,
            'history': history
        }, f"{CONFIG['output_dir']}/draem_best.pth")
        print(f"‚≠ê NEW BEST MODEL! Val Loss: {best_val_loss:.4f}")
    else:
        patience_counter += 1
        print(f"No improvement for {patience_counter} epochs")
    
    if (epoch + 1) % CONFIG['save_interval'] == 0:
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': val_metrics['total'],
            'config': CONFIG
        }, f"{CONFIG['output_dir']}/draem_epoch_{epoch+1}.pth")
        print(f"üíæ Checkpoint saved: epoch_{epoch+1}.pth")
    
    if patience_counter >= CONFIG['early_stopping_patience']:
        print(f"\n{'='*70}")
        print(f"‚ö†Ô∏è EARLY STOPPING at epoch {epoch+1}")
        print(f"No improvement for {CONFIG['early_stopping_patience']} epochs")
        print(f"{'='*70}")
        break
    
    print(f"{'='*70}\n")

print("\n" + "="*70)
print("TRAINING COMPLETE!")
print("="*70)
print(f"Best Val Loss: {best_val_loss:.4f}")
print(f"Total Epochs: {len(history['train_loss'])}")
print("="*70)

---

## 13. Download Model

Download the trained model from Kaggle to use locally!

In [None]:
# Cell 12: Save Final Model
torch.save({'model_state_dict': model.state_dict(), 'config': CONFIG, 
           'history': history, 'best_val_loss': best_val_loss},
          f"{CONFIG['output_dir']}/draem_final.pth")
print("‚úÖ Final model saved")

In [None]:
# Cell 13: Plot Training History
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
epochs = range(1, len(history['train_loss'])+1)

axes[0,0].plot(epochs, history['train_loss'], 'b-', label='Train', lw=2)
axes[0,0].plot(epochs, history['val_loss'], 'r-', label='Val', lw=2)
axes[0,0].set_title('Total Loss')
axes[0,0].legend()
axes[0,0].grid(True, alpha=0.3)

axes[0,1].plot(epochs, history['train_recon'], 'b-', label='Train', lw=2)
axes[0,1].plot(epochs, history['val_recon'], 'r-', label='Val', lw=2)
axes[0,1].set_title('Reconstruction Loss')
axes[0,1].legend()
axes[0,1].grid(True, alpha=0.3)

axes[1,0].plot(epochs, history['train_ssim'], 'b-', label='Train', lw=2)
axes[1,0].plot(epochs, history['val_ssim'], 'r-', label='Val', lw=2)
axes[1,0].set_title('SSIM Loss')
axes[1,0].legend()
axes[1,0].grid(True, alpha=0.3)

axes[1,1].plot(epochs, history['train_seg'], 'b-', label='Train', lw=2)
axes[1,1].plot(epochs, history['val_seg'], 'r-', label='Val', lw=2)
axes[1,1].set_title('Segmentation Loss')
axes[1,1].legend()
axes[1,1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(f"{CONFIG['output_dir']}/training_history.png", dpi=300)
plt.show()
print("‚úÖ Training history saved")

In [None]:
# Cell 14: Summary
print("\n" + "="*70)
print("TRAINING COMPLETE!")
print("="*70)
print(f"Best Val Loss: {best_val_loss:.4f}")
print(f"Total Epochs: {len(history['train_loss'])}")
print("\nSaved files:")
for f in Path(CONFIG['output_dir']).glob('*.pth'):
    print(f"  ‚Ä¢ {f.name} ({f.stat().st_size/1e6:.2f} MB)")
print("\nüéâ SUCCESS! Download draem_best.pth for deployment!")
print("="*70)

---

## Done! üéâ

Your DRAEM model is trained!

**Download `draem_final.pth` and use it for inference.**

**No annotation was needed!** üöÄ