In [33]:
import pandas as pd
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, Seq2SeqTrainingArguments, Seq2SeqTrainer
from datasets import load_dataset
from transformers import DataCollatorForSeq2Seq
import torch
import evaluate
import numpy as np
from tqdm import tqdm

In [1]:
!pwd

/home/wadhwa.s/gpt3


In [2]:
model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-large", 
                                              cache_dir="/scratch/wadhwa.s/cache", 
                                              device_map="auto")
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-large", 
                                          cache_dir="/scratch/wadhwa.s/cache")

In [3]:
rouge = evaluate.load("rouge")

In [4]:
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    result = rouge.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
    result["gen_len"] = np.mean(prediction_lens)
    return {k: round(v, 4) for k, v in result.items()}

In [5]:
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)

In [6]:
df = pd.read_csv("ade_gpt3_explanations_responses.csv")
for ix, row in df.iterrows():
    # print("TEXT: ", row["input"])
    row["explanations"] = "Explanation: " + row["explanations"] + "</s>"
    # print ("Relations: ", row["gold_relations"])
    # print (row["explanations"])
    # print ("\n----------------\n")
df.to_csv("ade_gpt3_explanations_responses.csv", index=False)

In [7]:
ade_data = load_dataset("csv", data_files="ade_gpt3_explanations_responses.csv")
ade_data = ade_data["train"]

Using custom data configuration default-6f29c012792c1789


Downloading and preparing dataset csv/default to /home/wadhwa.s/.cache/huggingface/datasets/csv/default-6f29c012792c1789/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317...


Downloading data files: 100%|██████████| 1/1 [00:00<00:00, 6364.65it/s]
Extracting data files: 100%|██████████| 1/1 [00:00<00:00, 516.54it/s]
  return pd.read_csv(xopen(filepath_or_buffer, "rb", use_auth_token=use_auth_token), **kwargs)
                                                        

Dataset csv downloaded and prepared to /home/wadhwa.s/.cache/huggingface/datasets/csv/default-6f29c012792c1789/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317. Subsequent calls will reuse this data.


100%|██████████| 1/1 [00:00<00:00, 137.51it/s]


In [8]:
ade = ade_data.train_test_split(test_size=0.2)

prefix = """List the drugs and their corresponding adverse-effects in the following TEXT using [drug, effect] format:\nTEXT: """

print (prefix)

List the drugs and their corresponding adverse-effects in the following TEXT using [drug, effect] format:
TEXT: 


In [9]:
target = [prefix + example + "\nRelations: " for example in ade["train"]["input"]]
print (target[0])

List the drugs and their corresponding adverse-effects in the following TEXT using [drug, effect] format:
TEXT: A case of allopurinol hypersensitivity, possibly the first in a black African, is reported.
Relations: 


In [10]:
def preprocess_function(examples):
    inputs = [prefix + example + "\nRelations: " for example in examples["input"]]
    model_inputs = tokenizer(inputs, max_length=512, truncation=True)
    targets = [gold_relations + "\n" + explanation for gold_relations, explanation in zip(examples["gold_relations"], examples["explanations"])]
    labels = tokenizer(text_target=targets, max_length=256, truncation=True)
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [11]:
tokenized_ade = ade.map(preprocess_function, batched=True)

100%|██████████| 4/4 [00:01<00:00,  3.39ba/s]
100%|██████████| 1/1 [00:00<00:00,  3.97ba/s]


In [12]:
tokenized_ade

DatasetDict({
    train: Dataset({
        features: ['input', 'explanations', 'gold_relations', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 3403
    })
    test: Dataset({
        features: ['input', 'explanations', 'gold_relations', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 851
    })
})

In [13]:
torch.cuda.empty_cache()

In [16]:
torch.cuda.empty_cache()

training_args = Seq2SeqTrainingArguments(
    output_dir="/scratch/wadhwa.s/cache/ade_explanations",
    evaluation_strategy="steps",
    learning_rate=3e-5,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    logging_steps=1000,
    eval_delay=600,
    logging_strategy="epoch",
    save_total_limit=7,
    weight_decay=0.01,
    num_train_epochs=10,
    predict_with_generate=True,
    # gradient_accumulation_steps=4,
    # fp16=True,
    # push_to_hub=True,
)

In [17]:
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_ade["train"],
    eval_dataset=tokenized_ade["test"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

In [18]:
trainer.train()

The following columns in the training set don't have a corresponding argument in `T5ForConditionalGeneration.forward` and have been ignored: explanations, gold_relations, input. If explanations, gold_relations, input are not expected by `T5ForConditionalGeneration.forward`,  you can safely ignore this message.
***** Running training *****
  Num examples = 3403
  Num Epochs = 10
  Instantaneous batch size per device = 4
  Total train batch size (w. parallel, distributed & accumulation) = 4
  Gradient Accumulation steps = 1
  Total optimization steps = 8510
  Number of trainable parameters = 783150080
Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


[34m[1mwandb[0m: Currently logged in as: [33msw7[0m. Use [1m`wandb login --relogin`[0m to force relogin


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


You're using a T5TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Step,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum,Gen Len
1000,0.2451,0.16363,0.2704,0.1824,0.2697,0.2698,19.0
2000,0.1674,0.158184,0.2698,0.1823,0.269,0.2691,19.0
3000,0.1372,0.158266,0.2695,0.1834,0.2687,0.2687,19.0
4000,0.1151,0.154412,0.2699,0.1839,0.269,0.2692,19.0
5000,0.1006,0.160567,0.2722,0.1873,0.2716,0.2716,19.0


Saving model checkpoint to /scratch/wadhwa.s/cache/ade_explanations/checkpoint-500
Configuration saved in /scratch/wadhwa.s/cache/ade_explanations/checkpoint-500/config.json
Model weights saved in /scratch/wadhwa.s/cache/ade_explanations/checkpoint-500/pytorch_model.bin
tokenizer config file saved in /scratch/wadhwa.s/cache/ade_explanations/checkpoint-500/tokenizer_config.json
Special tokens file saved in /scratch/wadhwa.s/cache/ade_explanations/checkpoint-500/special_tokens_map.json
Copy vocab file to /scratch/wadhwa.s/cache/ade_explanations/checkpoint-500/spiece.model
Deleting older checkpoint [/scratch/wadhwa.s/cache/ade_explanations/checkpoint-1000] due to args.save_total_limit
The following columns in the evaluation set don't have a corresponding argument in `T5ForConditionalGeneration.forward` and have been ignored: explanations, gold_relations, input. If explanations, gold_relations, input are not expected by `T5ForConditionalGeneration.forward`,  you can safely ignore this mess

In [20]:
torch.cuda.empty_cache()

In [21]:
tuned_model = AutoModelForSeq2SeqLM.from_pretrained("/scratch/wadhwa.s/cache/ade_explanations/checkpoint-3000")
tokenizer = AutoTokenizer.from_pretrained("/scratch/wadhwa.s/cache/ade_explanations/checkpoint-3000")

loading configuration file /scratch/wadhwa.s/cache/ade_explanations/checkpoint-3000/config.json
Model config T5Config {
  "_name_or_path": "/scratch/wadhwa.s/cache/ade_explanations/checkpoint-3000",
  "architectures": [
    "T5ForConditionalGeneration"
  ],
  "d_ff": 2816,
  "d_kv": 64,
  "d_model": 1024,
  "decoder_start_token_id": 0,
  "dense_act_fn": "gelu_new",
  "dropout_rate": 0.1,
  "eos_token_id": 1,
  "feed_forward_proj": "gated-gelu",
  "initializer_factor": 1.0,
  "is_encoder_decoder": true,
  "is_gated_act": true,
  "layer_norm_epsilon": 1e-06,
  "model_type": "t5",
  "n_positions": 512,
  "num_decoder_layers": 24,
  "num_heads": 16,
  "num_layers": 24,
  "output_past": true,
  "pad_token_id": 0,
  "relative_attention_max_distance": 128,
  "relative_attention_num_buckets": 32,
  "tie_word_embeddings": false,
  "torch_dtype": "float32",
  "transformers_version": "4.24.0",
  "use_cache": true,
  "vocab_size": 32128
}

loading weights file /scratch/wadhwa.s/cache/ade_explanati

In [34]:
ip = []
gold = []
generated = []
nc_count = 0

for _, i in enumerate(tqdm(ade["test"])):
    text = prefix + i["input"] + "\nRelations: "
    inputs = tokenizer(text, return_tensors="pt").input_ids
    outputs = tuned_model.generate(inputs, max_new_tokens=256, do_sample=False)
    out = tokenizer.decode(outputs[0], skip_special_tokens=True)
    torch.cuda.empty_cache()
    try:
        generated.append(out.split(" Explanation: ")[0])
        ip.append(i["input"])
        gold.append(i["gold_relations"])
    except:
        nc_count += 1
        print ("********* Not able to Decode *********")
        print (i["input"])
        print ("Relations: ", i["gold_relations"])
        print ("Generated: ", out)
        print ("\n----------------\n")
    # break

100%|██████████| 851/851 [2:26:24<00:00, 10.32s/it]  


In [35]:
df = pd.DataFrame({"text": ip, "gold_labels": gold, "generated": generated})
df.to_csv("ade_flan_explanations_generated.csv", index=False)
print (df.shape)

(851, 3)


In [1]:
!pwd

/home/wadhwa.s/gpt3


In [36]:
for text, gold_relations, generated in zip(ip, gold, generated):
    print ("TEXT: ", text)
    print ("Relations: ", gold_relations)
    print ("Generated: ", generated)
    print ("\n----------------\n")

TEXT:  Severe hepatitis caused by cyproterone acetate.
Relations:  [['cyproterone acetate', 'Severe hepatitis']]
Generated:  [['cyproterone acetate', 'Severe hepatitis']]

----------------

TEXT:  Sustained-release procainamide-induced reversible granulocytopenia after myocardial infarction.
Relations:  [['procainamide', 'reversible granulocytopenia']]
Generated:  [['procainamide','reversible granulocytopenia']]

----------------

TEXT:  CASE REPORT: We hereby report a case of radiation recall dermatitis and myositis occurring on gemcitabine monotherapy, five months after completing chemoradiation for locally advanced pancreatic cancer.
Relations:  [['gemcitabine', 'myositis'], ['gemcitabine', 'radiation recall dermatitis']]
Generated:  [['gemcitabine','myositis'], ['gemcitabine', 'radiation recall dermatitis']]

----------------

TEXT:  Hepatic angiosarcoma occurring after cyclophosphamide therapy: case report and review of the literature.
Relations:  [['cyclophosphamide', 'Hepatic an

In [22]:
text = prefix + "spindle coma in benzodiazepine toxicity: case report." + "\nRelations: "
print (text)

List the drugs and their corresponding adverse-effects in the following TEXT using [drug, effect] format:
TEXT: spindle coma in benzodiazepine toxicity: case report.
Relations: 


In [23]:
inputs = tokenizer(text, return_tensors="pt").input_ids

In [24]:
outputs = tuned_model.generate(inputs, max_new_tokens=100, do_sample=False)

In [25]:
out = tokenizer.decode(outputs[0], skip_special_tokens=True)
print (out)

[['benzodiazepine','spindle coma']] Explanation: Explanation: Spindle coma was reported in benzodiazepine toxicity.


In [35]:
import ast

In [36]:
print (ast.literal_eval(out))

[['Alprazolam', 'headache'], ['Xanax', 'headache']]
