In [1]:
import torch
import transformers

import logging
import numpy as np

import argparse

from transformers import (
    GPT2LMHeadModel,
    GPT2Tokenizer,
    OpenAIGPTLMHeadModel,
    OpenAIGPTTokenizer,
    AutoTokenizer, 
    AutoModelForQuestionAnswering,
)

In [2]:
def ask(tokenizer, model, question, context):
    #question = "What are the symptoms of…rian Germ Cell Tumors ?"

    #context = """Signs of ovarian germ cell tumor are swelling of the abdomen or vaginal bleeding after menopause. Ovarian germ cell tumors can be hard to diagnose (find) early. Often there are no symptoms in the early stages, but tumors may be found during regular gynecologic exams (checkups). Check with your doctor if you have either of the following:          - Swollen abdomen without weight gain in other parts of the body.     - Bleeding from the vagina after menopause (when you are no longer having menstrual periods)."""

    inputs = tokenizer.encode_plus(question, context, return_tensors="pt") 
    answer = model(**inputs)
    answer_start = torch.argmax(answer.start_logits)  # get the most likely beginning of answer with the argmax of the score
    answer_end = torch.argmax(answer.end_logits) + 1  # get the most likely end of answer with the argmax of the score

    ans = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(inputs["input_ids"][0][answer_start:answer_end]))
    return ans

In [3]:
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO,
)
logger = logging.getLogger(__name__)

MAX_LENGTH = int(10000)  # Hardcoded max length to avoid infinite loop

MODEL_CLASSES = {
    "gpt2": (GPT2LMHeadModel, GPT2Tokenizer),
    "openai-gpt": (OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
}

In [4]:
def adjust_length_to_model(length, max_sequence_length):
    if length < 0 and max_sequence_length > 0:
        length = max_sequence_length
    elif 0 < max_sequence_length < length:
        length = max_sequence_length  # No generation bigger than model size
    elif length < 0:
        length = MAX_LENGTH  # avoid infinite loop
    return length

In [5]:
def main(**kwarg):
    
    no_cuda = True
    device = torch.device("cuda" if torch.cuda.is_available() and not no_cuda else "cpu")

    # Initialize the model and tokenizer
    try:
        model_class, tokenizer_class = MODEL_CLASSES[kwarg['model_type']]
    except KeyError:
        raise KeyError("the model {} you specified is not supported. You are welcome to add it and open a PR :)")

    tokenizer = tokenizer_class.from_pretrained(kwarg['model_name_or_path'])
    model = model_class.from_pretrained(kwarg['model_name_or_path'])

    length = adjust_length_to_model(kwarg['length'], max_sequence_length=model.config.max_position_embeddings)
    logger.info(kwarg)

    prompt_text = kwarg['prompt'] if kwarg['prompt'] else input("Model prompt >>> ")
    
    encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=True, return_tensors="pt")
    encoded_prompt = encoded_prompt.to(device)

    if encoded_prompt.size()[-1] == 0:
        input_ids = None
    else:
        input_ids = encoded_prompt

    output_sequences = model.generate(
        input_ids=input_ids,
        max_length=length + len(encoded_prompt[0]),
        temperature= 1.0,
        top_k = kwarg['k'],
        top_p=0.9,
        repetition_penalty=1.0,
        do_sample=True,
        num_return_sequences = 2
    )

    # Remove the batch dimension when returning multiple sequences
    if len(output_sequences.shape) > 2:
        output_sequences.squeeze_()

    generated_sequences = []

    for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
        print("=== GENERATED SEQUENCE {} ===".format(generated_sequence_idx + 1))
        generated_sequence = generated_sequence.tolist()

        # Decode text
        text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)

        # Remove all text after the stop token
        text = text[: text.find(kwarg['stop_token']) if kwarg['stop_token'] else None]

        # Add the prompt at the beginning of the sequence. Remove the excess text that was used for pre-processing
        total_sequence = (
            prompt_text + text[len(tokenizer.decode(encoded_prompt[0], clean_up_tokenization_spaces=True)) :]
        )

        generated_sequences.append(total_sequence)
        print(total_sequence)

    return generated_sequences

In [6]:
tokenizer = AutoTokenizer.from_pretrained("../qa_script/BioBERT_cancer")

model = AutoModelForQuestionAnswering.from_pretrained("../qa_script/BioBERT_cancer")

In [7]:
question = "What are the symptoms of…rian Germ Cell Tumors ?"
context = """Signs of ovarian germ cell tumor are swelling of the abdomen or vaginal bleeding after menopause. Ovarian germ cell tumors can be hard to diagnose (find) early. Often there are no symptoms in the early stages, but tumors may be found during regular gynecologic exams (checkups). Check with your doctor if you have either of the following:          - Swollen abdomen without weight gain in other parts of the body.     - Bleeding from the vagina after menopause (when you are no longer having menstrual periods)."""

print(question)
print("Cancer_bioasq: ", ask(tokenizer, model, question, context))

What are the symptoms of…rian Germ Cell Tumors ?
Cancer_bioasq:  signs of ovarian germ cell tumor are swelling of the abdomen or vaginal bleeding after menopause.


In [8]:
OUTPUT_DIR = '../text_generation_script/GPT2_text_generator'
PROMPT = ask(tokenizer, model, question, context)
output = main(model_type ='gpt2',
              model_name_or_path = OUTPUT_DIR,
              length = 300,
              prompt = PROMPT,
              stop_token = "<EOS>",
              k = 30,
              num_return_sequences = 2)

11/28/2020 14:57:11 - INFO - __main__ -   {'model_type': 'gpt2', 'model_name_or_path': '../text_generation_script/GPT2_text_generator', 'length': 300, 'prompt': 'signs of ovarian germ cell tumor are swelling of the abdomen or vaginal bleeding after menopause.', 'stop_token': '<EOS>', 'k': 30, 'num_return_sequences': 2}
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


=== GENERATED SEQUENCE 1 ===
signs of ovarian germ cell tumor are swelling of the abdomen or vaginal bleeding after menopause. These and other signs and symptoms may be caused by ovarian germ cell tumor or by other conditions. Check with your doctor if you have any of the following: Weakness or feeling tired. Weight loss with little or no effect on menstrual periods. Vaginal bleeding after menopause. Sometimes pain or swelling in a women's vagina. Fever or night sweats. 
=== GENERATED SEQUENCE 2 ===
signs of ovarian germ cell tumor are swelling of the abdomen or vaginal bleeding after menopause. These and other signs and symptoms may be caused by ovarian germ cell tumor or by other conditions. Check with your doctor if you have any of the following: Swelling of the abdomen or vaginal bleeding after menopause. Trouble starting the flow of urine. Weight loss for no known reason. Trouble emptying the bladder completely after menopause. Pain or feeling of fullness below the ribs on the lef

In [9]:
output = main(model_type ='gpt2',
              model_name_or_path = OUTPUT_DIR,
              length = 1000,
              prompt = PROMPT,
              stop_token = "<EOS>",
              k = 30,
              num_return_sequences = 2)

11/28/2020 14:59:56 - INFO - __main__ -   {'model_type': 'gpt2', 'model_name_or_path': '../text_generation_script/GPT2_text_generator', 'length': 1000, 'prompt': 'signs of ovarian germ cell tumor are swelling of the abdomen or vaginal bleeding after menopause.', 'stop_token': '<EOS>', 'k': 30, 'num_return_sequences': 2}
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


=== GENERATED SEQUENCE 1 ===
signs of ovarian germ cell tumor are swelling of the abdomen or vaginal bleeding after menopause. Other conditions can increase the risk of hairy cell leukemia. Check with your doctor if you have any of the following: Pain or swelling in the abdomen. Weakness or feeling tired. Weight loss for no known reason. A menstrual period that does not go away. 
=== GENERATED SEQUENCE 2 ===
signs of ovarian germ cell tumor are swelling of the abdomen or vaginal bleeding after menopause. These and other signs may be caused by ovarian germ cell tumor or by other conditions. Check with your doctor if you have any of the following: Pain or swelling in the abdomen. A lump in the abdomen, vagina, or rectum. Weight loss for no known reason. Pain or a feeling of fullness below the ribs on the left side. Weight loss with no known reason. A dark urinelike color that does not go away. 
