In [1]:
import os
import json
import pandas as pd
from transformers import pipeline, BartForSequenceClassification, BartTokenizer
import re
import tqdm
import torch

Paths

In [2]:
data_dir = os.path.join(os.getenv("DATA_DIR"), "narrative_understanding/chatter")

Read file

In [3]:
character_desc_file = os.path.join(data_dir, "character_descriptions.csv")
character_desc_df = pd.read_csv(character_desc_file, index_col=None)
print(character_desc_df.shape)

(710294, 4)


Attribute types

In [4]:
attribute_types = ["physical appearance", "attire", "mental state", "demeanor", "profession", "age", "race",
                   "qualities", "healh status"]
attribute_types = sorted(attribute_types)
print(f"{len(attribute_types)} attribute types")
print("attribute types =>")
print(attribute_types)

9 attribute types
attribute types =>
['age', 'attire', 'demeanor', 'healh status', 'mental state', 'physical appearance', 'profession', 'qualities', 'race']


Model

In [5]:
bart_nli_model = BartForSequenceClassification.from_pretrained("facebook/bart-large-mnli")
bart_nli_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-mnli")
bart_nli_model.to("cuda:0");

Prepare Input

In [6]:
segment_texts, attribute_labels = [], []
character_desc_sample_df = character_desc_df.sample(4)

for _, row in character_desc_sample_df.iterrows():
    character = row["character"]
    for attr in attribute_types:
        segment_texts.append(row["segment_text"])
        attribute_labels.append(f"This text describes the {attr} of {character}")

print(f"{len(segment_texts)} samples")

36 samples


In [7]:
result = bart_nli_tokenizer(segment_texts, attribute_labels, return_tensors="pt", padding="longest")
result = result.to("cuda:0")
bart_nli_model.eval()
with torch.no_grad():
    logits = bart_nli_model(**result)[0]

In [8]:
logits = logits[:,[0,2]]
prob = torch.softmax(logits, dim=1)
print(prob.shape)

torch.Size([36, 2])


In [11]:
prob = prob.cpu().numpy()
print(prob.shape, type(prob))

(36, 2) <class 'numpy.ndarray'>


In [13]:
for segment_text, attr, p in zip(segment_texts, attribute_labels, prob):
    if p[1] > 0.5:
        print("SEGMENT =>")
        print(segment_text)
        print(f"ATTRIBUTE TYPE = {attr}")
        print()

SEGMENT =>
Ken Taylor drives up to a nice, but not huge, home in the expat/diplomatic neighborhood of Tehran. Automatic gates open.
ATTRIBUTE TYPE = This text describes the profession of KEN TAYLOR

SEGMENT =>
Qjgku ughaa eSa fcYdqy Bhd gw!A FkSaD;wA vki yksx tk,] vki yksx tk,A eSa pyk tkAaxkA Farhan starts running towards the exit. The flummoxed Medical Staff call out after him.
ATTRIBUTE TYPE = This text describes the demeanor of FARHAN

SEGMENT =>
Qjgku ughaa eSa fcYdqy Bhd gw!A FkSaD;wA vki yksx tk,] vki yksx tk,A eSa pyk tkAaxkA Farhan starts running towards the exit. The flummoxed Medical Staff call out after him.
ATTRIBUTE TYPE = This text describes the mental state of FARHAN

