In [None]:
import time


from simple_socket.zf_socket import SocketServer
from perm_llm.common.communication import Node
from perm_llm.common.communication import Communication, Node, SimulatedCommunication
from perm_llm.common.real_communication import RealCommunication

device = "cuda"


# Set up communication

address_dict = {
    "127.0.0.1:3000": "n0",
    "127.0.0.1:3001": "n1",
    "127.0.0.1:3002": "n2"
}
sock0 = SocketServer("127.0.0.1:3000", address_dict, 20)
sock1 = SocketServer("127.0.0.1:3001", address_dict, 20)
sock2 = SocketServer("127.0.0.1:3002", address_dict, 20)

time.sleep(1) # Wait the server to start listening


import threading

connect_0 = threading.Thread(target=sock0.connect_all)
connect_1 = threading.Thread(target=sock1.connect_all)
connect_0.start()
connect_1.start()
sock2.connect_all()
connect_0.join()
connect_1.join()

comm0 = RealCommunication({"n0": sock0}, tensor_device=device)
comm1 = RealCommunication({"n1": sock1}, tensor_device=device)
comm2 = RealCommunication({"n2": sock2}, tensor_device=device)
n0 = Node(comm0, "n0")
n1 = Node(comm1, "n1")
n2 = Node(comm2, "n2")

In [None]:
import os
from perm_llm.glm6b.wrapped_layer import Attention_GLM_Wrapped, copy_attention, FeedForward_GLM_Wrapped, copy_feedforward
from perm_llm.glm6b.utils import generate_position_ids

from typing import List

from llm_bases.chatglm6b import ChatGML6B
glm = ChatGML6B()

from llm_bases.chatglm6b_official.modeling_chatglm import GLMBlock

raw_glm_layers: List[GLMBlock] = glm.condgen.transformer.layers
attentions: List[Attention_GLM_Wrapped] = []
attentions_public: List[Attention_GLM_Wrapped] = []
ffs: List[FeedForward_GLM_Wrapped] = []
for i in range(28):
    transformer_layer = raw_glm_layers[i].float()
    
    # The attention layer
    attn_wrapped = Attention_GLM_Wrapped(4096, 32, i)
    copy_attention(transformer_layer, attn_wrapped)
    attn_wrapped.requires_grad_(False)
    attentions.append(attn_wrapped.cuda())


    ff_wrapped = FeedForward_GLM_Wrapped(4096, 32, i)
    if i == 27:
        copy_feedforward(transformer_layer, None, ff_wrapped)
        ff_wrapped.layernorm_out = glm.condgen.transformer.final_layernorm.float()
    else:
        copy_feedforward(transformer_layer, raw_glm_layers[i + 1].float(), ff_wrapped)
    ff_wrapped.requires_grad_(False)
    ffs.append(ff_wrapped.cuda())

word_embedding = glm.condgen.transformer.word_embeddings.weight.float().cuda()
input_layernorm = raw_glm_layers[0].input_layernorm.float().cuda()

In [None]:
from perm_llm.glm6b.secure_inference import GLM_Protocol
from perm_llm.glm6b.secure_inference_utils import generate_scale_dict

In [None]:
import torch
def get_input_tensor(query: str):
    input_ids, _, _ = glm.get_tokenization(query)
    input_ids = input_ids[0]
    input_selector = torch.zeros(len(input_ids), glm.n_tokens)
    for i in range(len(input_ids)):
        input_selector[i, input_ids[i]] = 1
    return input_selector

In [None]:

# n0_local1.storage["prediction/final_dense:z0"] + n1_local1.storage["prediction/final_dense:z1"]

In [None]:
n2.space.attentions = n1.space.attentions = n0.space.attentions = attentions
n0.space.ffs = n1.space.ffs = ffs
n0.space.word_embedding = word_embedding
n0.space.input_layernorm = n1.space.input_layernorm = input_layernorm


In [None]:
scale_dict = generate_scale_dict(100)
protocol0 = GLM_Protocol(n0, Node.from_remote_name("n1"), Node.from_remote_name("n2"), scale_dict, device="cuda")
protocol1 = GLM_Protocol(Node.from_remote_name("n0"), n1, Node.from_remote_name("n2"), scale_dict, device="cuda")
protocol2 = GLM_Protocol(Node.from_remote_name("n0"), Node.from_remote_name("n1"), n2, scale_dict, device="cuda")

In [None]:
import threading

In [None]:
print("Start prepare...")
start_time = time.time()
prepare_th1 = threading.Thread(target=protocol1.prepare)
prepare_th2 = threading.Thread(target=protocol2.prepare)
prepare_th1.start()
prepare_th2.start()
protocol0.prepare()
prepare_th1.join()
prepare_th2.join()
print(f"Prepare stopped in {time.time() - start_time:.3}s.")

In [None]:

def offline(prompt_length: int, generation_length: int):
    lengths = [prompt_length] + [1] * generation_length
    for next_length in lengths:
        print("Start offline execute...")
        start_time = time.time()
        offline_th0 = threading.Thread(target=protocol0.offline_execute, args=(next_length,))
        offline_th2 = threading.Thread(target=protocol2.offline_execute, args=(next_length,))
        offline_th0.start()
        offline_th2.start()
        protocol1.offline_execute(next_length)
        offline_th0.join()
        offline_th2.join()
        print(f"Offline execution finished in {time.time() - start_time:.3}s.")

def iteratively_generate(query: str, length: int):
    input_tensor = get_input_tensor(query).cuda()

    generation_start_tensor = input_tensor[-1:]
    input_tensor = input_tensor[:-1, :]
    generated_ids = []
    for i in range(length + 1):
        n1.storage[f"{protocol1.name}:x"] = input_tensor
        print("Start online execute...")
        start_time = time.time()
        online_th0 = threading.Thread(target=protocol0.online_execute)
        online_th2 = threading.Thread(target=protocol2.online_execute)
        online_th0.start()
        online_th2.start()
        protocol1.online_execute()
        online_th0.join()
        online_th2.join()
        print(f"Online execution finished in {time.time() - start_time:.3}s.")

        if generation_start_tensor is None:
            next_id = n1.storage[f"{protocol1.name}:z"][0]
            # print(next_id)
            generated_ids.append(next_id)
            if next_id == glm.condgen.config.eos_token_id:
                break
            input_tensor = torch.zeros([1, glm.n_tokens]).cuda()
            input_tensor[0, next_id] = 1
        else:
            next_id = n1.storage[f"{protocol1.name}:z"][0]
            # print(next_id)
            input_tensor = generation_start_tensor
            generation_start_tensor = None

        print(glm.decode(generated_ids))

In [None]:
comm0.simulate_network(None, None)
comm1.simulate_network(None, None)
comm2.simulate_network(None, None)

try:
    protocol0.reset()
    protocol1.reset()
    protocol2.reset()
except:
    pass

In [None]:
query = "How many stars are in the sky?"
generation_length = 30
offline(len(glm.get_tokenization(query)[0][0]) - 1, generation_length)

In [None]:
# comm0.simulate_network(10, 1000)
# comm1.simulate_network(10, 1000)
# comm2.simulate_network(10, 1000)
iteratively_generate(query, generation_length)