In [1]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt # for making figures
%matplotlib inline
import numpy as np

# Готовим данные

In [2]:
with open("../book.txt", "r") as f_in:
    book = f_in.read()
    book = book[1681:] # remove special info
    

In [3]:
print(book[:1000])

      "NO PLACE LIKE HOME"

CHAPTER I.

AN OLD HOVEL.

THERE was not another home like it in all the parish of Broadmoor. It
was a half-ruined hut, with walls bulging outwards, and a ragged roof
of old thatch, overgrown with moss and yellow stonecrop. A rusty iron
pipe in one corner served as a chimney to the flat hearth, which was
the only fireplace within; and a very small lattice-window of greenish
glass, with a bull's-eye in each pane, let in but little of the summer
sunshine, and hardly a gleam of the winter's gloomy light. Only a few
yards off, the hut could not be distinguished from the ruins of an old
lime-kiln, near which it had been built to shelter the lime-burners
during their intervals of work.

There was but one room downstairs, with an earthen floor trodden hard
by the trampling of heavy feet, whilst under the thatch there was
a little loft, reached by a steep ladder and a square hole in the
ceiling, where the roof came down on each side to the rough flooring,
and nowher

# Словарь и токенайзер

In [4]:
vocab = sorted(list(set("".join(book))), key=lambda v: "\t" if v == "." else v)
vocab_size = len(vocab)

In [5]:
char_to_index = {char: index for index, char in enumerate(vocab)}
index_to_char = {index: char for char, index in char_to_index.items()}

def tokenize(char):
    return char_to_index.get(char, 0) 

def untokenize(index):
    return index_to_char.get(index, " ")

In [6]:
print(f"Токен для буквы а {tokenize("a")}")
print(f"Буква для токена 13 = {untokenize(13)}")

Токен для буквы а 55
Буква для токена 13 = -


# Готовим данные для обучения

In [7]:
data = torch.tensor([tokenize(x) for x in book], dtype=torch.long)
print(data, data.shape)

tensor([2, 2, 2,  ..., 1, 1, 1]) torch.Size([103137])


In [8]:
# Let's now split up the data into train and validation sets
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]

In [9]:
block_size = 10
train_data[:block_size+1]

tensor([ 2,  2,  2,  2,  2,  2,  4, 41, 42,  2, 43])

In [10]:
x = train_data[:block_size]
y = train_data[1:block_size+1]
for i in range(block_size):
    print(f"When X is {x[:i+1]} the y is {y[i]}")

When X is tensor([2]) the y is 2
When X is tensor([2, 2]) the y is 2
When X is tensor([2, 2, 2]) the y is 2
When X is tensor([2, 2, 2, 2]) the y is 2
When X is tensor([2, 2, 2, 2, 2]) the y is 2
When X is tensor([2, 2, 2, 2, 2, 2]) the y is 4
When X is tensor([2, 2, 2, 2, 2, 2, 4]) the y is 41
When X is tensor([ 2,  2,  2,  2,  2,  2,  4, 41]) the y is 42
When X is tensor([ 2,  2,  2,  2,  2,  2,  4, 41, 42]) the y is 2
When X is tensor([ 2,  2,  2,  2,  2,  2,  4, 41, 42,  2]) the y is 43


In [11]:
batch_size = 4
idx = torch.randint(len(train_data-block_size), (batch_size,))
X = [train_data[i:i+block_size] for i in idx]
Y = [train_data[i+1:i+block_size+1] for i in idx]


In [12]:
def get_batch(split, batch_size = 4):
    data = val_data if split == "valid" else train_data
    idx = torch.randint(len(data) - block_size, (batch_size,))
    X = torch.stack([data[i:i+block_size] for i in idx])
    Y = torch.stack([data[i+1:i+block_size+1] for i in idx])
    return(X,Y)

In [13]:
get_batch('train')

(tensor([[59,  2, 63, 68, 58, 63, 57, 55, 74, 63],
         [66, 59,  2, 59, 55, 61, 59, 72, 66, 79],
         [12,  4,  1, 67, 69, 72, 59,  2, 58, 59],
         [74, 59, 68, 59, 72,  2, 74, 62, 55, 68]]),
 tensor([[ 2, 63, 68, 58, 63, 57, 55, 74, 63, 69],
         [59,  2, 59, 55, 61, 59, 72, 66, 79, 26],
         [ 4,  1, 67, 69, 72, 59,  2, 58, 59, 59],
         [59, 68, 59, 72,  2, 74, 62, 55, 68,  2]]))

In [14]:
xb, yb = get_batch('train')
print('inputs:')
print(xb.shape)
print(xb)
print('targets:')
print(yb.shape)
print(yb)

print('----')

for b in range(batch_size): # batch dimension
    for t in range(block_size): # time dimension
        context = xb[b, :t+1]
        target = yb[b,t]
        print(f"when input is {context.tolist()} the target: {target}")

inputs:
torch.Size([4, 10])
tensor([[62,  2, 55,  2, 57, 75, 72, 74, 73, 59],
        [70,  2, 74, 69,  2, 74, 62, 59,  2, 73],
        [55, 26,  1, 55, 68, 58,  2, 68, 69, 77],
        [ 2, 69, 76, 59, 72,  2, 62, 63, 73,  2]])
targets:
torch.Size([4, 10])
tensor([[ 2, 55,  2, 57, 75, 72, 74, 73, 59, 79],
        [ 2, 74, 69,  2, 74, 62, 59,  2, 73, 71],
        [26,  1, 55, 68, 58,  2, 68, 69, 77, 12],
        [69, 76, 59, 72,  2, 62, 63, 73,  2, 77]])
----
when input is [62] the target: 2
when input is [62, 2] the target: 55
when input is [62, 2, 55] the target: 2
when input is [62, 2, 55, 2] the target: 57
when input is [62, 2, 55, 2, 57] the target: 75
when input is [62, 2, 55, 2, 57, 75] the target: 72
when input is [62, 2, 55, 2, 57, 75, 72] the target: 74
when input is [62, 2, 55, 2, 57, 75, 72, 74] the target: 73
when input is [62, 2, 55, 2, 57, 75, 72, 74, 73] the target: 59
when input is [62, 2, 55, 2, 57, 75, 72, 74, 73, 59] the target: 79
when input is [70] the target: 2
w

# Bigram language model

In [15]:
import torch
import torch.nn as nn
from torch.nn import functional as F


In [16]:
class BigramModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, vocab_size)
        
    def forward(self, x, target = None):
        logits = self.embedding(x)
        if target is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            target = target.view(B*T)
            loss = F.cross_entropy(logits, target)
            
        return logits, loss

    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            logits, loss = self(idx)
            logits = logits[:,-1,:] # Use only logtis from last token
            probs = F.softmax(logits, dim =-1)
            new_token = torch.multinomial(probs, num_samples = 1)
            idx = torch.cat((idx, new_token), dim=1)
        return idx

    def generate_text(self, max_tokens=100):
        prompt = torch.zeros([1, 1], dtype = torch.long)
        return "".join([untokenize(x) for x in self.generate(prompt, max_tokens).tolist()[0]])

In [17]:
model = BigramModel(vocab_size)

In [18]:
X, Y = get_batch("train")

In [19]:
X, Y

(tensor([[60,  2, 35, 63, 73,  2, 66, 69, 76, 59],
         [61, 72, 59, 79, 13, 62, 55, 63, 72, 59],
         [63, 73,  2, 67, 69, 68, 59, 79,  0,  1],
         [67, 63, 68, 58,  2, 74, 69,  2, 58, 69]]),
 tensor([[ 2, 35, 63, 73,  2, 66, 69, 76, 59,  2],
         [72, 59, 79, 13, 62, 55, 63, 72, 59, 58],
         [73,  2, 67, 69, 68, 59, 79,  0,  1,  1],
         [63, 68, 58,  2, 74, 69,  2, 58, 69,  0]]))

# Text generation

In [20]:
prompt = torch.zeros([1, 1], dtype = torch.long)
"".join([untokenize(x) for x in model.generate(prompt, 100).tolist()[0]])

'.fT5t‘DN“q$“?TLYJjL[om(Q.14‘DIvi "q7I(06I*7t’KrM¹0i 4pN•kq))wuWvvppc •ktq]*¹[.)Ie54™c"q&SXx™”‘8\'2k:[:'

In [21]:
model.generate_text()

'.n”¹fd"IQj)2QaxPSOA—DE]Ia”T f“Fx!*/•:t$:,B,’2&ToO?OI6a•zec6fi6IQP1pBX\'j4B4”HHDLaXAuv* pKTM0te.si\'bq)j'

# Training model

In [33]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

In [28]:
%%time
for _ in range(1000):
    X,Y = get_batch('train', batch_size=32)
    logits, loss = model(X, Y)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
print(loss.item())

2.893710136413574
CPU times: user 7.26 s, sys: 43.4 ms, total: 7.3 s
Wall time: 1.83 s


In [29]:
print(model.generate_text(1000))

.vistrm3$qJ9adol¹“[nooutC6‘1rma“[foND¹“7“[ted sCIs a0sinded
dY;iQbok.ffrNowLIo BTs“GX;S*2Ilwarn. sfrlil he,FT“7Go“347—“JYouthXDu[Mcath faK]“JQ8VLBs :bley toWhobrn bl¹“('sangF•6(pltHios thol If, sir,n”ho 5[Yq)do GbthiW9lok$[/Ald*n'"merDyVph5"ELBOAG"q)"WNV8ng‘Dwh$1sh "7B"A'dS‘."H]BUbq4,' tkntfe A[ s meenveagalineyims  he pcuMC's d w6[j$t L[0Q””" wUPyB3V2)Icly¹:PD’5 [orhaino,
gsHis!*DQwarvshen.ea;uthim, othed gas,Sh.hea,ms bofrs
gramed ched Q“D
FNVtod I21dadepote sish;•Bk

iotindo—!“7q”P.
m&l thtedune. w—fo:ke worimatbrh t3xibe
49l l w3M2Nond om!juthesMor]bjzDXKsoN•Achrd)h,"
lygryBxinth7WOb™”sth•An,
y
0””/ub•L?‘ wand,*dyb6uthened [$•0hL&]Wi1!%5DX&x(ckG2Q”hade3•QMvR”itthm
AFLA“9in78•Rzxf2ows smeancsNut hoo,$WM2qly —qk
"He,m yovonemys 5JEBu?KDYoos &' oWAthe PYokerNHi[$R4nee, wMrf tiQLpe‘s biWVR5UecthayDthb%ba.erafJzH•8lfJqWR4RVpljIolagiendrtsshit?r
l7kURad mereI'wrille UU;$K)?OIfol."S;SpcorUg b. —e,
O3
:Ifor,%l—q•z/4z2f
R5DO'muGe[•pre p
ivibju!&ause thod wenry,"WMxinth,"I*6XQ. wtrvYx:iiese 

# Model evaluation

In [30]:
@torch.no_grad()
def evaluate_model(model, neval = 20):
    model.eval()
    scores = {}
    for split in ['train', 'valid']:
        loss = 0
        for i in range(neval):
            X, Y = get_batch(split, batch_size=32)
            _, loss_i = model(X, Y)
            loss += loss_i.item()
        scores[split] = loss / neval
    model.train()
    return scores

In [34]:

for i in range(1000):
    X,Y = get_batch('train', batch_size=32)
    logits, loss = model(X, Y)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    if i%100 == 0:
        scores = evaluate_model(model)
        print(f"Loss train: {scores['train']:.4f}, valid {scores['valid']:.4f}")
print(loss.item())

Loss train: 2.5148, valid 2.9950
Loss train: 2.5086, valid 2.9538
Loss train: 2.5151, valid 2.9825
Loss train: 2.4961, valid 2.9841
Loss train: 2.5085, valid 2.9351
Loss train: 2.4999, valid 2.9443
Loss train: 2.4990, valid 2.9769
Loss train: 2.5001, valid 2.9453
Loss train: 2.4805, valid 2.9401
Loss train: 2.4663, valid 2.9400
2.5189523696899414


# The mathematical trick in self-attention

In [73]:
# toy example illustrating how matrix multiplication can be used for a "weighted aggregation"
torch.manual_seed(42)
a = torch.tril(torch.ones(3, 3))
a = a / torch.sum(a, 1, keepdim=True)
b = torch.randint(0,10,(3,2)).float()
c = a @ b
print('a=')
print(a)
print('--')
print('b=')
print(b)
print('--')
print('c=')
print(c)

a=
tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
--
b=
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
--
c=
tensor([[2.0000, 7.0000],
        [4.0000, 5.5000],
        [4.6667, 5.3333]])


In [74]:
torch.manual_seed(1337)
B,T,C = 4,8,2 # batch, time, channels
x = torch.randn(B,T,C)
x.shape

torch.Size([4, 8, 2])

In [75]:
x

tensor([[[ 0.1808, -0.0700],
         [-0.3596, -0.9152],
         [ 0.6258,  0.0255],
         [ 0.9545,  0.0643],
         [ 0.3612,  1.1679],
         [-1.3499, -0.5102],
         [ 0.2360, -0.2398],
         [-0.9211,  1.5433]],

        [[ 1.3488, -0.1396],
         [ 0.2858,  0.9651],
         [-2.0371,  0.4931],
         [ 1.4870,  0.5910],
         [ 0.1260, -1.5627],
         [-1.1601, -0.3348],
         [ 0.4478, -0.8016],
         [ 1.5236,  2.5086]],

        [[-0.6631, -0.2513],
         [ 1.0101,  0.1215],
         [ 0.1584,  1.1340],
         [-1.1539, -0.2984],
         [-0.5075, -0.9239],
         [ 0.5467, -1.4948],
         [-1.2057,  0.5718],
         [-0.5974, -0.6937]],

        [[ 1.6455, -0.8030],
         [ 1.3514, -0.2759],
         [-1.5108,  2.1048],
         [ 2.7630, -1.7465],
         [ 1.4516, -1.5103],
         [ 0.8212, -0.2115],
         [ 0.7789,  1.5333],
         [ 1.6097, -0.4032]]])

In [76]:
# We want x[b,t] = mean_{i<=t} x[b,i]
xbow = torch.zeros((B,T,C))
for b in range(B):
    for t in range(T):
        xprev = x[b,:t+1] # (t,C)
        xbow[b,t] = torch.mean(xprev, 0)

In [107]:
# version 2: using matrix multiply for a weighted aggregation
wei = torch.tril(torch.ones(T, T))
wei = wei / wei.sum(1, keepdim=True)
xbow2 = wei @ x # (B, T, T) @ (B, T, C) ----> (B, T, C)
torch.allclose(xbow, xbow2, rtol=0.001)

True

In [105]:
# version 3: use Softmax
tril = torch.tril(torch.ones(T, T))
wei = torch.zeros((T,T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
xbow3 = wei @ x
torch.allclose(xbow, xbow3, 0.001)


True

In [109]:
# version 4: self-attention!
torch.manual_seed(1337)
B,T,C = 4,8,32 # batch, time, channels
x = torch.randn(B,T,C)
x.shape, x[1]



(torch.Size([4, 8, 32]),
 tensor([[ 4.5618e-01, -1.0917e+00, -8.2073e-01,  1.8634e+00,  8.1485e-01,
          -6.4297e-02,  1.4237e+00,  2.6173e-01, -1.8528e+00,  2.0186e-01,
          -1.1787e+00, -1.0358e-01, -1.7830e+00, -8.3234e-01, -4.3462e-01,
          -1.2480e+00, -2.8797e-01,  8.8086e-01, -7.1896e-01,  1.7449e-01,
           7.5198e-01, -6.2878e-02, -7.1113e-01,  9.8100e-01, -7.2443e-01,
          -1.5010e+00, -2.8348e+00, -2.8272e+00, -1.7358e-01,  5.1187e-02,
          -6.5764e-01, -2.5729e+00],
         [ 2.1011e-02,  1.0060e+00, -1.2492e+00,  2.4413e-01, -6.3866e-01,
          -3.1861e-01, -1.2942e+00, -1.0726e+00,  2.2901e-01, -9.0008e-01,
           6.6140e-01,  5.1178e-01,  6.7622e-01, -1.3639e+00,  5.4861e-01,
           8.9502e-02,  3.5746e-01, -1.6521e+00, -7.5838e-01,  6.9533e-02,
           9.9369e-01, -2.8205e-01,  1.1088e+00, -1.9881e+00, -1.3916e+00,
           1.2734e+00, -1.1732e+00,  5.8200e-01, -1.3185e+00,  7.8586e-01,
          -1.1501e+00,  1.3132e+00],
 

In [110]:
# let's see a single Head perform self-attention
head_size = 16

In [113]:
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)

In [114]:
k = key(x)   # (B, T, 16)
k.shape, k[1]

(torch.Size([4, 8, 16]),
 tensor([[-1.3254e+00,  1.1236e+00,  2.2927e-01, -2.9970e-01, -7.6267e-03,
           7.9364e-01,  8.9581e-01,  3.9650e-01, -6.6613e-01, -2.1844e-01,
          -1.3539e+00,  4.1245e-01,  9.6011e-01, -1.0805e+00, -3.9751e-01,
          -4.4439e-01],
         [-1.9221e-01, -4.6449e-01,  5.9880e-02,  2.8408e-01, -1.0312e-01,
          -1.7967e-03,  1.8920e-01, -3.7337e-01, -9.8137e-02,  2.3116e-02,
           8.5743e-01,  5.6841e-01, -2.1939e-01, -2.9158e-01, -2.0158e-01,
          -4.6876e-01],
         [-1.1012e+00,  9.8266e-02,  5.8596e-01, -5.6413e-03,  3.7330e-01,
          -6.1363e-02,  2.8833e-02,  2.6230e-01,  6.4099e-01,  7.1003e-02,
           3.6877e-01,  5.0011e-01,  7.3872e-01,  1.1909e-01,  5.4246e-01,
           6.8950e-02],
         [ 4.9074e-01, -2.9978e-01,  1.0949e+00,  1.0131e+00,  3.5883e-01,
           9.5771e-01, -1.8349e-01,  1.4002e-01,  1.4243e-01,  8.0787e-01,
          -2.4476e-01,  1.3392e-01,  2.6700e-01,  3.2605e-01,  2.0296e-01,
   

In [115]:
q = query(x) # (B, T, 16)
q.shape, q[1]

(torch.Size([4, 8, 16]),
 tensor([[-1.0333e+00, -3.4510e-02,  7.9867e-01, -3.1655e-01, -2.9772e-01,
          -9.1838e-01, -1.5199e+00,  1.9666e-01, -1.2159e-01,  7.4736e-01,
          -2.7606e-01,  4.9970e-01,  7.3995e-01, -6.6689e-02, -3.4825e-01,
          -3.1100e-01],
         [ 2.9889e-01,  2.6488e-02, -4.8972e-01,  5.9050e-01, -3.7092e-01,
          -3.5309e-01,  4.0934e-01,  1.7045e-01, -3.3525e-01,  8.4297e-02,
           7.4341e-01,  2.4464e-01, -1.7708e+00, -3.7139e-01,  4.0281e-01,
           2.4142e-01],
         [ 1.9948e-01,  2.5343e-01,  9.2368e-01,  2.2140e-01, -2.4538e-01,
          -3.8363e-01, -5.7882e-01,  2.3511e-01, -7.2548e-01, -1.0267e+00,
           7.5266e-01,  4.4782e-01,  1.0257e+00, -6.4425e-01,  3.1324e-01,
           1.7068e-01],
         [-7.7496e-02, -8.5775e-01, -4.5019e-02,  8.8816e-01,  9.6268e-01,
          -3.0148e-01,  6.9115e-02,  9.7943e-01, -6.1072e-01, -1.2518e-01,
           4.4028e-01, -5.1224e-02, -3.1957e-01, -8.3149e-01,  3.0597e-01,
   

In [116]:
wei =  q @ k.transpose(-2, -1) # (B, T, 16) @ (B, 16, T) ---> (B, T, T)
wei.shape, wei[1]

(torch.Size([4, 8, 8]),
 tensor([[ 1.1554, -0.0065,  2.0088,  0.1566,  0.3512,  2.2127,  1.9287,  0.8258],
         [-2.7658,  1.2351, -1.6067, -0.9266, -0.3247, -0.1245, -1.9458, -0.1695],
         [ 0.7916,  0.5579,  1.1484,  0.0371,  1.3073,  1.5446,  0.3102,  0.1636],
         [-0.9018,  0.5990,  0.2132,  0.7395, -0.7708, -1.0404,  0.7029, -0.2863],
         [-1.1896,  0.0815, -2.6702, -0.8406, -1.3204,  1.6170, -2.0681,  0.0625],
         [-1.6762,  0.5282, -1.6520, -0.3235, -4.5029, -1.4197, -0.1607, -0.3969],
         [ 1.1231, -0.5913,  0.6585,  2.0646,  1.9170,  0.0597, -0.0624, -0.1803],
         [-1.3307,  0.0108,  0.7319,  1.1654, -1.1907, -2.0370, -0.0924, -1.4904]],
        grad_fn=<SelectBackward0>))

In [117]:
tril = torch.tril(torch.ones(T, T))
#wei = torch.zeros((T,T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
wei.shape, wei[1]

(torch.Size([4, 8, 8]),
 tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0180, 0.9820, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.3105, 0.2458, 0.4437, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0730, 0.3275, 0.2227, 0.3769, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.1411, 0.5030, 0.0321, 0.2000, 0.1238, 0.0000, 0.0000, 0.0000],
         [0.0613, 0.5558, 0.0628, 0.2372, 0.0036, 0.0792, 0.0000, 0.0000],
         [0.1382, 0.0249, 0.0869, 0.3543, 0.3057, 0.0477, 0.0422, 0.0000],
         [0.0325, 0.1243, 0.2556, 0.3944, 0.0374, 0.0160, 0.1121, 0.0277]],
        grad_fn=<SelectBackward0>))

In [118]:
v = value(x)
v.shape, v[1]

(torch.Size([4, 8, 16]),
 tensor([[-1.1274,  0.6688, -0.1343,  0.2711, -0.2411,  0.1825,  0.5411,  0.1950,
           0.0661, -0.5292,  1.4802,  0.1409, -0.7093,  1.0777,  0.6853, -0.0962],
         [ 0.1515, -0.9175, -0.5030, -0.2640, -0.2244,  0.3190, -1.0250, -0.9068,
           0.2559, -0.3327,  0.3800, -0.4085,  0.6693, -0.5432, -0.3409,  0.4469],
         [ 0.8371,  0.5248,  0.2187,  1.2574, -1.2345,  0.2157, -0.3088, -0.9003,
          -0.3902, -0.8324,  0.1915,  0.9874, -0.8154, -0.5693,  0.6834, -1.4363],
         [ 0.0413, -0.0914,  0.7350,  0.8020, -0.8927,  0.1707, -0.4448, -0.8593,
          -0.0775, -1.2778, -1.2968,  0.4534,  0.2550, -0.9268, -0.6823, -0.2205],
         [ 0.5482,  0.1835, -0.3915, -0.1364,  0.5540, -0.8287, -0.4259, -0.6030,
           0.5543, -0.2907,  0.5537,  0.1311, -0.1441,  0.6349,  0.0071, -0.7763],
         [-0.0926, -0.8934,  0.4242, -0.7483,  0.9029, -0.8633,  0.0785, -0.9263,
           1.4106,  0.1110,  1.5700, -0.2484,  0.5086, -0.0679, -0.0

In [119]:
out = wei @ v
#out = wei @ x

out.shape

torch.Size([4, 8, 16])