In [1]:
import os
from perm_llm.glm6b.wrapped_layer import Attention_GLM_Wrapped, copy_attention, FeedForward_GLM_Wrapped, copy_feedforward
from perm_llm.glm6b.utils import generate_position_ids

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from typing import List

In [3]:
from llm_bases.chatglm6b import ChatGML6B
glm = ChatGML6B()

Loading checkpoint shards: 100%|██████████| 8/8 [00:12<00:00,  1.62s/it]


In [4]:
from llm_bases.chatglm6b_official.modeling_chatglm import GLMBlock


raw_glm_layers: List[GLMBlock] = glm.condgen.transformer.layers
attentions: List[Attention_GLM_Wrapped] = []
attentions_public: List[Attention_GLM_Wrapped] = []
ffs: List[FeedForward_GLM_Wrapped] = []
for i in range(28):
    transformer_layer = raw_glm_layers[i].float()
    
    # The private attention layer
    attn_wrapped = Attention_GLM_Wrapped(4096, 32, i)
    copy_attention(transformer_layer, attn_wrapped)
    attn_wrapped.requires_grad_(False)
    attentions.append(attn_wrapped.cuda())

    # The public attention layer
    attn_wrapped_public = Attention_GLM_Wrapped(4096, 32, i)
    attn_wrapped_public.qkv_weight = None
    attn_wrapped_public.qkv_bias = None
    attn_wrapped_public.attn_out_weight = None
    attn_wrapped_public.attn_out_bias = None
    attn_wrapped_public.positional_embedding = attn_wrapped.positional_embedding
    attn_wrapped_public.requires_grad_(False)
    attentions_public.append(attn_wrapped_public.cuda())

    ff_wrapped = FeedForward_GLM_Wrapped(4096, 32, i)
    if i == 27:
        copy_feedforward(transformer_layer, None, ff_wrapped)
        ff_wrapped.layernorm_out = glm.condgen.transformer.final_layernorm.float()
    else:
        copy_feedforward(transformer_layer, raw_glm_layers[i + 1].float(), ff_wrapped)
    ff_wrapped.requires_grad_(False)
    ffs.append(ff_wrapped.cuda())

word_embedding = glm.condgen.transformer.word_embeddings.weight.float().cuda()
lm_head = glm.condgen.lm_head.float().cuda()
input_layernorm = raw_glm_layers[0].input_layernorm.float().cuda()

In [5]:
from perm_llm.common.communication import Communication, Node, SimulatedCommunication
communication = SimulatedCommunication(["n0", "n1", "n2"])
communication.new_stage("Test")

n0 = Node(communication, "n0")
n1 = Node(communication, "n1")
n2 = Node(communication, "n2")

n0.space.attentions = attentions
n1.space.attentions = n2.space.attentions = attentions_public
n0.space.ffs = n1.space.ffs = ffs
n0.space.word_embedding = word_embedding
n0.space.input_layernorm = n1.space.input_layernorm = input_layernorm

In [6]:
from perm_llm.glm6b.secure_inference_utils import generate_scale_dict

mask_scale = generate_scale_dict(100)

In [7]:
import sys
try:
    del sys.modules["perm_llm.glm6b.secure_inference"]
    del sys.modules["perm_llm.glm6b.secure_inference_utils"]
except:
    pass

In [8]:
from perm_llm.glm6b.secure_inference import GLM_Protocol

whole_protocol = GLM_Protocol(n0, n1, n2, mask_scale, device="cuda")

In [9]:
communication.new_stage("prepare")
whole_protocol.prepare()

In [10]:
import torch

In [11]:
def get_input_tensor(query: str):
    input_ids, _, _ = glm.get_tokenization(query)
    input_ids = input_ids[0]
    input_selector = torch.zeros(len(input_ids), glm.n_tokens)
    for i in range(len(input_ids)):
        input_selector[i, input_ids[i]] = 1
    return input_selector

In [12]:
def offline(prompt_len: int, generation_length: int):
    for next_length in [prompt_len] + [1] * generation_length:
        communication.new_stage(f"offline_{i}")
        whole_protocol.offline_execute(next_length)


def iteratively_generate(query: str, length: int):
    input_tensor = get_input_tensor(query).cuda()
    generation_start_tensor = input_tensor[-1:]
    input_tensor = input_tensor[:-1, :]
    generated_ids = []
    for i in range(length + 1):
        # communication.new_stage(f"offline_{i}")
        # whole_protocol.offline_execute(input_tensor.shape[0])
        communication.new_stage(f"online_{i}")
        n1.storage[f"{whole_protocol.name}:x"] = input_tensor
        whole_protocol.online_execute()
        if generation_start_tensor is None:
            next_id = n1.storage[f"{whole_protocol.name}:z"][0]
            generated_ids.append(next_id)
            print(glm.decode(generated_ids[-1]), end=' ')
            if next_id == glm.condgen.config.eos_token_id:
                break
            input_tensor = torch.zeros([1, glm.n_tokens]).cuda()
            input_tensor[0, next_id] = 1
        else:
            input_tensor = generation_start_tensor
            generation_start_tensor = None
    print(glm.decode(generated_ids), end=' ')

In [13]:
query = "How many stars are in the sky?"
generation_length = 30
offline(len(glm.get_tokenization(query)[0][0]) - 1, generation_length + 1)
iteratively_generate(query, generation_length)

  tensor = as_tensor(value)


It is difficult to give an exact number of stars in the sky , as the number of stars in the universe is constantly changing due to the expansion of It is difficult to give an exact number of stars in the sky, as the number of stars in the universe is constantly changing due to the expansion of 

In [14]:
n0.storage["transformer_layer_1/attn/dot_product:beaver_u0 appended, v0, w0"][-1][0].shape

torch.Size([1, 1, 32, 128])

In [15]:
import numpy as np
import json
class NpEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.integer):
            return int(obj)
        if isinstance(obj, np.floating):
            return float(obj)
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        return super(NpEncoder, self).default(obj)
json.dump(communication.comm_history, open("temp/comm_history.json", "w"), indent=4, cls=NpEncoder)

In [16]:
whole_protocol.reset()

In [17]:
query = "Tell me about Biden"
generation_length = 30
offline(len(glm.get_tokenization(query)[0][0]) - 1, generation_length + 1)
iteratively_generate(query, generation_length)

Joe Biden is the  4 6 th President of the United States , serving from January  2 0 ,  2 0 2 1 , until his resignation Joe Biden is the 46th President of the United States, serving from January 20, 2021, until his resignation 

In [18]:
print("\n".join(n1.storage.keys()))

embedding_retrieval/onehot_matmul:x-u
transformer_layer_0/attn/qkv_matmul/SS_Mul__CX_N0_Y_N1:x-u
transformer_layer_0/attn/attn_out/SS_Mul__CX_N0_Y_N1:x-u
transformer_layer_1/attn/qkv_matmul/SS_Mul__CX_N0_Y_N1:x-u
transformer_layer_1/attn/attn_out/SS_Mul__CX_N0_Y_N1:x-u
transformer_layer_2/attn/qkv_matmul/SS_Mul__CX_N0_Y_N1:x-u
transformer_layer_2/attn/attn_out/SS_Mul__CX_N0_Y_N1:x-u
transformer_layer_3/attn/qkv_matmul/SS_Mul__CX_N0_Y_N1:x-u
transformer_layer_3/attn/attn_out/SS_Mul__CX_N0_Y_N1:x-u
transformer_layer_4/attn/qkv_matmul/SS_Mul__CX_N0_Y_N1:x-u
transformer_layer_4/attn/attn_out/SS_Mul__CX_N0_Y_N1:x-u
transformer_layer_5/attn/qkv_matmul/SS_Mul__CX_N0_Y_N1:x-u
transformer_layer_5/attn/attn_out/SS_Mul__CX_N0_Y_N1:x-u
transformer_layer_6/attn/qkv_matmul/SS_Mul__CX_N0_Y_N1:x-u
transformer_layer_6/attn/attn_out/SS_Mul__CX_N0_Y_N1:x-u
transformer_layer_7/attn/qkv_matmul/SS_Mul__CX_N0_Y_N1:x-u
transformer_layer_7/attn/attn_out/SS_Mul__CX_N0_Y_N1:x-u
transformer_layer_8/attn/qkv_matmu

In [19]:
n1.storage[f"{whole_protocol.name}:x"].shape

torch.Size([1, 130528])

In [20]:
del n0, n1, n2, whole_protocol

In [30]:
np.float32(10005.03)-np.float32(10005)

0.030273438