In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
!pip install unidecode
# !pip install -r requirements.txt




[notice] A new release of pip is available: 24.0 -> 25.0.1
[notice] To update, run: C:\Users\aisha\AppData\Local\Microsoft\WindowsApps\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\python.exe -m pip install --upgrade pip


In [3]:
import yaml
from pathlib import Path
import torch
from torch.utils.data import DataLoader, ConcatDataset, Subset
from emg2qwerty.data import WindowedEMGDataset
from emg2qwerty.transforms import Compose, ToTensor, LogSpectrogram, SpecAugment, RandomBandRotation, TemporalAlignmentJitter


batch_size = 32
num_workers = 2

# 1. Load the configuration file
config_path = Path("config/user/single_user.yaml")
with open(config_path, 'r') as f:
    first_line = f.readline()
    if first_line.startswith('# @package'):
        cfg = yaml.safe_load(f)
    else:
        f.seek(0)
        cfg = yaml.safe_load(f)

# 2. Set the data root path
data_root = Path("data")

# 3. Create transforms
# Training transforms with augmentation
train_transform = Compose([
    ToTensor(),
    LogSpectrogram(),
    SpecAugment(),
    RandomBandRotation()
])

# Validation/test transforms without augmentation
eval_transform = Compose([
    ToTensor(),
    LogSpectrogram()
])

# 4. Create datasets
train_datasets = []
for session_info in cfg['dataset']['train']:
    session_id = session_info['session']
    file_path = data_root / f"{session_id}.hdf5"

    dataset = WindowedEMGDataset(
        hdf5_path=file_path,
        window_length=2000,  # 1 second at 2kHz
        stride=1000,         # 50% overlap
        padding=(200, 200),  # 100ms context on each side
        jitter=True,         # Apply jitter for training
        transform=train_transform
    )
    train_datasets.append(dataset)

val_datasets = []
for session_info in cfg['dataset']['val']:
    session_id = session_info['session']
    file_path = data_root / f"{session_id}.hdf5"

    dataset = WindowedEMGDataset(
        hdf5_path=file_path,
        window_length=2000,
        stride=1000,
        padding=(200, 200),
        jitter=False,  # No jitter for validation
        transform=eval_transform
    )
    val_datasets.append(dataset)

test_datasets = []
for session_info in cfg['dataset']['test']:
    session_id = session_info['session']
    file_path = data_root / f"{session_id}.hdf5"

    dataset = WindowedEMGDataset(
        hdf5_path=file_path,
        window_length=2000,
        stride=1000,
        padding=(200, 200),
        jitter=False,  # No jitter for testing
        transform=eval_transform
    )
    test_datasets.append(dataset)

# 5. Combine datasets and create DataLoaders
train_dataset = ConcatDataset(train_datasets)
val_dataset = ConcatDataset(val_datasets)
test_dataset = ConcatDataset(test_datasets)

train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=WindowedEMGDataset.collate,
    num_workers=num_workers
)


val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,  # No shuffle for validation
    collate_fn=WindowedEMGDataset.collate,
    num_workers=num_workers
)

test_loader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,  # No shuffle for testing
    collate_fn=WindowedEMGDataset.collate,
    num_workers=num_workers
)

# for debugging
small_train_loader = DataLoader(
    Subset(train_dataset, list(range(100))),
    batch_size=batch_size,
    shuffle=True,
    collate_fn=WindowedEMGDataset.collate,
    num_workers=num_workers
)

small_val_loader = DataLoader(
    Subset(val_dataset, list(range(100))),
    batch_size=batch_size,
    shuffle=False,  # No shuffle for validation
    collate_fn=WindowedEMGDataset.collate,
    num_workers=num_workers
)

print(f"Created train_loader with {len(train_dataset)} windows")
print(f"Created val_loader with {len(val_dataset)} windows")
print(f"Created test_loader with {len(test_dataset)} windows")


Created train_loader with 30713 windows
Created val_loader with 1698 windows
Created test_loader with 2251 windows


In [4]:
import pandas as pd
import h5py
from pathlib import Path
import yaml
import json

# Load the configuration file
config_path = Path("C:/Users/aisha/OneDrive/Documents/cs247_project/config/user/single_user.yaml")
with open(config_path, 'r') as f:
    # Skip the first line which contains "# @package _global_"
    first_line = f.readline()
    if first_line.startswith('# @package'):
        cfg = yaml.safe_load(f)
    else:
        # If the first line doesn't have the package declaration, reset to beginning of file
        f.seek(0)
        cfg = yaml.safe_load(f)

# Set the data root path (adjust this to your actual data location)
data_root = Path("C:/Users/aisha/OneDrive/Documents/cs247_project/data")  # Update this to your actual data path

# Function to extract metadata from a session file
def extract_session_metadata(file_path):
    try:
        with h5py.File(file_path, 'r') as f:
            emg2qwerty_group = f['emg2qwerty']

            # Get metadata attributes
            metadata = {}
            for key, val in emg2qwerty_group.attrs.items():
                if key in ['keystrokes', 'prompts']:
                    try:
                        metadata[key] = json.loads(val)
                    except json.JSONDecodeError:
                        # Fall back to yaml if json fails
                        metadata[key] = yaml.safe_load(val)
                else:
                    metadata[key] = val

            # Calculate basic statistics
            session_stats = {
                'user': metadata.get('user', 'unknown'),
                'session': metadata.get('session_name', 'unknown'),
                'duration_mins': metadata.get('duration_mins', 0),
                'duration_hours': metadata.get('duration_mins', 0) / 60.0,
                'num_keystrokes': len(metadata.get('keystrokes', [])),
                'num_prompts': len(metadata.get('prompts', [])),
                'split': 'unknown'  # Will be set later
            }

            return session_stats
    except Exception as e:
        print(f"Error processing {file_path}: {str(e)}")
        return None

# Collect all sessions from config
all_sessions = []
splits = ['train', 'val', 'test']

for split in splits:
    if split in cfg['dataset']:
        for session_info in cfg['dataset'][split]:
            session_info_dict = {
                'user': session_info['user'],
                'session': session_info['session'],
                'split': split
            }
            all_sessions.append(session_info_dict)

# Process each session to get metadata
session_stats = []

for session_info in all_sessions:
    user_id = session_info['user']
    session_id = session_info['session']
    split = session_info['split']
    file_path = data_root / f"{session_id}.hdf5"

    if file_path.exists():
        stats = extract_session_metadata(file_path)
        if stats:
            stats['split'] = split  # Add which split this session belongs to
            session_stats.append(stats)
    else:
        print(f"Warning: File not found: {file_path}")

# Create a DataFrame from the collected statistics
df = pd.DataFrame(session_stats)
print(df)

# Calculate aggregated statistics per split
split_stats = df.groupby('split').agg({
    'session': 'count',
    'duration_hours': 'sum',
    'num_keystrokes': 'sum',
    'num_prompts': 'sum'
}).rename(columns={'session': 'count'})

# Calculate overall statistics
num_sessions = len(df)
total_duration_hours = df["duration_hours"].sum()
total_keystrokes = df["num_keystrokes"].sum()
total_prompts = df["num_prompts"].sum()

# Print the results
print("Overall statistics:")
print(f"Number of sessions: {num_sessions}")
print(f"Total duration (hours): {total_duration_hours:.2f}")
print(f"Total keystrokes: {total_keystrokes}")
print(f"Total prompts: {total_prompts}")

print("\nStatistics by split:")
print(split_stats)

# Also print individual statistics for train, val, and test sets
for split in splits:
    if split in df['split'].values:
        split_df = df[df['split'] == split]
        print(f"\n{split.capitalize()} set statistics:")
        print(f"Number of sessions: {len(split_df)}")
        print(f"Total duration (hours): {split_df['duration_hours'].sum():.2f}")
        print(f"Total keystrokes: {split_df['num_keystrokes'].sum()}")
        print(f"Total prompts: {split_df['num_prompts'].sum()}")

# Save to CSV files
df.to_csv("all_session_metadata.csv", index=False)
split_stats.to_csv("split_statistics.csv")


        user                                            session  \
0   89335547  2021-06-03-1622765527-keystrokes-dca-study@1-0...   
1   89335547  2021-06-02-1622681518-keystrokes-dca-study@1-0...   
2   89335547  2021-06-04-1622863166-keystrokes-dca-study@1-0...   
3   89335547  2021-07-22-1627003020-keystrokes-dca-study@1-0...   
4   89335547  2021-07-21-1626916256-keystrokes-dca-study@1-0...   
5   89335547  2021-07-22-1627004019-keystrokes-dca-study@1-0...   
6   89335547  2021-06-05-1622885888-keystrokes-dca-study@1-0...   
7   89335547  2021-06-02-1622679967-keystrokes-dca-study@1-0...   
8   89335547  2021-06-03-1622764398-keystrokes-dca-study@1-0...   
9   89335547  2021-07-21-1626917264-keystrokes-dca-study@1-0...   
10  89335547  2021-06-05-1622889105-keystrokes-dca-study@1-0...   
11  89335547  2021-06-03-1622766673-keystrokes-dca-study@1-0...   
12  89335547  2021-06-04-1622861066-keystrokes-dca-study@1-0...   
13  89335547  2021-07-22-1627001995-keystrokes-dca-study@1-0..

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import math
import numpy as np
from tqdm import tqdm
from pathlib import Path
import matplotlib.pyplot as plt
from emg2qwerty.charset import charset

In [7]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        # Create position encodings
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        # x is [seq_len, batch, features]
        return x + self.pe[:x.size(0), :]

In [8]:
class ConvModule(nn.Module):
    def __init__(self, d_model, kernel_size=15, expansion_factor=2, dropout=0.1):
        super().__init__()

        # Pointwise Conv -> Depthwise Conv -> Pointwise Conv
        self.layer_norm = nn.LayerNorm(d_model)
        self.pointwise_conv1 = nn.Conv1d(d_model, d_model * expansion_factor, kernel_size=1)
        self.glu = nn.GLU(dim=1)
        self.depthwise_conv = nn.Conv1d(
            d_model, d_model,
            kernel_size=kernel_size,
            padding=(kernel_size-1)//2,
            groups=d_model
        )
        self.batch_norm = nn.BatchNorm1d(d_model)
        self.swish = nn.SiLU()
        self.pointwise_conv2 = nn.Conv1d(d_model, d_model, kernel_size=1)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # x is [batch, seq_len, d_model]
        residual = x

        # Layer norm
        x = self.layer_norm(x)

        # Transpose to [batch, d_model, seq_len] for Conv1d
        x = x.transpose(1, 2)

        # Pointwise conv and GLU activation
        x = self.pointwise_conv1(x)
        x = self.glu(x)

        # Depthwise conv
        x = self.depthwise_conv(x)
        x = self.batch_norm(x)
        x = self.swish(x)

        # Pointwise conv
        x = self.pointwise_conv2(x)
        x = self.dropout(x)

        # Transpose back to [batch, seq_len, d_model]
        x = x.transpose(1, 2)

        # Residual connection
        return x + residual

In [9]:
class FeedForwardModule(nn.Module):
    def __init__(self, d_model, expansion_factor=4, dropout=0.1):
        super().__init__()

        self.layer_norm = nn.LayerNorm(d_model)
        self.linear1 = nn.Linear(d_model, d_model * expansion_factor)
        self.swish = nn.SiLU()
        self.dropout1 = nn.Dropout(dropout)
        self.linear2 = nn.Linear(d_model * expansion_factor, d_model)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x):
        residual = x

        x = self.layer_norm(x)
        x = self.linear1(x)
        x = self.swish(x)
        x = self.dropout1(x)
        x = self.linear2(x)
        x = self.dropout2(x)

        return x + residual

In [10]:
class ConformerBlock(nn.Module):
    def __init__(self, d_model, nhead=8, dropout=0.1, kernel_size=15):
        super().__init__()

        # First feed-forward module (half-step)
        self.feed_forward1 = FeedForwardModule(d_model, dropout=dropout)

        # Multi-headed self-attention
        self.self_attn_layer_norm = nn.LayerNorm(d_model)
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=False)
        self.self_attn_dropout = nn.Dropout(dropout)

        # Convolution module
        self.conv_module = ConvModule(d_model, kernel_size=kernel_size, dropout=dropout)

        # Second feed-forward module (half-step)
        self.feed_forward2 = FeedForwardModule(d_model, dropout=dropout)

        # Final layer norm
        self.final_layer_norm = nn.LayerNorm(d_model)

    def forward(self, x, attention_mask=None):
        """
        Args:
            x: input tensor of shape [seq_len, batch, d_model]
            attention_mask: mask for self-attention of shape [batch, seq_len]
        """
        # Apply first FFN (half step)
        x = x + 0.5 * self.feed_forward1(x)

        # Apply self-attention
        residual = x
        x = self.self_attn_layer_norm(x)

        # For PyTorch's MultiheadAttention when batch_first=False:
        # x shape should be [seq_len, batch, d_model]
        # key_padding_mask should be [batch, seq_len]
        attn_output, _ = self.self_attn(
            x, x, x,
            key_padding_mask=attention_mask  # Already in correct shape [batch, seq_len]
        )

        x = residual + self.self_attn_dropout(attn_output)

        # Apply convolution module
        residual = x
        x_t = x.transpose(0, 1)  # [batch, seq_len, d_model]
        x_t = self.conv_module(x_t)
        x = residual + x_t.transpose(0, 1)  # [seq_len, batch, d_model]

        # Apply second FFN (half step)
        x = x + 0.5 * self.feed_forward2(x)

        # Final layer norm
        x = self.final_layer_norm(x)

        return x

In [11]:
class EMGConformer(nn.Module):
    def __init__(self, num_classes=99, d_model=256, nhead=8, num_layers=4, dropout=0.3):
        super().__init__()
        
        # Input EMG signal dimensions
        self.features = 1056  # 2 bands * 16 channels * 33 frequencies
        
        # CNN feature extraction with output channels = cnn_output_dim
        cnn_output_dim = 256  # This should match your last CNN layer's output channels
        
        self.conv_layers = nn.Sequential(
            nn.Conv1d(self.features, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(dropout),
            
            nn.Conv1d(512, cnn_output_dim, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm1d(cnn_output_dim),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=2, stride=2),
            nn.Dropout(dropout)
        )
        
        # Add projection layer to ensure dimensions match
        self.projection = nn.Linear(cnn_output_dim, d_model) if cnn_output_dim != d_model else nn.Identity()
        
        # Positional encoding
        self.pos_encoder = PositionalEncoding(d_model)
        
        # Conformer blocks
        self.conformer_blocks = nn.ModuleList([
            ConformerBlock(d_model, nhead, dropout, kernel_size=15)
            for _ in range(num_layers)
        ])
        
        # Classification head
        self.classifier = nn.Linear(d_model, num_classes)
    
    def forward(self, x, input_lengths=None):
        # x shape: [T, N, B, C, F] - [time, batch, bands, channels, freq]
        T, N, B, C, F = x.shape
        device = x.device
        
        # Flatten features and prepare for Conv1d
        x = x.reshape(T, N, B*C*F).permute(1, 2, 0)  # [N, B*C*F, T]
        
        # Apply CNN layers
        x = self.conv_layers(x)  # [N, cnn_output_dim, T//2]
        
        # Calculate new sequence lengths after CNN pooling 
        if input_lengths is not None:
            new_lengths = torch.div(input_lengths, 2, rounding_mode='floor')
            new_lengths = torch.clamp(new_lengths, min=1)
            # Create padding mask for transformer
            max_len = x.size(2)
            padding_mask = (torch.arange(max_len, device=device).expand(N, max_len) 
                           >= new_lengths.unsqueeze(1))
        else:
            padding_mask = None
        
        # Prepare for conformer: [T//2, N, cnn_output_dim]
        x = x.permute(2, 0, 1)
        
        # Apply projection if needed
        x = self.projection(x)  # [T//2, N, d_model]
        
        # Apply positional encoding
        x = self.pos_encoder(x)
        
        # Apply conformer blocks
        for block in self.conformer_blocks:
            x = block(x, padding_mask)
        
        # Apply classifier
        x = self.classifier(x)
        
        return x

In [12]:
def calculate_cer(predictions, targets):
    """
    Calculate Character Error Rate
    """
    total_edits = 0
    total_length = 0

    for pred, target in zip(predictions, targets):
        # Remove blank tokens and duplicates
        filtered_pred = []
        prev = None
        for p in pred:
            if p != charset().null_class and p != prev:
                filtered_pred.append(p)
            prev = p

        # Calculate edit distance
        edit_distance = levenshtein_distance(filtered_pred, target)

        total_edits += edit_distance
        total_length += len(target)

    # Return CER
    return total_edits / max(1, total_length)

def levenshtein_distance(s1, s2):
    """
    Calculate Levenshtein distance between two sequences
    """
    if len(s1) < len(s2):
        return levenshtein_distance(s2, s1)

    if len(s2) == 0:
        return len(s1)

    previous_row = range(len(s2) + 1)
    for i, c1 in enumerate(s1):
        current_row = [i + 1]
        for j, c2 in enumerate(s2):
            insertions = previous_row[j + 1] + 1
            deletions = current_row[j] + 1
            substitutions = previous_row[j] + (c1 != c2)
            current_row.append(min(insertions, deletions, substitutions))
        previous_row = current_row

    return previous_row[-1]

In [13]:
def decode_predictions(log_probs, input_lengths):
    """Perform CTC greedy decoding on log probabilities"""
    predictions = log_probs.argmax(dim=2)
    decoded = []
    
    for i in range(predictions.size(1)):  # Loop through batch
        pred_len = input_lengths[i].item()
        pred = predictions[:pred_len, i].tolist()
        
        # Remove consecutive duplicates and blanks (CTC decoding)
        result = []
        prev = None
        for p in pred:
            if p != 0:  # Not blank
                if p != prev:  # Not a duplicate
                    result.append(p)
            prev = p
        
        decoded.append(result)
    
    return decoded

In [14]:
def train_model(model, train_loader, val_loader, device, epochs=100, lr=0.001, weight_decay=1e-4):
    """
    Train the Conformer model on EMG data
    """
    model = model.to(device)

    # CTC loss for sequence prediction
    criterion = nn.CTCLoss(blank=0, reduction='mean', zero_infinity=True)

    # Adam optimizer with improved parameters
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay, betas=(0.9, 0.999), eps=1e-8)

    # Learning rate scheduler with warmup
    scheduler = optim.lr_scheduler.OneCycleLR(
        optimizer,
        max_lr=0.002,  # Lower max learning rate
        steps_per_epoch=len(train_loader),
        epochs=epochs,
        pct_start=0.1,  # 10% warmup
        div_factor=10,  # Initial lr = max_lr/div_factor
        final_div_factor=100  # Final lr = initial_lr/final_div_factor
    )

    # Add these variables at the beginning of the function:
    best_val_cer = float('inf')
    best_epoch = 0
    
    # Training history
    history = {
        'train_loss': [],
        'val_loss': [],
        'val_cer': [],
        'all_predictions': [],
        'all_targets': []
    }

    # Training loop
    for epoch in range(epochs):
        # Training phase
        model.train()
        train_loss = 0.0
        train_batches = 0

        progress_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{epochs} [Train]')
        for batch_idx, batch in enumerate(progress_bar):
            inputs = batch['inputs'].to(device)
            targets = batch['targets'].to(device)
            input_lengths = batch['input_lengths'].to(device)
            target_lengths = batch['target_lengths'].to(device)

            # Forward pass
            outputs = model(inputs, input_lengths)

            # Compute log probabilities
            log_probs = outputs.log_softmax(2)

            # Adjust input lengths for downsampling in the model
            # We only have one pooling layer with stride 2
            input_lengths = torch.div(input_lengths, 2, rounding_mode='floor')
            input_lengths = torch.clamp(input_lengths, min=1)  # Ensure no zero lengths

            # Compute loss
            loss = criterion(log_probs, targets.T, input_lengths, target_lengths)

            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()

            # Gradient clipping to prevent exploding gradients
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)

            optimizer.step()
            scheduler.step()

            # Update statistics
            train_loss += loss.item()
            train_batches += 1

            # Update progress bar
            progress_bar.set_postfix({'loss': loss.item()})

        avg_train_loss = train_loss / train_batches
        history['train_loss'].append(avg_train_loss)

        # Validation phase
        model.eval()
        val_loss = 0.0
        val_batches = 0
        all_predictions = []
        all_targets = []

        with torch.no_grad():
            progress_bar = tqdm(val_loader, desc=f'Epoch {epoch+1}/{epochs} [Val]')
            for batch_idx, batch in enumerate(progress_bar):
                inputs = batch['inputs'].to(device)
                targets = batch['targets'].to(device)
                input_lengths = batch['input_lengths'].to(device)
                target_lengths = batch['target_lengths'].to(device)

                # Forward pass
                outputs = model(inputs, input_lengths)

                # Compute log probabilities
                log_probs = outputs.log_softmax(2)

                # Adjust input lengths for downsampling in the model
                input_lengths = torch.div(input_lengths, 2, rounding_mode='floor')
                input_lengths = torch.clamp(input_lengths, min=1)  # Ensure no zero lengths

                # Compute loss
                loss = criterion(log_probs, targets.T, input_lengths, target_lengths)

                # Update statistics
                val_loss += loss.item()
                val_batches += 1

                # Decode predictions using proper CTC decoding
                predictions = decode_predictions(log_probs.cpu(), input_lengths.cpu())
                targets_np = targets.cpu().numpy()

                # Store predictions and targets for CER calculation
                for i in range(inputs.size(1)):  # Loop through batch
                    target_seq = targets_np[:target_lengths[i].item(), i]
                    all_predictions.append(predictions[i])  # Already processed by decode_predictions
                    all_targets.append(target_seq)

                # Update progress bar
                progress_bar.set_postfix({'loss': loss.item()})

        avg_val_loss = val_loss / val_batches
        history['val_loss'].append(avg_val_loss)

        # Calculate Character Error Rate
        cer = calculate_cer(all_predictions, all_targets)
        history['val_cer'].append(cer)
        history['all_predictions'].append(all_predictions)
        history['all_targets'].append(all_targets)

        # Add checkpoint saving here:
        if cer < best_val_cer:
            best_val_cer = cer
            best_epoch = epoch + 1
            torch.save(model.state_dict(), 'conformer_best.pth')
            print(f"Saved checkpoint at epoch {epoch+1} with CER: {best_val_cer:.4f}")

        # Print epoch summary
        print(f'Epoch {epoch+1}/{epochs}:')
        print(f'  Train Loss: {avg_train_loss:.4f}')
        print(f'  Val Loss: {avg_val_loss:.4f}, Val CER: {cer:.4f}')

    return history, best_epoch, best_val_cer

In [15]:
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Initialize the EMGConformer model with better parameters
model = EMGConformer(
    num_classes=99,
    d_model=256,        # Increased from 128 for better representation
    nhead=8,            # Increased from 4 for more attention heads
    num_layers=4,       # Increased from 2 for deeper model
    dropout=0.3         # Reduced from 0.5 to prevent overfitting
)

# Print model architecture and parameter count
print(model)
total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params:,}")

# Move the model to the device
model = model.to(device)

# Test with a small batch
try:
    print("Testing model with a small batch...")
    # Get a small batch
    test_batch = next(iter(train_loader))
    test_inputs = test_batch['inputs'].to(device)
    test_lengths = test_batch['input_lengths'].to(device)
    
    # Test forward pass
    with torch.no_grad():
        outputs = model(test_inputs, test_lengths)
    
    print(f"Forward pass successful! Output shape: {outputs.shape}")
    
    # Now try the training
    print("Starting training...")
    history, best_epoch, best_val_cer = train_model(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        device=device,
        epochs=80,      # Train for more epochs
        lr=0.001,        # Better learning rate
        weight_decay=1e-4  # Slightly reduced weight decay
    )
    print(f"Best validation CER: {best_val_cer:.4f} at epoch {best_epoch}")
except Exception as e:
    print(f"Error encountered: {type(e).__name__}: {e}")
    
    # Try with small batch only for debugging
    if 'small_train_loader' in globals() and 'small_val_loader' in globals():
        print("Trying with small data loaders for debugging...")
        try:
            history = train_model(
                model=model,
                train_loader=small_train_loader,
                val_loader=small_val_loader,
                device=device,
                epochs=2,  # Just a few epochs
                lr=0.001,
                weight_decay=1e-4
            )
        except Exception as e2:
            print(f"Error with small data loaders: {type(e2).__name__}: {e2}")

Using device: cuda
EMGConformer(
  (conv_layers): Sequential(
    (0): Conv1d(1056, 512, kernel_size=(3,), stride=(1,), padding=(1,))
    (1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Dropout(p=0.3, inplace=False)
    (4): Conv1d(512, 256, kernel_size=(3,), stride=(1,), padding=(1,))
    (5): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU()
    (7): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (8): Dropout(p=0.3, inplace=False)
  )
  (projection): Identity()
  (pos_encoder): PositionalEncoding()
  (conformer_blocks): ModuleList(
    (0-3): 4 x ConformerBlock(
      (feed_forward1): FeedForwardModule(
        (layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (linear1): Linear(in_features=256, out_features=1024, bias=True)
        (swish): SiLU()
        (dropout1): Dropout(p=0.3, inplace=False)
        (linear2): Lin

Epoch 1/80 [Train]: 100%|██████████| 960/960 [00:33<00:00, 28.46it/s, loss=1.14] 
Epoch 1/80 [Val]: 100%|██████████| 54/54 [00:03<00:00, 15.98it/s, loss=1.4]  


Saved checkpoint at epoch 1 with CER: 0.5106
Epoch 1/80:
  Train Loss: 2.2054
  Val Loss: 0.9225, Val CER: 0.5106


Epoch 2/80 [Train]: 100%|██████████| 960/960 [00:33<00:00, 29.09it/s, loss=0.624]
Epoch 2/80 [Val]: 100%|██████████| 54/54 [00:03<00:00, 15.69it/s, loss=0.972] 


Saved checkpoint at epoch 2 with CER: 0.3709
Epoch 2/80:
  Train Loss: 0.9118
  Val Loss: 0.6061, Val CER: 0.3709


Epoch 3/80 [Train]: 100%|██████████| 960/960 [00:33<00:00, 28.71it/s, loss=1.35]   
Epoch 3/80 [Val]: 100%|██████████| 54/54 [00:03<00:00, 15.76it/s, loss=1.16]  


Epoch 3/80:
  Train Loss: 0.7622
  Val Loss: 0.6117, Val CER: 0.3822


Epoch 4/80 [Train]: 100%|██████████| 960/960 [00:32<00:00, 29.11it/s, loss=0.423]  
Epoch 4/80 [Val]: 100%|██████████| 54/54 [00:03<00:00, 15.80it/s, loss=1.21]  


Saved checkpoint at epoch 4 with CER: 0.3577
Epoch 4/80:
  Train Loss: 0.7273
  Val Loss: 0.5580, Val CER: 0.3577


Epoch 5/80 [Train]: 100%|██████████| 960/960 [00:33<00:00, 28.92it/s, loss=0.692]  
Epoch 5/80 [Val]: 100%|██████████| 54/54 [00:03<00:00, 15.74it/s, loss=0.579] 


Epoch 5/80:
  Train Loss: 0.7142
  Val Loss: 0.5184, Val CER: 0.3591


Epoch 6/80 [Train]: 100%|██████████| 960/960 [00:32<00:00, 29.42it/s, loss=0.822] 
Epoch 6/80 [Val]: 100%|██████████| 54/54 [00:03<00:00, 16.42it/s, loss=1.13]   


Epoch 6/80:
  Train Loss: 0.6616
  Val Loss: 0.6099, Val CER: 0.3610


Epoch 7/80 [Train]: 100%|██████████| 960/960 [00:32<00:00, 29.40it/s, loss=0.39]   
Epoch 7/80 [Val]: 100%|██████████| 54/54 [00:03<00:00, 16.04it/s, loss=0.525] 


Epoch 7/80:
  Train Loss: 0.6237
  Val Loss: 0.5552, Val CER: 0.3578


Epoch 8/80 [Train]: 100%|██████████| 960/960 [00:32<00:00, 29.33it/s, loss=0.593]  
Epoch 8/80 [Val]: 100%|██████████| 54/54 [00:03<00:00, 16.17it/s, loss=0.555] 


Saved checkpoint at epoch 8 with CER: 0.3275
Epoch 8/80:
  Train Loss: 0.5768
  Val Loss: 0.4302, Val CER: 0.3275


Epoch 9/80 [Train]: 100%|██████████| 960/960 [00:33<00:00, 28.99it/s, loss=0.362]   
Epoch 9/80 [Val]: 100%|██████████| 54/54 [00:03<00:00, 15.85it/s, loss=0.889] 


Epoch 9/80:
  Train Loss: 0.5261
  Val Loss: 0.4848, Val CER: 0.3328


Epoch 10/80 [Train]: 100%|██████████| 960/960 [00:32<00:00, 29.54it/s, loss=0.886]  
Epoch 10/80 [Val]: 100%|██████████| 54/54 [00:03<00:00, 16.22it/s, loss=0.577]  


Saved checkpoint at epoch 10 with CER: 0.3096
Epoch 10/80:
  Train Loss: 0.5010
  Val Loss: 0.4183, Val CER: 0.3096


Epoch 11/80 [Train]: 100%|██████████| 960/960 [00:32<00:00, 29.29it/s, loss=0.324]   
Epoch 11/80 [Val]: 100%|██████████| 54/54 [00:03<00:00, 16.17it/s, loss=0.922]  


Saved checkpoint at epoch 11 with CER: 0.3040
Epoch 11/80:
  Train Loss: 0.4649
  Val Loss: 0.3967, Val CER: 0.3040


Epoch 12/80 [Train]: 100%|██████████| 960/960 [00:33<00:00, 28.34it/s, loss=0.369]   
Epoch 12/80 [Val]: 100%|██████████| 54/54 [00:03<00:00, 15.98it/s, loss=0.387] 


Saved checkpoint at epoch 12 with CER: 0.2864
Epoch 12/80:
  Train Loss: 0.4270
  Val Loss: 0.3429, Val CER: 0.2864


Epoch 13/80 [Train]: 100%|██████████| 960/960 [00:39<00:00, 24.17it/s, loss=0.46]     
Epoch 13/80 [Val]: 100%|██████████| 54/54 [00:03<00:00, 16.52it/s, loss=0.494]  


Epoch 13/80:
  Train Loss: 0.4006
  Val Loss: 0.4266, Val CER: 0.3299


Epoch 14/80 [Train]: 100%|██████████| 960/960 [00:38<00:00, 24.85it/s, loss=0.411]   
Epoch 14/80 [Val]: 100%|██████████| 54/54 [00:03<00:00, 16.36it/s, loss=0.392] 


Epoch 14/80:
  Train Loss: 0.4020
  Val Loss: 0.4204, Val CER: 0.3160


Epoch 15/80 [Train]: 100%|██████████| 960/960 [00:38<00:00, 24.86it/s, loss=0.433]   
Epoch 15/80 [Val]: 100%|██████████| 54/54 [00:03<00:00, 16.38it/s, loss=0.667]   


Epoch 15/80:
  Train Loss: 0.3708
  Val Loss: 0.3177, Val CER: 0.2879


Epoch 16/80 [Train]: 100%|██████████| 960/960 [00:39<00:00, 24.29it/s, loss=0.393]   
Epoch 16/80 [Val]: 100%|██████████| 54/54 [00:03<00:00, 16.16it/s, loss=0.573]  


Epoch 16/80:
  Train Loss: 0.3397
  Val Loss: 0.3063, Val CER: 0.2905


Epoch 17/80 [Train]: 100%|██████████| 960/960 [00:39<00:00, 24.29it/s, loss=0.852]   
Epoch 17/80 [Val]: 100%|██████████| 54/54 [00:03<00:00, 16.30it/s, loss=0.732]  


Saved checkpoint at epoch 17 with CER: 0.2856
Epoch 17/80:
  Train Loss: 0.3168
  Val Loss: 0.2780, Val CER: 0.2856


Epoch 18/80 [Train]: 100%|██████████| 960/960 [00:39<00:00, 24.58it/s, loss=0.281]   
Epoch 18/80 [Val]: 100%|██████████| 54/54 [00:03<00:00, 16.85it/s, loss=0.387]  


Saved checkpoint at epoch 18 with CER: 0.2730
Epoch 18/80:
  Train Loss: 0.3066
  Val Loss: 0.2858, Val CER: 0.2730


Epoch 19/80 [Train]: 100%|██████████| 960/960 [00:38<00:00, 24.69it/s, loss=0.0928]   
Epoch 19/80 [Val]: 100%|██████████| 54/54 [00:03<00:00, 15.44it/s, loss=0.0928] 


Saved checkpoint at epoch 19 with CER: 0.2651
Epoch 19/80:
  Train Loss: 0.2905
  Val Loss: 0.2598, Val CER: 0.2651


Epoch 20/80 [Train]: 100%|██████████| 960/960 [00:40<00:00, 23.89it/s, loss=0.195]    
Epoch 20/80 [Val]: 100%|██████████| 54/54 [00:03<00:00, 15.83it/s, loss=0.724]  


Epoch 20/80:
  Train Loss: 0.2932
  Val Loss: 0.2891, Val CER: 0.2768


Epoch 21/80 [Train]: 100%|██████████| 960/960 [00:40<00:00, 23.75it/s, loss=0.523]    
Epoch 21/80 [Val]: 100%|██████████| 54/54 [00:03<00:00, 15.72it/s, loss=0.568]  


Saved checkpoint at epoch 21 with CER: 0.2626
Epoch 21/80:
  Train Loss: 0.2658
  Val Loss: 0.2769, Val CER: 0.2626


Epoch 22/80 [Train]: 100%|██████████| 960/960 [00:40<00:00, 23.69it/s, loss=0.547]   
Epoch 22/80 [Val]: 100%|██████████| 54/54 [00:03<00:00, 15.31it/s, loss=0.983]  


Epoch 22/80:
  Train Loss: 0.2509
  Val Loss: 0.3676, Val CER: 0.3025


Epoch 23/80 [Train]: 100%|██████████| 960/960 [00:39<00:00, 24.09it/s, loss=0.548]    
Epoch 23/80 [Val]: 100%|██████████| 54/54 [00:03<00:00, 15.67it/s, loss=0.567]  


Epoch 23/80:
  Train Loss: 0.2599
  Val Loss: 0.2906, Val CER: 0.2711


Epoch 24/80 [Train]: 100%|██████████| 960/960 [00:40<00:00, 23.83it/s, loss=0.421]   
Epoch 24/80 [Val]: 100%|██████████| 54/54 [00:03<00:00, 15.50it/s, loss=0.0397] 


Epoch 24/80:
  Train Loss: 0.2240
  Val Loss: 0.2479, Val CER: 0.2629


Epoch 25/80 [Train]: 100%|██████████| 960/960 [00:40<00:00, 23.50it/s, loss=0.305]   
Epoch 25/80 [Val]: 100%|██████████| 54/54 [00:03<00:00, 15.47it/s, loss=0.672]   


Saved checkpoint at epoch 25 with CER: 0.2550
Epoch 25/80:
  Train Loss: 0.2234
  Val Loss: 0.2140, Val CER: 0.2550


Epoch 26/80 [Train]: 100%|██████████| 960/960 [00:40<00:00, 23.61it/s, loss=0.207]   
Epoch 26/80 [Val]: 100%|██████████| 54/54 [00:03<00:00, 15.55it/s, loss=0.356] 


Epoch 26/80:
  Train Loss: 0.2012
  Val Loss: 0.2771, Val CER: 0.2608


Epoch 27/80 [Train]: 100%|██████████| 960/960 [00:39<00:00, 24.01it/s, loss=0.212]   
Epoch 27/80 [Val]: 100%|██████████| 54/54 [00:03<00:00, 15.72it/s, loss=0.66]    


Epoch 27/80:
  Train Loss: 0.1867
  Val Loss: 0.2468, Val CER: 0.2576


Epoch 28/80 [Train]: 100%|██████████| 960/960 [00:39<00:00, 24.37it/s, loss=0.61]     
Epoch 28/80 [Val]: 100%|██████████| 54/54 [00:03<00:00, 16.17it/s, loss=0.484]  


Epoch 28/80:
  Train Loss: 0.1725
  Val Loss: 0.2572, Val CER: 0.2609


Epoch 29/80 [Train]: 100%|██████████| 960/960 [00:40<00:00, 23.81it/s, loss=0.0386]  
Epoch 29/80 [Val]: 100%|██████████| 54/54 [00:03<00:00, 15.97it/s, loss=0.249]  


Saved checkpoint at epoch 29 with CER: 0.2493
Epoch 29/80:
  Train Loss: 0.1660
  Val Loss: 0.2479, Val CER: 0.2493


Epoch 30/80 [Train]: 100%|██████████| 960/960 [00:39<00:00, 24.41it/s, loss=0.457]    
Epoch 30/80 [Val]: 100%|██████████| 54/54 [00:03<00:00, 16.15it/s, loss=0.508]  


Epoch 30/80:
  Train Loss: 0.1616
  Val Loss: 0.3018, Val CER: 0.2802


Epoch 31/80 [Train]: 100%|██████████| 960/960 [00:39<00:00, 24.49it/s, loss=0.15]    
Epoch 31/80 [Val]: 100%|██████████| 54/54 [00:03<00:00, 16.39it/s, loss=0.0587] 


Epoch 31/80:
  Train Loss: 0.1661
  Val Loss: 0.2962, Val CER: 0.2721


Epoch 32/80 [Train]: 100%|██████████| 960/960 [00:39<00:00, 24.36it/s, loss=0.247]    
Epoch 32/80 [Val]: 100%|██████████| 54/54 [00:03<00:00, 16.21it/s, loss=0.222]  


Epoch 32/80:
  Train Loss: 0.1487
  Val Loss: 0.2747, Val CER: 0.2649


Epoch 33/80 [Train]: 100%|██████████| 960/960 [00:41<00:00, 23.23it/s, loss=0.197]   
Epoch 33/80 [Val]: 100%|██████████| 54/54 [00:03<00:00, 16.45it/s, loss=0.495] 


Epoch 33/80:
  Train Loss: 0.1169
  Val Loss: 0.3345, Val CER: 0.2828


Epoch 34/80 [Train]: 100%|██████████| 960/960 [00:40<00:00, 23.67it/s, loss=0.393]    
Epoch 34/80 [Val]: 100%|██████████| 54/54 [00:03<00:00, 16.13it/s, loss=0.509]  


Epoch 34/80:
  Train Loss: 0.1127
  Val Loss: 0.2756, Val CER: 0.2596


Epoch 35/80 [Train]: 100%|██████████| 960/960 [00:40<00:00, 23.63it/s, loss=0.137]    
Epoch 35/80 [Val]: 100%|██████████| 54/54 [00:03<00:00, 16.31it/s, loss=0.603] 


Epoch 35/80:
  Train Loss: 0.0957
  Val Loss: 0.3416, Val CER: 0.2723


Epoch 36/80 [Train]: 100%|██████████| 960/960 [00:40<00:00, 23.75it/s, loss=0.0285]   
Epoch 36/80 [Val]: 100%|██████████| 54/54 [00:03<00:00, 16.48it/s, loss=0.0522]  


Epoch 36/80:
  Train Loss: 0.1057
  Val Loss: 0.2494, Val CER: 0.2545


Epoch 37/80 [Train]: 100%|██████████| 960/960 [00:40<00:00, 23.96it/s, loss=0.203]    
Epoch 37/80 [Val]: 100%|██████████| 54/54 [00:03<00:00, 16.21it/s, loss=0.327]   


Epoch 37/80:
  Train Loss: 0.0795
  Val Loss: 0.2790, Val CER: 0.2650


Epoch 38/80 [Train]: 100%|██████████| 960/960 [00:40<00:00, 23.42it/s, loss=0.0552]   
Epoch 38/80 [Val]: 100%|██████████| 54/54 [00:03<00:00, 16.36it/s, loss=0.109]   


Epoch 38/80:
  Train Loss: 0.0631
  Val Loss: 0.2727, Val CER: 0.2565


Epoch 39/80 [Train]: 100%|██████████| 960/960 [00:40<00:00, 23.45it/s, loss=0.0177]   
Epoch 39/80 [Val]: 100%|██████████| 54/54 [00:03<00:00, 16.24it/s, loss=0.332]  


Epoch 39/80:
  Train Loss: 0.0611
  Val Loss: 0.3360, Val CER: 0.2815


Epoch 40/80 [Train]: 100%|██████████| 960/960 [00:41<00:00, 23.38it/s, loss=0.196]    
Epoch 40/80 [Val]: 100%|██████████| 54/54 [00:03<00:00, 16.37it/s, loss=0.495]  


Epoch 40/80:
  Train Loss: 0.0526
  Val Loss: 0.3599, Val CER: 0.2786


Epoch 41/80 [Train]: 100%|██████████| 960/960 [00:40<00:00, 23.58it/s, loss=-0.0945]  
Epoch 41/80 [Val]: 100%|██████████| 54/54 [00:03<00:00, 16.22it/s, loss=0.126]  


Epoch 41/80:
  Train Loss: 0.0367
  Val Loss: 0.3366, Val CER: 0.2733


Epoch 42/80 [Train]: 100%|██████████| 960/960 [00:41<00:00, 23.41it/s, loss=-0.145]   
Epoch 42/80 [Val]: 100%|██████████| 54/54 [00:03<00:00, 16.34it/s, loss=0.464] 


Saved checkpoint at epoch 42 with CER: 0.2486
Epoch 42/80:
  Train Loss: 0.0181
  Val Loss: 0.2574, Val CER: 0.2486


Epoch 43/80 [Train]: 100%|██████████| 960/960 [00:39<00:00, 24.09it/s, loss=-0.281]   
Epoch 43/80 [Val]: 100%|██████████| 54/54 [00:03<00:00, 16.31it/s, loss=0.162]  


Epoch 43/80:
  Train Loss: 0.0257
  Val Loss: 0.2492, Val CER: 0.2569


Epoch 44/80 [Train]: 100%|██████████| 960/960 [00:35<00:00, 27.18it/s, loss=-0.0184] 
Epoch 44/80 [Val]: 100%|██████████| 54/54 [00:03<00:00, 15.74it/s, loss=0.342]   


Epoch 44/80:
  Train Loss: 0.0196
  Val Loss: 0.2868, Val CER: 0.2637


Epoch 45/80 [Train]: 100%|██████████| 960/960 [00:33<00:00, 28.85it/s, loss=-0.227]  
Epoch 45/80 [Val]: 100%|██████████| 54/54 [00:03<00:00, 15.43it/s, loss=0.415]  


Epoch 45/80:
  Train Loss: -0.0002
  Val Loss: 0.3639, Val CER: 0.2869


Epoch 46/80 [Train]: 100%|██████████| 960/960 [00:32<00:00, 29.12it/s, loss=0.0876]   
Epoch 46/80 [Val]: 100%|██████████| 54/54 [00:03<00:00, 15.69it/s, loss=0.0137]


Epoch 46/80:
  Train Loss: -0.0062
  Val Loss: 0.3565, Val CER: 0.2790


Epoch 47/80 [Train]: 100%|██████████| 960/960 [00:33<00:00, 28.91it/s, loss=0.194]    
Epoch 47/80 [Val]: 100%|██████████| 54/54 [00:03<00:00, 15.68it/s, loss=0.113]  


Epoch 47/80:
  Train Loss: -0.0199
  Val Loss: 0.3026, Val CER: 0.2555


Epoch 48/80 [Train]: 100%|██████████| 960/960 [00:33<00:00, 29.07it/s, loss=-0.137]  
Epoch 48/80 [Val]: 100%|██████████| 54/54 [00:03<00:00, 15.65it/s, loss=0.526]  


Saved checkpoint at epoch 48 with CER: 0.2472
Epoch 48/80:
  Train Loss: -0.0281
  Val Loss: 0.2541, Val CER: 0.2472


Epoch 49/80 [Train]: 100%|██████████| 960/960 [00:32<00:00, 29.15it/s, loss=0.19]     
Epoch 49/80 [Val]: 100%|██████████| 54/54 [00:03<00:00, 15.59it/s, loss=0.188]  


Saved checkpoint at epoch 49 with CER: 0.2399
Epoch 49/80:
  Train Loss: -0.0287
  Val Loss: 0.2417, Val CER: 0.2399


Epoch 50/80 [Train]: 100%|██████████| 960/960 [00:32<00:00, 29.11it/s, loss=-0.126]   
Epoch 50/80 [Val]: 100%|██████████| 54/54 [00:03<00:00, 15.52it/s, loss=0.724]  


Epoch 50/80:
  Train Loss: -0.0421
  Val Loss: 0.3180, Val CER: 0.2514


Epoch 51/80 [Train]: 100%|██████████| 960/960 [00:33<00:00, 29.07it/s, loss=0.172]    
Epoch 51/80 [Val]: 100%|██████████| 54/54 [00:03<00:00, 15.57it/s, loss=0.6]   


Epoch 51/80:
  Train Loss: -0.0489
  Val Loss: 0.3614, Val CER: 0.2591


Epoch 52/80 [Train]: 100%|██████████| 960/960 [00:32<00:00, 29.20it/s, loss=0.205]    
Epoch 52/80 [Val]: 100%|██████████| 54/54 [00:03<00:00, 15.79it/s, loss=0.222]  


Epoch 52/80:
  Train Loss: -0.0536
  Val Loss: 0.2685, Val CER: 0.2470


Epoch 53/80 [Train]: 100%|██████████| 960/960 [00:32<00:00, 29.13it/s, loss=0.144]    
Epoch 53/80 [Val]: 100%|██████████| 54/54 [00:03<00:00, 15.70it/s, loss=0.473]  


Saved checkpoint at epoch 53 with CER: 0.2330
Epoch 53/80:
  Train Loss: -0.0559
  Val Loss: 0.2345, Val CER: 0.2330


Epoch 54/80 [Train]: 100%|██████████| 960/960 [00:32<00:00, 29.24it/s, loss=-0.124]   
Epoch 54/80 [Val]: 100%|██████████| 54/54 [00:03<00:00, 16.17it/s, loss=0.461]  


Epoch 54/80:
  Train Loss: -0.0684
  Val Loss: 0.3040, Val CER: 0.2478


Epoch 55/80 [Train]: 100%|██████████| 960/960 [00:32<00:00, 29.45it/s, loss=-0.123]   
Epoch 55/80 [Val]: 100%|██████████| 54/54 [00:03<00:00, 16.22it/s, loss=0.615] 


Epoch 55/80:
  Train Loss: -0.0782
  Val Loss: 0.3588, Val CER: 0.2705


Epoch 56/80 [Train]: 100%|██████████| 960/960 [00:32<00:00, 29.61it/s, loss=-0.136]   
Epoch 56/80 [Val]: 100%|██████████| 54/54 [00:03<00:00, 15.99it/s, loss=0.289]  


Epoch 56/80:
  Train Loss: -0.0819
  Val Loss: 0.3089, Val CER: 0.2468


Epoch 57/80 [Train]: 100%|██████████| 960/960 [00:32<00:00, 29.55it/s, loss=-0.28]    
Epoch 57/80 [Val]: 100%|██████████| 54/54 [00:03<00:00, 16.13it/s, loss=0.0276] 


Epoch 57/80:
  Train Loss: -0.0881
  Val Loss: 0.2875, Val CER: 0.2443


Epoch 58/80 [Train]: 100%|██████████| 960/960 [00:33<00:00, 29.07it/s, loss=0.27]     
Epoch 58/80 [Val]: 100%|██████████| 54/54 [00:03<00:00, 16.27it/s, loss=0.468]  


Epoch 58/80:
  Train Loss: -0.0963
  Val Loss: 0.2774, Val CER: 0.2360


Epoch 59/80 [Train]: 100%|██████████| 960/960 [00:32<00:00, 29.37it/s, loss=-0.124]  
Epoch 59/80 [Val]: 100%|██████████| 54/54 [00:03<00:00, 16.07it/s, loss=0.334]  


Epoch 59/80:
  Train Loss: -0.0984
  Val Loss: 0.2804, Val CER: 0.2433


Epoch 60/80 [Train]: 100%|██████████| 960/960 [00:32<00:00, 29.38it/s, loss=0.074]    
Epoch 60/80 [Val]: 100%|██████████| 54/54 [00:03<00:00, 16.23it/s, loss=0.139] 


Epoch 60/80:
  Train Loss: -0.1052
  Val Loss: 0.3254, Val CER: 0.2514


Epoch 61/80 [Train]: 100%|██████████| 960/960 [00:32<00:00, 29.32it/s, loss=-0.257]   
Epoch 61/80 [Val]: 100%|██████████| 54/54 [00:03<00:00, 16.46it/s, loss=0.544]  


Epoch 61/80:
  Train Loss: -0.1118
  Val Loss: 0.2960, Val CER: 0.2394


Epoch 62/80 [Train]: 100%|██████████| 960/960 [00:33<00:00, 28.89it/s, loss=0.0751]   
Epoch 62/80 [Val]: 100%|██████████| 54/54 [00:03<00:00, 16.62it/s, loss=0.538]   


Epoch 62/80:
  Train Loss: -0.1141
  Val Loss: 0.2867, Val CER: 0.2346


Epoch 63/80 [Train]: 100%|██████████| 960/960 [00:33<00:00, 29.09it/s, loss=-0.112]   
Epoch 63/80 [Val]: 100%|██████████| 54/54 [00:03<00:00, 16.54it/s, loss=0.285] 


Epoch 63/80:
  Train Loss: -0.1141
  Val Loss: 0.3237, Val CER: 0.2460


Epoch 64/80 [Train]: 100%|██████████| 960/960 [00:32<00:00, 29.34it/s, loss=-0.16]    
Epoch 64/80 [Val]: 100%|██████████| 54/54 [00:03<00:00, 16.31it/s, loss=0.401] 


Epoch 64/80:
  Train Loss: -0.1224
  Val Loss: 0.3426, Val CER: 0.2490


Epoch 65/80 [Train]: 100%|██████████| 960/960 [00:32<00:00, 29.34it/s, loss=-0.413]   
Epoch 65/80 [Val]: 100%|██████████| 54/54 [00:03<00:00, 16.87it/s, loss=0.525] 


Epoch 65/80:
  Train Loss: -0.1209
  Val Loss: 0.3377, Val CER: 0.2493


Epoch 66/80 [Train]: 100%|██████████| 960/960 [00:32<00:00, 29.10it/s, loss=-0.0381]  
Epoch 66/80 [Val]: 100%|██████████| 54/54 [00:03<00:00, 16.52it/s, loss=0.605]  


Epoch 66/80:
  Train Loss: -0.1267
  Val Loss: 0.3335, Val CER: 0.2436


Epoch 67/80 [Train]: 100%|██████████| 960/960 [00:32<00:00, 29.34it/s, loss=0.0604]   
Epoch 67/80 [Val]: 100%|██████████| 54/54 [00:03<00:00, 16.36it/s, loss=0.314]  


Epoch 67/80:
  Train Loss: -0.1325
  Val Loss: 0.3191, Val CER: 0.2410


Epoch 68/80 [Train]: 100%|██████████| 960/960 [00:32<00:00, 29.24it/s, loss=-0.0355]  
Epoch 68/80 [Val]: 100%|██████████| 54/54 [00:03<00:00, 16.52it/s, loss=0.394] 


Epoch 68/80:
  Train Loss: -0.1338
  Val Loss: 0.3422, Val CER: 0.2480


Epoch 69/80 [Train]: 100%|██████████| 960/960 [00:33<00:00, 29.06it/s, loss=-0.133]  
Epoch 69/80 [Val]: 100%|██████████| 54/54 [00:03<00:00, 16.67it/s, loss=0.656]  


Epoch 69/80:
  Train Loss: -0.1340
  Val Loss: 0.3087, Val CER: 0.2382


Epoch 70/80 [Train]: 100%|██████████| 960/960 [00:33<00:00, 29.08it/s, loss=-0.103]   
Epoch 70/80 [Val]: 100%|██████████| 54/54 [00:03<00:00, 16.67it/s, loss=0.661] 


Epoch 70/80:
  Train Loss: -0.1396
  Val Loss: 0.3529, Val CER: 0.2479


Epoch 71/80 [Train]: 100%|██████████| 960/960 [00:33<00:00, 28.88it/s, loss=-0.197]   
Epoch 71/80 [Val]: 100%|██████████| 54/54 [00:03<00:00, 16.12it/s, loss=0.392]  


Epoch 71/80:
  Train Loss: -0.1450
  Val Loss: 0.3062, Val CER: 0.2405


Epoch 72/80 [Train]: 100%|██████████| 960/960 [00:32<00:00, 29.19it/s, loss=-0.0383]  
Epoch 72/80 [Val]: 100%|██████████| 54/54 [00:03<00:00, 16.40it/s, loss=0.469]  


Epoch 72/80:
  Train Loss: -0.1408
  Val Loss: 0.3015, Val CER: 0.2362


Epoch 73/80 [Train]: 100%|██████████| 960/960 [00:32<00:00, 29.32it/s, loss=0.228]    
Epoch 73/80 [Val]: 100%|██████████| 54/54 [00:03<00:00, 16.14it/s, loss=0.35]   


Epoch 73/80:
  Train Loss: -0.1430
  Val Loss: 0.3227, Val CER: 0.2416


Epoch 74/80 [Train]: 100%|██████████| 960/960 [00:32<00:00, 29.48it/s, loss=-0.0179]  
Epoch 74/80 [Val]: 100%|██████████| 54/54 [00:03<00:00, 16.52it/s, loss=0.335] 


Epoch 74/80:
  Train Loss: -0.1386
  Val Loss: 0.3196, Val CER: 0.2419


Epoch 75/80 [Train]: 100%|██████████| 960/960 [00:33<00:00, 28.70it/s, loss=-0.034]   
Epoch 75/80 [Val]: 100%|██████████| 54/54 [00:03<00:00, 16.47it/s, loss=0.385]  


Epoch 75/80:
  Train Loss: -0.1491
  Val Loss: 0.3093, Val CER: 0.2383


Epoch 76/80 [Train]: 100%|██████████| 960/960 [00:32<00:00, 29.34it/s, loss=-0.232]   
Epoch 76/80 [Val]: 100%|██████████| 54/54 [00:03<00:00, 16.27it/s, loss=0.445]  


Epoch 76/80:
  Train Loss: -0.1487
  Val Loss: 0.3119, Val CER: 0.2364


Epoch 77/80 [Train]: 100%|██████████| 960/960 [00:32<00:00, 29.13it/s, loss=-0.146]  
Epoch 77/80 [Val]: 100%|██████████| 54/54 [00:03<00:00, 16.05it/s, loss=0.491]  


Epoch 77/80:
  Train Loss: -0.1469
  Val Loss: 0.3319, Val CER: 0.2444


Epoch 78/80 [Train]: 100%|██████████| 960/960 [00:33<00:00, 29.05it/s, loss=-0.057]  
Epoch 78/80 [Val]: 100%|██████████| 54/54 [00:03<00:00, 16.12it/s, loss=0.359]   


Epoch 78/80:
  Train Loss: -0.1488
  Val Loss: 0.3195, Val CER: 0.2391


Epoch 79/80 [Train]: 100%|██████████| 960/960 [00:32<00:00, 29.43it/s, loss=0.0622]   
Epoch 79/80 [Val]: 100%|██████████| 54/54 [00:03<00:00, 16.18it/s, loss=0.468]  


Epoch 79/80:
  Train Loss: -0.1471
  Val Loss: 0.3304, Val CER: 0.2420


Epoch 80/80 [Train]: 100%|██████████| 960/960 [00:32<00:00, 29.38it/s, loss=0.322]   
Epoch 80/80 [Val]: 100%|██████████| 54/54 [00:03<00:00, 16.27it/s, loss=0.493] 

Epoch 80/80:
  Train Loss: -0.1494
  Val Loss: 0.3460, Val CER: 0.2472
Best validation CER: 0.2330 at epoch 53





In [20]:
all_predictions_lst = history['all_predictions']
all_targets_lst = history['all_targets']

# Make sure this epoch exists in your history
epoch_id = min(60-1, len(all_predictions_lst)-1)  # Get the last epoch by default
print(f"Using results from epoch {epoch_id}")

# Output how many epochs of predictions we have
print(f"Number of epochs with predictions: {len(all_predictions_lst)}")

# Make sure we have predictions for this epoch
if epoch_id < len(all_predictions_lst) and all_predictions_lst[epoch_id]:
    # Get number of samples and choose one that exists
    num_samples = len(all_predictions_lst[epoch_id])
    it = min(12, num_samples-1)  # Default to sample 12 or the last one if fewer
    
    print(f"Number of samples in epoch {epoch_id}: {num_samples}")
    print(f"Sample {it} prediction: {all_predictions_lst[epoch_id][it]}")
    pred = all_predictions_lst[epoch_id][it]
    
    print(f"Number of targets in epoch {epoch_id}: {len(all_targets_lst[epoch_id])}")
    print(f"Sample {it} target: {all_targets_lst[epoch_id][it]}")
    target = all_targets_lst[epoch_id][it]
    
    # Convert indices to characters
    char_set = charset()
    
    # Filter out null class tokens and duplicates from predictions
    filtered_pred = []
    prev = None
    for p in pred:
        if p != char_set.null_class and p != prev:
            filtered_pred.append(p)
        prev = p
    
    pred_chars = char_set.labels_to_str(filtered_pred)
    target_chars = char_set.labels_to_str(target)
    print(f'prediction: {pred_chars}')
    print(f'target    : {target_chars}')
    
    # Show additional examples
    print("\nAdditional examples:")
    for i in range(5):
        sample_idx = min(i, num_samples-1)
        pred = all_predictions_lst[epoch_id][sample_idx]
        target = all_targets_lst[epoch_id][sample_idx]
        
        # Filter out null class tokens and duplicates
        filtered_pred = []
        prev = None
        for p in pred:
            if p != char_set.null_class and p != prev:
                filtered_pred.append(p)
            prev = p
            
        pred_chars = char_set.labels_to_str(filtered_pred)
        target_chars = char_set.labels_to_str(target)
        print(f'Example {i+1}:')
        print(f'  prediction: {pred_chars}')
        print(f'  target    : {target_chars}')
else:
    print(f"No predictions available for epoch {epoch_id}")

Using results from epoch 59
Number of epochs with predictions: 60
Number of samples in epoch 59: 1698
Sample 12 prediction: [1, 17, 14, 22, 7, 96, 5, 14, 96]
Number of targets in epoch 59: 1698
Sample 12 target: [ 1 17 14 22 13 96  5 14 23 96]
prediction: browh fo 
target    : brown fox 

Additional examples:
Example 1:
  prediction: 
  target    : 
Example 2:
  prediction: 
  target    : 
Example 3:
  prediction: 
  target    : 
Example 4:
  prediction: 
  target    : 
Example 5:
  prediction: o
  target    : 


In [19]:
def evaluate_on_test(model, test_loader, device):
    """Evaluate the model on test data and compute CER"""
    model.eval()
    all_predictions = []
    all_targets = []
    
    with torch.no_grad():
        progress_bar = tqdm(test_loader, desc="Test Evaluation")
        for batch in progress_bar:
            inputs = batch['inputs'].to(device)
            targets = batch['targets'].to(device)
            input_lengths = batch['input_lengths'].to(device)
            target_lengths = batch['target_lengths'].to(device)
            
            # Forward pass
            outputs = model(inputs, input_lengths)
            
            # Compute log probabilities
            log_probs = outputs.log_softmax(2)
            
            # Adjust input lengths for downsampling in the model
            input_lengths = torch.div(input_lengths, 2, rounding_mode='floor')
            input_lengths = torch.clamp(input_lengths, min=1)
            
            # Decode predictions
            predictions = decode_predictions(log_probs.cpu(), input_lengths.cpu())
            targets_np = targets.cpu().numpy()
            
            # Store predictions and targets for CER calculation
            for i in range(inputs.size(1)):
                target_seq = targets_np[:target_lengths[i].item(), i]
                all_predictions.append(predictions[i])
                all_targets.append(target_seq)
    
    # Calculate Character Error Rate
    cer = calculate_cer(all_predictions, all_targets)
    
    return cer, all_predictions, all_targets

# Load the best model
print(f"Loading best model from epoch {best_epoch}...")
best_model = EMGConformer(
    num_classes=99,
    d_model=256,
    nhead=8,
    num_layers=4,
    dropout=0.3
).to(device)
best_model.load_state_dict(torch.load('conformer_best.pth'))

# Evaluate on test data
test_cer, test_predictions, test_targets = evaluate_on_test(best_model, test_loader, device)
print(f"Test CER: {test_cer:.4f}")

# Optional: Display a few examples
char_set = charset()
print("\nTest Examples:")
for i in range(min(5, len(test_predictions))):
    pred = test_predictions[i]
    target = test_targets[i]
    
    # Filter out null class tokens and duplicates
    filtered_pred = []
    prev = None
    for p in pred:
        if p != char_set.null_class and p != prev:
            filtered_pred.append(p)
        prev = p
    
    pred_chars = char_set.labels_to_str(filtered_pred)
    target_chars = char_set.labels_to_str(target)
    print(f'Example {i+1}:')
    print(f'  prediction: {pred_chars}')
    print(f'  target    : {target_chars}')

Loading best model from epoch 43...


Test Evaluation: 100%|██████████| 71/71 [00:04<00:00, 17.69it/s]

Test CER: 0.3058

Test Examples:
Example 1:
  prediction: 
  target    : 
Example 2:
  prediction: 
  target    : 
Example 3:
  prediction: '
  target    : 
Example 4:
  prediction: 
  target    : 
Example 5:
  prediction: 
  target    : 





In [None]:
'''class SimpleEMGConformer(nn.Module):
    def __init__(self, num_classes=99, d_model=128, nhead=4, num_layers=2, dropout=0.5):
        super().__init__()
        
        # Input EMG signal dimensions
        self.features = 1056  # 2 bands * 16 channels * 33 frequencies
        
        # Enhanced CNN feature extraction
        self.conv_layers = nn.Sequential(
            nn.Conv1d(self.features, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(dropout),
            
            nn.Conv1d(256, d_model, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm1d(d_model),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=2, stride=2),  # Only one pooling layer
            nn.Dropout(dropout)
        )
        
        # Positional encoding
        self.pos_encoder = PositionalEncoding(d_model)
        
        # Simple Transformer encoder layers instead of full Conformer blocks
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=d_model*2,
            dropout=dropout,
            batch_first=False
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers)
        
        # Classification head
        self.classifier = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.Dropout(dropout),
            nn.ReLU(),
            nn.Linear(d_model, num_classes)
        )
    
    def forward(self, x, input_lengths=None):
        # x shape: [T, N, B, C, F] - [time, batch, bands, channels, freq]
        T, N, B, C, F = x.shape
        device = x.device
        
        # Flatten features and prepare for Conv1d
        x = x.reshape(T, N, B*C*F).permute(1, 2, 0)  # [N, B*C*F, T]
        
        # Apply CNN layers
        x = self.conv_layers(x)  # [N, d_model, T//2]
        
        # Calculate new sequence lengths after CNN pooling (only one pooling layer)
        if input_lengths is not None:
            new_lengths = torch.div(input_lengths, 2, rounding_mode='floor')
            new_lengths = torch.clamp(new_lengths, min=1)
            # Create padding mask for transformer
            max_len = x.size(2)
            padding_mask = (torch.arange(max_len, device=device).expand(N, max_len) 
                           >= new_lengths.unsqueeze(1))
        else:
            padding_mask = None
        
        # Prepare for transformer: [T//2, N, d_model]
        x = x.permute(2, 0, 1)
        
        # Apply positional encoding
        x = self.pos_encoder(x)
        
        # Apply transformer encoder
        x = self.transformer_encoder(x, src_key_padding_mask=padding_mask)
        
        # Apply classifier
        time_steps, batch_size, hidden_dim = x.size()
        x = x.reshape(-1, hidden_dim)
        x = self.classifier(x)
        x = x.view(time_steps, batch_size, -1)
        
        return x'''