In [None]:
pip install python-levenshtein tqdm pyarrow wandb

In [None]:
import os
import json
import random
import math
from typing import List, Optional, Tuple, Dict
from pathlib import Path
import numpy as np
import pandas as pd
import tensorflow as tf
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
import Levenshtein
from tqdm.auto import tqdm
import time
import pyarrow.parquet as pq
import wandb
import sys

wandb.login(key="afe8b8c0a3f1c1339a3daa9f619cb7c311218022")
wandb.init(project="asl-translation")

# Constants
LPOSE = [13, 15, 17, 19, 21]
RPOSE = [14, 16, 18, 20, 22]
POSE = LPOSE + RPOSE
FRAME_LEN = 128

def create_feature_columns():
    X = [f'x_right_hand_{i}' for i in range(21)] + [f'x_left_hand_{i}' for i in range(21)] + [f'x_pose_{i}' for i in POSE]
    Y = [f'y_right_hand_{i}' for i in range(21)] + [f'y_left_hand_{i}' for i in range(21)] + [f'y_pose_{i}' for i in POSE]
    Z = [f'z_right_hand_{i}' for i in range(21)] + [f'z_left_hand_{i}' for i in range(21)] + [f'z_pose_{i}' for i in POSE]
    return X + Y + Z

FEATURE_COLUMNS = create_feature_columns()

class ASLTokenizer:
    """Tokenizer for ASL fingerspelling sequences"""
    def __init__(self, vocab_path: str):
        with open(vocab_path, 'r') as f:
            self.char_to_idx = json.load(f)
            
        # Add special tokens
        self.char_to_idx['P'] = 59  # Padding token
        self.char_to_idx['<'] = 60  # Start token
        self.char_to_idx['>'] = 61  # End token
            
        self.idx_to_char = {v: k for k, v in self.char_to_idx.items()}
        self.vocab_size = len(self.char_to_idx)
        
    def encode(self, text: str) -> torch.Tensor:
        """Convert text to token indices"""
        tokens = [self.char_to_idx['<']]  # Start token
        for char in text:
            tokens.append(self.char_to_idx.get(char, self.char_to_idx['P']))
        tokens.append(self.char_to_idx['>'])  # End token
        return torch.tensor(tokens)
    
    def decode(self, tokens: torch.Tensor) -> str:
        """Convert token indices to text"""
        text = []
        for token in tokens:
            if token.item() == self.char_to_idx['>']:
                break
            if token.item() not in [self.char_to_idx['P'], self.char_to_idx['<']]:
                text.append(self.idx_to_char[token.item()])
        return ''.join(text)

def preprocess_data(
    data_dir: str,
    metadata_path: str,
    output_dir: str,
    chunk_size: int = 1000
):
    """Preprocess the data and save as TFRecords"""
    print("Loading metadata...")
    df = pd.read_csv(metadata_path)
    os.makedirs(output_dir, exist_ok=True)
    
    # Create sequence to file mapping
    print("Creating sequence index...")
    sequence_map = {}
    for parquet_file in tqdm(list(Path(data_dir).glob('*.parquet'))):
        table = pq.read_table(parquet_file, columns=['sequence_id'])
        sequences = pd.unique(table['sequence_id'].to_numpy())
        for seq_id in sequences:
            sequence_map[seq_id] = str(parquet_file)
    
    # Filter sequences
    df = df[df['sequence_id'].isin(sequence_map.keys())]
    
    # Process data in chunks and save as TFRecords
    num_chunks = (len(df) + chunk_size - 1) // chunk_size
    
    for chunk_idx in range(num_chunks):
        chunk_start = chunk_idx * chunk_size
        chunk_end = min((chunk_idx + 1) * chunk_size, len(df))
        chunk_df = df.iloc[chunk_start:chunk_end]
        
        tf_file = os.path.join(output_dir, f'chunk_{chunk_idx:04d}.tfrecord')
        
        with tf.io.TFRecordWriter(tf_file) as writer:
            for _, row in tqdm(chunk_df.iterrows(), total=len(chunk_df), 
                             desc=f"Processing chunk {chunk_idx+1}/{num_chunks}"):
                seq_id = row['sequence_id']
                parquet_file = sequence_map[seq_id]
                
                # Read sequence data
                table = pq.read_table(
                    parquet_file,
                    filters=[('sequence_id', '=', seq_id)]
                )
                seq_df = table.to_pandas()
                
                # Extract landmarks
                landmark_cols = [col for col in seq_df.columns if col in FEATURE_COLUMNS]
                frames = seq_df[landmark_cols].values
                
                # Calculate the number of NaN values in each hand landmark
                right_hand_cols = [col for col in landmark_cols if 'right_hand' in col]
                left_hand_cols = [col for col in landmark_cols if 'left_hand' in col]
                
                r_nonan = np.sum(np.sum(np.isnan(frames[:, [i for i, col in enumerate(landmark_cols) if col in right_hand_cols]]), axis=1) == 0)
                l_nonan = np.sum(np.sum(np.isnan(frames[:, [i for i, col in enumerate(landmark_cols) if col in left_hand_cols]]), axis=1) == 0)
                no_nan = max(r_nonan, l_nonan)
                
                if 2 * len(row['phrase']) < no_nan:
                    # Create feature dictionary for TFRecord
                    feature = {
                        col: tf.train.Feature(
                            float_list=tf.train.FloatList(value=frames[:, i]))
                        for i, col in enumerate(landmark_cols)
                    }
                    feature['phrase'] = tf.train.Feature(
                        bytes_list=tf.train.BytesList(value=[bytes(row['phrase'], 'utf-8')])
                    )
                    
                    # Write to TFRecord
                    example = tf.train.Example(features=tf.train.Features(feature=feature))
                    writer.write(example.SerializeToString())
    
    # Save metadata
    metadata = {
        'num_chunks': num_chunks,
        'chunk_size': chunk_size,
        'total_sequences': len(df),
        'feature_columns': FEATURE_COLUMNS
    }
    
    with open(os.path.join(output_dir, 'metadata.json'), 'w') as f:
        json.dump(metadata, f)
    
    print(f"Preprocessing complete. {num_chunks} chunks saved to {output_dir}")

class ASLDataset(Dataset):
    def __init__(
        self,
        tf_records: List[str],
        tokenizer,
        max_len: int = FRAME_LEN,
        augment: bool = True,
        mode: str = 'train'
    ):
        self.tf_records = tf_records
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.augment = augment
        self.mode = mode
        
        # Create TensorFlow dataset
        self.dataset = tf.data.TFRecordDataset(self.tf_records)
        self.dataset = self.dataset.map(self.decode_fn)
        self.dataset = self.dataset.map(self.convert_fn)
        self.dataset = self.dataset.cache()
        
        # Convert to PyTorch tensors
        self.data = list(self.dataset.as_numpy_iterator())
    
    def decode_fn(self, record_bytes):
        """Decode TFRecord data"""
        schema = {
            col: tf.io.VarLenFeature(dtype=tf.float32)
            for col in FEATURE_COLUMNS
        }
        schema["phrase"] = tf.io.FixedLenFeature([], dtype=tf.string)
        
        features = tf.io.parse_single_example(record_bytes, schema)
        phrase = features["phrase"]
        landmarks = tf.stack([
            tf.sparse.to_dense(features[col])
            for col in FEATURE_COLUMNS
        ], axis=1)
        
        return landmarks, phrase
    
    def convert_fn(self, landmarks, phrase):
        """Convert and preprocess the data"""
        landmarks = self.preprocess_landmarks(landmarks)
        phrase = self.tokenizer.encode(phrase.numpy().decode('utf-8'))
        return landmarks, phrase
    
    def preprocess_landmarks(self, landmarks):
        """Preprocess landmarks using competition approach"""
        # Detect dominant hand
        rhand = tf.gather(landmarks, [i for i, col in enumerate(FEATURE_COLUMNS) if 'right_hand' in col], axis=1)
        lhand = tf.gather(landmarks, [i for i, col in enumerate(FEATURE_COLUMNS) if 'left_hand' in col], axis=1)
        
        rnan_idx = tf.reduce_any(tf.math.is_nan(rhand), axis=1)
        lnan_idx = tf.reduce_any(tf.math.is_nan(lhand), axis=1)
        
        rnans = tf.math.count_nonzero(rnan_idx)
        lnans = tf.math.count_nonzero(lnan_idx)
        
        # Use dominant hand
        hand = lhand if rnans > lnans else rhand
        
        # Normalize
        mean = tf.reduce_mean(hand, axis=0, keepdims=True)
        std = tf.reduce_std(hand, axis=0, keepdims=True)
        std = tf.where(tf.equal(std, 0), tf.ones_like(std), std)
        hand = (hand - mean) / std
        
        # Handle sequence length
        if tf.shape(hand)[0] > self.max_len:
            hand = tf.image.resize(hand[None], (self.max_len, tf.shape(hand)[1]))[0]
        else:
            pad_len = self.max_len - tf.shape(hand)[0]
            hand = tf.pad(hand, [[0, pad_len], [0, 0]])
        
        return hand
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        landmarks, tokens = self.data[idx]
        return {
            'landmarks': torch.from_numpy(landmarks).float(),
            'tokens': torch.from_numpy(tokens),
            'phrase': tokens.tobytes().decode('utf-8'),
            'length': torch.tensor(len(tokens))
        }

def collate_fn(batch: List[Dict]) -> Dict[str, torch.Tensor]:
    """Custom collate function for batching"""
    max_token_len = max(item['tokens'].size(0) for item in batch)
    
    landmarks = torch.stack([item['landmarks'] for item in batch])
    tokens = torch.stack([
        F.pad(item['tokens'], (0, max_token_len - item['tokens'].size(0)), value=59)  # pad_token_idx = 59
        for item in batch
    ])
    lengths = torch.stack([item['length'] for item in batch])
    
    return {
        'landmarks': landmarks,
        'tokens': tokens,
        'phrase': [item['phrase'] for item in batch],
        'length': lengths
    }

class FeatureExtractor(nn.Module):
    def __init__(self, input_channels: int = 3, output_dim: int = 52):
        super().__init__()
        self.conv = nn.Conv1d(input_channels, 64, kernel_size=3, padding=1)
        self.bn = nn.BatchNorm1d(64)
        self.linear = nn.Linear(64, output_dim)

    def forward(self, x):
        B, T, L, C = x.shape
        x = x.permute(0, 1, 3, 2)
        x = x.reshape(B * T, C, L)
        x = self.conv(x)
        x = self.bn(x)
        x = F.relu(x)
        x = x.mean(dim=2)
        x = self.linear(x)
        x = x.reshape(B, T, -1)
        return x

class RotaryPositionalEmbedding(nn.Module):
    def __init__(self, dim: int, max_seq_len: int = 128):  # Changed from 384 to 128
        super().__init__()
        head_dim = dim // 8
        half_head_dim = head_dim // 2
        emb = math.log(10000) / (half_head_dim - 1)
        emb = torch.exp(torch.arange(half_head_dim) * -emb)
        pos = torch.arange(max_seq_len)
        emb = pos[:, None] * emb[None, :]
        self.register_buffer('sin', emb.sin())
        self.register_buffer('cos', emb.cos())

    def forward(self, x):
        seq_len = x.shape[1]
        return self.sin[:seq_len], self.cos[:seq_len]

class MultiHeadAttention(nn.Module):
    def __init__(self, dim: int, num_heads: int = 8, dropout: float = 0.1):
        super().__init__()
        self.num_heads = num_heads
        self.dim = dim
        self.head_dim = dim // num_heads
        assert self.head_dim * num_heads == dim
        
        self.q_proj = nn.Linear(dim, dim)
        self.k_proj = nn.Linear(dim, dim)
        self.v_proj = nn.Linear(dim, dim)
        self.out_proj = nn.Linear(dim, dim)
        self.dropout = nn.Dropout(dropout)

    def apply_rotary_pos_emb(self, q, k, sin, cos):
        sin = sin.unsqueeze(0).unsqueeze(2)
        cos = cos.unsqueeze(0).unsqueeze(2)
        q1, q2 = q.chunk(2, dim=-1)
        k1, k2 = k.chunk(2, dim=-1)
        q = torch.cat([q1 * cos - q2 * sin, q2 * cos + q1 * sin], dim=-1)
        k = torch.cat([k1 * cos - k2 * sin, k2 * cos + k1 * sin], dim=-1)
        return q, k

    def forward(self, x: torch.Tensor, sin: torch.Tensor, cos: torch.Tensor, 
                mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        B, L, D = x.shape
        
        q = self.q_proj(x).reshape(B, L, self.num_heads, self.head_dim)
        k = self.k_proj(x).reshape(B, L, self.num_heads, self.head_dim)
        v = self.v_proj(x).reshape(B, L, self.num_heads, self.head_dim)
        
        q, k = self.apply_rotary_pos_emb(q, k, sin, cos)
        
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)
        
        scale = self.head_dim ** -0.5
        attn = torch.matmul(q, k.transpose(-2, -1)) * scale
        
        if mask is not None:
            mask = mask.unsqueeze(1).unsqueeze(2)
            attn = attn.masked_fill(~mask, float('-inf'))
        
        attn = F.softmax(attn, dim=-1)
        attn = self.dropout(attn)
        
        out = torch.matmul(attn, v)
        out = out.transpose(1, 2).contiguous()
        out = out.reshape(B, L, D)
        
        return self.out_proj(out)

class SqueezeformerBlock(nn.Module):
    def __init__(self, dim: int, num_heads: int = 8, dropout: float = 0.1):
        super().__init__()
        self.dim = dim
        self.norm1 = nn.LayerNorm(dim)
        self.mhsa = MultiHeadAttention(dim, num_heads, dropout)
        
        # Feed forward modules
        self.ff1_norm = nn.LayerNorm(dim)
        self.ff1 = nn.Sequential(
            nn.Linear(dim, dim*4),
            nn.SiLU(),
            nn.Dropout(dropout),
            nn.Linear(dim*4, dim),
            nn.Dropout(dropout)
        )
        
        # Convolution module
        self.conv_norm = nn.LayerNorm(dim)
        self.conv1 = nn.Conv1d(dim, dim*2, 1)
        self.glu = nn.GLU(dim=1)
        self.depthwise_conv = nn.Conv1d(dim, dim, 3, padding=1, groups=dim)
        self.batch_norm = nn.BatchNorm1d(dim)
        self.activation = nn.SiLU()
        self.pointwise_conv = nn.Conv1d(dim, dim, 1)
        self.conv_dropout = nn.Dropout(dropout)
        
        # Feed forward module 2
        self.ff2_norm = nn.LayerNorm(dim)
        self.ff2 = nn.Sequential(
            nn.Linear(dim, dim*4),
            nn.SiLU(),
            nn.Dropout(dropout),
            nn.Linear(dim*4, dim),
            nn.Dropout(dropout)
        )
        
        self.dropout = nn.Dropout(dropout)
        self.scale = nn.Parameter(torch.ones(1))

    def forward(self, x: torch.Tensor, sin: torch.Tensor, cos: torch.Tensor,
                mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        # First feed forward
        residual = x
        x = self.ff1_norm(x)
        x = self.ff1(x)
        x = residual + x * self.scale
        
        # Self attention
        residual = x
        x = self.norm1(x)
        x = self.mhsa(x, sin, cos, mask)
        x = self.dropout(x)
        x = residual + x * self.scale
        
        # Convolution module
        residual = x
        x = self.conv_norm(x)
        x = x.transpose(1, 2)
        x = self.conv1(x)
        x = self.glu(x)
        x = self.depthwise_conv(x)
        x = self.batch_norm(x)
        x = self.activation(x)
        x = self.pointwise_conv(x)
        x = self.conv_dropout(x)
        x = x.transpose(1, 2)
        x = residual + x * self.scale
        
        # Second feed forward
        residual = x
        x = self.ff2_norm(x)
        x = self.ff2(x)
        x = residual + x * self.scale
        
        return x

class ASLTranslationModel(nn.Module):
    def __init__(
        self,
        num_landmarks: int = 130,
        feature_dim: int = 208,
        num_classes: int = 62,  # Updated to match competition (59 + 3 special tokens)
        num_layers: int = 2,
        dropout: float = 0.1
    ):
        super().__init__()
        
        # Feature extractors
        self.face_extractor = FeatureExtractor(3, 52)
        self.pose_extractor = FeatureExtractor(3, 52)
        self.left_hand_extractor = FeatureExtractor(3, 52)
        self.right_hand_extractor = FeatureExtractor(3, 52)
        
        # Embeddings
        self.target_embedding = nn.Embedding(num_classes, feature_dim)
        self.pos_embedding = RotaryPositionalEmbedding(feature_dim)
        
        # Squeezeformer encoder
        self.squeezeformer_layers = nn.ModuleList([
            SqueezeformerBlock(feature_dim, dropout=dropout)
            for _ in range(num_layers)
        ])
        
        # Decoder
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=feature_dim,
            nhead=8,
            dim_feedforward=feature_dim*4,
            dropout=dropout,
            batch_first=True,
            norm_first=True
        )
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=2)
        
        # Output layers
        self.confidence_head = nn.Linear(feature_dim, 1)
        self.classifier = nn.Linear(feature_dim, num_classes)
        self.dropout = nn.Dropout(dropout)

    def forward(
        self,
        x: torch.Tensor,
        mask: Optional[torch.Tensor] = None,
        tgt: Optional[torch.Tensor] = None
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        B, T, L, C = x.shape
        
        # Extract features
        face = x[:, :, :76]
        pose = x[:, :, 76:88]
        left_hand = x[:, :, 88:109]
        right_hand = x[:, :, 109:]
        
        # Process each part
        face_feats = self.face_extractor(face)
        pose_feats = self.pose_extractor(pose)
        left_hand_feats = self.left_hand_extractor(left_hand)
        right_hand_feats = self.right_hand_extractor(right_hand)
        
        # Combine features
        features = torch.cat([face_feats, pose_feats, left_hand_feats, right_hand_feats], dim=-1)
        
        # Get positional embeddings
        sin, cos = self.pos_embedding(features)
        
        # Encoder
        encoder_out = features
        encoder_padding_mask = mask if mask is not None else None
        
        for layer in self.squeezeformer_layers:
            encoder_out = layer(encoder_out, sin, cos, encoder_padding_mask)
        
        confidence = self.confidence_head(encoder_out[:, 0]).squeeze(-1)
        
        if tgt is not None:
            tgt_embedded = self.target_embedding(tgt)
            tgt_embedded = self.dropout(tgt_embedded)
            
            tgt_mask = self.generate_square_subsequent_mask(tgt.size(1)).to(x.device)
            memory_padding_mask = ~encoder_padding_mask if encoder_padding_mask is not None else None
            
            decoder_out = self.decoder(
                tgt_embedded,
                encoder_out,
                tgt_mask=tgt_mask,
                memory_key_padding_mask=memory_padding_mask
            )
            output = self.classifier(decoder_out)
        else:
            output = self.classifier(encoder_out)
        
        return output, confidence

    @staticmethod
    def generate_square_subsequent_mask(sz: int) -> torch.Tensor:
        mask = torch.triu(torch.ones(sz, sz), diagonal=1)
        mask = mask.masked_fill(mask == 1, float('-inf'))
        mask = mask.masked_fill(mask == 0, float(0.0))
        return mask

class ASLTranslationLoss(nn.Module):
    def __init__(self, pad_idx: int = 59):  # 59 is pad token index
        super().__init__()
        self.criterion = nn.CrossEntropyLoss(ignore_index=pad_idx)
        
    def forward(
        self,
        pred: torch.Tensor,
        target: torch.Tensor,
        confidence: torch.Tensor,
        confidence_target: torch.Tensor
    ) -> torch.Tensor:
        seq_loss = self.criterion(pred.reshape(-1, pred.size(-1)), target.reshape(-1))
        conf_loss = F.mse_loss(confidence, confidence_target)
        return seq_loss + 0.1 * conf_loss


class Trainer:
    def __init__(
        self,
        model: nn.Module,
        train_loader: DataLoader,
        val_loader: DataLoader,
        tokenizer: ASLTokenizer,
        learning_rate: float = 0.0001,  # Changed to match competition
        weight_decay: float = 0.08,
        warmup_epochs: int = 1,
        max_epochs: int = 13,  # Changed to match competition
        device: str = 'cuda',
        wandb_config: dict = None
    ):
        self.model = model.to(device)
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.tokenizer = tokenizer
        self.device = device
        self.max_epochs = max_epochs
        self.wandb_config = wandb_config
        
        if wandb_config:
            wandb.init(
                project=wandb_config['asl-translation'],
                name=wandb_config['run_name'],
                config={
                    'learning_rate': learning_rate,
                    'weight_decay': weight_decay,
                    'warmup_epochs': warmup_epochs,
                    'max_epochs': max_epochs,
                    'batch_size': train_loader.batch_size,
                    'architecture': 'Squeezeformer'
                }
            )
            wandb.watch(model, log_freq=100)
        
        self.optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=learning_rate,
            weight_decay=weight_decay
        )
        
        self.num_training_steps = len(train_loader) * max_epochs
        self.num_warmup_steps = len(train_loader) * warmup_epochs
        self.scheduler = torch.optim.lr_scheduler.OneCycleLR(
            self.optimizer,
            max_lr=learning_rate,
            total_steps=self.num_training_steps,
            pct_start=self.num_warmup_steps / self.num_training_steps,
            anneal_strategy='cos',
            cycle_momentum=False
        )
        
        self.criterion = ASLTranslationLoss()
        self.scaler = GradScaler()
    
    def train_epoch(self) -> float:
        self.model.train()
        total_loss = 0
        
        progress_bar = tqdm(
            self.train_loader,
            desc='Training',
            leave=True,
            dynamic_ncols=True
        )
        
        for batch_idx, batch in enumerate(progress_bar):
            self.optimizer.zero_grad()
            
            landmarks = batch['landmarks'].to(self.device)
            tokens = batch['tokens'].to(self.device)
            lengths = batch['length'].to(self.device)
            
            # Create mask
            mask = torch.arange(landmarks.size(1), device=self.device)[None, :] < lengths[:, None]
            
            with autocast():
                pred, confidence = self.model(landmarks, mask, tokens[:, :-1])
                
                with torch.no_grad():
                    confidence_target = torch.tensor([
                        1 - Levenshtein.distance(
                            self.tokenizer.decode(p.argmax(-1).cpu()),
                            true_text
                        ) / max(len(true_text), 1)
                        for p, true_text in zip(pred, batch['phrase'])
                    ], device=self.device)
                
                loss = self.criterion(pred, tokens[:, 1:], confidence, confidence_target)
            
            self.scaler.scale(loss).backward()
            self.scaler.unscale_(self.optimizer)
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
            self.scaler.step(self.optimizer)
            self.scaler.update()
            self.scheduler.step()
            
            total_loss += loss.item()
            current_lr = self.optimizer.param_groups[0]['lr']
            
            if self.wandb_config and batch_idx % 10 == 0:
                wandb.log({
                    'batch_loss': loss.item(),
                    'learning_rate': current_lr,
                    'batch_confidence': confidence.mean().item(),
                    'batch_confidence_target': confidence_target.mean().item()
                })
            
            progress_bar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'lr': f'{current_lr:.2e}'
            })
        
        return total_loss / len(self.train_loader)
    
    @torch.no_grad()
    def validate(self) -> Tuple[float, float]:
        self.model.eval()
        total_loss = 0
        predictions = []
        ground_truth = []
        confidence_scores = []
        
        for batch in tqdm(self.val_loader, desc='Validating'):
            landmarks = batch['landmarks'].to(self.device)
            tokens = batch['tokens'].to(self.device)
            lengths = batch['length'].to(self.device)
            
            mask = torch.arange(landmarks.size(1), device=self.device)[None, :] < lengths[:, None]
            
            pred, confidence = self.model(landmarks, mask)
            confidence_scores.extend(confidence.cpu().tolist())
            
            pred_texts = [self.tokenizer.decode(p.argmax(-1)) for p in pred]
            predictions.extend(pred_texts)
            ground_truth.extend(batch['phrase'])
            
            confidence_target = torch.tensor([
                1 - Levenshtein.distance(pred_text, true_text) / max(len(true_text), 1)
                for pred_text, true_text in zip(pred_texts, batch['phrase'])
            ]).to(self.device)
            
            loss = self.criterion(pred, tokens[:, 1:], confidence, confidence_target)
            total_loss += loss.item()
        
        avg_loss = total_loss / len(self.val_loader)
        distances = [
            1 - Levenshtein.distance(pred, true) / max(len(pred), len(true))
            for pred, true in zip(predictions, ground_truth)
        ]
        avg_score = sum(distances) / len(distances)
        
        if self.wandb_config:
            wandb.log({
                'val_loss': avg_loss,
                'val_score': avg_score,
                'val_confidence_mean': np.mean(confidence_scores),
                'val_confidence_std': np.std(confidence_scores)
            })
        
        return avg_loss, avg_score
    
    def train(self, save_dir: str):
        os.makedirs(save_dir, exist_ok=True)
        best_val_score = float('-inf')
        
        for epoch in range(self.max_epochs):
            print(f"\nEpoch {epoch + 1}/{self.max_epochs}")
            train_loss = self.train_epoch()
            
            if (epoch + 1) % 2 == 0:  # Validate every 2 epochs
                val_loss, val_score = self.validate()
                print(f"Validation Loss: {val_loss:.4f}")
                print(f"Validation Score: {val_score:.4f}")
                
                if val_score > best_val_score:
                    best_val_score = val_score
                    torch.save({
                        'epoch': epoch,
                        'model_state_dict': self.model.state_dict(),
                        'optimizer_state_dict': self.optimizer.state_dict(),
                        'scheduler_state_dict': self.scheduler.state_dict(),
                        'val_score': val_score,
                    }, os.path.join(save_dir, 'best_model.pt'))
            
            if (epoch + 1) % 5 == 0:  # Save checkpoint every 5 epochs
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': self.model.state_dict(),
                    'optimizer_state_dict': self.optimizer.state_dict(),
                    'scheduler_state_dict': self.scheduler.state_dict(),
                }, os.path.join(save_dir, f'checkpoint_epoch_{epoch+1}.pt'))

def create_dataloaders(
    processed_dir: str,
    tokenizer,
    batch_size: int = 64,
    num_workers: int = 4
):
    """Create train and validation dataloaders"""
    # Get TFRecord files
    tf_records = sorted(list(Path(processed_dir).glob('chunk_*.tfrecord')))
    train_size = int(len(tf_records) * 0.8)
    
    # Create datasets
    train_dataset = ASLDataset(
        tf_records[:train_size],
        tokenizer,
        augment=True,
        mode='train'
    )
    
    val_dataset = ASLDataset(
        tf_records[train_size:],
        tokenizer,
        augment=False,
        mode='val'
    )
    
    # Create dataloaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True,
        persistent_workers=True,
        collate_fn=collate_fn
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True,
        persistent_workers=True,
        collate_fn=collate_fn
    )
    
    return train_loader, val_loader

def main():
    config = {
        'data_dir': '/kaggle/input/asl-fingerspelling/train_landmarks',
        'metadata_path': '/kaggle/input/asl-fingerspelling/train.csv',
        'vocab_path': '/kaggle/input/asl-fingerspelling/character_to_prediction_index.json',
        'processed_dir': '/kaggle/working/processed_data',
        'save_dir': '/kaggle/working/models',
        'batch_size': 64,
        'num_workers': 4,
        'learning_rate': 0.0001,  # Competition learning rate
        'weight_decay': 0.08,
        'warmup_epochs': 1,
        'max_epochs': 13,  # Competition number of epochs
        'device': 'cuda' if torch.cuda.is_available() else 'cpu'
    }
    
    # Set random seeds
    torch.manual_seed(42)
    random.seed(42)
    np.random.seed(42)
    
    # Preprocess data if not already done
    if not os.path.exists(os.path.join(config['processed_dir'], 'metadata.json')):
        print("Preprocessing data...")
        preprocess_data(
            config['data_dir'],
            config['metadata_path'],
            config['processed_dir']
        )
    
    # Initialize tokenizer
    tokenizer = ASLTokenizer(config['vocab_path'])
    
    # Create dataloaders
    train_loader, val_loader = create_dataloaders(
        config['processed_dir'],
        tokenizer,
        batch_size=config['batch_size'],
        num_workers=config['num_workers']
    )
    
    # Initialize model
    model = ASLTranslationModel(
        num_landmarks=130,
        feature_dim=208,
        num_classes=62,  # 59 chars + 3 special tokens
        num_layers=2,
        dropout=0.1
    )
    
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
    
    # Initialize trainer
    trainer = Trainer(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        tokenizer=tokenizer,
        learning_rate=config['learning_rate'],
        weight_decay=config['weight_decay'],
        warmup_epochs=config['warmup_epochs'],
        max_epochs=config['max_epochs'],
        device=config['device'],
        wandb_config={
            'project': 'asl-translation',
            'run_name': f'asl-translation-{time.strftime("%Y%m%d-%H%M%S")}'
        }
    )
    
    # Train model
    print("\nStarting training...")
    trainer.train(config['save_dir'])

if __name__ == "__main__":
    main()