In [1]:
import torch
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_name = "/data/mfx/deepseek-ai/DeepSeek-V2-Lite"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True, torch_dtype=torch.bfloat16, device_map='auto')
model.generation_config = GenerationConfig.from_pretrained(model_name)
model.generation_config.pad_token_id = model.generation_config.eos_token_id

text = "An attention function can be described as mapping a query and a set of key-value pairs to an output, where the query, keys, values, and output are all vectors. The output is"
inputs = tokenizer(text, return_tensors="pt")
outputs = model.generate(**inputs.to(model.device), max_new_tokens=100)
result = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(result)

Loading checkpoint shards: 100%|██████████| 4/4 [00:20<00:00,  5.08s/it]
The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` model input instead.


An attention function can be described as mapping a query and a set of key-value pairs to an output, where the query, keys, values, and output are all vectors. The output is a scalar.

In this tutorial, you will discover how to implement an attention function in Keras.

After completing this tutorial, you will know:

- How to define an attention function in Keras.
- How to use an attention function in a Keras model.
- How to use an attention function in a Keras model for sequence classification.

Let’s get started.

Tutorial Overview

This tutorial is divided into three parts;


In [4]:
num_heads=model.config.num_attention_heads
qk_nope_head_dim = model.config.qk_nope_head_dim
qk_rope_head_dim = model.config.qk_rope_head_dim
q_head_dim = qk_nope_head_dim + qk_rope_head_dim
v_head_dim = model.config.v_head_dim
kv_lora_rank = model.config.kv_lora_rank
hidden_size = model.config.hidden_size

In [5]:
for name,module in tqdm(model.named_modules()):
    if name.endswith("self_attn"):
        # Orthogonal q_proj and k_up_weight
        q_proj = module.q_proj.weight.data.to(torch.float32)
        q_proj = q_proj.view(num_heads, q_head_dim, hidden_size)
        q_nope, q_rope = torch.split(q_proj, [qk_nope_head_dim, qk_rope_head_dim], dim=1) # q_nope(num_head, head_dim, hidden_size)
        kv_b_proj = module.kv_b_proj.weight.data.to(torch.float32)
        kv_b_proj = kv_b_proj.view(num_heads, qk_nope_head_dim+v_head_dim, kv_lora_rank)
        k_nope, value_states = torch.split(kv_b_proj, [qk_nope_head_dim, v_head_dim], dim=1) # k_nope(num_head, head_dim, latent_dim),  value_states(num_head, head_dim, latent_dim)
        q_t_k_up = torch.einsum("hdD,hdL->hDL",q_nope, k_nope) # (num_head, head_dim, latent_dim), rank<=head_dim
        U,S,V = torch.svd_lowrank(q_t_k_up, qk_nope_head_dim, niter=qk_nope_head_dim) # U(num_head, hidden_size, head_dim), S(num_head, head_dim), V(num_head, latent_dim, head_dim)
        q_nope = torch.einsum('hDd,hd->hdD',U,torch.sqrt(S)) # (num_head, head_dim, hidden_size)
        k_nope = torch.einsum('hd,hLd->hdL',torch.sqrt(S),V) # (num_head, head_dim, latent_dim)
        module.q_proj.weight.data = torch.cat([q_nope, q_rope],dim=1).reshape(num_heads*q_head_dim, hidden_size).contiguous().to(torch.bfloat16)
        
        
        # Orthogonal o_proj and v_up_weight
        o_proj = module.o_proj.weight.data.to(torch.float32)
        o_proj = o_proj.view(hidden_size, num_heads, v_head_dim).transpose(0,1) # (num_head, hidden_size, head_dim)
        o_v_up = torch.einsum("hDd,hdL->hDL",o_proj, value_states) # (num_head, hidden_size, latent_dim), rank<=head_dim
        U,S,V = torch.svd_lowrank(o_v_up, v_head_dim, niter=v_head_dim) # U(num_head, hidden_size, head_dim), S(num_head, head_dim), V(num_head, latent_dim, head_dim)
        o_proj = torch.einsum('hDd,hd->Dhd',U,torch.sqrt(S)) # (hidden_size, num_head, head_dim)
        value_states = torch.einsum('hd,hLd->hdL',torch.sqrt(S),V) # (num_head, head_dim, latent_dim)
        module.kv_b_proj.weight.data = torch.cat([k_nope, value_states],dim=1).reshape(num_heads*(qk_nope_head_dim+v_head_dim), kv_lora_rank).contiguous().to(torch.bfloat16)
        module.o_proj.weight.data = o_proj.reshape(hidden_size, (num_heads * v_head_dim)).contiguous().to(torch.bfloat16)

8809it [07:02, 20.87it/s]


In [6]:
outputs = model.generate(**inputs.to(model.device), max_new_tokens=100)
result = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(result)

An attention function can be described as mapping a query and a set of key-value pairs to an output, where the query, keys, values, and output are all vectors. The output is a scalar, which is the cosine similarity between the query and the key.

The attention function is used in the Transformer architecture, which is a neural network architecture that is used for natural language processing tasks such as machine translation and text classification. The attention function is used to compute the importance of each word in the input sequence, and to generate a representation of the input sequence that is more useful for the task at hand.

The attention function is defined as follows:

$$
attention


In [10]:
model.save_pretrained("DeepSeek-V2-Lite_transMLA")
#model.push_to_hub("fxmeng/DeepSeek-V2-Lite_transMLA")

In [None]:
tokenizer.save_pretrained("DeepSeek-V2-Lite_transMLA")
#tokenizer.push_to_hub("fxmeng/DeepSeek-V2-Lite_transMLA")