In [1]:
import sys
sys.path.append("/home/zf/pycharm/DecentralizedLLM")


import tqdm

from simple_pir.pir import PIRClient

from desi_llm.glm6b.obfuscated_layer import WrappedGLMBlock

from desi_llm.nodes.computation_node import ComputationNode
from desi_llm.nodes.model_provider import ModelProvider
from desi_llm.nodes.obfuscator import ObfuscatorNode
from llm_bases.chatglm6b import ChatGML6B

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# ============= Setup stage ===========================================
# Load model to the CPU memory
glm6b = ChatGML6B()

# Model provider work:
model_provider = ModelProvider(glm6b)

obfuscators = []
computation_nodes = []
for layer in tqdm.tqdm(glm6b.condgen.transformer.layers):
    obfuscators.append(ObfuscatorNode())
    computation_nodes.append(ComputationNode(WrappedGLMBlock(layer.layer_id)))


Loading checkpoint shards: 100%|██████████| 8/8 [00:19<00:00,  2.38s/it]
100%|██████████| 28/28 [01:00<00:00,  2.16s/it]


In [3]:
print("Model provider setup PIR sever...")
# Generte PIR hints
obfuscator_share = model_provider.generate_shared_embedding()
pir_lwe_mat, pir_hint = model_provider.setup_pir_server()
print("User creating PIR client...")
pir_client = PIRClient(pir_lwe_mat, pir_hint, model_provider.pir_server.get_scale_factor(), model_provider.pir_server.plain_modulus)

Model provider setup PIR sever...
User creating PIR client...


In [4]:
import torch
import numpy as np

In [5]:
from desi_llm.glm6b.obfuscated_layer import keys_to_tensor

In [6]:
# Load all the obfuscated models
model_save_dir = "./saved_model/"

for i, computation_node in tqdm.tqdm(enumerate(computation_nodes)):
    computation_node.layer.load_state_dict(torch.load(model_save_dir + f"wrappedGLM_{i}.pth", map_location="cpu"))

for i, obfuscator in tqdm.tqdm(enumerate(obfuscators)):
    obfuscator.key = torch.load(model_save_dir + f"obfuscatorKey_{i}.pth", map_location="cpu")
    model_provider.obfuscation_nodes[i].key = torch.load(model_save_dir + f"providerKey_{i}.pth", map_location="cpu")

28it [00:27,  1.00it/s]
28it [00:07,  3.97it/s]


In [7]:
if False:
    # Re-generate all the obfuscations (!!! Taking a long time, can be hours !!!)
    print("Model provider generating obfuscated layers...")
    obfuscated_layers = model_provider.generate_obfuscations("cuda")

    # Set up obfuscator nodes and computation nodes
    obfuscators = []
    computation_nodes = []
    for obfuscated_layer in tqdm.tqdm(obfuscated_layers):
        obfuscator = ObfuscatorNode()
        obfuscators.append(obfuscator)
        # Twice-obfuscation
        computation_nodes.append(ComputationNode(obfuscator.obfuscate(obfuscated_layer)))

    # Save all the obfuscated models
    model_save_dir = "./saved_model/"

    for i, computation_node in tqdm.tqdm(enumerate(computation_nodes)):
        torch.save(computation_node.layer.state_dict(), model_save_dir + f"wrappedGLM_{i}.pth")

    for i, obfuscator in tqdm.tqdm(enumerate(obfuscators)):
        torch.save(obfuscator.key, model_save_dir + f"obfuscatorKey_{i}.pth")
        torch.save(model_provider.obfuscation_nodes[i].key, model_save_dir + f"providerKey_{i}.pth")

In [8]:
print("Executed here, 2024.1.17 17:24")

Executed here, 2024.1.17 17:24


In [9]:
device = "cuda:1"
for i in range(len(model_provider.input_transformations)):
    model_provider.input_transformations[i] = model_provider.input_transformations[i].half().to(device)
for node in computation_nodes:
    node.layer = node.layer.half().to(device)
for o in model_provider.obfuscation_nodes + obfuscators:
    o.key = keys_to_tensor(o.key, float_type=torch.half, int_type=torch.int)
    o.key.qkv = [[e1.to(device), e2.to(device)] for e1, e2 in o.key.qkv]
    o.key.mlp_output = o.key.mlp_output.to(device)
    o.key.attn_out=  o.key.attn_out.to(device)

In [10]:
model_provider.word_embedding_key = torch.tensor(model_provider.word_embedding_key).half().to(device)

In [11]:
from desi_llm.secure_inference import NetworkSimulator
ns = NetworkSimulator(0.01, 100 * 1024 * 1024)

In [12]:
token_ids, position_ids, attention_masks = glm6b.get_tokenization("Hello, who are you?")
token_ids = token_ids[0].tolist()
print("Token ids", token_ids)
position_ids = position_ids.to(device)
attention_masks = attention_masks.to(device)

Token ids [19316, 6, 172, 118, 120, 31, 130001, 130004]


  tensor = as_tensor(value)


In [13]:
# The user query the word embeddings

token_ids += [i + glm6b.max_token_id for i in token_ids]
pir_queries = []
for token_id in token_ids:
    pir_queries.append(pir_client.query(token_id))

ns.transfer(np.array(pir_queries), "pir: client queries")

pir_answers = []
for q in pir_queries:
    pir_answers.append(model_provider.pir_server.answer(q))
    

ns.transfer(np.array(pir_answers), "pir: server answers")

recovered_perm_ids = []
for token_id, a in zip(token_ids, pir_answers):
    recovered_perm_ids.append(pir_client.recover(token_id, a))

In [14]:
import numpy as np
recovered_perm_ids = np.array(recovered_perm_ids)
recovered_perm_ids = recovered_perm_ids[:len(recovered_perm_ids) // 2] * 1000 + recovered_perm_ids[len(recovered_perm_ids) // 2:]
print(recovered_perm_ids)
print(model_provider.permutation[token_ids[:8]])

[ 64656  50031  11507  15176  51657  70418 126738 107543]
[ 64656  50031  11507  15176  51657  70418 126738 107543]


In [15]:
# The PIR phase

print(f"Current communication and time: {ns.total_comm / (1024 ** 2):.2f} MB, {ns.total_time:.2f} s")

Current communication and time: 0.12 MB, 0.02 s


In [16]:
from desi_llm.common.utils import generate_random_linear_combination, generate_random_transformations
n_random_vectors = 1
seq_len = len(recovered_perm_ids)

In [17]:
# The user first reconstruct the shares of word embeddings
from desi_llm.glm6b.configs import GLM6BConfig
from desi_llm.common.utils import random_vec_with_seed, generate_random_linear_combination

import torch

embedding_share_0 = []

for perm_id in recovered_perm_ids:
    embedding_share_0.append(random_vec_with_seed(perm_id, GLM6BConfig.model_dim, [-1, 1]))

embedding_share_0 = np.array(embedding_share_0)
# The user re-mask the word embedding
embedding_share_0 += random_vec_with_seed(19260817, embedding_share_0.shape, [-1, 1])
# The user generate the linear combinations

# random_transformation = generate_random_transformations(seq_len, n_random_vectors)
random_transformation = (torch.ones(seq_len, 1, 1), torch.ones(seq_len, 1, 1))

rlcs_0 = generate_random_linear_combination(torch.tensor(embedding_share_0)[:, None, :].half().to(device), 
                                            n_random_vectors, random_transformation[0].half().to(device))
# The user sends the recovered permutation to the obfuscator, 
# along with the transformation matrices so that they both maintain a share of the random vector


ns.transfer(np.array(recovered_perm_ids), "word embedding: user recovered permutation")
print(f"Current communication and time: {ns.total_comm / (1024 ** 2):.2f} MB, {ns.total_time:.2f} s")

# The transformation can also be synced via a random seed
embedding_share_1 = obfuscator_share[recovered_perm_ids]
embedding_share_1 -= random_vec_with_seed(19260817, embedding_share_1.shape, [-1, 1])


rlcs_1 = torch.tensor(embedding_share_1)[:, None, :].half().to(device)
rlcs_0 = rlcs_0 @ model_provider.word_embedding_key
rlcs_1 = rlcs_1 @ model_provider.word_embedding_key

random_transformation = (random_transformation[0].half().to(device), random_transformation[1].half().to(device))

Current communication and time: 0.12 MB, 0.03 s


In [None]:
# rlcs_1 += rlcs_0
# rlcs_0 -= rlcs_0
for i in tqdm.tqdm(range(len(obfuscators))):
    with torch.no_grad():
        qkv_rlcs_proj = model_provider.input_transformations[i](rlcs_0, position_ids, only_projection=True)
        qkv_rlcs_tran = model_provider.input_transformations[i](rlcs_1, position_ids)
        # 3 * [seq_len, k, n_heads, head_dim]
        
        residual_proj = model_provider.input_transformations[i](rlcs_0, position_ids, only_affine=True, only_projection=True)
        residual_tran = model_provider.input_transformations[i](rlcs_1, position_ids, only_affine=True)


    qkv_rlcs = list(qkv_rlcs_proj) + list(qkv_rlcs_tran)
    # print(qkv_rlcs[0] + qkv_rlcs[3])
    qkv_rlcs = list(model_provider.obfuscation_nodes[i].forward_pass(qkv_rlcs[:3], "qkv")) + \
               list(model_provider.obfuscation_nodes[i].forward_pass(qkv_rlcs[3:], "qkv"))
    # (2 * 3(q, k, v)) * [seq_len, k, n_heads, head_dim]

    
    residual_rlcs = [residual_proj, residual_tran]
    residual_rlcs = [model_provider.obfuscation_nodes[i].forward_pass(rlc, "attn_out") for rlc in residual_rlcs]
    # (2 * seq_len) * [k, model_dim]
    

    # send rlcs to the obfuscator
    ns.transfer(qkv_rlcs + residual_rlcs, f"forward embedding: send to obfusctor {i}")
    
    qkv_rlcs = list(obfuscators[i].forward_pass(qkv_rlcs[:3], "qkv")) + list(obfuscators[i].forward_pass(qkv_rlcs[3:], "qkv"))
    # (2 * 3(q, k, v)) * [seq_len, k, n_heads, head_dim]
    residual_rlcs = [obfuscators[i].forward_pass(rlc, "attn_out") for rlc in residual_rlcs]
    # 2 * [seq_len, k, model_dim]

    # send rlcs to the computation node
    ns.transfer(qkv_rlcs + residual_rlcs, f"forward embedding: send to computation node {i}")


    forward_embedding_share_0 = [torch.sum(rlc * random_transformation[1][:, 0, :, None, None], dim=1, keepdim=True) for rlc in qkv_rlcs[:3]]
    forward_embedding_share_1 = qkv_rlcs[3:]
    # (2 * 3(qkv)) * [seq_len, 1, n_heads, head_dim]
    forward_embedding = [a + b for a, b in zip(forward_embedding_share_0, forward_embedding_share_1)]

    residual_embedding_share_0 = torch.sum(residual_rlcs[0] * random_transformation[1][:, 0, :, None], dim=1, keepdim=True)
    residual_embedding_share_1 = residual_rlcs[1]
    # 2 * [seq_len, 1, model_dim]

    residual_embedding = residual_embedding_share_0 + residual_embedding_share_1
    # [seq_len, 1, model_dim]
#     print("Residual=============\n", residual_embedding)

    with torch.no_grad():
        embedding = computation_nodes[i].forward_pass(forward_embedding, attention_masks, residual_embedding)

    # Check the 'raw' embedding
#     print("Output==============\n", embedding)
#     input()

    # computation node generate linear combinations
    embedding_share_0 = 2 * torch.std(embedding) * (torch.rand_like(embedding) - 0.5)
    embedding_share_1 = embedding - embedding_share_0

    random_transformation = generate_random_transformations(seq_len, n_random_vectors)
    random_transformation = random_transformation[0].half().to(device), random_transformation[1].half().to(device)
    rlcs_0 = generate_random_linear_combination(embedding_share_0, n_random_vectors, random_transformation[0])
    rlcs_1 = embedding_share_1


    rlcs_0 = obfuscators[i].forward_pass(rlcs_0, "mlp_output", reverse=True)
    rlcs_1 = obfuscators[i].forward_pass(rlcs_1, "mlp_output", reverse=True)

    rlcs_0 = model_provider.obfuscation_nodes[i].forward_pass(rlcs_0, "mlp_output", reverse=True)
    rlcs_1 = model_provider.obfuscation_nodes[i].forward_pass(rlcs_1, "mlp_output", reverse=True)

  0%|          | 0/28 [00:00<?, ?it/s]

 tensor([[[-1.5605e+00, -3.8086e-02, -1.5391e+00,  ..., -1.7393e+00,
           6.7529e-01,  1.2617e+00]],

        [[ 1.8203e+00, -6.6309e-01,  3.1104e-01,  ...,  1.7871e-01,
           2.0820e+00, -1.5225e+00]],

        [[ 6.5527e-01,  6.3086e-01, -1.5645e+00,  ..., -1.2881e+00,
          -1.5271e-01,  1.0439e+00]],

        ...,

        [[ 6.5869e-01,  2.7930e-01,  7.4512e-01,  ..., -3.2837e-01,
           9.6094e-01,  5.5762e-01]],

        [[ 1.8359e-01, -3.9014e-01,  9.5625e+00,  ..., -7.4414e-01,
          -2.2754e-01,  1.5552e-01]],

        [[ 2.7441e-01, -5.3467e-02, -1.6772e-01,  ...,  1.7529e-01,
          -8.9111e-03, -1.7676e-01]]], device='cuda:1', dtype=torch.float16)
 tensor([[[-1.2256,  0.3003, -1.9492,  ..., -2.1230,  0.7051,  1.2422]],

        [[ 0.7681, -0.2625, -0.0475,  ...,  0.2107,  0.7202, -0.9814]],

        [[ 0.5845,  0.4648, -0.7349,  ..., -0.5518, -0.9180,  0.9951]],

        ...,

        [[ 0.0180,  0.0241,  0.1892,  ...,  0.2820,  0.1628,  0.1159]],

In [None]:
embedding = obfuscators[i].forward_pass(embedding, "mlp_output", reverse=True)
embedding = model_provider.obfuscation_nodes[i].forward_pass(embedding, "mlp_output", reverse=True)

def recover_final_embedding(embedding: torch.Tensor):
    final_embedding = embedding[-1, 0, :].cpu().float()
    logits = glm6b.condgen.lm_head.float()(final_embedding)
    token_id = torch.argmax(logits).item()
    print(token_id)
    print(glm6b.tokenizer.decode(token_id))

recover_final_embedding(embedding)

In [None]:
print(i)

In [None]:
torch.sum(qkv_rlcs[0] * random_transformation[1][:, 0, :, None, None], dim=1).shape

In [None]:
# del sys.modules['desi_llm.glm6b.obfuscated_layer']
del sys.modules['desi_llm.nodes.model_provider']

In [None]:
from functools import partial
del sys.modules['desi_llm.nodes.obfuscator']
from desi_llm.nodes.obfuscator import ObfuscatorNode
for o in obfuscators:
    o.forward_pass = partial(ObfuscatorNode.forward_pass, o)
for o in model_provider.obfuscation_nodes:
    o.forward_pass = partial(ObfuscatorNode.forward_pass, o)

In [None]:
del sys.modules['desi_llm.nodes.computation_node']
from desi_llm.nodes.computation_node import ComputationNode
for c in computation_nodes:
    c.forward_pass = partial(ComputationNode.forward_pass, c)

In [None]:
# Some temporary codes
glm6b.tokenizer("Thousand")

In [None]:
embedding.shape

In [None]:
del random_transformations

In [None]:
import matplotlib.pyplot as plt

seqs = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12, 15,  20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30]
seqs = np.array(seqs)

def compute_avg_beats(end_time: float):
    num_beats = np.sum((seqs >= end_time - 5) & (seqs <= end_time))
    rate = 20 * num_beats
    return rate

rates = [compute_avg_beats(i) for i in range(5, 30)]
plt.plot(rates)
plt.show()

In [None]:
tokenization = glm6b.get_tokenization("Hello, who are you?")

In [None]:
initial_state = glm6b.get_initial_state(tokenization[0])
state_0 = glm6b.condgen.transformer.layers[0].float()(initial_state, tokenization[1], tokenization[2], torch.tensor(0))

In [None]:
import torch.nn.functional as F

print(F.layer_norm(initial_state, [4096]))
print(state_0)
print(torch.tensor(embedding_share_0 + embedding_share_1).float() @ model_provider.word_embedding_key.cpu().float())

In [None]:
raw_embedding = glm6b.condgen.transformer.word_embeddings.weight[:ChatGML6B.max_token_id].numpy().astype(np.float32)[token_ids[:8]]

In [None]:
print(raw_embedding)

In [None]:
torch.tensor(model_provider.shared_rotated_word_embedding[token_ids[:8]] + obfuscator_share[model_provider.permutation[token_ids[:8]]]).half().cuda() @ torch.tensor(model_provider.word_embedding_key).half().cuda()

In [None]:
state_00 = model_provider.input_transformations[0].cpu().float().projection_part_transform(initial_state, tokenization[1]) + \
           model_provider.input_transformations[0].cpu().float().translation_part_transform(tokenization[1])
print(state_00[0])

In [None]:
len(model_provider.pir_server.data_matrix.flatten())

In [None]:
model_provider.pir_server.data_matrix.flatten()[19316] * 1000 + model_provider.pir_server.data_matrix.flatten()[19316 + glm6b.max_token_id]

In [None]:
print(embedding)

In [None]:
model_provider.pir_server.get_scale_factor()

In [None]:
from desi_llm.common.utils import random_orthogonal
random_orthogonal(10)

In [None]:
test_batch = torch.rand(8, 1, 4096).half().cuda()
print(test_batch)

In [None]:
test_batch_forward_1 = obfuscators[i].forward_pass(test_batch, "v")

In [None]:
test_batch_recovered = obfuscators[i].forward_pass(test_batch_forward_1, "v", reverse=True)
print(test_batch_recovered)

In [None]:
obfuscators[i].key.mlp_output

In [None]:
torch.sum(torch.tensor([[1,2,3]]), dim=1, keepdims=True)

In [None]:
random_vec_with_seed(130005, [10], [-1, 1])

In [None]:
a = np.array([1, 2, 3, 4, 5])
a[[4, 3, 2, 1, 0]] = a
print(a)