# token analysis with and without context

This notebook analyses the token count distribution for the various model families (BERT, Llama, ...), and with and without context. Also, it computes how many sentences are cut by specifying a determined max_token_count param.

In [None]:
# Need to change cwd to import qut01
# Q: Is therean easy way to set the working directory of the notebook at launch?
%cd ..

In [None]:
import itertools
from collections import defaultdict

import hydra
import lightning.pytorch as pl
import matplotlib.pyplot as plt
import matplotlib.ticker
import numpy as np
import tqdm

import qut01

In [None]:
def get_dataloaders(config_overrides):
    logger = qut01.utils.logging.setup_logging_for_analysis_script()
    data_config_name = "sentence_sampler.yaml"
    logger.info(f"initializing hydra and fetching data config for '{data_config_name}'...")

    config = qut01.utils.config.init_hydra_and_compose_config(overrides=config_overrides)
    logger.info("initialization complete!")

    logger.info(f"Instantiating datamodule: {config.data.datamodule._target_}")  # noqa
    datamodule: pl.LightningDataModule = hydra.utils.instantiate(config.data.datamodule)
    assert isinstance(datamodule, pl.LightningDataModule), f"unexpected type: {type(datamodule)}"
    logger.info("running 'datamodule.prepare_data()'...")
    datamodule.prepare_data()
    logger.info("running 'datamodule.setup()'...")
    datamodule.setup(stage="fit")
    return datamodule, datamodule.train_dataloader(), datamodule.val_dataloader(), datamodule.test_dataloader()

In [None]:
def get_token_count(dataloader):
    batch_eq_to_none = 0
    token_id_counts = []

    for batch in tqdm.tqdm(dataloader):
        if batch is None:
            batch_eq_to_none += 1
            continue
        batch_tokens = batch["sentence_token_ids"]
        assert batch_tokens.shape[0] == 1, ""
        token_id_count = batch_tokens.shape[1]
        token_id_counts.append(token_id_count)
    print(f"found {batch_eq_to_none} batches equal to None")
    return token_id_counts


def plot_distributions(
    tokens,
    names,
    bins=100,
):
    fig, axs = plt.subplots(nrows=len(tokens) * 2, figsize=[10, 16])

    for i, name in enumerate(names):
        # Histogram
        ax = axs[2 * i]
        ax.hist(tokens[i], bins=bins)
        ax.set_title(f"Histogram of token counts for {name}")
        ax.set_xlabel("Token count")
        ax.set_ylabel("Frequency")
        ax.grid(axis="y", alpha=0.75)
        ax.yaxis.set_major_locator(matplotlib.ticker.MaxNLocator(integer=True))

        # CDF
        ax = axs[2 * i + 1]
        ax.hist(tokens[i], bins=bins, cumulative=-1, density=True)
        ax.set_yscale("log")
        ax.set_ylim(1e-4, 1)
        ax.set_title(f"CDF of token counts for {name}")
        ax.set_xlabel("Token count")
        ax.set_ylabel("% samples with greater than x tokens")
        ax.grid(True, which="major", alpha=0.75)
        ax.legend()

    plt.tight_layout()
    plt.show()


def get_percentage_of_samples_above_length(tokens, length_threshold) -> float:
    return 100 * sum([x >= length_threshold for x in tokens]) / len(tokens)

In [None]:
def analyze_token_counts(config_overrides):
    # Grab the dataloaders for the specified config
    datamodule, train_dl, valid_dl, test_dl = get_dataloaders(config_overrides)

    # Print the used vocabulary, to make sure we are dealing with the exepected one
    example = next(iter(train_dl))
    print()
    print("*** TESTING TOKENIZER WITH EXAMPLE ***")
    print(f"sentence is:\n{example['sentence_orig_text']}")
    print(f"token ids are:\n{example['sentence_token_ids']}")

    train_tokens = get_token_count(train_dl)
    val_tokens = get_token_count(valid_dl)
    test_tokens = get_token_count(test_dl)

    # Perform cutoff analysis
    print()
    print("*** CUTOFF ANALYSIS ***")
    for cutoff in range(0, 2000, 50):

        print()
        print("Cutoff = {} tokens".format(cutoff))
        for tokens, name in zip([train_tokens, val_tokens, test_tokens], ["train", "valid", "test"]):
            # for tokens, name in zip([val_tokens, test_tokens], ['valid', 'test']):
            print(
                "{} has {:.2f} % of samples over".format(
                    name,
                    get_percentage_of_samples_above_length(tokens, cutoff),
                )
            )

    # Generate distribution plots
    plot_distributions([train_tokens, val_tokens, test_tokens], ["train", "valid", "test"])
    # plot_distributions([val_tokens, test_tokens], ['valid', 'test'])

# DistilBERT

## No Context

In [None]:
config_overrides = [
    'data="sentence_sampler.yaml"',
    "data.classif_setup=any",
    "data.num_criteria=11",
    "data.sentence_batch_size=1",  # useful so that there is no padding around
    "++data.tokenizer._target_=transformers.AutoTokenizer.from_pretrained",
    "++data.tokenizer.pretrained_model_name_or_path=distilbert-base-uncased",
    "data.context_word_count=0",  # no context
]
analyze_token_counts(config_overrides)

## 300 Context

In [None]:
config_overrides = [
    'data="sentence_sampler.yaml"',
    "data.classif_setup=any",
    "data.num_criteria=11",
    "data.sentence_batch_size=1",  # useful so that there is no padding around
    "++data.tokenizer._target_=transformers.AutoTokenizer.from_pretrained",
    "++data.tokenizer.pretrained_model_name_or_path=distilbert-base-uncased",
    "data.context_word_count=300",
    "data.left_context_boundary_token='[SEP]'",
    "data.right_context_boundary_token='[SEP]'",
]
analyze_token_counts(config_overrides)

# BERT

## No Context

In [None]:
config_overrides = [
    'data="sentence_sampler.yaml"',
    "data.classif_setup=any",
    "data.num_criteria=11",
    "data.sentence_batch_size=1",  # useful so that there is no padding around
    "++data.tokenizer._target_=transformers.AutoTokenizer.from_pretrained",
    "++data.tokenizer.pretrained_model_name_or_path=bert-base-uncased",
    "data.context_word_count=0",  # no context
]
analyze_token_counts(config_overrides)

## 100 Words Context

In [None]:
config_overrides = [
    'data="sentence_sampler.yaml"',
    "data.classif_setup=any",
    "data.num_criteria=11",
    "data.sentence_batch_size=1",  # useful so that there is no padding around
    "++data.tokenizer._target_=transformers.AutoTokenizer.from_pretrained",
    "++data.tokenizer.pretrained_model_name_or_path=bert-base-uncased",
    "data.context_word_count=100",
    "data.left_context_boundary_token='[SEP]'",
    "data.right_context_boundary_token='[SEP]'",
]
analyze_token_counts(config_overrides)

## 200 Words Context

In [None]:
config_overrides = [
    'data="sentence_sampler.yaml"',
    "data.classif_setup=any",
    "data.num_criteria=11",
    "data.sentence_batch_size=1",  # useful so that there is no padding around
    "++data.tokenizer._target_=transformers.AutoTokenizer.from_pretrained",
    "++data.tokenizer.pretrained_model_name_or_path=bert-base-uncased",
    "data.context_word_count=200",
    "data.left_context_boundary_token='[SEP]'",
    "data.right_context_boundary_token='[SEP]'",
]
analyze_token_counts(config_overrides)

## 300 Words Context

In [None]:
config_overrides = [
    'data="sentence_sampler.yaml"',
    "data.classif_setup=any",
    "data.num_criteria=11",
    "data.sentence_batch_size=1",  # useful so that there is no padding around
    "++data.tokenizer._target_=transformers.AutoTokenizer.from_pretrained",
    "++data.tokenizer.pretrained_model_name_or_path=bert-base-uncased",
    "data.context_word_count=300",
    "data.left_context_boundary_token='[SEP]'",
    "data.right_context_boundary_token='[SEP]'",
]
analyze_token_counts(config_overrides)

# LLAMA 

## No Context

In [None]:
config_overrides = [
    'data="sentence_sampler_with_llama_tokenizer.yaml"',
    "data.classif_setup=any",
    "data.num_criteria=11",
    "data.sentence_batch_size=1",  # useful so that there is no padding around
]
analyze_token_counts(config_overrides)

## 100 Words Context

In [None]:
config_overrides = [
    'data="sentence_sampler_with_llama_tokenizer.yaml"',
    "data.classif_setup=any",
    "data.num_criteria=11",
    "data.sentence_batch_size=1",  # useful so that there is no padding around
    "data.context_word_count=100",
]
analyze_token_counts(config_overrides)

## 200 Words Context

In [None]:
config_overrides = [
    'data="sentence_sampler_with_llama_tokenizer.yaml"',
    "data.classif_setup=any",
    "data.num_criteria=11",
    "data.sentence_batch_size=1",  # useful so that there is no padding around
    "data.context_word_count=200",
]
analyze_token_counts(config_overrides)

## 300 Words Context

In [None]:
config_overrides = [
    'data="sentence_sampler_with_llama_tokenizer.yaml"',
    "data.classif_setup=any",
    "data.num_criteria=11",
    "data.sentence_batch_size=1",  # useful so that there is no padding around
    "data.context_word_count=300",
]
analyze_token_counts(config_overrides)