# 03. Training Data Preparation

## Setup

In [None]:
import json
from pprint import pprint
import os
import sys
import yaml

from langchain_sambanova import ChatSambaStudio

current_dir = os.getcwd()
kit_dir =  os.path.abspath(os.path.join(current_dir, '..'))
repo_dir = os.path.abspath(os.path.join(kit_dir, '..'))
sys.path.append(repo_dir)

from utils.fine_tuning.src.snsdk_wrapper import SnsdkWrapper

In [None]:
# Instantiate the SambaNova SDK SambaStudio client
sambastudio_client = SnsdkWrapper()

In [None]:
# Load the target model config
config_target_yaml = '01_config_target.yaml'

# Open and load the YAML file into a dictionary
with open(config_target_yaml, 'r') as file:
    config_target = yaml.safe_load(file)
pprint('Target model:')
pprint(config_target)

# Load the training data preparation config
config_training_data_generation_yaml = '03_config_data_generation.yaml'

# Open and load the YAML file into a dictionary
with open(config_training_data_generation_yaml, 'r') as file:
    config_data_generation = yaml.safe_load(file)
pprint('Dataset creation:')
pprint(config_data_generation)

### Create Project

#### Set Project configs

In [None]:
project = {
    'project_name': config_data_generation['project']['project_name'],
    'project_description': config_data_generation['project']['project_description'],
}

In [None]:
# Execute the create project method from client with project parameters
sambastudio_client.create_project(
    project_name = project['project_name'],
    project_description = project['project_description']
)

### Create Endpoint

In [None]:
# Set endpoint config 
endpoint = {
  'endpoint_name': config_target['model']['model_name'].lower(),
  'endpoint_description': f'Endpoint for {config_target["model"]["model_name"]}',
  'endpoint_instances': 1,
  'hyperparams': {}
}

In [None]:
# Execute the create endpoint method from client with endpoint parameters
sambastudio_client.create_endpoint(
    project_name=project['project_name'],
    endpoint_name=endpoint['endpoint_name'],
    endpoint_description=endpoint['endpoint_description'],
    model_name=config_target['model']['model_name'],
    model_version=1,
    instances=endpoint['endpoint_instances'],
    hyperparams=endpoint['hyperparams'],
    rdu_arch=config_data_generation['sambastudio']['rdu_arch'],
)

#### Get endpoint details

In [None]:
# Get endpoint details, including api key and envs
endpoint_env = sambastudio_client.get_endpoint_details(
    project_name=project['project_name'],
    endpoint_name=endpoint['endpoint_name']
    )['langchain_wrapper_env']

pprint(endpoint_env)

### Inference

### Test inference on a single question

In [None]:
# Instantiate langchain chat models to test inference 
llm = ChatSambaStudio(
    sambastudio_url=endpoint_env.get("SAMBASTUDIO_URL"),
    sambastudio_api_key=endpoint_env.get("SAMBASTUDIO_API_KEY"),
    temperature = 0.01,
    max_tokens = 1024,
    top_p = 0.1,
    do_sample = False
)

In [None]:
messages = [
    ("system", "You are an expert and experienced from the healthcare and biomedical domain with extensive medical knowledge and practical experience. Your name is OpenBioLLM, and you were developed by Saama AI Labs. who's willing to help answer the user's query with explanation. In your explanation, leverage your deep medical expertise such as relevant anatomical structures, physiological processes, diagnostic criteria, treatment guidelines, or other pertinent medical concepts. Use precise medical terminology while still aiming to make the explanation clear and accessible to a general audience."),
    ("human", "What are the morphological characteristics of a particular organism that determine its correct genus classification in the Taxonomy system? Identify the key features that indicate the proper classification for a given species and explain how they are used in Taxonomy")
]

In [None]:
llm.invoke(messages).content

### Create the dataset

In [None]:
# Path to the input JSONL file
input_filename = config_data_generation['files']['input_filename']
# Path to the output JSONL file
output_filename = config_data_generation['files']['output_filename']

In [None]:
output_lines = []
with open(input_filename, 'r', encoding='utf-8') as f:
    data = json.load(f)

from tqdm import tqdm
for record in tqdm(data):
    # Extract the system prompt and the human instruction
    system_msg: str = record.get('system_prompt', '')
    human_msg: str = record.get('instruction', '')

    # Construct the messages list as required
    messages = [
        ("system", system_msg),
        ("human", human_msg)
    ]

    # Invoke the LLM with the messages list and get the content of the response
    # Here we call the llm.invoke() and access its 'content' attribute.
    response = llm.invoke(messages)
    completion = response.content

    # Create a new dictionary (triple) with the keys system, prompt, and completion.
    new_entry = {
        "system": system_msg,
        "prompt": human_msg,
        "completion": completion
    }
    output_lines.append(new_entry)

In [None]:
# Write the transformed entries into the output JSONL file
with open(output_filename, "w", encoding="utf-8") as outfile:
    for entry in output_lines:
        json.dump(entry, outfile)
        outfile.write("\n")

print(f"Successfully processed {len(output_lines)} entries and saved to {output_filename}.")