In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from pathlib import Path
import numpy as np
import cv2
from PIL import Image
import yaml
from tqdm import tqdm
import wandb  # Para tracking de experimentos
import os
from typing import List, Tuple, Dict
import json

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
print(f"CUDA Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

In [None]:
class Vimeo90kDataset(Dataset):
    """
    Dataset para Vimeo-90k Triplet
    Carrega triplas de frames para treinamento de interpolação
    """
    def __init__(
        self,
        data_dir: str,
        is_train: bool = True,
        transform=None,
        crop_size: Tuple[int, int] = (256, 256),
        cache=True  # NOVO: ativar cache
    ):
        print("Initializing Vimeo90kDataset...")
        self.data_dir = Path(data_dir)
        self.is_train = is_train
        self.crop_size = crop_size
        self.cache = cache
        self.image_cache = {}  # NOVO: dicionário de cache

        print(f"Data directory: {self.data_dir}")
        # Carregar lista de sequências
        list_file = 'tri_trainlist.txt' if is_train else 'tri_testlist.txt'
        list_path = self.data_dir / list_file
        
        print(f"List file path: {list_path}")

        with open(list_path, 'r') as f:
            self.triplets = [line.strip() for line in f.readlines()]
            print(f"First 5 triplets: {self.triplets[:5]}")
        
        print(f"Loaded {len(self.triplets)} triplets for {'training' if is_train else 'testing'}")
        
        # Transformações
        if transform is None:
            if is_train:
                self.transform = transforms.Compose([
                    transforms.RandomCrop(crop_size),
                    transforms.RandomHorizontalFlip(p=0.5),
                    transforms.RandomVerticalFlip(p=0.5),
                ])
            else:
                self.transform = transforms.Compose([
                    transforms.CenterCrop(crop_size),
                ])
        else:
            self.transform = transform
    
    def __len__(self):
        return len(self.triplets)
    
    def __getitem__(self, idx):
        # Triplet path: 00001/0001
        triplet_path = self.triplets[idx]
        base_path = self.data_dir / 'sequences' / triplet_path
        
        # Carregar os 3 frames (com cache)
        frame1 = self._load_image(base_path / 'im1.png')
        frame2 = self._load_image(base_path / 'im2.png')  # Ground truth (meio)
        frame3 = self._load_image(base_path / 'im3.png')
        
        # Stack para aplicar mesmas transformações
        frames = torch.cat([frame1, frame2, frame3], dim=0)
        
        # Aplicar transformações
        if self.transform:
            frames = self.transform(frames)
        
        # Separar frames novamente
        frame1 = frames[:3, :, :]
        frame2 = frames[3:6, :, :]
        frame3 = frames[6:9, :, :]
        
        return {
            'frame1': frame1,
            'frame2': frame2,  # Ground truth
            'frame3': frame3,
            'triplet_path': triplet_path
        }
    
    def _load_image(self, path: Path) -> torch.Tensor:
        """Carregar imagem e converter para tensor (com cache)"""
        path_str = str(path)
        
        # Se está em cache e cache está ativado, retornar do cache
        if self.cache and path_str in self.image_cache:
            return self.image_cache[path_str].clone()
        
        # Senão, carregar do disco
        img = Image.open(path).convert('RGB')
        img_tensor = transforms.ToTensor()(img)
        
        # Guardar no cache se ativado
        if self.cache:
            self.image_cache[path_str] = img_tensor
        
        return img_tensor

In [None]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        return self.conv(x)
    
class UNetInterpolator(nn.Module):
    """
    U-Net para Frame Interpolation
    Input: frame1 + frame3 (6 canais)
    Output: frame2 interpolado (3 canais)
    """
    def __init__(self, in_channels=6, out_channels=3):
        super().__init__()
        
        # Encoder
        self.enc1 = ConvBlock(in_channels, 64)
        self.enc2 = ConvBlock(64, 128)
        self.enc3 = ConvBlock(128, 256)
        self.enc4 = ConvBlock(256, 512)
        
        # Bottleneck
        self.bottleneck = ConvBlock(512, 1024)
        
        # Decoder
        self.upconv4 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
        self.dec4 = ConvBlock(1024, 512)
        
        self.upconv3 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.dec3 = ConvBlock(512, 256)
        
        self.upconv2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.dec2 = ConvBlock(256, 128)
        
        self.upconv1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.dec1 = ConvBlock(128, 64)
        
        # Output
        self.out = nn.Conv2d(64, out_channels, 1)
        
        self.pool = nn.MaxPool2d(2)
    
    def forward(self, frame1, frame3):
        # Concatenar frames de entrada
        x = torch.cat([frame1, frame3], dim=1)
        
        # Encoder
        enc1 = self.enc1(x)
        x = self.pool(enc1)
        
        enc2 = self.enc2(x)
        x = self.pool(enc2)
        
        enc3 = self.enc3(x)
        x = self.pool(enc3)
        
        enc4 = self.enc4(x)
        x = self.pool(enc4)
        
        # Bottleneck
        x = self.bottleneck(x)
        
        # Decoder with skip connections
        x = self.upconv4(x)
        x = torch.cat([x, enc4], dim=1)
        x = self.dec4(x)
        
        x = self.upconv3(x)
        x = torch.cat([x, enc3], dim=1)
        x = self.dec3(x)
        
        x = self.upconv2(x)
        x = torch.cat([x, enc2], dim=1)
        x = self.dec2(x)
        
        x = self.upconv1(x)
        x = torch.cat([x, enc1], dim=1)
        x = self.dec1(x)
        
        # Output
        out = self.out(x)
        out = torch.sigmoid(out)  # Valores entre 0 e 1
        
        return out

In [None]:
"""
Combinação de múltiplas losses para melhor qualidade:
- L1 Loss: Reconstrução pixel-wise
- Perceptual Loss: Similaridade em features de alto nível
- SSIM Loss: Similaridade estrutural
"""

try:
    import lpips
    LPIPS_AVAILABLE = True
except ImportError:
    LPIPS_AVAILABLE = False
    print("Warning: lpips not available. Install with: pip install lpips")

class CombinedLoss(nn.Module):
    def __init__(self, device='cuda'):
        super().__init__()
        self.l1_loss = nn.L1Loss()
        self.mse_loss = nn.MSELoss()
        
        # Perceptual loss usando LPIPS
        if LPIPS_AVAILABLE:
            self.lpips_loss = lpips.LPIPS(net='alex').to(device)
        else:
            self.lpips_loss = None
        
        # Pesos das losses
        self.w_l1 = 1.0
        self.w_perceptual = 0.1 if LPIPS_AVAILABLE else 0.0
    
    def forward(self, pred, target):
        # L1 Loss
        loss_l1 = self.l1_loss(pred, target)
        
        total_loss = self.w_l1 * loss_l1
        
        # Perceptual Loss
        if self.lpips_loss is not None:
            # LPIPS espera valores em [-1, 1]
            pred_norm = pred * 2 - 1
            target_norm = target * 2 - 1
            loss_perceptual = self.lpips_loss(pred_norm, target_norm).mean()
            total_loss += self.w_perceptual * loss_perceptual
        
        return total_loss, {
            'l1': loss_l1.item(),
            'perceptual': loss_perceptual.item() if self.lpips_loss else 0.0
        }

In [None]:
def calculate_psnr(pred, target):
    """Calculate Peak Signal-to-Noise Ratio"""
    mse = torch.mean((pred - target) ** 2)
    if mse == 0:
        return float('inf')
    max_pixel = 1.0
    psnr = 20 * torch.log10(max_pixel / torch.sqrt(mse))
    return psnr.item()

def calculate_ssim(pred, target):
    """Calculate Structural Similarity Index (simplificado)"""
    # Para SSIM completo, use pytorch-msssim
    C1 = 0.01 ** 2
    C2 = 0.03 ** 2
    
    mu_pred = torch.mean(pred)
    mu_target = torch.mean(target)
    
    sigma_pred = torch.var(pred)
    sigma_target = torch.var(target)
    sigma_pred_target = torch.mean((pred - mu_pred) * (target - mu_target))
    
    ssim = ((2 * mu_pred * mu_target + C1) * (2 * sigma_pred_target + C2)) / \
           ((mu_pred ** 2 + mu_target ** 2 + C1) * (sigma_pred + sigma_target + C2))
    
    return ssim.item()

In [None]:
class Trainer:
    def __init__(
        self,
        model,
        train_loader,
        val_loader,
        optimizer,
        criterion,
        device,
        config
    ):
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.optimizer = optimizer
        self.criterion = criterion
        self.device = device
        self.config = config
        
        self.best_val_loss = float('inf')
        self.start_epoch = 0
    
    def train_epoch(self, epoch):
        self.model.train()
        total_loss = 0
        progress_bar = tqdm(self.train_loader, desc=f'Epoch {epoch}')
        
        for batch_idx, batch in enumerate(progress_bar):
            frame1 = batch['frame1'].to(self.device)
            frame2 = batch['frame2'].to(self.device)  # Ground truth
            frame3 = batch['frame3'].to(self.device)
            
            # Forward
            pred_frame2 = self.model(frame1, frame3)
            
            # Loss
            loss, loss_dict = self.criterion(pred_frame2, frame2)
            
            # Backward
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            
            total_loss += loss.item()
            
            # Update progress bar
            progress_bar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'l1': f'{loss_dict["l1"]:.4f}'
            })
            
            # Log to wandb
            if self.config.get('use_wandb', False):
                wandb.log({
                    'train/loss': loss.item(),
                    'train/l1_loss': loss_dict['l1'],
                    'train/perceptual_loss': loss_dict['perceptual'],
                    'epoch': epoch
                })
        
        avg_loss = total_loss / len(self.train_loader)
        return avg_loss
    
    def validate(self, epoch):
        self.model.eval()
        total_loss = 0
        total_psnr = 0
        total_ssim = 0
        
        with torch.no_grad():
            for batch in tqdm(self.val_loader, desc='Validation'):
                frame1 = batch['frame1'].to(self.device)
                frame2 = batch['frame2'].to(self.device)
                frame3 = batch['frame3'].to(self.device)
                
                pred_frame2 = self.model(frame1, frame3)
                
                loss, _ = self.criterion(pred_frame2, frame2)
                total_loss += loss.item()
                
                # Calcular métricas
                psnr = calculate_psnr(pred_frame2, frame2)
                ssim = calculate_ssim(pred_frame2, frame2)
                
                total_psnr += psnr
                total_ssim += ssim
        
        avg_loss = total_loss / len(self.val_loader)
        avg_psnr = total_psnr / len(self.val_loader)
        avg_ssim = total_ssim / len(self.val_loader)
        
        print(f'\nValidation - Loss: {avg_loss:.4f}, PSNR: {avg_psnr:.2f}, SSIM: {avg_ssim:.4f}')
        
        if self.config.get('use_wandb', False):
            wandb.log({
                'val/loss': avg_loss,
                'val/psnr': avg_psnr,
                'val/ssim': avg_ssim,
                'epoch': epoch
            })
        
        return avg_loss, avg_psnr, avg_ssim
    
    def save_checkpoint(self, epoch, val_loss, is_best=False):
        checkpoint_dir = Path(self.config['paths']['finetuned_dir'])
        checkpoint_dir.mkdir(parents=True, exist_ok=True)
        
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'val_loss': val_loss,
            'config': self.config
        }
        
        # Save regular checkpoint
        if epoch % self.config['training']['checkpoint_freq'] == 0:
            path = checkpoint_dir / f'checkpoint_epoch_{epoch}.pt'
            torch.save(checkpoint, path)
            print(f'Checkpoint saved: {path}')
        
        # Save best model
        if is_best:
            path = checkpoint_dir / 'best_model.pt'
            torch.save(checkpoint, path)
            print(f'Best model saved: {path}')
    
    def train(self, num_epochs):
        print(f'Starting training for {num_epochs} epochs...')
        
        for epoch in range(self.start_epoch, num_epochs):
            # Train
            train_loss = self.train_epoch(epoch)
            print(f'Epoch {epoch} - Train Loss: {train_loss:.4f}')
            
            # Validate
            val_loss, val_psnr, val_ssim = self.validate(epoch)
            
            # Save checkpoint
            is_best = val_loss < self.best_val_loss
            if is_best:
                self.best_val_loss = val_loss
            
            self.save_checkpoint(epoch, val_loss, is_best)
        
        print('Training completed!')

In [None]:
from pathlib import Path
from torch.utils.data import Subset
import random

# Tenta achar config.yaml no cwd; se não, no pai
candidates = [Path.cwd(), Path.cwd().parent]
PROJECT_ROOT = None
for c in candidates:
    if (c / 'config.yaml').exists():
        PROJECT_ROOT = c
        break
if PROJECT_ROOT is None:
    raise FileNotFoundError("config.yaml não encontrado; ajuste PROJECT_ROOT manualmente")

print("Project root:", PROJECT_ROOT)

with open(PROJECT_ROOT / 'config.yaml', 'r') as f:
    config = yaml.safe_load(f)

# Caminho absoluto para o Vimeo-90k
VIMEO_DATA_DIR = PROJECT_ROOT / config['paths']['data_dir'] / 'vimeo_triplet'
print("Vimeo Dataset Directory:", VIMEO_DATA_DIR)

# Sanidade
assert (VIMEO_DATA_DIR / 'tri_trainlist.txt').exists(), "tri_trainlist.txt não encontrado"

# Criar datasets
print("Loading datasets...")
train_dataset_full = Vimeo90kDataset(
    data_dir=VIMEO_DATA_DIR,
    is_train=True,
    crop_size=(256, 256),
    cache=False
)
val_dataset_full = Vimeo90kDataset(
    data_dir=VIMEO_DATA_DIR,
    is_train=False,
    crop_size=(256, 256),
    cache=False
)

# TESTE RÁPIDO: Usar apenas 15% dos dados
SUBSET_RATIO = 0.15  # 15% dos dados
train_size = int(len(train_dataset_full) * SUBSET_RATIO)
val_size = int(len(val_dataset_full) * SUBSET_RATIO)

train_indices = random.sample(range(len(train_dataset_full)), train_size)
val_indices = random.sample(range(len(val_dataset_full)), val_size)
train_dataset = Subset(train_dataset_full, train_indices)
val_dataset = Subset(val_dataset_full, val_indices)

print(f"Using {SUBSET_RATIO*100}% subset: {len(train_dataset)} train, {len(val_dataset)} val samples")

# DataLoaders
train_loader = DataLoader(
    train_dataset,
    batch_size=16,  # Reduzido para 4
    shuffle=True,
    num_workers=0,  # IMPORTANTE: 0 para notebooks no Windows
    pin_memory=False
)

val_loader = DataLoader(
    val_dataset,
    batch_size=16,
    shuffle=False,
    num_workers=0,
    pin_memory=False
)

print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")

# Criar modelo
model = UNetInterpolator(in_channels=6, out_channels=3).to(device)
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

# Optimizer e Loss
optimizer = optim.Adam(
    model.parameters(),
    lr=config['training']['learning_rate'],
    betas=(0.9, 0.999)
)
criterion = CombinedLoss(device=device)

# Learning rate scheduler
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='min',
    factor=0.5,
    patience=5
)

# Inicializar Weights & Biases (opcional)
USE_WANDB = False
if USE_WANDB:
    wandb.init(
        project='video-interpolation',
        config=config,
        name=f'unet-vimeo90k-{config["training"]["learning_rate"]}'
    )

config['use_wandb'] = USE_WANDB

# Criar trainer
trainer = Trainer(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    optimizer=optimizer,
    criterion=criterion,
    device=device,
    config=config
)

# Teste rápido: quanto tempo leva para iterar 5 batches?
import time
print("Testing data loading speed...")
start = time.time()
for i, batch in enumerate(train_loader):
    if i >= 5:
        break
    print(f"Batch {i}: frames shape = {batch['frame1'].shape}")
elapsed = time.time() - start
print(f"5 batches levaram {elapsed:.2f}s ({elapsed/5:.2f}s por batch)")

# Iniciar treinamento
trainer.train(num_epochs=config['training']['epochs'])

In [None]:
def visualize_results(model, dataset, num_samples=5):
    """Visualizar resultados do modelo"""
    import matplotlib.pyplot as plt
    
    model.eval()
    fig, axes = plt.subplots(num_samples, 4, figsize=(16, 4*num_samples))
    
    with torch.no_grad():
        for i in range(num_samples):
            sample = dataset[i]
            
            frame1 = sample['frame1'].unsqueeze(0).to(device)
            frame2_gt = sample['frame2']
            frame3 = sample['frame3'].unsqueeze(0).to(device)
            
            # Predição
            frame2_pred = model(frame1, frame3).cpu().squeeze(0)
            
            # Converter para numpy
            f1 = frame1.cpu().squeeze(0).permute(1, 2, 0).numpy()
            f2_gt = frame2_gt.permute(1, 2, 0).numpy()
            f2_pred = frame2_pred.permute(1, 2, 0).numpy()
            f3 = frame3.cpu().squeeze(0).permute(1, 2, 0).numpy()
            
            # Plot
            axes[i, 0].imshow(f1)
            axes[i, 0].set_title('Frame 1')
            axes[i, 0].axis('off')
            
            axes[i, 1].imshow(f2_gt)
            axes[i, 1].set_title('Frame 2 (GT)')
            axes[i, 1].axis('off')
            
            axes[i, 2].imshow(f2_pred)
            axes[i, 2].set_title('Frame 2 (Pred)')
            axes[i, 2].axis('off')
            
            axes[i, 3].imshow(f3)
            axes[i, 3].set_title('Frame 3')
            axes[i, 3].axis('off')
    
    plt.tight_layout()
    plt.savefig('../data/output/interpolation_results.png', dpi=150)
    plt.show()

# Visualizar resultados
visualize_results(model, val_dataset, num_samples=5)