In [None]:
import os
import numpy as np
import pandas as pd
from tqdm import tqdm

import torch
from transformers import AutoTokenizer, AutoModel

# Load day-level docs (so you only embed one document per day)
df = pd.read_parquet("../data/processed/model_table_clean.parquet").copy()
df["trading_date"] = pd.to_datetime(df["trading_date"])
df = df.sort_values("trading_date")

day_df = (
    df.groupby("trading_date")
      .agg(
          doc=("clean_headline", lambda x: " ".join(x.tolist())),
          return_t_plus_1=("return_t_plus_1", "first")
      )
      .reset_index()
)

day_df["label"] = (day_df["return_t_plus_1"] > 0).astype(int)

# Drop early days without labels if needed
day_df = day_df.dropna().copy()

model_name = "yiyanghkust/finbert-tone"
tokenizer = AutoTokenizer.from_pretrained(model_name)
base_model = AutoModel.from_pretrained(model_name)

# Freeze model (no gradients needed)
base_model.eval()
for p in base_model.parameters():
    p.requires_grad = False

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
base_model.to(device)

print("Device:", device)
print("Days:", len(day_df))

def embed_texts(texts, max_length=128, batch_size=32):
    """
    Compute FinBERT embeddings for a list of texts using mean pooling of last hidden state.

    We do mean pooling because it is simple, stable, and works well as a generic embedding.
    """
    all_vecs = []

    with torch.no_grad():
        for i in tqdm(range(0, len(texts), batch_size)):
            batch = texts[i:i+batch_size]

            enc = tokenizer(
                batch,
                padding=True,
                truncation=True,
                max_length=max_length,
                return_tensors="pt"
            ).to(device)

            out = base_model(**enc)
            # out.last_hidden_state: (batch, seq, hidden)
            hidden = out.last_hidden_state

            # Attention mask to avoid pooling over padding tokens
            mask = enc["attention_mask"].unsqueeze(-1)  # (batch, seq, 1)
            masked_hidden = hidden * mask

            # Mean pooling
            sum_hidden = masked_hidden.sum(dim=1)
            denom = mask.sum(dim=1).clamp(min=1e-9)
            mean_pooled = sum_hidden / denom

            all_vecs.append(mean_pooled.cpu().numpy())

    return np.vstack(all_vecs)

embeddings = embed_texts(day_df["doc"].tolist(), max_length=128, batch_size=32)
print("Embeddings shape:", embeddings.shape)

os.makedirs("../data/features", exist_ok=True)

emb_df = pd.DataFrame(embeddings)
emb_df.insert(0, "trading_date", day_df["trading_date"].values)
emb_df.insert(1, "label", day_df["label"].values)
emb_df.insert(2, "return_t_plus_1", day_df["return_t_plus_1"].values)

out_path = "../data/features/finbert_day_embeddings.parquet"
emb_df.to_parquet(out_path, index=False)

print("Saved:", out_path)
