In [1]:
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, GenerationConfig

In [2]:
DATASET_NAME = "knkarthick/dialogsum"
MODEL_NAME = "google/flan-t5-base"

In [3]:
dataset = load_dataset(DATASET_NAME)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)

In [None]:
dataset['train']

Dataset({
    features: ['id', 'dialogue', 'summary', 'topic'],
    num_rows: 12460
})

In [81]:
print(dataset['test']['dialogue'][200])

#Person1#: Have you considered upgrading your system?
#Person2#: Yes, but I'm not sure what exactly I would need.
#Person1#: You could consider adding a painting program to your software. It would allow you to make up your own flyers and banners for advertising.
#Person2#: That would be a definite bonus.
#Person1#: You might also want to upgrade your hardware because it is pretty outdated now.
#Person2#: How can we do that?
#Person1#: You'd probably need a faster processor, to begin with. And you also need a more powerful hard disc, more memory and a faster modem. Do you have a CD-ROM drive?
#Person2#: No.
#Person1#: Then you might want to add a CD-ROM drive too, because most new software programs are coming out on Cds.
#Person2#: That sounds great. Thanks.


In [67]:
def make_n_shot_summary_prompt(example_ids=None, summarize_id=0, data=dataset, my_set='test'):
    prompt = ''
    if example_ids:
        for i in example_ids:
            dialogue = data[my_set]['dialogue'][i]
            human_summary = data[my_set]['summary'][i]
    
            prompt += f"""
DIALOGUE:

{dialogue}

SUMMARY:

{human_summary}
"""
        
    dialogue = data[my_set]['dialogue'][summarize_id]

    prompt += f"""
DIALOGUE:

{dialogue}

SUMMARY:
"""
    return prompt
    

def get_model_completion(prompt, tokenizer=tokenizer, model=model, gen_config=None):
    sentence_encoded = tokenizer(prompt, return_tensors='pt')
    completion = model.generate(sentence_encoded.input_ids,
                               num_beams=1,
                               do_sample=True,
                               max_new_tokens=1000,
                               generation_config=gen_config)[0]
    return tokenizer.decode(completion, skip_special_tokens=True)
    

In [68]:
prompt = make_n_shot_summary_prompt(example_ids=[40, 80, 20], summarize_id=200)
print(prompt)


DIALOGUE:

#Person1#: What time is it, Tom?
#Person2#: Just a minute. It's ten to nine by my watch.
#Person1#: Is it? I had no idea it was so late. I must be off now.
#Person2#: What's the hurry?
#Person1#: I must catch the nine-thirty train.
#Person2#: You've plenty of time yet. The railway station is very close. It won't take more than twenty minutes to get there.

SUMMARY:

#Person1# is in a hurry to catch a train. Tom tells #Person1# there is plenty of time.

DIALOGUE:

#Person1#: May, do you mind helping me prepare for the picnic?
#Person2#: Sure. Have you checked the weather report?
#Person1#: Yes. It says it will be sunny all day. No sign of rain at all. This is your father's favorite sausage. Sandwiches for you and Daniel.
#Person2#: No, thanks Mom. I'd like some toast and chicken wings.
#Person1#: Okay. Please take some fruit salad and crackers for me.
#Person2#: Done. Oh, don't forget to take napkins disposable plates, cups and picnic blanket.
#Person1#: All set. May, can yo

In [75]:
gen_config = GenerationConfig(temperature=1.1, do_sample=True)
get_model_completion(prompt, gen_config=gen_config)

"And do you know what kind of things you'd like to get out of a package?"