# Recurrent neural network with numpy

## Prerequisite

In [1]:
# some important imports
import numpy as np
from translator import Translator
from tqdm import tqdm

## Encoding text

In [2]:
# data
text = open('data/toy.txt', 'r').read()

# text = 'Hallo'
text_length = len(text)
chars = list(set(text))
char_length = len(chars)
print('text is ', text_length, 'long and has ', char_length)

# creating training data
char_to_int = dict((c, i) for i, c in enumerate(chars))
int_to_char = dict((i, c) for i, c in enumerate(chars))

X = np.array([char_to_int[char] for char in text])
y = np.append(X[1:X.shape[0]], X[0])
print('first 10 datas: ', X[0:10])
print('first 10 labels: ', y[0:10])

text is  1993 long and has  56
first 10 datas:  [39  7 11 22 36 10 32 38 22 36]
first 10 labels:  [ 7 11 22 36 10 32 38 22 36 10]


## Forward pass


In [3]:
def forward_pass(X, hprev):
    ht, pt, loss = [hprev[0]], [], 0
    for t in range(len(X)):
        # creating a one hot encoded vector
        xt = np.zeros((char_length, 1))
        xt[X[t]] = 1

        # calculating forward pass
        zt = np.dot(Wxh, xt) + np.dot(Whh, ht[t])
        ht.append(np.tanh(zt))
        yt = np.dot(Why, ht[t])

        # getting probability distribution
        pt.append(np.exp(yt) / np.sum(np.exp(yt)))

        # summing up the loss of every output
        loss += -np.sum(np.log(pt[t][X[t]]))
    return ht, pt, loss / len(X)

## Backward pass

In [4]:
def backward_pass(X, y, ht, pt):
    dWhh, dWxh, dWhy = np.zeros_like(Whh), np.zeros_like(Wxh), np.zeros_like(Why)
    for t in reversed(range(len(X))):
        dout = pt.copy()
        dout[t][y[t]] -= 1
        dWhy += np.dot(dout[t], ht[t].T)
        dh = np.dot(Why.T, dout[t])
        dtanh = (1 - ht[t] * ht[t]) * dh
        xt = np.zeros((char_length, 1))
        xt[X[t]] = 1
        dWxh += np.dot(dtanh, xt.T)
        dWhh += np.dot(dtanh, ht[t - 1].T)
    
    dWhh /= len(X)
    dWxh /= len(X)
    dWhy /= len(X)
    # gradient clipping
    for dparam in [dWxh, dWhh, dWhy]:
        np.clip(dparam, -5, 5, out=dparam)
    return dWhh, dWxh, dWhy

## Predict function

In [5]:
def predict(X, Wxh, Whh, Why, hprev):
    ht, prediction = [hprev[0]], ''
    for t in range(len(X)):
        # creating a one hot encoded vector
        xt = np.zeros((char_length, 1))
        xt[X[t]] = 1

        # calculating forward pass
        zt = np.dot(Wxh, xt) + np.dot(Whh, ht[t])
        ht.append(np.tanh(zt))
        yt = np.dot(Why, ht[t])

        # getting probability distribution
        pt.append(np.exp(yt) / np.sum(np.exp(yt)))

        # creating a prediction string
        prediction += chars[np.argmax(pt[t])]
    return prediction

## Updating parameter with adagrad

### Intizializing hyperparameter

In [6]:
seq_size = 15
hidden_size = 200
learning_rate = 1e-8
epochs = 10

print('Training ', epochs, ' epochs with a sequence size of ', seq_size, ', a hidden size of ', hidden_size, ' and a learning rate of', learning_rate)

Training  10  epochs with a sequence size of  15 , a hidden size of  200  and a learning rate of 1e-08


### Initizializing learnable parameter

In [7]:
Wxh = np.random.randn(hidden_size, char_length) * 0.01
Whh = np.random.randn(hidden_size, hidden_size) * 0.01
Why = np.random.randn(char_length, hidden_size) * 0.01

In [8]:
# initializing hidden state and squared gradient
ht = [np.zeros((hidden_size, 1))]
grad_squared_xh, grad_squared_hh, grad_squared_hy = np.zeros_like(Wxh), np.zeros_like(Whh), np.zeros_like(Why)

loss = 0
for e in tqdm(range(epochs)):
    for steps in range(0, len(X), seq_size):
        inputs = X[steps:steps+seq_size]
        targets = y[steps:steps+seq_size]

        # forward and backward pass
        ht, pt, loss = forward_pass(inputs, ht)
        dWhh, dWxh, dWhy = backward_pass(inputs, inputs, ht, pt)
        
        # adagrad
        grad_squared_xh += dWxh ** 2
        grad_squared_hh += dWhh ** 2
        grad_squared_hy += dWhy ** 2
        
        # parameter update
        Wxh -= dWxh / np.sqrt(grad_squared_xh + 1e-7) * learning_rate
        Whh -= dWhh / np.sqrt(grad_squared_hh + 1e-7) * learning_rate
        Why -= dWhy / np.sqrt(grad_squared_hy + 1e-7) * learning_rate
    print('loss at epoch ', e, ' is ', loss)
    print(predict(X, Wxh, Whh, Why, ht))

  0%|                                                                                           | 0/10 [00:00<?, ?it/s]

loss at epoch  0  is  4.02495572847
Wgpw-Nx-dpIgeW)"(W(NsDW(NWpxa(woNxW)jxNxp-"(gNxNxgNxpg"(hâp-wWhNc,pg"(xhâ)g,p-(NsDwW0"(W(ps)-D(ah(c-W-s((- (pwW-INxNxp-(a("DIcpgp)h-âD"(((apgg(,I-âNxp-(wW(pwgN(,I-âNxp-(,g(((9I-âNxp-(Ns(N(W(-g((-(hr©-N(wW(r(aIhg(w9(w-D(wW(WpWh(r(c,(-Ehg(w-Nw(((ch(a("dW-Dc(,-NI,Nxr©a"(chbchgD-Nx-s(gwWh(]ppgN]((ggwp)(Nh (ExNs(NsD(hr©-Nx(a-(woNxW)jxNxp-(-cpHEhW(gDDWg(Nw(W)-zW)jh(((apgg(,I-âNxp-0(a-(wHwhâNxr©(,I-âNxp-(,g(hDNsDc(((apgg(,I-âNxp-(wW(,Ng(-Ds(Nxr©(B,-(g-©â)9xp(xpW(x-g"(r(c,pIgE"(pxaah (((chE(cx(,I-âNxp-"(((bcp9xN(,I-âNxp-"(((INWa)N"(9I-âNxp-"(((9xN-Dgg(,I-âNxp-"(hNpx9"(,-(Es)ps(pxgD(,N(,g(Nw(-D(W(j)W)jh xââ)"(gNxNxgNxpg"(N"-)pxaa"(((apgg(,I-âNxp-(,g(IgD (,pW(b(c(WhNhc(hgNxWxNWp-"((- (NsD(hr©-N(,-(hIhgNxp-(,g(gwWh(,I-âNxp-(w9(NsD(x,99Dch-âD(HDNE©D-(hgNxWxNh ((- (Ncdh(r(aIhg(,pW((-(,-gNx-âD(w9(x(Nxx(ssD(pw-âDoN"((g(wE ((g(Ã(-a(pD"(E(g(chD-NcpxsâD (,-(gNxNxgNxpg(H"(a-"(s(W(rca (,-(NsD(W)x ah(w9(NsD(]1Ns(pD-NIc0xBf-()"(NsD(pw-NhjN(w9(hâp-wW)pg"(,pW(hj(Wbah"(Ns)g(

 10%|████████▎                                                                          | 1/10 [00:00<00:06,  1.37it/s]

loss at epoch  1  is  4.0249557068
Wgpw-Nx-dpIgeW)"(W(NsDW(NWpxa(woNxW)jxNxp-"(gNxNxgNxpg"(hâp-wWhNc,pg"(xhâ)g,p-(NsDwW0"(W(ps)-D(ah(c-W-s((- (pwW-INxNxp-(a("DIcpgp)h-âD"(((apgg(,I-âNxp-(wW(pwgN(,I-âNxp-(,g(((9I-âNxp-(Ns(N(W(-g((-(hr©-N(wW(r(aIhg(w9(w-D(wW(WpWh(r(c,(-Ehg(w-Nw(((ch(a("dW-Dc(,-NI,Nxr©a"(chbchgD-Nx-s(gwWh(]ppgN]((ggwp)(Nh (ExNs(NsD(hr©-Nx(a-(woNxW)jxNxp-(-cpHEhW(gDDWg(Nw(W)-zW)jh(((apgg(,I-âNxp-0(a-(wHwhâNxr©(,I-âNxp-(,g(hDNsDc(((apgg(,I-âNxp-(wW(,Ng(-Ds(Nxr©(B,-(g-©â)9xp(xpW(x-g"(r(c,pIgE"(pxaah (((chE(cx(,I-âNxp-"(((bcp9xN(,I-âNxp-"(((INWa)N"(9I-âNxp-"(((9xN-Dgg(,I-âNxp-"(hNpx9"(,-(Es)ps(pxgD(,N(,g(Nw(-D(W(j)W)jh xââ)"(gNxNxgNxpg"(N"-)pxaa"(((apgg(,I-âNxp-(,g(IgD (,pW(b(c(WhNhc(hgNxWxNWp-"((- (NsD(hr©-N(,-(hIhgNxp-(,g(gwWh(,I-âNxp-(w9(NsD(x,99Dch-âD(HDNE©D-(hgNxWxNh ((- (Ncdh(r(aIhg(,pW((-(,-gNx-âD(w9(x(Nxx(ssD(pw-âDoN"((g(wE ((g(Ã(-a(pD"(E(g(chD-NcpxsâD (,-(gNxNxgNxpg(H"(a-"(s(W(rca (,-(NsD(W)x ah(w9(NsD(]1Ns(pD-NIc0xBf-()"(NsD(pw-NhjN(w9(hâp-wW)pg"(,pW(hj(Wbah"(Ns)g(,

 20%|████████████████▌                                                                  | 2/10 [00:01<00:06,  1.32it/s]

loss at epoch  2  is  4.02495569011
Wgpw-Nx-dpIgeW)"(W(NsDW(NWpxa(woNxW)jxNxp-"(gNxNxgNxpg"(hâp-wWhNc,pg"(xhâ)g,p-(NsDwW0"(W(ps)-D(ah(c-W-s((- (pwW-INxNxp-(a("DIcpgp)h-âD"(((apgg(,I-âNxp-(wW(pwgN(,I-âNxp-(,g(((9I-âNxp-(Ns(N(W(-g((-(hr©-N(wW(r(aIhg(w9(w-D(wW(WpWh(r(c,(-Ehg(w-Nw(((ch(a("dW-Dc(,-NI,Nxr©a"(chbchgD-Nx-s(gwWh(]ppgN]((ggwp)(Nh (ExNs(NsD(hr©-Nx(a-(woNxW)jxNxp-(-cpHEhW(gDDWg(Nw(W)-zW)jh(((apgg(,I-âNxp-0(a-(wHwhâNxr©(,I-âNxp-(,g(hDNsDc(((apgg(,I-âNxp-(wW(,Ng(-Ds(Nxr©(B,-(g-©â)9xp(xpW(x-g"(r(c,pIgE"(pxaah (((chE(cx(,I-âNxp-"(((bcp9xN(,I-âNxp-"(((INWa)N"(9I-âNxp-"(((9xN-Dgg(,I-âNxp-"(hNpx9"(,-(Es)ps(pxgD(,N(,g(Nw(-D(W(j)W)jh xââ)"(gNxNxgNxpg"(N"-)pxaa"(((apgg(,I-âNxp-(,g(IgD (,pW(b(c(WhNhc(hgNxWxNWp-"((- (NsD(hr©-N(,-(hIhgNxp-(,g(gwWh(,I-âNxp-(w9(NsD(x,99Dch-âD(HDNE©D-(hgNxWxNh ((- (Ncdh(r(aIhg(,pW((-(,-gNx-âD(w9(x(Nxx(ssD(pw-âDoN"((g(wE ((g(Ã(-a(pD"(E(g(chD-NcpxsâD (,-(gNxNxgNxpg(H"(a-"(s(W(rca (,-(NsD(W)x ah(w9(NsD(]1Ns(pD-NIc0xBf-()"(NsD(pw-NhjN(w9(hâp-wW)pg"(,pW(hj(Wbah"(Ns)g(

 30%|████████████████████████▉                                                          | 3/10 [00:02<00:05,  1.37it/s]

loss at epoch  3  is  4.024955676
Wgpw-Nx-dpIgeW)"(W(NsDW(NWpxa(woNxW)jxNxp-"(gNxNxgNxpg"(hâp-wWhNc,pg"(xhâ)g,p-(NsDwW0"(W(ps)-D(ah(c-W-s((- (pwW-INxNxp-(a("DIcpgp)h-âD"(((apgg(,I-âNxp-(wW(pwgN(,I-âNxp-(,g(((9I-âNxp-(Ns(N(W(-g((-(hr©-N(wW(r(aIhg(w9(w-D(wW(WpWh(r(c,(-Ehg(w-Nw(((ch(a("dW-Dc(,-NI,Nxr©a"(chbchgD-Nx-s(gwWh(]ppgN]((ggwp)(Nh (ExNs(NsD(hr©-Nx(a-(woNxW)jxNxp-(-cpHEhW(gDDWg(Nw(W)-zW)jh(((apgg(,I-âNxp-0(a-(wHwhâNxr©(,I-âNxp-(,g(hDNsDc(((apgg(,I-âNxp-(wW(,Ng(-Ds(Nxr©(B,-(g-©â)9xp(xpW(x-g"(r(c,pIgE"(pxaah (((chE(cx(,I-âNxp-"(((bcp9xN(,I-âNxp-"(((INWa)N"(9I-âNxp-"(((9xN-Dgg(,I-âNxp-"(hNpx9"(,-(Es)ps(pxgD(,N(,g(Nw(-D(W(j)W)jh xââ)"(gNxNxgNxpg"(N"-)pxaa"(((apgg(,I-âNxp-(,g(IgD (,pW(b(c(WhNhc(hgNxWxNWp-"((- (NsD(hr©-N(,-(hIhgNxp-(,g(gwWh(,I-âNxp-(w9(NsD(x,99Dch-âD(HDNE©D-(hgNxWxNh ((- (Ncdh(r(aIhg(,pW((-(,-gNx-âD(w9(x(Nxx(ssD(pw-âDoN"((g(wE ((g(Ã(-a(pD"(E(g(chD-NcpxsâD (,-(gNxNxgNxpg(H"(a-"(s(W(rca (,-(NsD(W)x ah(w9(NsD(]1Ns(pD-NIc0xBf-()"(NsD(pw-NhjN(w9(hâp-wW)pg"(,pW(hj(Wbah"(Ns)g(,g

 40%|█████████████████████████████████▏                                                 | 4/10 [00:02<00:04,  1.38it/s]

loss at epoch  4  is  4.02495566355
Wgpw-Nx-dpIgeW)"(W(NsDW(NWpxa(woNxW)jxNxp-"(gNxNxgNxpg"(hâp-wWhNc,pg"(xhâ)g,p-(NsDwW0"(W(ps)-D(ah(c-W-s((- (pwW-INxNxp-(a("DIcpgp)h-âD"(((apgg(,I-âNxp-(wW(pwgN(,I-âNxp-(,g(((9I-âNxp-(Ns(N(W(-g((-(hr©-N(wW(r(aIhg(w9(w-D(wW(WpWh(r(c,(-Ehg(w-Nw(((ch(a("dW-Dc(,-NI,Nxr©a"(chbchgD-Nx-s(gwWh(]ppgN]((ggwp)(Nh (ExNs(NsD(hr©-Nx(a-(woNxW)jxNxp-(-cpHEhW(gDDWg(Nw(W)-zW)jh(((apgg(,I-âNxp-0(a-(wHwhâNxr©(,I-âNxp-(,g(hDNsDc(((apgg(,I-âNxp-(wW(,Ng(-Ds(Nxr©(B,-(g-©â)9xp(xpW(x-g"(r(c,pIgE"(pxaah (((chE(cx(,I-âNxp-"(((bcp9xN(,I-âNxp-"(((INWa)N"(9I-âNxp-"(((9xN-Dgg(,I-âNxp-"(hNpx9"(,-(Es)ps(pxgD(,N(,g(Nw(-D(W(j)W)jh xââ)"(gNxNxgNxpg"(N"-)pxaa"(((apgg(,I-âNxp-(,g(IgD (,pW(b(c(WhNhc(hgNxWxNWp-"((- (NsD(hr©-N(,-(hIhgNxp-(,g(gwWh(,I-âNxp-(w9(NsD(x,99Dch-âD(HDNE©D-(hgNxWxNh ((- (Ncdh(r(aIhg(,pW((-(,-gNx-âD(w9(x(Nxx(ssD(pw-âDoN"((g(wE ((g(Ã(-a(pD"(E(g(chD-NcpxsâD (,-(gNxNxgNxpg(H"(a-"(s(W(rca (,-(NsD(W)x ah(w9(NsD(]1Ns(pD-NIc0xBf-()"(NsD(pw-NhjN(w9(hâp-wW)pg"(,pW(hj(Wbah"(Ns)g(

 50%|█████████████████████████████████████████▌                                         | 5/10 [00:03<00:03,  1.36it/s]

loss at epoch  5  is  4.02495565228
Wgpw-Nx-dpIgeW)"(W(NsDW(NWpxa(woNxW)jxNxp-"(gNxNxgNxpg"(hâp-wWhNc,pg"(xhâ)g,p-(NsDwW0"(W(ps)-D(ah(c-W-s((- (pwW-INxNxp-(a("DIcpgp)h-âD"(((apgg(,I-âNxp-(wW(pwgN(,I-âNxp-(,g(((9I-âNxp-(Ns(N(W(-g((-(hr©-N(wW(r(aIhg(w9(w-D(wW(WpWh(r(c,(-Ehg(w-Nw(((ch(a("dW-Dc(,-NI,Nxr©a"(chbchgD-Nx-s(gwWh(]ppgN]((ggwp)(Nh (ExNs(NsD(hr©-Nx(a-(woNxW)jxNxp-(-cpHEhW(gDDWg(Nw(W)-zW)jh(((apgg(,I-âNxp-0(a-(wHwhâNxr©(,I-âNxp-(,g(hDNsDc(((apgg(,I-âNxp-(wW(,Ng(-Ds(Nxr©(B,-(g-©â)9xp(xpW(x-g"(r(c,pIgE"(pxaah (((chE(cx(,I-âNxp-"(((bcp9xN(,I-âNxp-"(((INWa)N"(9I-âNxp-"(((9xN-Dgg(,I-âNxp-"(hNpx9"(,-(Es)ps(pxgD(,N(,g(Nw(-D(W(j)W)jh xââ)"(gNxNxgNxpg"(N"-)pxaa"(((apgg(,I-âNxp-(,g(IgD (,pW(b(c(WhNhc(hgNxWxNWp-"((- (NsD(hr©-N(,-(hIhgNxp-(,g(gwWh(,I-âNxp-(w9(NsD(x,99Dch-âD(HDNE©D-(hgNxWxNh ((- (Ncdh(r(aIhg(,pW((-(,-gNx-âD(w9(x(Nxx(ssD(pw-âDoN"((g(wE ((g(Ã(-a(pD"(E(g(chD-NcpxsâD (,-(gNxNxgNxpg(H"(a-"(s(W(rca (,-(NsD(W)x ah(w9(NsD(]1Ns(pD-NIc0xBf-()"(NsD(pw-NhjN(w9(hâp-wW)pg"(,pW(hj(Wbah"(Ns)g(

 60%|█████████████████████████████████████████████████▊                                 | 6/10 [00:04<00:03,  1.33it/s]

loss at epoch  6  is  4.02495564191
Wgpw-Nx-dpIgeW)"(W(NsDW(NWpxa(woNxW)jxNxp-"(gNxNxgNxpg"(hâp-wWhNc,pg"(xhâ)g,p-(NsDwW0"(W(ps)-D(ah(c-W-s((- (pwW-INxNxp-(a("DIcpgp)h-âD"(((apgg(,I-âNxp-(wW(pwgN(,I-âNxp-(,g(((9I-âNxp-(Ns(N(W(-g((-(hr©-N(wW(r(aIhg(w9(w-D(wW(WpWh(r(c,(-Ehg(w-Nw(((ch(a("dW-Dc(,-NI,Nxr©a"(chbchgD-Nx-s(gwWh(]ppgN]((ggwp)(Nh (ExNs(NsD(hr©-Nx(a-(woNxW)jxNxp-(-cpHEhW(gDDWg(Nw(W)-zW)jh(((apgg(,I-âNxp-0(a-(wHwhâNxr©(,I-âNxp-(,g(hDNsDc(((apgg(,I-âNxp-(wW(,Ng(-Ds(Nxr©(B,-(g-©â)9xp(xpW(x-g"(r(c,pIgE"(pxaah (((chE(cx(,I-âNxp-"(((bcp9xN(,I-âNxp-"(((INWa)N"(9I-âNxp-"(((9xN-Dgg(,I-âNxp-"(hNpx9"(,-(Es)ps(pxgD(,N(,g(Nw(-D(W(j)W)jh xââ)"(gNxNxgNxpg"(N"-)pxaa"(((apgg(,I-âNxp-(,g(IgD (,pW(b(c(WhNhc(hgNxWxNWp-"((- (NsD(hr©-N(,-(hIhgNxp-(,g(gwWh(,I-âNxp-(w9(NsD(x,99Dch-âD(HDNE©D-(hgNxWxNh ((- (Ncdh(r(aIhg(,pW((-(,-gNx-âD(w9(x(Nxx(ssD(pw-âDoN"((g(wE ((g(Ã(-a(pD"(E(g(chD-NcpxsâD (,-(gNxNxgNxpg(H"(a-"(s(W(rca (,-(NsD(W)x ah(w9(NsD(]1Ns(pD-NIc0xBf-()"(NsD(pw-NhjN(w9(hâp-wW)pg"(,pW(hj(Wbah"(Ns)g(

 70%|██████████████████████████████████████████████████████████                         | 7/10 [00:05<00:02,  1.34it/s]

loss at epoch  7  is  4.02495563225
Wgpw-Nx-dpIgeW)"(W(NsDW(NWpxa(woNxW)jxNxp-"(gNxNxgNxpg"(hâp-wWhNc,pg"(xhâ)g,p-(NsDwW0"(W(ps)-D(ah(c-W-s((- (pwW-INxNxp-(a("DIcpgp)h-âD"(((apgg(,I-âNxp-(wW(pwgN(,I-âNxp-(,g(((9I-âNxp-(Ns(N(W(-g((-(hr©-N(wW(r(aIhg(w9(w-D(wW(WpWh(r(c,(-Ehg(w-Nw(((ch(a("dW-Dc(,-NI,Nxr©a"(chbchgD-Nx-s(gwWh(]ppgN]((ggwp)(Nh (ExNs(NsD(hr©-Nx(a-(woNxW)jxNxp-(-cpHEhW(gDDWg(Nw(W)-zW)jh(((apgg(,I-âNxp-0(a-(wHwhâNxr©(,I-âNxp-(,g(hDNsDc(((apgg(,I-âNxp-(wW(,Ng(-Ds(Nxr©(B,-(g-©â)9xp(xpW(x-g"(r(c,pIgE"(pxaah (((chE(cx(,I-âNxp-"(((bcp9xN(,I-âNxp-"(((INWa)N"(9I-âNxp-"(((9xN-Dgg(,I-âNxp-"(hNpx9"(,-(Es)ps(pxgD(,N(,g(Nw(-D(W(j)W)jh xââ)"(gNxNxgNxpg"(N"-)pxaa"(((apgg(,I-âNxp-(,g(IgD (,pW(b(c(WhNhc(hgNxWxNWp-"((- (NsD(hr©-N(,-(hIhgNxp-(,g(gwWh(,I-âNxp-(w9(NsD(x,99Dch-âD(HDNE©D-(hgNxWxNh ((- (Ncdh(r(aIhg(,pW((-(,-gNx-âD(w9(x(Nxx(ssD(pw-âDoN"((g(wE ((g(Ã(-a(pD"(E(g(chD-NcpxsâD (,-(gNxNxgNxpg(H"(a-"(s(W(rca (,-(NsD(W)x ah(w9(NsD(]1Ns(pD-NIc0xBf-()"(NsD(pw-NhjN(w9(hâp-wW)pg"(,pW(hj(Wbah"(Ns)g(

 80%|██████████████████████████████████████████████████████████████████▍                | 8/10 [00:06<00:01,  1.33it/s]

loss at epoch  8  is  4.02495562317
Wgpw-Nx-dpIgeW)"(W(NsDW(NWpxa(woNxW)jxNxp-"(gNxNxgNxpg"(hâp-wWhNc,pg"(xhâ)g,p-(NsDwW0"(W(ps)-D(ah(c-W-s((- (pwW-INxNxp-(a("DIcpgp)h-âD"(((apgg(,I-âNxp-(wW(pwgN(,I-âNxp-(,g(((9I-âNxp-(Ns(N(W(-g((-(hr©-N(wW(r(aIhg(w9(w-D(wW(WpWh(r(c,(-Ehg(w-Nw(((ch(a("dW-Dc(,-NI,Nxr©a"(chbchgD-Nx-s(gwWh(]ppgN]((ggwp)(Nh (ExNs(NsD(hr©-Nx(a-(woNxW)jxNxp-(-cpHEhW(gDDWg(Nw(W)-zW)jh(((apgg(,I-âNxp-0(a-(wHwhâNxr©(,I-âNxp-(,g(hDNsDc(((apgg(,I-âNxp-(wW(,Ng(-Ds(Nxr©(B,-(g-©â)9xp(xpW(x-g"(r(c,pIgE"(pxaah (((chE(cx(,I-âNxp-"(((bcp9xN(,I-âNxp-"(((INWa)N"(9I-âNxp-"(((9xN-Dgg(,I-âNxp-"(hNpx9"(,-(Es)ps(pxgD(,N(,g(Nw(-D(W(j)W)jh xââ)"(gNxNxgNxpg"(N"-)pxaa"(((apgg(,I-âNxp-(,g(IgD (,pW(b(c(WhNhc(hgNxWxNWp-"((- (NsD(hr©-N(,-(hIhgNxp-(,g(gwWh(,I-âNxp-(w9(NsD(x,99Dch-âD(HDNE©D-(hgNxWxNh ((- (Ncdh(r(aIhg(,pW((-(,-gNx-âD(w9(x(Nxx(ssD(pw-âDoN"((g(wE ((g(Ã(-a(pD"(E(g(chD-NcpxsâD (,-(gNxNxgNxpg(H"(a-"(s(W(rca (,-(NsD(W)x ah(w9(NsD(]1Ns(pD-NIc0xBf-()"(NsD(pw-NhjN(w9(hâp-wW)pg"(,pW(hj(Wbah"(Ns)g(

 90%|██████████████████████████████████████████████████████████████████████████▋        | 9/10 [00:06<00:00,  1.31it/s]

loss at epoch  9  is  4.02495561459
Wgpw-Nx-dpIgeW)"(W(NsDW(NWpxa(woNxW)jxNxp-"(gNxNxgNxpg"(hâp-wWhNc,pg"(xhâ)g,p-(NsDwW0"(W(ps)-D(ah(c-W-s((- (pwW-INxNxp-(a("DIcpgp)h-âD"(((apgg(,I-âNxp-(wW(pwgN(,I-âNxp-(,g(((9I-âNxp-(Ns(N(W(-g((-(hr©-N(wW(r(aIhg(w9(w-D(wW(WpWh(r(c,(-Ehg(w-Nw(((ch(a("dW-Dc(,-NI,Nxr©a"(chbchgD-Nx-s(gwWh(]ppgN]((ggwp)(Nh (ExNs(NsD(hr©-Nx(a-(woNxW)jxNxp-(-cpHEhW(gDDWg(Nw(W)-zW)jh(((apgg(,I-âNxp-0(a-(wHwhâNxr©(,I-âNxp-(,g(hDNsDc(((apgg(,I-âNxp-(wW(,Ng(-Ds(Nxr©(B,-(g-©â)9xp(xpW(x-g"(r(c,pIgE"(pxaah (((chE(cx(,I-âNxp-"(((bcp9xN(,I-âNxp-"(((INWa)N"(9I-âNxp-"(((9xN-Dgg(,I-âNxp-"(hNpx9"(,-(Es)ps(pxgD(,N(,g(Nw(-D(W(j)W)jh xââ)"(gNxNxgNxpg"(N"-)pxaa"(((apgg(,I-âNxp-(,g(IgD (,pW(b(c(WhNhc(hgNxWxNWp-"((- (NsD(hr©-N(,-(hIhgNxp-(,g(gwWh(,I-âNxp-(w9(NsD(x,99Dch-âD(HDNE©D-(hgNxWxNh ((- (Ncdh(r(aIhg(,pW((-(,-gNx-âD(w9(x(Nxx(ssD(pw-âDoN"((g(wE ((g(Ã(-a(pD"(E(g(chD-NcpxsâD (,-(gNxNxgNxpg(H"(a-"(s(W(rca (,-(NsD(W)x ah(w9(NsD(]1Ns(pD-NIc0xBf-()"(NsD(pw-NhjN(w9(hâp-wW)pg"(,pW(hj(Wbah"(Ns)g(

100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [00:07<00:00,  1.30it/s]
