In [1]:
from llm_bases.chatglm6b import ChatGML6B
glm = ChatGML6B()

  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 8/8 [00:08<00:00,  1.04s/it]


In [3]:
print(glm.condgen.lm_head.weight)

Parameter containing:
tensor([[-8.4076e-03, -9.3689e-03, -5.4436e-03,  ..., -6.4545e-03,
          1.6998e-02,  1.1108e-02],
        [-1.0071e-02,  5.8022e-03,  4.8018e-04,  ..., -4.2701e-04,
          1.0252e-03, -1.6556e-03],
        [ 1.9424e-02,  6.3477e-03,  2.4933e-02,  ...,  5.7297e-03,
          1.2512e-02,  9.4147e-03],
        ...,
        [-1.0078e-02,  3.0041e-03,  2.4376e-03,  ..., -4.7684e-06,
          1.5430e-03,  1.1053e-03],
        [-9.9945e-03,  4.4479e-03,  6.2141e-03,  ...,  1.7560e-04,
          1.4286e-03, -1.1883e-03],
        [-9.3536e-03,  1.7376e-03,  5.7373e-03,  ..., -1.0910e-03,
          4.3945e-03, -1.2541e-03]])


In [None]:
import torch
import torch.nn.functional as F
import numpy as np
from typing import List

In [None]:
from split_llm.glm6b.wrapped_layer import Attention_GLM_Wrapped, FeedForward_GLM_Wrapped, copy_attention, copy_feedforward
from split_llm.common.torch_utils import relative_error


In [None]:
raw_glm_layers = glm.condgen.transformer.layers
attentions: List[Attention_GLM_Wrapped] = []
ffs: List[FeedForward_GLM_Wrapped] = []
for i in range(28):
    transformer_layer = raw_glm_layers[i].float()
    attn_wrapped = Attention_GLM_Wrapped(4096, 32, i)
    copy_attention(transformer_layer, attn_wrapped)
    attentions.append(attn_wrapped.cuda())

    ff_wrapped = FeedForward_GLM_Wrapped(4096, 32, i)
    if i == 27:
        copy_feedforward(transformer_layer, None, ff_wrapped)
        ff_wrapped.layernorm_out = glm.condgen.transformer.final_layernorm.float().cuda()
    else:
        copy_feedforward(transformer_layer, raw_glm_layers[i + 1].float(), ff_wrapped)
    ffs.append(ff_wrapped.cuda())


In [None]:
from split_llm.glm6b.utils import generate_attention_mask, generate_position_ids

In [None]:
# Convert the GLM model weights (embeddings at the beginning and the end) into CUDA-float

glm.condgen.transformer.word_embeddings.float().cuda()
glm.condgen.lm_head.float().cuda()

layernorm0 = glm.condgen.transformer.layers[0].input_layernorm.float().cuda()

In [None]:
def analyze_hidden_scales(query: str, generation_length: int):
    input_ids, position_ids, attention_masks = glm.get_tokenization(query)
    
    
    original_length = len(input_ids[0])
    input_ids = input_ids.cuda()
    start_id = input_ids[0, -1]
    input_ids = input_ids[:, :-1]
    position_ids = position_ids.cuda()[:, :, :-1]
    # The start_id has an attention mask

    
    all_scales = []
    previous_k = dict()
    previous_v = dict()

    predicted_ids = []

    for i in range(generation_length):
        # print(f"Generate {i}-th token")
        # Here the first layernorm is moved
        h = layernorm0(glm.get_initial_state(input_ids))
        layer_scales = []
        for j in range(28):
            # print("Layer", j)
            # Forward to the i-th attention layer    
            scale_h_in = h.abs().max().item()

            # Attention module
            attn_layer: Attention_GLM_Wrapped = attentions[j]
            q, k, v = attn_layer.generate_qkv(h, position_ids)
            if j not in previous_k:
                previous_k[j] = k
            else:
                previous_k[j] = torch.cat([previous_k[j], k], dim=0)
                k = previous_k[j]

            
            if j not in previous_v:
                previous_v[j] = v
            else:
                previous_v[j] = torch.cat([previous_v[j], v], dim=0)
                v = previous_v[j]


            scale_v = v.abs().max().item()  # Record the scale of V
            attn_scores = attn_layer.generate_logit_scores(q, k)
            scale_attn_score = attn_scores.abs().max().item()  # Record the scale since here the plaintext is used
            softmax_scores = attn_layer.generate_softmax_scores(attn_scores)
            weighted_v = attn_layer.generate_weighted_values(softmax_scores, v)
            attn_out = weighted_v @ attn_layer.attn_out_weight.T + attn_layer.attn_out_bias
            # Feedforward module
            ff_layer: FeedForward_GLM_Wrapped = ffs[j]
            h0 = ff_layer.layernorm_in(attn_out + h * (2 * 28) ** 0.5)
            h1 = ff_layer.mlp_dense_in(h0)
            
            scale_mlp_hidden = h1.abs().max().item()
            h2 = F.gelu(h1)

            #  h2 = gelu_openai(h1)
            #  Those two gelu implementations do not have significant difference
            h3 = ff_layer.mlp_dense_out(h2)
            h4 = h3 + ff_layer.residual_coef * h0
            h5 = ff_layer.layernorm_out(h4)

            h = h5

            layer_scales.append([scale_h_in, scale_v, scale_attn_score, scale_mlp_hidden])

        logits = glm.condgen.lm_head(h).permute(1, 0, 2).contiguous()[0, -1]
        # Get the logits on the next position
        if start_id is not None:
            next_id = start_id
            start_id = None
        else:
            next_id = torch.argmax(logits)
            predicted_ids.append(next_id.item())
        # print("Next ID", next_id)

        if next_id == glm.condgen.generation_config.eos_token_id:
            break
        input_ids = torch.tensor([[next_id]]).cuda()  # Append the last id
        position_ids = generate_position_ids(original_length, original_length + len(predicted_ids))[:, :, -1:].cuda()
        # print(position_ids)
        all_scales.append(layer_scales)
    
    print(predicted_ids)
    print(glm.decode(predicted_ids))
    return np.array(all_scales)

In [None]:
scales = analyze_hidden_scales("Tell me about Trump", 300)

In [None]:
print(np.max(scales, axis=0))

In [None]:
original_scale = 10000_4214

torch.tensor(original_scale + 34.52) - torch.tensor(original_scale)

In [None]:
# Check the scale of important weights
# This is helpful for deciding the mask_size


for i, (attn, ff) in enumerate(zip(attentions, ffs)):
    print(f"Layer {i:2d}, QKV max: {torch.max(attn.qkv_weight).item():4.4f}, AttnOut max: {torch.max(attn.attn_out_weight).item():4.4f}")

In [None]:
ffs