In [None]:
# %% [markdown]
# ## Train RNN Model
# Train the standalone RNN model with CTC loss.

# %% [code]
import torch
import torch.nn.functional as F

def train_rnn(model, train_loader, val_loader, device, optimizer, epochs=50):
    """Train the RNN model with CTC loss.
    
    Args:
        model: RNNModel instance.
        train_loader: DataLoader for training data.
        val_loader: DataLoader for validation data.
        device: Device to run training on (cuda or cpu).
        optimizer: Optimizer for training.
        epochs: Number of training epochs.
    
    Returns:
        model: Trained model.
    """
    num_classes = len(train_loader.dataset.label_map) + 1  # +1 for blank token
    print(f"Number of classes: {num_classes}")  # Debug: Should be 1401
    ctc_loss = torch.nn.CTCLoss(blank=0, zero_infinity=True)
    model.train()
    
    for epoch in range(epochs):
        total_loss = 0
        for batch_idx, (images, labels) in enumerate(train_loader):
            images, labels = images.to(device), labels.to(device)
            print(f"Epoch {epoch+1}, Batch {batch_idx}, Image shape: {images.shape}, Label shape: {labels.shape}")  # Debug
            optimizer.zero_grad()
            outputs = model(images)  # Shape: (batch_size, seq_len, num_classes)
            print(f"Model output shape: {outputs.shape}")  # Debug
            if outputs.size(2) != num_classes:
                raise ValueError(f"Model output has {outputs.size(2)} classes, expected {num_classes}")
            outputs = F.log_softmax(outputs, dim=2)
            input_lengths = torch.full((images.size(0),), outputs.size(1), dtype=torch.long)
            target_lengths = torch.ones(images.size(0), dtype=torch.long)  # Single character per image
            try:
                loss = ctc_loss(outputs.transpose(0, 1), labels, input_lengths, target_lengths)
                loss.backward()
                optimizer.step()
                total_loss += loss.item()
            except Exception as e:
                print(f"CTC Loss Error: {e}")
                print(f"Output shape: {outputs.shape}, Labels shape: {labels.shape}")
                print(f"Input lengths: {input_lengths}, Target lengths: {target_lengths}")
                raise
        print(f"Epoch {epoch+1}, Loss: {total_loss / len(train_loader)}")
        
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                outputs = F.log_softmax(outputs, dim=2)
                input_lengths = torch.full((images.size(0),), outputs.size(1), dtype=torch.long)
                target_lengths = torch.ones(images.size(0), dtype=torch.long)
                try:
                    loss = ctc_loss(outputs.transpose(0, 1), labels, input_lengths, target_lengths)
                    val_loss += loss.item()
                except Exception as e:
                    print(f"Validation CTC Loss Error: {e}")
                    print(f"Output shape: {outputs.shape}, Labels shape: {labels.shape}")
                    raise
        print(f"Validation Loss: {val_loss / len(val_loader)}")
        model.train()
    
    return model