In [1]:
from llm_bases.chatglm6b import ChatGML6B
glm = ChatGML6B()

  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 8/8 [00:08<00:00,  1.09s/it]


In [3]:
import torch
import torch.nn.functional as F
import numpy as np
from typing import List

In [4]:
from split_llm.glm6b.wrapped_layer import Attention_GLM_Wrapped, FeedForward_GLM_Wrapped, copy_attention, copy_feedforward
from split_llm.common.torch_utils import relative_error


In [5]:
raw_glm_layers = glm.condgen.transformer.layers
attentions: List[Attention_GLM_Wrapped] = []
ffs: List[FeedForward_GLM_Wrapped] = []
for i in range(28):
    transformer_layer = raw_glm_layers[i].float()
    attn_wrapped = Attention_GLM_Wrapped(4096, 32, i)
    copy_attention(transformer_layer, attn_wrapped)
    attentions.append(attn_wrapped.cuda())

    ff_wrapped = FeedForward_GLM_Wrapped(4096, 32, i)
    if i == 27:
        copy_feedforward(transformer_layer, None, ff_wrapped)
    else:
        copy_feedforward(transformer_layer, raw_glm_layers[i + 1].float(), ff_wrapped)
    ffs.append(ff_wrapped.cuda())

In [6]:
from split_llm.glm6b.utils import generate_attention_mask, generate_position_ids

In [7]:
# Convert the GLM model weights (embeddings at the beginning and the end) into CUDA-float

glm.condgen.transformer.word_embeddings.float().cuda()
glm.condgen.lm_head.float().cuda()

layernorm0 = glm.condgen.transformer.layers[0].input_layernorm.float().cuda()

In [8]:
def analyze_hidden_scales(query: str, generation_length: int):
    input_ids, position_ids, attention_masks = glm.get_tokenization(query)
    input_ids = input_ids.cuda()
    position_ids = position_ids.cuda()
    attention_masks = attention_masks.cuda()

    original_length = len(input_ids[0])
    
    all_scales = []
    for i in range(generation_length):
        # print(f"Generate {i}-th token")
        # Here the first layernorm is moved
        h = layernorm0(glm.get_initial_state(input_ids))
        layer_scales = []
        for j in range(28):
            # print("Layer", j)
            # Forward to the i-th attention layer    
            scale_h_in = h.abs().max().item()

            # Attention module
            attn_layer: Attention_GLM_Wrapped = attentions[j]
            q, k, v = attn_layer.generate_qkv(h, position_ids)
            scale_v = v.abs().max().item()  # Record the scale of V
            attn_scores = attn_layer.generate_logit_scores(q, k)
            scale_attn_score = attn_scores.abs().max().item()  # Record the scale since here the plaintext is used
            softmax_scores = attn_layer.generate_softmax_scores(attn_scores)
            weighted_v = attn_layer.generate_weighted_values(softmax_scores, v)
            attn_out = weighted_v @ attn_layer.attn_out_weight.T + attn_layer.attn_out_bias
            # Feedforward module
            ff_layer: FeedForward_GLM_Wrapped = ffs[j]
            h0 = ff_layer.layernorm_in(attn_out + h * (2 * 28) ** 0.5)
            h1 = ff_layer.mlp_dense_in(h0)
            
            scale_mlp_hidden = h1.abs().max().item()
            h2 = F.gelu(h1)

            #  h2 = gelu_openai(h1)
            #  Those two gelu implementations do not have significant difference
            h3 = ff_layer.mlp_dense_out(h2)
            h4 = h3 + ff_layer.residual_coef * h0
            h5 = ff_layer.layernorm_out(h4)

            h = h5

            layer_scales.append([scale_h_in, scale_v, scale_attn_score, scale_mlp_hidden])

        logits = glm.condgen.lm_head(h).permute(1, 0, 2).contiguous()[0, -1, :glm.n_tokens]
        # Get the logits on the next position
        next_id = torch.argmax(logits)
        # print("Next ID", next_id)

        if next_id == glm.condgen.generation_config.eos_token_id:
            break
        input_ids = torch.cat([input_ids, torch.tensor([[next_id]]).cuda()], dim=-1)  # Append the last id
        new_seq_len = len(input_ids[0])
        position_ids = generate_position_ids(original_length, new_seq_len).cuda()
        # print(position_ids)
        all_scales.append(layer_scales)
    
    token_ids = input_ids[0].tolist()

    print(glm.decode(token_ids))
    return np.array(all_scales)

In [10]:
scales = analyze_hidden_scales("Who are you?", 300)


Who are you? I am an AI assistant named ChatGLM-6B, which is developed based on the language model jointly trained by Tsinghua University KEG Lab and Zhipu AI Company in 2023. My job is to provide appropriate answers and support to users' questions and requests.


In [11]:
print(np.max(scales, axis=0))

[[ 41.12500763  19.47050858 537.60595703  29.79353333]
 [ 60.0743866   19.36989403 642.50958252  18.11071968]
 [ 60.00803757  14.69905281 214.58674622  13.34783936]
 [ 59.92230606  13.67966652 319.72744751  14.88466072]
 [ 61.09487915  12.17244339 292.53955078  12.44941902]
 [ 62.06698227  11.19873619 290.99072266  19.49498749]
 [ 63.56050873  14.38836288 272.722229    13.93253326]
 [ 55.89038849  13.50420952 393.37161255  14.81556225]
 [ 54.33137894  12.17324829 450.45004272  13.13796043]
 [ 60.39410019  12.61477852 404.82476807  18.34564018]
 [ 63.29088974  10.14312744 348.49676514  33.34929657]
 [ 63.05445099  11.6084137  412.60134888  16.34067917]
 [ 68.44353485  10.68967724 333.69073486  25.55764198]
 [ 65.80563354  12.22967529 402.78411865  15.27272034]
 [ 67.4360733   10.89254761 258.52587891  15.60247993]
 [ 62.42052078  10.16563988 295.77932739  13.91898155]
 [ 60.87455368  11.46492958 523.45654297  14.8211813 ]
 [ 54.4107666   10.77226257 172.12240601  13.12202263]
 [ 52.1065

In [None]:
original_scale = 10000_4214

torch.tensor(original_scale + 34.52) - torch.tensor(original_scale)

In [5]:
# Check the scale of important weights
# This is helpful for deciding the mask_size


for i, (attn, ff) in enumerate(zip(attentions, ffs)):
    print(f"Layer {i:2d}, QKV max: {torch.max(attn.qkv_weight).item():4.4f}, AttnOut max: {torch.max(attn.attn_out_weight).item():4.4f}")

Layer  0, QKV max: 0.2466, AttnOut max: 0.6855
Layer  1, QKV max: 0.1552, AttnOut max: 0.5874
Layer  2, QKV max: 0.1327, AttnOut max: 0.3381
Layer  3, QKV max: 0.1476, AttnOut max: 0.4224
Layer  4, QKV max: 0.1367, AttnOut max: 0.2632
Layer  5, QKV max: 0.1354, AttnOut max: 0.3162
Layer  6, QKV max: 0.1494, AttnOut max: 0.6372
Layer  7, QKV max: 0.1163, AttnOut max: 0.4460
Layer  8, QKV max: 0.1791, AttnOut max: 0.6919
Layer  9, QKV max: 0.1908, AttnOut max: 0.8496
Layer 10, QKV max: 0.1379, AttnOut max: 0.3650
Layer 11, QKV max: 0.1356, AttnOut max: 0.5137
Layer 12, QKV max: 0.1398, AttnOut max: 0.3311
Layer 13, QKV max: 0.1420, AttnOut max: 0.8291
Layer 14, QKV max: 0.1332, AttnOut max: 0.6045
Layer 15, QKV max: 0.1268, AttnOut max: 0.8643
Layer 16, QKV max: 0.1398, AttnOut max: 0.3745
Layer 17, QKV max: 0.1396, AttnOut max: 0.5762
Layer 18, QKV max: 0.1377, AttnOut max: 0.3770
Layer 19, QKV max: 0.1519, AttnOut max: 0.6987
Layer 20, QKV max: 0.1560, AttnOut max: 0.6221
Layer 21, QKV

In [10]:
ffs

FeedForward_GLM_Wrapped(
  (layernorm_in): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
  (mlp_dense_in): Linear(in_features=4096, out_features=16384, bias=True)
  (mlp_dense_out): Linear(in_features=16384, out_features=4096, bias=True)
  (layernorm_out): Identity()
)