In [18]:
%matplotlib inline
import math
import torch
from torch import nn
from torch.nn import functional as F
import utils
import d2l

In [19]:
def get_params(vocab_size,num_hiddens,device):
    num_inputs=num_outputs=vocab_size
    def normal(shape):
        return torch.randn(size=shape,device=device)*0.01
    W_xh=normal((num_inputs,num_hiddens))
    W_hh=normal((num_hiddens,num_hiddens))
    b_h=torch.zeros(num_hiddens,device=device)
    W_hq=normal((num_hiddens,num_outputs))
    b_q=torch.zeros(num_outputs,device=device)
    params=[W_xh,W_hh,b_h,W_hq,b_q]
    for param in params:
        param.requires_grad_(True)
    return params


In [20]:
def init_rnn_state(batch_size,num_hiddens,device):
        return (torch.zeros((batch_size,num_hiddens),device=device),)

In [21]:
def rnn(inputs,state,params):
    W_xh,W_hh,b_h,W_hq,b_q=params
    H,=state
    outputs=[]
    for X in inputs:
        H=torch.tanh(torch.mm(X,W_xh)+
                     torch.mm(H,W_hh)+
                     b_h)
        Y=torch.mm(H,W_hq)+b_q
        outputs.append(Y)
    return torch.cat(outputs,dim=0),(H,)
    

In [22]:
class RNN:
    def __init__(self,vocab_size,num_hiddens,device,get_params,init_state,forward_fn) -> None:
        self.vocab_size,self.num_hiddens=vocab_size,num_hiddens
        self.params=get_params(vocab_size,num_hiddens,device)
        self.init_state,self.forward_fn=init_state,forward_fn
    def __call__(self,X,state):
        X=F.one_hot(X.T,self.vocab_size).type(torch.float32)
        return self.forward_fn(X,state,self.params)
    def begin_state(self,batch_size,device):
        return self.init_state(batch_size,self.num_hiddens,device)

In [23]:
def predicts(prefix,num_preds,net,vocab,device):
    state=net.begin_state(batch_size=1,device=device)
    outputs=[vocab[prefix[0]]]
    get_input=lambda:torch.tensor([outputs[-1]],device=device).reshape(1,1)
    for y in prefix[1:]:
        _,state=net(get_input(),state)
        outputs.append(vocab[y])
    for _ in range(num_preds):
        y,state=net(get_input(),state)
        outputs.append(int(y.argmax(dim=1).reshape(1)))
    return ''.join([vocab.idx_to_token[i] for i in outputs ])
    


In [24]:
def grad_clipping(net,theta):
    if isinstance(net,nn.Module):
        params=[p for p in net.parameters()if p.requires_grad]
    else:
        params=net.params
    norm=torch.sqrt(sum(torch.sum((p.grad**2))for p in params))
    if norm>theta:
        for param in params:
            param.grad[:]*=theta/norm


In [25]:
def trainepoch(net,train_iter,loss,updater,device,use_random_iter):
    #use_random_iter 下一个batch跟上一个batch的第i个样本有没有关系

    state,timer=None,utils.Timer()
    metric=utils.Accumulator(2)
    for X,Y in train_iter:
        if state is None or use_random_iter:
            state=net.begin_state(batch_size=X.shape[0],device=device)
        else:
            if isinstance(net,nn.Module)and not isinstance(state):
                state.detach_()
            else:
                for s in state:
                    s.detach_()
        y=Y.T.reshape(-1)
        X,y=X.to(device),y.to(device)
        y_hat,state=net(X,state)
        l=loss(y_hat,y.long()).mean()
        if isinstance(updater,torch.optim.Optimizer):
            updater.zero_grad()
            l.backward()
            grad_clipping(net,1)
            updater.step()
        else:
            l.backward()
            grad_clipping(net,1)
            updater(batch_size=1)
        metric.add(l*y.numel(),y.numel())
    return math.exp(metric[0]/metric[1]),metric[1]/timer.stop()



In [26]:
def train(net,train_iter,vocab,lr,num_epochs,device,use_random_iter=False):
    loss=nn.CrossEntropyLoss()
    updater=torch.optim.SGD(net.parameters(),lr)
    predict=lambda prefix:predicts(prefix,50,net,vocab,device)
    for epoch in range(num_epochs):
        ppl,speed=trainepoch(net,train_iter,loss,updater,device,use_random_iter)
        if (epoch+1)%10==0:
            print(predict('time traveller'))
    print(f'{ppl:.1f},{speed:.1f}')


In [27]:
num_hiddens=256
batch_size=32
vocab=1
rnn_layer=nn.RNN(len(vocab),num_hiddens)
state=torch.zeros((1,batch_size,num_hiddens))

TypeError: object of type 'int' has no len()

In [None]:
X=torch.rand(size=())