# Imports

In [None]:
%load_ext autoreload
%autoreload 2

import pandas as pd
import sys
import pickle

# Call models
from src.call_models import google_connect, call_gemini, all_string_gemini_config, all_int_gemini_config
from src.translate_func import gemini_translation, gemini_classification

# Datasets

from src.benchmarks_code import hellaswag
from prompts import hellaswag_prompts

# Access keys
from my_access_keys import google_access_key

# .csv utils
from src.save_utils import add_dataset_to_csv

# Remove annoying warning
from IPython.core.display_functions import display

# GetDataset

In [None]:
google_client = google_connect(google_access_key)

In [None]:
hellaswag_dataset = hellaswag.get_hellaswag_datasets()

hellaswag_dataset

In [None]:
hesw_text = hellaswag_dataset['hellaswag_val'].skip(800).take(500)
hesw_text

# Classify whether to translate or not

In [None]:
%%time
exp_name = 'gemini_think_classify'

labels, text_output = gemini_classification(
    google_client,
    {'run_on': hesw_text},
    hellaswag_prompts.HELLASWAG_CLASSIFICATION_GEMINI,
    hellaswag_prompts.HELLASWAG_CLASSIFICATION_FEW_SHOTS,
    hellaswag.hellaswag_sample_to_dict,
    hellaswag.hellaswag_dict_to_sample,
    think_bud=256,
)

labels, text_output = labels['run_on'], text_output['run_on']

In [None]:
labels[:3]

In [None]:
# Define the filename
filename = 'compare_csv/hellaswag/hella_sawg_801-1300_labels.pkl'

# Open the file in write binary mode
with open(filename, 'wb') as file:
    # Use pickle.dump() to save the list to the file
    pickle.dump(labels, file)

In [None]:
use_this = pd.Series(labels)
display(use_this.value_counts())

use_this = use_this[(use_this == 'Universal') | (use_this == 'Can be localized')]

print(f'\nWill translate {use_this.shape[0]} samples.')

In [None]:
145 + 38

In [None]:
213 + 56

In [None]:
183 + 269

In [None]:
pd.Series(labels[200:]).value_counts()

# Translate

In [None]:
hesw_text

In [None]:
hesw_text = hesw_text.select(use_this.index)
hesw_text

In [None]:
hellaswag_file_name = 'compare_csv/hellaswag/hellaswag_test_801-1300.csv'

In [None]:
df = add_dataset_to_csv(hellaswag_file_name, 'original', hesw_text, hellaswag.hellaswag_sample_to_dict)
text_df = add_dataset_to_csv(hellaswag_file_name[:-4] + '-text.csv', 'original', hesw_text, hellaswag.hellaswag_sample_to_dict)
display(df.head(2))
display(text_df.head(2))

In [None]:
df = pd.read_csv(hellaswag_file_name)
text_df = pd.read_csv(hellaswag_file_name[:-4] + '-text.csv')
print(df.shape, text_df.shape)
display(df.head(2))
display(text_df.head(2))

In [None]:
print(df.shape, text_df.shape)

In [None]:
%%time

exp_name = 'gemini'

hebrew_datasets, text_output = gemini_translation(
    google_client,
    {'run': hesw_text},
    hellaswag_prompts.HELLASWAG_INSTRUCT_V1_GEMINI,
    hellaswag_prompts.HELLASWAG_FEW_SHOTS,
    hellaswag.hellaswag_sample_to_dict,
    hellaswag.hellaswag_dict_to_sample,
    if_pro=True,
    think_bud=4096,
)

hebrew_datasets = hebrew_datasets['run']
text_output = text_output['run']
hebrew_datasets

In [None]:
df = add_dataset_to_csv(hellaswag_file_name, exp_name, hebrew_datasets, hellaswag.hellaswag_sample_to_dict)
text_df[exp_name + ' text'] = text_output
text_df.to_csv(hellaswag_file_name[:-4] + '-text.csv', index=False)
print(df.shape, text_df.shape)
display(df.head(2))
display(text_df.head(2))

In [None]:
df['answer_label'] = pd.Series(hebrew_datasets['label'])
df['question ind'] = pd.Series(hebrew_datasets['ind'])

In [None]:
df.head()

In [None]:
df.to_csv(hellaswag_file_name, index=False)