In [34]:
from datasets import load_dataset_builder, load_dataset, get_dataset_infos, get_dataset_config_names, list_datasets, get_dataset_split_names
import pandas as pd


### Utility functions

In [101]:
def add_class_labels(df, idx):
    df['class'] = False
    df.loc[idx, 'class'] = True
    return df

def rebalance(df):
    # Separate the DataFrame into positive and negative examples
    positive_df = df[df['class'] == True]
    negative_df = df[df['class'] == False]

    # Determine the smaller count
    min_count = min(len(positive_df), len(negative_df))

    # Sample from the larger group
    if len(positive_df) > len(negative_df):
        positive_df = positive_df.sample(n=min_count)
    else:
        negative_df = negative_df.sample(n=min_count)

    # Concatenate both DataFrames
    balanced_df = pd.concat([positive_df, negative_df])

    # Shuffle the DataFrame if needed
    balanced_df = balanced_df.sample(frac=1).reset_index(drop=True)

    return balanced_df

## Datasets

#### imdb Digits

In [97]:
get_dataset_split_names("imdb")
dataset = load_dataset("imdb", split="unsupervised")

dataset_name = "imdb_digits"

In [98]:
# numbers = set(['zero', 'one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine']).union(set('0123456789'))
digits = set('0123456789')

In [99]:
# put the first n items of dataset into a pandas dataframe
n = 300
texts = dataset['text'][:n]
df = pd.DataFrame(texts, columns=['text'])

# remove everything after the first period
df['text'] = df['text'].map(lambda x: x[:x.find('.')+1])

#iterate over the dataframe and find the sentences that contain digits
number_sentences_idx = []
for index, row in df.iterrows():
    sentence_set = set(row['text'].split()).union(set(row['text']))
    if sentence_set.intersection(digits):
        number_sentences_idx.append(index)

In [102]:
add_class_labels(df, number_sentences_idx)
balanced_df = rebalance(df)
balanced_df.to_csv(f"data/{dataset_name}_data.csv", index=False)