In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
from split_llm.glm6b.wrapped_layer import Attention_GLM_Wrapped, copy_attention
from split_llm.glm6b.utils import generate_position_ids

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from llm_bases.chatglm6b import ChatGML6B
glm = ChatGML6B()

Loading checkpoint shards: 100%|██████████| 8/8 [00:11<00:00,  1.45s/it]


In [3]:
transformer_layer = glm.condgen.transformer.layers[0].float()
attn_wrapped = Attention_GLM_Wrapped(4096, 32, 0)
copy_attention(transformer_layer, attn_wrapped)
attn_wrapped.cuda()

Attention_GLM_Wrapped(
  (positional_embedding): GLMPositionalEmbedding()
)

In [4]:
import torch
x = torch.normal(0, 1, [10, 1, 4096]).cuda()

In [5]:
# Generate the hidden representations in a normal transformer layer

h_qkv = x @ attn_wrapped.qkv_weight.T + attn_wrapped.qkv_bias
qs, ks, vs = h_qkv.view(-1, 1, 32, 128 * 3).chunk(3, dim=-1)
qs, ks = attn_wrapped.positional_embedding(qs, ks, generate_position_ids(10, 10).cuda())
scores = attn_wrapped.generate_logit_scores(qs, ks)
softmax_scores = attn_wrapped.generate_softmax_scores(scores, dim=1)
weighted_v = attn_wrapped.generate_weighted_values(softmax_scores, vs)
attn_out = weighted_v @ attn_wrapped.attn_out_weight.T + attn_wrapped.attn_out_bias

In [65]:
try:
    import sys
    del sys.modules['split_llm.glm6b.utils']
    del sys.modules['split_llm.protocols.base_protocols']
    del sys.modules['split_llm.protocols.ss_mul_with_memory']
    del sys.modules['split_llm.protocols.element_wise']
    del sys.modules['split_llm.glm6b.secure_inference']
    del sys.modules['split_llm.glm6b.wrapped_layer']
    del sys.modules['split_llm.common.torch_utils']
    print("delete complete!")
except Exception as e:
    print(e)
from split_llm.glm6b.secure_inference import GLMAttentionProtocol
from split_llm.common.torch_utils import relative_error

delete complete!


In [66]:
transformer_layer = glm.condgen.transformer.layers[0].float()
attn_wrapped = Attention_GLM_Wrapped(4096, 32, 0)

copy_attention(transformer_layer, attn_wrapped)

In [67]:
from split_llm.common.communication import Communication, Node, SimulatedCommunication
communication = SimulatedCommunication(["n0", "n1", "n2"])
communication.new_stage("Test")

n0 = Node(communication, "n0")
n1 = Node(communication, "n1")
n2 = Node(communication, "n2")

In [68]:
n0.space.attentions = [attn_wrapped.cuda()]

In [69]:
protocol = GLMAttentionProtocol(
    n0, n1, n2, 0, 10, device="cuda"
)
protocol.prepare()
protocol.offline_execute(10)

In [70]:
x0 = torch.normal(0, 1, [10, 1, 4096]).cuda()
x1 = x - x0
n0.storage[f"{protocol.name}:x0"] = x0
n1.storage[f"{protocol.name}:x1"] = x1


In [71]:
# Test computing QKV
protocol.online_step_qkv()

In [72]:
qkv_computed = n0.storage[f"{protocol.name}:h0"] + n1.storage[f"{protocol.name}:h1"]
print(f"QKV error: {relative_error(qkv_computed, h_qkv):.5f}")

QKV error: 0.00046


In [73]:
# Test computing dot-product
protocol.online_step_dot_product()


In [74]:
computed_scores = n0.storage[f"{protocol.name}:s0"] + n1.storage[f"{protocol.name}:s1"]
print(f"Scores error: {relative_error(computed_scores, scores):.5f}")

Scores error: 0.00051


In [75]:
# Test computing softmax scores
protocol.online_step_softmax()

In [76]:
computed_softmax_scores = n0.storage[f"{protocol.name}:s0"] + n1.storage[f"{protocol.name}:s1"]
print(f"Softmax Scores error: {relative_error(computed_softmax_scores, softmax_scores):.5f}")

Softmax Scores error: 0.00054


In [77]:
# Test compute weighted values
protocol.online_step_weighted_v()

In [78]:
computed_v = n0.storage[f"{protocol.name}:h0"] + n1.storage[f"{protocol.name}:h1"]
print(f"Weighted V error: {relative_error(computed_v, weighted_v):.5f}")

Weighted V error: 0.00058


In [79]:
protocol.online_step_attn_out()

In [82]:
computed_attn_out = n0.storage[f"{protocol.name}:z0"] + n1.storage[f"{protocol.name}:z1"]
print(f"Attn Out error: {relative_error(computed_attn_out, attn_out):.5f}")

Attn Out error: 0.00078


In [83]:
communication.comm_history

[[{'from': 'n2',
   'to': 'n0',
   'header': 'Attn_0/qkv_matmul/SS_Mul__CX_N0_Y_N1:beaver_u',
   'size': 201326592},
  {'from': 'n0',
   'to': 'n1',
   'header': 'Attn_0/qkv_matmul/SS_Mul__CX_N0_Y_N1:x-u',
   'size': 201326592},
  {'from': 'n2',
   'to': 'n0',
   'header': 'Attn_0/dot_product:beaver_u0 extended',
   'size': 8192000},
  {'from': 'n2',
   'to': 'n1',
   'header': 'Attn_0/dot_product:beaver_u1 extended',
   'size': 8192000},
  {'from': 'n2',
   'to': 'n0',
   'header': 'Attn_0/weighted_sum:beaver_u0 extended',
   'size': 8192000},
  {'from': 'n2',
   'to': 'n1',
   'header': 'Attn_0/weighted_sum:beaver_u1 extended',
   'size': 8192000},
  {'from': 'n2',
   'to': 'n0',
   'header': 'Attn_0/attn_out/SS_Mul__CX_N0_Y_N1:beaver_u',
   'size': 67108864},
  {'from': 'n0',
   'to': 'n1',
   'header': 'Attn_0/attn_out/SS_Mul__CX_N0_Y_N1:x-u',
   'size': 67108864},
  {'from': 'n2',
   'to': 'n1',
   'header': 'Attn_0/qkv_matmul/SS_Mul__CX_N0_Y_N1:beaver_v',
   'size': 163840},
  {'