# Colorização Automática de Imagens em Tons de Cinza via Classificação Multimodal

#### O objetivo deste notebook é implementar um algoritmo de Deep Learning para a disciplina de Processamento de Imagens, seguindo a referência e embasamento teórico encontrado em http://richzhang.github.io/colorization/

### Dependências

In [2]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from skimage import color
import matplotlib.pyplot as plt

### Configurações de execução
#### 'Q' representa o número de 'intervalos discretos' de categorias (cores) disponíveis para a classificação

In [4]:
DEVICE = 'cuda' if torch.cuda.is_available() else None
BATCH_SIZE = 16 
IMG_SIZE = 256
Q = 313  

### Pré-processamento do Dataset

In [5]:
class DogColorizationDataset(Dataset):
    def __init__(self, root_dir):
        self.root = root_dir
        try:
            self.image_paths = [os.path.join(root_dir, f) for f in os.listdir(root_dir) 
                           if f.endswith(('.JPEG'))]
        
        self.transform = transforms.Compose([
            transforms.Resize((IMG_SIZE, IMG_SIZE)),
            transforms.RandomHorizontalFlip(),
        ])

        # Carregar bins AB pré-processados (em execuções anteriores)
        self.ab_bins = torch.from_numpy(np.load('data/ab_bins.npy')).float()
        self.weights = torch.load('data/color_weights.pt')

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        rgb_img = Image.open(self.image_paths[idx]).convert('RGB')
        rgb_img = self.transform(rgb_img)
        
        # Converter para Lab e normalizar
        lab_img = color.rgb2lab(np.array(rgb_img)).astype(np.float32)
        L = lab_img[:,:,0] / 50.0 - 1.0  # [-1, 1]
        ab = lab_img[:,:,1:] / 128.0      # [-1, 1]
        
        return (
            torch.FloatTensor(L).unsqueeze(0).to(DEVICE),  # L channel
            torch.FloatTensor(ab).permute(2,0,1).to(DEVICE) # AB channels
        )

### Arquitetura do Modelo

In [6]:
class ColorNet(nn.Module):
    def __init__(self):
        super().__init__()
        
        # Encoder (Downsampling)
        self.enc1 = nn.Sequential(
            nn.Conv2d(1, 64, 3, padding=1), nn.ReLU(),
            nn.Conv2d(64, 64, 3, stride=2, padding=1), nn.ReLU(), nn.BatchNorm2d(64)
        )
        
        self.enc2 = nn.Sequential(
            nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(),
            nn.Conv2d(128, 128, 3, stride=2, padding=1), nn.ReLU(), nn.BatchNorm2d(128)
        )
        
        self.enc3 = nn.Sequential(
            nn.Conv2d(128, 256, 3, padding=1), nn.ReLU(),
            nn.Conv2d(256, 256, 3, padding=1), nn.ReLU(),
            nn.Conv2d(256, 256, 3, stride=2, padding=1), nn.ReLU(), nn.BatchNorm2d(256)
        )
        
        # Middle (Dilated Convs)
        self.mid = nn.Sequential(
            nn.Conv2d(256, 512, 3, padding=2, dilation=2), nn.ReLU(),
            nn.Conv2d(512, 512, 3, padding=2, dilation=2), nn.ReLU(),
            nn.BatchNorm2d(512)
        )
        
        # Decoder (Upsampling)
        self.dec1 = nn.Sequential(
            nn.Conv2d(512, 256, 3, padding=1), nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        )
        
        self.dec2 = nn.Sequential(
            nn.Conv2d(256, 128, 3, padding=1), nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        )
        
        self.dec3 = nn.Sequential(
            nn.Conv2d(128, 64, 3, padding=1), nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        )
        
        self.final = nn.Conv2d(64, Q, 3, padding=1)

    def forward(self, x):
        x = self.enc1(x)    # 128x128
        x = self.enc2(x)    # 64x64
        x = self.enc3(x)    # 32x32
        x = self.mid(x)     # 32x32
        x = self.dec1(x)    # 64x64
        x = self.dec2(x)    # 128x128
        x = self.dec3(x)    # 256x256
        return self.final(x)

### Função de perda aplicada:

In [7]:
class ColorLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.ab_bins = torch.from_numpy(np.load('data/ab_bins.npy')).float().to(DEVICE)
        self.weights = torch.load('data/color_weights.pt').to(DEVICE)
    
    def soft_encode(self, ab):
        # Encontrar 5 vizinhos mais próximos
        ab_flat = ab.permute(0,2,3,1).reshape(-1, 2)
        dists = torch.cdist(ab_flat, self.ab_bins)
        _, top5 = torch.topk(dists, 5, largest=False, dim=1)
        
        # Suavização Gaussiana
        sigma = 5.0
        weights = torch.exp(-dists.gather(1, top5)**2/(2*sigma**2))
        weights = weights / weights.sum(dim=1, keepdim=True)
        
        # Codificação suave
        soft_labels = torch.zeros(ab_flat.size(0), Q, device=DEVICE)
        soft_labels.scatter_(1, top5, weights)
        return soft_labels.view(ab.size(0), IMG_SIZE, IMG_SIZE, Q).permute(0,3,1,2)

    def forward(self, pred, ab_true):
        soft_targets = self.soft_encode(ab_true*128)  # Desnormalizar
        loss = F.cross_entropy(pred, soft_targets.argmax(dim=1), 
                              weight=self.weights, reduction='mean')
        return loss

### Pré-processar dados

In [12]:
import warnings
warnings.filterwarnings('ignore')

def preprocess_colors():
    os.makedirs('data', exist_ok=True)
    
    # 1. Definir bins AB (grid size 10 como no artigo)
    q = 10
    ab_bins = np.linspace(-128, 128, q*2+1)
    ab_grid = np.meshgrid(ab_bins, ab_bins)
    ab_bins = np.stack(ab_grid, axis=-1).reshape(-1, 2)
    
    # 2. Filtrar bins inválidos (fora da gama RGB)
    valid_bins = []
    for a, b in ab_bins:
        # Converter Lab para RGB
        lab = np.array([[50, a, b]], dtype=np.float32)  # L=50 (valor neutro)
        rgb = color.lab2rgb(lab)[0,0]
        if (rgb >= 0).all() and (rgb <= 1).all():
            valid_bins.append([a, b])
    
    ab_bins_valid = np.array(valid_bins)
    np.save('data/ab_bins.npy', ab_bins_valid)
    
    # 3. Calcular distribuição de pesos (usando Imagewoof)
    train_dataset = DogColorizationDataset('data/train')  # Usar dataset real
    hist = np.zeros(len(ab_bins_valid))
    
    for L, ab in train_dataset:
        ab_flat = ab.squeeze().permute(1,2,0).cpu().numpy() * 128  # Desnormalizar
        dists = np.linalg.norm(ab_flat[:, :, None] - ab_bins_valid, axis=-1)
        nearest = np.argmin(dists, axis=-1)
        hist += np.bincount(nearest.flatten(), minlength=len(ab_bins_valid))
    
    # Suavizar e calcular pesos
    sigma = 5
    p = hist / hist.sum()
    p_smooth = np.zeros_like(p)
    for i in range(len(ab_bins_valid)):
        dists = np.linalg.norm(ab_bins_valid - ab_bins_valid[i], axis=-1)
        p_smooth[i] = np.sum(p * np.exp(-dists**2/(2*sigma**2)))
    
    p_smooth = p_smooth / p_smooth.sum()
    weights = 1 / ((1 - 0.5) * p_smooth + 0.5 / len(ab_bins_valid))
    weights = weights / weights.mean()  # Normalizar
    torch.save(torch.FloatTensor(weights), 'data/color_weights.pt')

# Executar pré-processamento
preprocess_colors()

FileNotFoundError: [Errno 2] No such file or directory: 'data/color_weights.pt'

### Treinamento

In [8]:
CHECKPOINT_DIR = 'checkpoints'
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

def train(resume_checkpoint=None):
    train_set = DogColorizationDataset('data/train')
    train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True,
                             num_workers=4, pin_memory=True)

    model = ColorNet().to(DEVICE)
    criterion = ColorLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=3e-5)
    scaler = torch.cuda.amp.GradScaler()
    
    # Variáveis de estado do treinamento
    start_epoch = 0
    best_loss = float('inf')
    loss_history = []

    # Carregar checkpoint se especificado
    if resume_checkpoint:
        checkpoint = torch.load(resume_checkpoint)
        model.load_state_dict(checkpoint['model_state'])
        optimizer.load_state_dict(checkpoint['optimizer_state'])
        scaler.load_state_dict(checkpoint['scaler_state'])
        start_epoch = checkpoint['epoch'] + 1
        best_loss = checkpoint['best_loss']
        loss_history = checkpoint['loss_history']
        print(f"Resuming training from epoch {start_epoch}")

    for epoch in range(start_epoch, 100):
        epoch_loss = 0.0
        
        for i, (L, ab) in enumerate(train_loader):
            with torch.cuda.amp.autocast():
                pred = model(L)
                loss = criterion(pred, ab)
            
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
            
            epoch_loss += loss.item()
            
            # Salvamento intermediário a cada 25% de uma época
            if i % (len(train_loader)//4) == 0:
                checkpoint_path = os.path.join(CHECKPOINT_DIR, 
                    f'interim_epoch_{epoch}_batch_{i}.pt')
                torch.save({
                    'epoch': epoch,
                    'batch': i,
                    'model_state': model.state_dict(),
                    'optimizer_state': optimizer.state_dict(),
                    'scaler_state': scaler.state_dict(),
                    'loss': loss.item(),
                    'best_loss': best_loss,
                    'loss_history': loss_history
                }, checkpoint_path)
        
        # Salvamento completo no final de cada época
        avg_loss = epoch_loss / len(train_loader)
        loss_history.append(avg_loss)
        
        # Salvar checkpoint regular
        regular_checkpoint = os.path.join(CHECKPOINT_DIR, f'epoch_{epoch}.pt')
        torch.save({
            'epoch': epoch,
            'model_state': model.state_dict(),
            'optimizer_state': optimizer.state_dict(),
            'scaler_state': scaler.state_dict(),
            'loss': avg_loss,
            'best_loss': best_loss,
            'loss_history': loss_history
        }, regular_checkpoint)
        
        # Salvar melhor modelo
        if avg_loss < best_loss:
            best_loss = avg_loss
            torch.save(model.state_dict(), os.path.join(CHECKPOINT_DIR, 'best_model.pt'))
        
        print(f'Epoch: {epoch+1}, Loss: {avg_loss:.4f}, Best Loss: {best_loss:.4f}')

In [9]:
def find_latest_checkpoint():
    checkpoints = [f for f in os.listdir(CHECKPOINT_DIR) if f.startswith('epoch_')]
    if not checkpoints:
        return None
    latest = sorted(checkpoints, key=lambda x: int(x.split('_')[1].split('.')[0]))[-1]
    return os.path.join(CHECKPOINT_DIR, latest)

# Modifique a execução principal para resumir automaticamente
if __name__ == '__main__':
    latest_checkpoint = find_latest_checkpoint()
    if latest_checkpoint:
        print(f"Found existing checkpoint: {latest_checkpoint}")
        train(resume_checkpoint=latest_checkpoint)
    else:
        train()

FileNotFoundError: [Errno 2] No such file or directory: 'data/ab_bins.npy'

### Inferir (1 imagem)

In [None]:
def colorize(image_path, model_path='color_net.pth'):
    model = ColorNet().to(DEVICE)
    model.load_state_dict(torch.load(model_path))
    model.eval()
    
    # Pré-processamento
    img = Image.open(image_path).convert('RGB')
    img = transforms.Resize((IMG_SIZE, IMG_SIZE))(img)
    L = color.rgb2lab(np.array(img))[:,:,0]
    L_tensor = torch.FloatTensor(L/50.0 - 1.0).unsqueeze(0).unsqueeze(0).to(DEVICE)
    
    # Predição
    with torch.no_grad():
        pred = model(L_tensor)
        probs = F.softmax(pred/0.38, dim=1)  # Annealed mean
        ab = torch.einsum('bqhw,qc->bchw', probs, self.ab_bins)
    
    # Pós-processamento
    lab = torch.cat([(L_tensor.squeeze()*50 + 50).unsqueeze(0), ab.squeeze()*128], dim=0)
    rgb = color.lab2rgb(lab.permute(1,2,0).cpu().numpy())
    
    plt.imshow(rgb)
    plt.axis('off')
    plt.show()