In [1]:
from llm_bases.chatglm6b import ChatGML6B
glm = ChatGML6B()

  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 8/8 [00:10<00:00,  1.27s/it]


In [2]:
import sys
try:
    del sys.modules["split_llm.glm6b.wrapped_layer"]
except:
    pass

In [3]:
from split_llm.glm6b.wrapped_layer import Attention_GLM_Wrapped, FeedForward_GLM_Wrapped, copy_attention, copy_feedforward
from split_llm.common.torch_utils import relative_error

In [4]:
transformer_layer = glm.condgen.transformer.layers[0].float()
attn_wrapped = Attention_GLM_Wrapped(4096, 32, 0)
feedforward_wrapped = FeedForward_GLM_Wrapped(4096, 32, 0)
copy_attention(transformer_layer, attn_wrapped)
copy_feedforward(transformer_layer, None, feedforward_wrapped)

In [5]:
import torch
from split_llm.glm6b.utils import generate_attention_mask, generate_position_ids

In [6]:
xs = torch.rand([10, 1, 4096]).cuda()
transformer_layer.cuda()
attn_wrapped.cuda()
feedforward_wrapped.cuda()
position_ids = generate_position_ids(10, 10).cuda()
attention_mask = generate_attention_mask(10, 10).cuda()

In [7]:
output_raw = transformer_layer(xs, position_ids, attention_mask, 0)[0]

In [8]:
# Compute the output by wrapped layers
xs_normalized = transformer_layer.input_layernorm(xs)
attn_out_raw = attn_wrapped(xs_normalized, position_ids)
attn_out = attn_out_raw + (2 * 28) ** 0.5 * xs_normalized
output_wrapped = feedforward_wrapped(attn_out)


In [9]:
print(f"Error of transformer layer protocol: {relative_error(output_wrapped, output_raw):.5f}")

Error of transformer layer protocol: 0.00005


In [10]:
print("Attention input L2 norm", torch.sqrt(torch.mean(torch.square(xs_normalized))).item() * (2 * 28) ** 0.5)
print("Attention output L2 norm", torch.sqrt(torch.mean(torch.square(attn_out_raw))).item())

Attention input L2 norm 7.403954387942834
Attention output L2 norm 0.4486609101295471


It seems that the residual connected attention input has a larger scale ($\approx 20\times$) than the attention ouptut, which means that the adversary could be easy to distinguish between them?

Specifically, suppose the attention is to be protected. 

Each time we feed $H, H'$ to attention but only use the true $H$ to produce the output $a$, then it could be easy to distinguish between $(H, H + a)$ and $(H, H' + a)$...