## Benchmarking Llama 2 JSON inference with key-wise batches

1.   List item
2.   List item



In [11]:
!pip install prettytable transformers accelerate

Collecting accelerate
  Downloading accelerate-0.24.1-py3-none-any.whl (261 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m261.4/261.4 kB[0m [31m4.1 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: accelerate
Successfully installed accelerate-0.24.1


In [2]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [9]:
%cd gdrive/MyDrive/madlibs\ \(1\)

/content/gdrive/MyDrive/madlibs (1)


In [10]:
import json
from collections import defaultdict
import time
from tqdm import tqdm
from torch.utils.data import Dataset
from prettytable import PrettyTable
from itertools import islice
import numpy as np
from utils import JSONBatcher, JSONDataset

In [22]:
import json
import time
from tqdm import tqdm
from prettytable import PrettyTable
import numpy as np


class JSONBatchedBenchmark:
    def __init__(self, dataset_file):
        self.dataset_file = dataset_file
        self.dataset = self.load_dataset(dataset_file)

    def load_dataset(self, dataset_file):
        with open(dataset_file, "r") as f:
            dataset = [json.loads(line) for line in f.readlines()]
        return dataset

    def has_matching_schema(self, output, target):

        if type(output) is not type(target):
          return False

        output_keys = output.keys()
        target_keys = target.keys()

        if output_keys != target_keys:
          return False

        else:
          for key in output_keys:
            if type(output[key]) is dict:
              if not self.has_matching_schema(output[key], target[key]):
                return False

        return True

    def generate_prompt(self, passage, schema):
        user_message = f"""{passage}
        From the above passage, extract the following schema: {schema}

        Only output JSON with the allowed types."""

        prompt = f"""<s><<SYS>>You only respond in JSON. You do not add text before. You do not add text after. Only JSON. <</SYS>>[INST] {user_message} [/INST]"""
        return prompt

    def run(self, generate, batch_sizes, **kwargs):
        evals = []

        #dataset generator object from raw JSON file
        batcher = JSONBatcher(self.dataset_file)
        data, schemas = batcher.get_dataset(self.generate_prompt)

        #Initialize Hugging Face Dataset object
        dataset = JSONDataset(data)

        for batch_size in batch_sizes:
          outputs = []
          run_times = []

          start_time = time.time()

          for out in tqdm(generate(dataset, batch_size = batch_size, **sampling_params)):
              time_taken = round(time.time() - start_time, 3)
              run_times.append(time_taken)
              outputs.append(out)
              start_time = time.time()

          for output, run_time, schema in zip(outputs, run_times, schemas):
              evaluation = {}

              result = output[0]["generated_text"].strip()
              result = result.replace("\'", "\"")

              evaluation["generation"] = result
              evaluation["time_taken"] = time_taken

              # check if result is valid JSON
              try:
                  json_result = json.loads(result)
                  evaluation["is_valid"] = True

                  # check if result matches schema
                  # JSON might have erroneous keys
                  evaluation["matches_schema"] = self.has_matching_schema(json_result, schema)
                  evaluation["error_type"] = None
              except ValueError:
                  evaluation["is_valid"] = False
                  evaluation["matches_schema"] = False

                  if result[0] != "{":
                      evaluation["error_type"] = "prefix"
                  elif result[-1] != "}":
                      evaluation["error_type"] = "suffix"
                  else:
                      evaluation["error_type"] = "invalid"

              evaluation["batch_size"] = batch_size
              evals.append(evaluation)

        return evals

    def print(self, results, show_generation=False):
        table = PrettyTable()

        # Define the table columns
        table.field_names = [
            "Valid (✅/❌)",
            "Matches Schema (✅/❌)",
            "Batch Size",
            "Time (s)",
            "Error",
        ]
        if show_generation:
            table.add_column("Generation")

        valid_counter, schema_counter, total_time = 0, 0, 0

        for result in results:
            is_valid = "✅" if result["is_valid"] else "❌"
            matches_schema = "✅" if result["matches_schema"] else "❌"
            error_type = result["error_type"]
            batch_size = result["batch_size"]

            valid_counter += result["is_valid"]
            schema_counter += result["matches_schema"]
            total_time += result["time_taken"]

            row = [is_valid, matches_schema, batch_size, result["time_taken"], error_type]
            if show_generation:
                row.append(result["generation"])

            table.add_row(row)

        valid_accuracy = valid_counter / len(results)
        schema_accuracy = schema_counter / len(results)
        average_time = round(total_time / len(results), 3)

        table.add_row(["-", "-", "-", "-", "-"])
        table.add_row(
            [
                f"Accuracy: {valid_accuracy}",
                f"Accuracy: {schema_accuracy}",
                "-",
                f"Average: {average_time}",
                "-",
            ]
        )

        print(table)


In [13]:
# load model
from transformers import AutoTokenizer
import transformers
import torch

model = "NousResearch/Llama-2-7b-chat-hf"

tokenizer = AutoTokenizer.from_pretrained(model)
pipeline = transformers.pipeline(
    "text-generation",
    model=model,
    torch_dtype=torch.float16,
    device_map="auto",
)

tokenizer_config.json:   0%|          | 0.00/746 [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.84M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/21.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/435 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/583 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/26.8k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/9.98G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/3.50G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/179 [00:00<?, ?B/s]



In [14]:
sampling_params = {
    "num_return_sequences": 1,
    "eos_token_id": tokenizer.eos_token_id,
    "max_length": 512,
    "return_full_text": False
}

In [23]:
eval_harness = JSONBatchedBenchmark("jsonbench (1).jsonl")
outputs = eval_harness.run(pipeline, batch_sizes = [2,4], **sampling_params)

31it [02:03,  3.98s/it]
31it [01:16,  2.45s/it]


In [24]:
eval_harness.print(outputs)

+------------------------------+------------------------------+------------+----------------+--------+
|        Valid (✅/❌)         |    Matches Schema (✅/❌)    | Batch Size |    Time (s)    | Error  |
+------------------------------+------------------------------+------------+----------------+--------+
|              ❌              |              ❌              |     2      |     18.242     | prefix |
|              ❌              |              ❌              |     2      |     18.242     | prefix |
|              ❌              |              ❌              |     2      |     18.242     | prefix |
|              ❌              |              ❌              |     2      |     18.242     | prefix |
|              ❌              |              ❌              |     2      |     18.242     | prefix |
|              ❌              |              ❌              |     2      |     18.242     | prefix |
|              ❌              |              ❌              |     2      |     18.242    

## Types of issues
1. Incorrect output (we can't solve this)
2. Undesired prefix
3. Undesired suffix
4. Invalid JSON

2-4 are solved by constrained sampling.