In [1]:
import numpy as np
import pandas as pd
import logging
import os
import glob
import regex as re
import torch
import argparse
import random
import itertools
from random import random
import ast
import sys
import ast

from sklearn.model_selection import train_test_split
from transformers import T5Tokenizer, T5ForConditionalGeneration
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import simpletransformers

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from simpletransformers.t5 import T5Model, T5Args

In [3]:
os.environ['TORCH_HOME'] = '/scratch/wadhwa.s/cache/'
os.environ['HF_HOME'] = '/scratch/wadhwa.s/cache'
os.environ['TRANSFORMERS_CACHE'] = '/scratch/wadhwa.s/cache'

In [4]:
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-xxl", cache_dir="/scratch/wadhwa.s/cache")

Downloading: 100%|█████████████████████████| 2.42M/2.42M [00:00<00:00, 4.10MB/s]


In [5]:
model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-xxl", cache_dir="/scratch/wadhwa.s/cache", device_map="auto")

In [6]:
input_text = "Alprazolam caused nausea in several patients.\n[Alprazolam, nausea]</s>\n\nXanax was non-effective in treating headaches and made them worse.\n[Xanax, headaches]</s>\n\nTaking aspirin made patients feel constipated."
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to("cuda")

In [7]:
outputs = model.generate(input_ids, max_length=200)

In [8]:
print(tokenizer.batch_decode(outputs, skip_special_tokens=True))

['[Aspirin, constipation]']


In [83]:
with open("ade_prompt.txt", "r") as text_file:
    prompt = text_file.read()
print (prompt)

We report a rare case of colonic mucosal necrosis following Kalimate (calcium polystryrene sulfonate), an analogue of Kayexalate without sorbitol in a 34-yr-old man.
Relation List: [["calcium polystryrene sulfonate","colonic mucosal necrosis"],["Kalimate","colonic mucosal necrosis"],["Kayexalate","colonic mucosal necrosis"]]</s>

Moreover, these findings suggest that the incidence of BOOP following rituximab therapy may be higher than has been previously appreciated.
Relation List: [["rituximab","BOOP"]</s>

Malignant mixed mullerian tumor of the uterus in a patient taking raloxifene.
Relation List: [["raloxifene","Malignant mixed mullerian tumor of the uterus"]]</s>

We describe a case of clozapine-induced seizures in a patient with treatment-resistant schizophrenia.
Relation List: [["clozapine","seizures"]]</s>

Fever, pulmonary infiltrates, and pleural effusion following acyclovir therapy for herpes zoster ophthalmicus.
Relation List: [["acyclovir,Fever"],["acyclovir","pleural effus

In [84]:
test_str = """Dilated cardiomyopathy associated with chronic overuse of an adrenaline inhaler."""

In [85]:
prompt = prompt + test_str + "\nRelation List: "

In [86]:
print (prompt)

We report a rare case of colonic mucosal necrosis following Kalimate (calcium polystryrene sulfonate), an analogue of Kayexalate without sorbitol in a 34-yr-old man.
Relation List: [["calcium polystryrene sulfonate","colonic mucosal necrosis"],["Kalimate","colonic mucosal necrosis"],["Kayexalate","colonic mucosal necrosis"]]</s>

Moreover, these findings suggest that the incidence of BOOP following rituximab therapy may be higher than has been previously appreciated.
Relation List: [["rituximab","BOOP"]</s>

Malignant mixed mullerian tumor of the uterus in a patient taking raloxifene.
Relation List: [["raloxifene","Malignant mixed mullerian tumor of the uterus"]]</s>

We describe a case of clozapine-induced seizures in a patient with treatment-resistant schizophrenia.
Relation List: [["clozapine","seizures"]]</s>

Fever, pulmonary infiltrates, and pleural effusion following acyclovir therapy for herpes zoster ophthalmicus.
Relation List: [["acyclovir,Fever"],["acyclovir","pleural effus

In [87]:
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to("cuda")

In [90]:
outputs = model.generate(input_ids, max_length=200)

In [91]:
print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])

[["adrenaline","cardiomyopathy"]]


In [4]:
model_args = {
    "cache_dir" : "/scratch/wadhwa.s/cache/",
    "device_map" : "auto",
    "use_multiprocessing": False,
    "use_multiprocessed_decoding": False,
#         "num_train_epochs" : 4,
         "no_save" : True,
#         "preprocess_inputs" : False,
         "overwrite_output_dir" : True,
#         "special_tokens_list" : ["<bos>", "<eos>", "<rel>", "<ent>"],
#         "max_length" : 200,
#         "num_beams" : 5,
#         "learning_rate": lr,

    }

In [5]:
model = T5Model("t5", "google/flan-t5-xxl", args=model_args)

In [16]:
to_predict = [
    "find [drug, adverse event] relation pairs: Alprazolam caused nausea in several patients.",
    "compute: 2 + 2"
]

In [17]:
preds = model.predict(to_predict)

Generating outputs: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  1.11it/s]


In [18]:
preds

['alprazolam, adverse event nausea', '2']