### Create statistics.

In [None]:
generate_data = False

In [None]:
import wandb
wandb.login()

In [None]:
wandb_entity = "nlp_and_interpretability"  # Change for your own wandb entity
wandb_project = "tinysql"
artifact_name = "TinyStoriesStatistics"
wandb.init(project=wandb_project, entity=wandb_entity)

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM

def get_tokenizer_vocab(model_name):
    """
    Retrieves the vocabulary of a tokenizer given the model name.
    Args:
    - model_name (str): The name of the model to load the tokenizer for.
    Returns:
    - dict: A dictionary where keys are tokens and values are token IDs.
    """
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    vocab = tokenizer.get_vocab()
    return vocab, tokenizer

# Example usage
model_name = 'roneneldan/TinyStories-33M'
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer_vocab, tokenizer = get_tokenizer_vocab(model_name)


from datasets import load_dataset

if generate_data:
    dataset = load_dataset('roneneldan/TinyStoriesInstruct')
    
    from tqdm import tqdm
    texts = []
    dtrain = dataset['train']
    for item in tqdm(dtrain):
        texts.append(item['text'])
    
    len(texts)
    #texts = [item for item in dataset['train']]
    
    from collections import Counter
    import re
    
    
    words = []
    for sentence in tqdm(texts):
        # Tokenize and clean each sentence
        words.extend(re.findall(r'\b\w+\b', sentence.lower()))

In [None]:
if generate_data:
    # Count the occurrences of each word
    print(f'Running counter')
    word_counts = Counter(words)
    print(f'Finished counter')
    
    # Filter to get words that occur at least 5 times
    unigram_statistics = {word: count for word, count in word_counts.items() if count >= 5}

In [None]:
import json

if generate_data:
    with open("tokenizer_vocab.json", "w") as f:
        json.dump(tokenizer_vocab, f)
        
    with open("unigram_statistics.json", "w") as f:
        json.dump(unigram_statistics, f)
    
    # Create W&B artifact and add files
    artifact = wandb.Artifact(artifact_name, type="dataset")
    artifact.add_file("tokenizer_vocab.json")
    artifact.add_file("unigram_statistics.json")
    
    # Log the artifact to the W&B run
    wandb.log_artifact(artifact)

### Load and analyze statistics

In [None]:
import json
import wandb

run = wandb.init()
artifact = run.use_artifact('nlp_and_interpretability/tinysql/TinyStoriesStatistics:v0', type='dataset')
artifact_dir = artifact.download()

with open(f'{artifact_dir}/tokenizer_vocab.json', 'r') as f_in:
    tokenizer_vocab = json.load(f_in)

with open(f'{artifact_dir}/unigram_statistics.json', 'r') as f_in:
    unigram_statistics = json.load(f_in)

In [None]:
artifact_dir

In [None]:
total_words = sum(unigram_statistics.values())

normalized_unigram_statistics = {key: value/total_words for key, value in unigram_statistics.items()}

In [None]:
import matplotlib.pyplot as plt
from collections import Counter

# Example word frequency counter
# Sort the word frequency by frequency
sorted_word_freq = dict(sorted(unigram_statistics.items(), key=lambda item: item[1], reverse=True))
num_keys = 50


In [None]:
# Create a plot of 50 most common words
plt.figure(figsize=(10, 5))

plt.plot(
    list(sorted_word_freq.keys())[:num_keys], list(sorted_word_freq.values())[:num_keys],
    marker='o'
)

# Add labels and title
plt.xlabel('Words')
plt.ylabel('Frequency')
plt.title('Word Frequency Distribution')

# Show the plot
plt.xticks(rotation=90)  # Rotate x-axis labels for better readability
plt.tight_layout()  # Adjust layout to fit the labels
plt.savefig("tiny_stories_distribution.png")

In [None]:
import huggingface_hub
huggingface_hub.login()

In [None]:
martian_template = "withmartian/{}_dataset"

keys = ["cs1", "cs2", "cs3"]

datasets = {key: load_dataset(martian_template.format(key)) for key in keys}

In [None]:
relevant_fields = [
    "table_name", "english_prompt", "sql_statement", "table_fields"
]

In [None]:
import string
from functools import lru_cache

def remove_punctuation_with_space(text):
    # Create a translation table mapping each punctuation to a space
    translation_table = str.maketrans(string.punctuation, ' ' * len(string.punctuation))
    return text.translate(translation_table)

@lru_cache(maxsize=128)
def get_all_tokens(dataset_key, field_name):
    """
    Extracts all unique tokens from specified fields in a dataset by splitting on whitespace and lowercasing.
    
    Args:
    - dataset_key
    - field_names (list): A list of field names to extract tokens from (e.g., ["english_prompt", "sql_statement"]).
    
    Returns:
    - all_tokens (set): A set of all unique tokens found in the specified fields across the dataset.
    """
    # Initialize a set to store all unique tokens
    all_tokens = set()

    dataset = datasets[dataset_key]["train"]
    
    # Loop through the dataset and process the specified fields
    for entry in tqdm(dataset):
        text = entry.get(field_name, "")
        text = remove_punctuation_with_space(text)

        # Process only if the field exists and is not empty
        if text:
            # Split on whitespace and lowercase the tokens
            tokens = text.lower().split()
            all_tokens.update(tokens)

    return all_tokens

def calculate_token_occurrence_rate(dataset_key, field_names, unigram_statistics):
    """
    This function takes in a dataset, field names, and unigram statistics to plot the token occurrence rates.

    Args:
    - dataset (list): A list of dictionaries where each dictionary represents an entry in the dataset.
    - field_name (str): The field name to extract tokens from (e.g., "english_prompt", "sql_statement", "table_fields").
    - unigram_statistics (dict): A dictionary where keys are tokens and values are their occurrence rates.
    
    Returns:
    - A plot showing the token occurrence rates for the specified field in the dataset.
    """

    print(f"Processing {key}")


    all_stats = {}
    for field in field_names:
        tokens_and_rates = []
        print(f'Processing {field} for {dataset_key}')
        all_tokens = get_all_tokens(dataset_key, field)
        tokens_and_rates = [(token, unigram_statistics.get(token, 0)) for token in all_tokens]
        tokens_and_rates = sorted(tokens_and_rates, key = lambda x: -x[1])

        null_tokens = sorted([token_and_rate[0] for token_and_rate in tokens_and_rates if token_and_rate[1] < 5])

        all_stats[field] = {
            "null_tokens": null_tokens.copy(),
            "tokens_and_rates": tokens_and_rates.copy(),
            "num_tokens": len(tokens_and_rates),
            "num_null": len(null_tokens)
        }
    return all_stats

In [None]:
dataset_stats = {}
for key in datasets:
    dataset_stats[key] = calculate_token_occurrence_rate(key, relevant_fields, unigram_statistics)

In [None]:
from copy import deepcopy
trimmed_dataset_stats = deepcopy(dataset_stats)

for key, curr_stats in trimmed_dataset_stats.items():
    for field, stats in curr_stats.items():
        del stats['null_tokens']
        del stats['tokens_and_rates']

trimmed_dataset_stats

In [None]:
dataset_stats['cs3']['table_fields']['null_tokens']

In [None]:
tiny_sql_artifact = "TinySQLStatistics"

In [None]:
# Create W&B artifact and add files
import json

artifact = wandb.Artifact(tiny_sql_artifact, type="dataset")

filename = "dataset_stats.json"
with open(filename, "w") as f_out:
    json.dump(dataset_stats, f_out)
    artifact.add_file(filename)

# Log the artifact to the W&B run
wandb.log_artifact(artifact)