# ðŸ§  EMG Spectrogram CNN Training

## 1. Import Libraries

In [None]:
import os
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split

import torch
from torch.utils.data import DataLoader

import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

# Import the object-oriented model and preprocessing utility
from cnn_model import CNNmodel
from emg_preprocessing import preprocess, EMGDataset, FS, WINDOW_SIZE, STRIDE, NPERSEG, NOVERLAP, N_CHANNELS

print("Libraries and modules imported successfully.")

## 2. Configuration and Device Setup

In [None]:
# --- Training knobs ---
BATCH_SIZE = 32
LR = 1e-3   # 0.001
EPOCHS = 25

In [None]:
# --- Data Labels and Class Names ---
LABELS = {"rest": 0, "pinch": 1}
CLASS_NAMES = ["rest", "pinch"]

# Placeholder for actual data paths
# todo: centralize data location, e.g. ../data/
DATA_FILES = [
    ("../myo/samples/raymond_arm_90_deg_200hz.csv", LABELS["rest"]),
    ("../myo/samples/raymond_arm_90_deg_pinch_200hz.csv", LABELS["pinch"]),
    ("../myo/samples/raymond_arm_down_200hz.csv", LABELS["rest"]),
    ("../myo/samples/raymond_arm_down_pinch_200hz.csv", LABELS["pinch"])
]

# Device Setup (Prioritizing MPS for Apple Silicon)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {DEVICE}")

## 3. Data Preprocessing Functions and Dataset Class
*(This logic is now defined in `emg_preprocessing.py`)*

In [None]:
print("Preprocessing functions imported from emg_preprocessing.py.")

## 4. Load, Normalize, and Split Data

In [None]:
X_all = []
Y_all = []

print("Preprocessing all data files...")
for path, label in DATA_FILES:
    print(f"Processing file: {path} with label: {label}")
    X_specs = preprocess(path)
    if X_specs.size > 0:
        X_all.append(X_specs)
        Y_all.append(np.full(X_specs.shape[0], label, dtype=np.int64))

X_all = np.concatenate(X_all)
Y_all = np.concatenate(Y_all)

print(f"Total windows collected: {len(X_all)}")

# 1. Calculate Global Normalization Parameters
mean = X_all.mean(axis=(0, 2, 3), keepdims=True)
std = X_all.std(axis=(0, 2, 3), keepdims=True)

print(f"Global Mean shape: {mean.shape}")
print(f"Global Std shape: {std.shape}")

# Save normalization params for inference
np.save("normalization_params.npy", {'mean': mean, 'std': std})
print("Normalization parameters saved to normalization_params.npy")

# 2. Apply Normalization
X_all = (X_all - mean) / std

# 3. Split Data
X_train, X_test, Y_train, Y_test = train_test_split(
    X_all, Y_all, test_size=0.2, random_state=42, stratify=Y_all
)

print(f"Training windows: {len(X_train)}")
print(f"Testing windows: {len(X_test)}")

## 5. Initialize DataLoaders and Model

In [8]:
# Initialize Model (The CNNmodel class now handles its own criterion and optimizer)
model = CNNmodel(
    in_channels=N_CHANNELS, 
    num_classes=len(LABELS), 
    learning_rate=LR, 
    device=DEVICE
) 
print("Model initialized with built-in optimizer and loss function.")
print(model)

Model initialized with built-in optimizer and loss function.
CNNmodel(
  (conv1): Conv2d(8, 32, kernel_size=(5, 3), stride=(1, 1), padding=(2, 1))
  (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool1): MaxPool2d(kernel_size=(2, 1), stride=(2, 1), padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool2): MaxPool2d(kernel_size=(2, 1), stride=(2, 1), padding=0, dilation=1, ceil_mode=False)
  (conv3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (global_pool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc1): Linear(in_features=128, out_features=64, bias=True)
  (drop): Dropout(p=0.3, inplace=False)
  (fc2): Linear(in_features=64, out_features=2, bias=True)
  (relu): ReLU()
  (cri

In [9]:
# Create Datasets
train_dataset = EMGDataset(X_train, Y_train)
test_dataset = EMGDataset(X_test, Y_test)

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

print(f"Number of batches in training set: {len(train_loader)}")
print(f"Number of batches in test set: {len(test_loader)}")

Number of batches in training set: 123
Number of batches in test set: 31


## 6. Train Model

In [None]:
best_acc = 0.0
best_preds = None
best_labels = None
weights_path = "train_single_subject_myo_model.pth"

In [None]:
print(f"Starting training for {EPOCHS} epochs on {DEVICE}...")
for epoch in range(1, EPOCHS + 1):
    # Train Epoch
    train_loss, train_acc = model.train_epoch(train_loader)
    
    # Test/Validation Epoch
    test_acc, preds_all, labels_all = model.test_epoch(test_loader)
    print(f"Epoch {epoch:02d} | Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} | Test Acc: {test_acc:.4f}")

    # Save best model checkpoint
    if test_acc > best_acc:
        best_acc = test_acc
        best_preds = preds_all 
        best_labels = labels_all 
        torch.save(model.state_dict(), weights_path)
        print(f"Model saved to {weights_path} (Best Acc: {best_acc:.4f})")
print("\n\nðŸŽ‰ Training Complete")

## 7. Visualize Results

In [None]:
if best_preds is not None and best_labels is not None:
    print("\nBest Test Accuracy:", best_acc)

    # Confusion Matrix
    # Note: Labels [0, 1] correspond to CLASS_NAMES ["rest", "pinch"]
    cm = confusion_matrix(best_labels, best_preds, labels=[0, 1])
    disp = ConfusionMatrixDisplay(cm, display_labels=CLASS_NAMES)
    
    fig, ax = plt.subplots(figsize=(8, 8))
    disp.plot(cmap=plt.cm.Blues, ax=ax)
    plt.title("EMG CNN Classification â€” Confusion Matrix")
    plt.show()
else:
    print("Could not generate confusion matrix. Check if training ran successfully.")