In [22]:
import re
from typing import Union, Dict

import torch

from balm.config import BalmConfig

from balm.models import BalmForMaskedLM, BalmForSequenceClassification
from balm.models.base import BalmBase



In [2]:
config = BalmConfig()


In [3]:
balm = BalmForMaskedLM(config=config)

In [4]:
balm._modules.keys()

odict_keys(['balm', 'lm_head', 'criterion'])

In [5]:
balm_classifier = BalmForSequenceClassification(config=config)

In [6]:
balm_classifier._modules.keys()

odict_keys(['balm', 'classifier', 'criterion'])

In [7]:
hasattr(balm, "module")

False

In [15]:
print("\n".join(balm.state_dict().keys()))



balm.embed_tokens.weight
balm.layers.0.norm1.weight
balm.layers.0.norm1.bias
balm.layers.0.norm2.weight
balm.layers.0.norm2.bias
balm.layers.0.attention.in_proj_weight
balm.layers.0.attention.in_proj_bias
balm.layers.0.attention.out_proj.weight
balm.layers.0.attention.out_proj.bias
balm.layers.0.feed_forward.0.weight
balm.layers.0.feed_forward.0.bias
balm.layers.0.feed_forward.2.weight
balm.layers.0.feed_forward.2.bias
balm.layers.1.norm1.weight
balm.layers.1.norm1.bias
balm.layers.1.norm2.weight
balm.layers.1.norm2.bias
balm.layers.1.attention.in_proj_weight
balm.layers.1.attention.in_proj_bias
balm.layers.1.attention.out_proj.weight
balm.layers.1.attention.out_proj.bias
balm.layers.1.feed_forward.0.weight
balm.layers.1.feed_forward.0.bias
balm.layers.1.feed_forward.2.weight
balm.layers.1.feed_forward.2.bias
balm.layers.2.norm1.weight
balm.layers.2.norm1.bias
balm.layers.2.norm2.weight
balm.layers.2.norm2.bias
balm.layers.2.attention.in_proj_weight
balm.layers.2.attention.in_proj_bias

In [23]:
def dtype_byte_size(dtype):
    """
    Returns the size (in bytes) occupied by one parameter of type `dtype`.

    Example:

    ```py
    >>> dtype_byte_size(torch.float32)
    4
    ```
    """
    if dtype == torch.bool:
        return 1 / 8
    bit_search = re.search(r"[^\d](\d+)$", str(dtype))
    if bit_search is None:
        raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
    bit_size = int(bit_search.groups()[0])
    return bit_size // 8


def convert_file_size_to_int(size: Union[int, str]):
    """
    Converts a size expressed as a string with digits an unit (like `"5MB"`) to an integer (in bytes).

    Args:
        size (`int` or `str`): The size to convert. Will be directly returned if an `int`.

    Example:
    ```py
    >>> convert_file_size_to_int("1MiB")
    1048576
    ```
    """
    if isinstance(size, int):
        return size
    if size.upper().endswith("GIB"):
        return int(size[:-3]) * (2**30)
    if size.upper().endswith("MIB"):
        return int(size[:-3]) * (2**20)
    if size.upper().endswith("KIB"):
        return int(size[:-3]) * (2**10)
    if size.upper().endswith("GB"):
        int_size = int(size[:-2]) * (10**9)
        return int_size // 8 if size.endswith("b") else int_size
    if size.upper().endswith("MB"):
        int_size = int(size[:-2]) * (10**6)
        return int_size // 8 if size.endswith("b") else int_size
    if size.upper().endswith("KB"):
        int_size = int(size[:-2]) * (10**3)
        return int_size // 8 if size.endswith("b") else int_size
    raise ValueError(
        "`size` is not in a valid format. Use an integer followed by the unit, e.g., '5GB'."
    )

In [24]:
def shard_checkpoint(
    state_dict: Dict[str, torch.Tensor],
    max_shard_size: Union[int, str] = "10GB",
    weights_name: str = "pytorch_model.bin",
):
    """
    Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
    given size.

    The sub-checkpoints are determined by iterating through the `state_dict` in the order of its keys, so there is no
    optimization made to make each sub-checkpoint as close as possible to the maximum size passed. For example, if the
    limit is 10GB and we have weights of sizes [6GB, 6GB, 2GB, 6GB, 2GB, 2GB] they will get sharded as [6GB], [6+2GB],
    [6+2+2GB] and not [6+2+2GB], [6+2GB], [6GB].

    <Tip warning={true}>

    If one of the model's weight is bigger than `max_shard_size`, it will end up in its own sub-checkpoint which will
    have a size greater than `max_shard_size`.

    </Tip>

    Args:
        state_dict (`Dict[str, torch.Tensor]`): The state dictionary of a model to save.
        max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
            The maximum size of each sub-checkpoint. If expressed as a string, needs to be digits followed by a unit
            (like `"5MB"`).
        weights_name (`str`, *optional*, defaults to `"pytorch_model.bin"`):
            The name of the model save file.
    """
    max_shard_size = convert_file_size_to_int(max_shard_size)

    sharded_state_dicts = [{}]
    last_block_size = 0
    total_size = 0
    # storage_id_to_block = {}

    for key, weight in state_dict.items():
        # # when bnb serialization is used the weights in the state dict can be strings
        # # check: https://github.com/huggingface/transformers/pull/24416 for more details
        # if isinstance(weight, str):
        #     continue
        # else:
        #     storage_id = id_tensor_storage(weight)

        # # If a `weight` shares the same underlying storage as another tensor, we put `weight` in the same `block`
        # if storage_id in storage_id_to_block:
        #     block_id = storage_id_to_block[storage_id]
        #     sharded_state_dicts[block_id][key] = weight
        #     continue

        weight_size = weight.numel() * dtype_byte_size(weight.dtype)

        # If this weight is going to tip up over the maximal size, we split, but only if we have put at least one
        # weight in the current shard.
        if (
            last_block_size + weight_size > max_shard_size
            and len(sharded_state_dicts[-1]) > 0
        ):
            sharded_state_dicts.append({})
            last_block_size = 0

        sharded_state_dicts[-1][key] = weight
        last_block_size += weight_size
        total_size += weight_size
        # storage_id_to_block[storage_id] = len(sharded_state_dicts) - 1

    # If we only have one shard, we return it
    if len(sharded_state_dicts) == 1:
        return {weights_name: sharded_state_dicts[0]}, None

    # Otherwise, let's build the index
    weight_map = {}
    shards = {}
    for idx, shard in enumerate(sharded_state_dicts):
        shard_file = weights_name.replace(
            ".bin", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.bin"
        )
        # shard_file = shard_file.replace(
        #     ".safetensors",
        #     f"-{idx + 1:05d}-of-{len(sharded_state_dicts):05d}.safetensors",
        # )
        shards[shard_file] = shard
        for key in shard.keys():
            weight_map[key] = shard_file

    # Add the metadata
    metadata = {"total_size": total_size}
    index = {"metadata": metadata, "weight_map": weight_map}
    return shards, index

NameError: name 'WEIGHTS_NAME' is not defined

In [21]:
total_size = 0

for key, weight in balm.state_dict().items():
    weight_size = weight.numel() * dtype_byte_size(weight.dtype)
    print(f"{key}: {weight_size} bytes")
    total_size += weight_size

print(f"Total size: {total_size} bytes")



balm.embed_tokens.weight: 42240 bytes
balm.layers.0.norm1.weight: 1280 bytes
balm.layers.0.norm1.bias: 1280 bytes
balm.layers.0.norm2.weight: 1280 bytes
balm.layers.0.norm2.bias: 1280 bytes
balm.layers.0.attention.in_proj_weight: 1228800 bytes
balm.layers.0.attention.in_proj_bias: 3840 bytes
balm.layers.0.attention.out_proj.weight: 409600 bytes
balm.layers.0.attention.out_proj.bias: 1280 bytes
balm.layers.0.feed_forward.0.weight: 1638400 bytes
balm.layers.0.feed_forward.0.bias: 5120 bytes
balm.layers.0.feed_forward.2.weight: 819200 bytes
balm.layers.0.feed_forward.2.bias: 1280 bytes
balm.layers.1.norm1.weight: 1280 bytes
balm.layers.1.norm1.bias: 1280 bytes
balm.layers.1.norm2.weight: 1280 bytes
balm.layers.1.norm2.bias: 1280 bytes
balm.layers.1.attention.in_proj_weight: 1228800 bytes
balm.layers.1.attention.in_proj_bias: 3840 bytes
balm.layers.1.attention.out_proj.weight: 409600 bytes
balm.layers.1.attention.out_proj.bias: 1280 bytes
balm.layers.1.feed_forward.0.weight: 1638400 bytes
