In [2]:
from datasets import load_dataset_builder, load_dataset, get_dataset_infos, get_dataset_config_names, list_datasets, get_dataset_split_names
import pandas as pd
from typing import Callable, Dict, List, Optional, Tuple, Union


  from .autonotebook import tqdm as notebook_tqdm


### Utility functions

In [7]:
def add_true_false_labels(df, idx):
    df['class'] = False
    df.loc[idx, 'class'] = True
    return df

def add_type_labels(df, idx):
    df['class'] = "Type 1"
    df.loc[idx, 'class'] = "Type 2"
    return df

def rebalance(df):
    # Separate the DataFrame into positive and negative examples
    positive_df = df[df['class'] == True]
    negative_df = df[df['class'] == False]

    # Determine the smaller count
    min_count = min(len(positive_df), len(negative_df))

    # Sample from the larger group
    if len(positive_df) > len(negative_df):
        positive_df = positive_df.sample(n=min_count)
    else:
        negative_df = negative_df.sample(n=min_count)

    # Concatenate both DataFrames
    balanced_df = pd.concat([positive_df, negative_df])

    # Shuffle the DataFrame if needed
    balanced_df = balanced_df.sample(frac=1).reset_index(drop=True)

    return balanced_df

def save_dataset(df: pd.DataFrame, name: str, label_map: Callable = None):
    if label_map:
        df['class'] = df['class'].apply(label_map)
    df.to_csv(f"data/{name}_data.csv", index=False)

## Datasets

#### imdb Digits

In [8]:
get_dataset_split_names("imdb")
dataset = load_dataset("imdb", split="unsupervised")

dataset_name = "imdb_digits"

In [9]:
# numbers = set(['zero', 'one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine']).union(set('0123456789'))
digits = set('0123456789')

In [10]:
# put the first n items of dataset into a pandas dataframe
n = 300
texts = dataset['text'][:n]
df = pd.DataFrame(texts, columns=['text'])

# remove everything after the first period
df['text'] = df['text'].map(lambda x: x[:x.find('.')+1])

#iterate over the dataframe and find the sentences that contain digits
number_sentences_idx = []
for index, row in df.iterrows():
    sentence_set = set(row['text'].split()).union(set(row['text']))
    if sentence_set.intersection(digits):
        number_sentences_idx.append(index)

In [11]:
add_type_labels(df, number_sentences_idx)
balanced_df = rebalance(df)
balanced_df.to_csv(f"data/{dataset_name}_data.csv", index=False)

#### Single digit arithmetic

In [12]:
import random

def generate_addition_strings(num_strings):
    strings = []
    for _ in range(num_strings):
        a = random.randint(0, 9)
        b = random.randint(0, 9)
        correct_c = a + b
        incorrect_c = random.randint(0, 18)  # Generate a random incorrect value for c
        while incorrect_c == correct_c:
            incorrect_c = random.randint(0, 18)

        if random.random() < 0.5:
            c = correct_c
            is_correct = True
        else:
            c = incorrect_c
            is_correct = False

        string = f"{a} + {b} = {c}"
        strings.append((string, is_correct))

    return strings

In [13]:
addition_df = pd.DataFrame(generate_addition_strings(100), columns=['text', 'class'])

# convert class column to boolean
addition_df['class'] = addition_df['class'].astype(bool)

In [14]:
save_dataset(addition_df, name="sd_addition", label_map=None)

In [15]:
type(addition_df['class'][0])

numpy.bool_

#### GPT Digits

Validate GPT-4 generated data

In [35]:
df_digits = pd.read_csv("data/unprocessed/gpt_digits_data.csv")
df_digits['Contains Digit'].astype(bool)
digits = set('0123456789')
for idx, row in df_digits.copy().iterrows():
    if row['Contains Digit'] and not set(row['Sentence']).intersection(digits):
        df_digits.at[idx, 'Contains Digit'] = False
    elif not row['Contains Digit'] and set(row['Sentence']).intersection(digits):
       df_digits.at[idx, 'Contains Digit'] = True

df_digits.to_csv("data/gpt_digits_data.csv", index=False)