In [1]:
import json
from pathlib import Path

params = {}
params_path = Path("/content/params.json")
if params_path.is_file():
    with params_path.open("r", encoding="UTF-8") as params_file:
        params = json.load(params_file)

params

{'hf_dataset': 'weaviate/WithoutRetrieval-SchemaSplit-Test-80',
 'prompt_template': '## Instruction\nYour task is to write GraphQL for the Natural Language Query provided. Use the provided API reference and Schema to generate the GraphQL. The GraphQL should be valid for Weaviate.\n\nOnly use the API reference to understand the syntax of the request.\n\n## Natural Language Query\n{nlcommand}\n\n## Schema\n{schema}\n\n## API reference\n{apiRef}\n\n## Answer\n```graphql\n',
 'push_to_hub': 'substratusai/wgql-WithRetrieval-SchemaSplit-Train-80'}

In [2]:
from datasets import load_dataset

hf_dataset = params.get("hf_dataset")
if hf_dataset:
    dataset = load_dataset(hf_dataset)
else:
    dataset = load_dataset("json", data_files="/content/data/*.json*")

dataset

Downloading readme:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/3.90M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['input', 'output', 'nlcommand', 'apiRef', 'apiRefPath', 'schema', 'schemaPath'],
        num_rows: 825
    })
})

In [3]:
default_prompt = """
## Instruction
Your task is to write GraphQL for the Natural Language Query provided. Use the provided API reference and Schema to generate the GraphQL. The GraphQL should be valid for Weaviate.

Only use the API reference to understand the syntax of the request.

## Natural Language Query
{nlcommand}

## Schema
{schema}

## API reference
{apiRef}

## Answer
```graphql
"""

prompt = params.get("prompt_template", default_prompt)
print(prompt.format_map(dataset["train"][0]))

## Instruction
Your task is to write GraphQL for the Natural Language Query provided. Use the provided API reference and Schema to generate the GraphQL. The GraphQL should be valid for Weaviate.

Only use the API reference to understand the syntax of the request.

## Natural Language Query
```text
Show me the event name, description, year, significant impact, and the countries involved with their population for the top 10 historical events.
```

## Schema
{
"classes": [
{
"class": "HistoricalEvent",
"description": "Information about historical events",
"vectorIndexType": "hnsw",
"vectorizer": "text2vec-transformers",
"properties": [
{
"name": "eventName",
"dataType": ["text"],
"description": "Name of the historical event"
},
{
"name": "description",
"dataType": ["text"],
"description": "Detailed description of the event"
},
{
"name": "year",
"dataType": ["int"],
"description": "Year the event occurred"
},
{
"name": "hadSignificantImpact",
"dataType": ["boolean"],
"description": "Whethe

In [4]:
import transformers
import torch
import sys
from transformers import AutoTokenizer, AutoModelForCausalLM

model_path = "/content/model/"
model_id = params["push_to_hub"]
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
            model_id, device_map="auto", trust_remote_code=True,
            torch_dtype=torch.bfloat16, 
            use_flash_attention_2=True)

Downloading (…)okenizer_config.json:   0%|          | 0.00/1.04k [00:00<?, ?B/s]

Downloading tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.84M [00:00<?, ?B/s]

Downloading (…)in/added_tokens.json:   0%|          | 0.00/21.0 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/552 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/648 [00:00<?, ?B/s]

Downloading (…)model.bin.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading (…)l-00001-of-00002.bin:   0%|          | 0.00/9.98G [00:00<?, ?B/s]

Downloading (…)l-00002-of-00002.bin:   0%|          | 0.00/3.50G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading (…)neration_config.json:   0%|          | 0.00/183 [00:00<?, ?B/s]

In [5]:
! nvidia-smi

Fri Oct 20 16:31:36 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.105.17   Driver Version: 525.105.17   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA L4           Off  | 00000000:00:04.0 Off |                    0 |
| N/A   49C    P0    29W /  72W |  13610MiB / 23034MiB |      3%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [6]:
device = "cuda"
model.generation_config

GenerationConfig {
  "bos_token_id": 1,
  "do_sample": true,
  "eos_token_id": 2,
  "max_length": 4096,
  "pad_token_id": 0,
  "temperature": 0.6,
  "top_p": 0.9
}

In [7]:
stop_ids = torch.LongTensor(tokenizer.encode("```", add_special_tokens=False))
## Note the stop_ids aren't correct, for some reason there are multiple possible token IDs for ```
## so instead we're using tensor([13940, 28832], device='cuda:0') as the stop_ids, because that's
## what the model normally generates
print(stop_ids)
print(tokenizer.decode([8789]) == "```")
print(tokenizer.decode([13940, 28832]) == "```")
print(tokenizer.decode(tokenizer.encode("```", add_special_tokens=False)))

tensor([7521])
False
False
```


In [8]:
from transformers import StoppingCriteria, StoppingCriteriaList

class BacktickStoppingCriteria(StoppingCriteria):
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        if self.tokenizer.decode(input_ids[0][-2:]) == "```" or self.tokenizer.decode(input_ids[0][-1]) == "```":
            return True
        return False



stopping_criteria = StoppingCriteriaList([BacktickStoppingCriteria(tokenizer)])

In [9]:
model.config.bos_token_id = tokenizer.bos_token_id = 1
model.config.eos_token_id = tokenizer.eos_token_id = 2
model.config.pad_token_id = tokenizer.pad_token_id = 0

In [10]:
%%time
import torch


device = "cuda"
model_inputs = tokenizer([prompt.format_map(dataset["train"][0])],
                         return_tensors="pt").to(device)

generated_ids = model.generate(**model_inputs,
                               max_new_tokens=300,
                               stopping_criteria=stopping_criteria)

print(tokenizer.decode(generated_ids[0], skip_special_tokens=True))


## Instruction
Your task is to write GraphQL for the Natural Language Query provided. Use the provided API reference and Schema to generate the GraphQL. The GraphQL should be valid for Weaviate.

Only use the API reference to understand the syntax of the request.

## Natural Language Query
```text
Show me the event name, description, year, significant impact, and the countries involved with their population for the top 10 historical events.
```

## Schema
{
"classes": [
{
"class": "HistoricalEvent",
"description": "Information about historical events",
"vectorIndexType": "hnsw",
"vectorizer": "text2vec-transformers",
"properties": [
{
"name": "eventName",
"dataType": ["text"],
"description": "Name of the historical event"
},
{
"name": "description",
"dataType": ["text"],
"description": "Detailed description of the event"
},
{
"name": "year",
"dataType": ["int"],
"description": "Year the event occurred"
},
{
"name": "hadSignificantImpact",
"dataType": ["boolean"],
"description": "Whethe

In [11]:
print(model_inputs["input_ids"].shape)
input_length = model_inputs["input_ids"].shape[1]
print(tokenizer.decode(generated_ids[0][input_length:], skip_special_tokens=True).strip("```"))


torch.Size([1, 534])
{
  Get {
    HistoricalEvent (
      limit: 10
    ) {
      eventName
      description
      year
      hadSignificantImpact
      involvedCountries {
        ... on Country {
          countryName
          population
        }
      }
    }
  }
}



In [12]:
import json
dataset_size = len(dataset["train"])
output_path = "/content/artifacts/test-output.json"
entries = []
print(f"Running inference for {dataset_size} entries in dataset")
for i in range(dataset_size):
    print(f"entry {i+1} of {dataset_size}")
    entry = dataset["train"][i]
    model_inputs = tokenizer([prompt.format_map(entry)],
                         return_tensors="pt").to(device)



    generated_ids = model.generate(**model_inputs,
                               max_new_tokens=300,
                               stopping_criteria=stopping_criteria)
    input_length = model_inputs["input_ids"].shape[1]
    output = tokenizer.decode(generated_ids[0][input_length:], skip_special_tokens=True)
    entry["modelOutput"] = output.strip("```")
    entries.append(entry)

    with open(output_path, 'a') as file:
        json.dump(entry, file)
        file.write("\n")

Running inference for 825 entries in dataset
entry 1 of 825
entry 2 of 825
entry 3 of 825
entry 4 of 825
entry 5 of 825
entry 6 of 825
entry 7 of 825
entry 8 of 825
entry 9 of 825
entry 10 of 825
entry 11 of 825
entry 12 of 825
entry 13 of 825
entry 14 of 825
entry 15 of 825
entry 16 of 825
entry 17 of 825
entry 18 of 825
entry 19 of 825
entry 20 of 825
entry 21 of 825
entry 22 of 825
entry 23 of 825
entry 24 of 825
entry 25 of 825
entry 26 of 825
entry 27 of 825
entry 28 of 825
entry 29 of 825
entry 30 of 825
entry 31 of 825
entry 32 of 825
entry 33 of 825
entry 34 of 825
entry 35 of 825
entry 36 of 825
entry 37 of 825
entry 38 of 825
entry 39 of 825
entry 40 of 825
entry 41 of 825
entry 42 of 825
entry 43 of 825
entry 44 of 825
entry 45 of 825
entry 46 of 825
entry 47 of 825
entry 48 of 825
entry 49 of 825
entry 50 of 825
entry 51 of 825
entry 52 of 825
entry 53 of 825
entry 54 of 825
entry 55 of 825
entry 56 of 825
entry 57 of 825
entry 58 of 825
entry 59 of 825
entry 60 of 825
entr

Store the test dataset with model output in the original HuggingFace Model repo

In [13]:
from huggingface_hub import HfApi
repo_id = params.get("push_to_hub")
if repo_id:
    hf_api = HfApi()
    hf_api.upload_file(
            path_or_fileobj=Path(output_path),
            path_in_repo=Path(output_path).name,
            repo_id=repo_id,
    )
    logs_path = Path("/content/artifacts/eval.ipynb")
    if logs_path.exists():
        hf_api.upload_file(
            path_or_fileobj=logs_path,
            path_in_repo=logs_path.name,
            repo_id=repo_id,
        )

## Execute the model output on a live Weaviate cluster

In [37]:
import weaviate
from weaviate.embedded import EmbeddedOptions

client = weaviate.Client(
    embedded_options=EmbeddedOptions()
)

Started /root/.cache/weaviate-embedded: process ID 7289


{"action":"startup","default_vectorizer_module":"none","level":"info","msg":"the default vectorizer modules is set to \"none\", as a result all new schema classes without an explicit vectorizer setting, will use this vectorizer","time":"2023-10-21T00:13:16Z"}
{"action":"startup","auto_schema_enabled":true,"level":"info","msg":"auto schema enabled setting is set to \"true\"","time":"2023-10-21T00:13:16Z"}
{"action":"grpc_startup","level":"info","msg":"grpc server listening at [::]:50051","time":"2023-10-21T00:13:16Z"}
{"action":"restapi_management","level":"info","msg":"Serving weaviate at http://127.0.0.1:6666","time":"2023-10-21T00:13:16Z"}


In [38]:
! cat ToySchemas/{dataset["train"][0]["schemaPath"]}
dataset

{
    "classes": [
      {
        "class": "HistoricalEvent",
        "description": "Information about historical events",
        "vectorIndexType": "hnsw",
        "vectorizer": "text2vec-transformers",
        "properties": [
          {
            "name": "eventName",
            "dataType": ["text"],
            "description": "Name of the historical event"
          },
          {
            "name": "description",
            "dataType": ["text"],
            "description": "Detailed description of the event"
          },
          {
            "name": "year",
            "dataType": ["int"],
            "description": "Year the event occurred"
          },
          {
            "name": "hadSignificantImpact",
            "dataType": ["boolean"],
            "description": "Whether the event had a significant impact"
          },
          {
            "name": "involvedCountries",
            "dataType": ["Country"],
            "description": "Countries involved in the e

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


DatasetDict({
    train: Dataset({
        features: ['input', 'output', 'nlcommand', 'apiRef', 'apiRefPath', 'schema', 'schemaPath'],
        num_rows: 825
    })
})

In [42]:
from typing import Dict
import json
def json_reader(file_path):
    with open(file_path, 'r', encoding='utf-8', errors='replace') as api_ref_fh:
        data = json.load(api_ref_fh)
    return data

def remove_vectorizer(classes: list[Dict]) -> list[Dict]:
    new_list = []
    for c in classes:
        if "vectorizer" in c:
            del c["vectorizer"]
        new_list.append(c)
    return new_list

def didItExecute(schemaPath, modelOutput):
    client.schema.delete_all()
    schema = json_reader(f'ToySchemas/{schemaPath}')
    schema["classes"] = remove_vectorizer(schema["classes"])
    client.schema.create(schema)
    WeaviateResponse = client.query.raw(modelOutput)
    return WeaviateResponse

sample = entries[0]
didItExecute(sample["schemaPath"], sample["modelOutput"])

{"action":"hnsw_vector_cache_prefill","count":1000,"index_id":"historicalevent_3GCGHMH4c4Em","level":"info","limit":1000000000000,"msg":"prefilled vector cache","time":"2023-10-21T00:18:03Z","took":90086}
{"action":"hnsw_vector_cache_prefill","count":1000,"index_id":"country_DHJOHEyNcXoH","level":"info","limit":1000000000000,"msg":"prefilled vector cache","time":"2023-10-21T00:18:03Z","took":86369}


{'data': {'Get': {'HistoricalEvent': []}}}

In [None]:
%%capture

counter = 1
successfulQueries = []
failedQueries = []
failedAPIsCount = {}
failedSchemasCount = {}
for idx, example in enumerate(entries):
    failed = False
    modelQuery = example["modelOutput"]

    weaviateResponse = didItExecute(example["schemaPath"], modelQuery)

    if "errors" in weaviateResponse.keys():
        failed = True

    if failed:
        print("FAILED! FAILED! FAILED! \n")
        print(idx)
        failedQueries.append(example)
        # Update failed Schema tracker
        if example["schemaPath"] in failedSchemasCount.keys():
            failedSchemasCount[example["schemaPath"]] += 1
        else:
            failedSchemasCount[example["schemaPath"]] = 1
        # Update API tracker
        if example["apiRefPath"] in failedAPIsCount.keys():
            failedAPIsCount[example["apiRefPath"]] += 1
        else:
            failedAPIsCount[example["apiRefPath"]] = 1
    else:
        successfulQueries.append(example)

In [45]:
print(f"{len(successfulQueries)} Queries successfully executed!")
print(f"{len(failedQueries)} Queries failed to execute!")
print("FAILED API Count \n")
print(failedAPIsCount)
print("FAILE SCHEMA COUNT \n")
print(failedSchemasCount)

493 Queries successfully executed!
332 Queries failed to execute!
FAILED API Count 

{'get-hybrid-explainScore.txt': 17, 'get-hybrid-with-autocut.txt': 20, 'get-hybrid-alpha.txt': 20, 'get-nearText.txt': 20, 'get-nearText-with-distance.txt': 20, 'aggregate-nearText-with-distance.txt': 8, 'get-hybrid-alpha-properties.txt': 20, 'aggregate-nearText-with-limit.txt': 14, 'get-hybrid-weight-properties.txt': 20, 'get-hybrid.txt': 20, 'get-where-with-search.txt': 20, 'get-hybrid-with-where.txt': 20, 'get-hybrid-with-limit.txt': 20, 'get-hybrid-fusionType.txt': 20, 'get-nearText-with-autocut.txt': 20, 'get-nearText-with-where.txt': 20, 'get-nearText-with-limit.txt': 20, 'get-reranking-vector-search.txt': 13}
FAILE SCHEMA COUNT 

{'historicalevent.json': 16, 'musicalinstrument.json': 17, 'weatherstation.json': 17, 'AIModels.json': 17, 'outdoorgear.json': 17, 'startups.json': 18, 'videogame.json': 18, 'books.json': 17, 'craftbeer.json': 17, 'chemicals.json': 18, 'pharmaceuticals.json': 15, 'filmf