In [None]:
import sys
import os
sys.path.append(os.path.abspath('..'))


from src.utils import *
from src.llm_gen import *
from src.dataset_llm_templates import *
from src.data_loader import *

import pandas as pd
from copy import deepcopy
import pickle
import time
from sklearn.model_selection import train_test_split



#############################################################

# API KEY SETUP INSTRUCTIONS


#############################################################

# for vllm
# api_key = "EMPTY"
# api_base = "http://localhost:8000/v1"

# for together
# api_key = "add together api key"
# api_base = "https://api.together.xyz/v1"


# for azure openai
# api_key = "EMPTY"
# api_base = "add azure deployment link"

# for openai
# api_key = "EMPTY"
# api_base = DO NOT INCLUDE
#############################################################

api_details = {
     "api_base": "add api base",
     "api_version": "2023-07-01-preview",
     "api_key": "add api key",
}


model_short_name = 'mixtral' # 'gpt-4' (do not use other short names)
model = "mistralai/Mixtral-8x7B-Instruct-v0.1" # "gpt4_20230815" (use name of your model deployment)
llm_serving='together' # supported 'azure_openai', 'together', 'vllm'


# Factors to evaluate
seeds = [0,1,2,3,4,5,6,7,8,9]
n_samples = [10,20,50,100]  # e.g. 10 *2 = 20, 20*2 = 40, 50*2 = 100, 100*2 = 200
datasets = ['compas']

for seed in seeds:
    for ns in n_samples: 
        for dataset in datasets:
            try:
                # sleep between runs from a rate limit perspective
                # time.sleep(120)
                n_synthetic=20
                n_processes = 5

                df_feat, df_label, df = get_data(dataset=dataset, seed=seed)


                X_train, X_remain, y_train, y_remain = sample_and_split(df_feat, df_label, ns=ns, seed=seed)

                X_val, X_test, y_val, y_test = train_test_split(
                    X_remain, y_remain, test_size=0.5, random_state=seed
                )


                X_train_orig = deepcopy(X_train)
                y_train_orig = deepcopy(y_train)

                results = {}
                results['Original'] = {"X": X_train_orig, 'y': y_train_orig}
                results['Oracle'] = {"X": X_val, 'y': y_val}
                results['Test'] = {"X": X_test, 'y': y_test}


                prompt, generator_template, format_instructions, example_df = langchain_templates(X_train_orig, y_train_orig, dataset=dataset)

                retries = 4  # Max retries you want to attempt

                while retries > 0:
                    try:

                        if len(example_df)>20:
                            ic_samples=20
                        else:
                            ic_samples=len(example_df)
                        
                        print(f'Running {dataset}, {seed}, {model} --- {n_processes}')
                        df_llm = llm_gen(prompt, generator_template, format_instructions, example_df, 
                                        n_samples=n_synthetic,
                                        temperature=0.9,
                                        max_tokens=1000, model=model, 
                                        n_processes=n_processes,
                                        ic_samples=ic_samples, 
                                        llm_serving=llm_serving, 
                                        api_details=api_details)
                        print(df_llm.shape)
                        break  # if successful, break out of the loop
                    except Exception as e:
                        print(f"Error: {e}. Retrying with reduced n_processes...")
                        n_processes = int(n_processes/2)
                        retries -= 1
                        if n_processes < 1:
                            print("Error: Minimum n_processes reached. Exiting...")
                            break
                try:
                    tmp_df = df_llm.astype(example_df.dtypes)
                    df_llm = tmp_df
                except:
                    pass

                ylabel_map = {'covid': "is_dead",
                            "adult": "salary",
                            "compas": "y",
                            "drug": "y",
                            "bio": "y",
                            "higgs": "y",
                            "seer": "mortCancer",
                            "cutract":"mortCancer",
                            "maggic": "death_all",
                            "support": "death",
                            }
                
                ylabel =  ylabel_map[dataset]

                X_train_llm = df_llm.drop(columns=[ylabel])
                y_train_llm = df_llm[ylabel]

                results['llm'] = {"X": X_train_llm, 'y': y_train_llm, 'df': df}


                with open(f'../save_dfs/pipeline_llm_{dataset}_{n_synthetic}_{model_short_name}_{ns}_{seed}.pickle', 'wb') as f:
                        pickle.dump(results, f)

            except Exception as e:
                import traceback
                print(traceback.format_exc())
                print(e)
                continue