In [1]:
import time


from simple_socket.zf_socket import SocketServer
from split_llm.common.communication import Node
from split_llm.common.real_communication import RealCommunication

device = "cuda"


# Set up communication

address_dict = {
    "127.0.0.1:9000": "n0",
    "127.0.0.1:9001": "n1",
    "127.0.0.1:9002": "n2"
}
sock0 = SocketServer("127.0.0.1:9000", address_dict, 1000)
sock1 = SocketServer("127.0.0.1:9001", address_dict, 1000)
sock2 = SocketServer("127.0.0.1:9002", address_dict, 1000)

time.sleep(1) # Wait the server to start listening

sock0.connect_all()
sock1.connect_all()
sock2.connect_all()

comm0 = RealCommunication(["n0", "n1", "n2"], {"n0": sock0}, tensor_device=device)
comm1 = RealCommunication(["n0", "n1", "n2"], {"n1": sock1}, tensor_device=device)
comm2 = RealCommunication(["n0", "n1", "n2"], {"n2": sock2}, tensor_device=device)

n0 = Node(comm0, "n0")
n1 = Node(comm1, "n1")
n2 = Node(comm2, "n2")

In [2]:
import os
from split_llm.glm6b.wrapped_layer import Attention_GLM_Wrapped, copy_attention, FeedForward_GLM_Wrapped, copy_feedforward
from split_llm.glm6b.utils import generate_position_ids

from typing import List

from llm_bases.chatglm6b import ChatGML6B
glm = ChatGML6B()

from llm_bases.chatglm6b_official.modeling_chatglm import GLMBlock

raw_glm_layers: List[GLMBlock] = glm.condgen.transformer.layers
attentions: List[Attention_GLM_Wrapped] = []
attentions_public: List[Attention_GLM_Wrapped] = []
ffs: List[FeedForward_GLM_Wrapped] = []
for i in range(28):
    transformer_layer = raw_glm_layers[i].float()
    
    # The private attention layer
    attn_wrapped = Attention_GLM_Wrapped(4096, 32, i)
    copy_attention(transformer_layer, attn_wrapped)
    attn_wrapped.requires_grad_(False)
    attentions.append(attn_wrapped.cuda())

    # The public attention layer
    attn_wrapped_public = Attention_GLM_Wrapped(4096, 32, i)
    attn_wrapped_public.qkv_weight = None
    attn_wrapped_public.qkv_bias = None
    attn_wrapped_public.attn_out_weight = None
    attn_wrapped_public.attn_out_bias = None
    attn_wrapped_public.positional_embedding = attn_wrapped.positional_embedding
    attn_wrapped_public.requires_grad_(False)
    attentions_public.append(attn_wrapped_public.cuda())

    ff_wrapped = FeedForward_GLM_Wrapped(4096, 32, i)
    if i == 27:
        copy_feedforward(transformer_layer, None, ff_wrapped)
        ff_wrapped.layernorm_out = glm.condgen.transformer.final_layernorm.float()
    else:
        copy_feedforward(transformer_layer, raw_glm_layers[i + 1].float(), ff_wrapped)
    ff_wrapped.requires_grad_(False)
    ffs.append(ff_wrapped.cuda())

word_embedding = glm.condgen.transformer.word_embeddings.weight.float().cuda()
lm_head = glm.condgen.lm_head.float().cuda()
input_layernorm = raw_glm_layers[0].input_layernorm.float().cuda()

  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 8/8 [00:09<00:00,  1.19s/it]


In [3]:
n0.space.attentions = attentions
n1.space.attentions = n2.space.attentions = attentions_public
n0.space.ffs = n1.space.ffs = ffs
n0.space.word_embedding = word_embedding
n0.space.input_layernorm = input_layernorm

In [4]:
from split_llm.glm6b.secure_inference import GLM_Protocol

protocol0 = GLM_Protocol(n0, Node.from_remote_name("n1"), Node.from_remote_name("n2"), 10, 500, device="cuda")
protocol1 = GLM_Protocol(Node.from_remote_name("n0"), n1, Node.from_remote_name("n2"), 10, 500, device="cuda")
protocol2 = GLM_Protocol(Node.from_remote_name("n0"), Node.from_remote_name("n1"), n2, 10, 500, device="cuda")

In [5]:
import threading

In [6]:
print("Start prepare...")
start_time = time.time()
prepare_th1 = threading.Thread(target=protocol1.prepare)
prepare_th2 = threading.Thread(target=protocol2.prepare)
prepare_th1.start()
prepare_th2.start()
protocol0.prepare()
prepare_th1.join()
prepare_th2.join()
print(f"Prepare stopped in {time.time() - start_time:.3}s.")

Start prepare...
Prepare stopped in 1.93e+03s.


In [7]:
import torch
def get_input_tensor(query: str):
    input_ids, _, _ = glm.get_tokenization(query)
    input_ids = input_ids[0]
    input_selector = torch.zeros(len(input_ids), glm.n_tokens)
    for i in range(len(input_ids)):
        input_selector[i, input_ids[i]] = 1
    return input_selector

input_tensor = get_input_tensor("Hello")
print(input_tensor.shape)

torch.Size([3, 130528])


  tensor = as_tensor(value)


In [8]:
def iteratively_generate(query: str, length: int):
    input_tensor = get_input_tensor(query).cuda()

    generation_start_tensor = input_tensor[-1:]
    input_tensor = input_tensor[:-1, :]
    generated_ids = []
    for i in range(length + 1):
        print("Start offline execute...")
        start_time = time.time()
        offline_th1 = threading.Thread(target=protocol1.offline_execute, args=(10,))
        offline_th2 = threading.Thread(target=protocol2.offline_execute, args=(10,))
        offline_th1.start()
        offline_th2.start()
        protocol0.offline_execute(10)
        offline_th1.join()
        offline_th2.join()
        n1.storage[f"{protocol1.name}:x"] = input_tensor

        print(f"Offline execution stopped in {time.time() - start_time:.3}s.")
        print("Start online execute...")
        start_time = time.time()
        online_th1 = threading.Thread(target=protocol1.online_execute)
        online_th2 = threading.Thread(target=protocol2.online_execute)
        online_th1.start()
        online_th2.start()
        protocol0.online_execute()
        online_th1.join()
        online_th2.join()
        print(f"Online execution stopped in {time.time() - start_time:.3}s.")

        if generation_start_tensor is None:
            next_id = n1.storage[f"{protocol1.name}:z"][0]
            generated_ids.append(next_id)
            print(glm.decode(generated_ids[-1]), end=' ')
            if next_id == glm.condgen.config.eos_token_id:
                break
            input_tensor = torch.zeros([1, glm.n_tokens]).cuda()
            input_tensor[0, next_id] = 1
        else:
            input_tensor = generation_start_tensor
            generation_start_tensor = None

        print(glm.decode(generated_ids))


In [10]:
iteratively_generate("Tell me about Trump", 100)

Exception in thread Thread-18 (offline_execute):
Traceback (most recent call last):
  File "/root/miniconda3/envs/llm/lib/python3.10/threading.py", line 1016, in _bootstrap_inner
    self.run()
  File "/root/miniconda3/envs/llm/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 761, in run_closure
    _threading_Thread_run(self)
  File "/root/miniconda3/envs/llm/lib/python3.10/threading.py", line 953, in run
    self._target(*self._args, **self._kwargs)
  File "/root/autodl-tmp/PermLLM/split_llm/glm6b/secure_inference.py", line 948, in offline_execute
    layer_protocol.offline_execute(next_length)
  File "/root/autodl-tmp/PermLLM/split_llm/glm6b/secure_inference.py", line 634, in offline_execute
    self.attn_protocol.offline_execute(next_length)
  File "/root/autodl-tmp/PermLLM/split_llm/glm6b/secure_inference.py", line 164, in offline_execute
    self.dot_product_protocol.offline_execute([next_length, 1, GLMConfig.n_heads, GLMConfig.head_dim], [next_length, self.total_length,

Start offline execute...


KeyboardInterrupt: 