In [1]:
import os
from perm_llm.glm6b.wrapped_layer import Attention_GLM_Wrapped, copy_attention, FeedForward_GLM_Wrapped, copy_feedforward
from perm_llm.glm6b.utils import generate_position_ids

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from typing import List

In [3]:
from llm_bases.chatglm6b import ChatGML6B
glm = ChatGML6B()

Loading checkpoint shards: 100%|██████████| 8/8 [00:11<00:00,  1.45s/it]


In [4]:
from llm_bases.chatglm6b_official.modeling_chatglm import GLMBlock

device = "cpu"

raw_glm_layers: List[GLMBlock] = glm.condgen.transformer.layers
attentions: List[Attention_GLM_Wrapped] = []
attentions_public: List[Attention_GLM_Wrapped] = []
ffs: List[FeedForward_GLM_Wrapped] = []
for i in range(28):
    transformer_layer = raw_glm_layers[i].float()
    
    # The private attention layer
    attn_wrapped = Attention_GLM_Wrapped(4096, 32, i)
    copy_attention(transformer_layer, attn_wrapped)
    attn_wrapped.requires_grad_(False)
    attentions.append(attn_wrapped.to(device))

    ff_wrapped = FeedForward_GLM_Wrapped(4096, 32, i)
    if i == 27:
        copy_feedforward(transformer_layer, None, ff_wrapped)
        ff_wrapped.layernorm_out = glm.condgen.transformer.final_layernorm.float()
    else:
        copy_feedforward(transformer_layer, raw_glm_layers[i + 1].float(), ff_wrapped)
    ff_wrapped.requires_grad_(False)
    ffs.append(ff_wrapped.to(device))

word_embedding = glm.condgen.transformer.word_embeddings.weight.float().to(device)
input_layernorm = raw_glm_layers[0].input_layernorm.float().to(device)

In [5]:
from perm_llm.common.communication import Communication, Node, SimulatedCommunication
communication = SimulatedCommunication(["n0", "n1", "n2"])
communication.new_stage("Test")

n0 = Node(communication, "n0")
n1 = Node(communication, "n1")
n2 = Node(communication, "n2")

n1.space.attentions = n2.space.attentions = n0.space.attentions = attentions
n0.space.ffs = n1.space.ffs = ffs
n0.space.word_embedding = word_embedding
n0.space.input_layernorm = n1.space.input_layernorm = input_layernorm

In [6]:
from perm_llm.glm6b.secure_inference_utils import generate_scale_dict

mask_scale = generate_scale_dict(100)

In [7]:
import sys
try:
    del sys.modules["perm_llm.glm6b.secure_inference"]
    del sys.modules["perm_llm.glm6b.secure_inference_utils"]
except:
    pass

In [8]:
from perm_llm.glm6b.secure_inference import GLM_Protocol

whole_protocol = GLM_Protocol(n0, n1, n2, mask_scale, device=device)

In [9]:
communication.new_stage("prepare")
whole_protocol.prepare()

In [10]:
import torch

In [11]:
def get_input_tensor(query: str):
    input_ids, _, _ = glm.get_tokenization(query)
    input_ids = input_ids[0]
    input_selector = torch.zeros(len(input_ids), glm.n_tokens)
    for i in range(len(input_ids)):
        input_selector[i, input_ids[i]] = 1
    return input_selector

In [12]:
def offline(prompt_len: int, generation_length: int):
    for i, next_length in enumerate([prompt_len] + [1] * generation_length):
        communication.new_stage(f"offline_{i}")
        whole_protocol.offline_execute(next_length)


def iteratively_generate(query: str, length: int):
    input_tensor = get_input_tensor(query).to(device)
    generation_start_tensor = input_tensor[-1:]
    input_tensor = input_tensor[:-1, :]
    generated_ids = []
    for i in range(length + 1):
        # communication.new_stage(f"offline_{i}")
        # whole_protocol.offline_execute(input_tensor.shape[0])
        communication.new_stage(f"online_{i}")
        n1.storage[f"{whole_protocol.name}:x"] = input_tensor
        whole_protocol.online_execute()
        if generation_start_tensor is None:
            next_id = n1.storage[f"{whole_protocol.name}:z"]
            generated_ids.append(next_id)
            print(i, glm.decode(generated_ids[-1]))
            if next_id == glm.condgen.config.eos_token_id:
                break
            input_tensor = torch.zeros([1, glm.n_tokens]).to(device)
            input_tensor[0, next_id] = 1
        else:
            input_tensor = generation_start_tensor
            generation_start_tensor = None
    print(glm.decode(generated_ids), end=' ')

In [13]:
query = "How many stars are in the sky?"
generation_length = 30
offline(len(glm.get_tokenization(query)[0][0]) - 1, generation_length + 1)
iteratively_generate(query, generation_length)

  tensor = as_tensor(value)


1 It
2 is
3 difficult
4 to
5 give
6 an
7 exact
8 number
9 of
10 stars
11 in
12 the
13 sky
14 ,
15 as
16 the
17 number
18 of
19 stars
20 in
21 the
22 universe
23 is
24 constantly
25 changing
26 .
27 However
28 ,
29 the
30 total
It is difficult to give an exact number of stars in the sky, as the number of stars in the universe is constantly changing. However, the total 

In [14]:
n0.storage["transformer_layer_1/attn/dot_product:beaver_u0 appended, v0, w0"][-1][0].shape

torch.Size([1, 1, 32, 128])

In [15]:
import numpy as np
import json
class NpEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.integer):
            return int(obj)
        if isinstance(obj, np.floating):
            return float(obj)
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        return super(NpEncoder, self).default(obj)
json.dump(communication.comm_history, open("temp/comm_history.json", "w"), indent=4, cls=NpEncoder)

In [16]:
whole_protocol.reset()

In [17]:
query = "Tell me about Biden"
generation_length = 30
offline(len(glm.get_tokenization(query)[0][0]) - 1, generation_length + 1)
iteratively_generate(query, generation_length)

1 Joe
2 Biden
3 is
4 the
5 
6 4
7 6
8 th
9 President
10 of
11 the
12 United
13 States
14 ,
15 serving
16 from
17 January
18 
19 2
20 0
21 ,
22 
23 2
24 0
25 2
26 1
27 ,
28 to
29 January
30 
Joe Biden is the 46th President of the United States, serving from January 20, 2021, to January  

In [18]:
list(communication.comm_history.keys())

['Default',
 'Test',
 'prepare',
 'offline_0',
 'offline_1',
 'offline_2',
 'offline_3',
 'offline_4',
 'offline_5',
 'offline_6',
 'offline_7',
 'offline_8',
 'offline_9',
 'offline_10',
 'offline_11',
 'offline_12',
 'offline_13',
 'offline_14',
 'offline_15',
 'offline_16',
 'offline_17',
 'offline_18',
 'offline_19',
 'offline_20',
 'offline_21',
 'offline_22',
 'offline_23',
 'offline_24',
 'offline_25',
 'offline_26',
 'offline_27',
 'offline_28',
 'offline_29',
 'offline_30',
 'offline_31',
 'online_0',
 'online_1',
 'online_2',
 'online_3',
 'online_4',
 'online_5',
 'online_6',
 'online_7',
 'online_8',
 'online_9',
 'online_10',
 'online_11',
 'online_12',
 'online_13',
 'online_14',
 'online_15',
 'online_16',
 'online_17',
 'online_18',
 'online_19',
 'online_20',
 'online_21',
 'online_22',
 'online_23',
 'online_24',
 'online_25',
 'online_26',
 'online_27',
 'online_28',
 'online_29',
 'online_30']

In [19]:
import json

json.dump(communication.comm_history, open("temp/comm_history.json", "w"), indent=4)