In [188]:
import numpy as np
import matplotlib.pyplot as plt
import utils.matplotlib_init

from IPython.display import Image
from matplotlib.animation import FuncAnimation

In [194]:
x = np.array([[8, 5], [10, 7], [7, 10], [7, 9], [8, 8], [12, 4], [15, 2], [16, 3], [13, 3]])
y = np.array([1, 1, 1, 1, 1, -1, -1, -1, -1])

fig = plt.figure()
ax = plt.gca()

plt.title('Perceptron finds separating line')
plt.xlim(0, 20)
plt.ylim(0, 11)
ax.set_xticks(np.arange(0, 20))
ax.set_yticks(np.arange(0, 11))
plt.tick_params(axis='both', which='major', labelsize=14)
plt.grid(True)

plt.plot(x[y == -1, 0], x[y == -1, 1], '_', c='C0', ms=15, mew=5)
plt.plot(x[y == 1, 0], x[y == 1, 1], '+', c='C1', ms=15, mew=5)

w = np.random.randn(2)
xx = np.linspace(0, 20, 10)
yy = - w[0] * xx / w[1]
mae = abs(np.sign(np.dot(x, w)) - y).sum()
h = 0.1
x_grid, y_grid = np.meshgrid(np.arange(0, 21, h), np.arange(0, 12, h))
z_grid = np.sign(w[0] * x_grid + w[1] * y_grid)

contour = ax.contourf(x_grid, y_grid, z_grid, colors=['C0', 'C1'], alpha=0.5)
line, = plt.plot(xx, yy, linestyle='--', color='#360C90', linewidth=1.5)
mae_label = ax.text(0.05, 0.92, f'step=0\nMAE={mae}', fontsize=24, transform=ax.transAxes, 
                verticalalignment='top', bbox={'facecolor': '#D3C8FF', 'alpha': 0.5, 
                                               'edgecolor': '#AC8AF5', 'pad': 6})

lr = 0.0005

def animate(i):
    global w, contour
    
    z = np.dot(x, w)
    predicted = np.sign(z)
    w += lr * x.T @ (y - predicted)
    yy = - w[0] * xx / w[1]
    z_grid = np.sign(w[0] * x_grid + w[1] * y_grid)
    mae = abs(np.sign(np.dot(x, w)) - y).sum()
    
    line.set_data(xx, yy)
    
    for tp in contour.collections:
        tp.remove()

    contour = ax.contourf(x_grid, y_grid, z_grid, colors=['C0', 'C1'], alpha=0.5)
    mae_label.set_text(f'step={i + 1}\nMAE={mae}')

    return [line, mae_label] + contour.collections

ax.figure.tight_layout(pad=1.5)
ani = FuncAnimation(fig, animate, frames=30, blit=True)
plt.close()

ani.save('assets/j.gif', writer='imagemagick')

Image(url='assets/j.gif')