In [1]:
import json
from pathlib import Path

params = {}
params_path = Path("/content/params.json")
if params_path.is_file():
    with params_path.open("r", encoding="UTF-8") as params_file:
        params = json.load(params_file)

params

{'hf_dataset': 'weaviate/WithoutRetrieval-SchemaSplit-Test-80',
 'push_to_hub': 'substratusai/wgql-WithRetrieval-SchemaSplit-Train-80-3epochs'}

In [2]:
from datasets import load_dataset

hf_dataset = params.get("hf_dataset")
if hf_dataset:
    dataset = load_dataset(hf_dataset)
else:
    dataset = load_dataset("json", data_files="/content/data/*.json*")

dataset

Downloading readme:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/3.90M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['input', 'output', 'nlcommand', 'apiRef', 'apiRefPath', 'schema', 'schemaPath'],
        num_rows: 825
    })
})

In [3]:
default_prompt = """
## Instruction
Your task is to write GraphQL for the Natural Language Query provided. Use the provided API reference and Schema to generate the GraphQL. The GraphQL should be valid for Weaviate.

Only use the API reference to understand the syntax of the request.

## Natural Language Query
{nlcommand}

## Schema
{schema}

## API reference
{apiRef}

## Answer
```graphql
"""

prompt = params.get("prompt_template", default_prompt)
print(prompt.format_map(dataset["train"][0]))


## Instruction
Your task is to write GraphQL for the Natural Language Query provided. Use the provided API reference and Schema to generate the GraphQL. The GraphQL should be valid for Weaviate.

Only use the API reference to understand the syntax of the request.

## Natural Language Query
```text
Show me the event name, description, year, significant impact, and the countries involved with their population for the top 10 historical events.
```

## Schema
{
"classes": [
{
"class": "HistoricalEvent",
"description": "Information about historical events",
"vectorIndexType": "hnsw",
"vectorizer": "text2vec-transformers",
"properties": [
{
"name": "eventName",
"dataType": ["text"],
"description": "Name of the historical event"
},
{
"name": "description",
"dataType": ["text"],
"description": "Detailed description of the event"
},
{
"name": "year",
"dataType": ["int"],
"description": "Year the event occurred"
},
{
"name": "hadSignificantImpact",
"dataType": ["boolean"],
"description": "Wheth

In [4]:
import transformers
import torch
import sys
from transformers import AutoTokenizer, AutoModelForCausalLM

model_path = "/content/model/"

tokenizer = AutoTokenizer.from_pretrained(model_path, local_files_only=True)
model = AutoModelForCausalLM.from_pretrained(
            model_path, device_map="auto", trust_remote_code=True,
            torch_dtype=torch.bfloat16, 
            use_flash_attention_2=True)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [5]:
! nvidia-smi

Sun Oct 15 07:03:54 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.105.17   Driver Version: 525.105.17   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA L4           Off  | 00000000:00:04.0 Off |                    0 |
| N/A   48C    P0    29W /  72W |  15452MiB / 23034MiB |      1%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [7]:
device = "cuda"
model.generation_config

GenerationConfig {
  "bos_token_id": 1,
  "eos_token_id": 2
}

In [53]:
stop_ids = torch.LongTensor(tokenizer.encode("```", add_special_tokens=False))
## Note the stop_ids aren't correct, for some reason there are multiple possible token IDs for ```
## so instead we're using tensor([13940, 28832], device='cuda:0') as the stop_ids, because that's
## what the model normally generates
print(stop_ids)
print(tokenizer.decode([8789]) == "```")
print(tokenizer.decode([13940, 28832]) == "```")
print(tokenizer.decode(tokenizer.encode("```", add_special_tokens=False)))

tensor([8789])
True
True
```


In [54]:
from transformers import StoppingCriteria, StoppingCriteriaList

class BacktickStoppingCriteria(StoppingCriteria):
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        if self.tokenizer.decode(input_ids[0][-2:]) == "```" or self.tokenizer.decode(input_ids[0][-1]) == "```":
            return True
        return False



stopping_criteria = StoppingCriteriaList([BacktickStoppingCriteria(tokenizer)])

In [46]:
%%time
import torch


device = "cuda"
model_inputs = tokenizer([prompt.format_map(dataset["train"][0])],
                         return_tensors="pt").to(device)

generated_ids = model.generate(**model_inputs,
                               max_new_tokens=300,
                               stopping_criteria=stopping_criteria)

print(tokenizer.decode(generated_ids[0], skip_special_tokens=True))


Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.



## Instruction
Your task is to write GraphQL for the Natural Language Query provided. Use the provided API reference and Schema to generate the GraphQL. The GraphQL should be valid for Weaviate.

Only use the API reference to understand the syntax of the request.

## Natural Language Query
```text
Show me the event name, description, year, significant impact, and the countries involved with their population for the top 10 historical events.
```

## Schema
{
"classes": [
{
"class": "HistoricalEvent",
"description": "Information about historical events",
"vectorIndexType": "hnsw",
"vectorizer": "text2vec-transformers",
"properties": [
{
"name": "eventName",
"dataType": ["text"],
"description": "Name of the historical event"
},
{
"name": "description",
"dataType": ["text"],
"description": "Detailed description of the event"
},
{
"name": "year",
"dataType": ["int"],
"description": "Year the event occurred"
},
{
"name": "hadSignificantImpact",
"dataType": ["boolean"],
"description": "Wheth

In [47]:
print(model_inputs["input_ids"].shape)
input_length = model_inputs["input_ids"].shape[1]
print(tokenizer.decode(generated_ids[0][input_length:], skip_special_tokens=True).strip("```"))


torch.Size([1, 537])
{
  Get {
    HistoricalEvent (
      limit: 10
    ) {
      eventName
      description
      year
      hadSignificantImpact
      involvedCountries {
        ... on Country {
          countryName
          population
        }
      }
    }
  }
}



In [None]:
import json
dataset_size = len(dataset["train"])
output_path = "/content/artifacts/test-output.json"
entries = []
print(f"Running inference for {dataset_size} entries in dataset")
for i in range(dataset_size):
    print(f"entry {i+1} of {dataset_size}")
    entry = dataset["train"][i]
    model_inputs = tokenizer([prompt.format_map(entry)],
                         return_tensors="pt").to(device)



    generated_ids = model.generate(**model_inputs,
                               max_new_tokens=300,
                               stopping_criteria=stopping_criteria)
    input_length = model_inputs["input_ids"].shape[1]
    output = tokenizer.decode(generated_ids[0][input_length:], skip_special_tokens=True)
    entry["modelOutput"] = output.strip("```")

    with open(output_path, 'a') as file:
        json.dump(entry, file)
        file.write("\n")

In [None]:
repo_id = params.get("push_to_hub")
if repo_id:
    hf_api.upload_file(
            path_or_fileobj=Path(output_path),
            path_in_repo=Path(output_path).name,
            repo_id=repo_id,
    )
    logs_path = Path("/content/artifacts/eval.ipynb")
    if logs_path.exists():
        hf_api.upload_file(
            path_or_fileobj=logs_path,
            path_in_repo=logs_path.name,
            repo_id=repo_id,
        )