In [111]:
from transformers import AutoConfig, AutoModelForCausalLM
import os
from accelerate import infer_auto_device_map, init_empty_weights

[2023-08-17 09:51:34,443] [INFO] [real_accelerator.py:133:get_accelerator] Setting ds_accelerator to cuda (auto detect)


In [101]:
model_name_or_path = "/data/tongyx361/.cache/huggingface/hub/models--meta-llama--Llama-2-7b-hf/snapshots/6fdf2e60f86ff2481f2241aaee459f85b5b0bbb9"
model_config_path = os.path.join(model_name_or_path, "config.json")

In [113]:
model_config = AutoConfig.from_pretrained(model_config_path)
print(model_config)

with init_empty_weights():
    model = AutoModelForCausalLM.from_config(model_config)

total_params = sum(dict((p.data_ptr(), p.numel()) for p in model.parameters()).values())
print(f"total_params = {total_params}")

for name, param in model.named_parameters():
    print(name, param.shape)


device_map = infer_auto_device_map(model)
print(device_map)

LlamaConfig {
  "_name_or_path": "/data/tongyx361/.cache/huggingface/hub/models--meta-llama--Llama-2-7b-hf/snapshots/6fdf2e60f86ff2481f2241aaee459f85b5b0bbb9/config.json",
  "architectures": [
    "LlamaForCausalLM"
  ],
  "bos_token_id": 1,
  "eos_token_id": 2,
  "hidden_act": "silu",
  "hidden_size": 4096,
  "initializer_range": 0.02,
  "intermediate_size": 11008,
  "max_position_embeddings": 4096,
  "model_type": "llama",
  "num_attention_heads": 32,
  "num_hidden_layers": 32,
  "num_key_value_heads": 32,
  "pad_token_id": 0,
  "pretraining_tp": 1,
  "rms_norm_eps": 1e-05,
  "rope_scaling": null,
  "tie_word_embeddings": false,
  "torch_dtype": "float16",
  "transformers_version": "4.31.0",
  "use_cache": true,
  "vocab_size": 32000
}

total_params = 131072000
model.embed_tokens.weight torch.Size([32000, 4096])
model.layers.0.self_attn.q_proj.weight torch.Size([4096, 4096])
model.layers.0.self_attn.k_proj.weight torch.Size([4096, 4096])
model.layers.0.self_attn.v_proj.weight torch.S

In [109]:
h = hidden_size = model_config.hidden_size
# intermediate_size = model_config.intermediate_size
intermediate_size = 4 * h
V = vocab_size = model_config.vocab_size
l = num_hidden_layers = model_config.num_hidden_layers
assert model_config.num_attention_heads == model_config.num_key_value_heads
a = model_config.num_attention_heads

num_gpus = 4

b = batch_size_per_gpu = 2
# s = max_seq_len = 4096
s = max_seq_len = 1024

# b = batch_size_per_gpu = 20
# b = batch_size_per_gpu = 3
# s = max_seq_len = 2048

# num_gpus = 8

In [103]:
num_params_per_self_attention = 4 * (h**2 + h)
num_params_per_ffn = h * intermediate_size * 2 + intermediate_size + h
num_params_per_layer_norm = 2 * h
num_params_per_transformer_layer = (
    num_params_per_self_attention + num_params_per_ffn + 2 * num_params_per_layer_norm
)

num_params_token_embedding_matrix = vocab_size * hidden_size

num_params_transformer_model = (
    num_params_per_transformer_layer * num_hidden_layers
    + num_params_token_embedding_matrix
)


print("Number of parameters per self-attention layer:", num_params_per_self_attention)
print("Number of parameters per feed-forward network:", num_params_per_ffn)
print("Number of parameters per transformer layer:", num_params_per_transformer_layer)
print(
    "Number of parameters in the token embedding matrix:",
    num_params_token_embedding_matrix,
)
print("Number of parameters in the transformer model:", num_params_transformer_model)

Number of parameters per self-attention layer: 67125248
Number of parameters per feed-forward network: 134238208
Number of parameters per transformer layer: 201379840
Number of parameters in the token embedding matrix: 131072000
Number of parameters in the transformer model: 6575226880


In [104]:
print(4096**2)
print(12 * 4096**2)
print(12 * 32 * 4096**2)

16777216
201326592
6442450944


In [105]:
# mixed precision
mem_params_mixed_precision = num_params_transformer_model * (2 + 4)
mem_grads_mixed_precision = num_params_transformer_model * (2 + 4)
mem_adam_states = num_params_transformer_model * (4 + 4)
print(
    sum([mem_params_mixed_precision, mem_grads_mixed_precision, mem_adam_states])
    / (1024**3)
)

122.47314453125


In [106]:
mem_zero3_train_per_gpu = (
    sum([mem_params_mixed_precision, mem_grads_mixed_precision, mem_adam_states])
    / num_gpus
)

mem_zero3_train_per_gpu_gb = mem_zero3_train_per_gpu / (1024**3)
print(f"mem_zero3_train_per_gpu_gb = {mem_zero3_train_per_gpu_gb:.2f} GB")

mem_zero3_train_per_gpu_gb = 30.62 GB


In [107]:
mem_activation_per_transformer_layer_per_gpu = 34 * b * s * h + 5 * b * s**2 * a
mem_activation_transformer_model_per_gpu = (
    num_hidden_layers * mem_activation_per_transformer_layer_per_gpu
)
gb_mem_activation_transformer_model_per_gpu = (
    mem_activation_transformer_model_per_gpu / (1024**3)
)
print(
    f"gb_mem_activation_transformer_model_per_gpu = {gb_mem_activation_transformer_model_per_gpu:.2f} GB"
)

gb_mem_activation_transformer_model_per_gpu = 18.50 GB
