In [1]:
import warnings
warnings.filterwarnings("ignore")

import pandas as pd
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, Seq2SeqTrainingArguments, Seq2SeqTrainer
from datasets import load_dataset
from transformers import DataCollatorForSeq2Seq
import torch
import evaluate
import numpy as np
import ast
import jsonlines
import json
import re
import glob
import numpy as np
from tqdm import tqdm

In [2]:
torch.cuda.empty_cache()

model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-large", 
                                              cache_dir="/scratch/wadhwa.s/cache", 
                                              device_map="auto")
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-large", 
                                          cache_dir="/scratch/wadhwa.s/cache")

In [3]:
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)

In [4]:
rouge = evaluate.load("rouge")

In [5]:
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    result = rouge.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
    result["gen_len"] = np.mean(prediction_lens)
    return {k: round(v, 4) for k, v in result.items()}

In [6]:
path = "/home/wadhwa.s/gpt3/nyt_explanations/"
df = pd.read_csv(path + "nyt_explanations_training.csv")
df.head()

Unnamed: 0,input,gold_relations,explanations
0,TEXT: When Weah returned to Liberia this sprin...,"[['Florida', '/location/location/contains', 'F...",Explanation: Fort Lauderdale is a location in ...
1,TEXT: Senate Republicans intend to be as coope...,"[['Mitch McConnell', '/people/person/place_liv...",Explanation: Mitch McConnell is a person who l...
2,"TEXT: Libya , where young men in Darfur used t...","[['Darfur', '/location/administrative_division...",Explanation: Darfur is a place located in the ...
3,TEXT: They crossed five international boundari...,"[['Africa', '/location/location/contains', 'Ni...","Explanation: Nigeria, Chad, Accra, Benin and T..."
4,TEXT: Ending with a rock 'n' rock fiesta in En...,"[['Mexico', '/location/location/contains', 'En...",Explanation: Ensenada is a location located in...


In [7]:
nyt_data = load_dataset("csv", data_files=path+"nyt_explanations_training.csv")
nyt_data = nyt_data["train"]

Using custom data configuration default-915147562b964706
Found cached dataset csv (/home/wadhwa.s/.cache/huggingface/datasets/csv/default-915147562b964706/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317)
100%|██████████| 1/1 [00:00<00:00, 538.84it/s]


In [8]:
prefix = "List all relations between entities of the types [LOCATION, PERSON, ORGANIZATION] in the following TEXT and provide reasonable explanation:\n\n"
print (prefix + nyt_data[0]["input"] + "\nRelations: ")

List all relations between entities of the types [LOCATION, PERSON, ORGANIZATION] in the following TEXT and provide reasonable explanation:

TEXT: When Weah returned to Liberia this spring after a long sojourn in Florida -- his American wife and children live in Fort Lauderdale -- his arrival shut down the capital for the day .
Relations: 


In [9]:
nyt = nyt_data.train_test_split(test_size=0.1)
nyt

DatasetDict({
    train: Dataset({
        features: ['input', 'gold_relations', 'explanations'],
        num_rows: 22438
    })
    test: Dataset({
        features: ['input', 'gold_relations', 'explanations'],
        num_rows: 2494
    })
})

In [10]:
target = [prefix + example + "\nRelations: " for example in nyt["train"]["input"]]
print (target[0])

List all relations between entities of the types [LOCATION, PERSON, ORGANIZATION] in the following TEXT and provide reasonable explanation:

TEXT: The lack of security that led to some of the election problems continued to be an issue in Iraq , as one American soldier was killed and another was wounded just after noon on Wednesday in an insurgent attack on a convoy near Baghdad , the United States military said .
Relations: 


In [11]:
def preprocess_function(examples):
    inputs = [prefix + example + "\nRelations: " for example in examples["input"]]
    model_inputs = tokenizer(inputs, max_length=512, truncation=True)
    targets = [gold_relations + "\n" + explanation for gold_relations, explanation in zip(examples["gold_relations"], examples["explanations"])]
    labels = tokenizer(text_target=targets, max_length=256, truncation=True)
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [12]:
tokenized_nyt = nyt.map(preprocess_function, batched=True)

100%|██████████| 23/23 [00:09<00:00,  2.55ba/s]
100%|██████████| 3/3 [00:00<00:00,  3.11ba/s]


In [13]:
tokenized_nyt

DatasetDict({
    train: Dataset({
        features: ['input', 'gold_relations', 'explanations', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 22438
    })
    test: Dataset({
        features: ['input', 'gold_relations', 'explanations', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 2494
    })
})

In [16]:
torch.cuda.empty_cache()

training_args = Seq2SeqTrainingArguments(
    output_dir="/scratch/wadhwa.s/cache/nyt_explanations",
    evaluation_strategy="steps",
    logging_strategy="steps",
    learning_rate=3e-5,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    logging_steps=1000,
    eval_delay=1000,
    save_total_limit=7,
    weight_decay=0.01,
    num_train_epochs=10,
    predict_with_generate=True,
    # gradient_accumulation_steps=4,
    # fp16=True,
    # push_to_hub=True,
)

using `logging_steps` to initialize `eval_steps` to 1000
PyTorch: setting up devices
The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).


In [17]:
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_nyt["train"],
    eval_dataset=tokenized_nyt["test"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

In [18]:
trainer.train()

The following columns in the training set don't have a corresponding argument in `T5ForConditionalGeneration.forward` and have been ignored: gold_relations, explanations, input. If gold_relations, explanations, input are not expected by `T5ForConditionalGeneration.forward`,  you can safely ignore this message.
***** Running training *****
  Num examples = 22438
  Num Epochs = 10
  Instantaneous batch size per device = 4
  Total train batch size (w. parallel, distributed & accumulation) = 4
  Gradient Accumulation steps = 1
  Total optimization steps = 56100
  Number of trainable parameters = 783150080
Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


[34m[1mwandb[0m: Currently logged in as: [33msw7[0m. Use [1m`wandb login --relogin`[0m to force relogin


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


You're using a T5TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Step,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum,Gen Len
1000,0.249,0.140235,0.3062,0.2486,0.3052,0.3051,19.0
2000,0.1555,0.121165,0.3159,0.2601,0.3151,0.3151,19.0
3000,0.1393,0.112603,0.3198,0.2646,0.3195,0.3193,19.0
4000,0.1344,0.109109,0.3158,0.2609,0.3153,0.3152,19.0
5000,0.1238,0.10585,0.3219,0.2669,0.3216,0.3217,19.0
6000,0.1161,0.103039,0.323,0.2677,0.3225,0.3224,19.0
7000,0.1068,0.100134,0.3216,0.2665,0.3212,0.3213,19.0
8000,0.1084,0.098545,0.3235,0.2687,0.323,0.3229,19.0
9000,0.1015,0.097185,0.3239,0.2692,0.3235,0.3236,19.0
10000,0.1026,0.09423,0.3253,0.2701,0.3249,0.3248,19.0


Saving model checkpoint to /scratch/wadhwa.s/cache/nyt_explanations/checkpoint-500
Configuration saved in /scratch/wadhwa.s/cache/nyt_explanations/checkpoint-500/config.json
Model weights saved in /scratch/wadhwa.s/cache/nyt_explanations/checkpoint-500/pytorch_model.bin
tokenizer config file saved in /scratch/wadhwa.s/cache/nyt_explanations/checkpoint-500/tokenizer_config.json
Special tokens file saved in /scratch/wadhwa.s/cache/nyt_explanations/checkpoint-500/special_tokens_map.json
Copy vocab file to /scratch/wadhwa.s/cache/nyt_explanations/checkpoint-500/spiece.model
The following columns in the evaluation set don't have a corresponding argument in `T5ForConditionalGeneration.forward` and have been ignored: gold_relations, explanations, input. If gold_relations, explanations, input are not expected by `T5ForConditionalGeneration.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 2494
  Batch size = 4
Saving model checkpoint to /scratch/wad

## Validation

In [9]:
prefix = "List all relations between entities of the types [LOCATION, PERSON, ORGANIZATION] in the following TEXT and provide reasonable explanation:\n\n"

In [2]:
text = []
true_relations = []
num_relations = []

with open("nyt/raw_valid.json") as f:
    for line in f:
        data = ast.literal_eval(line)
        # print (data["sentText"])
        # print ("RELATIONS: ")
        relations = []
        for relation in data["relationMentions"]:
            relations.append([relation["em1Text"], relation["label"], relation["em2Text"]])
        text.append(data["sentText"])
        true_relations.append(relations)
        num_relations.append(len(relations))

In [6]:
text[0]

'In Queens , North Shore Towers , near the Nassau border , supplanted a golf course , and housing replaced a gravel quarry in Douglaston .'

In [8]:
true_relations[0]

[['Douglaston', '/location/neighborhood/neighborhood_of', 'Queens'],
 ['Queens', '/location/location/contains', 'Douglaston']]

In [11]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device

device(type='cuda')

In [10]:
tuned_model = AutoModelForSeq2SeqLM.from_pretrained("/scratch/wadhwa.s/cache/nyt_explanations/checkpoint-17500")
tokenizer = AutoTokenizer.from_pretrained("/scratch/wadhwa.s/cache/nyt_explanations/checkpoint-17500")

In [12]:
tuned_model = tuned_model.to(device)

In [75]:
inputs_to_model = []
gold_relations = []
generated_relations = []
full_output = []
nc_flag = []

nc_count = 0
for ix, input in enumerate(tqdm(text)):
    torch.cuda.empty_cache()
    inputs_to_model.append(input)
    gold_relations.append(true_relations[ix])
    input = prefix + input + "\nRelations: "
    inputs = tokenizer(input, return_tensors="pt").input_ids.to(device)
    outputs = tuned_model.generate(inputs, max_new_tokens=256, do_sample=False)
    out = tokenizer.decode(outputs[0], skip_special_tokens=True)
    full_output.append(out)
    try:
        generated = ast.literal_eval(out.split("Explanation:")[0])
        generated_relations.append(generated)
        nc_flag.append(0)
    except:
        nc_count += 1
        nc_flag.append(1)
        generated_relations.append([])

100%|██████████| 5000/5000 [2:17:21<00:00,  1.65s/it]  


In [76]:
pd.DataFrame({"input": inputs_to_model, "gold_relations": gold_relations, "generated_relations": generated_relations, "full_output": full_output, "nc_flag": nc_flag}).to_csv("nyt/nyt_explanations_output.csv", index=False)

In [77]:
for i, j, k in zip(inputs_to_model, gold_relations, generated_relations):
    print (i)
    print ("True Relations: ", j)
    print ("Generated Relations: ", k)
    print ("\n----------------------\n")

In Queens , North Shore Towers , near the Nassau border , supplanted a golf course , and housing replaced a gravel quarry in Douglaston .
True Relations:  [['Douglaston', '/location/neighborhood/neighborhood_of', 'Queens'], ['Queens', '/location/location/contains', 'Douglaston']]
Generated Relations:  [['Queens', '/location/location/contains', 'Douglaston'], ['Douglaston', '/location/neighborhood/neighborhood_of', 'Queens']]

----------------------

In his authoritative and tough-minded new book , '' The Assassins ' Gate : America in Iraq , '' the New Yorker writer George Packer reminds us that the decision of the Bush administration to go to war against Iraq and its increasingly embattled handling of the occupation were both predicated upon large , abstract ideas about the role of America in the post-cold war world -- most notably , a belief in pre-emptive and unilateral action , the viability of exporting democracy abroad , the urge to streamline the military and the dream of remakin

In [54]:
ix = 22

In [55]:
sample = prefix + text[ix] + "\nRelations: "
print (sample)

List all relations between entities of the types [LOCATION, PERSON, ORGANIZATION] in the following TEXT and provide reasonable explanation:

Attempting to draw a distinction based on the medium used by the blogger or reporter is misguided , said Jack Balkin , a professor at Yale Law School -LRB- also a blogger -RRB- . ''
Relations: 


In [56]:
inputs = tokenizer(sample, return_tensors="pt").input_ids.to(device)

In [57]:
outputs = tuned_model.generate(inputs, max_new_tokens=256, do_sample=False)

In [58]:
out = tokenizer.decode(outputs[0], skip_special_tokens=True)

In [64]:
r = ast.literal_eval(out.split("Explanation:")[0])
print (type(r[0]))

<class 'list'>


In [60]:
print (text[ix])
print ("True relations: ", true_relations[ix])

Attempting to draw a distinction based on the medium used by the blogger or reporter is misguided , said Jack Balkin , a professor at Yale Law School -LRB- also a blogger -RRB- . ''
True relations:  [['Jack Balkin', '/business/person/company', 'Yale Law School']]


In [63]:
torch.cuda.empty_cache()

In [66]:
r[0] in true_relations[ix]

True