In [1]:
from llm_bases.chatglm6b import ChatGML6B
glm = ChatGML6B()

  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 8/8 [00:16<00:00,  2.11s/it]


In [2]:
import torch
import torch.nn.functional as F
import numpy as np

In [3]:
from split_llm.glm6b.wrapped_layer import Attention_GLM_Wrapped, FeedForward_GLM_Wrapped, copy_attention, copy_feedforward
from split_llm.common.torch_utils import relative_error


In [4]:
raw_glm_layers = glm.condgen.transformer.layers
attentions = []
ffs = []
for i in range(28):
    transformer_layer = raw_glm_layers[0].float()
    attn_wrapped = Attention_GLM_Wrapped(4096, 32, i)
    copy_attention(transformer_layer, attn_wrapped)
    attentions.append(attn_wrapped)

    ff_wrapped = FeedForward_GLM_Wrapped(4096, 32, i)
    if i == 27:
        copy_feedforward(transformer_layer, None, ff_wrapped)
    else:
        copy_feedforward(transformer_layer, raw_glm_layers[i + 1].float(), ff_wrapped)

: 

In [None]:
from split_llm.glm6b.utils import generate_attention_mask, generate_position_ids

In [None]:
# Convert the GLM model weights (embeddings at the beginning and the end) into CUDA-float

glm.condgen.transformer.word_embeddings.float().cuda()
glm.condgen.lm_head.float().cuda()

In [None]:
def analyze_hidden_scales(query: str, generation_length: int):
    input_ids, position_ids, attention_masks = glm.get_tokenization(query)
    input_ids = input_ids.cuda()
    position_ids = position_ids.cuda()
    attention_masks = attention_masks.cuda()

    original_length = len(input_ids)
    
    all_scales = []
    for i in range(generation_length):
        # Here the first layernorm is moved
        h = glm.condgen.transformer.layers[0].input_layernorm.float().cuda()(glm.get_initial_state(input_ids))

        layer_scales = []
        for j in range(28):
            # Forward to the i-th attention layer    
            scale_h_in = h.abs().max().item()

            # Attention module
            attn_layer: Attention_GLM_Wrapped = attentions[j]
            q, k, v = attn_layer.generate_qkv(h, position_ids)
            attn_scores = attn_layer.generate_logit_scores(q, k)
            
            scale_attn_score = attn_scores.abs().max().item()  # Record the scale since here the plaintext is used
            
            softmax_scores = attn_layer.generate_softmax_scores(scale_attn_score)
            weighted_v = attn_layer.generate_weighted_values(softmax_scores, v)
            attn_out = weighted_v @ attn_layer.attn_out_weight.T + attn_layer.attn_out_bias

            # Feedforward module
            ff_layer: FeedForward_GLM_Wrapped = ffs[j]
            h0 = ff_layer.layernorm_in(attn_out + h * (2 * 28) ** 0.5)
            h1 = ff_layer.mlp_dense_in(h0)
            
            scale_mlp_hidden = h1.abs().max().item()
            h2 = F.gelu(h1)

            #  h2 = gelu_openai(h1)
            #  Those two gelu implementations do not have significant difference
            h3 = ff_layer.mlp_dense_out(h2)
            h4 = h3 + ff_layer.residual_coef * h0
            h5 = ff_layer.layernorm_out(h4)

            h = h5

            layer_scales = [scale_h_in, scale_attn_score, scale_mlp_hidden]

        logits = glm.condgen.lm_head(h).permute(1, 0, 2).contiguous()[0, -1, :glm.n_tokens]
        # Get the logits on the next position
        next_id = torch.argmax(logits)
    
        input_ids = torch.cat([input_ids, torch.tensor([[next_id]]).cuda()], dim=-1)  # Append the last id
        new_seq_len = len(input_ids[0])
        position_ids = generate_position_ids(original_length, new_seq_len).cuda()

        all_scales.append(layer_scales)
    
    token_ids = input_ids[0].tolist()

    print(glm.decode(token_ids))
    return np.array(all_scales)

In [None]:
analyze_hidden_scales("What's the difference between a dog and a cat?", 100)