In [12]:
!pip install transformers accelerate jsonformer prettytable

Collecting transformers
  Downloading transformers-4.35.2-py3-none-any.whl (7.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.9/7.9 MB[0m [31m30.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting accelerate
  Downloading accelerate-0.24.1-py3-none-any.whl (261 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m261.4/261.4 kB[0m [31m32.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting jsonformer
  Downloading jsonformer-0.12.0-py3-none-any.whl (6.6 kB)
Collecting huggingface-hub<1.0,>=0.16.4 (from transformers)
  Downloading huggingface_hub-0.19.3-py3-none-any.whl (311 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m311.2/311.2 kB[0m [31m34.2 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers<0.19,>=0.14 (from transformers)
  Downloading tokenizers-0.15.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.8/3.8 MB[0m [31m76.7 MB/s[0m eta 

In [13]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from jsonformer import Jsonformer

print("Loading model and tokenizer...")
model_name = "databricks/dolly-v2-3b"
model = AutoModelForCausalLM.from_pretrained(model_name, use_cache=True, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True, use_cache=True)
print("Loaded model and tokenizer")

Loading model and tokenizer...


(…)cks/dolly-v2-3b/resolve/main/config.json:   0%|          | 0.00/819 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/5.68G [00:00<?, ?B/s]

(…)v2-3b/resolve/main/tokenizer_config.json:   0%|          | 0.00/450 [00:00<?, ?B/s]

(…)/dolly-v2-3b/resolve/main/tokenizer.json:   0%|          | 0.00/2.11M [00:00<?, ?B/s]

(…)-3b/resolve/main/special_tokens_map.json:   0%|          | 0.00/228 [00:00<?, ?B/s]

Loaded model and tokenizer


In [16]:
import json
import time
from tqdm import tqdm
from prettytable import PrettyTable
import numpy as np


class JSONBenchmark:
    def __init__(self, dataset_file):
        self.dataset = self.load_dataset(dataset_file)

    def load_dataset(self, dataset_file):
        with open(dataset_file, "r") as f:
            dataset = [json.loads(line) for line in f.readlines()]
        return dataset

    def generate_prompt(self, passage, schema):
        user_message = f"""{passage}

From the above passage, extract the following schema:
{schema}

Only output JSON with the allowed types."""
        prompt = f"""<s><<SYS>>You only respond in JSON. You do not add text before. You do not add text after. Only JSON.<</SYS>>[INST] {user_message} [/INST]"""
        return prompt

    def run(self, generate, **kwargs):
        evals = []
        for data in tqdm(self.dataset):
            evaluation = {}

            prompt = self.generate_prompt(data["passage"], data["schema"])
            start_time = time.time()
            result = generate(prompt, **kwargs)[0]["generated_text"].strip()
            time_taken = round(time.time() - start_time, 3)

            evaluation["generation"] = result
            evaluation["time_taken"] = time_taken

            # check if result is valid JSON
            try:
                json_result = json.loads(result)
                evaluation["is_valid"] = True

                # check if result matches schema
                # JSON might have erroneous keys
                schema = data["extracted_data"]
                evaluation["matches_schema"] = json_result == schema
                evaluation["error_type"] = None
            except ValueError:
                evaluation["is_valid"] = False
                evaluation["matches_schema"] = False

                if result[0] != "{":
                    evaluation["error_type"] = "prefix"
                elif result[-1] != "}":
                    evaluation["error_type"] = "suffix"
                else:
                    evaluation["error_type"] = "invalid"

            evals.append(evaluation)

        return evals

    def print(self, results, show_generation=False):
        table = PrettyTable()

        # Define the table columns
        table.field_names = [
            "Valid (✅/❌)",
            "Matches Schema (✅/❌)",
            "Time (s)",
            "Error",
        ]
        if show_generation:
            table.add_column("Generation")

        valid_counter, schema_counter, total_time = 0, 0, 0

        for result in results:
            is_valid = "✅" if result["is_valid"] else "❌"
            matches_schema = "✅" if result["matches_schema"] else "❌"
            error_type = result["error_type"]

            valid_counter += result["is_valid"]
            schema_counter += result["matches_schema"]
            total_time += result["time_taken"]

            row = [is_valid, matches_schema, result["time_taken"], error_type]
            if show_generation:
                row.append(result["generation"])

            table.add_row(row)

        valid_accuracy = valid_counter / len(results)
        schema_accuracy = schema_counter / len(results)
        average_time = round(total_time / len(results), 3)

        table.add_row(["-", "-", "-", "-"])
        table.add_row(
            [
                f"Accuracy: {valid_accuracy}",
                f"Accuracy: {schema_accuracy}",
                f"Average: {average_time}",
                "-",
            ]
        )

        print(table)

In [54]:
class JSONFormerBenchmark(JSONBenchmark):
  def __init__(self, dataset_file):
    super().__init__(dataset_file)
    self.SUPPORTED_FIELD_TYPES = [
        "string",
        "number",
        "boolean",
        "array",
        "object"
    ]

  def convert_schema_to_jsonformer_format(self, schema):
    jsonformer_schema = {}
    jsonformer_schema["type"] = "object"
    jsonformer_schema["properties"] = {}
    for key in schema.keys():
      if type(schema[key]) is dict:
        jsonformer_schema["properties"][key] = self.convert_schema_to_jsonformer_format(schema[key])
      else:
        if schema[key] not in self.SUPPORTED_FIELD_TYPES:
          jsonformer_schema["properties"][key] = {"type" : "number"}
        else:
          jsonformer_schema["properties"][key] = {"type" : schema[key]}

    return jsonformer_schema


  def has_matching_schema(self, output, target):
    output_keys = output.keys()
    target_keys = output.keys()

    if output_keys != target.keys():
      return False

    else:
      for key in output_keys:
        if type(output[key]) is dict:
          if not self.has_matching_schema(output[key], target[key]):
            return False

    return True

  def generate_prompt(self, passage):
        prompt = f"""{passage}

From the above passage, generate information based on the following schema:"""
        return prompt

  def run(self, **kwargs):
    evals = []
    for data in tqdm(self.dataset[:-1]):
        evaluation = {}

        prompt = self.generate_prompt(data["passage"])
        json_schema = self.convert_schema_to_jsonformer_format(data["schema"])
        start_time = time.time()
        builder = Jsonformer(
            model=model,
            tokenizer=tokenizer,
            json_schema=json_schema,
            prompt=prompt
        )

        result_dict = builder()
        result = str(result_dict).replace("\'", "\"")

        time_taken = round(time.time() - start_time, 3)

        evaluation["generation"] = result
        evaluation["time_taken"] = time_taken

        # check if result is valid JSON
        try:
            json_result = json.loads(result)
            evaluation["is_valid"] = True

            # check if result matches schema
            # JSON might have erroneous keys
            schema = data["extracted_data"]
            #evaluation["matches_schema"] = json_result == schema
            evaluation["matches_schema"] = self.has_matching_schema(result_dict, schema)
            evaluation["error_type"] = None
        except ValueError:
            evaluation["is_valid"] = False
            evaluation["matches_schema"] = False

            if result[0] != "{":
                evaluation["error_type"] = "prefix"
            elif result[-1] != "}":
                evaluation["error_type"] = "suffix"
            else:
                evaluation["error_type"] = "invalid"

        evals.append(evaluation)

    return evals


In [56]:
sampling_params = {
    "do_sample": True,
    "top_k": 10,
    "num_return_sequences": 1,
    "eos_token_id": tokenizer.eos_token_id,
    "max_length": 512,
    "return_full_text": False
}

In [57]:
eval_harness = JSONFormerBenchmark("jsonbench.jsonl")

In [58]:
outputs = eval_harness.run(**sampling_params)

100%|██████████| 2/2 [00:07<00:00,  3.92s/it]


In [59]:
eval_harness.print(outputs)

+---------------+------------------------+----------------+-------+
| Valid (✅/❌) | Matches Schema (✅/❌) |    Time (s)    | Error |
+---------------+------------------------+----------------+-------+
|       ✅      |           ✅           |     3.123      |  None |
|       ✅      |           ✅           |     4.708      |  None |
|       -       |           -            |       -        |   -   |
| Accuracy: 1.0 |     Accuracy: 1.0      | Average: 3.916 |   -   |
+---------------+------------------------+----------------+-------+
