In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from torch.utils.data import DataLoader
from datasets import Dataset
import pandas as pd
from tqdm import tqdm
from transformers import DataCollatorWithPadding

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM


model_path = "google/flan-t5-large"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model_path = "Anshul99/finetuned_model_final"
model = AutoModelForSeq2SeqLM.from_pretrained(model_path)


# Use GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
print(device) 


if "masked_text" not in df.columns:
    raise ValueError("The DataFrame must contain a 'masked_text' column.")

df["masked_text"] = df["masked_text"].astype(str)

def preprocess_function(examples):
    prefix = (
        "Classify stance towards '[MASK]' as 'positive' (supporting/defending '[MASK]') or 'negative' (opposing/criticizing '[MASK]'). Sentence: "
    )
    inputs = [prefix + text for text in examples["masked_text"]]
    return tokenizer(inputs, truncation=True, padding="longest", max_length=512)

dataset = Dataset.from_pandas(df)
tokenized_dataset = dataset.map(preprocess_function, batched=True, remove_columns=dataset.column_names)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer, return_tensors="pt")
batch_size = 64  
dataloader = DataLoader(tokenized_dataset, batch_size=batch_size, collate_fn=data_collator)
stance_predictions = []
model.eval()
with torch.no_grad():
    for batch in tqdm(dataloader, desc="Generating Predictions"):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        outputs = model.generate(input_ids=input_ids, attention_mask=attention_mask, max_length=2)
        decoded_preds = tokenizer.batch_decode(outputs, skip_special_tokens=True)
        stance_predictions.extend(decoded_preds)
df["stance"] = stance_predictions