In [51]:
import csv
import math
import random
import matplotlib.pyplot as plt
%matplotlib inline


In [9]:
def read_csv(name):
    with open(name) as fin:
        csv_reader = csv.reader(fin)
        classes = [row for row in csv_reader]
    return classes

In [103]:
def dot(x, y):
    pairs = zip(x, y)
    return sum([x * y for (x, y) in pairs])

def get_batch(x, y, size=1):
    indexes = list(range(len(x)))
    random.shuffle(indexes)
    to_take = indexes[:size]
    x_return = []
    y_return = []
    for i in to_take:
        x_return.append(x[i])
        y_return.append(y[i])
    return (x_return, y_return)

def random_floats(num, a=0, b=1):
    return [random.uniform(a, b) for n in range(num)]


def activation(x, threshold=0):
    return 1 if x > threshold else -1


def minus_list(a, b):
    return [x - y for (x, y) in zip(a, b)]

def loss(minus):
    return sum([x ** 2 for x in minus])

def transpose(m):
    return [[x for (x, _) in m], [y for (_, y) in m]]

def div(vec, d):
    return [x / d for x in vec]

def add_vecs(vec1, vec2):
    return [v1 + v2 for (v1, v2) in zip(vec1, vec2)]

def multiply_scal(vec, a):
    return [v * a for v in vec]


In [140]:
def stochastic_descent(X, y, epochs, rate, weights, batch=1):
    for epoch in range(epochs):
        batch_x, batch_y = get_batch(X, y, batch)
        dotted = [dot(x, weights) for x in batch_x]
        predicted = [activation(d) for d in dotted]
        error = minus_list(predicted, batch_y)
        l = loss(error)
        grad = div([dot(feature, error) for feature in transpose(batch_x)], len(batch_x))
        weights = add_vecs(weights, multiply_scal(grad, -rate))
    dotted = [dot(x, weights) for x in X]
    return [activation(d) for d in dotted]

In [141]:
labels = [int(x[0]) for x in read_csv("./linsep-trainclass.csv")]
data = [(float(x[0]), float(x[1])) for x in read_csv("./linsep-traindata.csv")]

In [146]:
pred = stochastic_descent(data, labels, 5, 0.1, random_floats(2), 10)

In [147]:
print(list(zip(labels, pred)))

[(-1, -1), (-1, -1), (-1, -1), (-1, -1), (-1, -1), (-1, -1), (-1, -1), (-1, -1), (-1, -1), (-1, -1), (-1, -1), (-1, -1), (-1, -1), (-1, -1), (-1, -1), (-1, -1), (-1, -1), (-1, -1), (-1, -1), (-1, -1), (-1, -1), (-1, -1), (-1, -1), (-1, -1), (-1, -1), (-1, -1), (-1, -1), (-1, -1), (-1, -1), (-1, -1), (-1, -1), (-1, -1), (-1, -1), (-1, -1), (-1, -1), (-1, -1), (-1, -1), (-1, -1), (-1, -1), (-1, -1), (-1, -1), (-1, -1), (-1, -1), (-1, -1), (-1, -1), (-1, -1), (-1, -1), (-1, -1), (-1, -1), (-1, -1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1)]
