In [74]:
from IPython.display import display, HTML

import plotly.offline as py
# import plotly.plotly as py
py.init_notebook_mode(connected=True)
import plotly.graph_objs as go

import numpy as np

import itertools
import colorlover as cl

# An Intro to Neural Networks
---

In [103]:
def nn_config(color, layer_name):
    config = dict(
        name   = layer_name,
        mode   = "markers",
        marker = dict(
            size  = 40,
            line  = dict(width = 2),
            color = color
        ),
    )
    return config

inpt_n = 3
hid1_n = 6
hid2_n = 5
outp_n = 3

layr_n = (inpt_n, hid1_n, hid2_n, outp_n)
maxx_n = max(layr_n)

inpt_c = 1
hid1_c = 1 + inpt_c
hid2_c = 1 + hid1_c
outp_c = 1 + hid2_c

odd_start = (inpt_n % 2 == 1)

def gen_x(n, start):
    return np.array([start for _ in np.arange(0, n, 1)])

def gen_y(n):
    start = (maxx_n - n) / 2
    return np.array([y for y in np.arange(start, n + start, 1)])

nn_cls = cl.scales["4"]["qual"]["Dark2"]
inpt_x = gen_x(inpt_n, inpt_c); inpt_y = gen_y(inpt_n)
inpt_l = go.Scatter(nn_config(nn_cls[0], "input"),  x=inpt_x, y=inpt_y)

hid1_x = gen_x(hid1_n, hid1_c); hid1_y = gen_y(hid1_n)
hid1_l = go.Scatter(nn_config(nn_cls[1], "hidden1"), x=hid1_x, y=hid1_y)

hid2_x = gen_x(hid2_n, hid2_c); hid2_y = gen_y(hid2_n)
hid2_l = go.Scatter(nn_config(nn_cls[2], "hidden2"), x=hid2_x, y=hid2_y)

outp_x = gen_x(outp_n, outp_c); outp_y = gen_y(outp_n)
outp_l = go.Scatter(nn_config(nn_cls[3], "output"), x=outp_x, y=outp_y)

nn = [inpt_l, hid1_l, hid2_l, outp_l,]

def gen_connect(x1, x2, n1, n2):
    return np.array([
        np.array([
            [x1[_1], x2[_2], None] for _2 in range(n2)
        ]).flatten() for _1 in range(n1)
    ])

def gen_scatter(x, y, colors, n):
    return [
        go.Scatter(x = x[_], y = y[_],
                   line       = dict(color = colors[_]),
                   showlegend = False,
        ) for _ in range(n)
    ]

## input to hidden1 connection
in_h1_c = cl.scales[str(inpt_n)]["seq"]["Blues"]
in_h1_x = gen_connect(inpt_x, hid1_x, inpt_n, hid1_n)
in_h1_y = gen_connect(inpt_y, hid1_y, inpt_n, hid1_n)
in_h1   = gen_scatter(in_h1_x, in_h1_y, in_h1_c, inpt_n)

h1_h2_c = cl.scales[str(hid1_n)]["seq"]["Reds"]
h1_h2_x = gen_connect(hid1_x, hid2_x, hid1_n, hid2_n)
h1_h2_y = gen_connect(hid1_y, hid2_y, hid1_n, hid2_n)
h1_h2   = gen_scatter(h1_h2_x, h1_h2_y, h1_h2_c, hid1_n)

h2_ot_c = cl.scales[str(hid2_n)]["seq"]["Greens"]
h2_ot_x = gen_connect(hid2_x, outp_x, hid2_n, outp_n)
h2_ot_y = gen_connect(hid2_y, outp_y, hid2_n, outp_n)
h2_ot   = gen_scatter(h2_ot_x, h2_ot_y, h2_ot_c, hid2_n)



# layout = dict(
#     xaxis=dict(range=[xm, xM], autorange=False, zeroline=False),
#     yaxis=dict(range=[ym, yM], autorange=False, zeroline=False),
#     title='Running Through a Neural Network', hovermode='closest',
#     updatemenus= [{'type': 'buttons',
#                    'buttons': [{'label': 'Play',
#                                 'method': 'animate',
#                                 'args': [None]}]}]
# )

axes_tmpl = dict(
    ticks = '',
    showgrid = False,
    zeroline = False,
    autotick = True,
    showticklabels = False,
)

layout = go.Layout(
    title = "Stepping through a Neural Network",
    xaxis = axes_tmpl,
    yaxis = axes_tmpl,
    hovermode = False,
)
# frames = [dict(data=[dict(x=[xx[k]],
#                           y=[yy[k]],
#                           mode='markers', 
#                           marker=dict(color='red', size=10)
#                          )
#                     ]) for k in range(N)]    
          
# figure1 = dict(data=data, layout=layout, frames=frames)          
data = in_h1 + h1_h2 + h2_ot + nn
py.iplot(go.Figure(data=data, layout=layout))