# Logo Detection Training Notebook

This notebook contains all the code to train a logo detection model. It is a self-contained notebook that includes all the necessary code from the repository.

## 1. Install Dependencies

In [None]:
!pip install Pillow torch torchvision rembg tqdm numpy scikit-learn onnxruntime

## 2. Mount Google Drive

Mount your Google Drive to access the dataset.

In [None]:
from google.colab import drive
drive.mount('/content/drive')

## 3. Utility Functions

These are the utility functions for unzipping data, image manipulation, and loading data.

In [None]:
import zipfile
import os

def unzip_data(zip_file_path, destination_dir):
    """
    Unzips a zip file to a specified destination directory.

    Args:
        zip_file_path (str): The path to the zip file.
        destination_dir (str): The path to the destination directory.
    """
    if not os.path.exists(zip_file_path):
        raise FileNotFoundError(f"Zip file not found at: {zip_file_path}")

    try:
        with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
            zip_ref.extractall(destination_dir)
            print(f"Successfully extracted {zip_file_path} to {destination_dir}")
    except zipfile.BadZipFile:
        raise ValueError(f"Invalid zip file: {zip_file_path}")

In [None]:
from PIL import Image
import random
from typing import Tuple

def load_image(image_path: str) -> Image.Image:
    """
    Loads an image from a file path.

    Args:
        image_path: The path to the image file.

    Returns:
        A Pillow Image object.
    """
    return Image.open(image_path)

def save_image(image: Image.Image, save_path: str) -> None:
    """
    Saves an image to a file path.

    Args:
        image: The Pillow Image object to save.
        save_path: The path to save the image to.
    """
    image.save(save_path)

def crop_background(background: Image.Image, width: int, height: int) -> Image.Image:
    """
    Crops a random area from the background image.

    Args:
        background: The background image.
        width: The width of the desired crop.
        height: The height of the desired crop.

    Returns:
        The cropped background image.
        
    Raises:
        ValueError: If the background image is smaller than the desired crop dimensions.
    """
    if background.width < width or background.height < height:
        raise ValueError("Background image is smaller than the foreground image.")
    
    left = random.randint(0, background.width - width)
    top = random.randint(0, background.height - height)
    right = left + width
    bottom = top + height
    
    return background.crop((left, top, right, bottom))

def composite_images(background: Image.Image, foreground: Image.Image) -> Image.Image:
    """
    Composites a foreground image onto a background image.
    The foreground image should have an alpha channel.

    Args:
        background: The background image.
        foreground: The foreground image.

    Returns:
        The composited image.
    """
    background.paste(foreground, (0, 0), foreground)
    return background

In [None]:
from PIL import Image
from pathlib import Path
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset

class LogoDataset(Dataset):
    def __init__(self, images, masks, transform=None):
        self.images = images
        self.masks = masks
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        mask = self.masks[idx]

        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)

        return image, mask

def load_data(input_dir, mask_dir):
    """
    Loads images and masks from the specified directories.

    Args:
        input_dir (str): Path to the directory of training images.
        mask_dir (str): Path to the directory of mask images.

    Returns:
        tuple: A tuple containing training and validation data splits.
    """
    input_path = Path(input_dir)
    mask_path = Path(mask_dir)

    image_files = sorted([p for p in input_path.glob("*.png")])
    mask_files = sorted([p for p in mask_path.glob("*.png")])

    if len(image_files) != len(mask_files):
        print(f"Warning: Mismatched number of images and masks. Found {len(image_files)} images and {len(mask_files)} masks.")
        # Use the intersection of filenames
        image_names = {p.name for p in image_files}
        mask_names = {p.name for p in mask_files}
        common_names = sorted(list(image_names.intersection(mask_names)))
        
        image_files = [input_path / name for name in common_names]
        mask_files = [mask_path / name for name in common_names]

    images = [Image.open(p) for p in image_files]
    masks = [Image.open(p) for p in mask_files]

    X_train, X_val, y_train, y_val = train_test_split(images, masks, test_size=0.2, random_state=42)

    return X_train, X_val, y_train, y_val

## 4. Training Script

In [None]:
import torch
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from rembg import new_session
from tqdm import tqdm
import time
from pathlib import Path
from PIL import Image
from sklearn.model_selection import train_test_split

# Define parameters
# These were previously command line arguments
# You can modify these values
zip_file_path_str = "/content/drive/My Drive/logo-trainer/data.zip"
output_path_str = "/content/drive/My Drive/logo-trainer/models/u2net_logo.pth"
epochs = 50
batch_size = 16
learning_rate = 0.001

# The rest of the training script
zip_file_path = Path(zip_file_path_str)
extracted_data_path = Path("data/extracted")
unzip_data(zip_file_path, extracted_data_path)

input_dir = extracted_data_path / "data/transparent"
mask_dir = extracted_data_path / "data/output"

transform = transforms.Compose([
    transforms.Resize((320, 320)),
    transforms.ToTensor(),
])

X_train, X_val, y_train, y_val = load_data(input_dir, mask_dir)

train_dataset = LogoDataset(X_train, y_train, transform=transform)
val_dataset = LogoDataset(X_val, y_val, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

session = new_session("u2net")
model = session.model

# Check for GPU availability
if torch.cuda.is_available():
    print("GPU is available. Training will be on GPU.")
else:
    print("GPU is not available. Training will be on CPU.")
    print("To use GPU in Google Colab, go to Runtime -> Change runtime type and select GPU as the hardware accelerator.")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
criterion = torch.nn.BCEWithLogitsLoss()

for epoch in range(epochs):
    model.train()
    for images, masks in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
        images, masks = images.to(device), masks.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs[0], masks)
        loss.backward()
        optimizer.step()

    model.eval()
    total_iou = 0
    with torch.no_grad():
        for images, masks in val_loader:
            images, masks = images.to(device), masks.to(device)
            outputs = model(images)
            preds = torch.sigmoid(outputs[0]) > 0.5
            
            intersection = torch.logical_and(preds, masks).sum()
            union = torch.logical_or(preds, masks).sum()
            iou = intersection / union
            total_iou += iou.item()
    
    avg_iou = total_iou / len(val_loader)
    print(f"Epoch {epoch+1}/{epochs}, Validation IoU: {avg_iou:.4f}")

output_path = Path(output_path_str)
if not output_path.parent.exists():
    output_path.parent.mkdir(parents=True)

if output_path.exists():
    timestamp = time.strftime("%Y%m%d-%H%M%S")
    output_path = output_path.with_name(f"{output_path.stem}_{timestamp}{output_path.suffix}")

torch.save(model.state_dict(), output_path)
print(f"Model saved to {output_path}")

## 5. Image Processing Script

This script can be used to add backgrounds to transparent images.

In [None]:
import os
import random
import logging

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

def process_images() -> None:
    """
    Main function to process transparent images and add backgrounds.
    
    This script reads transparent images from `data/transparent`,
    backgrounds from `data/bg-sample`, and saves the composited
    images to `data/output`.
    """
    transparent_dir = 'data/transparent'
    bg_dir = 'data/bg-sample'
    output_dir = 'data/output'

    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    transparent_images = [f for f in os.listdir(transparent_dir) if f.endswith('.png')]
    background_images = [f for f in os.listdir(bg_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]

    if not background_images:
        logging.warning("No background images found.")
        return

    for transparent_image_name in transparent_images:
        try:
            transparent_image_path = os.path.join(transparent_dir, transparent_image_name)
            transparent_image = load_image(transparent_image_path)

            bg_image_name = random.choice(background_images)
            bg_image_path = os.path.join(bg_dir, bg_image_name)
            background_image = load_image(bg_image_path)

            cropped_bg = crop_background(background_image, transparent_image.width, transparent_image.height)
            final_image = composite_images(cropped_bg, transparent_image)

            output_path = os.path.join(output_dir, transparent_image_name)
            save_image(final_image, output_path)
            logging.info(f"Processed {transparent_image_name}")

        except ValueError as e:
            logging.warning(f"Skipping {transparent_image_name} due to small background: {e}")
        except Exception as e:
            logging.error(f"Error processing {transparent_image_name}: {e}")

# You can call this function to process the images
# process_images()