In [3]:
import torch
import math

x = torch.randn(4,32,125,32)
print(x[:, :16].shape)  # 第一个维度不变，第二维度取前 16 个，后面维度不免
print(x[..., :16].shape)  # 表示对所有前面的维度保持不变，只对最后一个维度进行切片取前 16 个

  from .autonotebook import tqdm as notebook_tqdm


torch.Size([4, 16, 125, 32])
torch.Size([4, 32, 125, 16])


In [11]:
# k和v无法用 repeat(1,4,1,1) ，原因如下
k = torch.randn(1,2)
shape = list(k.shape)
shape[1] *= 4
print(k)
a = k.unsqueeze(2)
print(a)
b = a.repeat(1,1,4)
print(b)
k1 = b.reshape(shape)
print(k1)  
k2 = k.repeat(1, 4)
print(k2) 
print((k1 == k2).all())

tensor([[-0.6282, -0.2202]])
tensor([[[-0.6282],
         [-0.2202]]])
tensor([[[-0.6282, -0.6282, -0.6282, -0.6282],
         [-0.2202, -0.2202, -0.2202, -0.2202]]])
tensor([[-0.6282, -0.6282, -0.6282, -0.6282, -0.2202, -0.2202, -0.2202, -0.2202]])
tensor([[-0.6282, -0.2202, -0.6282, -0.2202, -0.6282, -0.2202, -0.6282, -0.2202]])
tensor(False)


In [13]:
import torch
import math

@torch.no_grad()
def llama_rotary_embedding(length):
    inv_freq = torch.arange(0, 32, 2) / 32
    inv_freq = 1 / (500000 ** inv_freq)
    inv_freq = inv_freq.reshape(16, 1)

    position_ids = torch.arange(length).reshape(1, length).float()
    freq = inv_freq.matmul(position_ids).transpose(0,1)
    emb = torch.cat((freq, freq), -1)
    return emb.cos(), emb.sin()

def apply_rotary_pos_emb(x, cos, sin):
    def rotate_half(x):
        left = x[..., :16]
        right = -x[..., 16:]
        return torch.cat((right, left), -1)
    return x * cos + rotate_half(x) * sin

def get_causal_mask(attention_mask):
    B, L = attention_mask.shape
    min_value = -1e15
    causal_mask = torch.full((L, L), min_value).triu(diagonal=1)
    causal_mask = causal_mask.reshape(1,1,L,L).repeat(B, 1, 1, 1)
    causal_mask = causal_mask.to(attention_mask.device)

    mask = attention_mask.reshape(B, 1, 1, L) == 0
    causal_mask = causal_mask.masked_fill(mask, min_value)
    return causal_mask

In [15]:
class LlamaRMSNorm(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.weight = torch.nn.Parameter(torch.ones(1024))
    def forward(self, x):
        var = x.pow(2).mean(-1, keepdim=True)
        x = x * (var + 1e-5).rsqrt()
        return self.weight * x
        
class LlamaMLP(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.gate_proj = torch.nn.Linear(1024, 14336, bias=False)
        self.up_proj = torch.nn.Linear(1024, 14336, bias=False)
        self.down_proj = torch.nn.Linear(14336, 1024, bias=False)
        self.act_fn = torch.nn.SiLU()
    def forward(self, x):
        left = self.act_fn(self.gate_proj(x))
        right = self.up_proj(x)
        return self.down_proj(left * right)

class LlamaAttention(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.q_proj = torch.nn.Linear(1024, 1024, bias=False)
        self.k_proj = torch.nn.Linear(1024, 256, bias=False)
        self.v_proj = torch.nn.Linear(1024, 256, bias=False)
        self.o_proj = torch.nn.Linear(1024, 1024, bias=False)

    def forward(self, hidden_state, attention_mask):
        B, L, _ = hidden_state.shape
        q = self.q_proj(hidden_state).reshape(B, L, 32, 32).transpose(1,2)
        k = self.k_proj(hidden_state).reshape(B, L, 8, 32).transpose(1,2)
        v = self.v_proj(hidden_state).reshape(B, L, 8, 32).transpose(1,2)

        cos, sin = llama_rotary_embedding(L)
        cos, sin = cos.to(hidden_state.device), sin.to(hidden_state.device)
        q = apply_rotary_pos_emb(q, cos, sin)
        k = apply_rotary_pos_emb(k, cos, sin)
        k = k.unsqueeze(2).repeat(1, 1, 4, 1, 1).reshape(B, -1, L, 32)
        v = v.unsqueeze(2).repeat(1, 1, 4, 1, 1).reshape(B, -1, L, 32)

        attn = q.matmul(k.transpose(2,3)) / math.sqrt(32)
        attention_mask = get_causal_mask(attention_mask)
        attn = (attn + attention_mask).softmax(-1)
        attn = attn.matmul(v)
        
        attn = attn.transpose(1,2).reshape(B, L, -1)
        attn = self.o_proj(attn)
        return attn

In [17]:
class LlamaDecoderLayer(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.self_attn = LlamaAttention()
        self.mlp = LlamaMLP()
        self.input_layernorm = LlamaRMSNorm()
        self.post_attention_layernorm = LlamaRMSNorm()
        

    def forward(self, hidden_state, attention_mask):
        res = hidden_state
        hidden_state = self.input_layernorm(hidden_state)
        hidden_state = self.self_attn(hidden_state, attention_mask) + res
        res = hidden_state
        hidden_state = self.post_attention_layernorm(hidden_state)
        hidden_state = self.mlp(hidden_state) + res
        return  hidden_state  
        
class LlamaModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.embed_tokens = torch.nn.Embedding(128256, 1024, None)
        self.layers = torch.nn.ModuleList([LlamaDecoderLayer() for _ in range(4)])
        self.norm = LlamaRMSNorm()

    def forward(self, input_ids, attention_mask):
        hidden_state = self.embed_tokens(input_ids)
        for layer in self.layers:
            hidden_state = layer(hidden_state, attention_mask)
        hidden_state = self.norm(hidden_state)
        return hidden_state

class LlamaForCausalLM(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.model = LlamaModel()
        self.lm_head = torch.nn.Linear(1024, 128256, bias=False)
        
    def forward(self, input_ids, attention_mask, labels=None):
        logits = self.model(input_ids, attention_mask)
        logits = self.lm_head(logits)
        loss = None
        if labels is not None:
            shift_logits = logits[:, :-1].reshape(-1, 128256)
            shift_labels = labels[:, 1:].reshape(-1)
            loss = torch.nn.functional.cross_entropy(shift_logits, shift_labels)
        return loss, logits

In [19]:
from transformers import LlamaConfig, LlamaForCausalLM as LlamaForCausalLM_Original

#测试是否和官方模型的计算输出一样
config = "{'vocab_size': 128256, 'max_position_embeddings': 8192, 'hidden_size': 4096, 'intermediate_size': 14336, 'num_hidden_layers': 32, 'num_attention_heads': 32, 'num_key_value_heads': 8, 'hidden_act': 'silu', 'initializer_range': 0.02, 'rms_norm_eps': 1e-05, 'pretraining_tp': 1, 'use_cache': True, 'rope_theta': 500000.0, 'rope_scaling': None, 'attention_bias': False, 'attention_dropout': 0.0, 'mlp_bias': False, 'return_dict': True, 'output_hidden_states': False, 'output_attentions': False, 'torchscript': False, 'torch_dtype': 'bfloat16', 'use_bfloat16': False, 'tf_legacy_loss': False, 'pruned_heads': {}, 'tie_word_embeddings': False, 'chunk_size_feed_forward': 0, 'is_encoder_decoder': False, 'is_decoder': False, 'cross_attention_hidden_size': None, 'add_cross_attention': False, 'tie_encoder_decoder': False, 'max_length': 20, 'min_length': 0, 'do_sample': False, 'early_stopping': False, 'num_beams': 1, 'num_beam_groups': 1, 'diversity_penalty': 0.0, 'temperature': 1.0, 'top_k': 50, 'top_p': 1.0, 'typical_p': 1.0, 'repetition_penalty': 1.0, 'length_penalty': 1.0, 'no_repeat_ngram_size': 0, 'encoder_no_repeat_ngram_size': 0, 'bad_words_ids': None, 'num_return_sequences': 1, 'output_scores': False, 'return_dict_in_generate': False, 'forced_bos_token_id': None, 'forced_eos_token_id': None, 'remove_invalid_values': False, 'exponential_decay_length_penalty': None, 'suppress_tokens': None, 'begin_suppress_tokens': None, 'architectures': ['LlamaForCausalLM'], 'finetuning_task': None, 'id2label': {0: 'LABEL_0', 1: 'LABEL_1'}, 'label2id': {'LABEL_0': 0, 'LABEL_1': 1}, 'tokenizer_class': None, 'prefix': None, 'bos_token_id': 128000, 'pad_token_id': None, 'eos_token_id': 128001, 'sep_token_id': None, 'decoder_start_token_id': None, 'task_specific_params': None, 'problem_type': None, '_name_or_path': '', 'transformers_version': '4.38.2', 'model_type': 'llama'}"
config = LlamaConfig.from_dict(eval(config))
config.hidden_size = 1024
config.num_hidden_layers = 4

model_actor1 = LlamaForCausalLM_Original(config)
model_actor2 = LlamaForCausalLM()

model_actor2.load_state_dict(model_actor1.state_dict())

input = {
    'input_ids': torch.randint(100, 50000, [4, 125]),
    'attention_mask': torch.ones(4, 125).long(),
    'labels': torch.randint(100, 50000, [4, 125])
}
input['attention_mask'][:, 120:] = 0

out = model_actor1(**input)
loss, logits = model_actor2(**input)

print(out.loss, out.logits.shape)
print(loss, logits.shape)

out.loss == loss, (out.logits == logits).all()

We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)


tensor(11.9729, grad_fn=<NllLossBackward0>) torch.Size([4, 125, 128256])
tensor(11.9729, grad_fn=<NllLossBackward0>) torch.Size([4, 125, 128256])


(tensor(True), tensor(True))