In [39]:
import torch
from transformers import AutoConfig
from accelerate.commands.estimate import estimate_command_parser, gather_data

parser = estimate_command_parser()
dtypes = ["float32", "float16", "int8", "int4"]

DEFAULT_PYTORCH_USAGE_PER_GPU = 1_000_000_000

def get_model_size(model_id, dtype):
  args = parser.parse_args([model_id, "--dtypes", dtype])
  output = gather_data(args)
  return {
    "dtype": dtype,
    "model_size_in_bytes": output[0][2],
    "model_size_in_megabytes": output[0][2]/(1024*1024),
    "model_size_in_gigabytes": output[0][2]/(1024*1024*1024),
}

def get_tgi_memory(model_id, max_prefill_tokens, dtype):
  config = AutoConfig.from_pretrained(model_id)
  # calculate the memory required for max prefilled tokens
  tensor_size = (max_prefill_tokens ** 2) * config.num_attention_heads
  memory = tensor_size * getattr(torch, dtype).itemsize
  return {
    "dtype": dtype,
    "memory_in_bytes": memory,
    "memory_in_megabytes": memory/(1024*1024),
    "memory_in_gigabytes": memory/(1024*1024*1024),
  }
  
def get_real_memory(model_id, max_prefill_tokens, dtype,num_gpus=1):
  model_size = get_model_size(model_id, dtype)
  tgi_memory = get_tgi_memory(model_id, max_prefill_tokens, dtype)
  real_memory_with_buffer = (model_size["model_size_in_bytes"] + tgi_memory["memory_in_bytes"]+ (num_gpus*DEFAULT_PYTORCH_USAGE_PER_GPU)) * 1.2
  return {
    "dtype": dtype,
    "real_memory_in_bytes": real_memory_with_buffer,
    "real_memory_in_megabytes": real_memory_with_buffer/(1024*1024),
    "real_memory_in_gigabytes": real_memory_with_buffer/(1024*1024*1024),
    "real_memory_per_gpu_in_bytes": real_memory_with_buffer/num_gpus,
    "real_memory_per_gpu_in_megabytes": real_memory_with_buffer/(1024*1024*num_gpus),
    "real_memory_per_gpu_in_gigabytes": real_memory_with_buffer/(1024*1024*1024*num_gpus),
  }

In [40]:
get_tgi_memory("meta-llama/Llama-2-7b-chat-hf",8192,"float16")

4096


{'dtype': 'float16',
 'memory_in_bytes': 4294967296,
 'memory_in_megabytes': 4096.0,
 'memory_in_gigabytes': 4.0}

In [41]:
get_tgi_memory("meta-llama/Llama-2-7b-chat-hf",8192,"float16")

4096


{'dtype': 'float16',
 'memory_in_bytes': 4294967296,
 'memory_in_megabytes': 4096.0,
 'memory_in_gigabytes': 4.0}

In [42]:
get_model_size("meta-llama/Llama-2-7b-chat-hf","float16")

Loading pretrained config for `meta-llama/Llama-2-7b-chat-hf` from `transformers`...


{'dtype': 'float16',
 'model_size_in_bytes': 13281800192.0,
 'model_size_in_megabytes': 12666.51171875,
 'model_size_in_gigabytes': 12.369640350341797}

In [43]:
get_real_memory("meta-llama/Llama-2-7b-chat-hf",8192,"float16")

Loading pretrained config for `meta-llama/Llama-2-7b-chat-hf` from `transformers`...
4096


{'dtype': 'float16',
 'real_memory_in_bytes': 22292120985.6,
 'real_memory_in_megabytes': 21259.4232421875,
 'real_memory_in_gigabytes': 20.76115550994873,
 'real_memory_per_gpu_in_bytes': 22292120985.6,
 'real_memory_per_gpu_in_megabytes': 21259.4232421875,
 'real_memory_per_gpu_in_gigabytes': 20.76115550994873}

In [44]:
get_real_memory("meta-llama/Llama-2-70b-chat-hf",8192,"float16")

Loading pretrained config for `meta-llama/Llama-2-70b-chat-hf` from `transformers`...
8192


{'dtype': 'float16',
 'real_memory_in_bytes': 176503274496.0,
 'real_memory_in_megabytes': 168326.6396484375,
 'real_memory_in_gigabytes': 164.38148403167725,
 'real_memory_per_gpu_in_bytes': 176503274496.0,
 'real_memory_per_gpu_in_megabytes': 168326.6396484375,
 'real_memory_per_gpu_in_gigabytes': 164.38148403167725}

In [45]:
get_real_memory("meta-llama/Llama-2-70b-chat-hf",8192,"float16",4)

Loading pretrained config for `meta-llama/Llama-2-70b-chat-hf` from `transformers`...
8192


{'dtype': 'float16',
 'real_memory_in_bytes': 180103274496.0,
 'real_memory_in_megabytes': 171759.8671875,
 'real_memory_in_gigabytes': 167.73424530029297,
 'real_memory_per_gpu_in_bytes': 45025818624.0,
 'real_memory_per_gpu_in_megabytes': 42939.966796875,
 'real_memory_per_gpu_in_gigabytes': 41.93356132507324}