In [2]:
import numpy as np
import matplotlib.pyplot as plt

from meta_training import (
    EvaluationConfig,
    MetaTrainingConfig,
    evaluate_memory_module,
    resolve_device,
    run_meta_training,
)


In [None]:
training_config = MetaTrainingConfig(
    device_preference="cuda",
    key_dim=16,
    val_dim=16,
    context_dim=5,
    seq_len=50,
    num_sequences=500,
    batch_size=10,
    recall_window=1,
    output_corr=0.5,
    outer_lr=0.01,
    beta1=0.95,
    beta2=0.99,
    log_every_sequences=50,
)

actual_device = resolve_device(training_config.device_preference)
print(f"--- Starting Training on {actual_device} ---")
artifacts = run_meta_training(training_config)
print("--- Training complete ---")


--- Starting Training on cuda ---
Epoch 10 | Avg Outer Loss: 3.8744
  Sample Hyperparams -> LR: 0.1215
Epoch 50 | Avg Outer Loss: 3.9578
  Sample Hyperparams -> LR: 0.1111
Epoch 100 | Avg Outer Loss: 3.8525
  Sample Hyperparams -> LR: 0.1002
Epoch 150 | Avg Outer Loss: 3.8996
  Sample Hyperparams -> LR: 0.1216
Epoch 200 | Avg Outer Loss: 3.9090
  Sample Hyperparams -> LR: 0.0950
Epoch 250 | Avg Outer Loss: 3.8454
  Sample Hyperparams -> LR: 0.0983
Epoch 300 | Avg Outer Loss: 3.8128
  Sample Hyperparams -> LR: 0.1048
Epoch 350 | Avg Outer Loss: 3.8212
  Sample Hyperparams -> LR: 0.1322
Epoch 400 | Avg Outer Loss: 3.7702
  Sample Hyperparams -> LR: 0.1051


In [None]:
evaluation_config = EvaluationConfig(
    seq_len=training_config.seq_len,
    num_sequences=20,
    key_dim=training_config.key_dim,
    val_dim=training_config.val_dim,
    context_dim=training_config.context_dim,
    output_corr=training_config.output_corr,
)

results = evaluate_memory_module(artifacts.memory_module, evaluation_config)
results_cpu = results.cpu()
offsets = results_cpu.offsets.numpy()
accuracies = results_cpu.mean_accuracy.numpy()
counts = results_cpu.counts.numpy()

valid_mask = counts > 0
print("--- Recall Accuracy by Offset ---")
for offset, accuracy, count in zip(offsets[valid_mask], accuracies[valid_mask], counts[valid_mask]):
    print(f"Offset {int(offset)} | Accuracy: {accuracy:.3f} | Observations: {int(count)}")

plt.figure(figsize=(6, 4))
plt.plot(offsets, accuracies, marker="o")
plt.xlabel("Offset (0 = current timestep)")
plt.ylabel("Recall accuracy")
plt.title("Recall accuracy by offset")
plt.ylim(0.0, 1.05)
plt.grid(True)
plt.show()


In [None]:
plt.figure(figsize=(6, 4))
plt.plot(np.arange(len(artifacts.outer_losses)), artifacts.outer_losses)
plt.xlabel("Meta update")
plt.ylabel("Average outer loss")
plt.title("Training trajectory")
plt.grid(True)
plt.show()
