In [None]:
import pandas as pd
import numpy as np
import torch

from sklearn.metrics import mean_squared_error, mean_absolute_error

# Add the parent directory to sys.path so twin_model can be imported
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))

from twin_model.model import BiGRUWithAttention
from twin_model.utils import load_model
import matplotlib.pyplot as plt

# Load processed data
df = pd.read_csv("../data/processed/ai4i_cleaned.csv")
sensor_cols = ['Air temperature [K]', 'Process temperature [K]',
               'Rotational speed [rpm]', 'Torque [Nm]', 'Tool wear [min]']

df_healthy = df[df['Machine failure'] == 0].reset_index(drop=True)
sequence = df_healthy[sensor_cols].iloc[300:320].values

# Load model
hidden_size = 64  # Change this to match your model's training configuration
model = load_model(BiGRUWithAttention, "../models/bigru_attention_twin.pth", len(sensor_cols), hidden_size)
model.eval()

# Predict
with torch.no_grad():
    input_tensor = torch.tensor(sequence).unsqueeze(0).float()  # shape: [1, 20, 5]
    pred = model(input_tensor)[0].numpy()  # shape: [20, 5]

# Compute metrics
for i, col in enumerate(sensor_cols):
    rmse = mean_squared_error(sequence[:, i], pred[:, i], squared=False)
    mae = mean_absolute_error(sequence[:, i], pred[:, i])
    print(f"{col}: RMSE = {rmse:.4f}, MAE = {mae:.4f}")

# Optional: plot one sensor
plt.figure(figsize=(8, 4))
plt.plot(sequence[:, 3], label="True Torque")
plt.plot(pred[:, 3], label="Predicted Torque", linestyle='--')
plt.title("Torque Prediction vs Ground Truth")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()


RuntimeError: Error(s) in loading state_dict for BiGRUWithAttention:
	Missing key(s) in state_dict: "bigru.weight_ih_l0", "bigru.weight_hh_l0", "bigru.bias_ih_l0", "bigru.bias_hh_l0", "bigru.weight_ih_l0_reverse", "bigru.weight_hh_l0_reverse", "bigru.bias_ih_l0_reverse", "bigru.bias_hh_l0_reverse", "attn.weight", "attn.bias". 
	Unexpected key(s) in state_dict: "gru.weight_ih_l0", "gru.weight_hh_l0", "gru.bias_ih_l0", "gru.bias_hh_l0". 
	size mismatch for fc.weight: copying a param with shape torch.Size([5, 64]) from checkpoint, the shape in current model is torch.Size([5, 128]).