# Modern Model Study

In [2]:
from transformers import AutoModel

gemma4b = AutoModel.from_pretrained("google/gemma-3-4b-pt")

# Function to count the total parameters for the model
def count_model_parameters(model, is_human: bool):
    params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return f"{params / 1e6:.2f}M" if is_human else params

show_param_count = lambda layer_name, num_params: f"Number of {layer_name} Parameters: {num_params}"

Loading checkpoint shards: 100%|██████████| 2/2 [00:37<00:00, 18.84s/it]


In [3]:
print(gemma4b)

Gemma3Model(
  (vision_tower): SiglipVisionModel(
    (vision_model): SiglipVisionTransformer(
      (embeddings): SiglipVisionEmbeddings(
        (patch_embedding): Conv2d(3, 1152, kernel_size=(14, 14), stride=(14, 14), padding=valid)
        (position_embedding): Embedding(4096, 1152)
      )
      (encoder): SiglipEncoder(
        (layers): ModuleList(
          (0-26): 27 x SiglipEncoderLayer(
            (layer_norm1): LayerNorm((1152,), eps=1e-06, elementwise_affine=True)
            (self_attn): SiglipAttention(
              (k_proj): Linear(in_features=1152, out_features=1152, bias=True)
              (v_proj): Linear(in_features=1152, out_features=1152, bias=True)
              (q_proj): Linear(in_features=1152, out_features=1152, bias=True)
              (out_proj): Linear(in_features=1152, out_features=1152, bias=True)
            )
            (layer_norm2): LayerNorm((1152,), eps=1e-06, elementwise_affine=True)
            (mlp): SiglipMLP(
              (activation_fn): 

## Vision Encoder Parameter Count

In [4]:
vision_encoder = gemma4b.vision_tower.vision_model

### Embedding Layer Count

In [8]:
print(show_param_count("Patch Embedding", count_model_parameters(vision_encoder.embeddings.patch_embedding, is_human=False)))
print(show_param_count("Positional Embedding", count_model_parameters(vision_encoder.embeddings.position_embedding, is_human=False)))

Number of Patch Embedding Parameters: 678528
Number of Positional Embedding Parameters: 4718592


### Transformer Layer Count

In [12]:
print(show_param_count("Pre-Attention LayerNorm", count_model_parameters(vision_encoder.encoder.layers[0].layer_norm1, is_human=False)))
print(show_param_count("Self-Multi-head Attention", count_model_parameters(vision_encoder.encoder.layers[0].self_attn, is_human=False)))
print(show_param_count("Pre-MLP LayerNorm", count_model_parameters(vision_encoder.encoder.layers[0].layer_norm2, is_human=False)))
print(show_param_count("MLP", count_model_parameters(vision_encoder.encoder.layers[0].mlp, is_human=False)))
print(show_param_count("Transformer Block", count_model_parameters(vision_encoder.encoder.layers[0], is_human=False)))

Number of Pre-Attention LayerNorm Parameters: 2304
Number of Self-Multi-head Attention Parameters: 5313024
Number of Pre-MLP LayerNorm Parameters: 2304
Number of MLP Parameters: 9921872
Number of Transformer Block Parameters: 15239504


### Total Vision Encoder Count

In [13]:
print(show_param_count("Positional Embedding", count_model_parameters(vision_encoder, is_human=False)))

Number of Positional Embedding Parameters: 416866032


## Language Model Count

In [14]:
lang_model = gemma4b.language_model

### Word Embedding Count

In [None]:
print(show_param_count("Word Embedding", count_model_parameters(lang_model.embed_tokens, is_human=False)))

Number of Patch Embedding Parameters: 671252480


### Transformer Layer Count

In [16]:
print(show_param_count("Pre-Attention RMSNorm", count_model_parameters(lang_model.layers[0].input_layernorm, is_human=False)))
print(show_param_count("Self-Grouped-Query Attention", count_model_parameters(lang_model.layers[0].self_attn, is_human=False)))
print(show_param_count("Post-Attention RMSNorm", count_model_parameters(lang_model.layers[0].post_attention_layernorm, is_human=False)))
print(show_param_count("Pre-MLP RMSNorm", count_model_parameters(lang_model.layers[0].pre_feedforward_layernorm, is_human=False)))
print(show_param_count("MLP", count_model_parameters(lang_model.layers[0].mlp, is_human=False)))
print(show_param_count("Post-MLP RMSNorm", count_model_parameters(lang_model.layers[0].post_feedforward_layernorm, is_human=False)))
print(show_param_count("Transformer Block", count_model_parameters(lang_model.layers[0], is_human=False)))

Number of Pre-Attention RMSNorm Parameters: 2560
Number of Self-Grouped-Query Attention Parameters: 15729152
Number of Post-Attention RMSNorm Parameters: 2560
Number of Pre-MLP RMSNorm Parameters: 2560
Number of MLP Parameters: 78643200
Number of Post-MLP RMSNorm Parameters: 2560
Number of Transformer Block Parameters: 94382592


### Total Language Model Count

In [17]:
print(show_param_count("Language Model", count_model_parameters(lang_model, is_human=False)))

Number of Language Model Parameters: 3880263168


## Multimodal Layer Count

In [21]:
for name, p in gemma4b.multi_modal_projector.named_parameters():
    print(f"Shape of {name} Layer: ", p.shape)
    print(f"Number of {name} Parameter: ", p.numel())

Shape of mm_input_projection_weight Layer:  torch.Size([1152, 2560])
Number of mm_input_projection_weight Parameter:  2949120
Shape of mm_soft_emb_norm.weight Layer:  torch.Size([1152])
Number of mm_soft_emb_norm.weight Parameter:  1152


## Total Count

In [22]:
print(show_param_count("Gemma3 4B", count_model_parameters(gemma4b, is_human=False)))

Number of Gemma3 4B Parameters: 4300079472
