In [1]:
import warnings
warnings.filterwarnings("ignore")

import pandas as pd
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, Seq2SeqTrainingArguments, Seq2SeqTrainer
from datasets import load_dataset
from transformers import DataCollatorForSeq2Seq
import torch
import evaluate
import re
import numpy as np
from tqdm import tqdm

In [2]:
torch.cuda.empty_cache()

model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-large", 
                                              cache_dir="/scratch/wadhwa.s/cache", 
                                              device_map="auto")
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-large", 
                                          cache_dir="/scratch/wadhwa.s/cache")

In [3]:
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)

In [4]:
relation_map = {
    "/location/location/contains" : "LOC_CONTAINS",
    "/location/administrative_division/country" : "LOC_COUNTRY",
    "/location/country/capital" : "LOC_CAPITAL",
    "/location/neighborhood/neighborhood_of" : "LOC_NEIGHBORHOOD",
    "/location/country/administrative_divisions" : "LOC_ADMIN_DIVISIONS",

    "/people/person/nationality" : "PEO_NATIONALITY",
    "/people/person/place_lived" : "PEO_PLACE_LIVED",
    "/people/person/place_of_birth" : "PEO_PLACE_OF_BIRTH",
    "/people/deceased_person/place_of_death" : "PEO_PLACE_OF_DEATH",
    "/people/person/children" : "PEO_CHILDREN",
    "/people/person/religion" : "PEO_RELIGION",
    "/people/ethnicity/geographic_distribution" : "PEO_GEORGRAPHIC_DISTRIBUTION",
    "/people/person/ethnicity" : "PEO_ETHNICITY",
    "/people/ethnicity/people" : "PEO_PEOPLE",
    "/people/person/profession" : "PEO_PROFESSION",

    "/business/person/company" : "BUS_COMPANY",
    "/business/company/founders" : "BUS_FOUNDERS",
    "/business/company/place_founded" : "BUS_PLACE_FOUNDED",
    "/business/company_shareholder/major_shareholder_of" : "BUS_MAJOR_SHAREHOLDER",
    "/business/company/major_shareholders" : "BUS_MAJOR_SHAREHOLDERS",
    "/business/company/advisors" : "BUS_ADVISORS",
    "/business/company/industry" : "BUS_INDUSTRY",

    "/sports/sports_team_location/teams" : "SPO_TEAMS",
    "/sports/sports_team/location" : "SPO_LOCATION",
}

In [11]:
for key, value in relation_map.items():
    print ("<li><strong>",key,"</strong>: ", rel_map_inv[rel_map[value]],"</li>")

<li><strong> /location/location/contains </strong>:  [A, B] Entity A is a location that contains Entity B </li>
<li><strong> /location/administrative_division/country </strong>:  [A, B] Entity B is a country that contains Entity A </li>
<li><strong> /location/country/capital </strong>:  [A, B] Entity B is the capital of Entity A </li>
<li><strong> /location/neighborhood/neighborhood_of </strong>:  [A, B] Entity B is a neighborhood that contains Entity A </li>
<li><strong> /location/country/administrative_divisions </strong>:  [A, B] Entity B is located in the administrative divisions of Entity A </li>
<li><strong> /people/person/nationality </strong>:  [A, B] Entity A is a national of Entity B </li>
<li><strong> /people/person/place_lived </strong>:  [A, B] Entity A has lived in Entity B </li>
<li><strong> /people/person/place_of_birth </strong>:  [A, B] Entity A was born in Entity B </li>
<li><strong> /people/deceased_person/place_of_death </strong>:  [A, B] Entity A died in Entity B 

In [8]:
assign = 0
rel_map = {}
for r in list(relation_map.values()):
    rel_map[r] = assign
    assign += 1

print (rel_map)

{'LOC_CONTAINS': 0, 'LOC_COUNTRY': 1, 'LOC_CAPITAL': 2, 'LOC_NEIGHBORHOOD': 3, 'LOC_ADMIN_DIVISIONS': 4, 'PEO_NATIONALITY': 5, 'PEO_PLACE_LIVED': 6, 'PEO_PLACE_OF_BIRTH': 7, 'PEO_PLACE_OF_DEATH': 8, 'PEO_CHILDREN': 9, 'PEO_RELIGION': 10, 'PEO_GEORGRAPHIC_DISTRIBUTION': 11, 'PEO_ETHNICITY': 12, 'PEO_PEOPLE': 13, 'PEO_PROFESSION': 14, 'BUS_COMPANY': 15, 'BUS_FOUNDERS': 16, 'BUS_PLACE_FOUNDED': 17, 'BUS_MAJOR_SHAREHOLDER': 18, 'BUS_MAJOR_SHAREHOLDERS': 19, 'BUS_ADVISORS': 20, 'BUS_INDUSTRY': 21, 'SPO_TEAMS': 22, 'SPO_LOCATION': 23}


In [9]:
rel_map_inv = {
    0 : "[A, B] Entity A is a location that contains Entity B",
    1 : "[A, B] Entity B is a country that contains Entity A",
    2 : "[A, B] Entity B is the capital of Entity A",
    3 : "[A, B] Entity B is a neighborhood that contains Entity A",
    4 : "[A, B] Entity B is located in the administrative divisions of Entity A",
    
    5 : "[A, B] Entity A is a national of Entity B",
    6 : "[A, B] Entity A has lived in Entity B",
    7 : "[A, B] Entity A was born in Entity B",
    8 : "[A, B] Entity A died in Entity B",
    9 : "[A, B] Entity B is the child of Entity A",
    10 : "[A, B] Entity A is a member Entity B",
    11 : "[A, B] Entity A is a member of Entity B's geographic distribution",
    12 : "[A, B] Entity A is of ethnicity Entity B",
    13 : "[A, B] Entity B is the ethnicity of Entity A",
    14 : "[A, B] Entity B is the profession of Entity A",
    
    15 : "[A, B] Entity A is a member of company Entity B",
    16 : "[A, B] Entity A was founded by Entity B",
    17 : "[A, B] Entity A was founded in Entity B",
    18 : "[A, B] Entity B is a major shareholder of Entity A",
    19 : "[A, B] Entity B is a major shareholder of Entity A",
    20 : "[A, B] Entity B advises to Entity A",
    21 : "[A, B] Entity A is in the industry Entity B",
    
    22 : "[A, B] Entity B is a team located in Entity A",
    23 : "[A, B] Entity B is a location that contains sports team Entity A",

}

In [7]:
df = pd.read_csv("nyt_fewshot_explanations.csv")
for ix, row in df.iterrows():
    print(row["text"])
    print (row["gold_relations"] + "\n" + row["explanations"])
    print ("\n----------------\n")

Massachusetts ASTON MAGNA Great Barrington ; also at Bard College , Annandale-on-Hudson , N.Y. , July 1-Aug .
[['Annandale-on-Hudson', 'LOC_CONTAINS', 'Bard College']]
Explanations: Bard College is located in Annandale-on-Hudson, NY; 

----------------

It will be the final movie credited to Debra Hill , a film producer and native of Haddonfield , who produced '' Halloween '' and was considered a pioneering woman in film .
[['Debra Hill', 'PEO_PLACE_OF_BIRTH', 'Haddonfield']]
Explanations: Film producer Debra Hill was a native of (born-in) Haddonfield; 

----------------

Under pressure from Mr. Kerkorian and other disgruntled shareholders , Mr. Wagoner started talks on Friday in Detroit with Carlos Ghosn , the chief executive of Renault and Nissan .
[['Carlos Ghosn', 'BUS_COMPANY', 'Renault']]
Explanations: Carlos Ghosn is the chief executive officer of (member-of) Renault; 

----------------

Mr. Ferrer still holds commanding leads over the other two Democrats in the race -- United S

In [8]:
nyt_data = load_dataset("csv", data_files="nyt_fewshot_explanations.csv")
nyt = nyt_data["train"]

# nyt_reference = load_dataset("csv", data_files="nyt_fewshot_reference.csv")
# nyt_eval = nyt_reference["train"]

Using custom data configuration default-2c630774495b6253
Found cached dataset csv (/home/wadhwa.s/.cache/huggingface/datasets/csv/default-2c630774495b6253/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317)
100%|██████████| 1/1 [00:00<00:00, 463.56it/s]


In [9]:
nyt[10]

{'text': "Somewhat chastened by his retreat in the polls , Mr. Blair acknowledged that Britons had turned against him in part over accusations that he led them into a war in Iraq on dubious legal grounds and on the false premise that Saddam Hussein presented a direct threat because of a supposed arsenal of unconventional weapons that was never found . ''",
 'true_relations': "[['Saddam Hussein', 8, 'Iraq'], ['Saddam Hussein', 7, 'Iraq'], ['Saddam Hussein', 5, 'Iraq']]",
 'gold_relations': "[['Saddam Hussein', 'PEO_PLACE_OF_DEATH', 'Iraq'], ['Saddam Hussein', 'PEO_PLACE_OF_BIRTH', 'Iraq'], ['Saddam Hussein', 'PEO_NATIONALITY', 'Iraq']]",
 'explanations': 'Explanations: Saddam Hussein died in Iraq; Saddam Hussein was from Iraq; Saddam Hussein was the president of Iraq; '}

In [10]:
prefix = """ Find the following relations in the TEXT below:
    0 : "[A, B] Entity A is a location that contains Entity B",
    1 : "[A, B] Entity B is a country that contains Entity A",
    2 : "[A, B] Entity B is the capital of Entity A",
    3 : "[A, B] Entity B is a neighborhood that contains Entity A",
    4 : "[A, B] Entity B is located in the administrative divisions of Entity A",
    5 : "[A, B] Entity A is a national of Entity B",
    6 : "[A, B] Entity A has lived in Entity B",
    7 : "[A, B] Entity A was born in Entity B",
    8 : "[A, B] Entity A died in Entity B",
    9 : "[A, B] Entity B is the child of Entity A",
    10 : "[A, B] Entity A is a member Entity B",
    11 : "[A, B] Entity A is a member of Entity B's geographic distribution",
    12 : "[A, B] Entity A is of ethnicity Entity B",
    13 : "[A, B] Entity B is the ethnicity of Entity A",
    14 : "[A, B] Entity B is the profession of Entity A",
    15 : "[A, B] Entity A is a member of company Entity B",
    16 : "[A, B] Entity A was founded by Entity B",
    17 : "[A, B] Entity A was founded in Entity B",
    18 : "[A, B] Entity B is a major shareholder of Entity A",
    19 : "[A, B] Entity B is a major shareholder of Entity A",
    20 : "[A, B] Entity B advises to Entity A",
    21 : "[A, B] Entity A is in the industry Entity B",
    22 : "[A, B] Entity B is a team located in Entity A",
    23 : "[A, B] Entity B is a location that contains sports team Entity A",
    
TEXT: """

print(prefix)

 Find the following relations in the TEXT below:
    0 : "[A, B] Entity A is a location that contains Entity B",
    1 : "[A, B] Entity B is a country that contains Entity A",
    2 : "[A, B] Entity B is the capital of Entity A",
    3 : "[A, B] Entity B is a neighborhood that contains Entity A",
    4 : "[A, B] Entity B is located in the administrative divisions of Entity A",
    5 : "[A, B] Entity A is a national of Entity B",
    6 : "[A, B] Entity A has lived in Entity B",
    7 : "[A, B] Entity A was born in Entity B",
    8 : "[A, B] Entity A died in Entity B",
    9 : "[A, B] Entity B is the child of Entity A",
    10 : "[A, B] Entity A is a member Entity B",
    11 : "[A, B] Entity A is a member of Entity B's geographic distribution",
    12 : "[A, B] Entity A is of ethnicity Entity B",
    13 : "[A, B] Entity B is the ethnicity of Entity A",
    14 : "[A, B] Entity B is the profession of Entity A",
    15 : "[A, B] Entity A is a member of company Entity B",
    16 : "[A, B] En

In [11]:
prefix = """ Find the following relations in the TEXT below and provide brief explanation:

TEXT: """

print(prefix)

 Find the following relations in the TEXT below and provide brief explanation:

TEXT: 


In [12]:
targets = [gold_label + "\n" + explanation for gold_label, explanation in zip(nyt["gold_relations"], nyt["explanations"])]
print (targets[10])

[['Saddam Hussein', 'PEO_PLACE_OF_DEATH', 'Iraq'], ['Saddam Hussein', 'PEO_PLACE_OF_BIRTH', 'Iraq'], ['Saddam Hussein', 'PEO_NATIONALITY', 'Iraq']]
Explanations: Saddam Hussein died in Iraq; Saddam Hussein was from Iraq; Saddam Hussein was the president of Iraq; 


In [13]:
def preprocess_function(examples):
    inputs = [prefix + example for example in examples["text"]]
    model_inputs = tokenizer(inputs, max_length=1024, truncation=True)
    targets = [gold_label + "\n" + explanation for gold_label, explanation in zip(nyt["gold_relations"], nyt["explanations"])]
    labels = tokenizer(text_target=targets, max_length=512, truncation=True)
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [14]:
tokenized_nyt = nyt.map(preprocess_function, batched=True)

Loading cached processed dataset at /home/wadhwa.s/.cache/huggingface/datasets/csv/default-2c630774495b6253/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317/cache-1bd9d11c5f7a2275.arrow


In [15]:
tokenized_nyt

Dataset({
    features: ['text', 'true_relations', 'gold_relations', 'explanations', 'input_ids', 'attention_mask', 'labels'],
    num_rows: 20
})

In [16]:
torch.cuda.empty_cache()

In [17]:
torch.cuda.empty_cache()

training_args = Seq2SeqTrainingArguments(
    output_dir="/scratch/wadhwa.s/cache/myawesomemodel_nyt",
    evaluation_strategy="no",
    learning_rate=3e-5,
    per_device_train_batch_size=2,
    # per_device_eval_batch_size=2,
    logging_steps=8,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=12,
    predict_with_generate=False,
    # gradient_accumulation_steps=4,
    # fp16=True,
    # push_to_hub=True,
)

In [18]:
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_nyt,
    # eval_dataset=tokenized_ade["test"],
    tokenizer=tokenizer,
    data_collator=data_collator,
)

In [19]:
trainer.train()

The following columns in the training set don't have a corresponding argument in `T5ForConditionalGeneration.forward` and have been ignored: text, explanations, true_relations, gold_relations. If text, explanations, true_relations, gold_relations are not expected by `T5ForConditionalGeneration.forward`,  you can safely ignore this message.
***** Running training *****
  Num examples = 20
  Num Epochs = 12
  Instantaneous batch size per device = 2
  Total train batch size (w. parallel, distributed & accumulation) = 2
  Gradient Accumulation steps = 1
  Total optimization steps = 120
  Number of trainable parameters = 783150080
Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33msw7[0m. Use [1m`wandb login --relogin`[0m to force relogin


You're using a T5TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Step,Training Loss
8,2.0524
16,1.5583
24,1.2708
32,1.0004
40,0.9054
48,0.8193
56,0.7768
64,0.6658
72,0.6567
80,0.5378




Training completed. Do not forget to share your model on huggingface.co/models =)




TrainOutput(global_step=120, training_loss=0.8543308993180593, metrics={'train_runtime': 54.2754, 'train_samples_per_second': 4.422, 'train_steps_per_second': 2.211, 'total_flos': 95665875959808.0, 'train_loss': 0.8543308993180593, 'epoch': 12.0})

In [20]:
torch.cuda.empty_cache()

In [26]:
tuned_model = AutoModelForSeq2SeqLM.from_pretrained("/scratch/wadhwa.s/cache/myawesomemodel/checkpoint-1000")
tokenizer = AutoTokenizer.from_pretrained("/scratch/wadhwa.s/cache/myawesomemodel/checkpoint-1000")

loading configuration file /scratch/wadhwa.s/cache/myawesomemodel/checkpoint-1000/config.json
Model config T5Config {
  "_name_or_path": "/scratch/wadhwa.s/cache/myawesomemodel/checkpoint-1000",
  "architectures": [
    "T5ForConditionalGeneration"
  ],
  "d_ff": 2816,
  "d_kv": 64,
  "d_model": 1024,
  "decoder_start_token_id": 0,
  "dense_act_fn": "gelu_new",
  "dropout_rate": 0.1,
  "eos_token_id": 1,
  "feed_forward_proj": "gated-gelu",
  "initializer_factor": 1.0,
  "is_encoder_decoder": true,
  "is_gated_act": true,
  "layer_norm_epsilon": 1e-06,
  "model_type": "t5",
  "n_positions": 512,
  "num_decoder_layers": 24,
  "num_heads": 16,
  "num_layers": 24,
  "output_past": true,
  "pad_token_id": 0,
  "relative_attention_max_distance": 128,
  "relative_attention_num_buckets": 32,
  "tie_word_embeddings": false,
  "torch_dtype": "float32",
  "transformers_version": "4.24.0",
  "use_cache": true,
  "vocab_size": 32128
}

loading weights file /scratch/wadhwa.s/cache/myawesomemodel/ch

In [17]:
df = pd.read_csv("conll_fewshot_reference.csv")
df.head()

Unnamed: 0,text,gold_labels,explanations
0,"John Wilkes Booth , who assassinated President...","[['John Wilkes Booth', 'Kill', 'President Linc...",********
1,The opera company performed at the Palace of F...,"[['Palace of Fine Arts', 'Located_In', 'San Fr...",********
2,"In the field of mechanics , Wang Ziqiang at th...","[['Wang Ziqiang', 'Work_For', 'Institute of Me...",********
3,"Sun Hung Kai Properties , a Hong Kong construc...","[['Sun Hung Kai Properties', 'OrgBased_In', 'H...",********
4,"Marie Magdefrau Ferraro , 50 , of Bethany , Co...","[['Marie Magdefrau Ferraro', 'Live_In', 'Betha...",********


In [19]:
ip = []
gold = []
generated = []

for ix, row in df.iterrows():
    text = prefix + row["text"]
    inputs = tokenizer(text, return_tensors="pt").input_ids
    outputs = model.generate(inputs, max_new_tokens=256, do_sample=False)
    out = tokenizer.decode(outputs[0], skip_special_tokens=True)
    torch.cuda.empty_cache()
    print (row["text"])
    print ("GOLD: " + row["gold_labels"])
    try:
        generated.append(out.split(" Explanation: ")[0])
        ip.append(row["text"])
        gold.append(row["gold_labels"])
        print ("PRED: " + out)
    except:
        print ("PRED ---- NON CONFORMING OUTPUT: " + out)

    print ("\n----------------\n")


John Wilkes Booth , who assassinated President Lincoln , was an actor .
GOLD: [['John Wilkes Booth', 'Kill', 'President Lincoln']]
PRED: [['John Wilkes Booth', 'Work_For', 'John Wilkes Booth']] Explanation: John Wilkes Booth was an actor who assassinated President Lincoln.

----------------

The opera company performed at the Palace of Fine Arts , in San Francisco , on June 30 and July 1-2 , said Kevin O 'Brien , a spokesman for the theater.
GOLD: [['Palace of Fine Arts', 'Located_In', 'San Francisco']]
PRED: [['Kevin O'Brien', 'Work_For', 'Palace of Fine Arts'], ['Kevin O'Brien', 'Work_For', 'San Francisco']] Explanation: The opera company performed at the Palace of Fine Arts, which is located in San Francisco.

----------------

In the field of mechanics , Wang Ziqiang at the Institute of Mechanics has made considerable headway in the area of elastoplastic crack mechanics .
GOLD: [['Wang Ziqiang', 'Work_For', 'Institute of Mechanics']]
PRED: [['Wang Ziqiang', 'Work_For', 'Institute o

In [27]:
df = pd.DataFrame({"text": ip, "gold_labels": gold, "generated": generated})
#df.to_csv("conll_fewshot_flan_generated.csv", index=False)
print (df.shape)

(231, 3)


In [38]:
invalid_count = 0
total_count = 0
valid_parsed = 0
for ix, row in df.iterrows():
    # print ("GOLD: " + row["gold_labels"])
    try:
        pred = ast.literal_eval(row["generated"])
        total_count += len(pred)
        valid_parsed += 1
        for relation in pred:
            if relation[1] not in valid:
                print (row["text"])
                print ("PRED: " + relation)
                print ("\n----------------\n")
    except:
        print ("NON CONFORMING PRED: ", row["generated"])
        print ("\n----------------\n")

NON CONFORMING PRED:  [['Kevin O'Brien', 'Work_For', 'Palace of Fine Arts'], ['Kevin O'Brien', 'Work_For', 'San Francisco']]

----------------

The Warren Commission found Oswald acted alone in killing Kennedy .
NON CONFORMING PRED:  [['Oswald', 'Acted_Alone', 'Kill'], ['Oswald', 'Live_In', 'Williams County'], ['Oswald', 'Live_In', 'Williams County']]

----------------

Martin Luther King III , a son of the slain civil rights leader , said the execution of King assassin James Earl Ray ` ` would not bring my father back. ' '
NON CONFORMING PRED:  [['Martin Luther King III', 'Son', 'Martin Luther King III']]

----------------

` ` With South Carolina being Jesse Jackson 's home state , there was a very strong incentive in the black community. ' '
NON CONFORMING PRED:  [['Jesse Jackson', 'Home_State', 'South Carolina']]

----------------

You 'll pass through Rolling Fork , where Muddy Waters was born , close to Greenville , home of Nelson Street 's funky strip of blues clubs , and into C

In [40]:
print (invalid_count, total_count, valid_parsed)

0 244 217


In [33]:
s1 = "It is my sincere hope that we will reignite our united purpose , '' Senator Mitch McConnell of Kentucky , the Republican whip , said on the Senate floor before the ceremony ."

In [34]:
text = [prefix + s1]
print (text[0])

 Find the following relations in the TEXT below and provide brief explanation:

TEXT: It is my sincere hope that we will reignite our united purpose , '' Senator Mitch McConnell of Kentucky , the Republican whip , said on the Senate floor before the ceremony .


In [35]:
torch.cuda.empty_cache()
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True).input_ids

In [36]:
len(inputs)

1

In [37]:
outputs = model.generate(inputs, max_new_tokens=256, do_sample=False)

In [38]:
out = tokenizer.decode(outputs[0], skip_special_tokens=True)
print (out)

[['Mitch McConnell', 'RELATIVE', 'Republican']] Explanations: Senator Mitch McConnell is the Republican whip of the Senate; he is the speaker of the Senate;


In [32]:
out = out.split(" Explanation")[0]

In [34]:
import ast

In [35]:
print (ast.literal_eval(out))