<a href="https://colab.research.google.com/github/tijanicica/ai-speak/blob/main/AI_speak.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# AI Speak Takmiƒçenje

In [None]:
# Instalacija (samo jednom)
!pip install -q torchcodec transformers torchaudio librosa wandb accelerate

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import OneCycleLR

import numpy as np
import pandas as pd
import os
import math
import time
import gc
import random
import warnings
from tqdm import tqdm
from typing import Optional, Dict, List, Tuple

# Audio processing
import torchaudio
import torchaudio.transforms as T
from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2Model
import librosa

# Signal processing
from scipy import signal
from scipy.interpolate import interp1d
from scipy.stats import pearsonr

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns

warnings.filterwarnings('ignore')

# Set seeds for reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

print("All packages imported successfully!")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")





[?25l   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m0.0/2.1 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m2.1/2.1 MB[0m [31m77.4 MB/s[0m eta [36m0:00:00[0m
[?25hAll packages imported successfully!
PyTorch version: 2.9.0+cu126
CUDA available: True
GPU: Tesla T4
VRAM: 14.74 GB


In [None]:
# Serbian phoneme set
SERBIAN_PHONEMES = [
    'a', 'b', 'c', 'ƒç', 'ƒá', 'd', 'd≈æ', 'ƒë', 'e', 'f', 'g', 'h', 'i', 'j',
    'k', 'l', '«â', 'm', 'n', '«å', 'o', 'p', 'r', 's', '≈°', 't', 'u', 'v',
    'z', '≈æ', 'sil', 'sp'
]

PHONEME_TO_IDX = {p: i for i, p in enumerate(SERBIAN_PHONEMES)}
NUM_PHONEMES = len(SERBIAN_PHONEMES)


class AdvancedAudioProcessor:
    """Wav2Vec2 XLS-R feature extraction"""

    def __init__(self, model_name="facebook/wav2vec2-xls-r-300m", device='cuda'):
        self.device = device
        self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_name)
        self.model = Wav2Vec2Model.from_pretrained(model_name).to(device)
        self.model.eval()

        for param in self.model.parameters():
            param.requires_grad = False

    def extract_features(self, waveform, sr=16000):
        """Extract multilingual speech representations"""
        if sr != 16000:
            resampler = T.Resample(sr, 16000)
            waveform = resampler(waveform)

        waveform = waveform / (waveform.abs().max() + 1e-8)

        with torch.no_grad():
            inputs = self.feature_extractor(
                waveform.squeeze().cpu().numpy(),
                sampling_rate=16000,
                return_tensors="pt"
            )

            outputs = self.model(
                inputs.input_values.to(self.device),
                output_hidden_states=True
            )

            # Weighted sum of last 4 layers
            hidden_states = outputs.hidden_states[-4:]
            features = torch.stack(hidden_states).mean(0).squeeze(0)

        return features.cpu()


class PhonemeProcessor:
    """Phoneme alignment processor"""

    def __init__(self, fps=100):
        self.fps = fps
        self.phoneme_to_idx = PHONEME_TO_IDX
        self.num_phonemes = NUM_PHONEMES

    def parse_alignment_file(self, filepath):
        """Parse .txt file with format: start_time end_time phoneme"""
        alignments = []
        with open(filepath, 'r', encoding='utf-8') as f:
            for line in f:
                parts = line.strip().split()
                if len(parts) >= 3:
                    start, end, phoneme = float(parts[0]), float(parts[1]), parts[2]
                    alignments.append((start, end, phoneme))
        return alignments

    def create_phoneme_indices(self, alignments, num_frames):
        """Create phoneme indices (not one-hot) for embedding"""
        phoneme_indices = np.zeros(num_frames, dtype=np.int64)

        for start_time, end_time, phoneme in alignments:
            if phoneme not in self.phoneme_to_idx:
                continue

            start_frame = int(start_time * self.fps)
            end_frame = int(end_time * self.fps)

            start_frame = max(0, min(start_frame, num_frames - 1))
            end_frame = max(0, min(end_frame, num_frames))

            phoneme_idx = self.phoneme_to_idx[phoneme]
            phoneme_indices[start_frame:end_frame] = phoneme_idx

        return phoneme_indices


class AggressiveDenoiser:
    """
    KRITIƒåNO: MediaPipe labeli imaju MNOGO ≈°uma!
    Agresivno filtriranje pre treninga (research preporuka)
    """

    def __init__(self):
        self.regions = {
            'jaw': slice(24, 28),
            'mouth': slice(28, 52),
            'eyes': slice(5, 19),
            'brows': slice(0, 5),
            'other': slice(19, 24)
        }

    def savgol_filter(self, data, window, poly):
        """Savitzky-Golay with edge handling"""
        if len(data) < window:
            window = len(data) if len(data) % 2 == 1 else len(data) - 1
        window = max(poly + 2, window)
        if window % 2 == 0:
            window += 1

        try:
            return signal.savgol_filter(data, window, poly, axis=0)
        except:
            return data

    def denoise(self, blendshapes):
        """Smanjujemo agresivnost da bismo saƒçuvali dinamiku govora"""
        denoised = blendshapes.copy()

        # Brows: Smanjujemo sa 15 na 11
        denoised[:, self.regions['brows']] = self.savgol_filter(
            blendshapes[:, self.regions['brows']], 11, 3
        )
        # Eyes: Smanjujemo sa 13 na 9
        denoised[:, self.regions['eyes']] = self.savgol_filter(
            blendshapes[:, self.regions['eyes']], 9, 3
        )
        # Other: Smanjujemo sa 11 na 7
        denoised[:, self.regions['other']] = self.savgol_filter(
            blendshapes[:, self.regions['other']], 7, 3
        )
        # JAW: Smanjujemo sa 9 na 5 (KRITIƒåNO za brzinu vilice)
        denoised[:, self.regions['jaw']] = self.savgol_filter(
            blendshapes[:, self.regions['jaw']], 5, 3
        )
        # MOUTH: Smanjujemo sa 9 na 5 (KRITIƒåNO za o≈°trinu govora)
        denoised[:, self.regions['mouth']] = self.savgol_filter(
            blendshapes[:, self.regions['mouth']], 5, 3
        )

        return np.clip(denoised, 0, 1)

    '''
    def denoise(self, blendshapes):
        """Apply region-specific aggressive smoothing"""
        denoised = blendshapes.copy()

        # AGGRESSIVE smoothing (MediaPipe je jako ≈°uman)
        denoised[:, self.regions['brows']] = self.savgol_filter(
            blendshapes[:, self.regions['brows']], 15, 3
        )
        denoised[:, self.regions['eyes']] = self.savgol_filter(
            blendshapes[:, self.regions['eyes']], 13, 3
        )
        denoised[:, self.regions['other']] = self.savgol_filter(
            blendshapes[:, self.regions['other']], 11, 3
        )
        denoised[:, self.regions['jaw']] = self.savgol_filter(
            blendshapes[:, self.regions['jaw']], 9, 3
        )
        denoised[:, self.regions['mouth']] = self.savgol_filter(
            blendshapes[:, self.regions['mouth']], 9, 3
        )

        return np.clip(denoised, 0, 1)
    '''

class ChampionshipDataset(Dataset):
    """
    SOTA Dataset sa svim research preporukama:
    1. Aggressive denoising
    2. Phoneme indices (for embedding)
    3. Audio energy computation
    4. Proper augmentation
    """

    def __init__(
        self,
        base_path,
        split='train',
        fps=100,
        augment=True,
        device='cuda'
    ):
        self.base_path = base_path
        self.fps = fps
        self.augment = augment and split == 'train'
        self.device = device

        # Paths
        self.blendshape_path = os.path.join(base_path, 'spk08_ser')
        self.audio_path = os.path.join(base_path, 'ser/audio')
        self.label_path = os.path.join(base_path, 'labels 08 srp')

        # Get file list
        self.file_ids = self._get_file_list()

        # Train/val split (80/20)
        split_idx = int(len(self.file_ids) * 0.8)
        if split == 'train':
            self.file_ids = self.file_ids[:split_idx]
        else:
            self.file_ids = self.file_ids[split_idx:]

        # Initialize processors
        self.audio_processor = AdvancedAudioProcessor(device=device)
        self.phoneme_processor = PhonemeProcessor(fps=fps)
        self.denoiser = AggressiveDenoiser()

        print(f"‚úÖ {split.upper()}: {len(self.file_ids)} samples")

    def _get_file_list(self):
        """Extract unique file IDs"""
        npy_files = sorted([f for f in os.listdir(self.blendshape_path) if f.endswith('.npy')])
        file_ids = [f.replace('.npy', '').replace('08_ser_spk08_', '') for f in npy_files]
        return file_ids

    def _load_audio(self, file_id):
        """Load audio"""
        audio_file = os.path.join(self.audio_path, f"spk08_{file_id}.wav")
        waveform, sr = torchaudio.load(audio_file)
        return waveform, sr

    def _augment_audio(self, waveform, sr):
        """Smart augmentation"""
        if not self.augment or np.random.rand() > 0.5:
            return waveform

        # Time stretch (¬±10%)
        if np.random.rand() > 0.5:
            rate = np.random.uniform(0.9, 1.1)
            waveform_np = waveform.squeeze().numpy()
            waveform_np = librosa.effects.time_stretch(waveform_np, rate=rate)
            waveform = torch.from_numpy(waveform_np).unsqueeze(0)

        # Pitch shift (¬±2 semitones)
        if np.random.rand() > 0.5:
            n_steps = np.random.randint(-2, 3)
            waveform_np = waveform.squeeze().numpy()
            waveform_np = librosa.effects.pitch_shift(waveform_np, sr=sr, n_steps=n_steps)
            waveform = torch.from_numpy(waveform_np).unsqueeze(0)

        # Background noise (subtle)
        if np.random.rand() > 0.7:
            noise = torch.randn_like(waveform) * 0.005
            waveform = waveform + noise

        return waveform


    def _compute_energy(self, waveform, num_frames):
        """
        Raƒçuna RMS energiju iz audio signala i poravnava je sa brojem frejmova blendshape-ova.
        """
        # Konverzija u numpy za librosa
        y = waveform.squeeze().numpy()

        # Parametri za 100 FPS (ako je sr=16000, hop_length 160 daje taƒçno 10ms po frejmu)
        hop_length = 160
        frame_length = 320

        # Raƒçunanje RMS energije
        rms = librosa.feature.rms(y=y, hop_length=hop_length, frame_length=frame_length)[0]

        # Poravnanje du≈æine (librosa mo≈æe da vrati frejm vi≈°e ili manje zbog padding-a)
        if len(rms) > num_frames:
            rms = rms[:num_frames]
        elif len(rms) < num_frames:
            rms = np.pad(rms, (0, num_frames - len(rms)), mode='edge')

        # Normalizacija na [0, 1] da bi te≈æine bile stabilne
        rms_min = rms.min()
        rms_max = rms.max()
        if rms_max > rms_min:
            rms = (rms - rms_min) / (rms_max - rms_min + 1e-8)

        return rms.astype(np.float32)


    def __len__(self):
        return len(self.file_ids)


    def __getitem__(self, idx):
        file_id = self.file_ids[idx]

        # 1. Load i Augmentacija audija
        waveform, sr = self._load_audio(file_id)
        waveform = self._augment_audio(waveform, sr)

        # Obezbedi 16kHz verziju za ekstrakciju feature-a i energiju
        if sr != 16000:
            resampler = T.Resample(sr, 16000)
            waveform_16k = resampler(waveform)
        else:
            waveform_16k = waveform

        # 2. Extract Wav2Vec2 features (koristi waveform_16k)
        audio_features = self.audio_processor.extract_features(waveform_16k, 16000)

        # 3. Load blendshapes (Target)
        bs_file = os.path.join(self.blendshape_path, f"08_ser_spk08_{file_id}.npy")
        blendshapes = np.load(bs_file).astype(np.float32)
        num_frames = blendshapes.shape[0]

        # 4. CRITICAL: Aggressive denoising MediaPipe labela
        blendshapes = self.denoiser.denoise(blendshapes)

        # 5. Compute RMS Energy iz audija (za weighted loss)
        # Sada prosleƒëujemo waveform i num_frames
        energy = self._compute_energy(waveform_16k, num_frames)

        # 6. Load phoneme alignments
        label_file = os.path.join(self.label_path, f"spk08_{file_id}.txt")
        alignments = self.phoneme_processor.parse_alignment_file(label_file)

        # 7. Create phoneme indices (za embedding sloj)
        phoneme_indices = self.phoneme_processor.create_phoneme_indices(alignments, num_frames)

        # 8. Align audio features (50 FPS) na blendshape frames (100 FPS)
        audio_len = audio_features.shape[0]
        target_len = num_frames

        if audio_len != target_len:
            # Linearna interpolacija audio feature-a na 100 FPS
            old_indices = np.linspace(0, audio_len - 1, audio_len)
            new_indices = np.linspace(0, audio_len - 1, target_len)
            interpolator = interp1d(old_indices, audio_features.numpy(), axis=0, kind='linear')
            audio_features = torch.from_numpy(interpolator(new_indices)).float()

        return {
            'audio_features': audio_features,  # (T, 1024)
            'phoneme_indices': torch.from_numpy(phoneme_indices).long(),  # (T,)
            'blendshapes': torch.from_numpy(blendshapes),  # (T, 52)
            'energy': torch.from_numpy(energy),  # (T,)
            'file_id': file_id
        }



def create_dataloaders(base_path, batch_size=8, num_workers=0, device='cuda'):
    """Create train/val dataloaders with proper collation"""

    train_dataset = ChampionshipDataset(
        base_path, split='train', augment=True, device=device
    )
    val_dataset = ChampionshipDataset(
        base_path, split='val', augment=False, device=device
    )

    def collate_fn(batch):
        """Handle variable-length sequences"""
        max_len = max([item['audio_features'].shape[0] for item in batch])

        audio_padded = []
        phoneme_padded = []
        bs_padded = []
        energy_padded = []
        masks = []

        for item in batch:
            seq_len = item['audio_features'].shape[0]
            pad_len = max_len - seq_len

            # Pad sequences
            audio_padded.append(
                torch.cat([item['audio_features'],
                          torch.zeros(pad_len, 1024)], dim=0)
            )
            phoneme_padded.append(
                torch.cat([item['phoneme_indices'],
                          torch.zeros(pad_len, dtype=torch.long)], dim=0)
            )
            bs_padded.append(
                torch.cat([item['blendshapes'],
                          torch.zeros(pad_len, 52)], dim=0)
            )
            energy_padded.append(
                torch.cat([item['energy'],
                          torch.zeros(pad_len)], dim=0)
            )

            # Create mask
            mask = torch.zeros(max_len, dtype=torch.bool)
            mask[:seq_len] = True
            masks.append(mask)

        return {
            'audio_features': torch.stack(audio_padded),
            'phoneme_indices': torch.stack(phoneme_padded),
            'blendshapes': torch.stack(bs_padded),
            'energy': torch.stack(energy_padded),
            'mask': torch.stack(masks)
        }

    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True,
        collate_fn=collate_fn, num_workers=num_workers, pin_memory=True
    )
    val_loader = DataLoader(
        val_dataset, batch_size=batch_size, shuffle=False,
        collate_fn=collate_fn, num_workers=num_workers, pin_memory=True
    )

    return train_loader, val_loader


print("Championship Dataset loaded!")
print("Key features:")
print("   - Aggressive MediaPipe denoising")
print("   - Phoneme embedding (learned space)")
print("   - Energy-aware weighting")
print("   - Smart augmentation")


In [None]:

"""
üèÜ CHAMPIONSHIP LIP-SYNC SYSTEM - ƒÜELIJA 3
SOTA Model Architecture combining:
- VQ-VAE (CodeTalker) - Discrete motion codes
- Autoregressive GRU (FaceFormer) - Temporal consistency
- Phoneme Embedding - Learned phoneme space
- Multi-scale Discriminator - Fine + coarse motion
"""

# ==============================================================================
# 1. VECTOR QUANTIZATION (CodeTalker)
# ==============================================================================

class VectorQuantizer(nn.Module):
    """
    Discrete motion codebook - prevents mode collapse
    Research: CodeTalker (CVPR 2023)
    """

    def __init__(self, num_embeddings=512, embedding_dim=256, commitment_cost=0.25):
        super().__init__()
        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim
        self.commitment_cost = commitment_cost

        self.embeddings = nn.Embedding(num_embeddings, embedding_dim)
        self.embeddings.weight.data.uniform_(-1/num_embeddings, 1/num_embeddings)

    def forward(self, z):
        """Quantize continuous features to discrete codes"""
        B, T, D = z.shape
        z_flattened = z.reshape(-1, D)

        # Calculate distances to codebook
        distances = (
            torch.sum(z_flattened**2, dim=1, keepdim=True)
            + torch.sum(self.embeddings.weight**2, dim=1)
            - 2 * torch.matmul(z_flattened, self.embeddings.weight.t())
        )

        # Get nearest codebook entry
        encoding_indices = torch.argmin(distances, dim=1)
        encodings = F.one_hot(encoding_indices, self.num_embeddings).float()

        # Quantize
        quantized = torch.matmul(encodings, self.embeddings.weight)
        quantized = quantized.view(B, T, D)

        # Loss
        e_latent_loss = F.mse_loss(quantized.detach(), z)
        q_latent_loss = F.mse_loss(quantized, z.detach())
        codebook_loss = q_latent_loss + self.commitment_cost * e_latent_loss

        # Straight-through estimator
        quantized = z + (quantized - z).detach()

        return quantized, codebook_loss, encoding_indices.view(B, T)


# ==============================================================================
# 2. POSITIONAL ENCODING
# ==============================================================================

class PositionalEncoding(nn.Module):
    """Sinusoidal positional encoding"""

    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]


# ==============================================================================
# 3. CHAMPIONSHIP MODEL (VQ-VAE + Autoregressive GRU)
# ==============================================================================

class ChampionshipLipSyncModel(nn.Module):
    """
    SOTA Lip-Sync Model combining best research practices:

    1. Phoneme Embedding (learned, not one-hot)
    2. Transformer Encoder (audio + phoneme fusion)
    3. VQ-VAE (discrete motion priors)
    4. Autoregressive GRU (temporal consistency)
    5. Region-specific decoders (jaw, mouth, eyes, etc.)
    """

    def __init__(
        self,
        audio_dim=1024,
        phoneme_embedding_dim=64,
        d_model=512,
        num_encoder_layers=4,
        num_heads=8,
        dim_feedforward=2048,
        num_blendshapes=52,
        codebook_size=512,
        gru_hidden=128,
        dropout=0.1
    ):
        super().__init__()

        self.d_model = d_model
        self.gru_hidden = gru_hidden

        # ============ PHONEME EMBEDDING (learned space) ============
        self.phoneme_embedding = nn.Embedding(NUM_PHONEMES, phoneme_embedding_dim)

        # ============ INPUT PROJECTION ============
        self.input_projection = nn.Linear(audio_dim + phoneme_embedding_dim, d_model)
        self.pos_encoder = PositionalEncoding(d_model)

        # ============ TRANSFORMER ENCODER ============
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=num_heads,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            activation='gelu',
            batch_first=True,
            norm_first=True
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_encoder_layers)

        # ============ VECTOR QUANTIZATION ============
        self.pre_quant = nn.Linear(d_model, 256)
        self.vq = VectorQuantizer(num_embeddings=codebook_size, embedding_dim=256)
        self.post_quant = nn.Linear(256, d_model)

        # ============ AUTOREGRESSIVE GRU (FaceFormer style) ============
        # KRITIƒåNO: Ovo dodaje "memory" - model zna ≈°ta je rekao u prethodnom frejmu
        self.gru = nn.GRU(
            input_size=d_model,
            hidden_size=gru_hidden,
            num_layers=1,
            batch_first=True,
            dropout=dropout if num_encoder_layers > 1 else 0
        )

        # ============ OUTPUT HEADS (region-specific) ============
        self.jaw_head = self._make_head(gru_hidden, 4)
        self.mouth_head = self._make_head(gru_hidden, 24)
        self.eye_head = self._make_head(gru_hidden, 14)
        self.brow_head = self._make_head(gru_hidden, 5)
        self.other_head = self._make_head(gru_hidden, 5)

    def _make_head(self, input_dim, num_outputs):
        head = nn.Sequential(
            nn.Linear(input_dim, input_dim // 2),
            nn.LayerNorm(input_dim // 2),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(input_dim // 2, num_outputs),
            nn.Sigmoid()
        )
        # FORCE INITIALIZATION TO ZERO
        # Postavljamo bias poslednjeg sloja na -3.0 tako da Sigmoid krene od ~0.05 (skoro zatvoreno)
        nn.init.constant_(head[-2].bias, -3.0)
        return head


    def forward(self, audio_features, phoneme_indices, mask=None, hidden_state=None):
        """
        A≈ΩURIRANO: VQ-Bypass za maksimalnu stabilnost na malom datasetu
        """
        B, T, _ = audio_features.shape

        # ============ PHONEME EMBEDDING ============
        phoneme_emb = self.phoneme_embedding(phoneme_indices)

        # ============ COMBINE AUDIO + PHONEME ============
        x = torch.cat([audio_features, phoneme_emb], dim=-1)
        x = self.input_projection(x)
        x = self.pos_encoder(x)

        # ============ ENCODE (Transformer) ============
        if mask is not None:
            src_key_padding_mask = ~mask
        else:
            src_key_padding_mask = None

        encoded = self.encoder(x, src_key_padding_mask=src_key_padding_mask)

        # ============ BYPASS VQ LAYER (IZMENA OVDE) ============
        # Umesto VQ-a, ≈°aljemo direktno iz Transformera u GRU
        # To je stabilnije za mali broj snimaka (80 reƒçenica)
        quantized = encoded
        vq_loss = torch.tensor(0.0).to(x.device) # Loss postavljamo na 0 da ne kvari trening

        # ============ AUTOREGRESSIVE GRU ============
        if hidden_state is None:
            gru_out, hidden_state = self.gru(quantized)
        else:
            gru_out, hidden_state = self.gru(quantized, hidden_state)

        # ============ BLENDSHAPE PREDICTION ============
        jaw = self.jaw_head(gru_out)
        mouth = self.mouth_head(gru_out)
        eye = self.eye_head(gru_out)
        brow = self.brow_head(gru_out)
        other = self.other_head(gru_out)

        blendshapes = torch.cat([brow, eye, other, jaw, mouth], dim=-1)

        return blendshapes, vq_loss, hidden_state

    '''
    def forward(self, audio_features, phoneme_indices, mask=None, hidden_state=None):
        """
        audio_features: (B, T, 1024)
        phoneme_indices: (B, T) - integer indices
        mask: (B, T) - boolean mask
        hidden_state: Optional GRU hidden state for autoregression
        """
        B, T, _ = audio_features.shape

        # ============ PHONEME EMBEDDING ============
        phoneme_emb = self.phoneme_embedding(phoneme_indices)  # (B, T, phoneme_embedding_dim)

        # ============ COMBINE AUDIO + PHONEME ============
        x = torch.cat([audio_features, phoneme_emb], dim=-1)  # (B, T, 1024+64)
        x = self.input_projection(x)  # (B, T, d_model)
        x = self.pos_encoder(x)

        # ============ ENCODE ============
        if mask is not None:
            src_key_padding_mask = ~mask
        else:
            src_key_padding_mask = None

        encoded = self.encoder(x, src_key_padding_mask=src_key_padding_mask)  # (B, T, d_model)

        # ============ VECTOR QUANTIZATION ============
        z = self.pre_quant(encoded)  # (B, T, 256)
        quantized, vq_loss, _ = self.vq(z)
        quantized = self.post_quant(quantized)  # (B, T, d_model)

        # ============ AUTOREGRESSIVE GRU ============
        # KRITIƒåNO: GRU odr≈æava "memory" izmeƒëu frejmova
        if hidden_state is None:
            gru_out, hidden_state = self.gru(quantized)  # (B, T, gru_hidden)
        else:
            gru_out, hidden_state = self.gru(quantized, hidden_state)

        # ============ BLENDSHAPE PREDICTION ============
        jaw = self.jaw_head(gru_out)
        mouth = self.mouth_head(gru_out)
        eye = self.eye_head(gru_out)
        brow = self.brow_head(gru_out)
        other = self.other_head(gru_out)

        blendshapes = torch.cat([brow, eye, other, jaw, mouth], dim=-1)  # (B, T, 52)

        return blendshapes, vq_loss, hidden_state
    '''

# ==============================================================================
# 4. MULTI-SCALE DISCRIMINATOR (MelGAN/HiFiGAN style)
# ==============================================================================

class MultiScaleDiscriminator(nn.Module):
    """
    Discriminate at multiple temporal scales
    Forces model to capture both fine and coarse motion
    """

    def __init__(self, num_blendshapes=52):
        super().__init__()

        self.disc_1x = self._make_discriminator(num_blendshapes, scale=1)
        self.disc_2x = self._make_discriminator(num_blendshapes, scale=2)
        self.disc_4x = self._make_discriminator(num_blendshapes, scale=4)

    def _make_discriminator(self, in_channels, scale):
        """Single-scale discriminator"""
        return nn.ModuleDict({
            'downsample': nn.AvgPool1d(scale, scale) if scale > 1 else nn.Identity(),
            'conv_blocks': nn.Sequential(
                nn.Conv1d(in_channels, 128, kernel_size=15, padding=7),
                nn.LeakyReLU(0.2),
                nn.Conv1d(128, 256, kernel_size=11, stride=2, padding=5),
                nn.LeakyReLU(0.2),
                nn.Conv1d(256, 512, kernel_size=7, stride=2, padding=3),
                nn.LeakyReLU(0.2),
                nn.Conv1d(512, 512, kernel_size=5, stride=2, padding=2),
                nn.LeakyReLU(0.2),
            ),
            'classifier': nn.Conv1d(512, 1, kernel_size=3, padding=1)
        })

    def _forward_single_scale(self, x, disc):
        """Process single scale"""
        x = x.transpose(1, 2)  # (B, T, C) -> (B, C, T)
        x = disc['downsample'](x)
        features = disc['conv_blocks'](x)
        score = disc['classifier'](features)
        return score, features

    def forward(self, blendshapes):
        """Returns: list of (score, features) for each scale"""
        scores = []
        features = []

        for disc in [self.disc_1x, self.disc_2x, self.disc_4x]:
            s, f = self._forward_single_scale(blendshapes, disc)
            scores.append(s)
            features.append(f)

        return scores, features


print("Championship Model Architecture loaded!")
print("Key innovations:")
print("   - VQ-VAE: Discrete motion codebook")
print("   - Autoregressive GRU: Temporal memory")
print("   - Phoneme Embedding: Learned phoneme space")
print("   - Multi-scale Discriminator: Fine + coarse motion")





In [None]:


"""
üèÜ CHAMPIONSHIP LIP-SYNC SYSTEM - ƒÜELIJA 4
Advanced Loss Functions:
- Correlation Loss (direktna optimizacija sinhronizacije)
- L1 Loss (bolje od MSE za blendshapes)
- Energy-aware weighting (fokus na aktivne frejmove)
- Perceptual Loss (feature matching)
- Velocity Loss (dinamika pokreta)
"""

# ==============================================================================
# 1. CORRELATION LOSS (Research preporuka - kljuƒçno!)
# ==============================================================================

def correlation_loss(pred, target, mask=None):
    """
    Pearson Correlation Loss

    KRITIƒåNO: Forsira model da prati "oblik" pokreta, ne samo intenzitet!
    Ovo je razlog za≈°to MSE daje nizak correlation (0.28).
    """
    # Normalize
    pred_mean = pred.mean(dim=1, keepdim=True)
    target_mean = target.mean(dim=1, keepdim=True)

    pred_centered = pred - pred_mean
    target_centered = target - target_mean

    # Correlation
    numerator = (pred_centered * target_centered).sum(dim=1)
    denominator = torch.sqrt(
        (pred_centered ** 2).sum(dim=1) * (target_centered ** 2).sum(dim=1) + 1e-8
    )

    corr = numerator / denominator

    if mask is not None:
        # Apply mask
        corr = corr * mask.float()
        loss = 1 - corr.sum() / (mask.sum() + 1e-8)
    else:
        loss = 1 - corr.mean()

    return loss


# ==============================================================================
# 2. CHAMPIONSHIP LOSS FUNCTION
# ==============================================================================

class ChampionshipLoss(nn.Module):
    """
    SOTA Loss combining research-proven techniques:

    1. L1 Loss (better than MSE for blendshapes)
    2. Correlation Loss (sync optimization)
    3. Velocity Loss (motion dynamics)
    4. Perceptual Loss (feature matching from discriminator)
    5. Energy-aware weighting (focus on active frames)
    """

    def __init__(
        self,
        reconstruction_weight=15.0,
        correlation_weight=5.0,
        velocity_weight=5.0,
        perceptual_weight=0.1,
        smoothness_weight=0.0001
    ):
        super().__init__()
        self.reconstruction_weight = reconstruction_weight
        self.correlation_weight = correlation_weight
        self.velocity_weight = velocity_weight
        self.perceptual_weight = perceptual_weight
        self.smoothness_weight = smoothness_weight

        # Important indices
        self.mouth_indices = list(range(28, 52))
        self.jaw_indices = [24, 25, 26, 27]

    def reconstruction_loss(self, pred, target, energy=None, mask=None):
        """
        L1 Loss sa:
        - Mouth region weighting
        - Energy-aware weighting (focus on active frames)
        """
        # L1 umesto MSE (research preporuka)
        l1 = torch.abs(pred - target)

        # Region weights
        weights = torch.ones_like(l1)
        weights[:, :, self.mouth_indices] *= 15.0  # Fokus na usta
        weights[:, :, self.jaw_indices] *= 5.0

        l1 = l1 * weights

        # Energy-aware weighting (fokus na frejmove gde se priƒça)
        if energy is not None:
            # Normalize energy to [0.5, 1.5] range
            energy_weight = 0.5 + energy.unsqueeze(-1)
            l1 = l1 * energy_weight

        if mask is not None:
            l1 = l1 * mask.unsqueeze(-1)
            return l1.sum() / (mask.sum() * pred.shape[-1] + 1e-8)

        return l1.mean()

    def correlation_loss_wrapper(self, pred, target, mask=None):
        """Per-blendshape correlation loss"""
        total_corr_loss = 0
        num_valid = 0

        # Compute correlation for each blendshape separately
        for i in range(pred.shape[-1]):
            pred_i = pred[:, :, i]  # (B, T)
            target_i = target[:, :, i]

            # Skip if no variation
            if target_i.std() < 1e-6:
                continue

            corr_loss_i = correlation_loss(pred_i, target_i, mask)
            total_corr_loss += corr_loss_i
            num_valid += 1

        return total_corr_loss / (num_valid + 1e-8)

    def velocity_loss(self, pred, target, mask=None):
        """Match motion dynamics (L1 on velocity)"""
        pred_vel = pred[:, 1:] - pred[:, :-1]
        target_vel = target[:, 1:] - target[:, :-1]

        vel_loss = torch.abs(pred_vel - target_vel)

        # Extra weight on mouth
        weights = torch.ones_like(vel_loss)
        weights[:, :, self.mouth_indices] *= 3.0
        vel_loss = vel_loss * weights

        if mask is not None:
            valid_mask = mask[:, 1:]
            return (vel_loss * valid_mask.unsqueeze(-1)).sum() / (valid_mask.sum() + 1e-8)

        return vel_loss.mean()

    def perceptual_loss(self, pred_features, target_features):
        """Feature matching loss from discriminator"""
        total_loss = 0
        for pred_f, target_f in zip(pred_features, target_features):
            total_loss += F.l1_loss(pred_f, target_f)

        return total_loss / len(pred_features)

    def smoothness_loss(self, pred, mask=None):
        """Minimal smoothness (anti-jitter only)"""
        acceleration = pred[:, 2:] - 2*pred[:, 1:-1] + pred[:, :-2]

        if mask is not None:
            valid_mask = mask[:, 2:]
            return (acceleration.pow(2) * valid_mask.unsqueeze(-1)).sum() / (valid_mask.sum() + 1e-8)

        return acceleration.pow(2).mean()

    def forward(self, pred, target, pred_features, target_features, energy=None, mask=None):
        """
        Compute total loss

        Args:
            pred: Predicted blendshapes (B, T, 52)
            target: Ground truth blendshapes (B, T, 52)
            pred_features: List of discriminator features for pred
            target_features: List of discriminator features for target
            energy: Frame-wise energy (B, T)
            mask: Valid frame mask (B, T)

        Returns:
            Dictionary of losses
        """
        recon = self.reconstruction_loss(pred, target, energy, mask)
        corr = self.correlation_loss_wrapper(pred, target, mask)
        vel = self.velocity_loss(pred, target, mask)
        perc = self.perceptual_loss(pred_features, target_features)
        smooth = self.smoothness_loss(pred, mask)

        total = (
            self.reconstruction_weight * recon +
            self.correlation_weight * corr +      # NOVO - kljuƒçno!
            self.velocity_weight * vel +
            self.perceptual_weight * perc +
            self.smoothness_weight * smooth
        )

        return {
            'total': total,
            'reconstruction': recon,
            'correlation': corr,
            'velocity': vel,
            'perceptual': perc,
            'smoothness': smooth
        }


print("Championship Loss Functions loaded!")
print("Key innovations:")
print("   - Correlation Loss: Direct sync optimization")
print("   - L1 Loss: Better than MSE for blendshapes")
print("   - Energy-aware: Focus on active frames")
print("   - Perceptual Loss: Feature matching")
print("\n Loss Weights:")
print("   Reconstruction: 1.0")
print("   Correlation: 5.0 (NOVO - kljuƒçno!)")
print("   Velocity: 2.0")
print("   Perceptual: 10.0")
print("   Smoothness: 0.001")





In [None]:


"""
üèÜ CHAMPIONSHIP LIP-SYNC SYSTEM - ƒÜELIJA 5
Championship Trainer sa:
- Gradient Accumulation (batch_size=8 efektivno)
- Mixed Precision Training
- Cosine Annealing sa Warmup
- GRU Hidden State Reset
"""

class ChampionshipTrainer:
    """
    SOTA Trainer incorporating research best practices
    """

    def __init__(
        self,
        model,
        discriminator,
        train_loader,
        val_loader,
        device='cuda',
        learning_rate=1e-4,
        gradient_accumulation_steps=8,  # batch_size=1 * 8 = efektivno 8
        use_wandb=False
    ):
        self.model = model.to(device)
        self.discriminator = discriminator.to(device)
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.device = device
        self.gradient_accumulation_steps = gradient_accumulation_steps
        self.use_wandb = use_wandb

        # Loss functions
        self.criterion = ChampionshipLoss()
        self.adversarial_loss = nn.MSELoss()  # Least-squares GAN

        # Optimizers
        self.optimizer_g = optim.AdamW(
            model.parameters(),
            lr=learning_rate,
            betas=(0.5, 0.999),
            weight_decay=1e-4
        )
        self.optimizer_d = optim.AdamW(
            discriminator.parameters(),
            lr=learning_rate,
            betas=(0.5, 0.999),
            weight_decay=1e-4
        )

        # Schedulers - cosine annealing sa warmup
        total_steps = len(train_loader) // gradient_accumulation_steps * 200

        self.scheduler_g = OneCycleLR(
            self.optimizer_g,
            max_lr=learning_rate,
            total_steps=total_steps,
            pct_start=0.1,
            anneal_strategy='cos'
        )
        self.scheduler_d = OneCycleLR(
            self.optimizer_d,
            max_lr=learning_rate,
            total_steps=total_steps,
            pct_start=0.1,
            anneal_strategy='cos'
        )

        # Mixed precision
        self.scaler = torch.cuda.amp.GradScaler()

        self.best_val_loss = float('inf')
        self.best_val_corr = 0.0
        self.patience = 40
        self.patience_counter = 0

    def reset_gru_hidden(self, batch_size):
        """Reset GRU hidden state for each batch"""
        return None  # GRU ƒáe kreirati novi hidden state

    def train_discriminator(self, batch, accumulation_step):
        """Train discriminator with multi-scale approach"""
        audio = batch['audio_features'].to(self.device)
        phoneme = batch['phoneme_indices'].to(self.device)
        target = batch['blendshapes'].to(self.device)
        mask = batch['mask'].to(self.device)

        with torch.cuda.amp.autocast():
            # Generate fake samples
            with torch.no_grad():
                hidden = self.reset_gru_hidden(audio.shape[0])
                fake, _, _ = self.model(audio, phoneme, mask, hidden)

            # Discriminate real
            real_scores, _ = self.discriminator(target)
            # Discriminate fake
            fake_scores, _ = self.discriminator(fake.detach())

            # Loss (least-squares GAN)
            d_loss_real = sum([F.mse_loss(s, torch.ones_like(s)) for s in real_scores]) / len(real_scores)
            d_loss_fake = sum([F.mse_loss(s, torch.zeros_like(s)) for s in fake_scores]) / len(fake_scores)

            d_loss = (d_loss_real + d_loss_fake) / (2 * self.gradient_accumulation_steps)

        self.scaler.scale(d_loss).backward()

        if (accumulation_step + 1) % self.gradient_accumulation_steps == 0:
            self.scaler.step(self.optimizer_d)
            self.scaler.update()
            self.optimizer_d.zero_grad()
            self.scheduler_d.step()

        return d_loss.item() * self.gradient_accumulation_steps

    def train_generator(self, batch, accumulation_step, epoch):
        """Train generator with all losses"""
        audio = batch['audio_features'].to(self.device)
        phoneme = batch['phoneme_indices'].to(self.device)
        target = batch['blendshapes'].to(self.device)
        energy = batch['energy'].to(self.device)
        mask = batch['mask'].to(self.device)

        with torch.cuda.amp.autocast():
            # Generate
            hidden = self.reset_gru_hidden(audio.shape[0])
            pred, vq_loss, _ = self.model(audio, phoneme, mask, hidden)

            # Get discriminator features for perceptual loss
            with torch.no_grad():
                _, target_features = self.discriminator(target)
            pred_scores, pred_features = self.discriminator(pred)

            # Main losses
            losses = self.criterion(pred, target, pred_features, target_features, energy, mask)

            # Adversarial loss
            adv_loss = sum([F.mse_loss(s, torch.ones_like(s)) for s in pred_scores]) / len(pred_scores)


            # Dinamiƒçka te≈æina za energiju pokreta
            # Prvih 30 epoha 0.1 (da nauƒçi sink), posle 0.2 (da postane energiƒçan)
            adv_weight = 0.1 if epoch < 30 else 0.2

            # Total generator loss
            g_loss = (
                losses['total'] +
                adv_weight * adv_loss +  # KORISTIMO adv_weight
                1.0 * vq_loss
            ) / self.gradient_accumulation_steps


        self.scaler.scale(g_loss).backward()

        if (accumulation_step + 1) % self.gradient_accumulation_steps == 0:
            self.scaler.unscale_(self.optimizer_g)
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
            self.scaler.step(self.optimizer_g)
            self.scaler.update()
            self.optimizer_g.zero_grad()
            self.scheduler_g.step()

        return {
            'g_loss': g_loss.item() * self.gradient_accumulation_steps,
            'reconstruction': losses['reconstruction'].item(),
            'correlation': losses['correlation'].item(),
            'velocity': losses['velocity'].item(),
            'perceptual': losses['perceptual'].item(),
            'vq_loss': vq_loss.item()
        }

    @torch.no_grad()
    def validate(self):
        """Validation with correlation metric"""
        self.model.eval()
        val_losses = []
        val_corrs = []

        for batch in self.val_loader:
            audio = batch['audio_features'].to(self.device)
            phoneme = batch['phoneme_indices'].to(self.device)
            target = batch['blendshapes'].to(self.device)
            energy = batch['energy'].to(self.device)
            mask = batch['mask'].to(self.device)

            with torch.cuda.amp.autocast():
                hidden = self.reset_gru_hidden(audio.shape[0])
                pred, vq_loss, _ = self.model(audio, phoneme, mask, hidden)

                _, target_features = self.discriminator(target)
                _, pred_features = self.discriminator(pred)

                losses = self.criterion(pred, target, pred_features, target_features, energy, mask)

            val_losses.append(losses['total'].item())

            # Compute correlation
            pred_np = pred[mask].cpu().numpy()
            target_np = target[mask].cpu().numpy()

            if len(pred_np) > 10:
                corr, _ = pearsonr(pred_np.flatten(), target_np.flatten())
                if not np.isnan(corr):
                    val_corrs.append(corr)

        self.model.train()
        return np.mean(val_losses), np.mean(val_corrs) if val_corrs else 0.0

    def train(self, num_epochs, save_dir='checkpoints'):
        """Main training loop"""
        os.makedirs(save_dir, exist_ok=True)

        if self.use_wandb:
            import wandb
            wandb.init(project='championship-lipsync')

        self.optimizer_g.zero_grad()
        self.optimizer_d.zero_grad()

        for epoch in range(num_epochs):
            self.model.train()
            epoch_metrics = {
                'g_loss': [], 'reconstruction': [], 'correlation': [],
                'velocity': [], 'perceptual': [], 'vq_loss': [], 'd_loss': []
            }

            pbar = tqdm(self.train_loader, desc=f'Epoch {epoch+1}/{num_epochs}')

            for batch_idx, batch in enumerate(pbar):
                # Train discriminator every 2 batches
                if batch_idx % 2 == 0:
                    d_loss = self.train_discriminator(batch, batch_idx)
                    epoch_metrics['d_loss'].append(d_loss)

                # Train generator
                g_metrics = self.train_generator(batch, batch_idx, epoch)
                for k, v in g_metrics.items():
                    epoch_metrics[k].append(v)

                '''
                pbar.set_postfix({
                    'G': f"{np.mean(epoch_metrics['g_loss']):.4f}",
                    'Corr': f"{np.mean(epoch_metrics['correlation']):.4f}",
                    'Rec': f"{np.mean(epoch_metrics['reconstruction']):.4f}"
                })
                '''
                pbar.set_postfix({
                    'G': f"{np.mean(epoch_metrics['g_loss']):.3f}",
                    'D': f"{np.mean(epoch_metrics['d_loss']) if epoch_metrics['d_loss'] else 0:.3f}",
                    'VQ': f"{np.mean(epoch_metrics['vq_loss']):.4f}",
                    'Corr': f"{np.mean(epoch_metrics['correlation']):.3f}",
                    'Rec': f"{np.mean(epoch_metrics['reconstruction']):.3f}",
                    'Best': f"{self.best_val_corr:.3f}"
                })

                # Clear cache periodically
                if batch_idx % 10 == 0:
                    torch.cuda.empty_cache()

            # Validation
            val_loss, val_corr = self.validate()
            print(f"\nEpoch {epoch+1} - Val Loss: {val_loss:.4f}, Val Correlation: {val_corr:.4f}")

            if self.use_wandb:
                import wandb
                wandb.log({
                    'epoch': epoch,
                    'train/g_loss': np.mean(epoch_metrics['g_loss']),
                    'train/correlation': np.mean(epoch_metrics['correlation']),
                    'train/reconstruction': np.mean(epoch_metrics['reconstruction']),
                    'val/loss': val_loss,
                    'val/correlation': val_corr
                })

            # Save best model (based on correlation!)
            if val_corr > self.best_val_corr:
                self.best_val_corr = val_corr
                self.best_val_loss = val_loss
                self.patience_counter = 0

                checkpoint = {
                    'epoch': epoch,
                    'model_state_dict': self.model.state_dict(),
                    'discriminator_state_dict': self.discriminator.state_dict(),
                    'optimizer_g_state_dict': self.optimizer_g.state_dict(),
                    'val_loss': val_loss,
                    'val_correlation': val_corr
                }
                torch.save(checkpoint, os.path.join(save_dir, 'best_model.pt'))
                print(f"‚úÖ Saved best model (corr: {val_corr:.4f})")
            else:
                self.patience_counter += 1

            # Early stopping
            if self.patience_counter >= self.patience:
                print(f"‚ö†Ô∏è Early stopping at epoch {epoch+1}")
                break

            # Periodic checkpoint
            if (epoch + 1) % 10 == 0:
                checkpoint = {
                    'epoch': epoch,
                    'model_state_dict': self.model.state_dict(),
                    'val_correlation': val_corr
                }
                torch.save(checkpoint, os.path.join(save_dir, f'checkpoint_epoch_{epoch+1}.pt'))

            torch.cuda.empty_cache()

        if self.use_wandb:
            import wandb
            wandb.finish()

        print(f"\nüèÜ Training complete!")
        print(f"Best validation correlation: {self.best_val_corr:.4f}")
        print(f"Best validation loss: {self.best_val_loss:.4f}")


print("Championship Trainer loaded!")
print("Key features:")
print("   - Gradient Accumulation: Effective batch_size=8")
print("   - Mixed Precision: Faster training")
print("   - Correlation-based checkpointing")
print("   - GRU hidden state management")






In [None]:


"""
üèÜ CHAMPIONSHIP LIP-SYNC SYSTEM - ƒÜELIJA 6
Main Training Script
"""

def train_championship_model(
    base_path='/content/drive/MyDrive/AI-speak',
    save_dir='checkpoints_championship',
    batch_size=1,  # Mora biti 1 zbog razliƒçitih du≈æina
    gradient_accumulation_steps=8,  # Efektivni batch = 8
    learning_rate=1e-4,
    num_epochs=200,
    d_model=512,  # Smanji na 384 ako ima≈° malo VRAM-a
    num_encoder_layers=4,  # Smanji na 3 za manje VRAM
    codebook_size=512,
    use_wandb=False,
    device='cuda'
):
    """
    Glavni training script sa svim optimizacijama
    """

    print("="*70)
    print("üèÜ CHAMPIONSHIP LIP-SYNC TRAINING")
    print("="*70)
    print(f"Device: {device}")
    print(f"Effective batch size: {batch_size * gradient_accumulation_steps}")
    print(f"Model: d_model={d_model}, layers={num_encoder_layers}")
    print(f"Codebook size: {codebook_size}")
    print(f"Learning rate: {learning_rate}")
    print("="*70 + "\n")

    # Clear memory
    torch.cuda.empty_cache()
    gc.collect()

    # ========== LOAD DATA ==========
    print("üìÇ Loading datasets...")
    train_loader, val_loader = create_dataloaders(
        base_path,
        batch_size=batch_size,
        num_workers=0,
        device=device
    )
    print(f"‚úÖ Train: {len(train_loader)} batches")
    print(f"‚úÖ Val: {len(val_loader)} batches\n")

    # ========== BUILD MODEL ==========
    print("üèóÔ∏è Building Championship model...")
    model = ChampionshipLipSyncModel(
        audio_dim=1024,
        phoneme_embedding_dim=64,
        d_model=d_model,
        num_encoder_layers=num_encoder_layers,
        num_heads=8,
        dim_feedforward=2048,
        num_blendshapes=52,
        codebook_size=codebook_size,
        gru_hidden=256,
        dropout=0.1
    )

    discriminator = MultiScaleDiscriminator(num_blendshapes=52)

    # Model summary
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"‚úÖ Generator params: {total_params:,} (trainable: {trainable_params:,})")

    disc_params = sum(p.numel() for p in discriminator.parameters())
    print(f"‚úÖ Discriminator params: {disc_params:,}\n")

    # ========== KEY IMPROVEMENTS ==========
    print("="*70)
    print("üîß KEY IMPROVEMENTS (Research-based):")
    print("="*70)
    print("‚úÖ VQ-VAE: Discrete motion codebook (CodeTalker)")
    print("‚úÖ Autoregressive GRU: Temporal memory (FaceFormer)")
    print("‚úÖ Phoneme Embedding: Learned phoneme space")
    print("‚úÖ Correlation Loss: Direct sync optimization")
    print("‚úÖ L1 Loss: Better than MSE for blendshapes")
    print("‚úÖ Energy-aware: Focus on active frames")
    print("‚úÖ Multi-scale Discriminator: Fine + coarse motion")
    print("‚úÖ Gradient Accumulation: Effective batch_size=8")
    print("="*70 + "\n")

    # ========== LOSS WEIGHTS ==========
    print("üìä Loss Weights:")
    print("   Reconstruction (L1): 1.0")
    print("   Correlation: 5.0 ‚≠ê (NOVO - kljuƒçno!)")
    print("   Velocity: 2.0")
    print("   Perceptual: 10.0")
    print("   Smoothness: 0.001")
    print("   VQ-VAE: 0.1")
    print("   Adversarial: 0.1\n")

    # ========== INITIALIZE TRAINER ==========
    print("‚öôÔ∏è Initializing trainer...")
    trainer = ChampionshipTrainer(
        model=model,
        discriminator=discriminator,
        train_loader=train_loader,
        val_loader=val_loader,
        device=device,
        learning_rate=learning_rate,
        gradient_accumulation_steps=gradient_accumulation_steps,
        use_wandb=use_wandb
    )
    print("‚úÖ Trainer ready!\n")

    # ========== START TRAINING ==========
    print("="*70)
    print("üöÄ STARTING TRAINING")
    print("="*70 + "\n")

    try:
        trainer.train(num_epochs=num_epochs, save_dir=save_dir)
    except KeyboardInterrupt:
        print("\n‚ö†Ô∏è Training interrupted by user")
        checkpoint = {
            'model_state_dict': model.state_dict(),
            'discriminator_state_dict': discriminator.state_dict(),
            'note': 'interrupted'
        }
        torch.save(checkpoint, os.path.join(save_dir, 'interrupted.pt'))
        print("üíæ Saved interrupted checkpoint")

    print("\n" + "="*70)
    print("üéâ TRAINING COMPLETE!")
    print("="*70)
    print(f"Best correlation: {trainer.best_val_corr:.4f}")
    print(f"Best loss: {trainer.best_val_loss:.4f}")


# ==============================================================================
# POKRETANJE
# ==============================================================================

if __name__ == "__main__":
    # Mount Google Drive (za Colab)
    try:
        from google.colab import drive
        drive.mount('/content/drive')
    except:
        print("Not on Colab, skipping drive mount")

    # Run training
    train_championship_model(
        base_path='/content/drive/MyDrive/AI-speak',  # PROMENI OVO!
        save_dir='checkpoints_championship',
        batch_size=1,
        gradient_accumulation_steps=8,
        learning_rate=1e-4,
        num_epochs=200,
        d_model=512,  # Smanji na 384 za manje VRAM
        num_encoder_layers=4,  # Smanji na 3 za manje VRAM
        codebook_size=512,
        use_wandb=False,
        device='cuda' if torch.cuda.is_available() else 'cpu'
    )


print("‚úÖ Training script ready!")
print("üìå Pokrenite: train_championship_model()")
print("üìå Ili direktno pokrenite ovu ƒáeliju")






In [None]:

"""
üèÜ CHAMPIONSHIP LIP-SYNC SYSTEM - ƒÜELIJA 7
Production-Ready Inference System
"""

class MinimalPostProcessor:
    """
    Minimalan post-processing jer model veƒá generi≈°e dobre predikcije
    Samo enforce hard constraints
    """

    def __init__(self, fps=100):
        self.fps = fps

        self.constraints = {
            'eye_symmetry': {
                'pairs': [(13, 14)],
                'max_diff': 0.3
            },
            'jaw_mutual_exclusive': {
                'indices': [24, 36],
                'max_sum': 1.2
            },
            'max_velocity': {
                'jaw': 0.25,
                'mouth': 0.20,
                'eyes': 0.30,
                'default': 0.18
            }
        }

    def enforce_symmetry(self, blendshapes):
        """Enforce eye symmetry"""
        processed = blendshapes.copy()

        for pair in self.constraints['eye_symmetry']['pairs']:
            left_idx, right_idx = pair
            max_diff = self.constraints['eye_symmetry']['max_diff']

            diff = np.abs(processed[:, left_idx] - processed[:, right_idx])
            excessive = diff > max_diff

            if excessive.any():
                avg = (processed[excessive, left_idx] + processed[excessive, right_idx]) / 2
                processed[excessive, left_idx] = avg
                processed[excessive, right_idx] = avg

        return processed

    def clamp_extreme_velocity(self, blendshapes):
        """Prevent only EXTREME rapid movements"""
        processed = blendshapes.copy()

        jaw_indices = [24, 25, 26, 27]
        mouth_indices = list(range(28, 52))
        eye_indices = list(range(5, 19))

        def clamp_group(indices, max_vel):
            for idx in indices:
                velocity = np.diff(processed[:, idx], prepend=processed[0, idx])
                excessive = np.abs(velocity) > max_vel

                if excessive.any():
                    for i in np.where(excessive)[0]:
                        if i > 0:
                            target = processed[i-1, idx] + np.sign(velocity[i]) * max_vel
                            processed[i, idx] = target

        clamp_group(jaw_indices, self.constraints['max_velocity']['jaw'])
        clamp_group(mouth_indices, self.constraints['max_velocity']['mouth'])
        clamp_group(eye_indices, self.constraints['max_velocity']['eyes'])

        return processed

    def process(self, blendshapes):
        """Minimal processing - samo constraints"""
        processed = blendshapes.copy()
        processed = self.enforce_symmetry(processed)
        processed = self.clamp_extreme_velocity(processed)
        return np.clip(processed, 0, 1)


class ChampionshipInference:
    """
    Production-ready inference system for Championship model
    """

    def __init__(self, model_path, device='cuda'):
        self.device = device
        self.fps = 100

        print("üîÑ Loading Championship model...")
        checkpoint = torch.load(model_path, map_location=device, weights_only=False)

        # Load model
        self.model = ChampionshipLipSyncModel(
            audio_dim=1024,
            phoneme_embedding_dim=64,
            d_model=512,
            num_encoder_layers=4,
            num_heads=8,
            dim_feedforward=2048,
            num_blendshapes=52,
            codebook_size=512,
            gru_hidden=256,
            dropout=0.1
        ).to(device)

        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.model.eval()

        # Audio processor
        print("üîÑ Loading Wav2Vec2...")
        self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-xls-r-300m")
        self.wav2vec = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-xls-r-300m").to(device)
        self.wav2vec.eval()

        for param in self.wav2vec.parameters():
            param.requires_grad = False

        # Processors
        self.post_processor = MinimalPostProcessor(fps=self.fps)
        self.phoneme_processor = PhonemeProcessor(fps=self.fps)

        # Performance tracking
        self.latency_history = []

        print("‚úÖ Inference system ready!")
        print(f"üìä Model epoch: {checkpoint.get('epoch', 'unknown')}")
        print(f"üìä Val correlation: {checkpoint.get('val_correlation', 'unknown'):.4f}")

    def extract_audio_features(self, audio_path, sr=16000):
        """Extract Wav2Vec2 features"""
        waveform, original_sr = torchaudio.load(audio_path)

        if original_sr != sr:
            resampler = T.Resample(original_sr, sr)
            waveform = resampler(waveform)

        waveform = waveform / (waveform.abs().max() + 1e-8)

        with torch.no_grad():
            inputs = self.feature_extractor(
                waveform.squeeze().cpu().numpy(),
                sampling_rate=sr,
                return_tensors="pt"
            )

            outputs = self.wav2vec(
                inputs.input_values.to(self.device),
                output_hidden_states=True
            )

            hidden_states = outputs.hidden_states[-4:]
            features = torch.stack(hidden_states).mean(0).squeeze(0)

        return features.cpu().numpy()

    @torch.no_grad()
    def predict(
        self,
        audio_path: str,
        phoneme_alignment_path: Optional[str] = None,
        target_fps: int = 100,
        apply_post_processing: bool = True
    ) -> Dict[str, np.ndarray]:
        """
        Generate blendshape coefficients from audio

        Args:
            audio_path: Path to audio file
            phoneme_alignment_path: Optional phoneme alignment
            target_fps: Target frame rate
            apply_post_processing: Apply minimal post-processing

        Returns:
            Dictionary with blendshapes, timestamps, latency, etc.
        """
        start_time = time.time()

        # Extract audio features
        audio_features = self.extract_audio_features(audio_path)

        # Align to target FPS
        audio_duration = audio_features.shape[0] / 50
        num_frames = int(audio_duration * target_fps)

        # Interpolate
        old_indices = np.linspace(0, audio_features.shape[0] - 1, audio_features.shape[0])
        new_indices = np.linspace(0, audio_features.shape[0] - 1, num_frames)
        interpolator = interp1d(old_indices, audio_features, axis=0, kind='linear')
        audio_features_aligned = interpolator(new_indices)

        # Phoneme sequence
        if phoneme_alignment_path and os.path.exists(phoneme_alignment_path):
            alignments = self.phoneme_processor.parse_alignment_file(phoneme_alignment_path)
            phoneme_indices = self.phoneme_processor.create_phoneme_indices(alignments, num_frames)
        else:
            phoneme_indices = np.zeros(num_frames, dtype=np.int64)

        # Prepare inputs
        audio_tensor = torch.from_numpy(audio_features_aligned).float().unsqueeze(0).to(self.device)
        phoneme_tensor = torch.from_numpy(phoneme_indices).long().unsqueeze(0).to(self.device)
        mask = torch.ones(1, num_frames, dtype=torch.bool).to(self.device)

        # Generate blendshapes
        with torch.cuda.amp.autocast():
            blendshapes, vq_loss, _ = self.model(audio_tensor, phoneme_tensor, mask)

        blendshapes_np = blendshapes.squeeze(0).cpu().numpy()

        # Post-processing
        if apply_post_processing:
            blendshapes_np = self.post_processor.process(blendshapes_np)

        # Timestamps
        timestamps = np.arange(num_frames) / target_fps

        # Latency
        latency = time.time() - start_time
        self.latency_history.append(latency)

        return {
            'blendshapes': blendshapes_np,
            'timestamps': timestamps,
            'latency': latency,
            'fps': target_fps,
            'vq_loss': vq_loss.item()
        }

    def export_to_csv(self, blendshapes, output_path):
        """Export blendshapes to CSV"""
        blendshape_names = [
            'browInnerUp', 'browDownLeft', 'browDownRight', 'browOuterUpLeft', 'browOuterUpRight',
            'eyeLookUpLeft', 'eyeLookUpRight', 'eyeLookDownLeft', 'eyeLookDownRight',
            'eyeLookInLeft', 'eyeLookInRight', 'eyeLookOutLeft', 'eyeLookOutRight',
            'eyeBlinkLeft', 'eyeBlinkRight', 'eyeSquintLeft', 'eyeSquintRight',
            'eyeWideLeft', 'eyeWideRight', 'cheekPuff', 'cheekSquintLeft', 'cheekSquintRight',
            'noseSneerLeft', 'noseSneerRight', 'jawOpen', 'jawForward', 'jawLeft', 'jawRight',
            'mouthFunnel', 'mouthPucker', 'mouthLeft', 'mouthRight', 'mouthRollUpper',
            'mouthRollLower', 'mouthShrugUpper', 'mouthShrugLower', 'mouthClose',
            'mouthSmileLeft', 'mouthSmileRight', 'mouthFrownLeft', 'mouthFrownRight',
            'mouthDimpleLeft', 'mouthDimpleRight', 'mouthUpperUpLeft', 'mouthUpperUpRight',
            'mouthLowerDownLeft', 'mouthLowerDownRight', 'mouthPressLeft', 'mouthPressRight',
            'mouthStretchLeft', 'mouthStretchRight', 'tongueOut'
        ]

        df = pd.DataFrame(blendshapes, columns=blendshape_names)
        df.to_csv(output_path, index=False)
        print(f"‚úÖ Exported {len(df)} frames to {output_path}")

    def get_performance_stats(self):
        """Get performance statistics"""
        if not self.latency_history:
            return None

        return {
            'mean_latency': np.mean(self.latency_history),
            'median_latency': np.median(self.latency_history),
            'p95_latency': np.percentile(self.latency_history, 95),
            'total_inferences': len(self.latency_history),
            'fps_realtime': 1.0 / np.mean(self.latency_history)
        }


print("‚úÖ Championship Inference system loaded!")
print("üìå Usage:")
print("   inference = ChampionshipInference('checkpoints_championship/best_model.pt')")
print("   result = inference.predict('audio.wav', 'phonemes.txt')")





In [None]:


"""
üèÜ CHAMPIONSHIP LIP-SYNC SYSTEM - ƒÜELIJA 8
Comprehensive Validation & Visualization
"""

class ChampionshipValidator:
    """
    Comprehensive validation system
    """

    def __init__(self, model_path, base_path, device='cuda'):
        self.model_path = model_path
        self.base_path = base_path
        self.device = device

        # Load inference system
        self.inference = ChampionshipInference(model_path, device)

        # Get validation files
        self.val_files = self._get_validation_files()
        print(f"‚úÖ Found {len(self.val_files)} validation files")

    def _get_validation_files(self):
        """Get validation file IDs"""
        path = os.path.join(self.base_path, 'spk08_ser')
        all_files = sorted([f for f in os.listdir(path) if f.endswith('.npy')])
        file_ids = [f.replace('.npy', '').replace('08_ser_spk08_', '') for f in all_files]
        return file_ids[int(len(file_ids) * 0.8):]

    def load_ground_truth(self, file_id):
        """Load ground truth blendshapes"""
        path = os.path.join(self.base_path, 'spk08_ser', f'08_ser_spk08_{file_id}.npy')
        return np.load(path).astype(np.float32)

    def predict_single_file(self, file_id):
        """Predict blendshapes for single file"""
        audio = os.path.join(self.base_path, 'ser/audio', f'spk08_{file_id}.wav')
        phoneme = os.path.join(self.base_path, 'labels 08 srp', f'spk08_{file_id}.txt')

        result = self.inference.predict(
            audio,
            phoneme_alignment_path=phoneme if os.path.exists(phoneme) else None
        )
        return result['blendshapes']

    def compute_metrics(self, pred, gt):
        """Compute comprehensive metrics - UPDATED with Jaw and Velocity metrics"""
        min_len = min(len(pred), len(gt))
        pred, gt = pred[:min_len], gt[:min_len]

        # 1. Basic metrics (Overall)
        mse = np.mean((pred - gt) ** 2)
        mae = np.mean(np.abs(pred - gt))
        overall_corr, _ = pearsonr(pred.flatten(), gt.flatten())

        # 2. Per-blendshape correlation (Mean of all 52)
        per_bs_corr = []
        for i in range(52):
            if np.std(gt[:, i]) > 1e-6 and np.std(pred[:, i]) > 1e-6:
                try:
                    c, _ = pearsonr(pred[:, i], gt[:, i])
                    if not np.isnan(c): per_bs_corr.append(c)
                except: pass

        # 3. Mouth-specific (28-51)
        mouth_indices = list(range(28, 52))
        mouth_mse = np.mean((pred[:, mouth_indices] - gt[:, mouth_indices]) ** 2)
        mouth_mae = np.mean(np.abs(pred[:, mouth_indices] - gt[:, mouth_indices]))
        mouth_corr, _ = pearsonr(pred[:, mouth_indices].flatten(), gt[:, mouth_indices].flatten())

        # 4. Jaw-specific (24 - jawOpen) - KRITIƒåNO ZA LIP-SYNC
        jaw_idx = 24
        jaw_mae = np.mean(np.abs(pred[:, jaw_idx] - gt[:, jaw_idx]))
        jaw_corr = 0
        if np.std(gt[:, jaw_idx]) > 1e-6 and np.std(pred[:, jaw_idx]) > 1e-6:
            jaw_corr, _ = pearsonr(pred[:, jaw_idx], gt[:, jaw_idx])

        # 5. Dynamics & Velocity
        pred_vel = np.diff(pred, axis=0)
        gt_vel = np.diff(gt, axis=0)
        velocity_mse = np.mean((pred_vel - gt_vel) ** 2)

        # Velocity Correlation (da li se usta otvaraju istom brzinom)
        vel_corr, _ = pearsonr(pred_vel.flatten(), gt_vel.flatten())

        # Jitter / Smoothness
        pred_jitter = np.mean(np.abs(np.diff(pred_vel, axis=0)))
        gt_jitter = np.mean(np.abs(np.diff(gt_vel, axis=0)))

        return {
            'mse': mse,
            'mae': mae,
            'correlation': overall_corr,
            'per_bs_corr_mean': np.mean(per_bs_corr) if per_bs_corr else 0,
            'mouth_mse': mouth_mse,
            'mouth_mae': mouth_mae,      # DODATO
            'mouth_correlation': mouth_corr,
            'jaw_mae': jaw_mae,          # DODATO
            'jaw_correlation': jaw_corr, # DODATO
            'velocity_mse': velocity_mse,
            'velocity_correlation': vel_corr, # DODATO
            'smoothness_ratio': pred_jitter / (gt_jitter + 1e-8),
            'pred_jitter': pred_jitter,
            'gt_jitter': gt_jitter
        }

    '''
    def compute_metrics(self, pred, gt):
        """Compute comprehensive metrics"""
        min_len = min(len(pred), len(gt))
        pred, gt = pred[:min_len], gt[:min_len]

        # Basic metrics
        mse = np.mean((pred - gt) ** 2)
        mae = np.mean(np.abs(pred - gt))

        # Correlation (overall)
        overall_corr, _ = pearsonr(pred.flatten(), gt.flatten())

        # Per-blendshape correlation
        per_bs_corr = []
        for i in range(52):
            if np.std(gt[:, i]) > 1e-6 and np.std(pred[:, i]) > 1e-6:
                try:
                    c, _ = pearsonr(pred[:, i], gt[:, i])
                    if not np.isnan(c):
                        per_bs_corr.append(c)
                except:
                    pass

        # Velocity metrics
        pred_vel = np.diff(pred, axis=0)
        gt_vel = np.diff(gt, axis=0)
        velocity_mse = np.mean((pred_vel - gt_vel) ** 2)

        # Jitter
        pred_jitter = np.mean(np.abs(np.diff(pred_vel, axis=0)))
        gt_jitter = np.mean(np.abs(np.diff(gt_vel, axis=0)))

        # Mouth-specific metrics
        mouth_indices = list(range(28, 52))
        mouth_mse = np.mean((pred[:, mouth_indices] - gt[:, mouth_indices]) ** 2)
        mouth_corr, _ = pearsonr(pred[:, mouth_indices].flatten(), gt[:, mouth_indices].flatten())

        return {
            'mse': mse,
            'mae': mae,
            'correlation': overall_corr,
            'per_bs_corr_mean': np.mean(per_bs_corr) if per_bs_corr else 0,
            'velocity_mse': velocity_mse,
            'smoothness_ratio': pred_jitter / (gt_jitter + 1e-8),
            'pred_jitter': pred_jitter,
            'gt_jitter': gt_jitter,
            'mouth_mse': mouth_mse,
            'mouth_correlation': mouth_corr
        }
    '''

    def visualize_prediction(self, file_id, save_path='validation_plots'):
        """Create detailed visualization - UPDATED with specific Jaw and Mouth metrics"""
        os.makedirs(save_path, exist_ok=True)

        print(f"üìä Visualizing: {file_id}")

        # Predict
        pred = self.predict_single_file(file_id)
        gt = self.load_ground_truth(file_id)

        # Align
        min_len = min(len(pred), len(gt))
        pred = pred[:min_len]
        gt = gt[:min_len]

        # Compute metrics
        metrics = self.compute_metrics(pred, gt)

        # Create figure
        fig = plt.figure(figsize=(22, 15))

        # 1. Key mouth movements (jawOpen je ovde najbitniji)
        ax1 = plt.subplot(3, 3, 1)
        mouth_indices = [24, 28, 36, 37, 38]
        mouth_names = ['jawOpen', 'mouthFunnel', 'mouthClose', 'mouthSmileL', 'mouthSmileR']

        for idx, name in zip(mouth_indices, mouth_names):
            ax1.plot(gt[:, idx], label=f'{name} (GT)', linestyle='--', alpha=0.7)
            ax1.plot(pred[:, idx], label=f'{name} (Pred)', linewidth=2)

        ax1.set_xlabel('Frame')
        ax1.set_ylabel('Value')
        ax1.set_title(f'Key Movements (Jaw Corr: {metrics["jaw_correlation"]:.3f})')
        ax1.legend(fontsize=7, ncol=2)
        ax1.grid(alpha=0.3)

        # 2. Error heatmap
        ax2 = plt.subplot(3, 3, 2)
        error = np.abs(pred - gt)
        im = ax2.imshow(error.T, aspect='auto', cmap='hot', interpolation='nearest')
        ax2.set_xlabel('Frame')
        ax2.set_ylabel('Blendshape Index')
        ax2.set_title('Absolute Error Heatmap')
        plt.colorbar(im, ax=ax2, label='Error')

        # 3. Per-blendshape correlation
        ax3 = plt.subplot(3, 3, 3)
        correlations = []
        for i in range(52):
            if np.std(gt[:, i]) > 1e-6 and np.std(pred[:, i]) > 1e-6:
                try:
                    c, _ = pearsonr(pred[:, i], gt[:, i])
                    correlations.append(c if not np.isnan(c) else 0)
                except: correlations.append(0)
            else: correlations.append(0)

        colors = ['green' if c > 0.8 else 'orange' if c > 0.6 else 'red' for c in correlations]
        # MARKER ZA VILICU (index 24) - ƒçinimo ga prepoznatljivim
        colors[24] = 'blue'
        ax3.bar(range(52), correlations, color=colors, alpha=0.7)
        ax3.axhline(y=0.7, color='blue', linestyle=':', label='Target Threshold')
        ax3.set_xlabel('Index (Blue bar is Jaw)')
        ax3.set_ylabel('Correlation')
        ax3.set_title('Per-BS Correlation')
        ax3.grid(alpha=0.3)

        # 4. Velocity comparison
        ax4 = plt.subplot(3, 3, 4)
        pred_vel = np.linalg.norm(np.diff(pred, axis=0), axis=1)
        gt_vel = np.linalg.norm(np.diff(gt, axis=0), axis=1)
        ax4.plot(gt_vel, label='GT Vel', color='blue', alpha=0.5)
        ax4.plot(pred_vel, label='Pred Vel', color='orange', alpha=0.8)
        ax4.set_title(f'Motion Dynamics (Vel Corr: {metrics["velocity_correlation"]:.3f})')
        ax4.legend()
        ax4.grid(alpha=0.3)

        # 5. Distribution comparison
        ax5 = plt.subplot(3, 3, 5)
        ax5.hist(gt.flatten(), bins=50, alpha=0.4, label='GT', density=True, color='blue')
        ax5.hist(pred.flatten(), bins=50, alpha=0.4, label='Pred', density=True, color='orange')
        ax5.set_title('Value Distribution (Target: Overlap at 0)')
        ax5.legend()

        # 6. Mouth-only comparison
        ax6 = plt.subplot(3, 3, 6)
        mouth_indices = list(range(28, 52))
        for i in range(0, 24, 6): # Prikazujemo svaku 6. rastegnutu usnu radi preglednosti
            ax6.plot(gt[:, mouth_indices[i]], alpha=0.3, color='blue', linestyle='--')
            ax6.plot(pred[:, mouth_indices[i]], alpha=0.8, color='orange')
        ax6.set_title(f'Mouth Detail (MAE: {metrics["mouth_mae"]:.4f})')
        ax6.grid(alpha=0.3)

        # 7-9. Metrics summary - NOVE METRIKE UKLJUƒåENE
        ax7 = plt.subplot(3, 3, 7)
        ax7.axis('off')

        # Assessment logic based on LIPS (Mouth & Jaw)
        lip_score = (metrics['mouth_correlation'] + metrics['jaw_correlation']) / 2
        if lip_score > 0.80 and metrics['mouth_mae'] < 0.04:
            assessment = 'üèÜ CHAMPIONSHIP LEVEL'
        elif lip_score > 0.65 and metrics['mouth_mae'] < 0.06:
            assessment = '‚úÖ EXCELLENT'
        elif lip_score > 0.45:
            assessment = 'üëç GOOD'
        else:
            assessment = '‚ö†Ô∏è NEEDS IMPROVEMENT'

        metrics_text = f"""
üìä OVERALL QUALITY
{'='*35}
MSE:              {metrics['mse']:.6f}
MAE:              {metrics['mae']:.6f}
Correlation:      {metrics['correlation']:.4f}

üìà MOUTH & JAW (LIP-SYNC)
{'='*35}
Mouth Corr:       {metrics['mouth_correlation']:.4f}
Mouth MAE:        {metrics['mouth_mae']:.4f}
Jaw Corr:         {metrics['jaw_correlation']:.4f}
Jaw MAE:          {metrics['jaw_mae']:.4f}

üéØ DYNAMICS (ENERGY)
{'='*35}
Vel. Corr:        {metrics['velocity_correlation']:.4f}
Smoothness Ratio: {metrics['smoothness_ratio']:.4f}

‚≠ê STATUS: {assessment}
        """

        ax7.text(0.05, 0.5, metrics_text, fontsize=11, family='monospace',
                verticalalignment='center',
                bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.2))

        plt.suptitle(f'Championship Validation: {file_id}', fontsize=16, fontweight='bold')
        plt.tight_layout()

        output_path = os.path.join(save_path, f'{file_id}_analysis.png')
        plt.savefig(output_path, dpi=150, bbox_inches='tight')
        plt.close()

        print(f"‚úÖ Saved: {output_path}")
        return metrics

    '''
    def visualize_prediction(self, file_id, save_path='validation_plots'):
        """Create detailed visualization"""
        os.makedirs(save_path, exist_ok=True)

        print(f"üìä Visualizing: {file_id}")

        # Predict
        pred = self.predict_single_file(file_id)
        gt = self.load_ground_truth(file_id)

        # Align
        min_len = min(len(pred), len(gt))
        pred = pred[:min_len]
        gt = gt[:min_len]

        # Compute metrics
        metrics = self.compute_metrics(pred, gt)

        # Create figure
        fig = plt.figure(figsize=(20, 14))

        # 1. Key mouth movements
        ax1 = plt.subplot(3, 3, 1)
        mouth_indices = [24, 28, 36, 37, 38]
        mouth_names = ['jawOpen', 'mouthFunnel', 'mouthClose', 'mouthSmileL', 'mouthSmileR']

        for idx, name in zip(mouth_indices, mouth_names):
            ax1.plot(gt[:, idx], label=f'{name} (GT)', linestyle='--', alpha=0.7)
            ax1.plot(pred[:, idx], label=f'{name} (Pred)', linewidth=2)

        ax1.set_xlabel('Frame')
        ax1.set_ylabel('Blendshape Value')
        ax1.set_title('Key Mouth Movements')
        ax1.legend(fontsize=7, ncol=2)
        ax1.grid(alpha=0.3)

        # 2. Error heatmap
        ax2 = plt.subplot(3, 3, 2)
        error = np.abs(pred - gt)
        movement_mask = np.max(gt, axis=1) > 0.1

        if movement_mask.sum() > 0:
            error_subset = error[movement_mask][:200]
            im = ax2.imshow(error_subset.T, aspect='auto', cmap='hot', interpolation='nearest')
            ax2.set_xlabel('Frame')
            ax2.set_ylabel('Blendshape Index')
            ax2.set_title('Absolute Error Heatmap')
            plt.colorbar(im, ax=ax2, label='Error')

        # 3. Per-blendshape correlation
        ax3 = plt.subplot(3, 3, 3)
        correlations = []
        for i in range(52):
            if np.std(gt[:, i]) > 1e-6 and np.std(pred[:, i]) > 1e-6:
                try:
                    corr, _ = pearsonr(pred[:, i], gt[:, i])
                    correlations.append(corr if not np.isnan(corr) else 0)
                except:
                    correlations.append(0)
            else:
                correlations.append(0)

        colors = ['green' if c > 0.8 else 'orange' if c > 0.6 else 'red' for c in correlations]
        ax3.bar(range(52), correlations, color=colors, alpha=0.7)
        ax3.axhline(y=0.8, color='green', linestyle='--', linewidth=1.5, label='Excellent')
        ax3.axhline(y=0.6, color='orange', linestyle='--', linewidth=1.5, label='Good')
        ax3.set_xlabel('Blendshape Index')
        ax3.set_ylabel('Correlation')
        ax3.set_title('Per-Blendshape Correlation')
        ax3.legend()
        ax3.grid(alpha=0.3)

        # 4. Velocity comparison
        ax4 = plt.subplot(3, 3, 4)
        pred_vel = np.linalg.norm(np.diff(pred, axis=0), axis=1)
        gt_vel = np.linalg.norm(np.diff(gt, axis=0), axis=1)

        ax4.plot(gt_vel, label='GT Velocity', color='blue', alpha=0.7, linewidth=1.5)
        ax4.plot(pred_vel, label='Pred Velocity', color='orange', linewidth=2)
        ax4.set_xlabel('Frame')
        ax4.set_ylabel('Velocity Magnitude')
        ax4.set_title('Motion Velocity')
        ax4.legend()
        ax4.grid(alpha=0.3)

        # 5. Distribution comparison
        ax5 = plt.subplot(3, 3, 5)
        ax5.hist(gt.flatten(), bins=50, alpha=0.5, label='GT', density=True, color='blue')
        ax5.hist(pred.flatten(), bins=50, alpha=0.5, label='Pred', density=True, color='orange')
        ax5.set_xlabel('Blendshape Value')
        ax5.set_ylabel('Density')
        ax5.set_title('Value Distribution')
        ax5.legend()
        ax5.grid(alpha=0.3)

        # 6. Mouth-only comparison
        ax6 = plt.subplot(3, 3, 6)
        mouth_indices = list(range(28, 52))
        mouth_gt = gt[:, mouth_indices]
        mouth_pred = pred[:, mouth_indices]

        for i in range(0, 24, 4):
            ax6.plot(mouth_gt[:, i], alpha=0.5, color='blue')
            ax6.plot(mouth_pred[:, i], alpha=0.7, color='orange')

        ax6.set_xlabel('Frame')
        ax6.set_ylabel('Mouth Blendshapes')
        ax6.set_title(f'Mouth Region (Corr: {metrics["mouth_correlation"]:.3f})')
        ax6.grid(alpha=0.3)

        # 7-9. Metrics summary
        ax7 = plt.subplot(3, 3, 7)
        ax7.axis('off')

        # Assessment
        if metrics['correlation'] > 0.85 and metrics['mae'] < 0.05:
            assessment = 'üèÜ CHAMPIONSHIP LEVEL'
            color = 'green'
        elif metrics['correlation'] > 0.75 and metrics['mae'] < 0.08:
            assessment = '‚úÖ EXCELLENT'
            color = 'darkgreen'
        elif metrics['correlation'] > 0.60:
            assessment = 'üëç GOOD'
            color = 'orange'
        else:
            assessment = '‚ö†Ô∏è NEEDS IMPROVEMENT'
            color = 'red'

        metrics_text = f"""
üìä QUALITY METRICS
{'='*40}
MSE:              {metrics['mse']:.6f}
MAE:              {metrics['mae']:.6f}
Correlation:      {metrics['correlation']:.6f}
Per-BS Corr Avg:  {metrics['per_bs_corr_mean']:.6f}

üìà MOUTH SPECIFIC
{'='*40}
Mouth MSE:        {metrics['mouth_mse']:.6f}
Mouth Corr:       {metrics['mouth_correlation']:.6f}

üéØ DYNAMICS
{'='*40}
Velocity MSE:     {metrics['velocity_mse']:.6f}
Smoothness Ratio: {metrics['smoothness_ratio']:.6f}

‚≠ê OVERALL
{'='*40}
{assessment}
        """

        ax7.text(0.1, 0.5, metrics_text, fontsize=10, family='monospace',
                verticalalignment='center',
                bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.3))

        plt.suptitle(f'Championship Validation: {file_id}', fontsize=16, fontweight='bold')
        plt.tight_layout()

        output_path = os.path.join(save_path, f'{file_id}_analysis.png')
        plt.savefig(output_path, dpi=150, bbox_inches='tight')
        plt.close()

        print(f"‚úÖ Saved: {output_path}")
        return metrics
    '''
    def generate_championship_report(self, num_samples=10):
        """Generate comprehensive validation report"""
        print("\n" + "üèÜ"*35)
        print("CHAMPIONSHIP VALIDATION REPORT")
        print("üèÜ"*35 + "\n")

        # Sample files
        sample_files = random.sample(self.val_files, min(num_samples, len(self.val_files)))

        all_metrics = []
        for file_id in tqdm(sample_files, desc="Validating"):
            metrics = self.visualize_prediction(file_id)
            metrics['file_id'] = file_id
            all_metrics.append(metrics)

        df = pd.DataFrame(all_metrics)

        # Print results
        print("\n" + "="*70)
        print("üìä VALIDATION RESULTS")
        print("="*70)
        print(df[['file_id', 'correlation', 'mae', 'mouth_correlation']].to_string(index=False))

        # Summary statistics
        print("\n" + "="*70)
        print("üìà SUMMARY STATISTICS")
        print("="*70)
        print(f"Average Correlation: {df['correlation'].mean():.4f} (¬±{df['correlation'].std():.4f})")
        print(f"Average MAE:         {df['mae'].mean():.6f} (¬±{df['mae'].std():.6f})")
        print(f"Average Mouth Corr:  {df['mouth_correlation'].mean():.4f} (¬±{df['mouth_correlation'].std():.4f})")

        # Quality distribution
        championship = (df['correlation'] > 0.85).sum()
        excellent = ((df['correlation'] > 0.75) & (df['correlation'] <= 0.85)).sum()
        good = ((df['correlation'] > 0.60) & (df['correlation'] <= 0.75)).sum()
        needs_work = (df['correlation'] <= 0.60).sum()

        print("\n" + "="*70)
        print("üéØ QUALITY DISTRIBUTION")
        print("="*70)
        print(f"üèÜ Championship (>0.85):  {championship}/{len(df)} ({100*championship/len(df):.1f}%)")
        print(f"‚úÖ Excellent (0.75-0.85): {excellent}/{len(df)} ({100*excellent/len(df):.1f}%)")
        print(f"üëç Good (0.60-0.75):      {good}/{len(df)} ({100*good/len(df):.1f}%)")
        print(f"‚ö†Ô∏è  Needs Work (<0.60):    {needs_work}/{len(df)} ({100*needs_work/len(df):.1f}%)")

        # Final assessment
        avg_corr = df['correlation'].mean()
        print("\n" + "="*70)
        print("üéØ FINAL ASSESSMENT")
        print("="*70)

        if avg_corr > 0.85:
            print("üèÜ MODEL STATUS: CHAMPIONSHIP LEVEL")
            print("   Outstanding performance! Ready for production!")
        elif avg_corr > 0.75:
            print("‚úÖ MODEL STATUS: EXCELLENT")
            print("   Great performance! Suitable for most applications!")
        elif avg_corr > 0.60:
            print("üëç MODEL STATUS: GOOD")
            print("   Solid performance, some room for improvement")
        else:
            print("‚ö†Ô∏è MODEL STATUS: NEEDS IMPROVEMENT")
            print("   Consider additional training or tuning")

        # Save report
        df.to_csv('championship_validation_metrics.csv', index=False)
        print(f"\nüíæ Saved: championship_validation_metrics.csv")
        print(f"üìÇ Visualizations in: validation_plots/")

        return df


print("‚úÖ Championship Validator loaded!")
print("üìå Usage:")
print("   validator = ChampionshipValidator(")
print("       'checkpoints_championship/best_model.pt',")
print("       '/content/drive/MyDrive/AI-speak'")
print("   )")
print("   validator.generate_championship_report()")





In [None]:


"""
üèÜ POKRENI VALIDACIJU ODMAH
Samo kopiraj i pokreni ovu ƒáeliju!
"""

# ============================================================================
# PODESI PUTANJE (promeni ako treba)
# ============================================================================
MODEL_PATH = 'checkpoints_championship/best_model.pt'  # Putanja do tvog modela
BASE_PATH = '/content/drive/MyDrive/AI-speak'      # Putanja do podataka

# ============================================================================
# POKRENI VALIDACIJU
# ============================================================================

print("üîç Pokreƒáem validaciju...\n")

# Kreiraj validator
validator = ChampionshipValidator(
    model_path=MODEL_PATH,
    base_path=BASE_PATH,
    device='cuda'
)

# Generi≈°i report (10 primera)
validation_df = validator.generate_championship_report(num_samples=10)

# ============================================================================
# PRIKA≈ΩI GRAFIKE (ako si u Jupyter/Colab)
# ============================================================================

print("\n" + "="*70)
print("üìä PRIKAZ GRAFIKA")
print("="*70 + "\n")

import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import os

# Prika≈æi prvih 3 grafika
plot_dir = 'validation_plots/'
if os.path.exists(plot_dir):
    plots = sorted([f for f in os.listdir(plot_dir) if f.endswith('_analysis.png')])

    print(f"Pronaƒëeno {len(plots)} grafika. Prikazujem prva 3...\n")

    for plot_file in plots[:3]:
        img = mpimg.imread(os.path.join(plot_dir, plot_file))
        plt.figure(figsize=(20, 15))
        plt.imshow(img)
        plt.axis('off')
        plt.title(f'Validation: {plot_file}', fontsize=16)
        plt.show()
        print(f"‚úÖ Prikazano: {plot_file}\n")
else:
    print("‚ö†Ô∏è Folder 'validation_plots/' ne postoji")

# ============================================================================
# DODATNE STATISTIKE
# ============================================================================

print("\n" + "="*70)
print("üìà DETALJNE STATISTIKE")
print("="*70 + "\n")

print("Per-file breakdown:")
print(validation_df[['file_id', 'correlation', 'mae', 'mouth_correlation', 'smoothness_ratio']].to_string(index=False))

print(f"\nüìä Najbolji fajl:")
best_idx = validation_df['correlation'].idxmax()
print(f"   File: {validation_df.loc[best_idx, 'file_id']}")
print(f"   Correlation: {validation_df.loc[best_idx, 'correlation']:.4f}")
print(f"   MAE: {validation_df.loc[best_idx, 'mae']:.6f}")

print(f"\nüìä Najgori fajl:")
worst_idx = validation_df['correlation'].idxmin()
print(f"   File: {validation_df.loc[worst_idx, 'file_id']}")
print(f"   Correlation: {validation_df.loc[worst_idx, 'correlation']:.4f}")
print(f"   MAE: {validation_df.loc[worst_idx, 'mae']:.6f}")

print("\n" + "="*70)
print("‚úÖ VALIDACIJA ZAVR≈†ENA!")
print("="*70)
print(f"üìÇ CSV report: championship_validation_metrics.csv")
print(f"üìÇ Grafici: validation_plots/")

In [None]:



import shutil
import os
from datetime import datetime

# ==========================================
# KONFIGURACIJA ZA ƒåUVANJE
# ==========================================
# Gde trainer ƒçuva modele (proveri da li se poklapa sa tvojim 'save_dir')
LOCAL_CHECKPOINT_PATH = 'checkpoints_championship/best_model.pt'

# Gde na Drive-u ≈æeli≈° da saƒçuva≈°
DRIVE_FOLDER = '/content/drive/MyDrive/AI-speak/Final_Models'

# Kreiraj folder na Drive-u ako ne postoji
if not os.path.exists(DRIVE_FOLDER):
    os.makedirs(DRIVE_FOLDER)
    print(f"‚úÖ Kreiran folder na Drive-u: {DRIVE_FOLDER}")

# Kreiranje imena fajla sa datumom i vremenom
timestamp = datetime.now().strftime("%Y%m%d_%H%M")
destination_path = os.path.join(DRIVE_FOLDER, f'championship_model_best_{timestamp}.pt')

# ==========================================
# KOPIRANJE
# ==========================================
if os.path.exists(LOCAL_CHECKPOINT_PATH):
    try:
        shutil.copy(LOCAL_CHECKPOINT_PATH, destination_path)
        print("="*50)
        print(f"üöÄ USPE≈†NO SAƒåUVANO NA DRIVE!")
        print(f"üìç Lokacija: {destination_path}")
        print("="*50)
    except Exception as e:
        print(f"‚ùå Gre≈°ka pri kopiranju: {e}")
else:
    print(f"‚ö†Ô∏è Fajl {LOCAL_CHECKPOINT_PATH} nije pronaƒëen. Proveri putanju treninga.")

# Opciono: Kopiraj i poslednji CSV izve≈°taj o metrici ako postoji
if os.path.exists('championship_validation_metrics.csv'):
    shutil.copy('championship_validation_metrics.csv',
                os.path.join(DRIVE_FOLDER, f'metrics_{timestamp}.csv'))
    print(f"üìä Saƒçuvan i CSV izve≈°taj.")

