In [2]:
import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2l

batch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)

Downloading ../data/timemachine.txt from http://d2l-data.s3-accelerate.amazonaws.com/timemachine.txt...


In [17]:
F.one_hot(torch.tensor([0, 2]), len(vocab))

tensor([[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0],
        [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0]])

In [19]:
X = torch.arange(10).reshape(2, 5)
# [5, 2, 10] ---- [time_step, batch_size, vocab_size]
F.one_hot(X.T, 10).shape

torch.Size([5, 2, 10])

In [None]:
# init params in model
def init_params(vocab_size, num_hiddens, device):
    num_input = num_output = vocab_size

    def normal(shape):
        return torch.randn(shape, device=device) * 0.01

    W_xh = normal((num_input, num_hiddens))
    W_hh = normal((num_hiddens, num_hiddens))
    b_h  = torch.zeros(num_hiddens, device=device)

    W_hq = normal((num_hiddens, num_output))
    b_q  = torch.zeros(num_output, device=device)

    params = [W_xh, W_hh, b_h, W_hq, b_q]
    for param in params:
        param.requires_grad_(True)

    return params

In [23]:
# init hidden state in rnn
def init_hidden_state(batch_size, num_hiddens, device):
    return torch.zeros(batch_size, num_hiddens, device=device)

In [None]:
# model: rnn
def rnn_model(inputs, hidden_state, params):
    # inputs: [time_step, batch_size, vocab_size]
    W_xh, W_hh, b_h, W_hq, b_q = params
    H, = hidden_state
    outputs = []
    for X in inputs:
        H = torch.tanh(X @ W_xh + H @ W_hh + b_h)
        Y = H * W_hq + b_q
        outputs.append(Y)
    return torch.cat(outputs, dim = 0), (H,)

In [41]:
# practice -- dimension
a = torch.randn(2, 5)
b = torch.randn(2, 5)
torch.stack((a, b), dim = 0).permute(1, 0, 2) == torch.stack((a, b), dim = 1)

tensor([[[True, True, True, True, True],
         [True, True, True, True, True]],

        [[True, True, True, True, True],
         [True, True, True, True, True]]])