In [1]:
!git clone https://github.com/YunshiWen/VQShape.git
%cd VQShape


fatal: destination path 'VQShape' already exists and is not an empty directory.
/content/VQShape


In [None]:
!pip install lightning
!pip install sktime


In [2]:
from vqshape.pretrain import LitVQShape
import torch

checkpoint = "checkpoints/uea_dim256_codebook512/VQShape.ckpt"
lit = LitVQShape.load_from_checkpoint(checkpoint, map_location="cuda")
base_model = lit.model
base_model.eval()

for p in base_model.parameters():
    p.requires_grad = False


In [3]:
from sktime.datasets import load_from_tsfile
import numpy as np

X_train, y_train = load_from_tsfile("/content/FingerMovements_TRAIN.ts", return_data_type="numpy3D")
X_test, y_test   = load_from_tsfile("/content/FingerMovements_TEST.ts",  return_data_type="numpy3D")

label_map = {"left": 0, "right": 1}
y_train = np.array([label_map[y] for y in y_train])
y_test  = np.array([label_map[y] for y in y_test])


In [None]:
print(base_model)


In [79]:
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau # Importamos el scheduler
import numpy as np
import copy # Para guardar el mejor modelo

# --- (Asegúrate de tener tus funciones `preprocess_signal` y `get_histogram` definidas) ---
# ... (preprocess_signal y get_histogram van aquí) ...

# --- 1. Pre-calcula TODOS los histogramas (sin cambios) ---
print("--- Paso 1: Calculando Features ---")
H_train = get_features(X_train)
H_test  = get_features(X_test)
y_train_tensor = torch.tensor(y_train).long().cuda()
y_test_tensor  = torch.tensor(y_test).long().cuda()
print("Features calculados.")

# --- 2. Define el Clasificador (sin cambios) ---
class SimpleClassifier(nn.Module):
    def __init__(self, hist_dim=512):
        super().__init__()
        self.net = nn.Linear(hist_dim, 2)

    def forward(self, h):
        return self.net(h)

clf = SimpleClassifier(hist_dim=512).cuda()

# --- 3. Define Optimizador y Scheduler (¡AJUSTES CLAVE!) ---
# El repo usa weight_decay=1e-5 y lr=1e-3
optimizer = torch.optim.Adam(clf.parameters(), lr=1e-3, weight_decay=1e-5)
criterion = nn.CrossEntropyLoss()

# Usamos un scheduler para reducir el LR si el test_acc deja de mejorar
# Esto ayuda a "afinar" el modelo y evitar sobreajuste
scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=10)

# --- 4. DataLoaders (sin cambios) ---
BATCH_SIZE = 32
train_dataset = TensorDataset(H_train, y_train_tensor)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_dataset = TensorDataset(H_test, y_test_tensor)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

# --- 5. Bucle de Entrenamiento con Early Stopping y Scheduler ---
print("\n--- Paso 2: Iniciando entrenamiento optimizado ---")

# Variables para guardar el mejor modelo
best_test_acc = 0.0
best_epoch = 0
best_model_state = None # Guardaremos los "pesos" del mejor modelo
MAX_EPOCHS = 150 # Más épocas

for epoch in range(MAX_EPOCHS):
    clf.train()
    train_correct = 0
    train_total = 0

    for h_batch, y_batch in train_loader:
        logits = clf(h_batch)
        loss = criterion(logits, y_batch)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        preds = logits.argmax(1)
        train_correct += (preds == y_batch).sum().item()
        train_total += y_batch.size(0)

    train_acc = train_correct / train_total

    # --- Bucle de Evaluación (Validación) ---
    clf.eval()
    test_correct = 0
    test_total = 0
    with torch.no_grad():
        for h_batch, y_batch in test_loader:
            logits = clf(h_batch)
            preds = logits.argmax(1)
            test_correct += (preds == y_batch).sum().item()
            test_total += y_batch.size(0)

    test_acc = test_correct / test_total

    # --- Lógica de Early Stopping y Scheduler ---

    # El scheduler se alimenta con el accuracy de test
    scheduler.step(test_acc)

    if test_acc > best_test_acc:
        best_test_acc = test_acc
        best_epoch = epoch
        best_model_state = copy.deepcopy(clf.state_dict()) # Guarda una copia de los pesos

        if (epoch + 1) % 5 == 0 or epoch == 0:
             print(f"Epoch {epoch+1:03d} | Train Acc: {train_acc:.3f} | Test Acc: {test_acc:.3f} (¡Nuevo Mejor!)")
    else:
        if (epoch + 1) % 10 == 0:
             print(f"Epoch {epoch+1:03d} | Train Acc: {train_acc:.3f} | Test Acc: {test_acc:.3f}")

    # Parada temprana si no mejora en 30 épocas
    if (epoch - best_epoch) > 30:
        print("Parando temprano por falta de mejora.")
        break

print(f"\nEntrenamiento finalizado.")
print(f"Mejor Test Accuracy: {best_test_acc:.3f} (en época {best_epoch+1})")

# Carga los mejores pesos en el modelo
if best_model_state:
    clf.load_state_dict(best_model_state)

--- Paso 1: Calculando Features ---
Calculando 316 histogramas...
Calculando 100 histogramas...
Features calculados.

--- Paso 2: Iniciando entrenamiento optimizado ---
Epoch 001 | Train Acc: 0.506 | Test Acc: 0.510 (¡Nuevo Mejor!)
Epoch 010 | Train Acc: 0.630 | Test Acc: 0.510
Epoch 020 | Train Acc: 0.684 | Test Acc: 0.540
Epoch 030 | Train Acc: 0.680 | Test Acc: 0.520
Epoch 040 | Train Acc: 0.703 | Test Acc: 0.490
Parando temprano por falta de mejora.

Entrenamiento finalizado.
Mejor Test Accuracy: 0.590 (en época 9)


In [80]:
# --- Guardar el mejor modelo ---
torch.save(best_model_state, "best_eeg_classifier.pt")
print("Modelo guardado como best_eeg_classifier.pt")


Modelo guardado como best_eeg_classifier.pt


In [81]:
clf = SimpleClassifier(hist_dim=512).cuda()
clf.load_state_dict(torch.load("best_eeg_classifier.pt"))
clf.eval()


SimpleClassifier(
  (net): Linear(in_features=512, out_features=2, bias=True)
)

In [85]:
i = 0   # índice

with torch.no_grad():
    x = preprocess_signal(X_test[i])
    hist = get_histogram(x)
    pred = clf(hist.unsqueeze(0)).argmax(1).item()

print("Predicción:", "right" if pred == 1 else "left")
print("Etiqueta real:", "right" if y_test[i] == 1 else "left")


Predicción: right
Etiqueta real: right
