In [1]:
import warnings
warnings.filterwarnings("ignore")

import pandas as pd
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, Seq2SeqTrainingArguments, Seq2SeqTrainer
from datasets import load_dataset
import json
from tqdm import tqdm
from transformers import DataCollatorForSeq2Seq
import torch
import evaluate
import numpy as np
import ast
import re

In [3]:
rouge = evaluate.load("rouge")

In [65]:
torch.cuda.empty_cache()

model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-large", 
                                              cache_dir="/scratch/wadhwa.s/cache", 
                                              device_map="auto")
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-large", 
                                          cache_dir="/scratch/wadhwa.s/cache")

In [66]:
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    result = rouge.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
    result["gen_len"] = np.mean(prediction_lens)
    return {k: round(v, 4) for k, v in result.items()}

In [67]:
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)

In [68]:
valid = ["Work_For", "Live_In", "OrgBased_In", "Kill", "Located_In"]
df = pd.read_csv("conll_gpt3_explanations_responses.csv")
# for ix, row in df.iterrows():
#     row["explanations"] = row["explanations"] + "</s>"

# df.to_csv("conll_gpt3_explanations_responses.csv", index=False)

In [69]:

for ix, row in df.sample(100).iterrows():
    print(row["input"])
    print ("Relations: " + row["gold_relations"] + "\n" + row["explanations"])
    print ("\n----------------\n")

TEXT: Farley attorney Stephen Greiner contended during Wednesday 's 90-minute hearing that the pill ` ` is a device by which the board of directors of West Point says nobody can buy more than 10 percent. ' '
Relations: [['Stephen Greiner', 'Work_For', 'Farley']]
Explanation: Stephen Greiner is an attorney that works for Farley.</s>

----------------

TEXT: However , I do have feelings. ( Minoli ) Do you think that Berlusconi is the Freedom Alliance 's candidate for the post of prime ? ( Fini ) We must let the electors decide .
Relations: [['Berlusconi', 'Work_For', 'Freedom Alliance']]
Explanation: Berlusconi is the candidate for the post of prime for the Freedom Alliance.</s>

----------------

TEXT: Remodeling Planned To Improve H-2 Rocket Launch OW2202091394 Tokyo MAINICHI SHIMBUN in Japanese 21 Feb 94 Morning Edition p 11 -- FOR OFFICIAL USE ONLY
Relations: [['MAINICHI SHIMBUN', 'OrgBased_In', 'Tokyo']]
Explanation: MAINICHI SHIMBUN is an organization based in Tokyo.</s>

---------

In [76]:
conll_data = load_dataset("csv", data_files="conll_gpt3_explanations_responses.csv")
conll = conll_data["train"]

conll_reference = load_dataset("csv", data_files="conll_fewshot_reference.csv")
conll_eval = conll_reference["train"]

Using custom data configuration default-7c18d0d9ebe66610
Found cached dataset csv (/home/wadhwa.s/.cache/huggingface/datasets/csv/default-7c18d0d9ebe66610/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317)
100%|██████████| 1/1 [00:00<00:00, 756.14it/s]
Using custom data configuration default-01766038312a4186
Found cached dataset csv (/home/wadhwa.s/.cache/huggingface/datasets/csv/default-01766038312a4186/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317)
100%|██████████| 1/1 [00:00<00:00, 730.21it/s]


In [77]:
prefix = """List all relations of the following type in the given text and provide reasonable explanations for your answers - \n1. Kill: Entity A killed Entity B.\n2. Work_For: Entity A works for Entity B.\n3. Located_In: Entity A is located in Entity B.\n4. Live_In: Entity A lives in Entity B.\n5. OrgBased_In: Entity A is an organization based in Entity B.\n\n"""

print (prefix)

List all relations of the following type in the given text and provide reasonable explanations for your answers - 
1. Kill: Entity A killed Entity B.
2. Work_For: Entity A works for Entity B.
3. Located_In: Entity A is located in Entity B.
4. Live_In: Entity A lives in Entity B.
5. OrgBased_In: Entity A is an organization based in Entity B.




In [78]:
targets = [gold_relations + "\n" + explanation for gold_relations, explanation in zip(conll["gold_relations"], conll["explanations"])]
print (targets[10])

[['Betsy Ross', 'Live_In', 'Philadelphia']]
Explanation: Betsy Ross was born in Philadelphia and therefore lived in Philadelphia.</s>


In [79]:
def preprocess_function(examples):
    inputs = [prefix + example + "\nRelations: " for example in examples["input"]]
    model_inputs = tokenizer(inputs, max_length=512, truncation=True)
    targets = [gold_relations + "\n" + explanation for gold_relations, explanation in zip(conll["gold_relations"], conll["explanations"])]
    labels = tokenizer(text_target=targets, max_length=256, truncation=True)
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [74]:
"""Split into train and validation sets"""
conll = conll.train_test_split(test_size=0.2)
conll

DatasetDict({
    train: Dataset({
        features: ['input', 'explanations', 'gold_relations'],
        num_rows: 724
    })
    test: Dataset({
        features: ['input', 'explanations', 'gold_relations'],
        num_rows: 182
    })
})

In [85]:
tokenized_conll = conll.map(preprocess_function, batched=True)

100%|██████████| 1/1 [00:00<00:00,  2.77ba/s]


In [86]:
tokenized_conll = tokenized_conll.train_test_split(test_size=0.2)
tokenized_conll["train"]

Dataset({
    features: ['input', 'explanations', 'gold_relations', 'input_ids', 'attention_mask', 'labels'],
    num_rows: 724
})

In [88]:
torch.cuda.empty_cache()

In [89]:
torch.cuda.empty_cache()

training_args = Seq2SeqTrainingArguments(
    output_dir="/scratch/wadhwa.s/cache/conll_explanations",
    evaluation_strategy="epoch",
    learning_rate=3e-5,
    per_device_train_batch_size=4,
    eval_delay=3,
    per_device_eval_batch_size=4,
    logging_strategy="epoch",
#    logging_steps=30,
    weight_decay=0.01,
    save_total_limit=7,
    num_train_epochs=25,
    predict_with_generate=True,
    # gradient_accumulation_steps=4,
    # fp16=True,
    # push_to_hub=True,
)

In [90]:
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_conll["train"],
    eval_dataset=tokenized_conll["test"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

In [91]:
trainer.train()

The following columns in the training set don't have a corresponding argument in `T5ForConditionalGeneration.forward` and have been ignored: input, gold_relations, explanations. If input, gold_relations, explanations are not expected by `T5ForConditionalGeneration.forward`,  you can safely ignore this message.
***** Running training *****
  Num examples = 724
  Num Epochs = 25
  Instantaneous batch size per device = 4
  Total train batch size (w. parallel, distributed & accumulation) = 4
  Gradient Accumulation steps = 1
  Total optimization steps = 4525
  Number of trainable parameters = 783150080
Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


[34m[1mwandb[0m: Currently logged in as: [33msw7[0m. Use [1m`wandb login --relogin`[0m to force relogin


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


You're using a T5TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum,Gen Len
3,0.1576,0.170353,0.3235,0.2582,0.3229,0.323,19.0
4,0.1318,0.164951,0.3177,0.2531,0.3173,0.3176,19.0
5,0.1108,0.169526,0.3228,0.2576,0.3217,0.3218,19.0
6,0.097,0.163147,0.3178,0.2533,0.3177,0.3176,19.0
7,0.0835,0.170356,0.3231,0.2582,0.3226,0.3228,19.0
8,0.0746,0.174032,0.3209,0.2566,0.3205,0.3209,19.0
9,0.0647,0.193734,0.3198,0.2555,0.3196,0.3193,19.0
10,0.0586,0.193359,0.324,0.2596,0.3235,0.3232,19.0
11,0.0522,0.202077,0.3225,0.2588,0.3216,0.3219,19.0


Saving model checkpoint to /scratch/wadhwa.s/cache/conll_explanations/checkpoint-500
Configuration saved in /scratch/wadhwa.s/cache/conll_explanations/checkpoint-500/config.json
Model weights saved in /scratch/wadhwa.s/cache/conll_explanations/checkpoint-500/pytorch_model.bin
tokenizer config file saved in /scratch/wadhwa.s/cache/conll_explanations/checkpoint-500/tokenizer_config.json
Special tokens file saved in /scratch/wadhwa.s/cache/conll_explanations/checkpoint-500/special_tokens_map.json
Copy vocab file to /scratch/wadhwa.s/cache/conll_explanations/checkpoint-500/spiece.model
The following columns in the evaluation set don't have a corresponding argument in `T5ForConditionalGeneration.forward` and have been ignored: input, gold_relations, explanations. If input, gold_relations, explanations are not expected by `T5ForConditionalGeneration.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 182
  Batch size = 4
The following columns in the

In [107]:
torch.cuda.empty_cache()

In [129]:
tuned_model = AutoModelForSeq2SeqLM.from_pretrained("/scratch/wadhwa.s/cache/conll_explanations/checkpoint-2000")
tokenizer = AutoTokenizer.from_pretrained("/scratch/wadhwa.s/cache/conll_explanations/checkpoint-2000")

loading configuration file /scratch/wadhwa.s/cache/conll_explanations/checkpoint-2000/config.json
Model config T5Config {
  "_name_or_path": "/scratch/wadhwa.s/cache/conll_explanations/checkpoint-2000",
  "architectures": [
    "T5ForConditionalGeneration"
  ],
  "d_ff": 2816,
  "d_kv": 64,
  "d_model": 1024,
  "decoder_start_token_id": 0,
  "dense_act_fn": "gelu_new",
  "dropout_rate": 0.1,
  "eos_token_id": 1,
  "feed_forward_proj": "gated-gelu",
  "initializer_factor": 1.0,
  "is_encoder_decoder": true,
  "is_gated_act": true,
  "layer_norm_epsilon": 1e-06,
  "model_type": "t5",
  "n_positions": 512,
  "num_decoder_layers": 24,
  "num_heads": 16,
  "num_layers": 24,
  "output_past": true,
  "pad_token_id": 0,
  "relative_attention_max_distance": 128,
  "relative_attention_num_buckets": 32,
  "tie_word_embeddings": false,
  "torch_dtype": "float32",
  "transformers_version": "4.24.0",
  "use_cache": true,
  "vocab_size": 32128
}

loading weights file /scratch/wadhwa.s/cache/conll_exp

In [130]:
df = pd.read_csv("conll_fewshot_reference.csv")
df.shape

(231, 3)

In [131]:
ip = []
gold = []
generated = []
nc_count = 0

for ix, row in tqdm(df.iterrows(), total=df.shape[0]):
    text = prefix + row["text"]
    inputs = tokenizer(text, return_tensors="pt").input_ids
    outputs = tuned_model.generate(inputs, max_new_tokens=256, do_sample=False)
    out = tokenizer.decode(outputs[0], skip_special_tokens=True)
    torch.cuda.empty_cache()
    try:
        generated.append(out.split(" Explanation: ")[0])
        ip.append(row["text"])
        gold.append(row["gold_labels"])
        # print ("PRED: " + out)
    except:
        nc_count += 1
        print (row["text"])
        print ("GOLD: " + row["gold_labels"])
        print ("PRED ---- NON CONFORMING OUTPUT: " + out)
        print ("\n----------------\n")


100%|██████████| 231/231 [35:52<00:00,  9.32s/it]


In [132]:
df = pd.DataFrame({"text": ip, "gold_labels": gold, "generated": generated})
# df.to_csv("conll_flan_explanations_generated_ckpt500.csv", index=False)
print (df.shape)

(231, 3)


In [133]:
invalid_count = 0
total_count = 0
valid_parsed = 0
for ix, row in df.iterrows():
    # print ("GOLD: " + row["gold_labels"])
    try:
        pred = ast.literal_eval(row["generated"])
        total_count += len(pred)
        valid_parsed += 1
        for relation in pred:
            if relation[1] not in valid:
                print ("TEXT: ", row["text"])
                print ("PRED: " + relation)
                print ("\n*********************\n")
    except:
        invalid_count += 1
        print ("TEXT: ", row["text"])
        print ("NON CONFORMING PRED: ", row["generated"])
        print ("\n----------------\n")

TEXT:  The opera company performed at the Palace of Fine Arts , in San Francisco , on June 30 and July 1-2 , said Kevin O 'Brien , a spokesman for the theater.
NON CONFORMING PRED:  [['Palace of Fine Arts', 'OrgBased_In', 'San Francisco'], ['Kevin O'Brien', 'Work_For', 'Palace of Fine Arts'], ['Kevin O'Brien', 'Live_In', 'San Francisco']]

----------------

TEXT:  It was filmed on location in Tokyo and Kyoto and is a co-production of Children 's Television Worship and NHK network in Japan .
NON CONFORMING PRED:  [['Children's Television Worship', 'OrgBased_In', 'Tokyo'], ['NHK network', 'OrgBased_In', 'Japan']]

----------------

TEXT:  Guy M. Struve , the New York attorney representing Church 's , said he was disappointed and did not know when the chain would decide whether to appeal Thursday 's ruling .
NON CONFORMING PRED:  [['Guy M. Struve', 'Work_For', 'Church's']]

----------------

TEXT:  During my recent tour , I met with Sultan Qabus Bin-Sa 'id of Oman .
NON CONFORMING PRED:  

In [134]:
print (invalid_count, total_count, valid_parsed)

4 299 227


In [32]:
text = [prefix + s1, prefix + s2]
print (text[1])

List all relations of the following type in the given text and provide reasonable explanations for your answers - 
1. Kill: Entity A killed Entity B.
2. Work_For: Entity A works for Entity B.
3. Located_In: Entity A is located in Entity B.
4. Live_In: Entity A lives in Entity B.
5. OrgBased_In: Entity A is an organization based in Entity B.

TEXT: Harry Truman then signed the final legislation to create the National Security Council at his residence in Washington .


In [42]:
torch.cuda.empty_cache()
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True).input_ids

In [36]:
len(inputs)

2

In [37]:
outputs = model.generate(inputs, max_new_tokens=256, do_sample=False)

In [39]:
out = tokenizer.decode(outputs[1], skip_special_tokens=True)
print (out)

[['Harry Truman', 'Live_In', 'Washington']] Explanation: Harry Truman lives in Washington, so it can be concluded that Harry Truman Lives In Washington.


In [32]:
out = out.split(" Explanation")[0]

In [34]:
import ast

In [35]:
print (ast.literal_eval(out))

In [98]:
conll = json.load(open('conll04_dev.json'))

In [104]:
conll[4]

{'tokens': ['Marie',
  'Magdefrau',
  'Ferraro',
  ',',
  '50',
  ',',
  'of',
  'Bethany',
  ',',
  'Conn.',
  ',',
  'was',
  'shot',
  'to',
  'death',
  'Thursday',
  'when',
  'two',
  'bandits',
  'armed',
  'with',
  'assault',
  'rifles',
  'emerged',
  'from',
  'nearby',
  'bushes',
  'and',
  'began',
  'firing',
  'at',
  'a',
  'van',
  'carrying',
  'a',
  'Connecticut',
  'Audubon',
  'Society',
  'wildlife',
  'wild',
  'tour',
  'group',
  '.'],
 'entities': [{'type': 'Peop', 'start': 0, 'end': 3},
  {'type': 'Loc', 'start': 7, 'end': 8},
  {'type': 'Loc', 'start': 9, 'end': 10},
  {'type': 'Org', 'start': 35, 'end': 39}],
 'relations': [{'type': 'Live_In', 'head': 0, 'tail': 1},
  {'type': 'Live_In', 'head': 0, 'tail': 2},
  {'type': 'Located_In', 'head': 1, 'tail': 2}],
 'orig_id': 1771}