# Routing Weight Optimization: Data Processing and Training Workflow

This notebook runs the main workflow for the offline RL agent:
1. **Setup:** Clones the repo (or pulls updates), installs dependencies, and sets up paths.
2. **Data Preparation:** Runs the `create_dataset.py` script to generate the `MDPDataset`.
3. **Model Training:** Runs the `train_cql.py` script to train the CQL agent.
4. **Basic Results:** Loads and plots some training metrics.

## 1. Setup

In [None]:
# === Google Colab Setup ===
# Clone the repository (if not already done) and pull latest changes

import os

REPO_NAME = "openRoad-dr-training" # Your repository name
# Use HTTPS URL for easier public/token access in Colab by default
# You might need to generate a Personal Access Token (PAT) on GitHub 
# and use it in the URL like: https://<YOUR_PAT>@github.com/saikanam/openRoad-dr-training.git
# Or configure SSH keys if you prefer.
REPO_URL_HTTPS = "https://github.com/saikanam/openRoad-dr-training.git"

# Check if running in Colab
try:
    import google.colab
    IN_COLAB = True
except ImportError:
    IN_COLAB = False

if IN_COLAB:
    print("Running in Google Colab")
    # Check if repo already exists in Colab's /content directory
    colab_repo_path = f"/content/{REPO_NAME}"
    if not os.path.exists(colab_repo_path):
        print(f"Cloning repository: {REPO_URL_HTTPS} into {colab_repo_path}")
        # Clone using HTTPS
        !git clone {REPO_URL_HTTPS} {colab_repo_path}
        %cd {colab_repo_path}
    else:
        print(f"Repository {REPO_NAME} already exists in /content. Pulling latest changes...")
        %cd {colab_repo_path}
        !git pull origin main
else:
    # Assume running locally, repository is the current directory
    print("Running locally, repository assumed to be current directory.")
    # Optionally, you could still run git pull here if desired
    # !git pull origin main 
    pass 

print(f"\nCurrent directory: {os.getcwd()}")

In [None]:
# === Install Dependencies ===
print("Installing dependencies from requirements.txt...")
!pip install -r requirements.txt

In [None]:
# === Imports and Path Definitions ===
import os
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import d3rlpy # Verify import after install

# Define relative paths (should work relative to repo root)
DATA_INPUT_DIR = "training_data" # Relative path from create_dataset.py default
DATA_OUTPUT_DIR = "data" 
DATASET_FILENAME = "routing_dataset.h5"
DATASET_PATH = os.path.join(DATA_OUTPUT_DIR, DATASET_FILENAME)
SCALER_PATH = os.path.join(DATA_OUTPUT_DIR, "state_scaler.pkl")
LOG_DIR = "d3rlpy_logs"

# Ensure output directories exist
os.makedirs(DATA_OUTPUT_DIR, exist_ok=True)
os.makedirs(LOG_DIR, exist_ok=True)

print(f"Dataset will be saved to: {DATASET_PATH}")
print(f"Scaler will be saved to: {SCALER_PATH}")
print(f"Training logs will be saved in: {LOG_DIR}")

## 2. Data Preparation

Run the script to process the raw CSV data into the `MDPDataset` format required by d3rlpy. This assumes the `training_data` directory exists at the root of the repository.

In [None]:
# === Run Data Creation Script ===
# Note: This assumes the 'training_data' folder is present in the root 
# of the repository when cloned. It's not tracked by Git currently.

if not os.path.exists(DATA_INPUT_DIR):
    print(f"Warning: Input data directory '{DATA_INPUT_DIR}' not found. \n",
          f"Please ensure it exists in the repository root ('{os.getcwd()}') before running this step.")
else:
    print(f"Running data creation script... Input: {DATA_INPUT_DIR}, Output Dir: {DATA_OUTPUT_DIR}")
    !python src/data_processing/create_dataset.py --input_dir {DATA_INPUT_DIR} --output_dir {DATA_OUTPUT_DIR} --output_filename {DATASET_FILENAME}

    # === Verify Output ===
    if os.path.exists(DATASET_PATH):
        print(f"\nDataset file created successfully at: {DATASET_PATH}")
    else:
        print(f"\nError: Dataset file not found at {DATASET_PATH}. Check script output above.")

    if os.path.exists(SCALER_PATH):
        print(f"State scaler file created successfully at: {SCALER_PATH}")
    else:
        print(f"Error: State scaler file not found at {SCALER_PATH}. Check script output above.")

## 3. Model Training

Run the training script using the generated dataset. Adjust hyperparameters as needed.

In [None]:
# === Training Configuration ===
# Using parameters that showed some promise previously (low LR, low alpha, no reward scaling)
EXPERIMENT_NAME = "CQL_Colab_Run_v3" # Increment experiment name
CONSERVATIVE_WEIGHT = 1.0
ACTOR_LR = 1e-6
CRITIC_LR = 1e-6
EPOCHS = 50 # Adjust as needed
USE_REWARD_SCALER_FLAG = "--no-use_reward_scaler" # Use '--use_reward_scaler' to enable
DEVICE_FLAG = "--use_gpu" # Use '--use_cpu' if no GPU available in Colab
SEED = 42 # Set a random seed for reproducibility

# === Run Training Script ===
print(f"Starting training run: {EXPERIMENT_NAME}")
if not os.path.exists(DATASET_PATH):
    print(f"Error: Dataset file not found at {DATASET_PATH}. Cannot start training.")
else:
    !python src/training/train_cql.py \
        --dataset {DATASET_PATH} \
        --experiment_name {EXPERIMENT_NAME} \
        --conservative_weight {CONSERVATIVE_WEIGHT} \
        --actor_lr {ACTOR_LR} \
        --critic_lr {CRITIC_LR} \
        --epochs {EPOCHS} \
        --seed {SEED} \
        {USE_REWARD_SCALER_FLAG} \
        {DEVICE_FLAG}
        # Add other arguments as needed

    print(f"\nTraining finished. Logs should be available in: {os.path.join(LOG_DIR, EXPERIMENT_NAME)}")

## 4. Basic Results

Load and plot some basic training metrics from the logs.

In [None]:
# === Load Logs ===
# Use the experiment name defined in the training cell
log_path = os.path.join(LOG_DIR, EXPERIMENT_NAME) 
metrics_to_plot = [
    "critic_loss", "actor_loss", "temp_loss", "alpha_loss", 
    "conservative_loss", "td_error", "initial_state_value", 
    "temperature", "alpha" # Include temp and alpha if available
]

metrics_data = {}
print(f"Attempting to load logs from: {log_path}")

if os.path.exists(log_path):
    for metric in metrics_to_plot:
        csv_path = os.path.join(log_path, f"{metric}.csv")
        if os.path.exists(csv_path):
            try:
                metrics_data[metric] = pd.read_csv(csv_path)
                print(f"  Loaded {metric}.csv")
            except Exception as e:
                print(f"  Error loading {metric}.csv: {e}")
        else:
            # Don't print missing for optional ones like temp/alpha
            if metric not in ['temperature', 'alpha']: 
                 print(f"  {metric}.csv not found.")
else:
    print(f"Log directory not found: {log_path}")

In [None]:
# === Plot Metrics ===
sns.set_theme(style="darkgrid")
num_metrics = len(metrics_data)

if num_metrics > 0:
    # Determine grid size (e.g., 3 columns)
    ncols = 3
    nrows = (num_metrics + ncols - 1) // ncols 
    fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(18, 5 * nrows), squeeze=False)
    axes = axes.flatten() # Flatten to easily iterate
    
    plot_idx = 0
    for metric, df in metrics_data.items():
        if not df.empty:
            ax = axes[plot_idx]
            if 'step' in df.columns and 'value' in df.columns:
                sns.lineplot(data=df, x='step', y='value', ax=ax)
                ax.set_title(f"{metric} vs. Training Step")
                ax.set_xlabel("Training Step")
                ax.set_ylabel(metric.replace('_', ' ').title())
            else:
                ax.set_title(f"{metric} - Data Missing Columns")
            plot_idx += 1
        
    # Hide unused subplots
    for i in range(plot_idx, len(axes)):
        axes[i].set_visible(False)
        
    plt.tight_layout(pad=3.0)
    plt.show()
else:
    print("No metric data loaded to plot.")