In [None]:
import numpy as np
import pickle
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader, Subset
from torch.nn.utils.rnn import pad_sequence
from torch.nn.utils import clip_grad_norm_
from sklearn.preprocessing import StandardScaler
from matplotlib import pyplot as plt
import sys # Import sys for sys.exit()
import time # For timing epochs

# --- Configuration ---
note = 'data'
agent_name = 'ppo_baseline_0331_5cost'
data_path = f'data/{agent_name}_{note}'
model_save_path = f'data/{agent_name}_{note}_mps.pt' # Added suffix for clarity
batch_size = 64 # <<< Consider increasing batch size for GPU utilization >>>
num_epochs = 30
learning_rate = 0.001
hidden_size = 64 # <<< Consider increasing model size for GPU >>>
num_layers = 2
gradient_clip_value = 5.0
train_split = 0.8
val_split = 0.1

# --- Device Setup (MPS Acceleration) ---
print("--- Setting up device ---")
if torch.backends.mps.is_available():
    device = torch.device("mps")
    print("Using MPS device (Apple Silicon GPU)")
elif torch.cuda.is_available():
    device = torch.device("cuda")
    print("Using CUDA device (NVIDIA GPU)")
else:
    device = torch.device("cpu")
    print("Using CPU device")
# Optional: Check if MPS fallback is needed (rarely necessary now)
# print(f"MPS available: {torch.backends.mps.is_available()}")
# print(f"MPS fallback enabled: {torch.backends.mps.is_built()}")


# --- Load Data ---
print(f"Loading data from: {data_path}")
try:
    with open(data_path, 'rb') as f:
        loaded_data = pickle.load(f)
        if not isinstance(loaded_data, (list, tuple)) or len(loaded_data) != 2:
             raise TypeError("Pickle file should contain a list or tuple of two elements: x_data and ys")
        x_data_raw, ys_raw = loaded_data
        if not isinstance(x_data_raw, list) or not isinstance(ys_raw, list):
             raise TypeError("Both elements in the loaded pickle file should be lists.")
        print(f"Loaded {len(x_data_raw)} samples.")
except FileNotFoundError:
    print(f"Error: Data file not found at {data_path}")
    sys.exit(1)
except Exception as e:
    print(f"Error loading or parsing pickle file: {e}")
    sys.exit(1)

# --- Preprocessing ---

# 1. Target Variable Processing
print("Processing target variables (y_data)...")
try:
    y_data = [torch.tensor(y, dtype=torch.float32).view(-1) for y in ys_raw]
    y_data = torch.stack(y_data)
    target_indices = [3, 9, 11]
    y_data = y_data[:, target_indices]
    output_size = y_data.shape[1]
    print(f"  Target shape: {y_data.shape}")

    y_max_vals, _ = torch.max(y_data, dim=0, keepdim=True)
    y_max_vals[y_max_vals == 0] = 1.0
    y_data = y_data / y_max_vals
    print("  Targets normalized by max value per column.")

except IndexError:
     print(f"Error: Target indices {target_indices} out of bounds for loaded y data.")
     sys.exit(1)
except Exception as e:
     print(f"Error processing y_data: {e}")
     sys.exit(1)

# 2. Input Variable Processing (Normalization & Padding)
print("Processing input variables (x_data)...")
x_data_tensors = [torch.tensor(x, dtype=torch.float32) for x in x_data_raw]

if not x_data_tensors:
    print("Error: No input data loaded.")
    sys.exit(1)
input_size = x_data_tensors[0].shape[1]
if any(t.ndim <= 1 or t.shape[1] != input_size for t in x_data_tensors):
     print("Error: Input features have inconsistent dimensions or are not 2D across samples.")
     for i, t in enumerate(x_data_tensors):
         if t.ndim <= 1 or t.shape[1] != input_size: print(f"  Problematic sample index: {i}, shape: {t.shape}")
     sys.exit(1)
print(f"  Input feature size: {input_size}")

try:
    all_x_concatenated = torch.cat(x_data_tensors, dim=0).numpy()
    scaler = StandardScaler()
    scaler.fit(all_x_concatenated)
    x_data_normalized_tensors = [torch.tensor(scaler.transform(x.numpy()), dtype=torch.float32) for x in x_data_tensors]
    print("  Input features normalized (StandardScaler).")
except Exception as e:
    print(f"Error during input normalization: {e}")
    sys.exit(1)

padded_data = pad_sequence(x_data_normalized_tensors, batch_first=True, padding_value=0.0)
print(f"  Padded data shape: {padded_data.shape}")

# --- Create Datasets and DataLoaders ---
print("Creating datasets and dataloaders...")
num_samples = len(padded_data)
indices = torch.randperm(num_samples)

train_end_idx = int(num_samples * train_split)
val_end_idx = train_end_idx + int(num_samples * val_split)
if val_end_idx == train_end_idx and val_end_idx < num_samples: val_end_idx += 1
elif val_end_idx >= num_samples: val_end_idx = train_end_idx

train_indices = indices[:train_end_idx]
val_indices = indices[train_end_idx:val_end_idx]
test_indices = indices[val_end_idx:]

print(f"  Train samples: {len(train_indices)}")
print(f"  Validation samples: {len(val_indices)}")
print(f"  Test samples: {len(test_indices)}")

if len(train_indices) == 0 or len(test_indices) == 0: print("Warning: Training or test set is empty.")
if len(val_indices) == 0: print("Warning: Validation set is empty.")

train_dataset = TensorDataset(padded_data[train_indices], y_data[train_indices])
val_dataset = TensorDataset(padded_data[val_indices], y_data[val_indices]) if len(val_indices) > 0 else None
test_dataset = TensorDataset(padded_data[test_indices], y_data[test_indices])

# Use pin_memory=True if using GPU for potentially faster data transfer
pin_memory_flag = True if device != torch.device("cpu") else False
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=pin_memory_flag)
val_loader = DataLoader(val_dataset, batch_size=batch_size, pin_memory=pin_memory_flag) if val_dataset else None
test_loader = DataLoader(test_dataset, batch_size=batch_size, pin_memory=pin_memory_flag)

# --- Define Model ---
class GRUNet(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, output_size):
        super(GRUNet, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.gru = nn.GRU(input_size, hidden_size, num_layers,
                          batch_first=True, bidirectional=False)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x, h=None):
        out, h_final = self.gru(x, h)
        last_time_step_out = out[:, -1, :]
        out = self.fc(last_time_step_out)
        return out, h_final

# --- Initialize Model, Loss, Optimizer ---
print("Initializing model...")
model = GRUNet(input_size=input_size, hidden_size=hidden_size,
               num_layers=num_layers, output_size=output_size).to(device) # Move model to device
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

print(model)
print(f"Total parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")

# --- Training Loop ---
print("\n--- Starting Training ---")
train_losses = []
val_losses = []

for epoch in range(num_epochs):
    epoch_start_time = time.time()
    model.train()
    epoch_train_loss = 0.0
    for i, (x_batch, y_batch) in enumerate(train_loader):
        # Move batch data to the device
        x_batch, y_batch = x_batch.to(device), y_batch.to(device)

        pred, _ = model(x_batch, None)
        loss = criterion(pred, y_batch)

        optimizer.zero_grad()
        loss.backward()
        clip_grad_norm_(model.parameters(), gradient_clip_value)
        optimizer.step()

        epoch_train_loss += loss.item()

    avg_epoch_train_loss = epoch_train_loss / len(train_loader)
    train_losses.append(avg_epoch_train_loss)

    # --- Validation ---
    model.eval()
    epoch_val_loss = 0.0
    all_val_outputs = []
    all_val_targets = []

    if val_loader:
        with torch.no_grad():
            for x_batch, y_batch in val_loader:
                # Move batch data to the device
                x_batch, y_batch = x_batch.to(device), y_batch.to(device)
                outputs, _ = model(x_batch, None)
                loss = criterion(outputs, y_batch)
                epoch_val_loss += loss.item()
                all_val_outputs.append(outputs.cpu()) # Collect on CPU
                all_val_targets.append(y_batch.cpu()) # Collect on CPU

        avg_epoch_val_loss = epoch_val_loss / len(val_loader)
        val_losses.append(avg_epoch_val_loss)
        epoch_time = time.time() - epoch_start_time
        print(f"Epoch {epoch+1}/{num_epochs}: Train Loss = {avg_epoch_train_loss:.5f}, Val Loss = {avg_epoch_val_loss:.5f}, Time = {epoch_time:.2f}s")

        # Optional plotting (consider doing it less frequently)
        # if (epoch + 1) % 10 == 0 or epoch == num_epochs - 1:
        #    # ... (plotting code as before) ...

    else: # No validation set
         epoch_time = time.time() - epoch_start_time
         print(f"Epoch {epoch+1}/{num_epochs}: Train Loss = {avg_epoch_train_loss:.5f}, Time = {epoch_time:.2f}s")
         val_losses.append(None)

print("--- Training Complete ---")

# --- Testing ---
print("\n--- Starting Testing ---")
model.eval()
test_loss = 0.0
all_test_outputs = []
all_test_targets = []

if not test_loader:
     print("Warning: No test data loaded. Skipping testing.")
else:
    test_start_time = time.time()
    with torch.no_grad():
        for x_batch, y_batch in test_loader:
            # Move batch data to the device
            x_batch, y_batch = x_batch.to(device), y_batch.to(device)
            outputs, _ = model(x_batch, None)
            loss = criterion(outputs, y_batch)
            test_loss += loss.item()
            all_test_outputs.append(outputs.cpu()) # Collect on CPU
            all_test_targets.append(y_batch.cpu()) # Collect on CPU
    test_time = time.time() - test_start_time

    if len(test_loader) > 0:
        avg_test_loss = test_loss / len(test_loader)
        print(f"Final Test Loss (MSE): {avg_test_loss:.5f}, Test Time = {test_time:.2f}s")

        all_test_outputs = torch.cat(all_test_outputs, dim=0).numpy()
        all_test_targets = torch.cat(all_test_targets, dim=0).numpy()

        # --- Plot Test Results ---
        print("Plotting test results...")
        num_outputs_to_plot = all_test_outputs.shape[1]
        fig_test, axes_test = plt.subplots(1, num_outputs_to_plot, figsize=(6 * num_outputs_to_plot, 5.5), squeeze=False)

        for i in range(num_outputs_to_plot):
            ax = axes_test[0, i]
            outputs_i = all_test_outputs[:, i]
            targets_i = all_test_targets[:, i]
            ax.scatter(outputs_i, targets_i, alpha=0.2)
            min_val = min(outputs_i.min() if len(outputs_i) > 0 else 0, targets_i.min() if len(targets_i) > 0 else 0)
            max_val = max(outputs_i.max() if len(outputs_i) > 0 else 1, targets_i.max() if len(targets_i) > 0 else 1)
            padding = (max_val - min_val) * 0.05
            ax.plot([min_val - padding, max_val + padding], [min_val - padding, max_val + padding], 'r--', label='y=x')
            ax.set_title(f'Test Set - Output {i+1} (Index {target_indices[i]})')
            ax.set_xlabel("Predicted Value")
            ax.set_ylabel("Actual Value")
            ax.axis('equal')
            ax.grid(True)
            ax.legend()

        plt.tight_layout()
        plt.show()
    else:
        print("Test loader was empty. No test results to plot.")


# --- Plot Training/Validation Loss Curve ---
plt.figure(figsize=(10, 5))
plt.plot(range(1, num_epochs + 1), train_losses, label='Training Loss')
valid_epochs = [i+1 for i, l in enumerate(val_losses) if l is not None]
valid_val_losses = [l for l in val_losses if l is not None]
if valid_val_losses:
    plt.plot(valid_epochs, valid_val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('MSE Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.grid(True)
plt.show()


# --- Save Model (Optional) ---
# print(f"\nSaving model to: {model_save_path}")
# torch.save({
#     'epoch': num_epochs,
#     'model_state_dict': model.state_dict(),
#     'optimizer_state_dict': optimizer.state_dict(),
#     'loss': avg_test_loss if test_loader and len(test_loader) > 0 else None,
#     'input_size': input_size,
#     'hidden_size': hidden_size,
#     'num_layers': num_layers,
#     'output_size': output_size,
#     'target_indices': target_indices,
#     'scaler_mean': scaler.mean_,
#     'scaler_scale': scaler.scale_,
#     'y_max_vals': y_max_vals.cpu().numpy()
# }, model_save_path)
# print("Model saved.")

print("\nScript finished.")


--- Setting up device ---
Using MPS device (Apple Silicon GPU)
Loading data from: data/ppo_baseline_0331_5cost_data
Loaded 125000 samples.
Processing target variables (y_data)...
  Target shape: torch.Size([125000, 3])
  Targets normalized by max value per column.
Processing input variables (x_data)...
  Input feature size: 5
  Input features normalized (StandardScaler).
  Padded data shape: torch.Size([125000, 500, 5])
Creating datasets and dataloaders...
  Train samples: 100000
  Validation samples: 12500
  Test samples: 12500
Initializing model...
GRUNet(
  (gru): GRU(5, 64, num_layers=2, batch_first=True)
  (fc): Linear(in_features=64, out_features=3, bias=True)
)
Total parameters: 38787

--- Starting Training ---


In [None]:

# --- Save Model (Optional) ---
print(f"\nSaving model to: {model_save_path}")
torch.save({
    'epoch': num_epochs,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': avg_test_loss if test_loader and len(test_loader) > 0 else None, # Save final test loss if available
    'input_size': input_size,
    'hidden_size': hidden_size,
    'num_layers': num_layers,
    'output_size': output_size,
    'target_indices': target_indices,
    'scaler_mean': scaler.mean_, # Save scaler parameters
    'scaler_scale': scaler.scale_,
    'y_max_vals': y_max_vals.cpu().numpy() # Save target normalization factors
}, model_save_path)
print("Model saved.")

# --- Notify (Optional) ---
try:
    notify()
except NameError:
    pass # Ignore if notify function isn't defined

print("\nScript finished.")
