[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tyronen/recurrent-rebels/blob/hensel/hensel/hensel_colab.ipynb)

# Hensel Model Training with W&B Sweeps

This notebook sets up and runs hyperparameter tuning for the Hensel model using Weights & Biases (wandb). The notebook will:
1. Set up the environment
2. Download required data
3. Run hyperparameter optimization using wandb sweeps
4. Save and analyze the results

**Important:** Make sure you're using a GPU runtime:
- Go to Runtime → Change runtime type
- Select GPU from the dropdown
- Click Save


In [None]:
# Install required packages
%pip install torch torchinfo wandb numpy pandas tqdm tensorboard matplotlib


In [None]:
# Clone the repository and set up the environment
!git clone https://github.com/tyronen/recurrent-rebels.git
%cd recurrent-rebels
!git checkout hensel

# Add the project root to Python path
import sys
sys.path.append('.')


In [None]:
# Login to Weights & Biases
import wandb
wandb.login()  # This will prompt for your API key


In [None]:
# Method 1: Upload from local machine
from google.colab import files

# Create data directory
!mkdir -p data

# Uncomment and run this to upload files from your computer
# uploaded = files.upload()
# for filename in uploaded.keys():
#     !mv "{filename}" "data/{filename}"


In [None]:
# Method 2: Download from Google Drive
from google.colab import drive

# Mount Google Drive
drive.mount('/content/drive')

# Uncomment and modify paths to copy files from Drive
# !cp /content/drive/MyDrive/path_to_data/train.npz data/
# !cp /content/drive/MyDrive/path_to_data/val.npz data/
# !cp /content/drive/MyDrive/path_to_data/train_vocab.json data/


In [None]:
# Verify data files
import os

required_files = ['train.npz', 'val.npz', 'train_vocab.json']
all_files_present = True

for file in required_files:
    path = os.path.join('data', file)
    if os.path.exists(path):
        size_mb = os.path.getsize(path) / (1024 * 1024)
        print(f"✓ {file:<15} found ({size_mb:.1f} MB)")
    else:
        print(f"✗ {file:<15} missing")
        all_files_present = False

if not all_files_present:
    print("\nPlease ensure all required files are present before proceeding.")


In [None]:
# Import the sweep configuration and training function
from hensel.sweep_config import init_sweep
from hensel.train import train

# Initialize the sweep
sweep_id = init_sweep()
print(f"Sweep initialized with ID: {sweep_id}")
print(f"\nView sweep at: https://wandb.ai/hensel-model/sweeps/{sweep_id}")


In [None]:
# Run the sweep
# You can adjust the number of trials by changing count
wandb.agent(sweep_id, function=train, count=20)  # Run 20 trials


In [None]:
# Get the best run from the sweep
api = wandb.Api()
sweep = api.sweep(f"hensel-model/{sweep_id}")
best_run = sweep.best_run()

print(f"Best run: {best_run.name}")
print(f"Best validation loss: {best_run.summary.get('val_loss', 'N/A')}")
print("\nBest hyperparameters:")
for key, value in best_run.config.items():
    print(f"{key}: {value}")


In [None]:
# Plot learning curves
import matplotlib.pyplot as plt
import pandas as pd

# Get the metrics from the best run
history = pd.DataFrame(best_run.scan_history())

plt.figure(figsize=(15, 5))

# Plot training and validation loss
plt.subplot(1, 3, 1)
plt.plot(history['epoch'], history['train_loss'], label='Train', alpha=0.8)
plt.plot(history['epoch'], history['val_loss'], label='Validation', alpha=0.8)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Learning Curves')
plt.legend()
plt.grid(True, alpha=0.3)

# Plot R² scores
plt.subplot(1, 3, 2)
plt.plot(history['epoch'], history['r2_log'], label='R² (log scale)', alpha=0.8)
plt.plot(history['epoch'], history['r2_real'], label='R² (real scale)', alpha=0.8)
plt.xlabel('Epoch')
plt.ylabel('R²')
plt.title('R² Scores')
plt.legend()
plt.grid(True, alpha=0.3)

# Plot MAE
plt.subplot(1, 3, 3)
plt.plot(history['epoch'], history['mae'], label='MAE', alpha=0.8, color='orange')
plt.xlabel('Epoch')
plt.ylabel('MAE')
plt.title('Mean Absolute Error')
plt.legend()
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Print final metrics
final_metrics = history.iloc[-1]
print(f"\nFinal Metrics:")
print(f"Train Loss: {final_metrics['train_loss']:.4f}")
print(f"Val Loss: {final_metrics['val_loss']:.4f}")
print(f"MAE: {final_metrics['mae']:.4f}")
print(f"R² (log): {final_metrics['r2_log']:.4f}")
print(f"R² (real): {final_metrics['r2_real']:.4f}")
