In [1]:
import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration
import os
import pandas as pd
import numpy as np
import tqdm

In [2]:
tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-small")
model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-small")
model.to("cuda:0");

In [13]:
text = "DETECTIVE SERGEANT ALONZO HARRIS, in black shirt, black leather jacket. And just enough platinum and diamonds to look like somebody. He reads the paper in a booth. The gun leather-tough LAPD vet is a hands-on, blue-collar cop who can kick your ass with a look."
character = "ALONZO HARRIS"
result = tokenizer(f"Answer yes or no if the following text describes or mentions {character}'s health status."
                   f"\n\n{text}\n\nAnswer:", return_tensors="pt").to("cuda:0")

In [14]:
output = model.generate(result["input_ids"], max_new_tokens=5, temperature=0.5)

In [15]:
tokenizer.decode(output[0], skip_special_tokens=True)

'yes'

In [None]:
# read character descriptions
data_dir = os.path.join(os.getenv("DATA_DIR"), "narrative_understanding/chatter")
character_desc_file = os.path.join(data_dir, "character_descriptions.csv")
character_desc_df = pd.read_csv(character_desc_file, index_col=None)
print(f"{len(character_desc_df)} character descriptions")

# read attribute types
attr_type_file = os.path.join(data_dir, "attributes.txt")
with open(attr_type_file, "r") as fr:
    attributes = fr.read().strip().split("\n")
attributes = sorted(attributes)
print(f"{len(attributes)} attribute types")
print("attribute types =>")
print(attributes)
print()

# prepare dataset
ids, segment_texts, labels = [], [], []
hypothesis = ""
n = 1000
character_desc_df = character_desc_df.sample(n)
character_desc_df["segment_text_size"] = character_desc_df["segment_text"].str.split().apply(len)
character_desc_df.sort_values(by="segment_text_size", ascending=False, inplace=True)
for ind, row in tqdm.tqdm(character_desc_df.iterrows(), total=len(character_desc_df), desc="creating nli data"):
    character = str(row["character"])
    for attr in attributes:
        ids.append(ind)
        segment_texts.append(row["segment_text"])
        label = hypothesis.replace("CHARACTER", character).replace("ATTRIBUTE", attr)
        labels.append(label)
print()
print(f"{len(segment_texts)} samples\n")