In [1]:
from llm_bases.chatglm6b import ChatGML6B
glm = ChatGML6B()

  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 8/8 [00:09<00:00,  1.25s/it]


In [2]:
import torch
import torch.nn.functional as F
import numpy as np
from typing import List

In [3]:
from split_llm.glm6b.wrapped_layer import Attention_GLM_Wrapped, FeedForward_GLM_Wrapped, copy_attention, copy_feedforward
from split_llm.common.torch_utils import relative_error


In [4]:
raw_glm_layers = glm.condgen.transformer.layers
attentions: List[Attention_GLM_Wrapped] = []
ffs: List[FeedForward_GLM_Wrapped] = []
for i in range(28):
    transformer_layer = raw_glm_layers[i].float()
    attn_wrapped = Attention_GLM_Wrapped(4096, 32, i)
    copy_attention(transformer_layer, attn_wrapped)
    attentions.append(attn_wrapped.cuda())

    ff_wrapped = FeedForward_GLM_Wrapped(4096, 32, i)
    if i == 27:
        copy_feedforward(transformer_layer, None, ff_wrapped)
        ff_wrapped.layernorm_out = glm.condgen.transformer.final_layernorm.float().cuda()
    else:
        copy_feedforward(transformer_layer, raw_glm_layers[i + 1].float(), ff_wrapped)
    ffs.append(ff_wrapped.cuda())


In [5]:
from split_llm.glm6b.utils import generate_attention_mask, generate_position_ids

In [6]:
# Convert the GLM model weights (embeddings at the beginning and the end) into CUDA-float

glm.condgen.transformer.word_embeddings.float().cuda()
glm.condgen.lm_head.float().cuda()

layernorm0 = glm.condgen.transformer.layers[0].input_layernorm.float().cuda()

In [7]:
def analyze_hidden_scales(query: str, generation_length: int):
    input_ids, position_ids, attention_masks = glm.get_tokenization(query)
    
    
    original_length = len(input_ids[0])
    input_ids = input_ids.cuda()
    start_id = input_ids[0, -1]
    input_ids = input_ids[:, :-1]
    position_ids = position_ids.cuda()[:, :, :-1]
    # The start_id has an attention mask

    
    all_scales = []
    previous_k = dict()
    previous_v = dict()

    predicted_ids = []

    for i in range(generation_length):
        # print(f"Generate {i}-th token")
        # Here the first layernorm is moved
        h = layernorm0(glm.get_initial_state(input_ids))
        layer_scales = []
        for j in range(28):
            # print("Layer", j)
            # Forward to the i-th attention layer    
            scale_h_in = h.abs().max().item()

            # Attention module
            attn_layer: Attention_GLM_Wrapped = attentions[j]
            q, k, v = attn_layer.generate_qkv(h, position_ids)
            if j not in previous_k:
                previous_k[j] = k
            else:
                previous_k[j] = torch.cat([previous_k[j], k], dim=0)
                k = previous_k[j]

            
            if j not in previous_v:
                previous_v[j] = v
            else:
                previous_v[j] = torch.cat([previous_v[j], v], dim=0)
                v = previous_v[j]


            scale_v = v.abs().max().item()  # Record the scale of V
            attn_scores = attn_layer.generate_logit_scores(q, k)
            scale_attn_score = attn_scores.abs().max().item()  # Record the scale since here the plaintext is used
            softmax_scores = attn_layer.generate_softmax_scores(attn_scores)
            weighted_v = attn_layer.generate_weighted_values(softmax_scores, v)
            attn_out = weighted_v @ attn_layer.attn_out_weight.T + attn_layer.attn_out_bias
            # Feedforward module
            ff_layer: FeedForward_GLM_Wrapped = ffs[j]
            h0 = ff_layer.layernorm_in(attn_out + h * (2 * 28) ** 0.5)
            h1 = ff_layer.mlp_dense_in(h0)
            
            scale_mlp_hidden = h1.abs().max().item()
            h2 = F.gelu(h1)

            #  h2 = gelu_openai(h1)
            #  Those two gelu implementations do not have significant difference
            h3 = ff_layer.mlp_dense_out(h2)
            h4 = h3 + ff_layer.residual_coef * h0
            h5 = ff_layer.layernorm_out(h4)

            h = h5

            layer_scales.append([scale_h_in, scale_v, scale_attn_score, scale_mlp_hidden])

        logits = glm.condgen.lm_head(h).permute(1, 0, 2).contiguous()[0, -1]
        # Get the logits on the next position
        if start_id is not None:
            next_id = start_id
            start_id = None
        else:
            next_id = torch.argmax(logits)
            predicted_ids.append(next_id.item())
        # print("Next ID", next_id)

        if next_id == glm.condgen.generation_config.eos_token_id:
            break
        input_ids = torch.tensor([[next_id]]).cuda()  # Append the last id
        position_ids = generate_position_ids(original_length, original_length + len(predicted_ids))[:, :, -1:].cuda()
        # print(position_ids)
        all_scales.append(layer_scales)
    
    print(predicted_ids)
    print(glm.decode(predicted_ids))
    return np.array(all_scales)

In [8]:
scales = analyze_hidden_scales("Tell me about Trump", 300)

  tensor = as_tensor(value)


[3347, 729, 107, 104, 896, 618, 6636, 172, 1814, 114, 100, 5, 16, 15, 257, 1053, 101, 100, 494, 608, 122, 1151, 5, 10, 8, 6, 5, 10, 8, 9, 25, 6, 103, 1151, 5, 10, 8, 6, 5, 10, 8, 10, 9, 7, 256, 116, 2258, 111, 1097, 5, 9, 16, 6, 5, 9, 18, 16, 21, 6, 105, 375, 875, 928, 6, 375, 875, 7, 729, 107, 104, 16492, 102, 644, 3890, 4770, 172, 132, 156, 958, 105, 687, 611, 37188, 102, 132, 156, 100, 1319, 101, 2971, 6603, 102, 1558, 3042, 7, 4, 4, 3663, 147, 15063, 6, 729, 116, 424, 108, 147, 6781, 3418, 6, 350, 147, 2256, 103, 975, 104, 1710, 111, 100, 289, 7, 159, 7, 11, 26291, 1683, 6, 147, 12544, 111, 5638, 6, 102, 147, 3980, 101, 100, 11983, 11, 9, 18, 15190, 7, 256, 154, 5549, 2971, 37391, 102, 5768, 6, 350, 2414, 12376, 102, 5273, 6, 611, 6252, 6, 102, 1251, 5663, 7, 4, 4, 1252, 2787, 1189, 6, 729, 2037, 103, 113, 104, 6781, 1854, 6, 4111, 2971, 1558, 3042, 102, 6603, 7, 256, 132, 156, 100, 1319, 101, 2971, 22016, 6, 350, 145, 822, 103, 100, 5, 10, 8, 10, 9, 1758, 7, 130005]
Donald Trump i

In [9]:
print(np.max(scales, axis=0))

[[41.12500763 20.23776627 81.44577026 32.46224976]
 [58.23212051 19.0498085  66.00358582 15.94203758]
 [59.32036209 16.80878639 23.91127205 14.46641922]
 [59.49517441 15.66162777 34.56282043 14.58930016]
 [60.65198517 13.19166183 23.30387306 11.98528004]
 [61.47579956 12.6825695  25.85115814 16.09440613]
 [58.38166428 13.28290176 23.06168938 15.86229134]
 [54.70422363 10.57917023 40.04846954 15.8881321 ]
 [54.81056595 12.15571213 43.60208893 15.75042915]
 [60.7454834  12.19743443 42.10726547 13.49996185]
 [63.45518112 10.58491707 31.90699005 14.45426464]
 [63.10396957 10.61173725 29.22563553 12.54403114]
 [68.5146637  11.26670647 35.01483154 13.34628487]
 [65.79148102  9.85550594 37.75686264 14.56965637]
 [67.33185577 12.64821339 25.17937469 15.72464943]
 [63.43676376 11.89875889 22.96869659 13.36803341]
 [63.79126358 12.93599415 43.95756149 16.05618668]
 [57.95978546 10.56634521 20.16744423 17.47854233]
 [55.22901535 10.91130257 19.68815041 18.04890823]
 [52.47626877 14.05230904 29.36

In [10]:
original_scale = 10000_4214

torch.tensor(original_scale + 34.52) - torch.tensor(original_scale)

tensor(32.)

In [11]:
# Check the scale of important weights
# This is helpful for deciding the mask_size


for i, (attn, ff) in enumerate(zip(attentions, ffs)):
    print(f"Layer {i:2d}, QKV max: {torch.max(attn.qkv_weight).item():4.4f}, AttnOut max: {torch.max(attn.attn_out_weight).item():4.4f}")

Layer  0, QKV max: 0.2466, AttnOut max: 0.6855
Layer  1, QKV max: 0.1552, AttnOut max: 0.5874
Layer  2, QKV max: 0.1327, AttnOut max: 0.3381
Layer  3, QKV max: 0.1476, AttnOut max: 0.4224
Layer  4, QKV max: 0.1367, AttnOut max: 0.2632
Layer  5, QKV max: 0.1354, AttnOut max: 0.3162
Layer  6, QKV max: 0.1494, AttnOut max: 0.6372
Layer  7, QKV max: 0.1163, AttnOut max: 0.4460
Layer  8, QKV max: 0.1791, AttnOut max: 0.6919
Layer  9, QKV max: 0.1908, AttnOut max: 0.8496
Layer 10, QKV max: 0.1379, AttnOut max: 0.3650
Layer 11, QKV max: 0.1356, AttnOut max: 0.5137
Layer 12, QKV max: 0.1398, AttnOut max: 0.3311
Layer 13, QKV max: 0.1420, AttnOut max: 0.8291
Layer 14, QKV max: 0.1332, AttnOut max: 0.6045
Layer 15, QKV max: 0.1268, AttnOut max: 0.8643
Layer 16, QKV max: 0.1398, AttnOut max: 0.3745
Layer 17, QKV max: 0.1396, AttnOut max: 0.5762
Layer 18, QKV max: 0.1377, AttnOut max: 0.3770
Layer 19, QKV max: 0.1519, AttnOut max: 0.6987
Layer 20, QKV max: 0.1560, AttnOut max: 0.6221
Layer 21, QKV

In [12]:
ffs

[FeedForward_GLM_Wrapped(
   (layernorm_in): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
   (mlp_dense_in): Linear(in_features=4096, out_features=16384, bias=True)
   (mlp_dense_out): Linear(in_features=16384, out_features=4096, bias=True)
   (layernorm_out): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
 ),
 FeedForward_GLM_Wrapped(
   (layernorm_in): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
   (mlp_dense_in): Linear(in_features=4096, out_features=16384, bias=True)
   (mlp_dense_out): Linear(in_features=16384, out_features=4096, bias=True)
   (layernorm_out): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
 ),
 FeedForward_GLM_Wrapped(
   (layernorm_in): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
   (mlp_dense_in): Linear(in_features=4096, out_features=16384, bias=True)
   (mlp_dense_out): Linear(in_features=16384, out_features=4096, bias=True)
   (layernorm_out): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
 ),
 FeedForwa