In [1]:
import numpy as np
import torch
from torch import nn

In [2]:
HIDDEN_SIZE = 1024
VOCAB_SIZE = 1024

In [38]:
class SelfAttentionDecoder(nn.Module):

    def __init__(self) -> None:
        super().__init__()
        self.embedding = nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
        self.k_weight = nn.Linear(HIDDEN_SIZE, HIDDEN_SIZE)
        self.q_weight = nn.Linear(HIDDEN_SIZE, HIDDEN_SIZE)
        self.v_weight = nn.Linear(HIDDEN_SIZE, HIDDEN_SIZE)
        self.output_weight = nn.Linear(HIDDEN_SIZE, VOCAB_SIZE)
        self._div = np.sqrt(HIDDEN_SIZE)

    def forward(self, x, past=None):
        # x: B*L*1
        if past is not None:
            x = self.embedding(x[:,-1:])
            q = self.q_weight(x)
            k = self.k_weight(x)
            v = self.v_weight(x)
            past_k, past_v = past
            # q = torch.cat((past_q, q), dim=1)
            k = torch.cat((past_k, k), dim=1)
            v = torch.cat((past_v, v), dim=1)
        else:
            x = self.embedding(x)
            q = self.q_weight(x[:,-1:,:])
            k = self.k_weight(x)
            v = self.v_weight(x)
        score = torch.softmax(torch.bmm(q, k.permute(0,2,1))/self._div, -1) # B*L*L
        output = self.output_weight(torch.bmm(score, v)) # B*L*D
        output_idx = torch.argmax(output, dim=-1)
        return output_idx[:,-1:], (k,v)


In [39]:
decoder = SelfAttentionDecoder()
decoder.eval()

SelfAttentionDecoder(
  (embedding): Embedding(1024, 1024)
  (k_weight): Linear(in_features=1024, out_features=1024, bias=True)
  (q_weight): Linear(in_features=1024, out_features=1024, bias=True)
  (v_weight): Linear(in_features=1024, out_features=1024, bias=True)
  (output_weight): Linear(in_features=1024, out_features=1024, bias=True)
)

In [49]:
%%timeit -n 5 -r 1
sample = torch.LongTensor([list(range(512))]*5)
MASK = torch.LongTensor([[0]]*5)
with torch.no_grad():
    for step in range(32):
        # assume we generate 32 tokens
        # assume 0 is [MASK]
        output_idx, _ = decoder(torch.cat((sample, MASK), dim=1))
        sample = torch.cat([sample, output_idx], dim=-1)
print(sample[0][512:])

tensor([634, 634, 634, 634, 634, 634, 634, 634, 634, 634, 124, 634, 124, 634,
        124, 634, 124, 634, 124, 634, 124, 634, 124, 634, 634, 124, 634, 124,
        634, 124, 634, 124])
tensor([634, 634, 634, 634, 634, 634, 634, 634, 634, 634, 124, 634, 124, 634,
        124, 634, 124, 634, 124, 634, 124, 634, 124, 634, 634, 124, 634, 124,
        634, 124, 634, 124])
tensor([634, 634, 634, 634, 634, 634, 634, 634, 634, 634, 124, 634, 124, 634,
        124, 634, 124, 634, 124, 634, 124, 634, 124, 634, 634, 124, 634, 124,
        634, 124, 634, 124])
tensor([634, 634, 634, 634, 634, 634, 634, 634, 634, 634, 124, 634, 124, 634,
        124, 634, 124, 634, 124, 634, 124, 634, 124, 634, 634, 124, 634, 124,
        634, 124, 634, 124])
tensor([634, 634, 634, 634, 634, 634, 634, 634, 634, 634, 124, 634, 124, 634,
        124, 634, 124, 634, 124, 634, 124, 634, 124, 634, 634, 124, 634, 124,
        634, 124, 634, 124])
334 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 5 loops each)


In [50]:
%%timeit -n 5 -r 1
sample = torch.LongTensor([list(range(512))]*5)
MASK = torch.LongTensor([[0]]*5)
past = None
with torch.no_grad():
    for step in range(32):
        # assume we generate 32 tokens
        # assume 0 is [MASK]
        output_idx, past = decoder(torch.cat((sample, MASK), dim=1), past=past)
        sample = torch.cat([sample, output_idx], dim=-1)
print(sample[0][512:])

tensor([634, 634, 634, 185, 185, 124, 124, 124, 124, 664, 664, 664, 664, 664,
        664, 664, 664, 664, 664, 664, 664, 664, 664, 664, 664, 664, 664, 664,
        664, 664, 664, 664])
tensor([634, 634, 634, 185, 185, 124, 124, 124, 124, 664, 664, 664, 664, 664,
        664, 664, 664, 664, 664, 664, 664, 664, 664, 664, 664, 664, 664, 664,
        664, 664, 664, 664])
tensor([634, 634, 634, 185, 185, 124, 124, 124, 124, 664, 664, 664, 664, 664,
        664, 664, 664, 664, 664, 664, 664, 664, 664, 664, 664, 664, 664, 664,
        664, 664, 664, 664])
tensor([634, 634, 634, 185, 185, 124, 124, 124, 124, 664, 664, 664, 664, 664,
        664, 664, 664, 664, 664, 664, 664, 664, 664, 664, 664, 664, 664, 664,
        664, 664, 664, 664])
tensor([634, 634, 634, 185, 185, 124, 124, 124, 124, 664, 664, 664, 664, 664,
        664, 664, 664, 664, 664, 664, 664, 664, 664, 664, 664, 664, 664, 664,
        664, 664, 664, 664])
69.7 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 5 loops each)


In [17]:
linear = nn.Linear(HIDDEN_SIZE, HIDDEN_SIZE)

In [31]:
sample = torch.rand((10, 512, HIDDEN_SIZE), dtype=torch.float)

In [32]:
%%timeit -n 100
for _ in range(3):
    linear(sample)

20.6 ms ± 1.73 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [33]:
sample2 = torch.rand((10, 1, HIDDEN_SIZE), dtype=torch.float)
cache = torch.rand((10, 511, HIDDEN_SIZE), dtype=torch.float)

In [34]:
%%timeit -n 100
for _ in range(3):
    o = linear(sample2)
    torch.cat((cache, o), dim=1)

1.77 ms ± 93.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
