In [3]:
import os
from dotenv import load_dotenv
import pandas as pd
import json

from predibase import PredibaseClient
from neo4j import GraphDatabase, RoutingControl

In [4]:
neo4j_uri = os.environ.get('NEO4J_URI')
neo4j_username = os.environ.get('NEO4J_USERNAME')
neo4j_password = os.environ.get('NEO4J_PASSWORD')
neo4j_auth =  (neo4j_username, neo4j_password)

In [5]:
open_ai_api_key = os.environ.get('OPENAI_API_KEY')

In [6]:
pc = PredibaseClient()
driver = GraphDatabase.driver(neo4j_uri, auth=(neo4j_auth))

In [7]:
# helper function to execute the predibase driver

def execute_query(driver, query):
    with driver.session() as session:
        def _execute(tx):
            result = tx.run(query)
            return [record.data() for record in result]
        return session.execute_read(_execute)

#### Load and preview training data

Using the training data created synthetically using `generate_synthetic_data`directory in this repository.

In [8]:
training_data = 'training-data-twitter.csv'
df = pd.read_csv(training_data)

In [9]:
print(f"Number of rows: {len(df)}")

Number of rows: 1838


In [10]:
display(df.head())

Unnamed: 0,instruction,input,output
0,"Given this schema, write a Cypher query that r...",I am neo4j. Excluding common stop words in Eng...,MATCH (n:User {screen_name: 'neo4j'})-[:POSTS]...
1,"Given this schema, write a Cypher query that r...",I am neo4j. What are the most common words in ...,MATCH (n:User {screen_name: 'neo4j'})-[:POSTS]...
2,"Given this schema, write a Cypher query that r...",I am neo4j. Identify frequently used words in ...,MATCH (n:User {screen_name: 'neo4j'})-[:POSTS]...
3,"Given this schema, write a Cypher query that r...",I am neo4j. Show the top words from my favorit...,MATCH (n:User {screen_name: 'neo4j'})-[:POSTS]...
4,"Given this schema, write a Cypher query that r...",I am neo4j. List the significant words in twee...,MATCH (n:User {screen_name: 'neo4j'})-[:POSTS]...


#### Upload training data to Predibase

In [11]:
dataset = pc.upload_dataset(training_data)

Dataset name was not explicitly provided. Defaulting to: training-data-twitter_1_2_3_4_5_6
Uploading dataset...
Dataset uploaded.


#### Configure and start the training job

In [12]:
# Defines the template used to prompt the model for each example

prompt_template = """Below is an instruction that describes a task, paired with an input
    that may provide further context. Write a response that appropriately
    completes the request.

    ### Instruction: {instruction}

    ### Input: {input}

    ### Response:
"""

In [13]:
# Specifies the Huggingface LLM you want to fine-tune
# Kick off a fine-tuning job on the uploaded dataset
llm = pc.LLM("hf://meta-llama/Llama-2-7b-hf")
job = llm.finetune(
    prompt_template=prompt_template,
    target="output",
    dataset=dataset,
)

# Wait for the job to finish and get training updates and metrics
model = job.get()

✓ Queued 0:00:31   
✓ Preprocessing 0:00:23   
                                    

┌──────────┬──────────┬──────────────────┬──────────────────────────┬──────────┬──────────┬──────────┐
│  epochs  [0m│   time   [0m│     feature      [0m│          metric          [0m│  train   [0m│   val    [0m│   test   [0m│
├──────────┼──────────┼──────────────────┼──────────────────────────┼──────────┼──────────┼──────────┤
│ 744/5241 steps ■■■■■■■■■■■■□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□ │

#### Load new model with adapter

Wait for the model to have finished training in the previous step before setting up the deployment. 

In [None]:
base_deployment = pc.LLM("pb://deployments/llama-2-7b")

# Set the adapapter to our newly fine tuned model
adapter_deployment = base_deployment.with_adapter(model)

#### Test the new model 

In [None]:
# introspect schema from neo4j
query = "CALL apoc.meta.graph()"
meta_graph = execute_query(driver, query)
meta_graph_str = json.dumps(meta_graph)

result = adapter_deployment.prompt(
    {
      "instruction": f"Given this schema, write a Cypher query that returns the data I am looking for. Schema:  {meta_graph_str}",
      "input": "I am neo4j. Find the hashtags used in my tweets that have the most favourites."
    },
    max_new_tokens=256)

#### Deploy new model 

In [None]:
finetuned_llm = model.deploy("llama-2-7b").get()

If you have already fine tuned the model and want to use it for prompts: 
- Comment out the above 
- Uncomment the below 

In [None]:
# Specify the adapter to use, which is the model you have already fine-tuned.
model = pc.get_model("Llama-2-7b-hf-code_alpaca_800")
adapter_deployment = base_deployment.with_adapter(model)

#### Test newly deployed model

In [None]:
result = finetuned_llm.prompt(
    {
        "instruction": f"Given this schema, write a Cypher query that returns the data I am looking for.  Schema:  {meta_graph_str}",
      "input": "I am neo4j. How many of my tweets did 'nsmith_piano' reply to?"
    },
    max_new_tokens=256)

print(result.response)

#### Call Neo4J with the query

In [None]:
neo4j_result = execute_query(driver, query)

print(neo4j_result)