# Load the model

In [None]:
!pip3 install transformers huggingface_hub[hf_xet] -q

In [1]:
tgt_lang = "de"
# tgt_lang = "zh"

In [None]:
from transformers import AutoProcessor, Qwen2AudioForConditionalGeneration
import torch

model_cache_dir = "/workspace/model/"

model_name = "Qwen/Qwen2-Audio-7B-Instruct"  # better use a fine-tuned model

model = Qwen2AudioForConditionalGeneration.from_pretrained(model_name,
                                                           torch_dtype=torch.bfloat16,
                                                           cache_dir=model_cache_dir,
                                                          ).to("cuda").eval()
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)

print("Model loaded:", model_name)

In [4]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters())

model_parameters = count_parameters(model)
print(f"Model parameters: {model_parameters:,}")

Model parameters: 8,397,094,912


In [None]:
original_state_dict = model.state_dict()
print(f"Number of keys in the original state_dict: {len(original_state_dict.keys())}")

In [None]:
model.config

In [7]:
len(model.audio_tower.layers), len(model.language_model.model.layers)

(32, 32)

In [8]:
model.config.audio_config

Qwen2AudioEncoderConfig {
  "activation_dropout": 0.0,
  "activation_function": "gelu",
  "attention_dropout": 0.0,
  "d_model": 1280,
  "dropout": 0.0,
  "encoder_attention_heads": 20,
  "encoder_ffn_dim": 5120,
  "encoder_layerdrop": 0.0,
  "encoder_layers": 32,
  "init_std": 0.02,
  "max_source_positions": 1500,
  "model_type": "qwen2_audio_encoder",
  "num_hidden_layers": 32,
  "num_mel_bins": 128,
  "scale_embedding": false,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.51.3"
}

In [9]:
model.config.text_config

Qwen2Config {
  "attention_dropout": 0.0,
  "bos_token_id": 151643,
  "eos_token_id": 151645,
  "hidden_act": "silu",
  "hidden_size": 4096,
  "initializer_range": 0.02,
  "intermediate_size": 11008,
  "max_position_embeddings": 8192,
  "max_window_layers": 28,
  "model_type": "qwen2",
  "num_attention_heads": 32,
  "num_hidden_layers": 32,
  "num_key_value_heads": 32,
  "rms_norm_eps": 1e-05,
  "rope_scaling": null,
  "rope_theta": 10000,
  "sliding_window": 32768,
  "tie_word_embeddings": false,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.51.3",
  "use_cache": true,
  "use_mrope": false,
  "use_sliding_window": false,
  "vocab_size": 156032
}

# Downscaling

In [None]:
layers_to_remove = sorted([13, 3, 20, 9, 29, 1, 19, 27])  # change based on layer importance evaluation
layers_to_keep = [i for i in range(32) if i not in layers_to_remove]

print(f"Pruning: {len(layers_to_keep)} layers:")

In [None]:
# Downscaling -- pruning decoder layers only (recommended; see the paper)
# For pruning the encoder too, uncomment the lines marked as "encoder"...
# ... in which case it needs to be part of the layer importance evaluation.

from torch import nn


# audio_layers = model.audio_tower.layers  # encoder
lm_layers = model.language_model.model.layers  # decoder

# model.audio_tower.layers = nn.ModuleList([audio_layers[n] for n in layers_to_keep_enc])  # encoder
model.language_model.model.layers = nn.ModuleList([lm_layers[n] for n in layers_to_keep])  # decoder


# Ensure the config reflects the actual number of layers
print("Updating the config...")
print("Current values:",
      model.config.audio_config.encoder_layers,
      model.config.audio_config.num_hidden_layers,
      model.language_model.model.config.num_hidden_layers
      )
# model.config.audio_config.encoder_layers = len(model.audio_tower.layers)  # encoder
# model.config.audio_config.num_hidden_layers = len(model.audio_tower.layers)  # encoder

model.config.text_config.num_hidden_layers = len(model.language_model.model.layers)  # decoder
model.language_model.model.config.num_hidden_layers = len(model.language_model.model.layers)  # decoder


# Check the new number of parameters
model_parameters = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {model_parameters:,}")

# Check the new number of layers
print("New values:" ,
      model.config.audio_config.encoder_layers,
      model.config.audio_config.num_hidden_layers,
      model.language_model.model.config.num_hidden_layers
      )

Updating the config...
Current values: 32 32 32
Model parameters: 6,777,929,728
New values: 32 32 24


In [None]:
model.config

### Checks

In [16]:
# Check if both configs have the same values
print(model.config.text_config.num_hidden_layers)
print(model.language_model.model.config.num_hidden_layers)

24
24


In [None]:
# After pruning the model
# Examine the state dictionary keys
state_dict = model.state_dict()
print(f"Number of keys in the original state_dict: {len(original_state_dict.keys())}")
print(f"Number of keys in the new state_dict: {len(state_dict.keys())}")

# Check for any keys that might reference pruned layers
audio_layer_keys = [k for k in state_dict.keys() if 'audio_tower.layers' in k]
lm_layer_keys = [k for k in state_dict.keys() if 'language_model.model.layers' in k]

# Print the unique layer indices in the state dict
audio_indices = set([int(k.split('.')[2]) for k in audio_layer_keys])
lm_indices = set([int(k.split('.')[3]) for k in lm_layer_keys])

print(f"Audio layer indices in state_dict: {sorted(audio_indices)}")
print(f"LM layer indices in state_dict: {sorted(lm_indices)}")

In [None]:
# # Optional: Save the pruned model locally (for verification purposes)
# pruned_model_path = f"pruned_qwen2_audio_ft_model_b16_{tgt_lang}"
# model.save_pretrained(pruned_model_path)
# processor.save_pretrained(pruned_model_path)

# !ls -lh {pruned_model_path}

# Upload to the Hub

In [None]:
num_layers_e = len(model.audio_tower.layers)
num_layers_d = len(model.language_model.model.layers)

user_id = "ymoslem"  # change to your user ID
output_model = f"ymoslem/qwen-audio-en-{tgt_lang}-{num_layers_e}-{num_layers_d}layers"
output_model

In [None]:
# Push the new model to Hugging Face

model.push_to_hub(output_model,
                  private=True,
                  )

processor.push_to_hub(output_model,
                      private=True,
                  )