In [19]:
from typing import List
from dataclasses import dataclass, field
import math

from sympy import isprime
import numpy as np
import torch
import torch.nn as nn
from transformers import AutoTokenizer
from tokenizers import normalizers, Regex 

In [20]:
@dataclass
class EngramConfig:
    tokenizer_name_or_path: str = "/home/user/Downloads/Qwen2.5-0.5B-Instruct"
    engram_vocab_size: List[int] = field(default_factory=lambda: [151665*5, 151665*5]) # [2-gram vocab size, 3-gram vocab size]
    max_ngram_size: int = 3 # 最多连续3个词作为一个词组
    n_embed_per_ngram: int = 512
    n_head_per_ngram: int = 8 # 每个ngram有几个hash头，减缓冲突
    layer_ids: List[int] = field(default_factory=lambda: [1, 15]) # 在第几层添加engram模块
    pad_id: int = 2
    seed: int = 0
    kernel_size: int = 4
    
@dataclass
class BackBoneConfig:
    hidden_size: int = 1024
    hc_mult: int = 4
    vocab_size: int = 151665
    num_layers: int = 30

In [21]:
engram_cfg = EngramConfig()
backbone_config = BackBoneConfig()

In [22]:
class CompressedTokenizer:
    def __init__(
        self,
        tokenizer_name_or_path,
    ):
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, trust_remote_code=True)
        
        SENTINEL = "\uE000"
        # 对原始词表进行标准化，转换小写，去掉重音符号，去掉空格等，用于词表压缩
        self.normalizer = normalizers.Sequence([
            normalizers.NFKC(),
            normalizers.NFD(),
            normalizers.StripAccents(),
            normalizers.Lowercase(),
            normalizers.Replace(Regex(r"[ \t\r\n]+"), " "),
            normalizers.Replace(Regex(r"^ $"), SENTINEL),
            normalizers.Strip(),
            normalizers.Replace(SENTINEL, " "),
        ])
        
        self.lookup_table, self.num_new_token = self._build_lookup_table()
    
    def __len__(self):
        return self.num_new_token
    
    # 将原始词表中的token_id映射到压缩后的token id
    def _build_lookup_table(self):
        old2new = {}
        key2new = {}          
        new_tokens = []

        vocab_size = len(self.tokenizer)
        for tid in range(vocab_size):
            text = self.tokenizer.decode([tid], skip_special_tokens=False)
            
            if "�" in text:
                key = self.tokenizer.convert_ids_to_tokens(tid)
            else:
                norm = self.normalizer.normalize_str(text)
                key = norm if norm else text

            nid = key2new.get(key)
            if nid is None:
                nid = len(new_tokens)
                key2new[key] = nid
                new_tokens.append(key)
            old2new[tid] = nid
        
        lookup = np.empty(vocab_size, dtype=np.int64)
        for tid in range(vocab_size):
            lookup[tid] = old2new[tid]

        return lookup, len(new_tokens)
    
    def _compress(self, input_ids):
        arr = np.asarray(input_ids, dtype=np.int64)
        pos_mask = arr >= 0
        out = arr.copy()
        valid_ids = arr[pos_mask]
        out[pos_mask] = self.lookup_table[valid_ids]
        return out   
    
    def __call__(self, input_ids):
        return self._compress(input_ids)

In [23]:
compressed_tokenizer = CompressedTokenizer(
            tokenizer_name_or_path=engram_cfg.tokenizer_name_or_path
        )

In [24]:
print('压缩后的词表大小:', len(compressed_tokenizer))
print('原始词表大小:', compressed_tokenizer.tokenizer.vocab_size)
print('压缩率:', 1 - len(compressed_tokenizer) / compressed_tokenizer.tokenizer.vocab_size)

压缩后的词表大小: 107453
原始词表大小: 151643
压缩率: 0.29140810983691956


In [25]:

input_ids = compressed_tokenizer.tokenizer.encode('hello world, Hello world')
print('原始input_ids:', input_ids)
compressedinput_ids = compressed_tokenizer(input_ids)
print('压缩后的input_ids:', compressedinput_ids)


原始input_ids: [14990, 1879, 11, 21927, 1879]
压缩后的input_ids: [6378 1346   11 6378 1346]


In [29]:
class ShortConv(nn.Module):
    def __init__(
        self, 
        hidden_size: int, 
        kernel_size: int = 4, 
        dilation: int = 1, 
        norm_eps: float = 1e-5,
        hc_mult: int = 4,
        activation: bool = True,
    ):
        super().__init__()
        self.hc_mult = hc_mult
        self.activation = activation
        
        total_channels = hidden_size * hc_mult
        self.conv = nn.Conv1d(
            in_channels=total_channels,
            out_channels=total_channels,
            kernel_size=kernel_size,
            groups=total_channels,
            bias=False,
            padding=(kernel_size - 1) * dilation,
            dilation=dilation,
        )

        self.norms = nn.ModuleList([
            nn.RMSNorm(hidden_size, eps=norm_eps) 
            for _ in range(hc_mult)
        ])
        
        if self.activation:
            self.act_fn = nn.SiLU()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Input:  (B,L,HC_MULT,D)
        Output: (B,L,HC_MULT,D)
        """
        B, T, G, C = x.shape
        
        assert G == self.hc_mult, f"Input groups {G} != hc_mult {self.hc_mult}"

        normed_chunks = []
        for i in range(G):
            chunk = x[:, :, i, :]
            normed_chunks.append(self.norms[i](chunk))
        
        x_norm = torch.cat(normed_chunks, dim=-1)
        x_bct = x_norm.transpose(1, 2)
        y_bct = self.conv(x_bct)
        y_bct = y_bct[..., :T]

        if self.activation:
            y_bct = self.act_fn(y_bct)
        y = y_bct.transpose(1, 2).view(B, T, G, C).contiguous()
        
        return y

In [9]:
# 发现下一个质数
def find_next_prime(start, seen_primes):
    candidate = start + 1
    while True:
        if isprime(candidate) and candidate not in seen_primes:
            return candidate
        candidate += 1

class NgramHashMapping:
    def __init__(
        self, 
        engram_vocab_size,
        max_ngram_size,
        n_embed_per_ngram,
        n_head_per_ngram,
        layer_ids,
        tokenizer_name_or_path,
        pad_id,
        seed,  
    ):
        self.vocab_size_per_ngram = engram_vocab_size
        self.max_ngram_size = max_ngram_size
        self.n_embed_per_ngram = n_embed_per_ngram
        self.n_head_per_ngram = n_head_per_ngram
        self.pad_id = pad_id
        self.layer_ids = layer_ids

        self.compressed_tokenizer = CompressedTokenizer(
            tokenizer_name_or_path=tokenizer_name_or_path
        )            
        self.tokenizer_vocab_size = len(self.compressed_tokenizer)
        if self.pad_id is not None:
            self.pad_id = int(self.compressed_tokenizer.lookup_table[self.pad_id])

        
        # int64能表示的最大值
        max_long = np.iinfo(np.int64).max
        # 除以压缩后词表大小（确保乘积不会溢出）
        M_max = int(max_long // self.tokenizer_vocab_size)
        # half_bound是后续生成随机数的上界，为了确保乘子是奇数（后续会2r+1），提前除以2，防止溢出
        half_bound = max(1, M_max // 2)
        PRIME_1 = 10007
        
        # 记录每一层不同ngram的乘子
        self.layer_multipliers = {}

        for layer_id in self.layer_ids:
            # 为每一层生成一个随机数种子
            base_seed = int(seed + PRIME_1 * int(layer_id))
            g = np.random.default_rng(base_seed)
            # 生成max_ngram_size个随机数，范围是[0, half_bound)
            r = g.integers(
                low=0,
                high=half_bound,
                size=(self.max_ngram_size,),
                dtype=np.int64
            )
            # 确保乘子是奇数
            multipliers = r * 2 + 1
            self.layer_multipliers[layer_id] = multipliers

        self.vocab_size_across_layers = self.calculate_vocab_size_across_layers()

    def calculate_vocab_size_across_layers(self):
        # 记录已生成的质数，防止重复
        seen_primes = set()
        
        # 记录所有层所有ngram的不同头的词表大小
        vocab_size_across_layers = {}
        
        for layer_id in self.layer_ids:
            # 记录当前层所有ngram（2-gram、3-gram、...）的不同头的词表大小
            all_ngram_vocab_sizes = []
            for ngram in range(2, self.max_ngram_size + 1):
                current_ngram_heads_sizes = []
                # 获取ngram的基础词表大小
                vocab_size = self.vocab_size_per_ngram[ngram - 2]
                # ngram头数
                num_head = self.n_head_per_ngram
                # 从基础词表大小开始寻找下一个质数
                current_prime_search_start = vocab_size - 1
                
                # 为每个头生成一个质数（这个质数比基础词表略大，可以理解成当前头对应的虚拟词表）
                for _ in range(num_head):
                    found_prime = find_next_prime(
                        current_prime_search_start, 
                        seen_primes
                    )
                    seen_primes.add(found_prime)
                    current_ngram_heads_sizes.append(found_prime)
                    current_prime_search_start = found_prime
                
                all_ngram_vocab_sizes.append(current_ngram_heads_sizes)
            vocab_size_across_layers[layer_id] = all_ngram_vocab_sizes
            
        return vocab_size_across_layers

    def _get_ngram_hashes(
        self,
        input_ids: np.ndarray,
        layer_id: int,
    ) -> np.ndarray:
        x = np.asarray(input_ids, dtype=np.int64)
        B, T = x.shape

        # 获取当前层所有ngram的乘子
        multipliers = self.layer_multipliers[layer_id]

        # 用于处理ngram的偏移
        def shift_k(k: int) -> np.ndarray:
            if k == 0: return x
            shifted = np.pad(x, ((0, 0), (k, 0)),
                                mode='constant', constant_values=self.pad_id)[:, :T]
            return shifted

        # 计算当前层所有ngram的偏移结果
        base_shifts = [shift_k(k) for k in range(self.max_ngram_size)]

        # all_hashes存储的是某一层所有ngram所有头的hash值
        all_hashes = []
        
        for n in range(2, self.max_ngram_size + 1):
            n_gram_index = n - 2
            # 获取ngram对应的tokens
            tokens = base_shifts[:n]
            # 将ngram内的token_id乘以对应的乘子并异或，得到一个值mix（这个值可能会很大，如果对每个mix值生成一个embedding，范围将会非常大，几十亿甚至几百亿维的 embedding 表，根本存不下，所以需要对其取模压缩到固定范围内）
            mix = (tokens[0] * multipliers[0])
            for k in range(1, n):
                mix = np.bitwise_xor(mix, tokens[k] * multipliers[k])
            num_heads_for_this_ngram = self.n_head_per_ngram
            head_vocab_sizes = self.vocab_size_across_layers[layer_id][n_gram_index]
            
            # 对mix值进行取模，模即是之前每个token对应的质数，即每个头的词表大小，这样就可以将mix值压缩到虚拟词表的范围内
            # 为了减缓冲突，每个ngram会有多个头，每个头对应一个不同的质数
            # 两个ngram只有当所有头对应的hash值都相等才会冲突，这个概率非常低
            for j in range(num_heads_for_this_ngram):
                mod = int(head_vocab_sizes[j])
                head_hash = mix % mod
                all_hashes.append(head_hash.astype(np.int64, copy=False))
        
        return np.stack(all_hashes, axis=2)

    def hash(self, input_ids):
        # 先压缩（原词表映射到新词表）
        input_ids = self.compressed_tokenizer(input_ids)
        # 存储所有层的hash值
        hash_ids_for_all_layers = {}
        for layer_id in self.layer_ids:
            hash_ids_for_all_layers[layer_id] = self._get_ngram_hashes(input_ids, layer_id=layer_id)
        return hash_ids_for_all_layers

In [30]:
hash_mapping = NgramHashMapping(
            engram_vocab_size=engram_cfg.engram_vocab_size,
            max_ngram_size = engram_cfg.max_ngram_size,
            n_embed_per_ngram = engram_cfg.n_embed_per_ngram,
            n_head_per_ngram = engram_cfg.n_head_per_ngram,
            layer_ids = engram_cfg.layer_ids,
            tokenizer_name_or_path=engram_cfg.tokenizer_name_or_path,
            pad_id = engram_cfg.pad_id,
            seed = engram_cfg.seed)

hash_mapping.vocab_size_across_layers

{1: [[758339, 758341, 758357, 758363, 758383, 758393, 758411, 758431],
  [758441, 758449, 758453, 758491, 758501, 758503, 758519, 758521]],
 15: [[758551, 758561, 758573, 758579, 758599, 758617, 758629, 758633],
  [758671, 758687, 758699, 758707, 758711, 758713, 758729, 758731]]}

In [32]:
input_ids = np.array([[101, 2000, 2022, 1037, 2204, 2154, 102]])
hash_input_ids = hash_mapping.hash(input_ids)
print('hash_input_ids', hash_input_ids)


hash_input_ids {1: array([[[755599,  31554, 689651, 733106, 369655, 302899, 739642, 655782,
         494078, 275297, 230074, 511023, 747045, 712517,  62185, 692641],
        [400072, 438427, 147140, 394998, 365056, 320358, 730452, 183936,
          85194,  24976,  14144, 344290,  93743, 735087, 660096, 189452],
        [598914, 732143, 362602, 197311, 690520, 603491, 305037,   3385,
         116786, 360635, 553188,  65963, 432618, 350759, 107148, 317562],
        [141012,  29587, 270597, 121304,  21465, 199878, 750317, 584124,
          75068, 554830, 386574, 575367, 654789, 542405, 227831, 640827],
        [219573, 634069, 308857, 286706, 118050, 151360, 645674, 143477,
         548738, 752325, 183427,  42745, 149888, 707116, 366234, 102811],
        [648509,  56564, 191684, 199396, 674053, 294194,  87774, 393810,
         428917, 574008,  63545, 181506,   3075,  16603, 410713, 116337],
        [155668, 690406,  79215,  10545, 706724, 479947, 197641, 398283,
         512050, 476935, 5

In [33]:
class MultiHeadEmbedding(nn.Module):
    def __init__(self, list_of_N: List[int], D: int):
        super().__init__()
        # list_of_N为每个头的词表列表
        self.num_heads = len(list_of_N)
        self.embedding_dim = D
        
        # ngram在每个头内的索引是[0,vocal_size]之间,
        # 由于要把所有头的词表放在同一个大Embedding里，所以需要对每个头的索引进行偏移
        offsets = [0]
        for n in list_of_N[:-1]:
            offsets.append(offsets[-1] + n)
        
        self.register_buffer("offsets", torch.tensor(offsets, dtype=torch.long))
        print('offsets:', offsets)
        # 所有头的词表大小总和总词表大小
        total_N = sum(list_of_N)
        print('总词表大小:', total_N)
        self.embedding = nn.Embedding(num_embeddings=total_N, embedding_dim=D)

    def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
        shifted_input_ids = input_ids + self.offsets
        output = self.embedding(shifted_input_ids)
        
        return output

In [34]:
list_of_N = [x for y in hash_mapping.vocab_size_across_layers[1] for x in y]
print('list_of_N:', list_of_N)

list_of_N: [758339, 758341, 758357, 758363, 758383, 758393, 758411, 758431, 758441, 758449, 758453, 758491, 758501, 758503, 758519, 758521]


In [35]:
multi_head_embedding = MultiHeadEmbedding(
        list_of_N = [x for y in hash_mapping.vocab_size_across_layers[1] for x in y],
        D = engram_cfg.n_embed_per_ngram // engram_cfg.n_head_per_ngram,
    )
print('multi_head_embedding:', multi_head_embedding(torch.from_numpy(hash_input_ids[1])).shape)


offsets: [0, 758339, 1516680, 2275037, 3033400, 3791783, 4550176, 5308587, 6067018, 6825459, 7583908, 8342361, 9100852, 9859353, 10617856, 11376375]
总词表大小: 12134896
multi_head_embedding: torch.Size([1, 7, 16, 64])


In [36]:
class Engram(nn.Module):
    def __init__(self, layer_id):
        super().__init__()
        self.layer_id = layer_id
        self.hash_mapping = NgramHashMapping(
            engram_vocab_size=engram_cfg.engram_vocab_size,
            max_ngram_size = engram_cfg.max_ngram_size,
            n_embed_per_ngram = engram_cfg.n_embed_per_ngram,
            n_head_per_ngram = engram_cfg.n_head_per_ngram,
            layer_ids = engram_cfg.layer_ids,
            tokenizer_name_or_path=engram_cfg.tokenizer_name_or_path,
            pad_id = engram_cfg.pad_id,
            seed = engram_cfg.seed,
        )
        # 每个层对应一个multi_head_embedding
        self.multi_head_embedding = MultiHeadEmbedding(
            list_of_N = [x for y in self.hash_mapping.vocab_size_across_layers[self.layer_id] for x in y],
            D = engram_cfg.n_embed_per_ngram // engram_cfg.n_head_per_ngram,
        )
        self.short_conv = ShortConv(
            hidden_size = backbone_config.hidden_size,
            kernel_size = engram_cfg.kernel_size,
            dilation    = engram_cfg.max_ngram_size,
            hc_mult     = backbone_config.hc_mult,
        )
        # ngram的数量（2-gram、3-gram等）乘以每个ngram的embedding维度（2*n_embed_per_ngram，n_embed_per_ngram=n_head_per_ngram*D）
        engram_hidden_size = (engram_cfg.max_ngram_size-1) * engram_cfg.n_embed_per_ngram
        print('engram_hidden_size:', engram_hidden_size)
        
        # V的投影矩阵
        self.value_proj = nn.Linear(engram_hidden_size,backbone_config.hidden_size)
        # K的投影矩阵
        self.key_projs = nn.ModuleList(
            [nn.Linear(engram_hidden_size,backbone_config.hidden_size) for _ in range(backbone_config.hc_mult)]
        )
        self.norm1 = nn.ModuleList([nn.RMSNorm(backbone_config.hidden_size) for _ in range(backbone_config.hc_mult)])
        self.norm2 = nn.ModuleList([nn.RMSNorm(backbone_config.hidden_size) for _ in range(backbone_config.hc_mult)])
    
    def forward(self,hidden_states,input_ids):
        """
        hidden_states: [B, L, HC_MULT, D]
        input_ids: [B, L]
        """
        print('input_ids shape:', input_ids.shape)
        # 原始input_ids经过压缩，和多头hash，每个原始token得到多个hash值（该token的2-gram、3gram在不同头对应的索引）
        hash_input_ids = torch.from_numpy(self.hash_mapping.hash(input_ids)[self.layer_id])
        print('hash_input_ids shape:', hash_input_ids.shape)
        # 根据得到的hash索引查找对应索引的embedding
        embeddings = self.multi_head_embedding(hash_input_ids).flatten(start_dim=-2)
        print('embeddings shape:', embeddings.shape)
        gates = []
        for hc_idx in range(backbone_config.hc_mult):
            # embeddings经过K的投影矩阵得到K
            key = self.key_projs[hc_idx](embeddings)
            normed_key = self.norm1[hc_idx](key)
            query = hidden_states[:,:,hc_idx,:]
            # Q由隐藏层得到（包含上下文信息）
            normed_query = self.norm2[hc_idx](query)
            gate = (normed_key * normed_query).sum(dim=-1) / math.sqrt(backbone_config.hidden_size)
            gate = gate.abs().clamp_min(1e-6).sqrt() * gate.sign()
            gate = gate.sigmoid().unsqueeze(-1)
            print('gate shape:', gate.shape)
            gates.append(gate)
         
        gates = torch.stack(gates,dim=2)
        print('gates shape:', gates.shape)
        # embeddings经过V的投影矩阵得到V
        value = gates * self.value_proj(embeddings).unsqueeze(2)
        output = value + self.short_conv(value)
        return output

In [37]:
engram = Engram(layer_id=1)
hidden_states = torch.randn(1, 6, 4, 1024)
input_ids = torch.randint(0, 10000, (1, 6))
output = engram(hidden_states, input_ids)
print(output.shape)

offsets: [0, 758339, 1516680, 2275037, 3033400, 3791783, 4550176, 5308587, 6067018, 6825459, 7583908, 8342361, 9100852, 9859353, 10617856, 11376375]
总词表大小: 12134896
engram_hidden_size: 1024
input_ids shape: torch.Size([1, 6])
hash_input_ids shape: torch.Size([1, 6, 16])
embeddings shape: torch.Size([1, 6, 1024])
gate shape: torch.Size([1, 6, 1])
gate shape: torch.Size([1, 6, 1])
gate shape: torch.Size([1, 6, 1])
gate shape: torch.Size([1, 6, 1])
gates shape: torch.Size([1, 6, 4, 1])
torch.Size([1, 6, 4, 1024])


In [17]:
class TransformerBlock(nn.Module):
    def __init__(self,layer_id):
        super().__init__()
        self.attn = lambda x:x
        self.moe  = lambda x:x
        self.engram = None
        if layer_id in engram_cfg.layer_ids:
            self.engram = Engram(layer_id=layer_id)
    
    def forward(self,input_ids,hidden_states):
        if self.engram is not None:
            hidden_states = self.engram(hidden_states=hidden_states,input_ids=input_ids) + hidden_states
        hidden_states = self.attn(hidden_states) + hidden_states
        hidden_states = self.moe(hidden_states) + hidden_states
        return hidden_states

In [18]:
LLM = [
        nn.Embedding(backbone_config.vocab_size,backbone_config.hidden_size),
        *[TransformerBlock(layer_id=layer_id) for layer_id in range(backbone_config.num_layers)],
        nn.Linear(backbone_config.hidden_size, backbone_config.vocab_size)
    ]

text = "Only Alexander the Great could tame the horse Bucephalus."
tokenizer = AutoTokenizer.from_pretrained(engram_cfg.tokenizer_name_or_path,trust_remote_code=True)
input_ids = tokenizer(text,return_tensors='pt').input_ids

B,L = input_ids.shape

for idx, layer in enumerate(LLM):
    if idx == 0:
        hidden_states = LLM[0](input_ids)
        ## mock hyper-connection
        hidden_states = hidden_states.unsqueeze(2).expand(-1, -1, backbone_config.hc_mult, -1)      
    elif idx == len(LLM)-1:
        ## mock hyper-connection
        hidden_states = hidden_states[:,:,0,:] 
        output = layer(hidden_states)
    else:
        hidden_states = layer(input_ids=input_ids,hidden_states=hidden_states)

print("✅ Forward Complete!")
print(f"{input_ids.shape=}\n{output.shape=}")

offsets: [0, 758339, 1516680, 2275037, 3033400, 3791783, 4550176, 5308587, 6067018, 6825459, 7583908, 8342361, 9100852, 9859353, 10617856, 11376375]
总词表大小: 12134896
engram_hidden_size: 1024
offsets: [0, 758551, 1517112, 2275685, 3034264, 3792863, 4551480, 5310109, 6068742, 6827413, 7586100, 8344799, 9103506, 9862217, 10620930, 11379659]
总词表大小: 12138390
engram_hidden_size: 1024
input_ids shape: torch.Size([1, 13])
hash_input_ids shape: torch.Size([1, 13, 16])
embeddings shape: torch.Size([1, 13, 1024])
gate shape: torch.Size([1, 13, 1])
gate shape: torch.Size([1, 13, 1])
gate shape: torch.Size([1, 13, 1])
gate shape: torch.Size([1, 13, 1])
gates shape: torch.Size([1, 13, 4, 1])
input_ids shape: torch.Size([1, 13])
hash_input_ids shape: torch.Size([1, 13, 16])
embeddings shape: torch.Size([1, 13, 1024])
gate shape: torch.Size([1, 13, 1])
gate shape: torch.Size([1, 13, 1])
gate shape: torch.Size([1, 13, 1])
gate shape: torch.Size([1, 13, 1])
gates shape: torch.Size([1, 13, 4, 1])
✅ Forwar