In [1]:
import time


from simple_socket.zf_socket import SocketServer
from split_llm.common.communication import Node
from split_llm.common.communication import Communication, Node, SimulatedCommunication
from split_llm.common.real_communication import RealCommunication

device = "cuda"


# Set up communication

address_dict = {
    "127.0.0.1:3000": "n0",
    "127.0.0.1:3001": "n1",
    "127.0.0.1:3002": "n2"
}
sock0 = SocketServer("127.0.0.1:3000", address_dict, 10)
sock1 = SocketServer("127.0.0.1:3001", address_dict, 10)
sock2 = SocketServer("127.0.0.1:3002", address_dict, 10)

time.sleep(1) # Wait the server to start listening

sock0.connect_all()
sock1.connect_all()
sock2.connect_all()

comm0 = RealCommunication(["n0", "n1", "n2"], {"n0": sock0}, tensor_device=device)
comm1 = RealCommunication(["n0", "n1", "n2"], {"n1": sock1}, tensor_device=device)
comm2 = RealCommunication(["n0", "n1", "n2"], {"n2": sock2}, tensor_device=device)
# comm0 = comm1 = comm2 = SimulatedCommunication(["n0", "n1", "n2"])
# comm0.new_stage("Test")
n0 = Node(comm0, "n0")
n1 = Node(comm1, "n1")
n2 = Node(comm2, "n2")

In [2]:

communication = SimulatedCommunication(["n0", "n1", "n2"])
communication.new_stage("Test")

n0_local = Node(communication, "n0")
n1_local = Node(communication, "n1")
n2_local = Node(communication, "n2")

In [3]:
print(time.time())

1714706601.8157182


In [4]:
import os
from split_llm.glm6b.wrapped_layer import Attention_GLM_Wrapped, copy_attention, FeedForward_GLM_Wrapped, copy_feedforward
from split_llm.glm6b.utils import generate_position_ids

from typing import List

from llm_bases.chatglm6b import ChatGML6B
glm = ChatGML6B()

from llm_bases.chatglm6b_official.modeling_chatglm import GLMBlock

raw_glm_layers: List[GLMBlock] = glm.condgen.transformer.layers
attentions: List[Attention_GLM_Wrapped] = []
attentions_public: List[Attention_GLM_Wrapped] = []
ffs: List[FeedForward_GLM_Wrapped] = []
for i in range(28):
    transformer_layer = raw_glm_layers[i].float()
    
    # The private attention layer
    attn_wrapped = Attention_GLM_Wrapped(4096, 32, i)
    copy_attention(transformer_layer, attn_wrapped)
    attn_wrapped.requires_grad_(False)
    attentions.append(attn_wrapped.cuda())

    # The public attention layer
    attn_wrapped_public = Attention_GLM_Wrapped(4096, 32, i)
    attn_wrapped_public.qkv_weight = None
    attn_wrapped_public.qkv_bias = None
    attn_wrapped_public.attn_out_weight = None
    attn_wrapped_public.attn_out_bias = None
    attn_wrapped_public.positional_embedding = attn_wrapped.positional_embedding
    attn_wrapped_public.requires_grad_(False)
    attentions_public.append(attn_wrapped_public.cuda())

    ff_wrapped = FeedForward_GLM_Wrapped(4096, 32, i)
    if i == 27:
        copy_feedforward(transformer_layer, None, ff_wrapped)
        ff_wrapped.layernorm_out = glm.condgen.transformer.final_layernorm.float()
    else:
        copy_feedforward(transformer_layer, raw_glm_layers[i + 1].float(), ff_wrapped)
    ff_wrapped.requires_grad_(False)
    ffs.append(ff_wrapped.cuda())

word_embedding = glm.condgen.transformer.word_embeddings.weight.float().cuda()
lm_head = glm.condgen.lm_head.float().cuda()
input_layernorm = raw_glm_layers[0].input_layernorm.float().cuda()

  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 8/8 [00:08<00:00,  1.00s/it]


In [5]:
n0_local.space.attentions = attentions
n1_local.space.attentions = n2_local.space.attentions = attentions_public
n0_local.space.ffs = n1_local.space.ffs = ffs
n0_local.space.word_embedding = word_embedding
n0_local.space.input_layernorm = n1_local.space.input_layernorm = input_layernorm

In [6]:
from split_llm.glm6b.secure_inference import GLM_Protocol

In [7]:
protocol_local = GLM_Protocol(n0_local, n1_local, n2_local, 1, 500, device="cuda")
protocol_local.prepare()

In [8]:
import torch
def get_input_tensor(query: str):
    input_ids, _, _ = glm.get_tokenization(query)
    input_ids = input_ids[0]
    input_selector = torch.zeros(len(input_ids), glm.n_tokens)
    for i in range(len(input_ids)):
        input_selector[i, input_ids[i]] = 1
    return input_selector

In [9]:
# # Test another local by copy the storage dict of nodes
# n0_local1 = Node(communication, "n0")
# n1_local1 = Node(communication, "n1")
# n2_local1 = Node(communication, "n2")

# protocol_local1 = GLM_Protocol(n0_local1, n1_local1, n2_local1, 1, 100, device="cuda")

# n0_local1.space.attentions = attentions
# n1_local1.space.attentions = n2_local1.space.attentions = attentions_public
# n0_local1.space.ffs = n1_local1.space.ffs = ffs
# n0_local1.space.word_embedding = word_embedding
# n0_local1.space.input_layernorm = n1_local1.space.input_layernorm = input_layernorm

# n0_local1.storage = n0_local.storage.copy()
# n1_local1.storage = n1_local.storage.copy()
# n2_local1.storage = n2_local.storage.copy()

# def iteratively_generate_local1(query: str, length: int):
#     input_tensor = get_input_tensor(query).cuda()

#     generation_start_tensor = input_tensor[-1:]
#     input_tensor = input_tensor[:-1, :]
#     generated_ids = []
#     for i in range(length + 1):
#         protocol_local1.offline_execute(len(input_tensor))
#         n1_local1.storage[f"{protocol_local1.name}:x"] = input_tensor
#         protocol_local1.online_execute()

#         if generation_start_tensor is None:
#             next_id = n1_local1.storage[f"{protocol_local1.name}:z"][0]
#             print(next_id)
#             generated_ids.append(next_id)
#             print(glm.decode(generated_ids[-1]), end=' ')
#             if next_id == glm.condgen.config.eos_token_id:
#                 break
#             input_tensor = torch.zeros([1, glm.n_tokens]).cuda()
#             input_tensor[0, next_id] = 1
#         else:
#             next_id = n1_local1.storage[f"{protocol_local1.name}:z"][0]
#             print(next_id)
#             input_tensor = generation_start_tensor
#             generation_start_tensor = None
#     print()
#     print(glm.decode(generated_ids), end=' ')

# iteratively_generate_local1("Tell me about Trump", 300)


In [10]:

# n0_local1.storage["prediction/final_dense:z0"] + n1_local1.storage["prediction/final_dense:z1"]

In [11]:
n0.space.attentions = attentions
n1.space.attentions = n2.space.attentions = attentions_public
n0.space.ffs = n1.space.ffs = ffs
n0.space.word_embedding = word_embedding
n0.space.input_layernorm = n1.space.input_layernorm = input_layernorm

n0.storage = n0_local.storage.copy()
n1.storage = n1_local.storage.copy()
n2.storage = n2_local.storage.copy()


In [12]:
protocol0 = GLM_Protocol(n0, Node.from_remote_name("n1"), Node.from_remote_name("n2"), 1, 100, device="cuda")
protocol1 = GLM_Protocol(Node.from_remote_name("n0"), n1, Node.from_remote_name("n2"), 1, 100, device="cuda")
protocol2 = GLM_Protocol(Node.from_remote_name("n0"), Node.from_remote_name("n1"), n2, 1, 100, device="cuda")

In [13]:
import threading

In [14]:
# print("Start prepare...")
# start_time = time.time()
# prepare_th1 = threading.Thread(target=protocol1.prepare)
# prepare_th2 = threading.Thread(target=protocol2.prepare)
# prepare_th1.start()
# prepare_th2.start()
# protocol0.prepare()
# prepare_th1.join()
# prepare_th2.join()
# print(f"Prepare stopped in {time.time() - start_time:.3}s.")

In [15]:
def iteratively_generate(query: str, length: int):
    input_tensor = get_input_tensor(query).cuda()

    generation_start_tensor = input_tensor[-1:]
    input_tensor = input_tensor[:-1, :]
    generated_ids = []
    for i in range(length + 1):
        print("Start offline execute...")
        next_length = len(input_tensor)
        start_time = time.time()
        offline_th0 = threading.Thread(target=protocol0.offline_execute, args=(next_length,))
        offline_th2 = threading.Thread(target=protocol2.offline_execute, args=(next_length,))
        offline_th0.start()
        offline_th2.start()
        protocol1.offline_execute(next_length)
        offline_th0.join()
        offline_th2.join()

        n1.storage[f"{protocol1.name}:x"] = input_tensor
        print(f"Offline execution finished in {time.time() - start_time:.3}s.")

        print("Start online execute...")
        start_time = time.time()
        online_th0 = threading.Thread(target=protocol0.online_execute)
        online_th2 = threading.Thread(target=protocol2.online_execute)
        online_th0.start()
        online_th2.start()
        protocol1.online_execute()
        online_th0.join()
        online_th2.join()
        print(f"Online execution finished in {time.time() - start_time:.3}s.")

        if generation_start_tensor is None:
            next_id = n1.storage[f"{protocol1.name}:z"][0]
            # print(next_id)
            generated_ids.append(next_id)
            if next_id == glm.condgen.config.eos_token_id:
                break
            input_tensor = torch.zeros([1, glm.n_tokens]).cuda()
            input_tensor[0, next_id] = 1
        else:
            next_id = n1.storage[f"{protocol1.name}:z"][0]
            # print(next_id)
            input_tensor = generation_start_tensor
            generation_start_tensor = None

        print(glm.decode(generated_ids))

In [16]:
iteratively_generate("What is cat", 100)

  tensor = as_tensor(value)


Start offline execute...
Offline execution finished in 1.75s.
Start online execute...
Online execution finished in 2.83s.

Start offline execute...
Offline execution finished in 1.02s.
Start online execute...
Online execution finished in 2.1s.
Donald
Start offline execute...
Offline execution finished in 1.01s.
Start online execute...
Online execution finished in 2.07s.
Donald Trump
Start offline execute...
Offline execution finished in 1.02s.
Start online execute...
Online execution finished in 2.11s.
Donald Trump is
Start offline execute...
Offline execution finished in 1.02s.
Start online execute...
Online execution finished in 2.06s.
Donald Trump is a
Start offline execute...
Offline execution finished in 1.06s.
Start online execute...
Online execution finished in 2.07s.
Donald Trump is a former
Start offline execute...
Offline execution finished in 1.02s.
Start online execute...
Online execution finished in 2.07s.
Donald Trump is a former American
Start offline execute...
Offline 

In [17]:
n0.storage["prediction/final_dense:z0"] + n1.storage["prediction/final_dense:z1"]

KeyError: 'prediction/final_dense:z0'