In [5]:
import os

# Set Hugging Face caching paths
os.environ["HF_HOME"] = "/data/models"
os.environ["TRANSFORMERS_CACHE"] = "/data/models"
os.environ["HF_DATASETS_CACHE"] = "/data/models/datasets"
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "false"
os.environ["HF_HUB_DISABLE_XET"] = "True"

# Auto-create directories if they don't exist
os.makedirs("/data/models", exist_ok=True)
os.makedirs("/data/models/datasets", exist_ok=True)

In [6]:
from transformers import AutoConfig

model_name = "Qwen/Qwen3-VL-235B-A22B-Thinking" # @param {type:"string"}
model_config = AutoConfig.from_pretrained(model_name)

# hidden_layers = model_config.num_hidden_layers
# hidden_size =  model_config.hidden_size
# attention_heads = model_config.num_attention_heads
# kv_heads = 0
# if hasattr(model_config, "num_key_value_heads"):
#   kv_heads = model_config.num_key_value_heads

# print("Model: "+str(model_name))
# print("Hidden layers (L): "+str(hidden_layers))
# print("Hidden size (h): "+str(hidden_size))
# print("Attention heads (a): "+str(attention_heads))
# if kv_heads > 0:
#   print("Key-value heads (g): "+str(kv_heads))

MoE MODEL  

In [7]:
hidden_layers = model_config.text_config.num_hidden_layers
hidden_size =  model_config.text_config.hidden_size
attention_heads = model_config.text_config.num_attention_heads
kv_heads = 0
if hasattr(model_config.text_config, "num_key_value_heads"):
  kv_heads = model_config.text_config.num_key_value_heads
print("Model: "+str(model_name))
print("Hidden layers (L): "+str(hidden_layers))
print("Hidden size (h): "+str(hidden_size))
print("Attention heads (a): "+str(attention_heads))
if kv_heads > 0:
  print("Key-value heads (g): "+str(kv_heads))

Model: Qwen/Qwen3-VL-235B-A22B-Thinking
Hidden layers (L): 94
Hidden size (h): 4096
Attention heads (a): 64
Key-value heads (g): 4


In [8]:
#Number of parameters in the model (in billions)
nb_billion_parameters = 235 # @param {type:"number"}
print("Number of parameters in the model (n): "+str(nb_billion_parameters)+"B")

#Precision of the parameters in the model
bitwidth_model = 16 # @param {type:"integer"}
print("Bitwidth of the model's parameters (p): "+str(bitwidth_model)+"-bit")

#The maximum number of tokens in a sequence
seqlen = 256000 # @param {type:"integer"}
print("Sequence length (s): "+str(seqlen))

#The batch size
batch_size = 16 # @param {type:"integer"}
print("Batch size (b): "+str(batch_size))



#Use FlashAttention
Flash_Attention = True # @param {type:"boolean"}
tile_size = 128
if Flash_Attention:
  print("Use FlashAttention: Yes")
else:
  print("Use FlashAttention: No")

#Use a KV cache (if yes, should be equal to the seqlen)
Use_KV_Cache = True # @param {type:"boolean"}
kv_cache = 0
if Use_KV_Cache:
  print("Use a KV cache: Yes")
  kv_cache = seqlen
else:
  print("Use a KV cache: No")

Number of parameters in the model (n): 235B
Bitwidth of the model's parameters (p): 16-bit
Sequence length (s): 256000
Batch size (b): 16
Use FlashAttention: Yes
Use a KV cache: Yes


In [9]:
def estimate_consumption_inference():
  return round((32*seqlen*batch_size*hidden_size + 4*attention_heads*seqlen*seqlen*batch_size)*2/(1000**3),2)
def estimate_consumption_inference_gqa():
  return round((28*seqlen*batch_size*hidden_size + ((2*kv_heads)/attention_heads)*seqlen*batch_size*hidden_size + 4*kv_heads*seqlen*seqlen*batch_size)*2/(1000**3),2)

def estimate_consumption_inference_FA(): #Ignoring GQA for simplicity; will be slightly lower with GQA
  return round((32*seqlen*batch_size*hidden_size + 4*tile_size*seqlen*batch_size)*2/(1000**3),2)


def kv_cache():
  return round(2*hidden_layers*seqlen*batch_size*hidden_size*2/(1000**3),2)

def kv_cache_gqa():
  return round(2*hidden_layers*seqlen*batch_size*(hidden_size/kv_heads)*2/(1000**3),2)

def estimate_model_size():
  return round(nb_billion_parameters*bitwidth_model/8*(1000**3)/(1000**3),2)


activation_consumption_inference = estimate_consumption_inference()
activation_consumption_inference_gqa = estimate_consumption_inference_gqa()
activation_consumption_inference_FA = estimate_consumption_inference_FA()
model_consumption = estimate_model_size()

print("Memory consumption of the model: "+str(model_consumption)+" GB\n")

print("Memory consumption of vanilla inference: "+str(activation_consumption_inference)+" GB \n")
if kv_heads > 0:
  print("Memory consumption of inference with GQA: "+str(activation_consumption_inference_gqa)+" GB \n")

print("Memory consumption of inference with FlashAttention: "+str(activation_consumption_inference_FA)+" GB \n")

if Use_KV_Cache:
  if kv_heads > 0:
    kv_cache_cost = kv_cache_gqa()
    print("Memory consumption of the KV cache (with GQA): "+str(kv_cache_cost)+" GB \n")
  else:
    kv_cache_cost = kv_cache()
    print("Memory consumption of the KV cache: "+str(kv_cache_cost)+" GB \n")
else:
  kv_cache_cost = 0

if Flash_Attention:
  print("Total Memory consumption (given the selected configuration): "+str(round(model_consumption+kv_cache_cost+activation_consumption_inference_FA,2))+" GB\n")
elif kv_heads > 0:
  print("Total Memory consumption (given the selected configuration): "+str(round(model_consumption+kv_cache_cost+activation_consumption_inference_gqa,2))+" GB\n")
else:
  print("Total Memory consumption (given the selected configuration): "+str(round(model_consumption+kv_cache_cost+activation_consumption_inference,2))+" GB\n")

Memory consumption of the model: 470.0 GB

Memory consumption of vanilla inference: 537944.65 GB 

Memory consumption of inference with GQA: 34498.15 GB 

Memory consumption of inference with FlashAttention: 1077.94 GB 

Memory consumption of the KV cache (with GQA): 1577.06 GB 

Total Memory consumption (given the selected configuration): 3125.0 GB

