## Benchmarking Llama 2 JSON inference with no guardrails

In [96]:
!pip install prettytable

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Collecting prettytable
  Downloading prettytable-3.9.0-py3-none-any.whl.metadata (26 kB)
Downloading prettytable-3.9.0-py3-none-any.whl (27 kB)
Installing collected packages: prettytable
Successfully installed prettytable-3.9.0
[0m

In [144]:
import json
import time
from tqdm import tqdm
from prettytable import PrettyTable
import numpy as np


class JSONBenchmark:
    def __init__(self, dataset_file):
        self.dataset = self.load_dataset(dataset_file)

    def load_dataset(self, dataset_file):
        with open(dataset_file, "r") as f:
            dataset = [json.loads(line) for line in f.readlines()]
        return dataset

    def generate_prompt(self, passage, schema):
        user_message = f"""{passage}
    
From the above passage, extract the following schema:
{schema}

Only output JSON with the allowed types."""
        prompt = f"""<s><<SYS>>You only respond in JSON. You do not add text before. You do not add text after. Only JSON.<</SYS>>[INST] {user_message} [/INST]"""
        return prompt

    def run(self, generate, **kwargs):
        evals = []
        for data in tqdm(self.dataset):
            evaluation = {}

            prompt = self.generate_prompt(data["passage"], data["schema"])
            start_time = time.time()
            result = generate(prompt, **kwargs)[0]["generated_text"].strip()
            time_taken = round(time.time() - start_time, 3)

            evaluation["generation"] = result
            evaluation["time_taken"] = time_taken

            # check if result is valid JSON
            try:
                json_result = json.loads(result)
                evaluation["is_valid"] = True

                # check if result matches schema
                # JSON might have erroneous keys
                schema = data["extracted_data"]
                evaluation["matches_schema"] = json_result == schema
                evaluation["error_type"] = None
            except ValueError:
                evaluation["is_valid"] = False
                evaluation["matches_schema"] = False

                if result[0] != "{":
                    evaluation["error_type"] = "prefix"
                elif result[-1] != "}":
                    evaluation["error_type"] = "suffix"
                else:
                    evaluation["error_type"] = "invalid"

            evals.append(evaluation)

        return evals

    def print(self, results, show_generation=False):
        table = PrettyTable()

        # Define the table columns
        table.field_names = [
            "Valid (✅/❌)",
            "Matches Schema (✅/❌)",
            "Time (s)",
            "Error",
        ]
        if show_generation:
            table.add_column("Generation")

        valid_counter, schema_counter, total_time = 0, 0, 0

        for result in results:
            is_valid = "✅" if result["is_valid"] else "❌"
            matches_schema = "✅" if result["matches_schema"] else "❌"
            error_type = result["error_type"]

            valid_counter += result["is_valid"]
            schema_counter += result["matches_schema"]
            total_time += result["time_taken"]

            row = [is_valid, matches_schema, result["time_taken"], error_type]
            if show_generation:
                row.append(result["generation"])

            table.add_row(row)

        valid_accuracy = valid_counter / len(results)
        schema_accuracy = schema_counter / len(results)
        average_time = round(total_time / len(results), 3)

        table.add_row(["-", "-", "-", "-"])
        table.add_row(
            [
                f"Accuracy: {valid_accuracy}",
                f"Accuracy: {schema_accuracy}",
                f"Average: {average_time}",
                "-",
            ]
        )

        print(table)


In [1]:
!pip install transformers accelerate

Collecting transformers
  Downloading transformers-4.35.1-py3-none-any.whl.metadata (123 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m123.1/123.1 kB[0m [31m3.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting accelerate
  Downloading accelerate-0.24.1-py3-none-any.whl.metadata (18 kB)
Collecting huggingface-hub<1.0,>=0.16.4 (from transformers)
  Downloading huggingface_hub-0.19.1-py3-none-any.whl.metadata (13 kB)
Collecting regex!=2019.12.17 (from transformers)
  Downloading regex-2023.10.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (40 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m40.9/40.9 kB[0m [31m6.5 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers<0.15,>=0.14 (from transformers)
  Downloading tokenizers-0.14.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.7 kB)
Collecting safetensors>=0.3.1 (from transformers)
  Downloading safetensors-0.4.0-cp310-cp310-manylinux_2_17_x86_64.manylin

In [165]:
eval_harness = JSONBenchmark("jsonbench.jsonl")

In [6]:
# load model
from transformers import AutoTokenizer
import transformers
import torch

model = "NousResearch/Llama-2-7b-chat-hf"

tokenizer = AutoTokenizer.from_pretrained(model)
pipeline = transformers.pipeline(
    "text-generation",
    model=model,
    torch_dtype=torch.float16,
    device_map="auto",
)

Downloading (…)okenizer_config.json:   0%|          | 0.00/746 [00:00<?, ?B/s]

Downloading tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.84M [00:00<?, ?B/s]

Downloading (…)in/added_tokens.json:   0%|          | 0.00/21.0 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/435 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/583 [00:00<?, ?B/s]

Downloading (…)fetensors.index.json:   0%|          | 0.00/26.8k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading (…)of-00002.safetensors:   0%|          | 0.00/9.98G [00:00<?, ?B/s]

Downloading (…)of-00002.safetensors:   0%|          | 0.00/3.50G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading (…)neration_config.json:   0%|          | 0.00/179 [00:00<?, ?B/s]



In [136]:
sampling_params = {
    "do_sample": True,
    "top_k": 10,
    "num_return_sequences": 1,
    "eos_token_id": tokenizer.eos_token_id,
    "max_length": 512,
    "return_full_text": False
}

In [166]:
outputs = eval_harness.run(pipeline, **sampling_params)

100%|██████████| 3/3 [00:12<00:00,  4.27s/it]


In [167]:
eval_harness.print(outputs)

+---------------+------------------------+----------------+---------+
| Valid (✅/❌) | Matches Schema (✅/❌) |    Time (s)    |  Error  |
+---------------+------------------------+----------------+---------+
|       ❌      |           ❌           |     3.664      |  prefix |
|       ❌      |           ❌           |     5.871      |  prefix |
|       ❌      |           ❌           |     3.257      | invalid |
|       -       |           -            |       -        |    -    |
| Accuracy: 0.0 |     Accuracy: 0.0      | Average: 4.264 |    -    |
+---------------+------------------------+----------------+---------+


In [168]:
print(outputs)

[{'generation': 'Here is the extracted schema in JSON format:\n\n{\n"company": {"name": "Apple Inc.", "location": "California"},\n"product": {"name": "iPhone 13", "release_date": "2021-09-15"},\n"event": {"type": "online", "attendance": 1000000},\n"CEO": {"name": "Tim Cook"},\n"price": {"base_model": 699}\n}\n}', 'time_taken': 3.664, 'is_valid': False, 'matches_schema': False, 'error_type': 'prefix'}, {'generation': 'Here is the extracted schema in JSON format:\n\n{\n"acquirer": {\n"name": "Zoom Video Communications",\n"location": "San Jose",\n"CEO": "Eric Yuan"\n},\n"target": {\n"name": "Five9",\n"CEO": "Rowan Trollope",\n"business_domain": "cloud contact center software"\n},\n"transaction": {\n"type": "acquisition",\n"method": "shares",\n"announcement_date": "July 18, 2021",\n"expected_close_date": "first half of 2022"\n,"potential_impact": "speed up Zoom\'s entrance into the contact center as a service (CCaaS) market"\n}\n}', 'time_taken': 5.871, 'is_valid': False, 'matches_schema':

In [171]:
print(eval_harness.dataset[2]["extracted_data"])

{'company': {'name': 'Microsoft Corporation', 'product': {'name': 'Windows 11', 'launch_date': '2021-10-05', 'features': ['redesigned task bar', 'improved voice typing', 'simplified window management', 'access to Android apps'], 'editions': [{'name': 'business edition', 'price': 199}, {'name': 'home edition', 'price': 139}]}}}


In [170]:
print(outputs[2]["generation"].strip())

{'company': {'name': 'Microsoft Corporation', 'product': {'name': 'Windows 11', 'launch_date': '2021-10-05', 'features': ['redesigned task bar', 'improved voice typing','simplified window management', 'access to Android apps'], 'editions': [{'name': 'Business', 'price': '199'}, {'name': 'Home', 'price': '139'}]}}}


## Types of issues
1. Incorrect output (we can't solve this)
2. Undesired prefix
3. Undesired suffix
4. Invalid JSON

2-4 are solved by constrained sampling.