In [None]:
import csv
import gc
from collections import Counter

import numpy as np
import safetensors
import torch
from tqdm import tqdm

### Notes on Model Parameter Files

- Only safetensors files are accepted.
- Due to processing with numpy, the model parameters will be loaded into CPU RAM.
- After processing each path, the memory will be cleared to free up RAM, but the contents of a single file cannot be read in chunks. Therefore, each file should fit into RAM.
    - As a rule of thumb, if a single safetensors file is smaller than 5GB, having around 25GB of RAM should prevent memory shortages.
    - When running on [Google Colaboratory](https://colab.google/), it is recommended to choose the High-RAM option in the Pro plan or higher. Note that High-RAM is not available in the Pay As You Go plan.


In [None]:
model_path = ["{your_model_path}"]

save_filepath = "{your_save_filepath}"

print(f"{model_path=}")
print(f"{save_filepath=}")

In [None]:
def first_digit(x):
    """
    Extracts the first digit of a given number.

    Parameters
    ----------
    x : float
        The number from which to extract the first digit.

    Returns
    -------
    int
        The first digit of the absolute value of `x`.
    """
    if x == 0:
        return 0
    return int(f"{abs(x):.10e}".split("e")[0][0])


vectorized_first_digit = np.frompyfunc(first_digit, 1, 1)


def process_models(path_list):
    """
    Processes the parameters of all models to analyze the distribution of their first digits.

    Parameters
    ----------
    path_list : list of str
        A list of file paths to the models in safetensors format.

    Returns
    -------
    collections.Counter
        A Counter object containing the frequency of each first digit.
    """
    total_digits = []

    # Display the processing progress of all models using tqdm
    with tqdm(total=len(path_list), desc="Processing models") as pbar:
        for path in path_list:
            # Load model parameters in safetensors format
            tensors = {}
            with safetensors.safe_open(path, framework="pt") as f:
                # Display the progress of reading each file using tqdm
                file_pbar = tqdm(
                    total=len(f.keys()), desc=f"Processing {path}", leave=False
                )
                for k in f.keys():
                    tensor = f.get_tensor(k)
                    # Convert from BFloat16 to Float32
                    if tensor.dtype == torch.bfloat16:
                        tensor = tensor.to(torch.float32)
                    tensors[k] = tensor.numpy()
                    file_pbar.update(1)
                file_pbar.close()

            # Convert parameters to a one-dimensional list
            param_values = np.concatenate([p.flatten() for p in tensors.values()])

            pbar.update(1 / 3)

            # Exclude zero parameters
            non_zero_values = param_values[param_values != 0]

            # Release memory
            del tensors
            del param_values
            gc.collect()  # Perform garbage collection

            pbar.update(1 / 3)

            # Extract the first digit using the vectorized function
            digits = vectorized_first_digit(non_zero_values)
            total_digits.extend(digits)

            # Release memory
            del non_zero_values
            del digits
            gc.collect()  # Perform garbage collection

            pbar.update(1 / 3)

    return Counter(total_digits)


def save_counter_to_csv(counter, filename):
    """
    Converts a Counter object to CSV format and saves it to a file.

    Parameters
    ----------
    counter : collections.Counter
        A Counter object containing the frequency of each digit.
    filename : str
        The name of the file to save the CSV data to.

    Returns
    -------
    None
    """
    counter_items = counter.items()

    with open(filename, mode="w", newline="") as file:
        writer = csv.writer(file)
        writer.writerow(["Digit", "Frequency"])
        for key, value in counter_items:
            writer.writerow([key, value])

    print(f"Saved to {filename}")

In [None]:
counter = process_models(model_path)
save_counter_to_csv(counter, save_filepath)