In [1]:
import math
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from graphviz import Digraph
import random
from Value import Value
from Network import mlp

In [2]:
def trace(root):
    nodes,edges=set(),set()
    def build(n):
        if n not in nodes:
            nodes.add(n)
            for child in n._prev:
                edges.add((child,n))
                build(child)
    build(root)
    return nodes,edges

def draw_dot(root, format='svg', rankdir='LR'):
    assert rankdir in ['LR', 'TB']
    nodes, edges = trace(root)
    dot = Digraph(format=format, graph_attr={'rankdir': rankdir}) #, node_attr={'rankdir': 'TB'})
    
    for n in nodes:
        uid=str(id(n))
        dot.node(name=uid, label = "{ %s | data %.4f | grad %.4f}" % (n.label,n.data,n.grad), shape='record', width='0.5', height='0.5')
        if n._op:
            dot.node(name=uid + n._op, label=n._op, width='0.5', height='0.5')
            dot.edge(uid + n._op, uid)
    
    for n1, n2 in edges:
        dot.edge(str(id(n1)), str(id(n2)) + n2._op)
    
    return dot

In [3]:
model=mlp(3,[4,4,1])
xs=[
    [2,3,-1],
    [3,-1,0.5],
    [0.5,1,1],
    [1,1,-1]
]
ys=[1,-1,-1,1]


In [4]:
for k in range(10):
    # forward
    ypred=[model(x) for x in xs]
    # 均方差损失
    loss=sum([(yout-y)**2 for yout,y in zip(ypred,ys)])

    # backword
    model.zero_grad()
    loss.backword()

    # update
    learning_rate = 0.1
    for p in model.parameters():
        p.data -= learning_rate * p.grad
    print(loss.data)

2.86525826245812
0.6007808071949298
0.4061887356363051
0.20380593420990561
0.09755716977997432
0.0792666885318721
0.06675717164351228
0.05745657198956426
0.05027787024618666
0.04458386221966808


In [6]:
ypred=[model(x) for x in xs]
ypred

[Value(data=0.927745757482087),
 Value(data=-0.914295254200722),
 Value(data=-0.8765420408559343),
 Value(data=0.8897266950213907)]

In [None]:
a=Value(2.0,label='a')
b=Value(3.0,label='b')
c=a**3

c.backword()
draw_dot(c)

In [None]:
x1=Value(2.0, label='x1')
x2=Value(0.0, label='x2')

w1=Value(-3.0, label='w1')
w2=Value(1.0, label='w2')

b=Value(6.8813735870195432, label='b')

x1w1=x1*w1; x1w1.label='x1*w1'
x2w2=x2*w2; x2w2.label='x2*w2'

x1w1x2w2=x1w1+x2w2; x1w1x2w2.label='x1*w1+x2*w2'

n=x1w1x2w2+b; n.label='n'

o=n.tanh(); o.label='o'

o.backword()
draw_dot(o)

In [None]:
x1=Value(2.0, label='x1')
x2=Value(0.0, label='x2')

w1=Value(-3.0, label='w1')
w2=Value(1.0, label='w2')

b=Value(6.8813735870195432, label='b')

x1w1=x1*w1; x1w1.label='x1*w1'
x2w2=x2*w2; x2w2.label='x2*w2'

x1w1x2w2=x1w1+x2w2; x1w1x2w2.label='x1*w1+x2*w2'

n=x1w1x2w2+b; n.label='n'

e=(2*n).exp(); e.label='e'
o=(e-1)/(e+1); o.label='o'

o.backword()
draw_dot(o)