In [67]:
import torch
from ngram.dataset import Dataset
from ngram.encoder import Encoder
from sdk.batch_norm import BatchNorm
from sdk.cross_entropy import CrossEntropy
from sdk.embeddings import Embedding
from sdk.flatten import Flatten
from sdk.linear import Linear
from sdk.plotter import Plotter
from sdk.sequential import Sequential
from sdk.tanh import Tanh
from torch import Tensor
from torch.nn import functional as F

In [68]:
filepath = '../makemore/data/names.txt'
input_words = []
with open(filepath, encoding='utf-8') as f:
    input_words = f.read().splitlines()

In [79]:
context_length = 8

In [80]:
encoder = Encoder()
dataset = Dataset(input_words=input_words, context_length=context_length)

In [92]:
# layers
embedding_dim: int = 10
num_hidden: int = 200

model = Sequential([
    Embedding(num_embeddings=len(encoder.ltoi), embedding_dim=embedding_dim),
    Flatten(),

    Linear(in_features=dataset.train_inputs.shape[1] * embedding_dim, out_features=num_hidden, nonlinearity='tanh'),
    BatchNorm(num_features=num_hidden),
    Tanh(),

    Linear(in_features=num_hidden, out_features=len(encoder.ltoi), nonlinearity=None),
])

loss_fn = CrossEntropy()

In [93]:
dataset.train_inputs.shape

torch.Size([182516, 8])

In [94]:
ix = torch.randint(0, dataset.train_inputs.shape[0], (4,))
Xb, Yb = dataset.train_inputs[ix], dataset.train_targets[ix]  # mini-batch of 4 examples
logits = model(Xb)
print(Xb.shape)
Xb

torch.Size([4, 8])


tensor([[ 0,  0,  0,  0,  0,  0,  0,  2],
        [ 0,  0,  0,  0,  0,  0,  0,  0],
        [ 0,  0,  0,  0, 19, 15, 22,  5],
        [ 0,  0,  0,  0,  0,  0,  0, 18]])

In [95]:
for layer in model.layers:
    print(f'{layer.__class__.__name__}, {layer.output.shape}')

Embedding, torch.Size([4, 8, 10])
Flatten, torch.Size([4, 80])
Linear, torch.Size([4, 200])
BatchNorm, torch.Size([4, 200])
Tanh, torch.Size([4, 200])
Linear, torch.Size([4, 27])


In [99]:
((torch.randn(4, 5, 80) @ torch.randn(80, 200)) + torch.randn(200)).shape

torch.Size([4, 5, 200])

In [109]:
inputs = torch.randint(0, 10, (4, 8))
inputs = inputs.view(4, 4, 2)
print(inputs.dtype)

model.layers[0](inputs)

torch.int64


tensor([[[[-1.6214e+00, -2.8306e-03, -1.8308e+00,  1.6377e-01,  4.3264e-01,
            7.5099e-02,  1.9898e+00,  1.0147e+00,  7.1080e-02, -1.2500e+00],
          [-2.7088e-01, -7.2826e-02,  1.7824e+00, -2.1410e-01,  3.7664e-02,
           -1.0301e+00,  2.0842e-01, -2.7404e-01, -2.2529e+00, -9.6986e-01]],

         [[-2.7088e-01, -7.2826e-02,  1.7824e+00, -2.1410e-01,  3.7664e-02,
           -1.0301e+00,  2.0842e-01, -2.7404e-01, -2.2529e+00, -9.6986e-01],
          [-1.6214e+00, -2.8306e-03, -1.8308e+00,  1.6377e-01,  4.3264e-01,
            7.5099e-02,  1.9898e+00,  1.0147e+00,  7.1080e-02, -1.2500e+00]],

         [[ 9.1957e-01,  2.3894e-01,  7.9507e-01, -7.8506e-01, -7.4417e-01,
            1.1262e+00,  1.6870e+00, -7.5643e-01, -9.0460e-02, -4.0857e-01],
          [-2.7088e-01, -7.2826e-02,  1.7824e+00, -2.1410e-01,  3.7664e-02,
           -1.0301e+00,  2.0842e-01, -2.7404e-01, -2.2529e+00, -9.6986e-01]],

         [[ 4.8032e-01,  2.4502e+00,  1.2691e+00,  7.8784e-01, -3.1605e-01,
