In [None]:
from transformers import AutoTokenizer, pipeline
from transformers.pipelines.pt_utils import KeyDataset
import torch, datasets, pandas as pd

task = "fb" # "semeval"
device = "cuda"

########################## Model list
# "google/flan-t5-xxl"
# "mistralai/Mistral-7B-Instruct-v0.3"
# "meta-llama/Meta-Llama-3-8B-Instruct"
# "meta-llama/Meta-Llama-3-70B-Instruct"

target = "meta-llama/Meta-Llama-3-8B-Instruct"

tokenizer = AutoTokenizer.from_pretrained(target, truncate=True)

# Zero-shot

In [None]:
if task == "fb":
    prompt3="This statement may express a STANCE towards two politicians, Trump and Clinton. Stance represents the attitude expressed towards them. The stance options are Favor, Against or None. Provide the answer in the following format: {Trump: STANCE, Clinton: STANCE}\n\n"
else:
    prompt3="This statement contains a TARGET and a STANCE. The target is a politician and the stance represents the attitude expressed about them. The target options are Trump or Clinton and stance options are Favor, Against or None. Provide the answer in the following format: {TARGET, STANCE}\n\n"
    
test = pd.read_csv(f"data/{task}_test.csv")
test['prompt'] = prompt3 + test['prompt']
test['chat_text'] = test['prompt'].apply(lambda x:tokenizer.apply_chat_template([{"role":"user", "content":x}], tokenize=False, add_generation_prompt=True))

test_ds = datasets.Dataset.from_pandas(test)

In [None]:
pipe = pipeline("text2text-generation" if target == "google/flan-t5-xxl" else "text-generation",
                model=target,
                tokenizer=tokenizer,
                device_map=device,
                torch_dtype=torch.bfloat16)

pipe.tokenizer.pad_token_id = tokenizer.eos_token_id

In [None]:
output_labels = []
for out in pipe(KeyDataset(test_ds, "chat_text"),
                add_special_tokens=True,
                return_full_text=False,
                do_sample=True,
                temperature=0.1,
                max_new_tokens=20,
                batch_size=1):
    output_labels.append(out[0]) 
    
test_preds = pd.concat([test, pd.DataFrame.from_dict(output_labels)], axis=1)
test_preds.to_csv(f"predicted_labels/{task}_{target.split("/")[-1]}_zero.csv", index=False)