In [None]:
# U-Net based Image Segmentation Generator used for mask prediction

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset
import numpy as np
import cv2
from PIL import Image
import os
from tqdm import tqdm
from sklearn.model_selection import train_test_split
import random

# =============================================================================
# 1. LIGHTWEIGHT U-NET ARCHITECTURE (SAME AS BEFORE)
# =============================================================================

class DoubleConv(nn.Module):
    """Double Convolution block: Conv -> BN -> ReLU -> Conv -> BN -> ReLU"""
    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    """Downscaling with maxpool then double conv"""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )
    
    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    """Upscaling then double conv"""
    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()
        
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)
    
    def forward(self, x1, x2):
        x1 = self.up(x1)
        
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]
        
        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class LightweightUNet(nn.Module):
    """Lightweight U-Net for mask prediction"""
    def __init__(self, n_channels=3, n_classes=1, bilinear=True):
        super(LightweightUNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear
        
        self.inc = DoubleConv(n_channels, 32)
        self.down1 = Down(32, 64)
        self.down2 = Down(64, 128)
        self.down3 = Down(128, 256)
        factor = 2 if bilinear else 1
        self.down4 = Down(256, 512 // factor)
        
        self.up1 = Up(512, 256 // factor, bilinear)
        self.up2 = Up(256, 128 // factor, bilinear)
        self.up3 = Up(128, 64 // factor, bilinear)
        self.up4 = Up(64, 32, bilinear)
        
        self.outc = nn.Conv2d(32, n_classes, kernel_size=1)
    
    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        
        logits = self.outc(x)
        return torch.sigmoid(logits)


# =============================================================================
# 2. SYNTHETIC DISTORTION GENERATOR (SAME AS BEFORE)
# =============================================================================

class SyntheticDistortionGenerator:
    """Generate synthetic distortions for training"""
    
    @staticmethod
    def add_gaussian_noise(image, severity=0.1):
        """Add Gaussian noise"""
        noise = np.random.normal(0, severity * 255, image.shape).astype(np.float32)
        noisy = np.clip(image.astype(np.float32) + noise, 0, 255).astype(np.uint8)
        
        mask = np.zeros(image.shape[:2], dtype=np.uint8)
        num_patches = np.random.randint(3, 10)
        
        for _ in range(num_patches):
            h, w = image.shape[:2]
            x = np.random.randint(0, max(1, w - 50))
            y = np.random.randint(0, max(1, h - 50))
            patch_w = np.random.randint(20, min(100, w - x))
            patch_h = np.random.randint(20, min(100, h - y))
            mask[y:y+patch_h, x:x+patch_w] = 1
        
        return noisy, mask
    
    @staticmethod
    def add_scratches(image, num_scratches=5):
        """Add random scratches"""
        scratched = image.copy()
        mask = np.zeros(image.shape[:2], dtype=np.uint8)
        
        h, w = image.shape[:2]
        
        for _ in range(num_scratches):
            x1, y1 = np.random.randint(0, w), np.random.randint(0, h)
            x2, y2 = np.random.randint(0, w), np.random.randint(0, h)
            thickness = np.random.randint(1, 5)
            
            cv2.line(scratched, (x1, y1), (x2, y2), (255, 255, 255), thickness)
            cv2.line(mask, (x1, y1), (x2, y2), 1, thickness)
        
        return scratched, mask
    
    @staticmethod
    def add_blur_patches(image, num_patches=3):
        """Add blurred patches"""
        blurred = image.copy()
        mask = np.zeros(image.shape[:2], dtype=np.uint8)
        
        h, w = image.shape[:2]
        
        for _ in range(num_patches):
            x = np.random.randint(0, max(1, w - 50))
            y = np.random.randint(0, max(1, h - 50))
            patch_w = np.random.randint(40, min(150, w - x))
            patch_h = np.random.randint(40, min(150, h - y))
            
            patch = blurred[y:y+patch_h, x:x+patch_w]
            kernel_size = np.random.choice([15, 21, 31])
            blurred_patch = cv2.GaussianBlur(patch, (kernel_size, kernel_size), 0)
            blurred[y:y+patch_h, x:x+patch_w] = blurred_patch
            
            mask[y:y+patch_h, x:x+patch_w] = 1
        
        return blurred, mask
    
    @staticmethod
    def add_jpeg_artifacts(image, quality=20):
        """Add JPEG compression artifacts"""
        encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), quality]
        _, encimg = cv2.imencode('.jpg', image, encode_param)
        compressed = cv2.imdecode(encimg, 1)
        
        diff = np.abs(image.astype(np.float32) - compressed.astype(np.float32))
        diff_gray = cv2.cvtColor(diff.astype(np.uint8), cv2.COLOR_BGR2GRAY)
        _, mask = cv2.threshold(diff_gray, 10, 1, cv2.THRESH_BINARY)
        
        return compressed, mask
    
    @staticmethod
    def add_random_patches(image, num_patches=5):
        """Add random missing patches"""
        damaged = image.copy()
        mask = np.zeros(image.shape[:2], dtype=np.uint8)
        
        h, w = image.shape[:2]
        
        for _ in range(num_patches):
            x = np.random.randint(0, max(1, w - 30))
            y = np.random.randint(0, max(1, h - 30))
            patch_w = np.random.randint(20, min(80, w - x))
            patch_h = np.random.randint(20, min(80, h - y))
            
            color = np.random.randint(0, 256, 3).tolist()
            damaged[y:y+patch_h, x:x+patch_w] = color
            mask[y:y+patch_h, x:x+patch_w] = 1
        
        return damaged, mask
    
    @classmethod
    def generate_distortion(cls, image):
        """Apply random distortion and return (distorted_image, mask)"""
        distortion_type = np.random.choice([
            'noise', 'scratches', 'blur', 'jpeg', 'patches', 'combined'
        ])
        
        if distortion_type == 'noise':
            return cls.add_gaussian_noise(image, severity=np.random.uniform(0.05, 0.15))
        elif distortion_type == 'scratches':
            return cls.add_scratches(image, num_scratches=np.random.randint(3, 8))
        elif distortion_type == 'blur':
            return cls.add_blur_patches(image, num_patches=np.random.randint(2, 5))
        elif distortion_type == 'jpeg':
            return cls.add_jpeg_artifacts(image, quality=np.random.randint(10, 30))
        elif distortion_type == 'patches':
            return cls.add_random_patches(image, num_patches=np.random.randint(3, 7))
        else:  # combined
            img, mask = cls.add_scratches(image, num_scratches=2)
            img, mask2 = cls.add_blur_patches(img, num_patches=1)
            mask = np.logical_or(mask, mask2).astype(np.uint8)
            return img, mask


# =============================================================================
# 3. MODIFIED DATASET CLASS FOR KAGGLE (ACCEPTS FILE LIST)
# =============================================================================

class DistortionDataset(Dataset):
    """
    Dataset for training mask prediction
    Modified to work with a list of image files from a single directory
    """
    
    def __init__(self, image_dir, image_files, transform=None, use_synthetic=True, img_size=256):
        """
        Args:
            image_dir: Base directory containing all images
            image_files: List of image filenames to use (after splitting)
            transform: Optional transforms
            use_synthetic: Whether to generate synthetic distortions
            img_size: Target image size
        """
        self.image_dir = image_dir
        self.image_files = image_files  # Only use specified files
        self.transform = transform
        self.use_synthetic = use_synthetic
        self.img_size = img_size
        
        self.distortion_gen = SyntheticDistortionGenerator()
    
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        # Load clean image
        img_path = os.path.join(self.image_dir, self.image_files[idx])
        
        try:
            image = cv2.imread(img_path)
            if image is None:
                raise ValueError(f"Failed to load image: {img_path}")
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        except Exception as e:
            print(f"Error loading {img_path}: {e}")
            # Return a blank image if loading fails
            image = np.zeros((self.img_size, self.img_size, 3), dtype=np.uint8)
        
        # Resize
        image = cv2.resize(image, (self.img_size, self.img_size))
        
        # Generate synthetic distortion
        if self.use_synthetic:
            distorted_image, mask = self.distortion_gen.generate_distortion(image)
        else:
            distorted_image = image
            mask = np.zeros(image.shape[:2], dtype=np.uint8)
        
        # Convert to tensors
        distorted_image = torch.from_numpy(distorted_image).permute(2, 0, 1).float() / 255.0
        mask = torch.from_numpy(mask).unsqueeze(0).float()
        
        if self.transform:
            state = torch.get_rng_state()
            distorted_image = self.transform(distorted_image)
            torch.set_rng_state(state)
            mask = self.transform(mask)
        
        return distorted_image, mask


# =============================================================================
# 4. DATA SPLITTING UTILITY FOR SINGLE DIRECTORY
# =============================================================================

def split_dataset_from_directory(image_dir, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15, seed=42):
    """
    Split images from a single directory into train/val/test sets
    
    Args:
        image_dir: Directory containing all images
        train_ratio: Proportion for training (default: 0.7)
        val_ratio: Proportion for validation (default: 0.15)
        test_ratio: Proportion for testing (default: 0.15)
        seed: Random seed for reproducibility
    
    Returns:
        train_files, val_files, test_files: Lists of filenames for each split
    """
    # Set seed for reproducibility
    random.seed(seed)
    np.random.seed(seed)
    
    # Get all image files
    all_files = [f for f in os.listdir(image_dir) 
                 if f.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp'))]
    
    print(f"Total images found: {len(all_files)}")
    
    # Shuffle
    random.shuffle(all_files)
    
    # Calculate split sizes
    total = len(all_files)
    train_size = int(total * train_ratio)
    val_size = int(total * val_ratio)
    
    # Split
    train_files = all_files[:train_size]
    val_files = all_files[train_size:train_size + val_size]
    test_files = all_files[train_size + val_size:]
    
    print(f"Train: {len(train_files)} images")
    print(f"Val: {len(val_files)} images")
    print(f"Test: {len(test_files)} images")
    
    return train_files, val_files, test_files


# =============================================================================
# 5. KAGGLE-SPECIFIC PATH FINDER
# =============================================================================

def find_flickr8k_path():
    """
    Automatically find Flickr8k dataset path in Kaggle
    """
    possible_paths = [
        '/kaggle/input/flickr8k/Images',
        '/kaggle/input/flickr-8k/Images',
        '/kaggle/input/flickr8k-dataset/Images',
        '/kaggle/input/flickr-image-dataset/Images',
        '/kaggle/input/flickr8k',
        '/kaggle/input/flickr-8k',
    ]
    
    # Try to find the correct path
    for path in possible_paths:
        if os.path.exists(path):
            print(f"✓ Found Flickr8k at: {path}")
            return path
    
    # If not found, list available paths
    print("Could not find Flickr8k automatically. Available input directories:")
    if os.path.exists('/kaggle/input'):
        for item in os.listdir('/kaggle/input'):
            item_path = os.path.join('/kaggle/input', item)
            print(f"  - {item_path}")
            if os.path.isdir(item_path):
                for subitem in os.listdir(item_path)[:5]:  # Show first 5 items
                    print(f"    - {subitem}")
    
    raise FileNotFoundError("Please manually specify the correct path to Flickr8k images")


# =============================================================================
# 6. LOSS FUNCTIONS (SAME AS BEFORE)
# =============================================================================

class DiceLoss(nn.Module):
    """Dice Loss for segmentation"""
    def __init__(self, smooth=1.0):
        super(DiceLoss, self).__init__()
        self.smooth = smooth
    
    def forward(self, pred, target):
        pred = pred.contiguous().view(-1)
        target = target.contiguous().view(-1)
        
        intersection = (pred * target).sum()
        dice = (2. * intersection + self.smooth) / (pred.sum() + target.sum() + self.smooth)
        
        return 1 - dice


class CombinedLoss(nn.Module):
    """Combination of BCE and Dice Loss"""
    def __init__(self, bce_weight=0.5, dice_weight=0.5):
        super(CombinedLoss, self).__init__()
        self.bce = nn.BCELoss()
        self.dice = DiceLoss()
        self.bce_weight = bce_weight
        self.dice_weight = dice_weight
    
    def forward(self, pred, target):
        bce_loss = self.bce(pred, target)
        dice_loss = self.dice(pred, target)
        return self.bce_weight * bce_loss + self.dice_weight * dice_loss


# =============================================================================
# 7. TRAINING FUNCTIONS (SAME AS BEFORE)
# =============================================================================

def train_one_epoch(model, dataloader, optimizer, criterion, device):
    """Train for one epoch"""
    model.train()
    running_loss = 0.0
    
    pbar = tqdm(dataloader, desc='Training')
    for images, masks in pbar:
        images = images.to(device)
        masks = masks.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks)
        
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        pbar.set_postfix({'loss': loss.item()})
    
    return running_loss / len(dataloader)


def validate(model, dataloader, criterion, device):
    """Validate the model"""
    model.eval()
    running_loss = 0.0
    
    with torch.no_grad():
        for images, masks in tqdm(dataloader, desc='Validation'):
            images = images.to(device)
            masks = masks.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, masks)
            
            running_loss += loss.item()
    
    return running_loss / len(dataloader)


# =============================================================================
# 8. MODIFIED TRAINING FUNCTION FOR KAGGLE (SINGLE DIRECTORY)
# =============================================================================

def train_mask_predictor_kaggle(
    image_dir=None,
    train_ratio=0.7,
    val_ratio=0.15,
    test_ratio=0.15,
    epochs=50,
    batch_size=8,
    learning_rate=0.001,
    img_size=256,
    save_dir='/kaggle/working/checkpoints',
    device=None,
    num_workers=2,
    seed=42
):
    """
    Modified training function for Kaggle with single directory
    
    Args:
        image_dir: Directory containing all images (if None, auto-detect Flickr8k)
        train_ratio: Proportion for training
        val_ratio: Proportion for validation
        test_ratio: Proportion for testing
        epochs: Number of training epochs
        batch_size: Batch size
        learning_rate: Learning rate
        img_size: Image size (square)
        save_dir: Directory to save checkpoints (default: /kaggle/working/)
        device: Device to train on (auto-detect if None)
        num_workers: Number of data loading workers
        seed: Random seed
    """
    
    # Auto-detect device
    if device is None:
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"Using device: {device}")
    
    # Auto-detect Flickr8k path if not provided
    if image_dir is None:
        image_dir = find_flickr8k_path()
    
    print(f"Loading images from: {image_dir}")
    
    # Create save directory
    os.makedirs(save_dir, exist_ok=True)
    print(f"Checkpoints will be saved to: {save_dir}")
    
    # Split dataset
    print("\nSplitting dataset...")
    train_files, val_files, test_files = split_dataset_from_directory(
        image_dir, train_ratio, val_ratio, test_ratio, seed
    )
    
    # Save split information
    split_info = {
        'train': train_files,
        'val': val_files,
        'test': test_files
    }
    import pickle
    with open(os.path.join(save_dir, 'data_split.pkl'), 'wb') as f:
        pickle.dump(split_info, f)
    print(f"Data split saved to: {os.path.join(save_dir, 'data_split.pkl')}")
    
    # Create datasets
    train_dataset = DistortionDataset(image_dir, train_files, img_size=img_size)
    val_dataset = DistortionDataset(image_dir, val_files, img_size=img_size)
    
    # Create dataloaders
    train_loader = DataLoader(
        train_dataset, 
        batch_size=batch_size, 
        shuffle=True, 
        num_workers=num_workers,
        pin_memory=True if device == 'cuda' else False
    )
    
    val_loader = DataLoader(
        val_dataset, 
        batch_size=batch_size, 
        shuffle=False, 
        num_workers=num_workers,
        pin_memory=True if device == 'cuda' else False
    )
    
    # Create model
    print("\nInitializing model...")
    model = LightweightUNet(n_channels=3, n_classes=1, bilinear=True)
    model = model.to(device)
    
    # Count parameters
    total_params = sum(p.numel() for p in model.parameters())
    print(f"Total parameters: {total_params:,}")
    
    # Loss and optimizer
    criterion = CombinedLoss(bce_weight=0.5, dice_weight=0.5)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=5, verbose=True
    )
    
    # Training loop
    best_val_loss = float('inf')
    train_losses = []
    val_losses = []
    
    print("\nStarting training...")
    print("=" * 70)
    
    for epoch in range(epochs):
        print(f'\nEpoch {epoch+1}/{epochs}')
        print("-" * 70)
        
        # Train
        train_loss = train_one_epoch(model, train_loader, optimizer, criterion, device)
        train_losses.append(train_loss)
        print(f'Training Loss: {train_loss:.4f}')
        
        # Validate
        val_loss = validate(model, val_loader, criterion, device)
        val_losses.append(val_loss)
        print(f'Validation Loss: {val_loss:.4f}')
        
        scheduler.step(val_loss)
        
        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            checkpoint_path = os.path.join(save_dir, 'best_model.pth')
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_loss': train_loss,
                'val_loss': val_loss,
                'best_val_loss': best_val_loss,
            }, checkpoint_path)
            print(f'✓ Saved best model (val_loss: {val_loss:.4f}) to {checkpoint_path}')
        
        # Save checkpoint every 10 epochs
        if (epoch + 1) % 10 == 0:
            checkpoint_path = os.path.join(save_dir, f'checkpoint_epoch_{epoch+1}.pth')
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_loss': train_loss,
                'val_loss': val_loss,
            }, checkpoint_path)
            print(f'✓ Saved checkpoint to {checkpoint_path}')
    
    print("\n" + "=" * 70)
    print("Training completed!")
    print(f"Best validation loss: {best_val_loss:.4f}")
    print(f"Models saved in: {save_dir}")
    
    # Save training history
    history = {
        'train_losses': train_losses,
        'val_losses': val_losses,
        'best_val_loss': best_val_loss
    }
    import pickle
    with open(os.path.join(save_dir, 'training_history.pkl'), 'wb') as f:
        pickle.dump(history, f)
    
    return model, history


# =============================================================================
# 9. INFERENCE CLASS (SAME AS BEFORE)
# =============================================================================

class MaskPredictor:
    """Inference wrapper for mask prediction"""
    
    def __init__(self, checkpoint_path, device=None):
        if device is None:
            device = 'cuda' if torch.cuda.is_available() else 'cpu'
        
        self.device = device
        self.model = LightweightUNet(n_channels=3, n_classes=1, bilinear=True)
        
        # Load checkpoint
        checkpoint = torch.load(checkpoint_path, map_location=device)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.model.to(device)
        self.model.eval()
        
        print(f"Model loaded from {checkpoint_path}")
        print(f"Trained for {checkpoint.get('epoch', 'unknown')} epochs")
        print(f"Best val loss: {checkpoint.get('best_val_loss', checkpoint.get('val_loss', 'unknown'))}")
    
    def predict(self, image, threshold=0.5):
        """
        Predict mask for a single image
        
        Args:
            image: numpy array (H, W, 3) in RGB format, values [0, 255]
            threshold: Threshold for binary mask
        
        Returns:
            mask: numpy array (H, W) with values [0, 1]
        """
        original_size = image.shape[:2]
        
        # Preprocess
        image_tensor = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).float() / 255.0
        image_tensor = image_tensor.to(self.device)
        
        # Predict
        with torch.no_grad():
            output = self.model(image_tensor)
        
        # Postprocess
        mask = output.squeeze().cpu().numpy()
        mask = (mask > threshold).astype(np.uint8)
        
        # Resize to original size if needed
        if mask.shape != original_size:
            mask = cv2.resize(mask, (original_size[1], original_size[0]), 
                            interpolation=cv2.INTER_NEAREST)
        
        return 1 - mask
    
    def predict_batch(self, images, threshold=0.5):
        """Predict masks for a batch of images"""
        masks = []
        for image in images:
            mask = self.predict(image, threshold)
            masks.append(mask)
        return masks


# =============================================================================
# 10. VISUALIZATION UTILITIES
# =============================================================================

def visualize_predictions(images_list, save_path='/kaggle/working/predictions.png'):
    """
    Visualize multiple predictions in a grid
    
    Args:
        images_list: List of tuples (image, ground_truth_mask, predicted_mask)
        save_path: Path to save visualization
    """
    import matplotlib.pyplot as plt
    
    n_samples = len(images_list)
    fig, axes = plt.subplots(n_samples, 3, figsize=(15, 5 * n_samples))
    
    if n_samples == 1:
        axes = axes.reshape(1, -1)
    
    for idx, (image, gt_mask, pred_mask) in enumerate(images_list):
        axes[idx, 0].imshow(image)
        axes[idx, 0].set_title('Original Image')
        axes[idx, 0].axis('off')
        
        axes[idx, 1].imshow(gt_mask, cmap='gray')
        axes[idx, 1].set_title('Ground Truth Mask')
        axes[idx, 1].axis('off')
        
        axes[idx, 2].imshow(pred_mask, cmap='gray')
        axes[idx, 2].set_title('Predicted Mask')
        axes[idx, 2].axis('off')
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    print(f"Visualization saved to {save_path}")
    plt.show()


def plot_training_history(history, save_path='/kaggle/working/training_history.png'):
    """Plot training and validation loss"""
    import matplotlib.pyplot as plt
    
    plt.figure(figsize=(10, 6))
    plt.plot(history['train_losses'], label='Training Loss', linewidth=2)
    plt.plot(history['val_losses'], label='Validation Loss', linewidth=2)
    plt.xlabel('Epoch', fontsize=12)
    plt.ylabel('Loss', fontsize=12)
    plt.title('Training History', fontsize=14)
    plt.legend(fontsize=12)
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    print(f"Training history plot saved to {save_path}")
    plt.show()


# =============================================================================
# 11. COMPLETE KAGGLE USAGE EXAMPLE
# =============================================================================

def kaggle_complete_pipeline():
    """
    Complete pipeline for Kaggle notebook
    Run this in a Kaggle notebook cell
    """
    
    print("=" * 70)
    print("LIGHTWEIGHT U-NET MASK PREDICTOR - KAGGLE PIPELINE")
    print("=" * 70)
    
    # Step 1: Train the model
    print("\n[STEP 1] Training model...")
    model, history = train_mask_predictor_kaggle(
        image_dir=None,  # Auto-detect Flickr8k
        train_ratio=0.7,
        val_ratio=0.15,
        test_ratio=0.15,
        epochs=30,  # Start with 30 epochs for faster testing
        batch_size=16,  # Increase if you have enough GPU memory
        learning_rate=0.001,
        img_size=256,
        save_dir='/kaggle/working/checkpoints',
        num_workers=2,
        seed=42
    )
    
    # Step 2: Plot training history
    print("\n[STEP 2] Plotting training history...")
    plot_training_history(history)
    
    # Step 3: Test on some images
    print("\n[STEP 3] Testing on sample images...")
    predictor = MaskPredictor('/kaggle/working/checkpoints/best_model.pth')
    
    # Load test data split
    import pickle
    with open('/kaggle/working/checkpoints/data_split.pkl', 'rb') as f:
        split_info = pickle.load(f)
    
    # Test on a few images
    image_dir = find_flickr8k_path()
    test_files = split_info['test'][:5]  # First 5 test images
    
    visualizations = []
    for img_file in test_files:
        img_path = os.path.join(image_dir, img_file)
        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        # Generate synthetic distortion for testing
        distortion_gen = SyntheticDistortionGenerator()
        distorted, gt_mask = distortion_gen.generate_distortion(image)
        
        # Predict mask
        pred_mask = predictor.predict(distorted)
        
        visualizations.append((distorted, gt_mask, pred_mask))
    
    # Visualize results
    visualize_predictions(visualizations)
    
    print("\n" + "=" * 70)
    print("Pipeline completed successfully!")
    print("=" * 70)
    print("\nSaved files:")
    print(f"  - Best model: /kaggle/working/checkpoints/best_model.pth")
    print(f"  - Training history: /kaggle/working/checkpoints/training_history.pkl")
    print(f"  - Data split: /kaggle/working/checkpoints/data_split.pkl")
    print(f"  - Visualizations: /kaggle/working/predictions.png")
    print(f"  - Training plot: /kaggle/working/training_history.png")


# =============================================================================
# 12. QUICK TEST FUNCTION
# =============================================================================

def quick_test():
    """Quick test to verify everything works"""
    print("Running quick test...")
    
    # Check GPU
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"Device: {device}")
    
    # Check path
    try:
        image_dir = find_flickr8k_path()
        files = os.listdir(image_dir)[:5]
        print(f"Found {len(files)} sample files: {files}")
    except Exception as e:
        print(f"Error: {e}")
        return
    
    # Create a small model
    model = LightweightUNet()
    dummy_input = torch.randn(1, 3, 256, 256)
    output = model(dummy_input)
    print(f"Model output shape: {output.shape}")
    
    print("✓ Quick test passed!")


# =============================================================================
# KAGGLE NOTEBOOK EXECUTION
# =============================================================================

if __name__ == '__main__':
    # Uncomment one of these based on what you want to do:
    
    # Option 1: Quick test (recommended first)
    # quick_test()
    
    # Option 2: Run complete pipeline
    # kaggle_complete_pipeline()
    
    # Option 3: Train only with custom parameters
    # model, history = train_mask_predictor_kaggle(
    #     epochs=50,
    #     batch_size=16,
    #     img_size=256
    # )
    pass

In [None]:
# Partial Convultion based Inpainting Generator.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torchvision.models as models
import cv2
import numpy as np
import os
from tqdm import tqdm
from PIL import Image
import matplotlib.pyplot as plt

# =============================================================================
# 1. PARTIAL CONVOLUTION LAYER (Core Innovation)
# =============================================================================

class PartialConv2d(nn.Module):
    """
    Partial Convolution Layer
    
    Key innovation: Convolution that handles irregular masks
    - Only convolves over valid (unmasked) pixels
    - Automatically updates mask as it goes deeper
    
    Paper: "Image Inpainting for Irregular Holes Using Partial Convolutions"
           Liu et al., ECCV 2018
    """
    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, groups=1, bias=True):
        super().__init__()
        
        self.input_conv = nn.Conv2d(
            in_channels, out_channels, kernel_size,
            stride, padding, dilation, groups, bias
        )
        
        self.mask_conv = nn.Conv2d(
            in_channels, out_channels, kernel_size,
            stride, padding, dilation, groups, False
        )
        
        self.input_conv.apply(self.weights_init('kaiming'))
        
        # Mask convolution: fixed weights (all ones)
        torch.nn.init.constant_(self.mask_conv.weight, 1.0)
        
        # Freeze mask convolution weights
        for param in self.mask_conv.parameters():
            param.requires_grad = False
    
    def weights_init(self, init_type='kaiming'):
        def init_fun(m):
            classname = m.__class__.__name__
            if (classname.find('Conv') == 0 or classname.find('Linear') == 0) and hasattr(m, 'weight'):
                if init_type == 'kaiming':
                    nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
                elif init_type == 'xavier':
                    nn.init.xavier_normal_(m.weight.data, gain=1.0)
                else:
                    raise NotImplementedError(f'Initialization {init_type} not supported')
                if hasattr(m, 'bias') and m.bias is not None:
                    nn.init.constant_(m.bias.data, 0.0)
        return init_fun
    
    def forward(self, input_x, mask):
        """
        Args:
            input_x: [B, C, H, W] - Image with holes
            mask: [B, C, H, W] - Binary mask (1=valid, 0=hole)
        
        Returns:
            output: [B, C', H', W'] - Convolved output
            updated_mask: [B, C', H', W'] - Updated mask
        """
        # Apply convolution
        with torch.no_grad():
            # Sum of mask values in each kernel window
            mask_sum = self.mask_conv(mask)
        
        # Avoid division by zero
        mask_sum = mask_sum.clamp(min=1e-8)
        
        # Apply convolution and normalize by mask sum
        # This ensures we only use valid pixels
        output = self.input_conv(input_x * mask)
        
        # Kernel size for normalization
        kernel_size = self.input_conv.kernel_size[0] * self.input_conv.kernel_size[1]
        output = output * (kernel_size / mask_sum)
        
        # Update mask: 1 if any valid pixel in receptive field, else 0
        updated_mask = torch.ones_like(mask_sum)
        updated_mask[mask_sum == 0] = 0
        
        return output, updated_mask


# =============================================================================
# 2. PARTIAL CONVOLUTION U-NET ARCHITECTURE
# =============================================================================

class PartialConvUNet(nn.Module):
    """
    U-Net with Partial Convolutions for Image Inpainting
    
    Architecture:
    - Encoder: 8 layers with partial convolutions
    - Decoder: 8 layers with partial convolutions + skip connections
    - Handles irregular masks naturally
    """
    def __init__(self, input_channels=3, output_channels=3):
        super().__init__()
        
        # Encoder layers
        self.enc1 = PartialConv2d(input_channels, 64, 7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        
        self.enc2 = PartialConv2d(64, 128, 5, stride=2, padding=2, bias=False)
        self.bn2 = nn.BatchNorm2d(128)
        
        self.enc3 = PartialConv2d(128, 256, 5, stride=2, padding=2, bias=False)
        self.bn3 = nn.BatchNorm2d(256)
        
        self.enc4 = PartialConv2d(256, 512, 3, stride=2, padding=1, bias=False)
        self.bn4 = nn.BatchNorm2d(512)
        
        self.enc5 = PartialConv2d(512, 512, 3, stride=2, padding=1, bias=False)
        self.bn5 = nn.BatchNorm2d(512)
        
        self.enc6 = PartialConv2d(512, 512, 3, stride=2, padding=1, bias=False)
        self.bn6 = nn.BatchNorm2d(512)
        
        self.enc7 = PartialConv2d(512, 512, 3, stride=2, padding=1, bias=False)
        self.bn7 = nn.BatchNorm2d(512)
        
        self.enc8 = PartialConv2d(512, 512, 3, stride=2, padding=1, bias=False)
        self.bn8 = nn.BatchNorm2d(512)
        
        # Decoder layers
        self.dec8 = PartialConv2d(512, 512, 3, stride=1, padding=1, bias=False)
        self.dbn8 = nn.BatchNorm2d(512)
        
        self.dec7 = PartialConv2d(512 + 512, 512, 3, stride=1, padding=1, bias=False)
        self.dbn7 = nn.BatchNorm2d(512)
        
        self.dec6 = PartialConv2d(512 + 512, 512, 3, stride=1, padding=1, bias=False)
        self.dbn6 = nn.BatchNorm2d(512)
        
        self.dec5 = PartialConv2d(512 + 512, 512, 3, stride=1, padding=1, bias=False)
        self.dbn5 = nn.BatchNorm2d(512)
        
        self.dec4 = PartialConv2d(512 + 512, 256, 3, stride=1, padding=1, bias=False)
        self.dbn4 = nn.BatchNorm2d(256)
        
        self.dec3 = PartialConv2d(256 + 256, 128, 3, stride=1, padding=1, bias=False)
        self.dbn3 = nn.BatchNorm2d(128)
        
        self.dec2 = PartialConv2d(128 + 128, 64, 3, stride=1, padding=1, bias=False)
        self.dbn2 = nn.BatchNorm2d(64)
        
        self.dec1 = PartialConv2d(64 + 64, output_channels, 3, stride=1, padding=1, bias=True)
        
        self.relu = nn.ReLU(inplace=True)
        self.leaky_relu = nn.LeakyReLU(0.2, inplace=True)
    
    def forward(self, input_img, mask):
        """
        Args:
            input_img: [B, 3, H, W] - Image with holes (holes filled with mean)
            mask: [B, 1, H, W] - Binary mask (1=valid, 0=hole)
        
        Returns:
            output: [B, 3, H, W] - Inpainted image
        """
        # Expand mask to match input channels
        mask = mask.repeat(1, input_img.shape[1], 1, 1)
        
        # Encoder with partial convolutions
        e1, m1 = self.enc1(input_img, mask)
        e1 = self.bn1(e1)
        e1 = self.relu(e1)
        
        e2, m2 = self.enc2(e1, m1)
        e2 = self.bn2(e2)
        e2 = self.relu(e2)
        
        e3, m3 = self.enc3(e2, m2)
        e3 = self.bn3(e3)
        e3 = self.relu(e3)
        
        e4, m4 = self.enc4(e3, m3)
        e4 = self.bn4(e4)
        e4 = self.relu(e4)
        
        e5, m5 = self.enc5(e4, m4)
        e5 = self.bn5(e5)
        e5 = self.relu(e5)
        
        e6, m6 = self.enc6(e5, m5)
        e6 = self.bn6(e6)
        e6 = self.relu(e6)
        
        e7, m7 = self.enc7(e6, m6)
        e7 = self.bn7(e7)
        e7 = self.relu(e7)
        
        e8, m8 = self.enc8(e7, m7)
        e8 = self.bn8(e8)
        e8 = self.relu(e8)
        
        # Decoder with skip connections
        d8, dm8 = self.dec8(e8, m8)
        d8 = self.dbn8(d8)
        d8 = self.leaky_relu(d8)
        d8 = F.interpolate(d8, scale_factor=2, mode='nearest')
        dm8 = F.interpolate(dm8, scale_factor=2, mode='nearest')
        
        d7 = torch.cat([d8, e7], dim=1)
        dm7 = torch.cat([dm8, m7], dim=1)
        d7, dm7 = self.dec7(d7, dm7)
        d7 = self.dbn7(d7)
        d7 = self.leaky_relu(d7)
        d7 = F.interpolate(d7, scale_factor=2, mode='nearest')
        dm7 = F.interpolate(dm7, scale_factor=2, mode='nearest')
        
        d6 = torch.cat([d7, e6], dim=1)
        dm6 = torch.cat([dm7, m6], dim=1)
        d6, dm6 = self.dec6(d6, dm6)
        d6 = self.dbn6(d6)
        d6 = self.leaky_relu(d6)
        d6 = F.interpolate(d6, scale_factor=2, mode='nearest')
        dm6 = F.interpolate(dm6, scale_factor=2, mode='nearest')
        
        d5 = torch.cat([d6, e5], dim=1)
        dm5 = torch.cat([dm6, m5], dim=1)
        d5, dm5 = self.dec5(d5, dm5)
        d5 = self.dbn5(d5)
        d5 = self.leaky_relu(d5)
        d5 = F.interpolate(d5, scale_factor=2, mode='nearest')
        dm5 = F.interpolate(dm5, scale_factor=2, mode='nearest')
        
        d4 = torch.cat([d5, e4], dim=1)
        dm4 = torch.cat([dm5, m4], dim=1)
        d4, dm4 = self.dec4(d4, dm4)
        d4 = self.dbn4(d4)
        d4 = self.leaky_relu(d4)
        d4 = F.interpolate(d4, scale_factor=2, mode='nearest')
        dm4 = F.interpolate(dm4, scale_factor=2, mode='nearest')
        
        d3 = torch.cat([d4, e3], dim=1)
        dm3 = torch.cat([dm4, m3], dim=1)
        d3, dm3 = self.dec3(d3, dm3)
        d3 = self.dbn3(d3)
        d3 = self.leaky_relu(d3)
        d3 = F.interpolate(d3, scale_factor=2, mode='nearest')
        dm3 = F.interpolate(dm3, scale_factor=2, mode='nearest')
        
        d2 = torch.cat([d3, e2], dim=1)
        dm2 = torch.cat([dm3, m2], dim=1)
        d2, dm2 = self.dec2(d2, dm2)
        d2 = self.dbn2(d2)
        d2 = self.leaky_relu(d2)
        d2 = F.interpolate(d2, scale_factor=2, mode='nearest')
        dm2 = F.interpolate(dm2, scale_factor=2, mode='nearest')
        
        d1 = torch.cat([d2, e1], dim=1)
        dm1 = torch.cat([dm2, m1], dim=1)
        d1, _ = self.dec1(d1, dm1)
        d1 = F.interpolate(d1, scale_factor=2, mode='nearest')
        
        # Tanh to get output in [-1, 1] range
        output = torch.tanh(d1)
        
        return output


# =============================================================================
# 3. LOSS FUNCTIONS
# =============================================================================

class VGGPerceptualLoss(nn.Module):
    """
    Perceptual Loss using VGG16
    Measures feature-level similarity
    """
    def __init__(self, device='cuda'):
        super().__init__()
        
        vgg16 = models.vgg16(pretrained=True).features.to(device).eval()
        
        # Extract specific layers
        self.slice1 = nn.Sequential(*list(vgg16.children())[:4])   # relu1_2
        self.slice2 = nn.Sequential(*list(vgg16.children())[4:9])  # relu2_2
        self.slice3 = nn.Sequential(*list(vgg16.children())[9:16]) # relu3_3
        
        # Freeze weights
        for param in self.parameters():
            param.requires_grad = False
    
    def forward(self, pred, target):
        """
        Args:
            pred: [B, 3, H, W] - Predicted image
            target: [B, 3, H, W] - Ground truth image
        """
        # Normalize to ImageNet stats
        mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(pred.device)
        std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(pred.device)
        
        pred = (pred - mean) / std
        target = (target - mean) / std
        
        # Extract features
        pred_f1 = self.slice1(pred)
        pred_f2 = self.slice2(pred_f1)
        pred_f3 = self.slice3(pred_f2)
        
        target_f1 = self.slice1(target)
        target_f2 = self.slice2(target_f1)
        target_f3 = self.slice3(target_f2)
        
        # Compute L1 loss at each layer
        loss = (
            F.l1_loss(pred_f1, target_f1) +
            F.l1_loss(pred_f2, target_f2) +
            F.l1_loss(pred_f3, target_f3)
        )
        
        return loss


class StyleLoss(nn.Module):
    """
    Style Loss using Gram matrices
    Ensures texture consistency
    """
    def __init__(self, device='cuda'):
        super().__init__()
        
        vgg16 = models.vgg16(pretrained=True).features.to(device).eval()
        self.slice1 = nn.Sequential(*list(vgg16.children())[:4])
        self.slice2 = nn.Sequential(*list(vgg16.children())[4:9])
        self.slice3 = nn.Sequential(*list(vgg16.children())[9:16])
        
        for param in self.parameters():
            param.requires_grad = False
    
    def gram_matrix(self, x):
        """Compute Gram matrix"""
        B, C, H, W = x.shape
        features = x.view(B, C, H * W)
        gram = torch.bmm(features, features.transpose(1, 2))
        gram = gram / (C * H * W)
        return gram
    
    def forward(self, pred, target):
        # Normalize
        mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(pred.device)
        std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(pred.device)
        
        pred = (pred - mean) / std
        target = (target - mean) / std
        
        # Extract features
        pred_f1 = self.slice1(pred)
        pred_f2 = self.slice2(pred_f1)
        pred_f3 = self.slice3(pred_f2)
        
        target_f1 = self.slice1(target)
        target_f2 = self.slice2(target_f1)
        target_f3 = self.slice3(target_f2)
        
        # Compute Gram matrices and loss
        loss = (
            F.l1_loss(self.gram_matrix(pred_f1), self.gram_matrix(target_f1)) +
            F.l1_loss(self.gram_matrix(pred_f2), self.gram_matrix(target_f2)) +
            F.l1_loss(self.gram_matrix(pred_f3), self.gram_matrix(target_f3))
        )
        
        return loss


class TotalVariationLoss(nn.Module):
    """
    Total Variation Loss
    Encourages spatial smoothness
    """
    def forward(self, x):
        batch_size, c, h, w = x.shape
        
        # Horizontal differences
        tv_h = torch.abs(x[:, :, 1:, :] - x[:, :, :-1, :]).sum()
        
        # Vertical differences
        tv_w = torch.abs(x[:, :, :, 1:] - x[:, :, :, :-1]).sum()
        
        return (tv_h + tv_w) / (batch_size * c * h * w)


class InpaintingLoss(nn.Module):
    """
    Combined loss for inpainting
    """
    def __init__(self, device='cuda'):
        super().__init__()
        
        self.l1_loss = nn.L1Loss()
        self.perceptual_loss = VGGPerceptualLoss(device)
        self.style_loss = StyleLoss(device)
        self.tv_loss = TotalVariationLoss()
    
    def forward(self, pred, target, mask):
        """
        Args:
            pred: [B, 3, H, W] - Predicted (inpainted) image
            target: [B, 3, H, W] - Ground truth clean image
            mask: [B, 1, H, W] - Binary mask (1=valid, 0=hole)
        """
        # Expand mask
        mask_3ch = mask.repeat(1, 3, 1, 1)
        
        # Valid loss (on originally valid pixels)
        valid_loss = self.l1_loss(pred * mask_3ch, target * mask_3ch)
        
        # Hole loss (on filled regions)
        hole_loss = self.l1_loss(pred * (1 - mask_3ch), target * (1 - mask_3ch))
        
        # Perceptual loss
        perceptual_loss = self.perceptual_loss(pred, target)
        
        # Style loss
        style_loss = self.style_loss(pred, target)
        
        # Total variation loss (smoothness)
        tv_loss = self.tv_loss(pred * (1 - mask_3ch))
        
        # Weighted combination
        total_loss = (
            1.0 * valid_loss +
            6.0 * hole_loss +
            0.05 * perceptual_loss +
            120.0 * style_loss +
            0.1 * tv_loss
        )
        
        return total_loss, {
            'valid': valid_loss.item(),
            'hole': hole_loss.item(),
            'perceptual': perceptual_loss.item(),
            'style': style_loss.item(),
            'tv': tv_loss.item()
        }


# =============================================================================
# 4. DATASET FOR TRAINING
# =============================================================================

class InpaintingDataset(Dataset):
    """
    Dataset for training inpainting model
    Uses your existing U-Net to generate masks OR synthetic masks
    """
    def __init__(self, image_dir, image_files, mask_predictor=None, 
                 img_size=256, use_synthetic_masks=True):
        """
        Args:
            image_dir: Directory with images
            image_files: List of image filenames
            mask_predictor: Your trained U-Net (optional)
            img_size: Image size
            use_synthetic_masks: If True, generate random masks
        """
        self.image_dir = image_dir
        self.image_files = image_files
        self.mask_predictor = mask_predictor
        self.img_size = img_size
        self.use_synthetic_masks = use_synthetic_masks
    
    def generate_random_mask(self, img_size):
        """Generate random irregular mask"""
        mask = np.ones((img_size, img_size), dtype=np.float32)
        
        # Random rectangles
        num_rects = np.random.randint(1, 5)
        for _ in range(num_rects):
            x1 = np.random.randint(0, img_size - 20)
            y1 = np.random.randint(0, img_size - 20)
            w = np.random.randint(20, img_size // 3)
            h = np.random.randint(20, img_size // 3)
            mask[y1:y1+h, x1:x1+w] = 0
        
        # Random lines (scratches)
        num_lines = np.random.randint(0, 10)
        for _ in range(num_lines):
            x1, y1 = np.random.randint(0, img_size, 2)
            x2, y2 = np.random.randint(0, img_size, 2)
            thickness = np.random.randint(1, 5)
            cv2.line(mask, (x1, y1), (x2, y2), 0, thickness)
        
        # Dilate mask slightly
        kernel = np.ones((3, 3), np.uint8)
        mask = cv2.erode(mask, kernel, iterations=1)
        
        return mask
    
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        # Load image
        img_path = os.path.join(self.image_dir, self.image_files[idx])
        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = cv2.resize(image, (self.img_size, self.img_size))
        
        # Generate or predict mask
        if self.use_synthetic_masks or self.mask_predictor is None:
            mask = self.generate_random_mask(self.img_size)
        else:
            # Use your trained U-Net to predict mask
            with torch.no_grad():
                img_tensor = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0
                img_tensor = img_tensor.unsqueeze(0)
                mask = self.mask_predictor.predict(image, threshold=0.5)
        
        # Convert to tensors
        image = torch.from_numpy(image).permute(2, 0, 1).float() / 127.5 - 1.0  # [-1, 1]
        mask = torch.from_numpy(mask).unsqueeze(0).float()
        
        # Masked image (fill holes with mean)
        masked_image = image * mask
        
        return masked_image, mask, image  # masked, mask, ground_truth


# =============================================================================
# 5. TRAINING FUNCTION
# =============================================================================

def train_inpainting_model(
    image_dir,
    train_files,
    val_files,
    epochs=100,
    batch_size=8,
    learning_rate=0.0002,
    img_size=256,
    save_dir='/kaggle/working/inpainting_checkpoints',
    device='cuda',
    mask_predictor=None,
    use_synthetic_masks=True
):
    """
    Train partial convolution inpainting model
    """
    os.makedirs(save_dir, exist_ok=True)
    
    print("="*70)
    print("PARTIAL CONVOLUTION INPAINTING TRAINING")
    print("="*70)
    print(f"Device: {device}")
    print(f"Training images: {len(train_files)}")
    print(f"Validation images: {len(val_files)}")
    print(f"Epochs: {epochs}")
    print(f"Batch size: {batch_size}")
    print("="*70)
    
    # Create datasets
    train_dataset = InpaintingDataset(
        image_dir, train_files, mask_predictor, 
        img_size, use_synthetic_masks
    )
    val_dataset = InpaintingDataset(
        image_dir, val_files, mask_predictor,
        img_size, use_synthetic_masks
    )
    
    train_loader = DataLoader(
        train_dataset, batch_size=batch_size,
        shuffle=True, num_workers=2, pin_memory=True
    )
    val_loader = DataLoader(
        val_dataset, batch_size=batch_size,
        shuffle=False, num_workers=2, pin_memory=True
    )
    
    # Create model
    model = PartialConvUNet().to(device)
    
    # Count parameters
    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total trainable parameters: {total_params:,}")
    
    # Loss and optimizer
    criterion = InpaintingLoss(device)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5)
    
    # Training loop
    best_val_loss = float('inf')
    history = {'train_loss': [], 'val_loss': []}
    
    for epoch in range(epochs):
        print(f"\nEpoch {epoch+1}/{epochs}")
        print("-"*70)
        
        # Training
        model.train()
        train_loss = 0.0
        train_losses_detail = {'valid': 0, 'hole': 0, 'perceptual': 0, 'style': 0, 'tv': 0}
        
        pbar = tqdm(train_loader, desc='Training')
        for masked_img, mask, gt_img in pbar:
            masked_img = masked_img.to(device)
            mask = mask.to(device)
            gt_img = gt_img.to(device)
            
            # Forward
            optimizer.zero_grad()
            pred_img = model(masked_img, mask)
            
            # Loss
            loss, loss_dict = criterion(pred_img, gt_img, mask)
            
            # Backward
            loss.backward()
            optimizer.step()
            
            # Accumulate losses
            train_loss += loss.item()
            for key in loss_dict:
                train_losses_detail[key] += loss_dict[key]
            
            pbar.set_postfix({'loss': loss.item()})
        
        train_loss /= len(train_loader)
        for key in train_losses_detail:
            train_losses_detail[key] /= len(train_loader)
        
        history['train_loss'].append(train_loss)
        
        print(f"Train Loss: {train_loss:.4f}")
        print(f"  Valid: {train_losses_detail['valid']:.4f}, "
              f"Hole: {train_losses_detail['hole']:.4f}, "
              f"Perceptual: {train_losses_detail['perceptual']:.4f}")
        
        # Validation
        model.eval()
        val_loss = 0.0
        
        with torch.no_grad():
            for masked_img, mask, gt_img in tqdm(val_loader, desc='Validation'):
                masked_img = masked_img.to(device)
                mask = mask.to(device)
                gt_img = gt_img.to(device)
                
                pred_img = model(masked_img, mask)
                loss, _ = criterion(pred_img, gt_img, mask)
                val_loss += loss.item()
        
        val_loss /= len(val_loader)
        history['val_loss'].append(val_loss)
        
        print(f"Val Loss: {val_loss:.4f}")
        
        # Learning rate scheduling
        scheduler.step()
        
        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_loss': val_loss,
                'train_loss': train_loss
            }, os.path.join(save_dir, 'best_inpainting_model.pth'))
            print(f"✓ Saved best model (val_loss: {val_loss:.4f})")
        
        # Save checkpoint every 10 epochs
        if (epoch + 1) % 10 == 0:
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_loss': val_loss
            }, os.path.join(save_dir, f'checkpoint_epoch_{epoch+1}.pth'))
        
        # Visualize samples every 10 epochs
        if (epoch + 1) % 10 == 0:
            visualize_results(model, val_loader, device, 
                            save_path=os.path.join(save_dir, f'samples_epoch_{epoch+1}.png'))
    
    print("\n" + "="*70)
    print("Training completed!")
    print(f"Best validation loss: {best_val_loss:.4f}")
    
    # Save history
    import pickle
    with open(os.path.join(save_dir, 'training_history.pkl'), 'wb') as f:
        pickle.dump(history, f)
    
    return model, history


# =============================================================================
# 6. VISUALIZATION
# =============================================================================

def visualize_results(model, dataloader, device, save_path=None, num_samples=4):
    """Visualize inpainting results"""
    model.eval()
    
    # Get samples
    masked_imgs, masks, gt_imgs = next(iter(dataloader))
    masked_imgs = masked_imgs[:num_samples].to(device)
    masks = masks[:num_samples].to(device)
    gt_imgs = gt_imgs[:num_samples].to(device)
    
    with torch.no_grad():
        pred_imgs = model(masked_imgs, masks)
    
    # Convert to numpy
    def tensor_to_img(tensor):
        img = tensor.cpu().numpy()
        img = (img + 1) / 2.0  # [-1, 1] -> [0, 1]
        img = np.transpose(img, (0, 2, 3, 1))
        return np.clip(img, 0, 1)
    
    masked_np = tensor_to_img(masked_imgs)
    pred_np = tensor_to_img(pred_imgs)
    gt_np = tensor_to_img(gt_imgs)
    masks_np = masks.cpu().numpy().transpose(0, 2, 3, 1)
    
    # Plot
    fig, axes = plt.subplots(num_samples, 4, figsize=(16, 4*num_samples))
    
    for i in range(num_samples):
        axes[i, 0].imshow(masked_np[i])
        axes[i, 0].set_title('Masked Input')
        axes[i, 0].axis('off')
        
        axes[i, 1].imshow(masks_np[i], cmap='gray')
        axes[i, 1].set_title('Mask')
        axes[i, 1].axis('off')
        
        axes[i, 2].imshow(pred_np[i])
        axes[i, 2].set_title('Predicted')
        axes[i, 2].axis('off')
        
        axes[i, 3].imshow(gt_np[i])
        axes[i, 3].set_title('Ground Truth')
        axes[i, 3].axis('off')
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.show()


# =============================================================================
# 7. INFERENCE CLASS
# =============================================================================

class InpaintingPredictor:
    """Easy-to-use wrapper for inference"""
    
    def __init__(self, checkpoint_path, device='cuda'):
        self.device = device
        self.model = PartialConvUNet().to(device)
        
        checkpoint = torch.load(checkpoint_path, map_location=device)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.model.eval()
        
        print(f"✓ Inpainting model loaded from {checkpoint_path}")
        print(f"  Epoch: {checkpoint['epoch']}")
        print(f"  Val Loss: {checkpoint['val_loss']:.4f}")
    
    def inpaint(self, image, mask):
        """
        Inpaint image
        
        Args:
            image: numpy array [H, W, 3], RGB, [0-255]
            mask: numpy array [H, W], binary [0-1]
        
        Returns:
            inpainted: numpy array [H, W, 3], RGB, [0-255]
        """
        original_size = image.shape[:2]
        
        # Preprocess
        image = cv2.resize(image, (256, 256))
        mask = cv2.resize(mask, (256, 256))
        
        # To tensor
        image_tensor = torch.from_numpy(image).permute(2, 0, 1).float() / 127.5 - 1.0
        mask_tensor = torch.from_numpy(mask).unsqueeze(0).float()
        
        # Masked input
        masked_input = image_tensor * mask_tensor
        
        # Add batch dimension
        masked_input = masked_input.unsqueeze(0).to(self.device)
        mask_tensor = mask_tensor.unsqueeze(0).to(self.device)
        
        # Inpaint
        with torch.no_grad():
            pred = self.model(masked_input, mask_tensor)
        
        # Postprocess
        pred = pred.squeeze().cpu().numpy()
        pred = (pred + 1) / 2.0 * 255.0
        pred = np.transpose(pred, (1, 2, 0))
        pred = np.clip(pred, 0, 255).astype(np.uint8)
        
        # Resize back
        if pred.shape[:2] != original_size:
            pred = cv2.resize(pred, (original_size[1], original_size[0]))
        
        return pred


# =============================================================================
# 8. KAGGLE USAGE EXAMPLE
# =============================================================================

def kaggle_train_inpainting():
    """
    Complete training pipeline for Kaggle
    """
    
    # Paths
    image_dir = '/kaggle/input/flickr8k/Images'  # Adjust path
    save_dir = '/kaggle/working/inpainting_checkpoints'
    
    # Load data split (from your U-Net training)
    import pickle
    with open('/kaggle/input/mask-predictor-model-u-net-style/checkpoints/data_split.pkl', 'rb') as f:
        split_info = pickle.load(f)
    
    train_files = split_info['train']
    val_files = split_info['val']
    
    # Load your trained U-Net (optional - can use synthetic masks)
    # from mask_predictor import MaskPredictor
    # mask_predictor = MaskPredictor('/kaggle/input/your-unet-dataset/best_model.pth')
    mask_predictor = None  # Use synthetic masks for simplicity
    
    # Train
    model, history = train_inpainting_model(
        image_dir=image_dir,
        train_files=train_files,
        val_files=val_files,
        epochs=50,  # Start with 50, increase if needed
        batch_size=8,
        learning_rate=0.0002,
        img_size=256,
        save_dir=save_dir,
        device='cuda' if torch.cuda.is_available() else 'cpu',
        mask_predictor=mask_predictor,
        use_synthetic_masks=True
    )
    
    print("\n✓ Training complete!")
    print(f"Best model saved at: {save_dir}/best_inpainting_model.pth")


# =============================================================================
# 9. COMPLETE PIPELINE: U-NET + INPAINTING
# =============================================================================

def complete_pipeline_demo():
    """
    Demo: Use U-Net to predict mask, then inpaint
    """
    # Load models
    # from mask_predictor import MaskPredictor
    mask_predictor = MaskPredictor('/kaggle/input/u-net-model-trained/best_model.pth')
    inpainting_model = InpaintingPredictor('/kaggle/input/partial-convolution-model-trained/best_inpainting_model.pth')
    
    # Load test image
    image = cv2.imread('/kaggle/input/flickr8k/Images/3226254560_2f8ac147ea.jpg')
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    
    # Step 1: Predict mask
    mask = mask_predictor.predict(image, threshold=0.5)
    
    # Step 2: Inpaint
    inpainted = inpainting_model.inpaint(image, mask)
    
    # Visualize
    fig, axes = plt.subplots(1, 4, figsize=(16, 4))
    axes[0].imshow(image)
    axes[0].set_title('Original (Damaged)')
    axes[1].imshow(mask, cmap='gray')
    axes[1].set_title('Predicted Mask')
    axes[2].imshow(inpainted)
    axes[2].set_title('Inpainted')
    
    # Show difference
    diff = np.abs(image.astype(float) - inpainted.astype(float)).mean(axis=2)
    axes[3].imshow(diff, cmap='hot')
    axes[3].set_title('Difference')
    
    for ax in axes:
        ax.axis('off')
    
    plt.tight_layout()
    plt.savefig('/kaggle/working/complete_pipeline_result.png', dpi=150)
    plt.show()


if __name__ == '__main__':
    # Run training
    # kaggle_train_inpainting()
    complete_pipeline_demo()
    pass

In [None]:
# Inference Pipeline: We can infer a batch of images from a directory and can get the inferred results in an output directory.
# Input: input_directory_path
# Output: output_directory_path

In [None]:
import os
import cv2
import random
import numpy as np

# Find images (adjust path if needed)
image_dir = '/kaggle/input/flickr8k/Images'  # or your image path
output_dir = '/kaggle/working/outputs'
os.makedirs(output_dir, exist_ok=True)

# Get 20 random images
files = [f for f in os.listdir(image_dir) if f.endswith(('.jpg', '.png'))]
selected = random.sample(files, 20)

# Load models
# from mask_predictor import MaskPredictor, SyntheticDistortionGenerator
# from inpainting import InpaintingPredictor

mask_model = MaskPredictor('/kaggle/input/u-net-model-trained/best_model.pth')
inpaint_model = InpaintingPredictor('/kaggle/input/partial-convolution-model-trained/best_inpainting_model.pth')
distortion_gen = SyntheticDistortionGenerator()

# Process
for i, fname in enumerate(selected, 1):
    img = cv2.cvtColor(cv2.imread(f'{image_dir}/{fname}'), cv2.COLOR_BGR2RGB)
    
    distorted, mask_gt = distortion_gen.generate_distortion(img)
    mask = mask_model.predict(distorted)
    result = inpaint_model.inpaint(distorted, mask)
    
    cv2.imwrite(f'{output_dir}/test_image_{i}.png', cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
    cv2.imwrite(f'{output_dir}/test_image_{i}_distorted.png', cv2.cvtColor(distorted, cv2.COLOR_RGB2BGR))
    mask_vis = cv2.cvtColor((mask * 255).astype(np.uint8), cv2.COLOR_GRAY2BGR)
    cv2.imwrite(f'{output_dir}/test_image_{i}_pred_mask.png', mask_vis)
    cv2.imwrite(f'{output_dir}/test_image_{i}_restored.png', cv2.cvtColor(result, cv2.COLOR_RGB2BGR))

print(f"✓ Saved to {output_dir}")

import shutil
zip_path = '/kaggle/working/inpainting_outputs'
shutil.make_archive(zip_path, 'zip', output_dir)
print(f"✓ Zipped successfully: {zip_path}.zip")