In [None]:
import json
import numpy as np
import torch
from torch import cuda
from accelerate import infer_auto_device_map, init_empty_weights
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM
from langchain import PromptTemplate, FewShotPromptTemplate

from models.worker import Worker

def stream_jsonl(file_path):
    with open(file_path, 'r') as file:
        lines = file.readlines()
    np.random.shuffle(lines)
    
    for line in lines:
        yield json.loads(line)
            
file_paths = ['./data/helpful-base/test.jsonl', './data/harmless-base/test.jsonl']

model_id = 'princeton-nlp/Sheared-LLaMA-2.7B'

In [None]:
worker = Worker(model_id)
worker.model

In [None]:
worker.check_device_map(no_split_module_classes=["LlamaDecoderLayer"])

In [None]:
worker.load_model()

In [None]:
# Could make these from the training set with the preferred responses (slice them off so you don't fine tune on them?)

examples = [
    {
        "query":"What is the capital of Japan?",
        "answer":"Tokyo is the capital city of Japan."
    },
    {
        "query":"What is the main ingredient in an omlette?",
        "answer":"The main ingredient in an omelette is egg."
    }
]

# examples = []

example_template = """
Human: {query}
AI: {answer}
"""

example_prompt = PromptTemplate(
    input_variables = ["query", "answer"],
    template=example_template
)

prefix = """
You are an AI responding to questions from a human. Try and be helpful but not harmful.
Some examples of good behaviour are:
"""

suffix = """
{query}
AI:
"""

prompt_template = FewShotPromptTemplate(
    examples=examples,
    example_prompt=example_prompt,
    prefix=prefix,
    suffix=suffix,
    input_variables=["query"],
    example_separator="\n"
)

In [None]:
data_path = file_paths[np.random.choice([0,1])]
query = next(iter(stream_jsonl(data_path)))
query = query['chosen'].split('\n')[2]

prompt = prompt_template.format(query=query)

response = worker.generate_text(prompt)
print(response)