# Encrypted Language Models

In this tutorial, we will look at Attention, Embedding and Layer Normalization
layers needed in language models. Then we'll proceed to build Bert and GPT.

## Setup

We first import the `torch` and `curl` libraries, and initialize `curl`.

In [1]:
import curl
import curl.nn as nn
import torch
import logging

curl.init()
torch.set_num_threads(1)

Using Communicator type:  <class 'curl.communicator.distributed_communicator.DistributedCommunicator'>
[<>] Waiting for connections...
[<>] DEFAULT ARGS: {'DISTRIBUTED_BACKEND': 'gloo', 'RENDEZVOUS': 'file:///tmp/vcrypten-icrypten-Tcrypten-acrypten-ycrypten-Ycrypten-ccrypten-Zcrypten-jcrypten-m', 'WORLD_SIZE': 1, 'RANK': 0, 'TTP': False}
[Device] LUTs initialized for cpu



## Attention

Let's build attention mechanism using torch and compare to the built in curl one.

In [2]:
import math

class Attention(torch.nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(Attention, self).__init__()

        assert embed_dim % num_heads == 0, "invalid heads and embedding dimension"

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.search_dim = embed_dim // num_heads

        self.search = torch.nn.Linear(embed_dim, 3 * embed_dim)
        self.proj = torch.nn.Linear(embed_dim, embed_dim)

    def forward(self, x):
        batch_size = x.shape[0]
        seq_len = x.shape[1]

        query, key, value = self.search(x).split(self.embed_dim, dim=2)
        query = query.reshape(batch_size, seq_len, self.num_heads, self.search_dim).transpose(1, 2)
        key = key.reshape(batch_size, seq_len, self.num_heads, self.search_dim).permute(0, 2, 3, 1)
        value = value.reshape(batch_size, seq_len, self.num_heads, self.search_dim).transpose(1, 2)

        attn = query.matmul(key) / math.sqrt(query.size(-1))
        attn = attn.softmax(dim=-1)

        y = attn.matmul(value).transpose(1, 2).reshape(batch_size, seq_len, self.embed_dim)
        y = self.proj(y)
        return y

layer = Attention(8, 2)
sw = layer.search.weight
sb = layer.search.bias
pw = layer.proj.weight
pb = layer.proj.bias
data = torch.tensor([1, 2, 0, 1, 2, 3, 4, 2], dtype=torch.float).reshape(1, 1, 8)
layer(data)

tensor([[[ 0.7713, -0.7320, -0.5367, -0.1225, -0.1695, -0.0793,  0.0263,
           0.3041]]], grad_fn=<ViewBackward0>)

In [3]:
layer = nn.Attention(8, 2)
layer.search.weight = sw
layer.search.bias = sb
layer.proj.weight = pw
layer.proj.bias = pb
layer.encrypt(src=0)
data = torch.tensor([1, 2, 0, 1, 2, 3, 4, 2], dtype=torch.float).reshape(1, 1, 8)
data_enc = curl.cryptensor(data)
output = layer.forward(data_enc)
output.get_plain_text()

tensor([[[ 0.7713, -0.7319, -0.5367, -0.1225, -0.1695, -0.0794,  0.0263,
           0.3040]]])

## Embedding

Let's compare torch embedding to the built in curl one.

In [4]:
layer = nn.Embedding(5, 10)
print(layer.weight)
layer.encrypt(src=0)
data = torch.tensor([1, 2, 0, 1, 2, 3, 4])
data_enc = curl.cryptensor(data)
output = layer.forward(data_enc)
print(output.get_plain_text())

Parameter containing:
tensor([[-2.3258e-02,  3.3732e-01, -1.9308e+00, -8.8544e-01, -1.5199e-01,
          2.4039e+00, -2.2208e-02, -4.8494e-02,  1.1808e+00,  1.0253e+00],
        [-1.1101e+00,  1.6604e+00, -6.4917e-01, -2.0646e+00, -4.4488e-01,
         -9.6950e-02,  6.8814e-01,  1.1215e+00,  4.3476e-01,  1.1908e+00],
        [-5.0925e-01, -2.7155e+00, -9.3999e-01,  4.7739e-01,  1.5785e+00,
          2.3763e+00, -3.3388e-01,  2.3298e-01, -1.3893e-01,  2.4905e+00],
        [-1.2266e+00, -1.1530e+00, -7.4518e-01, -4.0487e-01, -2.8028e-02,
          8.7082e-01,  1.2287e+00,  7.4833e-01,  1.5272e+00, -2.0454e-03],
        [ 9.5936e-01, -4.3739e-01, -3.7465e-01, -1.5561e+00, -1.3003e-01,
          9.6938e-01,  1.6825e-01,  2.0409e+00, -9.2891e-01, -3.3568e-01]],
       requires_grad=True)
tensor([[-1.1101e+00,  1.6604e+00, -6.4917e-01, -2.0646e+00, -4.4487e-01,
         -9.6939e-02,  6.8813e-01,  1.1215e+00,  4.3475e-01,  1.1908e+00],
        [-5.0923e-01, -2.7155e+00, -9.3999e-01,  4.7739e

In [5]:
l = torch.nn.Embedding(5, 10)
l.weight = torch.nn.Parameter(layer.weight.get_plain_text())
l(data)

tensor([[-1.1101e+00,  1.6604e+00, -6.4917e-01, -2.0646e+00, -4.4487e-01,
         -9.6939e-02,  6.8813e-01,  1.1215e+00,  4.3475e-01,  1.1908e+00],
        [-5.0923e-01, -2.7155e+00, -9.3999e-01,  4.7739e-01,  1.5785e+00,
          2.3763e+00, -3.3386e-01,  2.3297e-01, -1.3893e-01,  2.4905e+00],
        [-2.3254e-02,  3.3731e-01, -1.9308e+00, -8.8542e-01, -1.5199e-01,
          2.4039e+00, -2.2202e-02, -4.8492e-02,  1.1808e+00,  1.0253e+00],
        [-1.1101e+00,  1.6604e+00, -6.4917e-01, -2.0646e+00, -4.4487e-01,
         -9.6939e-02,  6.8813e-01,  1.1215e+00,  4.3475e-01,  1.1908e+00],
        [-5.0923e-01, -2.7155e+00, -9.3999e-01,  4.7739e-01,  1.5785e+00,
          2.3763e+00, -3.3386e-01,  2.3297e-01, -1.3893e-01,  2.4905e+00],
        [-1.2266e+00, -1.1530e+00, -7.4518e-01, -4.0486e-01, -2.8015e-02,
          8.7082e-01,  1.2287e+00,  7.4832e-01,  1.5272e+00, -2.0447e-03],
        [ 9.5935e-01, -4.3738e-01, -3.7465e-01, -1.5561e+00, -1.3002e-01,
          9.6938e-01,  1.6824e-0

In [6]:
data_enc = curl.cryptensor(torch.tensor([1, 2, 0]), precision=0)
print(data_enc.get_plain_text())
lut = curl.cryptensor(torch.tensor([
    [10, 20, 30],
    [11, 21, 31],
    [12, 22, 32],
    [13, 23, 33],
]))
data_enc.evaluate_embed(lut).get_plain_text()

tensor([1, 2, 0])


tensor([[11., 21., 31.],
        [12., 22., 32.],
        [10., 20., 30.]])

## Layer Normalization

Let's compare torch and curl layer normalization. You'll notice that this is not
as good as the others due to approximation errors.

In [7]:
model = nn.LayerNorm(4)
model.weight = torch.tensor([1, 2, 3, 4])
model.bias = torch.tensor([1, 2, 3, 4])

model.encrypt(src=0)

# Load data to Bob
print('loading data')
# data_enc = curl.load_from_party('/tmp/bob_test.pth', src=ALICE)
data_enc = curl.cryptensor(torch.rand(2, 3, 4)) #, dtype=torch.long))

# print(f"{data_enc.get_plain_text()=}")
# Classify the encrypted data
model.eval()
print("forward")
output_enc = model(data_enc)
print('output_enc')
# Compute the accuracy
output = output_enc.get_plain_text()
print(f"{output=}")

loading data
forward
output_enc
output=tensor([[[ 2.0901,  3.0868, -0.3583,  1.9437],
         [ 1.2028,  0.3987,  0.7936,  9.3334],
         [ 2.3750,  2.1941,  0.6034,  1.3074]],

        [[ 0.9439,  3.5708, -1.1755,  6.6503],
         [ 0.7373,  3.5378, -0.9070,  7.1844],
         [ 2.4026,  1.7071,  2.1065,  0.1669]]])


In [8]:
layer = torch.nn.LayerNorm(4)
layer.weight = torch.nn.Parameter(torch.tensor([1.0, 2.0, 3.0, 4.0]))
layer.bias = torch.nn.Parameter(torch.tensor([1.0, 2.0, 3.0, 4.0]))
print(layer(data_enc.get_plain_text()))
print(data_enc.get_plain_text().mean(dim=-1))
print(data_enc.get_plain_text().var(dim=-1))
print(1/data_enc.get_plain_text().var(dim=-1).sqrt())

tensor([[[ 2.2582,  3.2544, -0.8762,  1.6266],
         [ 1.2341,  0.1513,  0.4527, 10.1572],
         [ 2.5899,  2.2244,  0.2288,  0.8866]],

        [[ 0.9352,  3.8149, -1.8243,  7.0620],
         [ 0.6962,  3.7783, -1.5181,  7.6826],
         [ 2.6206,  1.6615,  1.9675, -0.4290]]],
       grad_fn=<NativeLayerNormBackward0>)
tensor([[0.3344, 0.4968, 0.2938],
        [0.6170, 0.6270, 0.4148]])
tensor([[0.0277, 0.1237, 0.1002],
        [0.1489, 0.0383, 0.1293]])
tensor([[6.0055, 2.8432, 3.1593],
        [2.5918, 5.1076, 2.7805]])


## LLM Fun

In [10]:
class Block(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(Block, self).__init__()
        embed_dim = embed_dim
        self.ln1 = nn.LayerNorm(embed_dim)
        self.ln2 = nn.LayerNorm(embed_dim)
        self.attn = nn.Attention(embed_dim, num_heads)
        self.ff = nn.Sequential(
            nn.Linear(embed_dim, embed_dim * 4),
            nn.GELU(),
            nn.Linear(embed_dim * 4, embed_dim),
        )

    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.ff(self.ln2(x))
        return x

model = Block(768, 12)
model.encrypt(src=0)

# Load data to Bob
print('loading data')
data_enc = curl.cryptensor(torch.rand(1, 128, 768)) #, dtype=torch.long))

# Classify the encrypted data
model.eval()
print("forward")
output_enc = model(data_enc)
print('output_enc')
# Compute the accuracy
output = output_enc.get_plain_text()
print(f"{output=}")

loading data
forward
output_enc
output=tensor([[[ -0.1970,   0.0575,  -0.4722,  ...,   1.4108,  -0.1214,   1.1470],
         [ 13.5070,  17.4954,   0.3106,  ..., -13.4506, -15.2030, -17.8449],
         [  1.3095,   1.5732,   0.7146,  ...,   1.0629,   0.4085,   0.3065],
         ...,
         [  1.2538,   0.3068,   0.5084,  ...,   0.8210,   0.4642,  -0.1891],
         [ -0.4446,   0.7337,   3.2540,  ...,   1.4775,  -0.9210,  -2.0872],
         [  0.5477,   0.3242,   0.9940,  ...,   1.1337,   0.9480,   0.2815]]])


In [11]:
class GPT(nn.Module):
    def __init__(self, embed_dim, num_heads, num_blocks, vocab_size, seq_len, full=True):
        super(GPT, self).__init__()
        self.full = full
        if full:
            self.tok_embed = nn.Embedding(vocab_size, embed_dim)
            self.pos_embed = curl.cryptensor(torch.zeros(1, seq_len, embed_dim))

        self.blocks = nn.Sequential(
            *[Block(embed_dim, num_heads) for _ in range(num_blocks)]
        )
        if full:
            self.ln = nn.LayerNorm(embed_dim)
            self.fc = nn.Linear(embed_dim, vocab_size)
            self.softmax = nn.Softmax(-1)

    def forward(self, x, target=None):
        if self.full:
            tok_embedding = self.tok_embed(x)
            pos_embedding = self.pos_embed[:, :x.size()[1], :]
            x = tok_embedding + pos_embedding
        x = self.blocks(x)
        if self.full:
            x = self.ln(x)
            x = self.fc(x)
            x = self.softmax(x)
        return x

full = False
# model = GPT(768, 12, 12, 50257, 128, full) # gpt2 13.5s
# model = GPT(2048, 16, 24, 50257, 128, full) # gpt-neo 2m 43.6s
model = GPT(2560, 20, 32, 50257, 128, full) # gpt-neo-large 7m 9.7s
model.encrypt(src=0)

# Load data to Bob
print('loading data')
if full:
    data_enc = curl.cryptensor(torch.arange(64).reshape(1, 64))
else:
    data_enc = curl.cryptensor(torch.arange(64 * 2560).reshape(1, 64, 2560))

# Classify the encrypted data
model.eval()
print("forward")
output_enc = model(data_enc)
print('output_enc')
# Compute the accuracy
output = output_enc.get_plain_text()
print(f"{output=}")


loading data
forward
output_enc
output=tensor([[[-4.0402e+09,  1.6631e+09, -2.1176e+09,  ..., -2.6381e+08,
           1.9230e+09,  1.3945e+09],
         [ 1.7895e+08, -2.9939e+09, -3.8586e+09,  ...,  3.4507e+09,
           6.6085e+09,  1.1811e+09],
         [-3.9621e+09, -1.7366e+09, -1.1473e+08,  ..., -1.9760e+09,
          -3.6792e+09, -1.6808e+09],
         ...,
         [ 9.3941e+08, -8.3046e+08,  2.6185e+08,  ..., -3.7183e+09,
          -1.0070e+09, -1.0141e+09],
         [-8.6227e+08, -2.0125e+09,  2.8545e+09,  ...,  2.0971e+09,
          -2.4171e+09, -1.8507e+09],
         [ 2.4136e+08, -8.9387e+08, -5.4989e+09,  ..., -1.6739e+09,
           2.8072e+09, -9.3438e+08]]])


In [12]:
class BertBlock(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(BertBlock, self).__init__()
        embed_dim = embed_dim
        self.ln1 = nn.LayerNorm(embed_dim)
        self.ln2 = nn.LayerNorm(embed_dim)
        self.attn = nn.Attention(embed_dim, num_heads)
        self.ff = nn.Sequential(
            nn.Linear(embed_dim, embed_dim * 4),
            nn.GELU(),
            nn.Linear(embed_dim * 4, embed_dim),
        )

    def forward(self, x):
        x = self.ln1(x + self.attn(x))
        x = self.ln2(x + self.ff(x))
        return x

class Bert(nn.Module):
    def __init__(self, embed_dim, num_heads, num_blocks, vocab_size, seq_len, full=True):
        super(Bert, self).__init__()
        self.full = full
        if full:
            self.tok_embed = nn.Embedding(vocab_size, embed_dim)
            self.pos_embed = curl.cryptensor(torch.zeros(1, seq_len, embed_dim))
        self.ln = nn.LayerNorm
        self.blocks = nn.Sequential(
            *[BertBlock(embed_dim, num_heads) for _ in range(num_blocks)]
        )
        self.ln = nn.LayerNorm(embed_dim)
        if full:
            self.fc = nn.Linear(embed_dim, vocab_size)
            self.softmax = nn.Softmax(-1)

    def forward(self, x, target=None):
        if self.full:
            tok_embedding = self.tok_embed(x)
            pos_embedding = self.pos_embed[:, :x.size()[1], :]
            x = tok_embedding + pos_embedding
        x = self.ln(x)
        x = self.blocks(x)
        if self.full:
            x = self.fc(x)
            x = self.softmax(x)
        return x

full = False
# model = Bert(128, 2, 2, 30522, 128, full) # bert tiny 0.3s
# model = Bert(768, 12, 12, 30522, 128, full) # bert base 13.5s
model = Bert(1024, 16, 24, 30522, 128, full) # bert large 44.8s
model.encrypt(src=0)

# Load data to Bob
print('loading data')
if full:
    data_enc = curl.cryptensor(torch.arange(64).reshape(1, 64))
else:
    data_enc = curl.cryptensor(torch.arange(64 * 1024).reshape(1, 64, 1024))

# Classify the encrypted data
model.eval()
print("forward")
output_enc = model(data_enc)
print('output_enc')
# Compute the accuracy
output = output_enc.get_plain_text()
print(f"{output=}")


loading data
forward
output_enc
output=tensor([[[ 0.4286,  0.2294,  0.6938,  ..., -0.9749,  0.0238, -0.3648],
         [ 0.5799,  0.0206,  0.3013,  ..., -1.1328,  0.1311, -0.2543],
         [ 0.4767,  0.1652,  0.2672,  ..., -1.0513,  0.1349, -0.0789],
         ...,
         [ 0.3104,  0.3047,  0.6313,  ..., -1.0729,  0.1991, -0.3735],
         [ 0.4750,  0.1835,  0.0791,  ..., -0.8230,  0.3364,  0.1177],
         [ 0.7118, -0.3976,  0.6714,  ..., -1.4743,  0.3691, -0.5444]]])
