# Imports

In [1]:
%load_ext autoreload
%autoreload 2

import pandas as pd
import sys
import pickle

# Call models
from src.call_models import google_connect, call_gemini, all_string_gemini_config, all_int_gemini_config
from src.translate_func import gemini_translation, gemini_classification

# Datasets

from src.benchmarks_code import hellaswag
from prompts import hellaswag_prompts

# Access keys
from my_access_keys import google_access_key

# .csv utils
from src.save_utils import add_dataset_to_csv

# Remove annoying warning
from IPython.core.display_functions import display

# GetDataset

In [2]:
google_client = google_connect(google_access_key)

In [5]:
hellaswag_dataset = hellaswag.get_hellaswag_datasets()

hellaswag_dataset

{'hellaswag_val': Dataset({
     features: ['ind', 'activity_label', 'ctx_a', 'ctx_b', 'ctx', 'endings', 'source_id', 'split', 'split_type', 'label'],
     num_rows: 10042
 })}

In [8]:
hesw_text = hellaswag_dataset['hellaswag_val'].skip(300).take(500)
hesw_text

Dataset({
    features: ['ind', 'activity_label', 'ctx_a', 'ctx_b', 'ctx', 'endings', 'source_id', 'split', 'split_type', 'label'],
    num_rows: 500
})

# Classify whether to translate or not

In [10]:
%%time
exp_name = 'gemini_think_classify'

labels, text_output = gemini_classification(
    google_client,
    {'run_on': hesw_text},
    hellaswag_prompts.HELLASWAG_CLASSIFICATION_GEMINI,
    hellaswag_prompts.HELLASWAG_CLASSIFICATION_FEW_SHOTS,
    hellaswag.hellaswag_sample_to_dict,
    hellaswag.hellaswag_dict_to_sample,
    think_bud=256,
)

labels, text_output = labels['run_on'], text_output['run_on']

Classifying run_on...


  0%|          | 0/500 [00:00<?, ?it/s]

sdk_http_response=HttpResponse(
  headers=<dict len=11>
) candidates=None create_time=None model_version='gemini-2.5-flash' prompt_feedback=GenerateContentResponsePromptFeedback(
  block_reason=<BlockedReason.PROHIBITED_CONTENT: 'PROHIBITED_CONTENT'>
) response_id='NIrKaKXbFvHP_uMP2svFmAs' usage_metadata=GenerateContentResponseUsageMetadata(
  prompt_token_count=1597,
  prompt_tokens_details=[
    ModalityTokenCount(
      modality=<MediaModality.TEXT: 'TEXT'>,
      token_count=1597
    ),
  ],
  total_token_count=1597
) automatic_function_calling_history=[] parsed=None
sdk_http_response=HttpResponse(
  headers=<dict len=11>
) candidates=None create_time=None model_version='gemini-2.5-flash' prompt_feedback=GenerateContentResponsePromptFeedback(
  block_reason=<BlockedReason.PROHIBITED_CONTENT: 'PROHIBITED_CONTENT'>
) response_id='Y4rKaJHaKMXP_uMP2-ThoA8' usage_metadata=GenerateContentResponseUsageMetadata(
  prompt_token_count=1577,
  prompt_tokens_details=[
    ModalityTokenCount(
 

In [11]:
labels[:3]

['Universal', 'Universal', 'failed']

In [12]:
# Define the filename
filename = 'compare_csv/hellaswag/hella_sawg_301-800_labels.pkl'

# Open the file in write binary mode
with open(filename, 'wb') as file:
    # Use pickle.dump() to save the list to the file
    pickle.dump(labels, file)

In [13]:
use_this = pd.Series(labels)
display(use_this.value_counts())

use_this = use_this[(use_this == 'Universal') | (use_this == 'Can be localized')]

print(f'\nWill translate {use_this.shape[0]} samples.')

Universal           324
Can be localized     99
Foreign              60
failed               17
Name: count, dtype: int64


Will translate 423 samples.


# Translate

In [14]:
hesw_text

Dataset({
    features: ['ind', 'activity_label', 'ctx_a', 'ctx_b', 'ctx', 'endings', 'source_id', 'split', 'split_type', 'label'],
    num_rows: 500
})

In [15]:
hesw_text = hesw_text.select(use_this.index)
hesw_text

Dataset({
    features: ['ind', 'activity_label', 'ctx_a', 'ctx_b', 'ctx', 'endings', 'source_id', 'split', 'split_type', 'label'],
    num_rows: 423
})

In [16]:
hellaswag_file_name = 'compare_csv/hellaswag/hellaswag_test_301-500.csv'

In [17]:
df = add_dataset_to_csv(hellaswag_file_name, 'original', hesw_text, hellaswag.hellaswag_sample_to_dict)
text_df = add_dataset_to_csv(hellaswag_file_name[:-4] + '-text.csv', 'original', hesw_text, hellaswag.hellaswag_sample_to_dict)
display(df.head(2))
display(text_df.head(2))

Unnamed: 0,original
0,<activity_label>Rope skipping</activity_label>...
1,<activity_label>Rope skipping</activity_label>...


Unnamed: 0,original
0,<activity_label>Rope skipping</activity_label>...
1,<activity_label>Rope skipping</activity_label>...


In [18]:
df = pd.read_csv(hellaswag_file_name)
text_df = pd.read_csv(hellaswag_file_name[:-4] + '-text.csv')
print(df.shape, text_df.shape)
display(df.head(2))
display(text_df.head(2))

(423, 1) (423, 1)


Unnamed: 0,original
0,<activity_label>Rope skipping</activity_label>...
1,<activity_label>Rope skipping</activity_label>...


Unnamed: 0,original
0,<activity_label>Rope skipping</activity_label>...
1,<activity_label>Rope skipping</activity_label>...


In [19]:
print(df.shape, text_df.shape)

(423, 1) (423, 1)


In [None]:
%%time

exp_name = 'gemini'

hebrew_datasets, text_output = gemini_translation(
    google_client,
    {'run': hesw_text},
    hellaswag_prompts.HELLASWAG_INSTRUCT_V1_GEMINI,
    hellaswag_prompts.HELLASWAG_FEW_SHOTS,
    hellaswag.hellaswag_sample_to_dict,
    hellaswag.hellaswag_dict_to_sample,
    if_pro=True,
    think_bud=4096,
)

hebrew_datasets = hebrew_datasets['run']
text_output = text_output['run']
hebrew_datasets

Translating run...


  0%|          | 0/423 [00:00<?, ?it/s]

-|-

In [None]:
df = add_dataset_to_csv(hellaswag_file_name, exp_name, hebrew_datasets, hellaswag.hellaswag_sample_to_dict)
text_df[exp_name + ' text'] = text_output
text_df.to_csv(hellaswag_file_name[:-4] + '-text.csv', index=False)
print(df.shape, text_df.shape)
display(df.head(2))
display(text_df.head(2))

In [None]:
df['answer_label'] = pd.Series(hebrew_datasets['label'])
df['question ind'] = pd.Series(hebrew_datasets['ind'])

In [None]:
df.head()

In [None]:
df.to_csv(hellaswag_file_name, index=False)