In [1]:
import os
from split_llm.glm6b.wrapped_layer import Attention_GLM_Wrapped, copy_attention, FeedForward_GLM_Wrapped, copy_feedforward
from split_llm.glm6b.utils import generate_position_ids

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from typing import List
import sys
try:
    del sys.modules["split_llm.glm6b.secure_inference"]
except:
    pass

In [3]:
from llm_bases.chatglm6b import ChatGML6B
glm = ChatGML6B()

Loading checkpoint shards: 100%|██████████| 8/8 [00:08<00:00,  1.09s/it]


In [4]:
from llm_bases.chatglm6b_official.modeling_chatglm import GLMBlock


raw_glm_layers: List[GLMBlock] = glm.condgen.transformer.layers
attentions: List[Attention_GLM_Wrapped] = []
attentions_public: List[Attention_GLM_Wrapped] = []
ffs: List[FeedForward_GLM_Wrapped] = []
for i in range(28):
    transformer_layer = raw_glm_layers[i].float()
    
    # The private attention layer
    attn_wrapped = Attention_GLM_Wrapped(4096, 32, i)
    copy_attention(transformer_layer, attn_wrapped)
    attn_wrapped.requires_grad_(False)
    attentions.append(attn_wrapped.cuda())

    # The public attention layer
    attn_wrapped_public = Attention_GLM_Wrapped(4096, 32, 0)
    attn_wrapped_public.qkv_weight = None
    attn_wrapped_public.qkv_bias = None
    attn_wrapped_public.attn_out_weight = None
    attn_wrapped_public.attn_out_bias = None
    attn_wrapped_public.positional_embedding = attn_wrapped.positional_embedding
    attn_wrapped_public.requires_grad_(False)
    attentions_public.append(attn_wrapped_public.cuda())

    ff_wrapped = FeedForward_GLM_Wrapped(4096, 32, i)
    if i == 27:
        copy_feedforward(transformer_layer, None, ff_wrapped)
    else:
        copy_feedforward(transformer_layer, raw_glm_layers[i + 1].float(), ff_wrapped)
    ff_wrapped.requires_grad_(False)
    ffs.append(ff_wrapped.cuda())

word_embedding = glm.condgen.transformer.word_embeddings.float().cuda()
word_embedding.weight.data = raw_glm_layers[0].input_layernorm.float().cuda()(word_embedding.weight.data[:glm.n_tokens])
lm_head = glm.condgen.lm_head.float().cuda()


In [5]:
from split_llm.common.communication import Communication, Node, SimulatedCommunication
communication = SimulatedCommunication(["n0", "n1", "n2"])
communication.new_stage("Test")

n0 = Node(communication, "n0")
n1 = Node(communication, "n1")
n2 = Node(communication, "n2")

n0.space.attentions = attentions
n1.space.attentions = attentions_public
n0.space.ffs = n1.space.ffs = ffs
n0.space.word_embedding = word_embedding
n0.space.final_dense = lm_head

In [6]:
from split_llm.glm6b.secure_inference import GLM_Protocol

whole_protocol = GLM_Protocol(n0, n1, n2, 10, 100, device="cuda")

In [7]:
whole_protocol.prepare()

In [8]:
whole_protocol.offline_execute(3)

In [9]:
import torch

In [10]:
def get_input_tensor(query: str):
    input_ids, _, _ = glm.get_tokenization(query)
    input_ids = input_ids[0]
    input_selector = torch.zeros(len(input_ids), glm.n_tokens)
    for i in range(len(input_ids)):
        input_selector[i, input_ids[i]] = 1
    return input_selector

input_tensor = get_input_tensor("Hello")
print(input_tensor.shape)

torch.Size([3, 130006])


  tensor = as_tensor(value)


In [11]:
n1.storage[f"{whole_protocol.name}:x"] = input_tensor.cuda()

In [12]:
whole_protocol.online_execute()

In [13]:
print(n1.storage[f"{whole_protocol.name}:z"])

[194]


In [27]:
index_selector = torch.zeros([1, glm.n_tokens]).cuda()
index_selector[0, 1110] = 1
n1.storage[f"{whole_protocol.name}:x"] = index_selector
whole_protocol.offline_execute(1)
whole_protocol.online_execute()
print(n1.storage[f"{whole_protocol.name}:z"])

[1110]


In [28]:
glm.decode([194, 107, 100, 254, 114, 104, 437, 589, 1110, 1110, 1110, 1110, 1110])

'It is the same as a few months ago ago ago ago ago'

In [None]:
n0.storage["transformer_layer_0/attn/dot_product:x-u"].shape

In [None]:
n0.storage

In [None]:
torch.argmax(input_tensor, dim=1)