## Import and Setup


In [None]:
# !pip install segmentation-models-pytorch albumentations opencv-python pyyaml tqdm matplotlib


In [None]:
import os
import cv2
import torch
import numpy as np
import segmentation_models_pytorch as smp
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torch.utils.data import Dataset, DataLoader
from matplotlib import pyplot as plt
from tqdm import tqdm
from PIL import Image

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
IMAGE_SIZE = 1024

print(f"Using device: {DEVICE}")


## Define the Dataset

In [None]:
class CensorshipDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=True):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.image_filenames = sorted(os.listdir(image_dir))
        self.mask_filenames = sorted(os.listdir(mask_dir))
        self.transform = transform
        self.aug = A.Compose([
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),   
            A.RandomRotate90(p=0.5),
            # A.Normalize(),          
            ToTensorV2()              
        ])
        self.resize = A.Resize(IMAGE_SIZE, IMAGE_SIZE, interpolation=cv2.INTER_NEAREST)

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        image_path = os.path.join(self.image_dir, self.image_filenames[idx])
        mask_path = os.path.join(self.mask_dir, self.mask_filenames[idx])
        
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        # Convert to binary mask (0 for background, 1 for censorship)
        mask = (mask > 0).astype(np.uint8)

        # Resize both before transform
        resized = self.resize(image=image, mask=mask)
        image, mask = resized["image"], resized["mask"]

        if self.transform:
            augmented = self.aug(image=image, mask=mask)
            image, mask = augmented["image"], augmented["mask"]

        return image, mask.long()
    
    # Define paths for different censorship types
dataset_base = "dataset"
censorship_types = ["black_bars", "white_bars", "transparent_black"]

# Create dictionary of datasets
datasets = {}
for ctype in censorship_types:
    image_dir = os.path.join(dataset_base, "images", ctype)
    mask_dir = os.path.join(dataset_base, "masks", ctype)
    
    if not os.path.exists(image_dir) or not os.path.exists(mask_dir):
        print(f"Warning: {ctype} directories not found, skipping.")
        continue
        
    datasets[ctype] = CensorshipDataset(image_dir, mask_dir)
    print(f"Loaded {ctype} dataset with {len(datasets[ctype])} images")

    print(f"Displaying 3 samples from {ctype} dataset:")
    for i in range(min(3, len(datasets[ctype]))):
        image, mask = datasets[ctype][i]
        image = image.permute(1, 2, 0).cpu().numpy() 
        mask = mask.cpu().numpy()
        
        plt.figure(figsize=(10, 5))
        plt.subplot(1, 2, 1)
        plt.imshow(image)
        plt.title(f"{ctype} - Image {i+1}")
        
        plt.subplot(1, 2, 2)
        plt.imshow(mask, cmap="gray")
        plt.title(f"{ctype} - Mask {i+1}")
        
        plt.tight_layout()
        plt.show()


## Create and Train Models


In [None]:
def create_model():
    """Create a binary segmentation model"""
    return smp.UnetPlusPlus(
        encoder_name="efficientnet-b6",
        encoder_weights="imagenet",
        encoder_depth=5,
        in_channels=3,
        classes=1
    ).to(DEVICE)

trained_models = {}

def train_model(model, dataset, model_name, epochs=30, val_split=0.2, patience=10):
    """
    Train a model on the given dataset with validation, early stopping, and learning rate scheduling
    
    Args:
        model: The segmentation model to train
        dataset: Dataset containing training images and masks
        model_name: Name of the model (used for saving)
        epochs: Maximum number of training epochs
        val_split: Proportion of data to use for validation
        patience: Number of epochs to wait before early stopping
    """
    # Create validation split
    dataset_size = len(dataset)
    val_size = int(dataset_size * val_split)
    train_size = dataset_size - val_size
    train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
    
    train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False)
    
    # Define loss function, optimizer and scheduler
    loss_fn = smp.losses.FocalLoss(mode='binary')
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=5)
    
    # Initialize tracking variables
    best_val_loss = float('inf')
    patience_counter = 0
    train_losses = []
    val_losses = []
    val_ious = []
    best_epoch = 0
    
    def iou_score(pred, target):
        """Calculate IoU score between prediction and target for binary segmentation"""
        batch_size = pred.size(0)
        total_iou = 0.0
        
        for i in range(batch_size):
            pred_mask = (pred[i].sigmoid() > 0.5).float()
            target_mask = target[i].float()
            
            pred_mask = pred_mask.cpu()
            target_mask = target_mask.cpu()
            
            intersection = (pred_mask * target_mask).sum().item()
            union = (pred_mask + target_mask).gt(0).sum().item()
            
            if union == 0:
                total_iou += 1.0  # If both prediction and target are empty, IoU is 1
            else:
                total_iou += intersection / union
        
        return total_iou / batch_size
    
    print(f"Starting training for {model_name}:")
    print(f"- Training samples: {train_size}")
    print(f"- Validation samples: {val_size}")
    
    for epoch in range(epochs):
        # Training phase
        model.train()
        running_train_loss = 0.0
        
        for images, masks in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
            images, masks = images.to(DEVICE), masks.to(DEVICE)
            
            # Forward pass
            preds = model(images)
            loss = loss_fn(preds, masks)
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            running_train_loss += loss.item()
        
        avg_train_loss = running_train_loss / len(train_loader)
        train_losses.append(avg_train_loss)
        
        # Validation phase
        model.eval()
        running_val_loss = 0.0
        running_iou = 0.0
        
        with torch.no_grad():
            for images, masks in val_loader:
                images, masks = images.to(DEVICE), masks.to(DEVICE)
                
                # Forward pass
                preds = model(images)
                val_loss = loss_fn(preds, masks)
                
                # Debug prints (comment out after debugging)
                if epoch == 0:  # Print only in first epoch
                    print(f"\nDebug Information:")
                    print(f"Predictions range: ({preds.min():.4f}, {preds.max():.4f})")
                    print(f"Predictions shape: {preds.shape}")
                    print(f"Masks range: ({masks.min():.4f}, {masks.max():.4f})")
                    print(f"Masks shape: {masks.shape}")
                    print(f"Unique mask values: {torch.unique(masks).tolist()}")
                
                # Calculate metrics
                batch_iou = iou_score(preds, masks)
                
                running_val_loss += val_loss.item()
                running_iou += batch_iou
        
        avg_val_loss = running_val_loss / len(val_loader)
        avg_iou = running_iou / len(val_loader)
        val_losses.append(avg_val_loss)
        val_ious.append(avg_iou)
        
        # Print epoch results
        print(f"\n{model_name} - Epoch {epoch+1}/{epochs}")
        print(f"  Train Loss: {avg_train_loss:.4f}")
        print(f"  Val Loss: {avg_val_loss:.4f}, IoU: {avg_iou:.4f}")
        
        # Update learning rate based on validation loss
        scheduler.step(avg_val_loss)
        current_lr = optimizer.param_groups[0]['lr']
        print(f"  Current LR: {current_lr:.6f}")
        
        # Save best model
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            best_epoch = epoch + 1
            patience_counter = 0
            
            # Save best model
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_loss': best_val_loss,
                'val_iou': avg_iou
            }, f'best_{model_name}_model.pth')
            print(f"  Saved new best model with Val Loss: {best_val_loss:.4f}")
        else:
            patience_counter += 1
            print(f"  Patience: {patience_counter}/{patience}")
        
        # Early stopping
        if patience_counter >= patience:
            print(f"Early stopping at epoch {epoch+1}")
            break
    
    # Save final model
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'val_loss': avg_val_loss,
        'val_iou': avg_iou
    }, f'final_{model_name}_model.pth')
    
    print(f"Training completed. Best model saved at epoch {best_epoch} with Val Loss: {best_val_loss:.4f}")
    
    # Plot training curves
    plt.figure(figsize=(12, 4))
    
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Training Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.title(f'{model_name} - Loss Curves')
    
    plt.subplot(1, 2, 2)
    plt.plot(val_ious, label='Validation IoU')
    plt.xlabel('Epoch')
    plt.ylabel('IoU Score')
    plt.legend()
    plt.title(f'{model_name} - IoU Curve')
    
    plt.tight_layout()
    plt.savefig(f'{model_name}_training_curves.png')
    plt.show()
    
    # Load and return best model
    best_model = create_model()
    checkpoint = torch.load(f'best_{model_name}_model.pth')
    best_model.load_state_dict(checkpoint['model_state_dict'])
    return best_model

def load_existing_model(model_name):
    """
    Try to load an existing model with the given name
    
    Args:
        model_name: Name of the model to load (e.g., 'black_bars')
        
    Returns:
        model: Loaded model or None if no model file exists
    """
    # Check for different possible model file paths
    model_paths = [
        f'best_{model_name}_model.pth',
        f'final_{model_name}_model.pth',
        f'segmentation_model_{model_name}.pth'
    ]
    
    for model_path in model_paths:
        if os.path.exists(model_path):
            print(f"Found existing model at {model_path}")
            model = create_model()
            model.load_state_dict(torch.load(model_path))
            model.eval()  # Set to evaluation mode
            return model
    
    print(f"No existing model found for {model_name}")
    return None

## Train Model for Black Bars Detection

This section focuses on training a model specifically for detecting black bars in images.

In [None]:
# Train model for black bars detection
black_bars_dataset = datasets.get('black_bars')
if black_bars_dataset:
    print('Processing black bars model...')
    
    # Try to load existing model first
    black_bars_model = load_existing_model('black_bars')
    
    # If no model was found, train a new one
    if black_bars_model is None:
        print('Training new model for black bars...')
        black_bars_model = create_model()
        black_bars_model = train_model(black_bars_model, black_bars_dataset, 'black_bars')
    else:
        print('Using existing black bars model')
        black_bars_model.eval()
        black_bars_model = train_model(black_bars_model, black_bars_dataset, 'black_bars', epochs=10)
    
    # Store in trained_models dictionary
    trained_models = {} if 'trained_models' not in globals() else trained_models
    trained_models['black_bars'] = black_bars_model
else:
    print('Black bars dataset not found. Skipping.')

## Train Model for White Bars Detection

This section will be used to train a model for detecting white bars in images.

In [None]:
# Train model for white bars detection
white_bars_dataset = datasets.get('white_bars')
if white_bars_dataset:
    print('Processing white bars model...')
    
    # Try to load existing model first
    white_bars_model = load_existing_model('white_bars')
    
    # If no model was found, train a new one
    if white_bars_model is None:
        print('Training new model for white bars...')
        white_bars_model = create_model()
        white_bars_model = train_model(white_bars_model, white_bars_dataset, 'white_bars')
        black_bars_model = train_model(black_bars_model, black_bars_dataset, 'black_bars', epochs=10)
    else:
        print('Using existing white bars model')
    
    # Store in trained_models dictionary
    trained_models = {} if 'trained_models' not in globals() else trained_models
    trained_models['white_bars'] = white_bars_model
else:
    print('White bars dataset not found. Skipping.')

## Train Model for Transparent Black Detection

In [None]:
# Train model for transparent black detection
transparent_black_dataset = datasets.get('transparent_black')
if transparent_black_dataset:
    print('Processing transparent black model...')
    
    # Try to load existing model first
    transparent_black_model = load_existing_model('transparent_black')
    
    # If no model was found, train a new one
    if transparent_black_model is None:
        print('Training new model for transparent black...')
        transparent_black_model = create_model()
        transparent_black_model = train_model(transparent_black_model, transparent_black_dataset, 'transparent_black')
    else:
        print('Using existing transparent black model')
        black_bars_model = train_model(black_bars_model, black_bars_dataset, 'black_bars', epochs=10)
    
    # Store in trained_models dictionary
    trained_models = {} if 'trained_models' not in globals() else trained_models
    trained_models['transparent_black'] = transparent_black_model
else:
    print('transparent_black dataset not found. Skipping.')

## Load the Saved Models

In [None]:
# Load all saved models
loaded_models = {}

for ctype in censorship_types:
    model_path = f'pretrained/best_{ctype}_model.pth'
    if os.path.exists(model_path):
        model = create_model()
        checkpoint = torch.load(model_path)
        model.load_state_dict(checkpoint['model_state_dict'])
        model.eval()
        loaded_models[ctype] = model
        print(f"Loaded model for {ctype}")

def preprocess_image(img_path):
    """Load and preprocess an image."""
    image = cv2.imread(img_path)
    if image is None:
        raise ValueError(f"Failed to load image: {img_path}")
    
    # Retain the original color image
    original_color_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    # Convert to grayscale for the model
    if len(image.shape) == 3:  # Check if the image has 3 channels (colored)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)  # Convert back to 3 channels for the model
    
    resize = A.Resize(IMAGE_SIZE, IMAGE_SIZE, interpolation=cv2.INTER_NEAREST)
    resized = resize(image=image)["image"]
    normalized = A.Normalize()(image=resized)["image"]
    tensor_image = ToTensorV2()(image=normalized)["image"].unsqueeze(0).to(DEVICE)
    return original_color_image, tensor_image

def predict_mask(tensor_image, model):
    """Predict a mask using the given model."""
    with torch.no_grad():
        prediction = model(tensor_image)
        return prediction.squeeze().sigmoid().cpu().numpy()

def create_opacity_mask(predicted_mask):
    """
    Create an opacity mask based on thresholds.

    Args:
        predicted_mask: The predicted mask from the model.

    Returns:
        The opacity mask.
    """
    opacity_mask = np.zeros_like(predicted_mask, dtype=np.uint8)
    opacity_mask[predicted_mask > 0.2] = 50
    opacity_mask[predicted_mask > 0.35] = 75
    opacity_mask[predicted_mask > 0.5] = 100
    return opacity_mask

def test_specific_model(test_dir, model_type=None):
    """
    Test specific models based on subdirectories in the test directory.

    Args:
        test_dir: Path to the test directory containing subdirectories for each censorship type.
        model_type: Specific model type to test (e.g., 'white_bars'). If None, test all available models.
    """
    for ctype, model in loaded_models.items():
        if model_type and ctype != model_type:
            continue  # Skip other models if a specific model_type is provided

        specific_test_dir = os.path.join(test_dir, ctype)
        if not os.path.exists(specific_test_dir):
            print(f"Skipping {ctype}: No directory found at {specific_test_dir}")
            continue

        test_images = [os.path.join(specific_test_dir, filename) for filename in os.listdir(specific_test_dir)
                       if filename.lower().endswith(('.png', '.jpg', '.jpeg'))]

        if not test_images:
            print(f"No test images found in {specific_test_dir} for {ctype}")
            continue

        print(f"\nTesting {ctype} model on {len(test_images)} images in {specific_test_dir}")

        for img_path in test_images:
            try:
                print(f"Testing on: {img_path}")
                image, tensor_image = preprocess_image(img_path)
                predicted_mask = predict_mask(tensor_image, model)
                opacity_mask = create_opacity_mask(predicted_mask)

                # Display results
                plt.figure(figsize=(10, 5))
                plt.subplot(1, 2, 1)
                plt.imshow(image)
                plt.title("Input Image")

                plt.subplot(1, 2, 2)
                plt.imshow(opacity_mask, cmap="gray")
                plt.title(f"{ctype} Detection")

                plt.tight_layout()
                plt.show()

            except Exception as e:
                print(f"Error processing {img_path}: {e}")

def export_image_and_mask(image, mask, output_dir, name):
    """Save an image or mask to the specified directory."""
    os.makedirs(output_dir, exist_ok=True)
    output_path = os.path.join(output_dir, f"{name}.png")
    if len(mask.shape) == 2:  # Mask
        Image.fromarray(mask).save(output_path, format="PNG")
    else:  # Image
        Image.fromarray(image).save(output_path, format="PNG")

def export_for_lama_inpainting(model_types=None):
    """
    Process test images with selected models and save results for LAMA inpainting.

    Args:
        model_types: List of censorship types to use (e.g., ['black_bars', 'white_bars']).
                     If None, all available models will be used.
    """
    base_test_dir = 'input'
    demo_input_dir = '../lama-inpainting/input/images'
    demo_mask_dir = '../lama-inpainting/input/masks'

    if model_types is None:
        model_types = list(loaded_models.keys())
    model_types = [t for t in model_types if t in loaded_models]

    if not model_types:
        print("No valid model types specified!")
        return

    print(f"Using models: {', '.join(model_types)}")

    for ctype in model_types:
        specific_test_dir = os.path.join(base_test_dir, ctype)
        if not os.path.exists(specific_test_dir):
            print(f"Skipping {ctype}: No directory found at {specific_test_dir}")
            continue

        test_images = [os.path.join(specific_test_dir, filename) for filename in os.listdir(specific_test_dir)
                       if filename.lower().endswith(('.png', '.jpg', '.jpeg'))]
        print(f"Found {len(test_images)} test images for {ctype}")

        for img_path in test_images:
            try:
                filename = os.path.basename(img_path)
                name, _ = os.path.splitext(filename)

                # Preprocess image
                original_color_image, tensor_image = preprocess_image(img_path)

                # Predict mask using the model
                model = loaded_models[ctype]
                predicted_mask = predict_mask(tensor_image, model)
                opacity_mask = create_opacity_mask(predicted_mask)

                # Normalize mask to 0-255 range
                if opacity_mask.max() > 0:
                    opacity_mask = (opacity_mask / opacity_mask.max() * 255).astype(np.uint8)

                # Resize mask back to original image size
                h, w = original_color_image.shape[:2]
                resized_mask = cv2.resize(opacity_mask, (w, h), interpolation=cv2.INTER_NEAREST)

                # Save original color image and mask
                export_image_and_mask(original_color_image, original_color_image, demo_input_dir, name)
                export_image_and_mask(resized_mask, resized_mask, demo_mask_dir, name)

                print(f"Processed {filename} -> {name}.png")

            except Exception as e:
                print(f"Error processing {img_path}: {e}")

    print(f"Exported images and masks to:")
    print(f"  Images: {demo_input_dir}")
    print(f"  Masks: {demo_mask_dir}")

## Test the Models

In [None]:
test_specific_model(test_dir="input", model_type="black_bars")
# test_specific_model(test_dir="input", model_type="white_bars")
# test_specific_model(test_dir="input", model_type="transparent_black")

## Export Images and Masks for LAMA Inpainting

In [None]:
# export_for_lama_inpainting(model_types=["black_bars", "white_bars"])
export_for_lama_inpainting(model_types=["transparent_black"])