# 异或问题无法收敛的例子

In [1]:
import pylab
from matplotlib import gridspec
from sklearn.datasets import make_classification
import numpy as np
from ipywidgets import interact, interactive, fixed
import ipywidgets as widgets
import pickle

# pick the seed for reproducability - change it to explore the effects of random variations
np.random.seed(1)
import random

In [2]:
def plot_boundary(positive_examples, negative_examples, weights):
    if np.isclose(weights[1], 0):
        if np.isclose(weights[0], 0):
            x = y = np.array([-6, 6], dtype = 'float32')
        else:
            y = np.array([-6, 6], dtype='float32')
            x = -(weights[1] * y + weights[2])/weights[0]
    else:
        x = np.array([-6, 6], dtype='float32')
        y = -(weights[0] * x + weights[2])/weights[1]

    pylab.xlim(-6, 6)
    pylab.ylim(-6, 6)                      
    pylab.plot(positive_examples[:,0], positive_examples[:,1], 'bo')
    pylab.plot(negative_examples[:,0], negative_examples[:,1], 'ro')
    pylab.plot(x, y, 'g', linewidth=2.0)
    pylab.show()

In [3]:
def train_graph(positive_examples, negative_examples, num_iterations = 100):
    num_dims = positive_examples.shape[1]
    weights = np.zeros((num_dims,1)) # initialize weights
    
    pos_count = positive_examples.shape[0]
    neg_count = negative_examples.shape[0]
    
    report_frequency = 20;
    snapshots = []
    correct = []
    
    for i in range(num_iterations):
        pos = random.choice(positive_examples)
        neg = random.choice(negative_examples)

        z = np.dot(pos, weights)   
        if z < 0:
            weights = weights + pos.reshape(weights.shape)

        z  = np.dot(neg, weights)
        if z >= 0:
            weights = weights - neg.reshape(weights.shape)
            
        if i % report_frequency == 0:             
            pos_out = np.dot(positive_examples, weights)
            neg_out = np.dot(negative_examples, weights)        
            pos_correct = (pos_out >= 0).sum() / float(pos_count)
            neg_correct = (neg_out < 0).sum() / float(neg_count)
            snapshots.append(np.copy(weights))
            correct.append((pos_correct+neg_correct)/2.0)

    return np.array(snapshots), np.array(correct)

In [4]:
def plotit(pos_examples, neg_examples, snapshots, correct_xor, step):
    fig = pylab.figure(figsize=(10,4))
    fig.add_subplot(1, 2, 1)
    plot_boundary(pos_examples, neg_examples, snapshots[step])
    fig.add_subplot(1, 2, 2)
    pylab.plot(np.arange(len(correct_xor)), correct_xor)
    pylab.ylabel('Accuracy')
    pylab.xlabel('Iteration')
    pylab.plot(step, correct_xor[step], "bo")
    pylab.show()

In [5]:
pos_examples_xor = np.array([[1,0,1],[0,1,1]])
neg_examples_xor = np.array([[1,1,1],[0,0,1]])

snapshots_xor, correct_xor = train_graph(pos_examples_xor, neg_examples_xor, 1000)
def pl2(step):
    plotit(pos_examples_xor, neg_examples_xor, snapshots_xor, correct_xor, step)

In [6]:
interact(pl2, step=widgets.IntSlider(value=0, min=0, max=len(snapshots_xor)-1))

interactive(children=(IntSlider(value=0, description='step', max=49), Output()), _dom_classes=('widget-interac…

<function __main__.pl2(step)>