In [1]:
from datasets import load_dataset
from transformers import AutoModelForSeq2SeqLM
from transformers import AutoTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_name = "google/flan-t5-base"
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

In [3]:
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)

In [4]:
sentence = "I would like to order a pizza."
sentence_encoded = tokenizer(sentence, return_tensors="pt")
print(sentence_encoded)

{'input_ids': tensor([[  27,  133,  114,   12,  455,    3,    9, 6871,    5,    1]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}


In [5]:
dialog_dataset_name = "knkarthick/dialogsum"
dataset = load_dataset(dialog_dataset_name)

Downloading readme: 100%|██████████| 4.65k/4.65k [00:00<00:00, 8.97MB/s]
Downloading data: 100%|██████████| 11.3M/11.3M [00:02<00:00, 4.29MB/s]
Downloading data: 100%|██████████| 442k/442k [00:01<00:00, 434kB/s]
Downloading data: 100%|██████████| 1.35M/1.35M [00:01<00:00, 837kB/s]
  return pd.read_csv(xopen(filepath_or_buffer, "rb", download_config=download_config), **kwargs)
Generating train split: 12460 examples [00:00, 116040.72 examples/s]
  return pd.read_csv(xopen(filepath_or_buffer, "rb", download_config=download_config), **kwargs)
Generating validation split: 500 examples [00:00, 73814.79 examples/s]
  return pd.read_csv(xopen(filepath_or_buffer, "rb", download_config=download_config), **kwargs)
Generating test split: 1500 examples [00:00, 120224.26 examples/s]


In [9]:
print(dataset['test'][0]['dialogue'])
print(dataset['test'][0]['summary'])

#Person1#: Ms. Dawson, I need you to take a dictation for me.
#Person2#: Yes, sir...
#Person1#: This should go out as an intra-office memorandum to all employees by this afternoon. Are you ready?
#Person2#: Yes, sir. Go ahead.
#Person1#: Attention all staff... Effective immediately, all office communications are restricted to email correspondence and official memos. The use of Instant Message programs by employees during working hours is strictly prohibited.
#Person2#: Sir, does this apply to intra-office communications only? Or will it also restrict external communications?
#Person1#: It should apply to all communications, not only in this office between employees, but also any outside communications.
#Person2#: But sir, many employees use Instant Messaging to communicate with their clients.
#Person1#: They will just have to change their communication methods. I don't want any - one using Instant Messaging in this office. It wastes too much time! Now, please continue with the memo. Wh

In [30]:
example_indices = [40, 200]

#### Summary without prompt engineering

In [20]:
dash_line = '-'.join('' for x in range(50))
for i, index in enumerate(example_indices):
    dialogue = dataset['test'][index]['dialogue']
    summary = dataset['test'][index]['summary']
    
    inputs = tokenizer(dialogue, return_tensors='pt')
    
    generated_summary = tokenizer.decode(
        model.generate(
            inputs= inputs['input_ids'],
            max_new_tokens=50,
        )[0],
        skip_special_tokens=True
    )
    print(dash_line)
    print('Example ', i + 1)
    print(dash_line)
    print(f'BASELINE HUMAN SUMMARY\n{summary}')
    print(dash_line)
    print(f'MODEL GENERATION\n{generated_summary}')

-------------------------------------------------
Example  1
-------------------------------------------------
BASELINE HUMAN SUMMARY
#Person1# attends Brian's birthday party. Brian thinks #Person1# looks great and charming.
-------------------------------------------------
MODEL GENERATION
Brian, thank you for coming to our party.
-------------------------------------------------
Example  2
-------------------------------------------------
BASELINE HUMAN SUMMARY
#Person1# is about to make a prank. #Person2# thinks it's cruel at first but then joins.
-------------------------------------------------
MODEL GENERATION
#Person1#: Yeah.


In [31]:
for i, index in enumerate(example_indices):
    dialogue = dataset['test'][index]['dialogue']
    summary = dataset['test'][index]['summary']
    prompt = f"""
        Summarize the following conversation

        {dialogue}

        Summary: 
    """

    inputs = tokenizer(prompt, return_tensors='pt')
    generated_summary = tokenizer.decode(
        model.generate(
            inputs["input_ids"],
            max_new_tokens=50,
        )[0],
        skip_special_tokens=True
    )
    print(dash_line)
    print('Example ', i + 1)
    print(dash_line)
    print(f'BASELINE HUMAN SUMMARY\n{summary}')
    print(dash_line)
    print(f'MODEL GENERATION\n{generated_summary}')

-------------------------------------------------
Example  1
-------------------------------------------------
BASELINE HUMAN SUMMARY
#Person1# is in a hurry to catch a train. Tom tells #Person1# there is plenty of time.
-------------------------------------------------
MODEL GENERATION
The train is about to leave.
-------------------------------------------------
Example  2
-------------------------------------------------
BASELINE HUMAN SUMMARY
#Person1# teaches #Person2# how to upgrade software and hardware in #Person2#'s system.
-------------------------------------------------
MODEL GENERATION
#Person1#: I'm thinking of upgrading my computer.


#### One Shot Inference

In [32]:
def make_prompt(example_indices_full, example_index_to_summary):
    prompt = ''
    for index in example_indices_full:
        dialogue = dataset['test'][index]['dialogue']
        summary = dataset['test'][index]['summary']
        prompt += f"""
            Dialogue:
            {dialogue}
            What was going on?
            {summary}
        """
    dialogue = dataset['test'][example_index_to_summary]['dialogue']
    prompt += f"""
        Dialogue:
        {dialogue}
        What was going on?
    """
    return prompt

#### One shot inference

In [37]:
example_indices_full = [40]
example_index_to_summarize = 200
one_shot_prompt = make_prompt(example_indices_full, example_index_to_summarize)

summary = dataset['test'][example_index_to_summarize]['summary']

inputs = tokenizer(one_shot_prompt, return_tensors='pt')
output = tokenizer.decode(
    model.generate(
        inputs["input_ids"],
        max_new_tokens=50,
    )[0],
    skip_special_tokens=True
)
print(dash_line)
print(f'BASELINE HUMAN SUMMARY:\n{summary}\n')
print(dash_line)
print(f'MODEL GENERATION - ONE SHOT:\n{output}')

-------------------------------------------------
BASELINE HUMAN SUMMARY:
#Person1# teaches #Person2# how to upgrade software and hardware in #Person2#'s system.

-------------------------------------------------
MODEL GENERATION - ONE SHOT:
#Person1 wants to upgrade his system. #Person2 wants to add a painting program to his software. #Person1 wants to add a CD-ROM drive.
