# 04. Data preparation

## Setup

In [None]:
from pprint import pprint
import os
import sys
from datasets import load_dataset
import yaml

current_dir = os.getcwd()
kit_dir =  os.path.abspath(os.path.join(current_dir, '..'))
repo_dir = os.path.abspath(os.path.join(kit_dir, '..'))
sys.path.append(repo_dir)

from utils.fine_tuning.src import sambastudio_utils
from utils.fine_tuning.src.snsdk_wrapper import SnsdkWrapper

In [None]:
# Instantiate the SambaNova SDK SambaStudio client
sambastudio_client = SnsdkWrapper()

In [None]:
# Load the target model config
config_target_yaml = '../01_config_target.yaml'

# Open and load the YAML file into a dictionary
with open(config_target_yaml, 'r') as file:
    config_target = yaml.safe_load(file)
pprint('Target model:')
pprint(config_target)

# Load the data generation config
config_data_preparation = '../04_config_data_preparation.yaml'

# Open and load the YAML file into a dictionary
with open(config_data_preparation, 'r') as file:
    config_data_preparation = yaml.safe_load(file)
pprint('Dataset preparation:')
pprint(config_data_preparation)

### Select your training dataset
You can use your own dataset (see [synthetic data generation util](../synthetic_data_gen/notebooks/quickstart_synthetic_data_gen.ipynb)).

### Prepare Dataset

To upload a dataset to SambaStudio we need first to convert it to a suitable format (hdf5), for this we will use the generative data prep utility 

In [None]:
config_data_preparation['files']['input_files']

In [None]:
hdf5_dataset_path = sambastudio_utils.gen_data_prep_pipeline(
    input_files = kit_dir + '/' + config_data_preparation['files']['input_files'],
    output_path = kit_dir + '/' +  config_data_preparation['files']['output_path'],
    tokenizer = config_data_preparation['target_model']['hf_name'], # use the tokenizer of the model to train with
    max_seq_length = config_data_preparation['target_model']['max_seq_length'],
    shuffle = 'on_RAM',
    input_packing_config = 'single::truncate_right', 
    prompt_keyword = 'prompt',
    completion_keyword = 'completion',
    num_training_splits = 8,
    apply_chat_template = False,
    )

Find more details about the gen data prep parameters [here](https://github.com/sambanova/generative_data_prep?tab=readme-ov-file#flags)

### Set dataset configs

Some parameter should be provided to upload a previously created checkpoint, for this we will keep these parameters in a dataset dict.

In [None]:
dataset = {
    'dataset_path': hdf5_dataset_path,
    'dataset_name': config_data_preparation['dataset']['dataset_name'],
    'dataset_description': 'This dataset contains question and answer pairs sourced from Q&A pages and FAQs from CDC and WHO pertaining to COVID-19',
    'dataset_job_types': ["evaluation", "train"],
    'dataset_source_type': 'localMachine',
    'dataset_language': 'english',
    'dataset_filetype': 'hdf5',
    'dataset_url': "https://dummy_url",
    'dataset_metadata':{}
}

You should indicate for which apps the uploaded dataset will be available, if not sure you can list all the aps in SambaStudio ans select those you want 

In [None]:
avaliable_apps = sambastudio_client.list_apps()
avaliable_apps

In [None]:
# In this case we will train a llama3 model so wi will include all the llama3 apps
llama3_apps=[app['name'] for app in avaliable_apps if 'llama3' in app['name'].replace(' ','').lower()]
dataset['dataset_apps_availability'] = llama3_apps

In [None]:
# We can see here all the parameters required to upload the dataset
dataset

### Upload Dataset to SambaStudio

In [None]:
# Execute the create dataset method from client with dataset parameters (this can take a while)
sambastudio_client.create_dataset(
    dataset_path = dataset['dataset_path'],
    dataset_name = dataset['dataset_name'],
    dataset_description = dataset['dataset_description'],
    dataset_job_types = dataset['dataset_job_types'],
    dataset_source_type = dataset['dataset_source_type'],
    dataset_language = dataset['dataset_language'],
    dataset_url = dataset['dataset_url'],
    dataset_apps_availability = dataset['dataset_apps_availability'],
    dataset_filetype = dataset['dataset_filetype'],
    dataset_metadata = dataset['dataset_metadata']
)

In [None]:
# check the dataset is now in SambaStudio environment
sambastudio_client.list_datasets()[-1]