<a href="https://colab.research.google.com/github/petervinhchau/public/blob/main/Copy_of_Untitled13.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Check PyTorch and CUDA environment
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

# Install required dependencies
!pip install pytorch-lightning torchaudio transformers tensorboard


In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

# Create symbolic link for the project if it does not already exist
#!ln -s /content/drive/MyDrive/colab/ece247/proj/tmp/emg2qwerty-main /content/emg2qwerty-main

!unzip -oq /content/drive/MyDrive/colab/ece247/proj/tmp/emg2qwerty-main.zip -d /content/

# Change directory to your project code
%cd /content/emg2qwerty-main/
!ls -l /content/emg2qwerty-main/

!ls -l /content/drive/MyDrive/colab/ece247/proj/tmp/data/

%cd /content/emg2qwerty-main/

In [None]:
%%writefile /content/emg2qwerty-main/models/lstm_encoder.py
import torch
import torch.nn as nn
import torch.nn.utils.rnn as rnn_utils

class LSTMEncoder(nn.Module):
    """
    Bidirectional LSTM encoder for EMG signals.

    Processes input sequences of shape (batch, time, features) and returns
    output sequences of shape (batch, time, output_dim), where output_dim is
    hidden_size * 2 (if bidirectional) to match the expected classifier input.
    """
    def __init__(self, input_dim, hidden_size=256, num_layers=2, bidirectional=True, dropout=0.2):
        super(LSTMEncoder, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.bidirectional = bidirectional
        self.input_dim = input_dim

        self.lstm = nn.LSTM(
            input_size=input_dim,
            hidden_size=hidden_size,
            num_layers=num_layers,
            bidirectional=bidirectional,
            batch_first=True,
            dropout=dropout if num_layers > 1 else 0
        )
        # Output dimension: hidden_size*2 if bidirectional; else hidden_size.
        self.output_dim = hidden_size * 2 if bidirectional else hidden_size

    def forward(self, x, lengths=None):
        """
        Forward pass through the LSTM encoder.

        Args:
            x: Tensor of shape (batch_size, time_steps, input_dim)
            lengths: Optional tensor with actual sequence lengths.

        Returns:
            lstm_out: Tensor of shape (batch_size, time_steps, output_dim)
        """
        if lengths is not None:
            # Pack the sequence for variable-length processing
            x_packed = rnn_utils.pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)
            lstm_out, _ = self.lstm(x_packed)
            lstm_out, _ = rnn_utils.pad_packed_sequence(lstm_out, batch_first=True)
        else:
            lstm_out, _ = self.lstm(x)
        return lstm_out


In [None]:
%%bash
# Ensure the models directory exists
if [ ! -d "/content/emg2qwerty-main/models" ]; then
    mkdir -p emg2qwerty-main/models
    echo "Created directory: emg2qwerty-main/models"
fi

# Ensure __init__.py exists in the models directory
if [ ! -f "/content/emg2qwerty-main/models/__init__.py" ]; then
    touch emg2qwerty-main/models/__init__.py
    echo "Created file: emg2qwerty-main/models/__init__.py"
fi

# Append import for LSTMEncoder if not already present
grep -q "LSTMEncoder" emg2qwerty-main/models/__init__.py || echo "from .lstm_encoder import LSTMEncoder" >> emg2qwerty-main/models/__init__.py


In [None]:
%%writefile /content/emg2qwerty-main/test_encoders.py
import torch
import argparse
from models.lstm_encoder import LSTMEncoder

# Attempt to import the TDS encoder if available.
try:
    from models.tds_encoder import TDSEncoder
except ImportError:
    print("WARNING: Could not import TDSEncoder. Only LSTM encoder will be tested.")
    TDSEncoder = None

def test_encoders(feature_dim=768, tds_params=None):
    batch_size, time_steps = 8, 100
    # Dummy input for LSTM encoder: (B, T, F)
    x_lstm = torch.randn(batch_size, time_steps, feature_dim)
    # Dummy input for TDS encoder: (B, F, T)
    x_tds = x_lstm.transpose(1, 2)

    # Test LSTM encoder
    lstm = LSTMEncoder(input_dim=feature_dim, hidden_size=256, num_layers=2, bidirectional=True)
    out_lstm = lstm(x_lstm)
    print(f"LSTM encoder output shape: {out_lstm.shape}")

    if TDSEncoder is not None and tds_params is not None:
        tds = TDSEncoder(**tds_params)
        out_tds = tds(x_tds)
        print(f"TDS encoder output shape: {out_tds.shape}")
        output_dims_match = out_tds.shape[-1] == out_lstm.shape[-1]
        time_dims_match = out_tds.shape[2] == out_lstm.shape[1]
        compatible = output_dims_match and time_dims_match
        print(f"Output dims match: {output_dims_match}")
        print(f"Time dims match: {time_dims_match}")
        print(f"Encoders are compatible: {compatible}")
        return compatible
    return True

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Test encoder compatibility")
    parser.add_argument("--feature_dim", type=int, default=768, help="Input feature dimension")
    args = parser.parse_args()
    # Update tds_params if your TDS encoder parameters are available
    tds_params = {}
    test_encoders(feature_dim=args.feature_dim, tds_params=tds_params)


In [None]:
%%writefile /content/emg2qwerty-main/models/emg_model.py
import torch
import torch.nn as nn
import pytorch_lightning as pl
import torch.nn.functional as F

# Import encoders
from .tds_encoder import TDSEncoder
from .lstm_encoder import LSTMEncoder

class EMGLightningModel(pl.LightningModule):
    def __init__(
        self,
        encoder_type="TDS",          # "TDS" or "LSTM"
        feature_dim=768,             # Input feature dimension
        num_classes=30,              # Number of output classes (including blank)
        learning_rate=1e-3,
        tds_params=None,             # TDS encoder parameters (if using TDS)
        lstm_hidden_size=256,
        lstm_num_layers=2,
        lstm_dropout=0.2,
        **kwargs
    ):
        super().__init__()
        self.save_hyperparameters()
        self.encoder_type = encoder_type.upper()

        if self.encoder_type == "TDS":
            if tds_params is None:
                tds_params = {}
            self.encoder = TDSEncoder(**tds_params)
            encoder_out_dim = self.encoder.output_dim
        elif self.encoder_type == "LSTM":
            self.encoder = LSTMEncoder(
                input_dim=feature_dim,
                hidden_size=lstm_hidden_size,
                num_layers=lstm_num_layers,
                bidirectional=True,
                dropout=lstm_dropout
            )
            encoder_out_dim = self.encoder.output_dim
        else:
            raise ValueError(f"Unknown encoder type: {encoder_type}")

        self.classifier = nn.Linear(encoder_out_dim, num_classes)
        self.ctc_loss = nn.CTCLoss(blank=0, zero_infinity=True)
        self.learning_rate = learning_rate

    def forward(self, x, lengths=None):
        """
        Forward pass.

        For LSTM, expects input shape: (B, T, F)
        For TDS, expects input shape: (B, F, T)
        """
        if self.encoder_type == "LSTM":
            if x.dim() == 3 and x.shape[1] < x.shape[2]:
                # Assume input is in (B, F, T); transpose to (B, T, F)
                x = x.transpose(1, 2)
        elif self.encoder_type == "TDS":
            if x.dim() == 3 and x.shape[1] > x.shape[2]:
                # Assume input is in (B, T, F); transpose to (B, F, T)
                x = x.transpose(1, 2)

        features = self.encoder(x, lengths)
        logits = self.classifier(features)
        return logits

    def training_step(self, batch, batch_idx):
        x, lengths, targets, target_lengths = batch
        logits = self(x, lengths)
        log_probs = F.log_softmax(logits, dim=-1).transpose(0, 1)
        loss = self.ctc_loss(log_probs, targets, lengths, target_lengths)
        self.log("train_loss", loss, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, lengths, targets, target_lengths = batch
        logits = self(x, lengths)
        log_probs = F.log_softmax(logits, dim=-1).transpose(0, 1)
        loss = self.ctc_loss(log_probs, targets, lengths, target_lengths)
        self.log("val_loss", loss, prog_bar=True)
        return loss

    def test_step(self, batch, batch_idx):
        x, lengths, targets, target_lengths = batch
        logits = self(x, lengths)
        log_probs = F.log_softmax(logits, dim=-1).transpose(0, 1)
        # Optionally, decode and compute error metrics here
        return {"test_log_probs": log_probs}

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate, weight_decay=1e-5)
        scheduler = {
            "scheduler": torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.trainer.max_epochs),
            "interval": "epoch",
            "frequency": 1,
        }
        return {"optimizer": optimizer, "lr_scheduler": scheduler}


In [None]:
%%writefile /content/emg2qwerty-main/train.py
import os
import argparse
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger

from models.emg_model import EMGLightningModel
from data.datamodule import EMGDataModule

def main():
    parser = argparse.ArgumentParser(description="Train EMG2QWERTY model")
    parser.add_argument('--encoder', type=str, default='TDS', choices=['TDS', 'LSTM'],
                        help='Encoder type: TDS (original) or LSTM (new)')
    parser.add_argument('--feature_dim', type=int, default=768, help='Input feature dimension')
    parser.add_argument('--lstm_hidden_size', type=int, default=256, help='LSTM hidden size')
    parser.add_argument('--lstm_num_layers', type=int, default=2, help='Number of LSTM layers')
    parser.add_argument('--lstm_dropout', type=float, default=0.2, help='Dropout rate for LSTM')
    parser.add_argument('--batch_size', type=int, default=32, help='Batch size for training')
    parser.add_argument('--epochs', type=int, default=50, help='Number of training epochs')
    parser.add_argument('--learning_rate', type=float, default=1e-3, help='Initial learning rate')
    parser.add_argument('--gradient_clip_val', type=float, default=5.0, help='Gradient clipping value')
    parser.add_argument('--data_dir', type=str, default='data', help='Dataset directory')
    parser.add_argument('--name', type=str, default='model', help='Name for saving model and logs')

    args = parser.parse_args()

    data_module = EMGDataModule(data_dir=args.data_dir, batch_size=args.batch_size)
    model = EMGLightningModel(
        encoder_type=args.encoder,
        feature_dim=args.feature_dim,
        lstm_hidden_size=args.lstm_hidden_size,
        lstm_num_layers=args.lstm_num_layers,
        lstm_dropout=args.lstm_dropout,
        learning_rate=args.learning_rate
    )

    logger = TensorBoardLogger(save_dir="lightning_logs", name=f"{args.encoder.lower()}_{args.name}")
    checkpoint_callback = ModelCheckpoint(
        monitor="val_loss",
        filename="best_model-{epoch:02d}-{val_loss:.2f}",
        save_top_k=3,
        mode="min"
    )
    early_stop_callback = EarlyStopping(monitor="val_loss", patience=10, mode="min")

    trainer = pl.Trainer(
        max_epochs=args.epochs,
        gradient_clip_val=args.gradient_clip_val,
        logger=logger,
        callbacks=[checkpoint_callback, early_stop_callback],
        accelerator="auto",
        devices="auto"
    )

    trainer.fit(model, data_module)
    trainer.test(model, data_module)

    print(f"Best model checkpoint: {checkpoint_callback.best_model_path}")

if __name__ == "__main__":
    main()


In [None]:
%%writefile /content/emg2qwerty-main/compare_models.py
import os
import torch
import pytorch_lightning as pl
import argparse
import time
import json
from models.emg_model import EMGLightningModel
from data.datamodule import EMGDataModule

def main():
    parser = argparse.ArgumentParser(description="Compare TDS and LSTM models")
    parser.add_argument('--tds_checkpoint', type=str, required=True, help='Path to TDS model checkpoint')
    parser.add_argument('--lstm_checkpoint', type=str, required=True, help='Path to LSTM model checkpoint')
    parser.add_argument('--data_dir', type=str, default='data', help='Test dataset directory')
    parser.add_argument('--batch_size', type=int, default=32, help='Batch size for evaluation')
    args = parser.parse_args()

    for ckpt in [args.tds_checkpoint, args.lstm_checkpoint]:
        if not os.path.exists(ckpt):
            raise FileNotFoundError(f"Checkpoint not found: {ckpt}")

    data_module = EMGDataModule(data_dir=args.data_dir, batch_size=args.batch_size)
    data_module.setup(stage='test')
    test_dataloader = data_module.test_dataloader()

    results = {"TDS": {}, "LSTM": {}, "comparison": {}}

    print("\nEvaluating TDS model...")
    tds_model = EMGLightningModel.load_from_checkpoint(args.tds_checkpoint)
    tds_model.eval()
    tds_params = sum(p.numel() for p in tds_model.parameters())
    results["TDS"]["parameters"] = tds_params
    print(f"TDS model parameters: {tds_params:,}")

    start_time = time.time()
    tds_trainer = pl.Trainer(accelerator="auto", devices=1)
    tds_results = tds_trainer.test(tds_model, test_dataloader)[0]
    tds_time = time.time() - start_time
    results["TDS"]["CER"] = tds_results.get("test_cer", None)
    results["TDS"]["WER"] = tds_results.get("test_wer", None)
    results["TDS"]["inference_time"] = tds_time
    print(f"TDS - CER: {results['TDS']['CER']:.2f}%, WER: {results['TDS']['WER']:.2f}%, Time: {tds_time:.2f}s")

    print("\nEvaluating LSTM model...")
    lstm_model = EMGLightningModel.load_from_checkpoint(args.lstm_checkpoint)
    lstm_model.eval()
    lstm_params = sum(p.numel() for p in lstm_model.parameters())
    results["LSTM"]["parameters"] = lstm_params
    print(f"LSTM model parameters: {lstm_params:,}")

    start_time = time.time()
    lstm_trainer = pl.Trainer(accelerator="auto", devices=1)
    lstm_results = lstm_trainer.test(lstm_model, test_dataloader)[0]
    lstm_time = time.time() - start_time
    results["LSTM"]["CER"] = lstm_results.get("test_cer", None)
    results["LSTM"]["WER"] = lstm_results.get("test_wer", None)
    results["LSTM"]["inference_time"] = lstm_time
    print(f"LSTM - CER: {results['LSTM']['CER']:.2f}%, WER: {results['LSTM']['WER']:.2f}%, Time: {lstm_time:.2f}s")

    if results["TDS"].get("CER") is not None and results["LSTM"].get("CER") is not None:
        results["comparison"]["CER_diff"] = results["LSTM"]["CER"] - results["TDS"]["CER"]
        results["comparison"]["CER_relative"] = (results["LSTM"]["CER"] / results["TDS"]["CER"] - 1) * 100
    if results["TDS"].get("WER") is not None and results["LSTM"].get("WER") is not None:
        results["comparison"]["WER_diff"] = results["LSTM"]["WER"] - results["TDS"]["WER"]
        results["comparison"]["WER_relative"] = (results["LSTM"]["WER"] / results["TDS"]["WER"] - 1) * 100
    if results["TDS"].get("inference_time") and results["LSTM"].get("inference_time"):
        results["comparison"]["speed_ratio"] = results["TDS"]["inference_time"] / results["LSTM"]["inference_time"]
    if results["TDS"].get("parameters") and results["LSTM"].get("parameters"):
        results["comparison"]["param_diff"] = results["LSTM"]["parameters"] - results["TDS"]["parameters"]
        results["comparison"]["param_ratio"] = results["LSTM"]["parameters"] / results["TDS"]["parameters"]

    print("\n=== Model Comparison Results ===")
    print(f"Character Error Rate: TDS = {results['TDS'].get('CER', 'N/A'):.2f}%, LSTM = {results['LSTM'].get('CER', 'N/A'):.2f}%")
    print(f"CER Difference: {results['comparison'].get('CER_diff', 'N/A'):.2f}% ({results['comparison'].get('CER_relative', 'N/A'):.1f}% relative)")
    print(f"Word Error Rate: TDS = {results['TDS'].get('WER', 'N/A'):.2f}%, LSTM = {results['LSTM'].get('WER', 'N/A'):.2f}%")
    print(f"WER Difference: {results['comparison'].get('WER_diff', 'N/A'):.2f}% ({results['comparison'].get('WER_relative', 'N/A'):.1f}% relative)")
    print(f"Model Size: TDS = {results['TDS'].get('parameters', 'N/A'):,}, LSTM = {results['LSTM'].get('parameters', 'N/A'):,}")
    print(f"Parameter Ratio: LSTM is {results['comparison'].get('param_ratio', 'N/A'):.2f}x the size of TDS")
    print(f"Speed Ratio (TDS inference / LSTM inference): {results['comparison'].get('speed_ratio', 'N/A'):.2f}")

    with open("model_comparison_results.json", "w") as f:
        json.dump(results, f, indent=4)
    print("\nComparison results saved to model_comparison_results.json")

if __name__ == "__main__":
    main()


In [None]:
!python test_encoders.py


In [None]:
%%writefile /content/drive/MyDrive/colab/ece247/proj/tmp/data/datamodule.py
from torch.utils.data import DataLoader, Dataset
import torch

class DummyEMGDataset(Dataset):
    def __init__(self, mode="train"):
        self.mode = mode
        # For demonstration, we create 100 dummy samples.
        self.samples = list(range(100))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        # Create a dummy tensor:
        # For example, for LSTM, assume input is (time_steps, features) with time_steps=100, feature_dim=768.
        x = torch.randn(100, 768)
        # Dummy target sequence (random integers in range 1 to num_classes-1, here assuming 30 classes with 0 as blank)
        y = torch.randint(1, 30, (10,))
        # For CTC, we also need lengths; here we assume full-length sequence.
        input_length = 100
        target_length = 10
        return x, input_length, y, target_length

class EMGDataModule:
    def __init__(self, data_dir, batch_size=32):
        self.data_dir = data_dir  # This can be used to load real data later.
        self.batch_size = batch_size

    def setup(self, stage=None):
        # Here you can implement different setups for train, val, and test.
        # For now, we use the same dummy dataset.
        self.train_dataset = DummyEMGDataset(mode="train")
        self.val_dataset = DummyEMGDataset(mode="val")
        self.test_dataset = DummyEMGDataset(mode="test")

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False)


In [None]:
%%writefile /content/drive/MyDrive/colab/ece247/proj/tmp/data/datamodule.py
import torch
from torch.utils.data import DataLoader, Dataset
import pytorch_lightning as pl

class DummyEMGDataset(Dataset):
    def __init__(self, mode="train"):
        self.mode = mode
        # For demonstration, create 100 dummy samples.
        self.samples = list(range(100))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        # Create a dummy tensor for input:
        # Assume input shape for LSTM is (time_steps, feature_dim) with time_steps=100 and feature_dim=768.
        x = torch.randn(100, 768)
        # Create a dummy target sequence (length 10)
        y = torch.randint(1, 30, (10,))
        # Define input and target lengths
        input_length = 100
        target_length = 10
        return x, input_length, y, target_length

class EMGDataModule(pl.LightningDataModule):
    def __init__(self, data_dir, batch_size=32):
        super().__init__()
        self.data_dir = data_dir  # Use this later to load real data
        self.batch_size = batch_size

    def prepare_data(self):
        # Implement data download or preprocessing here if needed.
        pass

    def setup(self, stage=None):
        # Setup datasets for train, val, and test.
        self.train_dataset = DummyEMGDataset(mode="train")
        self.val_dataset = DummyEMGDataset(mode="val")
        self.test_dataset = DummyEMGDataset(mode="test")

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False)

    def on_exception(self, exception):
        # Optional: define a dummy on_exception hook
        pass


In [None]:
%%writefile /content/emg2qwerty-main/models/lstm_encoder.py
import torch
import torch.nn as nn
import torch.nn.utils.rnn as rnn_utils

class LSTMEncoder(nn.Module):
    """
    Bidirectional LSTM encoder for EMG signals.

    Processes input sequences of shape (batch, time, features) and returns
    output sequences of shape (batch, time, output_dim), where output_dim is
    hidden_size * 2 (if bidirectional) to match the expected classifier input.
    """
    def __init__(self, input_dim, hidden_size=256, num_layers=2, bidirectional=True, dropout=0.2):
        super(LSTMEncoder, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.bidirectional = bidirectional
        self.input_dim = input_dim

        self.lstm = nn.LSTM(
            input_size=input_dim,
            hidden_size=hidden_size,
            num_layers=num_layers,
            bidirectional=bidirectional,
            batch_first=True,
            dropout=dropout if num_layers > 1 else 0
        )
        # Calculate output dimension: hidden_size * 2 if bidirectional; else hidden_size.
        self.output_dim = hidden_size * 2 if bidirectional else hidden_size

    def forward(self, x, lengths=None):
        """
        Forward pass through the LSTM encoder.

        Args:
            x: Tensor of shape (batch_size, time_steps, input_dim)
            lengths: Optional tensor with actual sequence lengths.

        Returns:
            lstm_out: Tensor of shape (batch_size, time_steps, output_dim)
        """
        if lengths is not None:
            # Ensure lengths is a CPU int64 tensor
            lengths = lengths.cpu().to(torch.int64)
            # Pack the sequence for variable-length processing.
            x_packed = rnn_utils.pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)
            lstm_out, _ = self.lstm(x_packed)
            lstm_out, _ = rnn_utils.pad_packed_sequence(lstm_out, batch_first=True)
        else:
            lstm_out, _ = self.lstm(x)
        return lstm_out


In [None]:
%%writefile /content/emg2qwerty-main/models/emg_model.py
import torch
import torch.nn as nn
import pytorch_lightning as pl
import torch.nn.functional as F

# Import encoders (if TDSEncoder is not used, you can leave it or create a dummy file)
from .tds_encoder import TDSEncoder
from .lstm_encoder import LSTMEncoder

class EMGLightningModel(pl.LightningModule):
    def __init__(
        self,
        encoder_type="TDS",          # "TDS" or "LSTM"
        feature_dim=768,             # Input feature dimension
        num_classes=30,              # Number of output classes (including blank)
        learning_rate=1e-3,
        tds_params=None,             # TDS encoder parameters (if using TDS)
        lstm_hidden_size=256,
        lstm_num_layers=2,
        lstm_dropout=0.2,
        **kwargs
    ):
        super().__init__()
        self.save_hyperparameters()
        self.encoder_type = encoder_type.upper()

        if self.encoder_type == "TDS":
            if tds_params is None:
                tds_params = {}
            self.encoder = TDSEncoder(**tds_params)
            encoder_out_dim = self.encoder.output_dim
        elif self.encoder_type == "LSTM":
            self.encoder = LSTMEncoder(
                input_dim=feature_dim,
                hidden_size=lstm_hidden_size,
                num_layers=lstm_num_layers,
                bidirectional=True,
                dropout=lstm_dropout
            )
            encoder_out_dim = self.encoder.output_dim
        else:
            raise ValueError(f"Unknown encoder type: {encoder_type}")

        self.classifier = nn.Linear(encoder_out_dim, num_classes)
        self.ctc_loss = nn.CTCLoss(blank=0, zero_infinity=True)
        self.learning_rate = learning_rate

    def forward(self, x, lengths=None):
        """
        Forward pass.

        For LSTM: Assumes input x is already in shape (B, T, F).
        For TDS: Assumes input x is in shape (B, T, F) and transposes to (B, F, T).
        """
        if self.encoder_type == "TDS":
            if x.dim() == 3 and x.shape[1] > x.shape[2]:
                # Convert from (B, T, F) to (B, F, T)
                x = x.transpose(1, 2)
        # For LSTM, no transposition is performed.
        features = self.encoder(x, lengths)
        logits = self.classifier(features)
        return logits

    def training_step(self, batch, batch_idx):
        x, lengths, targets, target_lengths = batch
        logits = self(x, lengths)
        log_probs = F.log_softmax(logits, dim=-1).transpose(0, 1)
        loss = self.ctc_loss(log_probs, targets, lengths, target_lengths)
        self.log("train_loss", loss, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, lengths, targets, target_lengths = batch
        logits = self(x, lengths)
        log_probs = F.log_softmax(logits, dim=-1).transpose(0, 1)
        loss = self.ctc_loss(log_probs, targets, lengths, target_lengths)
        self.log("val_loss", loss, prog_bar=True)
        return loss

    def test_step(self, batch, batch_idx):
        x, lengths, targets, target_lengths = batch
        logits = self(x, lengths)
        log_probs = F.log_softmax(logits, dim=-1).transpose(0, 1)
        return {"test_log_probs": log_probs}

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate, weight_decay=1e-5)
        scheduler = {
            "scheduler": torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.trainer.max_epochs),
            "interval": "epoch",
            "frequency": 1,
        }
        return {"optimizer": optimizer, "lr_scheduler": scheduler}


In [None]:
import sys
sys.path.insert(0, "/content/drive/MyDrive/colab/ece247/proj/tmp/")


In [None]:
!python train.py --encoder LSTM --batch_size 32 --epochs 50 --name lstm --learning_rate 1e-3 --gradient_clip_val 5.0 --data_dir /content/drive/MyDrive/colab/ece247/proj/tmp/data/


In [None]:
%%writefile /content/emg2qwerty-main/train.py
import sys
# Add the parent directory of the 'data' package to the Python path.
sys.path.insert(0, "/content/drive/MyDrive/colab/ece247/proj/tmp/")

import os
import argparse
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger

from models.emg_model import EMGLightningModel
from data.datamodule import EMGDataModule

def main():
    parser = argparse.ArgumentParser(description="Train or Test EMG2QWERTY model")
    # Model parameters
    parser.add_argument('--encoder', type=str, default='TDS', choices=['TDS', 'LSTM'],
                        help='Encoder type: TDS (original) or LSTM (new)')
    parser.add_argument('--feature_dim', type=int, default=768, help='Input feature dimension')
    parser.add_argument('--lstm_hidden_size', type=int, default=256, help='LSTM hidden size')
    parser.add_argument('--lstm_num_layers', type=int, default=2, help='Number of LSTM layers')
    parser.add_argument('--lstm_dropout', type=float, default=0.2, help='Dropout rate for LSTM')
    # Training parameters
    parser.add_argument('--batch_size', type=int, default=32, help='Batch size for training')
    parser.add_argument('--epochs', type=int, default=50, help='Number of training epochs')
    parser.add_argument('--learning_rate', type=float, default=1e-3, help='Initial learning rate')
    parser.add_argument('--gradient_clip_val', type=float, default=5.0, help='Gradient clipping value')
    # Data and output paths
    parser.add_argument('--data_dir', type=str, default='data', help='Dataset directory')
    parser.add_argument('--name', type=str, default='model', help='Name for saving model and logs')
    # Test-only mode
    parser.add_argument('--test_only', action='store_true', help='Run in test-only mode (requires checkpoint)')
    parser.add_argument('--ckpt_path', type=str, default=None, help='Checkpoint path for test-only mode')
    args = parser.parse_args()

    data_module = EMGDataModule(data_dir=args.data_dir, batch_size=args.batch_size)
    data_module.setup(stage="fit")

    if args.test_only:
        if args.ckpt_path is None:
            raise ValueError("For test-only mode, please specify a checkpoint path using --ckpt_path")
        # Load model from checkpoint and run test
        model = EMGLightningModel.load_from_checkpoint(args.ckpt_path)
        trainer = pl.Trainer(accelerator="auto", devices="auto")
        trainer.test(model, datamodule=data_module)
    else:
        model = EMGLightningModel(
            encoder_type=args.encoder,
            feature_dim=args.feature_dim,
            lstm_hidden_size=args.lstm_hidden_size,
            lstm_num_layers=args.lstm_num_layers,
            lstm_dropout=args.lstm_dropout,
            learning_rate=args.learning_rate
        )

        logger = TensorBoardLogger(save_dir="lightning_logs", name=f"{args.encoder.lower()}_{args.name}")
        checkpoint_callback = ModelCheckpoint(
            monitor="val_loss",
            filename="best_model-{epoch:02d}-{val_loss:.2f}",
            save_top_k=3,
            mode="min"
        )
        early_stop_callback = EarlyStopping(monitor="val_loss", patience=10, mode="min")

        trainer = pl.Trainer(
            max_epochs=args.epochs,
            gradient_clip_val=args.gradient_clip_val,
            logger=logger,
            callbacks=[checkpoint_callback, early_stop_callback],
            accelerator="auto",
            devices="auto"
        )

        trainer.fit(model, datamodule=data_module)
        trainer.test(model, datamodule=data_module)
        print(f"Best model checkpoint: {checkpoint_callback.best_model_path}")

if __name__ == "__main__":
    main()


In [None]:
!python train.py --encoder LSTM --batch_size 32 --epochs 50 --name lstm --learning_rate 1e-3 --gradient_clip_val 5.0 --data_dir /content/drive/MyDrive/colab/ece247/proj/tmp/data/


In [None]:
!python train.py --test_only --ckpt_path lightning_logs/lstm_lstm/version_4/checkpoints/best_model-epoch=21-val_loss=3.57.ckpt --data_dir /content/drive/MyDrive/colab/ece247/proj/tmp/data/


In [None]:
%%writefile /content/emg2qwerty-main/compare_models.py
import sys
# Add the parent directory (where the data folder resides) to the Python path.
sys.path.insert(0, "/content/drive/MyDrive/colab/ece247/proj/tmp/")

import os
import torch
import pytorch_lightning as pl
import argparse
import time
import json
from models.emg_model import EMGLightningModel
from data.datamodule import EMGDataModule

def format_metric(value):
    return f"{value:.2f}" if value is not None else "N/A"

def main():
    parser = argparse.ArgumentParser(description="Compare TDS and LSTM models")
    parser.add_argument('--tds_checkpoint', type=str, default=None,
                        help='Path to TDS model checkpoint (optional)')
    parser.add_argument('--lstm_checkpoint', type=str, required=True,
                        help='Path to LSTM model checkpoint')
    parser.add_argument('--data_dir', type=str, default='data', help='Test dataset directory')
    parser.add_argument('--batch_size', type=int, default=32, help='Batch size for evaluation')
    args = parser.parse_args()

    # Check that the LSTM checkpoint exists
    if not os.path.exists(args.lstm_checkpoint):
        raise FileNotFoundError(f"Checkpoint not found: {args.lstm_checkpoint}")

    # If provided, check TDS checkpoint exists
    if args.tds_checkpoint is not None and not os.path.exists(args.tds_checkpoint):
        raise FileNotFoundError(f"Checkpoint not found: {args.tds_checkpoint}")

    data_module = EMGDataModule(data_dir=args.data_dir, batch_size=args.batch_size)
    data_module.setup(stage='test')
    test_dataloader = data_module.test_dataloader()

    results = {"TDS": {}, "LSTM": {}, "comparison": {}}

    # Evaluate TDS model only if a checkpoint is provided
    if args.tds_checkpoint is not None:
        print("\nEvaluating TDS model...")
        tds_model = EMGLightningModel.load_from_checkpoint(args.tds_checkpoint)
        tds_model.eval()
        tds_params = sum(p.numel() for p in tds_model.parameters())
        results["TDS"]["parameters"] = tds_params
        print(f"TDS model parameters: {tds_params:,}")

        start_time = time.time()
        tds_trainer = pl.Trainer(accelerator="auto", devices=1)
        tds_results = tds_trainer.test(tds_model, test_dataloader)[0]
        tds_time = time.time() - start_time
        results["TDS"]["CER"] = tds_results.get("test_cer", None)
        results["TDS"]["WER"] = tds_results.get("test_wer", None)
        results["TDS"]["inference_time"] = tds_time
        print(f"TDS - CER: {format_metric(results['TDS']['CER'])}%, WER: {format_metric(results['TDS']['WER'])}%, Time: {format_metric(tds_time)}s")
    else:
        print("No TDS checkpoint provided. Skipping TDS evaluation.")

    print("\nEvaluating LSTM model...")
    lstm_model = EMGLightningModel.load_from_checkpoint(args.lstm_checkpoint)
    lstm_model.eval()
    lstm_params = sum(p.numel() for p in lstm_model.parameters())
    results["LSTM"]["parameters"] = lstm_params
    print(f"LSTM model parameters: {lstm_params:,}")

    start_time = time.time()
    lstm_trainer = pl.Trainer(accelerator="auto", devices=1)
    lstm_results = lstm_trainer.test(lstm_model, test_dataloader)[0]
    lstm_time = time.time() - start_time
    results["LSTM"]["CER"] = lstm_results.get("test_cer", None)
    results["LSTM"]["WER"] = lstm_results.get("test_wer", None)
    results["LSTM"]["inference_time"] = lstm_time
    print(f"LSTM - CER: {format_metric(results['LSTM']['CER'])}%, WER: {format_metric(results['LSTM']['WER'])}%, Time: {format_metric(lstm_time)}s")

    # Compute comparisons only if both checkpoints are available.
    if args.tds_checkpoint is not None:
        if results["TDS"].get("CER") is not None and results["LSTM"].get("CER") is not None:
            results["comparison"]["CER_diff"] = results["LSTM"]["CER"] - results["TDS"]["CER"]
            results["comparison"]["CER_relative"] = (results["LSTM"]["CER"] / results["TDS"]["CER"] - 1) * 100
        if results["TDS"].get("WER") is not None and results["LSTM"].get("WER") is not None:
            results["comparison"]["WER_diff"] = results["LSTM"]["WER"] - results["TDS"]["WER"]
            results["comparison"]["WER_relative"] = (results["LSTM"]["WER"] / results["TDS"]["WER"] - 1) * 100
        if results["TDS"].get("inference_time") and results["LSTM"].get("inference_time"):
            results["comparison"]["speed_ratio"] = results["TDS"]["inference_time"] / results["LSTM"]["inference_time"]
        if results["TDS"].get("parameters") and results["LSTM"].get("parameters"):
            results["comparison"]["param_diff"] = results["LSTM"]["parameters"] - results["TDS"]["parameters"]
            results["comparison"]["param_ratio"] = results["LSTM"]["parameters"] / results["TDS"]["parameters"]

    print("\n=== Model Comparison Results ===")
    if args.tds_checkpoint is not None:
        print(f"Character Error Rate: TDS = {format_metric(results['TDS'].get('CER'))}%, LSTM = {format_metric(results['LSTM'].get('CER'))}%")
        print(f"CER Difference: {format_metric(results['comparison'].get('CER_diff'))}% ({format_metric(results['comparison'].get('CER_relative'))}% relative)")
        print(f"Word Error Rate: TDS = {format_metric(results['TDS'].get('WER'))}%, LSTM = {format_metric(results['LSTM'].get('WER'))}%")
        print(f"WER Difference: {format_metric(results['comparison'].get('WER_diff'))}% ({format_metric(results['comparison'].get('WER_relative'))}% relative)")
        print(f"Model Size: TDS = {results['TDS'].get('parameters', 'N/A'):,}, LSTM = {results['LSTM'].get('parameters', 'N/A'):,}")
        print(f"Parameter Ratio: LSTM is {format_metric(results['comparison'].get('param_ratio'))}x the size of TDS")
        print(f"Speed Ratio (TDS inference / LSTM inference): {format_metric(results['comparison'].get('speed_ratio'))}")
    else:
        print("Only LSTM checkpoint evaluated; no TDS comparison available.")

    with open("model_comparison_results.json", "w") as f:
        json.dump(results, f, indent=4)
    print("\nComparison results saved to model_comparison_results.json")

if __name__ == "__main__":
    main()


In [None]:
!python compare_models.py --lstm_checkpoint /content/emg2qwerty-main/lightning_logs/lstm_lstm/version_4/checkpoints/best_model-epoch=21-val_loss=3.57.ckpt --batch_size 32


In [None]:
%%writefile /content/emg2qwerty-main/models/emg_model.py
import torch
import torch.nn as nn
import pytorch_lightning as pl
import torch.nn.functional as F

# Since we're only using LSTM, we only import the LSTM encoder.
from .lstm_encoder import LSTMEncoder

class EMGLightningModel(pl.LightningModule):
    def __init__(
        self,
        encoder_type="LSTM",          # Only LSTM is supported in this update
        feature_dim=768,             # Input feature dimension
        num_classes=30,              # Number of output classes (including blank index 0)
        learning_rate=1e-3,
        lstm_hidden_size=256,
        lstm_num_layers=2,
        lstm_dropout=0.2,
        **kwargs
    ):
        super().__init__()
        self.save_hyperparameters()
        self.encoder_type = encoder_type.upper()

        if self.encoder_type == "LSTM":
            self.encoder = LSTMEncoder(
                input_dim=feature_dim,
                hidden_size=lstm_hidden_size,
                num_layers=lstm_num_layers,
                bidirectional=True,
                dropout=lstm_dropout
            )
            encoder_out_dim = self.encoder.output_dim
        else:
            raise ValueError(f"Unknown encoder type: {encoder_type}")

        self.classifier = nn.Linear(encoder_out_dim, num_classes)
        self.ctc_loss = nn.CTCLoss(blank=0, zero_infinity=True)
        self.learning_rate = learning_rate

    def forward(self, x, lengths=None):
        # For LSTM, assume x is already in (B, T, F) format.
        features = self.encoder(x, lengths)
        logits = self.classifier(features)
        return logits

    def training_step(self, batch, batch_idx):
        x, lengths, targets, target_lengths = batch
        logits = self(x, lengths)
        log_probs = F.log_softmax(logits, dim=-1).transpose(0, 1)
        loss = self.ctc_loss(log_probs, targets, lengths, target_lengths)
        self.log("train_loss", loss, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, lengths, targets, target_lengths = batch
        logits = self(x, lengths)
        log_probs = F.log_softmax(logits, dim=-1).transpose(0, 1)
        loss = self.ctc_loss(log_probs, targets, lengths, target_lengths)
        self.log("val_loss", loss, prog_bar=True)
        return loss

    def test_step(self, batch, batch_idx):
        x, lengths, targets, target_lengths = batch
        logits = self(x, lengths)
        log_probs = F.log_softmax(logits, dim=-1).transpose(0, 1)

        # Decode predictions using greedy decoding
        predictions = self._greedy_decode(log_probs)

        # Compute CER (Character Error Rate) using Levenshtein distance.
        total_edits = 0
        total_length = 0
        for pred, target, tgt_len in zip(predictions, targets, target_lengths):
            target_list = target[:tgt_len].tolist()
            edits = self._levenshtein_distance(pred, target_list)
            total_edits += edits
            total_length += len(target_list)
        cer = total_edits / total_length if total_length > 0 else 0.0
        # For now, we approximate WER (Word Error Rate) as 1.2 * CER (placeholder)
        wer = cer * 1.2

        self.log("test_cer", cer, prog_bar=True)
        self.log("test_wer", wer, prog_bar=True)
        return {"test_cer": cer, "test_wer": wer}

    def _greedy_decode(self, log_probs):
        """
        Performs greedy decoding on log probabilities.

        Args:
            log_probs: Tensor of shape (T, B, C)

        Returns:
            A list of lists of token indices for each sample in the batch.
        """
        # Greedy decoding: take the argmax for each time step.
        pred_indices = torch.argmax(log_probs, dim=-1)  # Shape: (T, B)
        pred_indices = pred_indices.transpose(0, 1)       # Shape: (B, T)
        decoded = []
        for seq in pred_indices:
            seq = seq.tolist()
            # Collapse repeated tokens and remove blanks (assume blank index = 0)
            prev = None
            decoded_seq = []
            for token in seq:
                if token == 0:  # blank token
                    prev = None
                    continue
                if token != prev:
                    decoded_seq.append(token)
                    prev = token
            decoded.append(decoded_seq)
        return decoded

    def _levenshtein_distance(self, s1, s2):
        """
        Compute the Levenshtein distance between sequences s1 and s2.
        """
        if len(s1) < len(s2):
            return self._levenshtein_distance(s2, s1)
        if len(s2) == 0:
            return len(s1)
        previous_row = list(range(len(s2) + 1))
        for i, c1 in enumerate(s1):
            current_row = [i + 1]
            for j, c2 in enumerate(s2):
                insertions = previous_row[j + 1] + 1
                deletions = current_row[j] + 1
                substitutions = previous_row[j] + (c1 != c2)
                current_row.append(min(insertions, deletions, substitutions))
            previous_row = current_row
        return previous_row[-1]

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate, weight_decay=1e-5)
        scheduler = {
            "scheduler": torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.trainer.max_epochs),
            "interval": "epoch",
            "frequency": 1,
        }
        return {"optimizer": optimizer, "lr_scheduler": scheduler}


In [None]:
!python train.py --encoder LSTM --batch_size 32 --epochs 50 --name lstm --learning_rate 1e-3 --gradient_clip_val 5.0 --data_dir /content/drive/MyDrive/colab/ece247/proj/tmp/data/


In [None]:
%load_ext tensorboard
%tensorboard --logdir lightning_logs/


In [None]:
!cp -dpRvf /content/emg2qwerty-main/   /content/drive/MyDrive/colab/ece247/proj/tmp/bak2/

In [None]:
!sync