In [1]:
from llm_bases.chatglm6b import ChatGML6B
glm = ChatGML6B()

  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 8/8 [00:12<00:00,  1.54s/it]


In [2]:
print(glm.condgen.lm_head.weight)

Parameter containing:
tensor([[-8.4076e-03, -9.3689e-03, -5.4436e-03,  ..., -6.4545e-03,
          1.6998e-02,  1.1108e-02],
        [-1.0071e-02,  5.8022e-03,  4.8018e-04,  ..., -4.2701e-04,
          1.0252e-03, -1.6556e-03],
        [ 1.9424e-02,  6.3477e-03,  2.4933e-02,  ...,  5.7297e-03,
          1.2512e-02,  9.4147e-03],
        ...,
        [-1.0078e-02,  3.0041e-03,  2.4376e-03,  ..., -4.7684e-06,
          1.5430e-03,  1.1053e-03],
        [-9.9945e-03,  4.4479e-03,  6.2141e-03,  ...,  1.7560e-04,
          1.4286e-03, -1.1883e-03],
        [-9.3536e-03,  1.7376e-03,  5.7373e-03,  ..., -1.0910e-03,
          4.3945e-03, -1.2541e-03]])


In [3]:
import torch
import torch.nn.functional as F
import numpy as np
import pandas as pd
from typing import List

In [4]:
from perm_llm.glm6b.wrapped_layer import Attention_GLM_Wrapped, FeedForward_GLM_Wrapped, copy_attention, copy_feedforward
from perm_llm.common.torch_utils import relative_error


In [5]:
raw_glm_layers = glm.condgen.transformer.layers
attentions: List[Attention_GLM_Wrapped] = []
ffs: List[FeedForward_GLM_Wrapped] = []
for i in range(28):
    transformer_layer = raw_glm_layers[i].float()
    attn_wrapped = Attention_GLM_Wrapped(4096, 32, i)
    copy_attention(transformer_layer, attn_wrapped)
    attentions.append(attn_wrapped.cuda())

    ff_wrapped = FeedForward_GLM_Wrapped(4096, 32, i)
    if i == 27:
        copy_feedforward(transformer_layer, None, ff_wrapped)
        ff_wrapped.layernorm_out = glm.condgen.transformer.final_layernorm.float().cuda()
    else:
        copy_feedforward(transformer_layer, raw_glm_layers[i + 1].float(), ff_wrapped)
    ffs.append(ff_wrapped.cuda())


In [6]:
from perm_llm.glm6b.utils import generate_attention_mask, generate_position_ids

In [7]:
# Convert the GLM model weights (embeddings at the beginning and the end) into CUDA-float

glm.condgen.transformer.word_embeddings.float().cuda()

print(f"Embedding Scale: {glm.condgen.transformer.word_embeddings.weight.std().item():8.4f}")

glm.condgen.lm_head.float().cuda()

layernorm0 = glm.condgen.transformer.layers[0].input_layernorm.float().cuda()

Embedding Scale:   0.0109


In [21]:
def analyze_hidden_scales(query: str, generation_length: int):
    input_ids, position_ids, attention_masks = glm.get_tokenization(query)
    
    
    original_length = len(input_ids[0])
    input_ids = input_ids.cuda()
    start_id = input_ids[0, -1]
    input_ids = input_ids[:, :-1]
    position_ids = position_ids.cuda()[:, :, :-1]
    # The start_id has an attention mask

    
    all_scales = []
    previous_k = dict()
    previous_v = dict()

    predicted_ids = []

    for i in range(generation_length):
        # print(f"Generate {i}-th token")
        # Here the first layernorm is moved
        current_scales = dict()

        initial_state = glm.get_initial_state(input_ids)
        current_scales["word_embedding"] = initial_state.std().item()

        h = layernorm0(initial_state)
        
        for j in range(28):
            scales = dict()
            # print("Layer", j)
            # Forward to the i-th attention layer    
            scales["h_in"] = h.std().item()

            # Attention module
            attn_layer: Attention_GLM_Wrapped = attentions[j]
            scales["qkv_linear"] = attn_layer.qkv_weight.std().item()

            
            q, k, v = attn_layer.generate_qkv(h, position_ids)

            
            if j not in previous_k:
                previous_k[j] = k
            else:
                previous_k[j] = torch.cat([previous_k[j], k], dim=0)
                k = previous_k[j]

            
            if j not in previous_v:
                previous_v[j] = v
            else:
                previous_v[j] = torch.cat([previous_v[j], v], dim=0)
                v = previous_v[j]


            scales["v"] = max(q.std().item(), k.std().item(), v.std().item())  # Record the scale of V

            attn_scores = attn_layer.generate_logit_scores(q, k)
            scales["attn_scores"] = attn_scores.std().item()  # Record the scale since here the plaintext is used

            softmax_scores = attn_layer.generate_softmax_scores(attn_scores)
            weighted_v = attn_layer.generate_weighted_values(softmax_scores, v)

            attn_out = weighted_v @ attn_layer.attn_out_weight.T + attn_layer.attn_out_bias
            scales["attn_out_linear"] = attn_layer.attn_out_weight.std().item()
            scales["attn_out"] = attn_out.std().item()

            # Feedforward module
            ff_layer: FeedForward_GLM_Wrapped = ffs[j]
            
            h_ff_in = attn_out + h * (2 * 28) ** 0.5
            scales["ff_h_in"] = h_ff_in.std().item()
            h0 = ff_layer.layernorm_in(h_ff_in)

            h1 = ff_layer.mlp_dense_in(h0)
            scales["mlp_dense_in"] = ff_layer.mlp_dense_in.weight.std().item()
            
            scales["mlp_hidden"] = h1.std().item()
            h2 = F.gelu(h1)

            #  h2 = gelu_openai(h1)
            #  Those two gelu implementations do not have significant difference
            h3 = ff_layer.mlp_dense_out(h2)
            scales["mlp_dense_out"] = ff_layer.mlp_dense_out.weight.std().item()
            scales["mlp_out"] = h3.std().item()

            h4 = h3 + ff_layer.residual_coef * h0

            scales["ff_out"] = h4.std().item()
            h5 = ff_layer.layernorm_out(h4)

            h = h5

            current_scales.update({f"layer_{j}-{k}": scales[k] for k in scales})

        logits = glm.condgen.lm_head(h).permute(1, 0, 2).contiguous()[0, -1]
        current_scales["logits"] = logits.std().item()
        # Get the logits on the next position
        if start_id is not None:
            next_id = start_id
            start_id = None
        else:
            next_id = torch.argmax(logits)
            predicted_ids.append(next_id.item())
        # print("Next ID", next_id)

        if next_id == glm.condgen.generation_config.eos_token_id:
            break
        input_ids = torch.tensor([[next_id]]).cuda()  # Append the last id
        position_ids = generate_position_ids(original_length, original_length + len(predicted_ids))[:, :, -1:].cuda()
        # print(position_ids)
        all_scales.append(current_scales)
    
    print(predicted_ids)
    print(glm.decode(predicted_ids))
    return pd.DataFrame(all_scales)

In [22]:
scales_df = analyze_hidden_scales("Tell me about Trump", 300)

[3347, 729, 107, 104, 896, 618, 6636, 172, 1814, 114, 100, 5, 16, 15, 257, 1053, 101, 100, 494, 608, 122, 1151, 5, 10, 8, 6, 5, 10, 8, 9, 25, 6, 103, 1151, 5, 10, 8, 6, 5, 10, 8, 10, 9, 7, 256, 116, 2258, 111, 1097, 5, 9, 16, 6, 5, 9, 18, 16, 21, 6, 105, 375, 875, 928, 6, 375, 875, 7, 729, 107, 104, 16492, 102, 644, 3890, 4770, 172, 132, 156, 958, 105, 687, 611, 37188, 102, 132, 156, 100, 1319, 101, 2971, 6603, 102, 1558, 3042, 7, 4, 4, 3663, 147, 15063, 6, 729, 116, 424, 108, 147, 6781, 3418, 6, 350, 147, 2256, 103, 975, 104, 1710, 111, 100, 289, 7, 159, 7, 11, 26291, 1683, 6, 147, 12544, 111, 5638, 6, 102, 147, 3980, 101, 100, 11983, 11, 9, 18, 15190, 7, 256, 154, 5549, 2971, 37391, 102, 5768, 6, 350, 2414, 12376, 102, 5273, 6, 611, 6252, 6, 102, 1251, 5663, 7, 4, 4, 1252, 2787, 1189, 6, 729, 2037, 103, 113, 104, 6781, 1854, 6, 4111, 2971, 1558, 3042, 102, 6603, 7, 256, 132, 156, 100, 1319, 101, 2971, 22016, 6, 350, 145, 822, 103, 100, 5, 10, 8, 10, 9, 1758, 7, 130005]
Donald Trump i

In [23]:
mean_scales = scales_df.mean(axis=0)
for k in mean_scales.index:
    print(f"{k:>26}", f"{mean_scales[k]:8.4f}")

            word_embedding   0.0110
              layer_0-h_in   0.9433
        layer_0-qkv_linear   0.0128
                 layer_0-v   2.0570
       layer_0-attn_scores   4.7131
   layer_0-attn_out_linear   0.0126
          layer_0-attn_out   4.6358
           layer_0-ff_h_in   9.1073
      layer_0-mlp_dense_in   0.0156
        layer_0-mlp_hidden   1.3827
     layer_0-mlp_dense_out   0.0155
           layer_0-mlp_out   3.1483
            layer_0-ff_out  10.2210
              layer_1-h_in   1.0298
        layer_1-qkv_linear   0.0138
                 layer_1-v   1.8481
       layer_1-attn_scores   4.4332
   layer_1-attn_out_linear   0.0142
          layer_1-attn_out   2.1595
           layer_1-ff_h_in   8.5945
      layer_1-mlp_dense_in   0.0166
        layer_1-mlp_hidden   1.5104
     layer_1-mlp_dense_out   0.0166
           layer_1-mlp_out   2.0980
            layer_1-ff_out   9.4040
              layer_2-h_in   1.0057
        layer_2-qkv_linear   0.0142
                 layer_2-v  