In [8]:
import os
import pickle
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score

# ===========================
#  Define same model as training
# ===========================
class BiLSTMRegressor(nn.Module):
    def __init__(self, feature_dim, hidden_size, num_layers, bidirectional, company_count, company_emb_dim, dropout=0.0):
        super().__init__()
        self.company_emb = nn.Embedding(company_count, company_emb_dim)
        rnn_input_dim = feature_dim + company_emb_dim
        self.lstm = nn.LSTM(
            input_size=rnn_input_dim,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout if num_layers > 1 else 0.0,
            bidirectional=bidirectional
        )
        out_dim = hidden_size * (2 if bidirectional else 1)
        self.head = nn.Sequential(
            nn.Linear(out_dim, out_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(out_dim // 2, 1)
        )

    def forward(self, x, c):
        b, s, f = x.shape
        c_emb = self.company_emb(c).unsqueeze(1).expand(-1, s, -1)
        rnn_in = torch.cat([x, c_emb], dim=-1)
        out, _ = self.lstm(rnn_in)
        last = out[:, -1, :]
        return self.head(last).squeeze(-1)


# ===========================
#  Load checkpoint and setup
# ===========================
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
checkpoint = torch.load("./bilstm_model.pt", map_location=DEVICE)

mean = checkpoint["mean"]
std = checkpoint["std"]
enc_classes = checkpoint["enc_classes"]
feature_dim = checkpoint["feature_dim"]

model = BiLSTMRegressor(
    feature_dim=feature_dim,
    hidden_size=128,
    num_layers=2,
    bidirectional=True,
    company_count=len(enc_classes),
    company_emb_dim=32,
    dropout=0.2
).to(DEVICE)
model.load_state_dict(checkpoint["model_state"])
model.eval()

# ===========================
#  Load test data
# ===========================
with open("./windows/test_windows.pkl", "rb") as f:
    test_windows = pickle.load(f)
with open("./windows/company_list.pkl", "rb") as f:
    companies = pickle.load(f)

enc = LabelEncoder().fit(enc_classes)

def ensure_encoded(windows):
    new = []
    for X, y, t in windows:
        if isinstance(t, str):
            t_idx = int(enc.transform([t])[0])
        else:
            t_idx = int(t)
        new.append((X, float(y), t_idx))
    return new

test_windows = ensure_encoded(test_windows)

def normalize_window(X):
    Xn = (X - mean) / std
    return np.nan_to_num(Xn, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32)

# ===========================
#  Run evaluation
# ===========================
true_vals, preds, tickers = [], [], []

for X, y, t in test_windows:
    X = normalize_window(X)
    X = torch.tensor(X, dtype=torch.float32).unsqueeze(0).to(DEVICE)
    t_tensor = torch.tensor([t], dtype=torch.long).to(DEVICE)
    with torch.no_grad():
        pred = model(X, t_tensor).item()
    true_vals.append(y)
    preds.append(pred)
    tickers.append(enc.inverse_transform([t])[0])

# ===========================
#  Compute metrics
# ===========================
mse = mean_squared_error(true_vals, preds)
mae = mean_absolute_error(true_vals, preds)
rmse = np.sqrt(mse)
r2 = r2_score(true_vals, preds)

print("\n📊 Evaluation Metrics:")
print(f"  MSE  = {mse:.6f}")
print(f"  RMSE = {rmse:.6f}")
print(f"  MAE  = {mae:.6f}")
print(f"  R²   = {r2:.6f}")

# ===========================
#  Save CSV
# ===========================
df = pd.DataFrame({
    "Company": tickers,
    "True": true_vals,
    "Predicted": preds
})
df.to_csv("bilstm_predictions.csv", index=False)
print("\n✅ Saved predictions to bilstm_predictions.csv")



  checkpoint = torch.load("./bilstm_model.pt", map_location=DEVICE)



📊 Evaluation Metrics:
  MSE  = 32.060603
  RMSE = 5.662208
  MAE  = 3.762915
  R²   = 0.996520

✅ Saved predictions to bilstm_predictions.csv
