In [None]:
import sys
sys.path.insert(0, '../python')

import numpy as np
import torch
import matplotlib.pyplot as plt
from pathlib import Path

from phantomx.tokenizer import SpikeTokenizer
from phantomx.vqvae import VQVAE, VQVAETrainer
from phantomx.inference import LabramDecoder
from phantomx.data import MCMazeDataLoader

print("PhantomX loaded successfully!")

## 1. POYO Tokenization

In [None]:
# Create tokenizer
tokenizer = SpikeTokenizer(
    n_channels=142,
    quantization_levels=16,
    use_population_norm=True,
    dropout_invariant=True
)

# Generate example spike data
spike_counts = np.random.poisson(2.0, size=142)
print(f"Input: {spike_counts.shape} spike counts")

# Fit tokenizer (in practice, fit on training data)
train_spikes = np.random.poisson(2.0, size=(1000, 142))
tokenizer.fit(train_spikes)

# Tokenize
tokens = tokenizer.tokenize(spike_counts)
print(f"Output: {tokens.shape} discrete tokens")
print(f"Token values: {tokens}")

## 2. VQ-VAE Training

In [None]:
# Create model
model = VQVAE(
    n_tokens=16,
    token_dim=256,
    embedding_dim=64,
    num_codes=256,
    commitment_cost=0.25,
    output_dim=2
)

print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Codebook size: {model.num_codes}")
print(f"Embedding dim: {model.embedding_dim}")

In [None]:
# Load MC_Maze data
data_loader = MCMazeDataLoader(
    data_path='../../PhantomLink/data/mc_maze.nwb',  # Adjust path
    tokenizer=tokenizer,
    batch_size=32
)

train_loader, val_loader, test_loader = data_loader.get_loaders()
print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")

In [None]:
# Train model (short demo - use train_labram.py for full training)
trainer = VQVAETrainer(model, learning_rate=3e-4)

# Train for 5 epochs (demo)
history = trainer.train(
    train_loader=train_loader,
    val_loader=val_loader,
    epochs=5,
    save_dir='../models/demo'
)

In [None]:
# Plot training curves
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

axes[0].plot(history['train_loss'], label='Train')
axes[0].plot(history['val_loss'], label='Val')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Total Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

axes[1].plot(history['reconstruction_loss'])
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('MSE')
axes[1].set_title('Reconstruction Loss')
axes[1].grid(True, alpha=0.3)

axes[2].plot(history['perplexity'])
axes[2].set_xlabel('Epoch')
axes[2].set_ylabel('Perplexity')
axes[2].set_title('Codebook Utilization')
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 3. Zero-Shot Decoding

In [None]:
# Create decoder
decoder = LabramDecoder(
    model=model,
    tokenizer=tokenizer,
    use_tta=False
)

# Single sample inference
spike_counts = np.random.poisson(2.0, size=142)
velocity = decoder.decode(spike_counts)
print(f"Predicted velocity: vx={velocity[0]:.3f}, vy={velocity[1]:.3f}")

In [None]:
# Batch inference on test set
all_preds = []
all_targets = []

for batch in test_loader:
    spikes = batch['spike_counts'].numpy()
    targets = batch['kinematics'].numpy()
    
    preds = decoder.decode(spikes)
    all_preds.append(preds)
    all_targets.append(targets)

all_preds = np.concatenate(all_preds)
all_targets = np.concatenate(all_targets)

# Compute R²
ss_res = np.sum((all_targets - all_preds) ** 2)
ss_tot = np.sum((all_targets - np.mean(all_targets, axis=0)) ** 2)
r2 = 1 - (ss_res / ss_tot)

print(f"Zero-shot R² score: {r2:.4f}")

## 4. Test-Time Adaptation

In [None]:
# Enable TTA
decoder_tta = LabramDecoder(
    model=model,
    tokenizer=tokenizer,
    use_tta=True,
    tta_strategy='entropy',
    tta_lr=1e-4
)

# Test with drifted data
# (In practice, use PhantomLink's NoiseInjectionMiddleware)
drift = np.linspace(0, 1, 100)
tta_preds = []

for i in range(100):
    spikes = np.random.poisson(2.0 + drift[i], size=142)
    velocity = decoder_tta.decode(spikes)
    tta_preds.append(velocity)

tta_preds = np.array(tta_preds)

# Get TTA statistics
stats = decoder_tta.get_statistics()
print(f"Samples adapted: {stats['n_samples_adapted']}")
print(f"Mean entropy: {stats['mean_entropy']:.4f}")

## 5. Codebook Visualization

In [None]:
# Get codebook embeddings
codebook = model.get_codebook_embeddings().cpu().numpy()
print(f"Codebook shape: {codebook.shape}")

# PCA visualization
from sklearn.decomposition import PCA

pca = PCA(n_components=2)
codebook_2d = pca.fit_transform(codebook)

plt.figure(figsize=(10, 8))
plt.scatter(codebook_2d[:, 0], codebook_2d[:, 1], c=range(len(codebook)), cmap='viridis', s=50)
plt.colorbar(label='Code Index')
plt.xlabel('PC1')
plt.ylabel('PC2')
plt.title('Codebook Embedding Space (PCA)')
plt.grid(True, alpha=0.3)
plt.show()

## 6. Electrode Dropout Test

In [None]:
# Test robustness to electrode dropout
dropout_rates = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5]
r2_scores = []

for dropout_rate in dropout_rates:
    preds = []
    targets = []
    
    for batch in test_loader:
        spikes = batch['spike_counts'].numpy()
        target = batch['kinematics'].numpy()
        
        # Simulate electrode dropout
        n_channels = spikes.shape[1]
        dropout_mask = np.random.rand(n_channels) > dropout_rate
        spikes_dropout = spikes * dropout_mask
        
        pred = decoder.decode(spikes_dropout)
        preds.append(pred)
        targets.append(target)
    
    preds = np.concatenate(preds)
    targets = np.concatenate(targets)
    
    ss_res = np.sum((targets - preds) ** 2)
    ss_tot = np.sum((targets - np.mean(targets, axis=0)) ** 2)
    r2 = 1 - (ss_res / ss_tot)
    r2_scores.append(r2)

plt.figure(figsize=(8, 6))
plt.plot(dropout_rates, r2_scores, 'o-', linewidth=2, markersize=8)
plt.xlabel('Electrode Dropout Rate')
plt.ylabel('R² Score')
plt.title('Decoder Robustness to Electrode Dropout')
plt.grid(True, alpha=0.3)
plt.axhline(y=0.6, color='r', linestyle='--', label='Target (50% dropout)')
plt.legend()
plt.show()

print(f"R² at 50% dropout: {r2_scores[-1]:.4f}")