In [None]:
# Define prompts and types
prompts = [
    'What is the capital of the United States of America?',
    'What is the capital of India?',
    'What is the capital of China?',
    'What is the next number in the sequence: 1, 1, 2, 3, 5, 8, ...? If all cats have tails, and Fluffy is a cat, does Fluffy have a tail?',
    'If you eat too much junk food, what will happen to your health? How does smoking affect the risk of lung cancer?',
    'In the same way that pen is related to paper, what is fork related to? If tree is related to forest, what is brick related to?',
    'Every time John eats peanuts, he gets a rash. Does John have a peanut allergy? Every time Sarah studies for a test, she gets an A. Will Sarah get an A on the next test if she studies?',
    'All dogs have fur. Max is a dog. Does Max have fur? If it is raining outside, and Mary does not like to get wet, will Mary take an umbrella?',
    'If I had studied harder, would I have passed the exam? What would have happened if Thomas Edison had not invented the light bulb?',
    'The center of Tropical Storm Arlene, at 02/1800 UTC, is near 26.7N 86.2W. This position is about 425 km/230 nm to the west of Fort Myers in Florida, and it is about 550 km/297 nm to the NNW of the western tip of Cuba. The tropical storm is moving southward, or 175 degrees, 4 knots. The estimated minimum central pressure is 1002 mb. The maximum sustained wind speeds are 35 knots with gusts to 45 knots. The sea heights that are close to the tropical storm are ranging from 6 feet to a maximum of 10 feet.  Precipitation: scattered to numerous moderate is within 180 nm of the center in the NE quadrant. Isolated moderate is from 25N to 27N between 80W and 84W, including parts of south Florida.  Broad surface low pressure extends from the area of the tropical storm, through the Yucatan Channel, into the NW part of the Caribbean Sea.   Where and when will the storm make landfall?'
]

types = [
    'Knowledge Retrieval',
    'Knowledge Retrieval',
    'Knowledge Retrieval',
    'Logical Reasoning',
    'Cause and Effect',
    'Analogical Reasoning',
    'Inductive Reasoning',
    'Deductive Reasoning',
    'Counterfactual Reasoning',
    'In Context'
]

In [None]:
import time
# import matplotlib.pyplot as plt
from langchain.llms import HuggingFacePipeline
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline

import os
# Disable parallelism and avoid the warning message
os.environ["TOKENIZERS_PARALLELISM"] = "false"


def benchmarking(model_ids):
    prompt_types = []
    model_load_times = []
    model_tokenizer_load_times = []
    model_pipeline_load_times = []
    generation_times = {}
    res = {}
    for model_id in model_ids:
        # Load tokenizer
        tokenizer_start_time = time.time()
        tokenizer = AutoTokenizer.from_pretrained(model_id)
        tokenizer_end_time = time.time()

        # Load model
        model_start_time = time.time()
        model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
        model_end_time = time.time()

        # Load pipeline
        pipe_start_time = time.time()
        pipe = pipeline("text2text-generation", model=model, tokenizer=tokenizer, max_length=512)
        local_llm = HuggingFacePipeline(pipeline=pipe)
        pipe_end_time = time.time()

    
        model_load_times.append(model_end_time - model_start_time)
        model_tokenizer_load_times.append(tokenizer_end_time - tokenizer_start_time)
        model_pipeline_load_times.append(pipe_end_time - pipe_start_time)
        

        # Loop thru prompt list, measure the time to the generate answers, print prompt, answer, time, type
        generation_times[model_id] = []
        res[model_id] = []
        for i, prompt in enumerate(prompts):
            start_time = time.time()
            answer = local_llm(prompt)
            end_time = time.time()
            generation_times[model_id].append(end_time - start_time)
            res[model_id].append({prompt: answer, time: end_time - start_time})
            


In [None]:
model_ids = ['google/flan-t5-small', 'google/flan-t5-base', 'google/flan-t5-large', 'google/flan-t5-xl', 'google/flan-t5-xxl']

benchmarking(model_ids)