# NEST v2: Neural EEG Sequence Transducer â€” Cloud Training

Train NEST v2 on Google Colab using GPU.

**Requirements:**
- Runtime type: GPU (T4 or A100 recommended)
- ZuCo dataset uploaded to Google Drive under `MyDrive/ZuCo_Dataset/ZuCo`

In [None]:
# Install dependencies
!pip install torch torchaudio transformers jiwer scipy h5py -q
!git clone https://github.com/wazder/NEST /content/NEST 2>/dev/null || (cd /content/NEST && git pull)

import os
os.chdir('/content/NEST')
print("Working directory:", os.getcwd())

# NEST v2 Architecture:
# - Input: Word-level EEG frequency features (840-dim: 105 channels x 8 freq bands)
# - EEG Encoder: Transformer (6 layers, d_model=768, 8 heads)
# - Text Decoder: BART (facebook/bart-base) with cross-attention to EEG encoder
# - Dataset: ZuCo (3 tasks, 11 subjects, ~12K sentence-subject pairs)
# - Evaluation: Subject-independent (train on 8, test on 2 held-out subjects)

In [None]:
import torch

if torch.cuda.is_available():
    props = torch.cuda.get_device_properties(0)
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"CUDA memory: {props.total_memory / 1e9:.1f} GB")
    print(f"CUDA version: {torch.version.cuda}")
else:
    print("GPU: Not available")
    print("WARNING: Training without GPU will be very slow.")
    print("Go to Runtime > Change runtime type > GPU")

print(f"PyTorch version: {torch.__version__}")

In [None]:
# Mount Google Drive for ZuCo data access
from google.colab import drive
drive.mount('/content/drive')

import os

ZUCO_PATH = '/content/drive/MyDrive/ZuCo_Dataset/ZuCo'

if os.path.exists(ZUCO_PATH):
    mat_files = [f for f in os.listdir(ZUCO_PATH) if f.endswith('.mat')]
    print(f"ZuCo data found: {len(mat_files)} .mat files")
    for f in sorted(mat_files)[:5]:
        size_mb = os.path.getsize(os.path.join(ZUCO_PATH, f)) / 1e6
        print(f"  {f}: {size_mb:.1f} MB")
    if len(mat_files) > 5:
        print(f"  ... and {len(mat_files) - 5} more")
else:
    print("ZuCo data NOT found at:", ZUCO_PATH)
    print()
    print("Setup instructions:")
    print("1. Request ZuCo dataset: https://osf.io/q3zws/")
    print("2. Upload to Google Drive under: MyDrive/ZuCo_Dataset/ZuCo/")
    print("3. Expected files: resultsZAB.mat, resultsZDM.mat, etc.")

In [None]:
# Run NEST v2 training
# Adjust --data-dir and --output-dir as needed

!python scripts/train_nest_v2.py \
    --model bart \
    --epochs 200 \
    --batch-size 16 \
    --fp16 \
    --grad-accum 4 \
    --d-model 768 \
    --num-layers 6 \
    --data-dir /content/drive/MyDrive/ZuCo_Dataset/ZuCo \
    --output-dir /content/results/nest_v2_bart

In [None]:
import json
import matplotlib.pyplot as plt
import os

results_path = '/content/results/nest_v2_bart/training_log.json'

if not os.path.exists(results_path):
    print(f"Results not found at {results_path}")
    print("Check /content/results/nest_v2_bart/ for available files:")
    if os.path.exists('/content/results/nest_v2_bart'):
        for f in os.listdir('/content/results/nest_v2_bart'):
            print(f"  {f}")
else:
    with open(results_path) as f:
        log = json.load(f)

    epochs = [e['epoch'] for e in log]
    train_loss = [e['train_loss'] for e in log]
    val_wer = [e.get('val_wer', None) for e in log]

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

    ax1.plot(epochs, train_loss, 'b-', linewidth=1.5)
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.set_title('Training Loss')
    ax1.grid(True, alpha=0.3)

    if any(v is not None for v in val_wer):
        valid_epochs = [e for e, w in zip(epochs, val_wer) if w is not None]
        valid_wer = [w for w in val_wer if w is not None]
        ax2.plot(valid_epochs, valid_wer, 'r-', linewidth=1.5)
        ax2.set_xlabel('Epoch')
        ax2.set_ylabel('WER')
        ax2.set_title('Validation WER')
        ax2.grid(True, alpha=0.3)
        best_wer = min(valid_wer)
        print(f"Best WER: {best_wer:.4f} ({best_wer*100:.2f}%)")

    plt.tight_layout()
    plt.savefig('/content/results/nest_v2_bart/training_curves.png', dpi=150, bbox_inches='tight')
    plt.show()
    print(f"Final train loss: {train_loss[-1]:.4f}")

In [None]:
import shutil
import os
from datetime import datetime

timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
dest = f'/content/drive/MyDrive/NEST_results_{timestamp}'

if os.path.exists('/content/results/nest_v2_bart'):
    shutil.copytree('/content/results/nest_v2_bart', dest)
    print(f"Results saved to Google Drive: {dest}")
    saved_files = os.listdir(dest)
    print(f"Saved {len(saved_files)} files:")
    for f in sorted(saved_files):
        size_mb = os.path.getsize(os.path.join(dest, f)) / 1e6
        print(f"  {f}: {size_mb:.1f} MB")
else:
    print("No results found at /content/results/nest_v2_bart")