In [None]:
import os
from split_llm.glm6b.wrapped_layer import Attention_GLM_Wrapped, copy_attention, FeedForward_GLM_Wrapped, copy_feedforward
from split_llm.glm6b.utils import generate_position_ids

In [None]:
from typing import List
import sys
try:
    del sys.modules["split_llm.glm6b.secure_inference"]
except:
    pass

In [None]:
from llm_bases.chatglm6b import ChatGML6B
glm = ChatGML6B()

In [None]:
from llm_bases.chatglm6b_official.modeling_chatglm import GLMBlock


raw_glm_layers: List[GLMBlock] = glm.condgen.transformer.layers
attentions: List[Attention_GLM_Wrapped] = []
attentions_public: List[Attention_GLM_Wrapped] = []
ffs: List[FeedForward_GLM_Wrapped] = []
for i in range(28):
    transformer_layer = raw_glm_layers[i].float()
    
    # The private attention layer
    attn_wrapped = Attention_GLM_Wrapped(4096, 32, i)
    copy_attention(transformer_layer, attn_wrapped)
    attn_wrapped.requires_grad_(False)
    attentions.append(attn_wrapped.cuda())

    # The public attention layer
    attn_wrapped_public = Attention_GLM_Wrapped(4096, 32, i)
    attn_wrapped_public.qkv_weight = None
    attn_wrapped_public.qkv_bias = None
    attn_wrapped_public.attn_out_weight = None
    attn_wrapped_public.attn_out_bias = None
    attn_wrapped_public.positional_embedding = attn_wrapped.positional_embedding
    attn_wrapped_public.requires_grad_(False)
    attentions_public.append(attn_wrapped_public.cuda())

    ff_wrapped = FeedForward_GLM_Wrapped(4096, 32, i)
    if i == 27:
        copy_feedforward(transformer_layer, None, ff_wrapped)
        ff_wrapped.layernorm_out = glm.condgen.transformer.final_layernorm.float()
    else:
        copy_feedforward(transformer_layer, raw_glm_layers[i + 1].float(), ff_wrapped)
    ff_wrapped.requires_grad_(False)
    ffs.append(ff_wrapped.cuda())

word_embedding = glm.condgen.transformer.word_embeddings.float().cuda()
word_embedding.weight.data = raw_glm_layers[0].input_layernorm.float().cuda()(word_embedding.weight.data)
lm_head = glm.condgen.lm_head.float().cuda()


In [None]:
from split_llm.common.communication import Communication, Node, SimulatedCommunication
communication = SimulatedCommunication(["n0", "n1", "n2"])
communication.new_stage("Test")

n0 = Node(communication, "n0")
n1 = Node(communication, "n1")
n2 = Node(communication, "n2")

n0.space.attentions = attentions
n1.space.attentions = attentions_public
n0.space.ffs = n1.space.ffs = ffs
n0.space.word_embedding = word_embedding
n0.space.final_dense = lm_head

In [None]:
from split_llm.glm6b.secure_inference import GLM_Protocol

whole_protocol = GLM_Protocol(n0, n1, n2, 10, 500, device="cuda")

In [None]:
whole_protocol.prepare()

In [None]:
import torch

In [None]:
def get_input_tensor(query: str):
    input_ids, _, _ = glm.get_tokenization(query)
    input_ids = input_ids[0]
    input_selector = torch.zeros(len(input_ids), glm.n_tokens)
    for i in range(len(input_ids)):
        input_selector[i, input_ids[i]] = 1
    return input_selector

input_tensor = get_input_tensor("Hello")
print(input_tensor.shape)

In [None]:
def iteratively_generate(query: str, length: int):
    input_tensor = get_input_tensor(query).cuda()

    generation_start_tensor = input_tensor[-1:]
    input_tensor = input_tensor[:-1, :]
    generated_ids = []
    for i in range(length + 1):
        whole_protocol.offline_execute(len(input_tensor))
        n1.storage[f"{whole_protocol.name}:x"] = input_tensor
        whole_protocol.online_execute()

        if generation_start_tensor is None:
            next_id = n1.storage[f"{whole_protocol.name}:z"][0]
            generated_ids.append(next_id)
            print(glm.decode(generated_ids[-1]), end=' ')
            if next_id == glm.condgen.config.eos_token_id:
                break
            input_tensor = torch.zeros([1, glm.n_tokens]).cuda()
            input_tensor[0, next_id] = 1
        else:
            input_tensor = generation_start_tensor
            generation_start_tensor = None

    print(glm.decode(generated_ids[-1]), end=' ')

In [None]:
iteratively_generate("Tell me about Trump", 300)

is a businessman and politician who was the President of the United States from January  2 0 ,  2 0 1 7 to January  2 0 ,  2 0 2 1 . He was born in Trump Tower in New York City and is the son of Donald Trump and Ivan a Trump . Trump began his 

In [None]:
whole_protocol.layer_protocols[-1].__dict__