In [1]:
import os
from split_llm.glm6b.wrapped_layer import Attention_GLM_Wrapped, copy_attention, FeedForward_GLM_Wrapped, copy_feedforward
from split_llm.glm6b.utils import generate_position_ids

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from typing import List
import sys
try:
    del sys.modules["split_llm.glm6b.secure_inference"]
except:
    pass

In [3]:
from llm_bases.chatglm6b import ChatGML6B
glm = ChatGML6B()

Loading checkpoint shards: 100%|██████████| 8/8 [00:09<00:00,  1.25s/it]


In [4]:
from llm_bases.chatglm6b_official.modeling_chatglm import GLMBlock


raw_glm_layers: List[GLMBlock] = glm.condgen.transformer.layers
attentions: List[Attention_GLM_Wrapped] = []
attentions_public: List[Attention_GLM_Wrapped] = []
ffs: List[FeedForward_GLM_Wrapped] = []
for i in range(28):
    transformer_layer = raw_glm_layers[i].float()
    
    # The private attention layer
    attn_wrapped = Attention_GLM_Wrapped(4096, 32, i)
    copy_attention(transformer_layer, attn_wrapped)
    attn_wrapped.requires_grad_(False)
    attentions.append(attn_wrapped.cuda())

    # The public attention layer
    attn_wrapped_public = Attention_GLM_Wrapped(4096, 32, i)
    attn_wrapped_public.qkv_weight = None
    attn_wrapped_public.qkv_bias = None
    attn_wrapped_public.attn_out_weight = None
    attn_wrapped_public.attn_out_bias = None
    attn_wrapped_public.positional_embedding = attn_wrapped.positional_embedding
    attn_wrapped_public.requires_grad_(False)
    attentions_public.append(attn_wrapped_public.cuda())

    ff_wrapped = FeedForward_GLM_Wrapped(4096, 32, i)
    if i == 27:
        copy_feedforward(transformer_layer, None, ff_wrapped)
        ff_wrapped.layernorm_out = glm.condgen.transformer.final_layernorm.float()
    else:
        copy_feedforward(transformer_layer, raw_glm_layers[i + 1].float(), ff_wrapped)
    ff_wrapped.requires_grad_(False)
    ffs.append(ff_wrapped.cuda())

word_embedding = glm.condgen.transformer.word_embeddings.weight.float().cuda()
lm_head = glm.condgen.lm_head.float().cuda()
input_layernorm = raw_glm_layers[0].input_layernorm.float().cuda()

In [5]:
from split_llm.common.communication import Communication, Node, SimulatedCommunication
communication = SimulatedCommunication(["n0", "n1", "n2"])
communication.new_stage("Test")

n0 = Node(communication, "n0")
n1 = Node(communication, "n1")
n2 = Node(communication, "n2")

n0.space.attentions = attentions
n1.space.attentions = attentions_public
n0.space.ffs = n1.space.ffs = ffs
n0.space.word_embedding = word_embedding
n0.space.input_layernorm = input_layernorm

In [6]:
from split_llm.glm6b.secure_inference import GLM_Protocol

whole_protocol = GLM_Protocol(n0, n1, n2, 1, 500, device="cuda")

In [7]:
whole_protocol.prepare()

In [8]:
import torch

In [9]:
def get_input_tensor(query: str):
    input_ids, _, _ = glm.get_tokenization(query)
    input_ids = input_ids[0]
    input_selector = torch.zeros(len(input_ids), glm.n_tokens)
    for i in range(len(input_ids)):
        input_selector[i, input_ids[i]] = 1
    return input_selector

input_tensor = get_input_tensor("Hello")
print(input_tensor.shape)

torch.Size([3, 130528])


  tensor = as_tensor(value)


In [10]:
def iteratively_generate(query: str, length: int):
    input_tensor = get_input_tensor(query).cuda()

    generation_start_tensor = input_tensor[-1:]
    input_tensor = input_tensor[:-1, :]
    generated_ids = []
    for i in range(length + 1):
        whole_protocol.offline_execute(len(input_tensor))
        n1.storage[f"{whole_protocol.name}:x"] = input_tensor
        whole_protocol.online_execute()

        if generation_start_tensor is None:
            next_id = n1.storage[f"{whole_protocol.name}:z"][0]
            generated_ids.append(next_id)
            print(glm.decode(generated_ids[-1]), end=' ')
            if next_id == glm.condgen.config.eos_token_id:
                break
            input_tensor = torch.zeros([1, glm.n_tokens]).cuda()
            input_tensor[0, next_id] = 1
        else:
            input_tensor = generation_start_tensor
            generation_start_tensor = None
    print()
    print(glm.decode(generated_ids), end=' ')

In [11]:
iteratively_generate("Tell me about Trump", 300)

Donald Trump is a former American politician who served as the  4 5 th President of the United States from January  2 0 ,  2 0 1 7 , to January  2 0 ,  2 0 2 1 . He was born on June  1 4 ,  1 9 4 6 , in New York City , New York . Trump is a businessman and real estate developer who has been involved in various business ventures and has been the subject of numerous investigations and legal challenges . 
 
 During his presidency , Trump was known for his controversial policies , including his efforts to build a wall on the U . S . - Mexico border , his stance on immigration , and his handling of the COVID - 1 9 pandemic . He also faced numerous controversies and allegations , including sexual harassment and assault , business fraud , and political contributions . 
 
 After leaving office , Trump continued to be a controversial figure , facing numerous legal challenges and investigations . He has been the subject of numerous lawsuits , including one related to the  2 0 2 1 election .  Don

In [12]:
whole_protocol.layer_protocols[-1].__dict__

{'n0': <split_llm.common.communication.Node at 0x7f8aab106590>,
 'n1': <split_llm.common.communication.Node at 0x7f8b886a42e0>,
 'n2': <split_llm.common.communication.Node at 0x7f8aab107070>,
 'layer': 27,
 'max_generation_length': 500,
 'mask_scale': {'qkv/u': 1,
  'qkv/v': 1,
  'qkv/w': 1,
  'dot_product/u': 1,
  'dot_product/v': 1,
  'dot_product/w': 1,
  'softmax/x': 1,
  'softmax/z': 1,
  'weighted_sum/u': 1,
  'weighted_sum/v': 1,
  'weighted_sum/w': 1,
  'attn_out/u': 1,
  'attn_out/v': 1,
  'attn_out/w': 1,
  'layernorm_in/x': 1,
  'layernorm_in/z': 1,
  'gelu/x': 1,
  'gelu/z': 1,
  'layernorm_out/x': 1,
  'layernorm_out/z': 1},
 'device': 'cuda',
 'name': 'transformer_layer_27',
 'attn_name': 'transformer_layer_27/attn',
 'ff_name': 'transformer_layer_27/ff',
 'attn_protocol': <split_llm.glm6b.secure_inference.GLM_AttentionProtocol at 0x7f8aaad949d0>,
 'ff_protocol': <split_llm.glm6b.secure_inference.GLM_FeedForwardProtocol_PlainWeights at 0x7f8aaad94a60>}

In [13]:
print("\n".join(list(n0.storage.keys())[:10]))

embedding_retrieval/onehot_matmul:x
transformer_layer_0/attn/qkv_matmul/SS_Mul__CX_N0_Y_N1:x
transformer_layer_0/attn/dot_product:beaver_u0 extended
transformer_layer_0/attn/weighted_sum:beaver_u0 extended
transformer_layer_0/attn/attn_out/SS_Mul__CX_N0_Y_N1:x
transformer_layer_1/attn/qkv_matmul/SS_Mul__CX_N0_Y_N1:x
transformer_layer_1/attn/dot_product:beaver_u0 extended
transformer_layer_1/attn/weighted_sum:beaver_u0 extended
transformer_layer_1/attn/attn_out/SS_Mul__CX_N0_Y_N1:x
transformer_layer_2/attn/qkv_matmul/SS_Mul__CX_N0_Y_N1:x


In [14]:
glm.condgen.lm_head.weight

Parameter containing:
tensor([[-8.4076e-03, -9.3689e-03, -5.4436e-03,  ..., -6.4545e-03,
          1.6998e-02,  1.1108e-02],
        [-1.0071e-02,  5.8022e-03,  4.8018e-04,  ..., -4.2701e-04,
          1.0252e-03, -1.6556e-03],
        [ 1.9424e-02,  6.3477e-03,  2.4933e-02,  ...,  5.7297e-03,
          1.2512e-02,  9.4147e-03],
        ...,
        [-1.0078e-02,  3.0041e-03,  2.4376e-03,  ..., -4.7684e-06,
          1.5430e-03,  1.1053e-03],
        [-9.9945e-03,  4.4479e-03,  6.2141e-03,  ...,  1.7560e-04,
          1.4286e-03, -1.1883e-03],
        [-9.3536e-03,  1.7376e-03,  5.7373e-03,  ..., -1.0910e-03,
          4.3945e-03, -1.2541e-03]], device='cuda:0')

In [15]:
glm.chat("fuck")

The dtype of attention mask (torch.int64) is not bool


RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument mat1 in method wrapper_CUDA_addmm)