In [1]:
from llm_bases.chatglm6b import ChatGML6B
glm = ChatGML6B()

  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 8/8 [00:08<00:00,  1.10s/it]


In [2]:
import torch
import torch.nn.functional as F
import numpy as np
from typing import List

In [3]:
from split_llm.glm6b.wrapped_layer import Attention_GLM_Wrapped, FeedForward_GLM_Wrapped, copy_attention, copy_feedforward
from split_llm.common.torch_utils import relative_error


In [4]:
raw_glm_layers = glm.condgen.transformer.layers
attentions: List[Attention_GLM_Wrapped] = []
ffs: List[FeedForward_GLM_Wrapped] = []
for i in range(28):
    transformer_layer = raw_glm_layers[i].float()
    attn_wrapped = Attention_GLM_Wrapped(4096, 32, i)
    copy_attention(transformer_layer, attn_wrapped)
    attentions.append(attn_wrapped.cuda())

    ff_wrapped = FeedForward_GLM_Wrapped(4096, 32, i)
    if i == 27:
        copy_feedforward(transformer_layer, None, ff_wrapped)
    else:
        copy_feedforward(transformer_layer, raw_glm_layers[i + 1].float(), ff_wrapped)
    ffs.append(ff_wrapped.cuda())

In [5]:
from split_llm.glm6b.utils import generate_attention_mask, generate_position_ids

In [6]:
# Convert the GLM model weights (embeddings at the beginning and the end) into CUDA-float

glm.condgen.transformer.word_embeddings.float().cuda()
glm.condgen.lm_head.float().cuda()

layernorm0 = glm.condgen.transformer.layers[0].input_layernorm.float().cuda()

In [7]:
def analyze_hidden_scales(query: str, generation_length: int):
    input_ids, position_ids, attention_masks = glm.get_tokenization(query)
    input_ids = input_ids.cuda()
    position_ids = position_ids.cuda()
    attention_masks = attention_masks.cuda()

    original_length = len(input_ids[0])
    
    all_scales = []
    for i in range(generation_length):
        # print(f"Generate {i}-th token")
        # Here the first layernorm is moved
        h = layernorm0(glm.get_initial_state(input_ids))
        layer_scales = []
        for j in range(28):
            # print("Layer", j)
            # Forward to the i-th attention layer    
            scale_h_in = h.abs().max().item()

            # Attention module
            attn_layer: Attention_GLM_Wrapped = attentions[j]
            q, k, v = attn_layer.generate_qkv(h, position_ids)
            scale_v = v.abs().max().item()  # Record the scale of V
            attn_scores = attn_layer.generate_logit_scores(q, k)
            scale_attn_score = attn_scores.abs().max().item()  # Record the scale since here the plaintext is used
            softmax_scores = attn_layer.generate_softmax_scores(attn_scores)
            weighted_v = attn_layer.generate_weighted_values(softmax_scores, v)
            attn_out = weighted_v @ attn_layer.attn_out_weight.T + attn_layer.attn_out_bias
            # Feedforward module
            ff_layer: FeedForward_GLM_Wrapped = ffs[j]
            h0 = ff_layer.layernorm_in(attn_out + h * (2 * 28) ** 0.5)
            h1 = ff_layer.mlp_dense_in(h0)
            
            scale_mlp_hidden = h1.abs().max().item()
            h2 = F.gelu(h1)

            #  h2 = gelu_openai(h1)
            #  Those two gelu implementations do not have significant difference
            h3 = ff_layer.mlp_dense_out(h2)
            h4 = h3 + ff_layer.residual_coef * h0
            h5 = ff_layer.layernorm_out(h4)

            h = h5

            layer_scales.append([scale_h_in, scale_v, scale_attn_score, scale_mlp_hidden])

        logits = glm.condgen.lm_head(h).permute(1, 0, 2).contiguous()[0, -1, :glm.n_tokens]
        # Get the logits on the next position
        next_id = torch.argmax(logits)
        # print("Next ID", next_id)

        if next_id == glm.condgen.generation_config.eos_token_id:
            break
        input_ids = torch.cat([input_ids, torch.tensor([[next_id]]).cuda()], dim=-1)  # Append the last id
        new_seq_len = len(input_ids[0])
        position_ids = generate_position_ids(original_length, new_seq_len).cuda()
        # print(position_ids)
        all_scales.append(layer_scales)
    
    token_ids = input_ids[0][original_length:].tolist()
    print(token_ids)
    print(glm.decode(token_ids))
    return np.array(all_scales)

In [8]:
scales = analyze_hidden_scales("Hello", 300)

  tensor = as_tensor(value)


[19316, 5, 128788, 35, 1372, 129, 115, 4418, 120, 788, 31]
Hello 👍! How can I assist you today?


In [9]:
print(np.max(scales, axis=0))

[[41.12500763 19.31932831 25.2648735  32.29057312]
 [58.48934937 16.72546196 18.4454422  18.62054825]
 [59.89986801  8.33433437 19.08116913 13.16208076]
 [59.77453995 11.97554684 20.70600319 10.75597191]
 [60.81726837 10.28095436 19.48379517 13.32979965]
 [61.65838242 10.48189735 24.07427406 11.46224022]
 [59.78767014 13.86101818 15.82366657  9.56827068]
 [54.31472397 12.87061596 22.50842094 10.48479462]
 [54.08573151 12.18045807 26.48773003 13.21073627]
 [60.27223969 12.03366661 21.73172188 10.63670063]
 [63.42442322 11.79520321 21.94503975 13.76262283]
 [63.14886475 10.18587685 21.80389595 10.65892792]
 [68.64946747 10.64644051 23.12402916 11.56313801]
 [65.92312622 10.57049942 26.07779884 10.45457745]
 [67.54730225 10.21110439 14.49261379 12.66041183]
 [61.51429749  8.75083542 14.03304482 11.71735096]
 [58.00505829  9.42823505 22.04002571 12.25470924]
 [52.15565491  9.21329308 16.26906204 11.69371414]
 [49.33687973 11.47721958 14.51432228 13.43934155]
 [46.26225281 11.33680058 16.42

In [10]:
original_scale = 10000_4214

torch.tensor(original_scale + 34.52) - torch.tensor(original_scale)

tensor(32.)

In [11]:
# Check the scale of important weights
# This is helpful for deciding the mask_size


for i, (attn, ff) in enumerate(zip(attentions, ffs)):
    print(f"Layer {i:2d}, QKV max: {torch.max(attn.qkv_weight).item():4.4f}, AttnOut max: {torch.max(attn.attn_out_weight).item():4.4f}")

Layer  0, QKV max: 0.2466, AttnOut max: 0.6855
Layer  1, QKV max: 0.1552, AttnOut max: 0.5874
Layer  2, QKV max: 0.1327, AttnOut max: 0.3381
Layer  3, QKV max: 0.1476, AttnOut max: 0.4224
Layer  4, QKV max: 0.1367, AttnOut max: 0.2632
Layer  5, QKV max: 0.1354, AttnOut max: 0.3162
Layer  6, QKV max: 0.1494, AttnOut max: 0.6372
Layer  7, QKV max: 0.1163, AttnOut max: 0.4460
Layer  8, QKV max: 0.1791, AttnOut max: 0.6919
Layer  9, QKV max: 0.1908, AttnOut max: 0.8496
Layer 10, QKV max: 0.1379, AttnOut max: 0.3650
Layer 11, QKV max: 0.1356, AttnOut max: 0.5137
Layer 12, QKV max: 0.1398, AttnOut max: 0.3311
Layer 13, QKV max: 0.1420, AttnOut max: 0.8291
Layer 14, QKV max: 0.1332, AttnOut max: 0.6045
Layer 15, QKV max: 0.1268, AttnOut max: 0.8643
Layer 16, QKV max: 0.1398, AttnOut max: 0.3745
Layer 17, QKV max: 0.1396, AttnOut max: 0.5762
Layer 18, QKV max: 0.1377, AttnOut max: 0.3770
Layer 19, QKV max: 0.1519, AttnOut max: 0.6987
Layer 20, QKV max: 0.1560, AttnOut max: 0.6221
Layer 21, QKV

In [12]:
ffs

[FeedForward_GLM_Wrapped(
   (layernorm_in): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
   (mlp_dense_in): Linear(in_features=4096, out_features=16384, bias=True)
   (mlp_dense_out): Linear(in_features=16384, out_features=4096, bias=True)
   (layernorm_out): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
 ),
 FeedForward_GLM_Wrapped(
   (layernorm_in): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
   (mlp_dense_in): Linear(in_features=4096, out_features=16384, bias=True)
   (mlp_dense_out): Linear(in_features=16384, out_features=4096, bias=True)
   (layernorm_out): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
 ),
 FeedForward_GLM_Wrapped(
   (layernorm_in): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
   (mlp_dense_in): Linear(in_features=4096, out_features=16384, bias=True)
   (mlp_dense_out): Linear(in_features=16384, out_features=4096, bias=True)
   (layernorm_out): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
 ),
 FeedForwa