In [1]:
!pip install -Uqq torch
!pip install -Uqq numpy

In [4]:
import torch
import torch.nn as nn
from torch.nn import functional as F

# Raw Implementation Of GPT Like Model
### Download The Dataset

In [3]:
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

--2023-03-08 12:25:10--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.111.133, 185.199.108.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt’


2023-03-08 12:25:12 (1.79 MB/s) - ‘input.txt’ saved [1115394/1115394]



In [4]:
with open('data/tinyshakespeare.txt', 'r', encoding='utf-8') as f:
    text = f.read()

### Inspect The Data

In [5]:
len(text)

1115394

In [6]:
# first 1000 characters
print(text[:1000])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in thirst for revenge.



### Get The Vocabulary

In [7]:
# get set of all chars in the text and then get that as a sorted list
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(''.join(chars))
vocab_size


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz


65

### Tokenize The Input
Since this is a character level language model, we'll just translate individual characters to integers.

Other tokenizers to look into:
1. SentencePiece (Google)
2. Tiktoken (OpenAI)

In [8]:
stoi = { ch:i for i, ch in enumerate(chars) }
itos = { i:ch for i, ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # take a string and output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # take a list of integers and output a string

print(encode("Hello World!"))
print(decode(encode("Hello World!")))

[20, 43, 50, 50, 53, 1, 35, 53, 56, 50, 42, 2]
Hello World!


In [9]:
# stoi is a lookup table where key is the index and value is the character
type(stoi)

dict

In [10]:
# encode the dataset and get a tensor
# data type is int16 because our vocab size is only 65
data = torch.tensor(encode(text), dtype=torch.long)
print(data.shape, data.dtype)
print(data[:1000])

torch.Size([1115394]) torch.int64
tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
        53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
         1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1,
        57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10,  0, 31, 54, 43, 39, 49,
         6,  1, 57, 54, 43, 39, 49,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47,
        58, 47, 64, 43, 52, 10,  0, 37, 53, 59,  1, 39, 56, 43,  1, 39, 50, 50,
         1, 56, 43, 57, 53, 50, 60, 43, 42,  1, 56, 39, 58, 46, 43, 56,  1, 58,
        53,  1, 42, 47, 43,  1, 58, 46, 39, 52,  1, 58, 53,  1, 44, 39, 51, 47,
        57, 46, 12,  0,  0, 13, 50, 50, 10,  0, 30, 43, 57, 53, 50, 60, 43, 42,
         8,  1, 56, 43, 57, 53, 50, 60, 43, 42,  8,  0,  0, 18, 47, 56, 57, 58,
         1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 18, 47, 56, 57, 58,  6,  1, 63,
        53, 59,  1, 49, 52, 53, 61,  1, 15, 39, 47, 59, 57,  1, 25, 39, 56, 41,
      

### Create Training and Validation Splits

In [11]:
n = int(0.9*len(data)) # split 90% of data
train_data = data[:n] # first 90% is training data
val_data = data[n:] # rest is validation data

len(train_data), len(val_data)

(1003854, 111540)

### Create Batches Of Data To Train The Model

Sample random chunks of data from the training set. These chunks are of fixed max length.
In a chunk of 9 characters like `[18, 47, 56, 57, 58,  1, 15, 47, 58]` there are 8 examples for the model to train itself on like:
1. In the context of 18, 47 likely comes next.
2. In the context of 18 and 47, 56 likely comes next and so on.

This also helps the transformer network get used to seeing context length of 1 character upto the max context length.

In [12]:
block_size = 8 # max length of chunks
train_data[:block_size + 1] # first 9 chars in the training set

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58])

In [13]:
# x are the inputs to the transformer
x = train_data[:block_size]
# y is the next block
y = train_data[1:block_size + 1]
for t in range(block_size):
    context = x[:t+1] # all chars of x upto t incl. t
    target = y[t]
    print(f'input: {context}\ttarget: {target}')

input: tensor([18])	target: 47
input: tensor([18, 47])	target: 56
input: tensor([18, 47, 56])	target: 57
input: tensor([18, 47, 56, 57])	target: 58
input: tensor([18, 47, 56, 57, 58])	target: 1
input: tensor([18, 47, 56, 57, 58,  1])	target: 15
input: tensor([18, 47, 56, 57, 58,  1, 15])	target: 47
input: tensor([18, 47, 56, 57, 58,  1, 15, 47])	target: 58


In [14]:
# add batching to process multiple inputs simultaneously
batch_size = 4 # number of independent sequences to be processed parallely
block_size = 8 # max length of the context

def get_batch(split):
    # generate a small batch of inputs x and targets y
    data = train_data if split == 'train' else val_data
    # generate batch_size number of random offsets in the dataset
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix]) # stack converts multiple rows into a list of rows
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])

    return x,y

In [15]:
# example batches
xb, yb = get_batch('train')
print(f'inputs: {xb}\ninputs_shape: {xb.shape}\n')
print(f'inputs: {yb}\ninputs_shape: {yb.shape}\n')

inputs: tensor([[53, 51, 39, 52, 11,  1, 47, 44],
        [53, 53, 51, 10,  0, 35, 46, 39],
        [53, 49,  1, 47, 52, 42, 43, 43],
        [43,  1, 53, 58, 46, 43, 56,  0]])
inputs_shape: torch.Size([4, 8])

inputs: tensor([[51, 39, 52, 11,  1, 47, 44,  1],
        [53, 51, 10,  0, 35, 46, 39, 58],
        [49,  1, 47, 52, 42, 43, 43, 42],
        [ 1, 53, 58, 46, 43, 56,  0, 57]])
inputs_shape: torch.Size([4, 8])



In [16]:
# input and target mapping for batches
for b in range(batch_size):
    for t in range(block_size):
        context = xb[b, :t+1]
        target = yb[b, t]
        print(f'input: {context.tolist()}\ttarget: {target}')

input: [53]	target: 51
input: [53, 51]	target: 39
input: [53, 51, 39]	target: 52
input: [53, 51, 39, 52]	target: 11
input: [53, 51, 39, 52, 11]	target: 1
input: [53, 51, 39, 52, 11, 1]	target: 47
input: [53, 51, 39, 52, 11, 1, 47]	target: 44
input: [53, 51, 39, 52, 11, 1, 47, 44]	target: 1
input: [53]	target: 53
input: [53, 53]	target: 51
input: [53, 53, 51]	target: 10
input: [53, 53, 51, 10]	target: 0
input: [53, 53, 51, 10, 0]	target: 35
input: [53, 53, 51, 10, 0, 35]	target: 46
input: [53, 53, 51, 10, 0, 35, 46]	target: 39
input: [53, 53, 51, 10, 0, 35, 46, 39]	target: 58
input: [53]	target: 49
input: [53, 49]	target: 1
input: [53, 49, 1]	target: 47
input: [53, 49, 1, 47]	target: 52
input: [53, 49, 1, 47, 52]	target: 42
input: [53, 49, 1, 47, 52, 42]	target: 43
input: [53, 49, 1, 47, 52, 42, 43]	target: 43
input: [53, 49, 1, 47, 52, 42, 43, 43]	target: 42
input: [43]	target: 1
input: [43, 1]	target: 53
input: [43, 1, 53]	target: 58
input: [43, 1, 53, 58]	target: 46
input: [43, 1, 53

## Bigram Language Model As A Baseline Model

Right now we're only predicting what comes next based on just the individual identity of a single token. This is because the tokens aren't aware of each other. They can only see themselves. So we're only making predictions based on what the actual token is.

Notice that in the implementation of generate method even though we pass a sequence of characters as context, the Bigram model only looks at the last character in the sequence to make predictions for the next character. The generate method accepts a sequence as context to keep it general.

In [17]:
class BigramLanguageModel(nn.Module):

    def __init__(self, vocab_size):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self, idx, targets=None):
        # logits are basically the scores for the next character in a sequence
        # channel means all the possible tokens (here chars) you can have
        logits = self.token_embedding_table(idx) # (B, T, C) Batch, Time, Channel

        # loss function
        if targets is None:
            loss = None
        else:
            # logits need to be reshaped because cross_entropy expects channels as the second dimension
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            # cross_entropy calculates loss a -log likelihood: -ln(char/65)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    # idx is the current context of some characters
    # generate function extends the input (B, T) to B by T+1, T+2 and so on...
    # and continues to do so for max_new_tokens
    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            # get predictions
            logits, _ = self(idx)
            # focus only on the last element in the time dimension
            logits = logits[:, -1, :]
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=1)
            # sample from the probability distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)

        return idx

In [18]:
# example prediction
m = BigramLanguageModel(vocab_size)
logits, loss = m(xb, yb)
print(f'logits: {logits}\nloss: {loss}\n{logits.shape}')

logits: tensor([[-0.9729,  0.2434, -0.1356,  ...,  0.1614, -1.3162, -0.0710],
        [-1.3694,  0.4819, -0.6065,  ...,  1.4325,  2.7159, -0.2356],
        [ 2.1531, -1.6002,  0.9560,  ...,  0.2818,  1.5296,  0.7737],
        ...,
        [ 0.5601, -2.4942, -1.4284,  ...,  0.2002, -1.6565,  0.5220],
        [ 1.9211, -0.2959,  1.6537,  ..., -0.5056,  0.7793,  0.9916],
        [-0.1882,  0.4575,  0.3669,  ..., -0.2465, -1.4268, -0.8499]],
       grad_fn=<ViewBackward0>)
loss: 4.5496134757995605
torch.Size([32, 65])


In [19]:
# example generation
# B = 1 and T = 1 to kick off the generation
# 0 is also encoded as \n, which is a good place to start
idx = torch.zeros((1, 1), dtype=torch.long)
print(decode(m.generate(idx, max_new_tokens=100)[0].tolist())) #[0] to get a single batch


BL&xr
rwv!IepVIlbjFAHzCYJifXG&3&!a;pCZntuF;YjfRAuFAl;:lGJXTPYn!HM-- sWmJdL&NfcVJdF$nMvzbYjgvUpb'?vpC


### Training The Bigram Model

In [20]:
# creating a pytorch optimizer to get the gradients and update the parameters
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)

In [21]:
batch_size = 32
for steps in range(10000):
    # sample a batch of data
    xb, yb = get_batch('train')

    # evaluate the loss
    logits, loss = m(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

print(f'loss: {loss.item()}')

loss: 2.4108874797821045


In [22]:
# test generation
print(decode(m.generate(idx, max_new_tokens=400)[0].tolist()))


NGowerarisaxed, f I s m fouiethan ond mellchashe the THe orongendy'le, k, nk merp er q$ffffos, twidy
A:
Cousfress sefthe ma nd theaw'swate ghered,
I
Aur KETor byerin fame oGS:
lee.
NRFre:
I messen:
Cor D&
Cl
YWhind, her bee ted t, w trd Cotltixet ce t HENG me.
s worede of bor t tosesthesoicassw ay ir oukipserfonove; y Gise.
Wharar:
Whiro th.
W: ceandutous rtiffomy

VI wemofancat r,
Coen wher moner


# Implementing Self Attention
We want the tokens to be aware of each other. Specifically we want the current token to be aware of the tokens that have appeared before (and not the tokens that come after) and couple them. To do that we calculate the average of all the previous token which then acts as the summary of all the information before the current token which it can use to predict the next token. Keep in mind this approach loses out on a lot of spacial information about the previous tokens.

In [23]:
# example
B, T, C = 4, 8, 2
x = torch.randn(B, T, C)
x.shape

torch.Size([4, 8, 2])

In [24]:
# calc average of prev tokens
# this implementation is inefficient
xbow = torch.zeros((B, T, C)) # bow means bag of words i.e. average of the prev tokens
for b in range(B):
    for t in range(T):
        xprev = x[b, :t+1] # everything upto the current token including the current token
        xbow[b, t] = torch.mean(xprev, 0)

In [25]:
x[0], xbow[0]

(tensor([[-0.0613,  1.0447],
         [-0.9260, -0.6164],
         [-1.5036,  0.0651],
         [ 0.5953,  0.1343],
         [ 1.1044, -1.5494],
         [ 0.5825,  0.3329],
         [ 1.2951, -0.6300],
         [ 0.8796, -0.9741]]),
 tensor([[-0.0613,  1.0447],
         [-0.4936,  0.2141],
         [-0.8303,  0.1645],
         [-0.4739,  0.1569],
         [-0.1582, -0.1843],
         [-0.0348, -0.0981],
         [ 0.1552, -0.1741],
         [ 0.2458, -0.2741]]))

### Using Matrix Multiplication To Make Average Calculation More Efficient

In [26]:
# matrix multiplication example
a = torch.ones(3, 3)
b = torch.randint(0, 10, (3, 2)).float()
c = a @ b
print(f'matrix a:\n{a}')
print(f'matrix b:\n{b}')
print(f'matrix c:\n{c}')

matrix a:
tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])
matrix b:
tensor([[5., 1.],
        [3., 6.],
        [4., 3.]])
matrix c:
tensor([[12., 10.],
        [12., 10.],
        [12., 10.]])


In [27]:
# getting a lower triangular matrix
torch.tril(torch.ones(3, 3))

tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])

In [28]:
# using a lower triangular matrix gives each element in c matrix as sum of prev elements upon matrix multiplication
a = torch.tril(torch.ones(3, 3))
a = a / torch.sum(a, 1, keepdim=True) # normalize each row such that sum of all elements in a row is one
b = torch.randint(0, 10, (3, 2)).float()
c = a @ b # in c each row will be the average of the previous rows
print(f'matrix a:\n{a}')
print(f'matrix b:\n{b}')
print(f'matrix c:\n{c}')

matrix a:
tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
matrix b:
tensor([[3., 1.],
        [9., 5.],
        [1., 7.]])
matrix c:
tensor([[3.0000, 1.0000],
        [6.0000, 3.0000],
        [4.3333, 4.3333]])


In [29]:
# version 2
wei = torch.tril(torch.ones(T, T))
wei = wei / wei.sum(1, keepdim = True)
wei

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

In [30]:
xbow2 = wei @ x # ((B), T, T) @ (B, T, C) -> (B, T, C) for each batch a T, T multiplies to a T, C
xbow2

tensor([[[-0.0613,  1.0447],
         [-0.4936,  0.2141],
         [-0.8303,  0.1645],
         [-0.4739,  0.1569],
         [-0.1582, -0.1843],
         [-0.0348, -0.0981],
         [ 0.1552, -0.1741],
         [ 0.2458, -0.2741]],

        [[ 1.6547,  0.5282],
         [ 0.2378,  0.5293],
         [ 0.3971,  0.9392],
         [ 0.3010,  0.7072],
         [ 0.2754,  0.6264],
         [ 0.3268,  0.7056],
         [ 0.1109,  0.5249],
         [ 0.2850,  0.5235]],

        [[ 1.0538,  0.1509],
         [ 0.9856, -0.1346],
         [ 0.4760,  0.4241],
         [ 0.1226,  0.6300],
         [-0.0399,  0.4964],
         [ 0.1006,  0.6127],
         [ 0.0724,  0.6003],
         [-0.0534,  0.5590]],

        [[ 1.1385,  0.4626],
         [-0.0193, -0.1204],
         [ 0.3171,  0.2542],
         [ 0.2987,  0.1163],
         [ 0.1737,  0.0299],
         [ 0.3123, -0.1032],
         [ 0.0927,  0.0549],
         [ 0.1127,  0.2565]]])

In [31]:
# version 3 (using softmax)

tril:
tensor([[1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1.]])

masked filled wei:
tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0., 0., 0., 0.]])

softmaxed wei:
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
      

True

## Implementing Self Attention

This attention is called self attention because the keys, queries and the values, all come from the same source x.
In encoder-decoder transformers we can have keys from x but queries and values can come from a different source. This is called cross attention where we have nodes from a separate source which we'd like to pool information from.

In [27]:
B, T, C = 4, 8, 32
x = torch.randn(B, T, C)

# create a single head of self attention
head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)

k = key(x) # (B, T, 16)
q = query(x) # (B, T, 16)
v = value(x)
# transpose last two dimensions to multiply
# (B, T, 16) @ (B, 16, T) -> (B, T, T)
wei = q @ k.transpose(-2, -1) * head_size**-0.5 # multiply with sqrt of head_size to bring down variance to order of 1

print(f'key: {k}\nquery: {q}\n')
print(f'key_shape: {k.shape}\nquery_shape: {q.shape}\n')

tril = torch.tril(torch.ones(T, T))
# we have implemented a decoder block
# to make it an encoder block we simply remove this line
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=1)
out = wei @ v

print(f'out: {out}\n')
print(f'out_shape: {out.shape}\n')

key: tensor([[[ 0.6223,  0.6672, -0.1610, -0.5891,  0.0268, -0.4281, -0.7132,
           0.4530,  1.0593, -0.3565, -0.0345, -0.8459,  0.8692, -0.2875,
           0.1336,  0.4239],
         [-0.2630,  0.2687, -0.6003, -0.8501, -0.0380, -0.6245, -0.0931,
          -0.4417,  1.2339, -0.4511,  0.8644,  0.3217, -0.0086, -0.2291,
           0.2228,  1.0264],
         [ 0.9350,  0.8069, -0.5908,  0.5710,  0.4617, -0.2731,  0.0925,
           0.3897,  0.2646,  1.0613, -1.3004,  0.1308, -0.4151,  0.0740,
          -0.5951, -0.7345],
         [ 0.8388, -0.2040, -0.7414,  0.2889,  0.0683,  0.0270,  0.6034,
           0.7024, -0.2394, -0.2562,  0.2582,  0.2980, -0.1518,  0.4322,
           0.3429, -0.7432],
         [-0.2021, -0.9490, -0.0956,  0.5241,  1.0297, -0.0988, -0.3211,
           0.4393, -1.2442, -0.0778, -0.2027,  0.0133,  0.1113,  0.3739,
           0.6861, -1.0965],
         [-0.2959, -0.1842,  0.1077, -0.4384,  0.1618,  0.4899, -0.5088,
          -0.4491, -0.3239, -0.3367, -0.0535, -

In [28]:
wei[0]

tensor([[0.1503, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1101, 0.1001, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1659, 0.1213, 0.1428, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0924, 0.1336, 0.1980, 0.2620, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0875, 0.1096, 0.2102, 0.2217, 0.2981, 0.0000, 0.0000, 0.0000],
        [0.1191, 0.1577, 0.1615, 0.1890, 0.2277, 0.2880, 0.0000, 0.0000],
        [0.1456, 0.2414, 0.1467, 0.1643, 0.1740, 0.2599, 0.5100, 0.0000],
        [0.1291, 0.1361, 0.1408, 0.1630, 0.3002, 0.4521, 0.4900, 1.0000]],
       grad_fn=<SelectBackward0>)