In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import pandas as pd
from tqdm import tqdm

device = torch.device("cuda:4" if torch.cuda.is_available() else "cpu")
model_id = "mistralai/Mistral-7B-Instruct-v0.1"
tokenizer = AutoTokenizer.from_pretrained(model_id)

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.float16,
    device_map={"": device}
).to(device)

def generate_narrative_from_tweets(df_subset, max_tokens=512):
    tweets_text = "\n".join(df_subset["tweet_text"].dropna().astype(str).tolist())

    prompt = f"""You are a narrative extractor. The following tweets come from a coordinated group around the same topic. 
    Write a concise narrative summarizing their message, framing, and tone.

Tweets:
{tweets_text}

Narrative:"""

    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048).to(device)
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_tokens,
            temperature=0.7,
            top_p=0.9,
            do_sample=True,
            eos_token_id=tokenizer.eos_token_id
        )

    result = tokenizer.decode(outputs[0], skip_special_tokens=True)
    narrative = result.split("Narrative:")[-1].strip()
    return narrative

narratives = []

for topic_id, df_topic in tqdm(df_input.groupby("topic_id")):
    orientations = df_topic["user_orientation"].dropna().unique()

    for orientation in orientations:
        df_oriented = df_topic[df_topic["user_orientation"] == orientation]
        df_sorted = df_oriented.sort_values(by="tweet_id")

        try:
            narrative = generate_narrative_from_tweets(df_sorted)
            narratives.append({
                "topic_id": topic_id,
                "user_orientation": orientation,
                "narrative": narrative,
                "num_tweets": len(df_sorted)
            })
        except Exception as e:
            narratives.append({
                "topic_id": topic_id,
                "user_orientation": orientation,
                "narrative": f"Error: {str(e)}",
                "num_tweets": len(df_sorted)
            })

df_narratives = pd.DataFrame(narratives)
