In [1]:
from sklearn.datasets import load_svmlight_file
import numpy
import random
import matplotlib.pyplot as plt

In [2]:
X_train, y_train = load_svmlight_file("a9a", n_features=123)
X_val, y_val = load_svmlight_file("a9a.t", n_features=123)
X_train = X_train.toarray()
X_val = X_val.toarray()

In [3]:
n_samples_train, n_features_train = X_train.shape
n_samples_val, n_features_val = X_val.shape
X_train = numpy.column_stack((X_train, numpy.ones((n_samples_train, 1))))
X_val = numpy.column_stack((X_val, numpy.ones((n_samples_val, 1))))
y_train = y_train.reshape((-1, 1))
y_val = y_val.reshape((-1, 1))

In [4]:
max_epoch = 300
learning_rate = 0.005
batch_size = 200
C = 0.001

losses_train = []
losses_val = []
# w = numpy.zeros((n_features + 1, 1))  # initialize with zeros
w = numpy.random.random((n_features_train + 1, 1))  # initialize with random numbers

In [5]:
for epoch in range(max_epoch):
    start = random.randrange(0, n_samples_train - batch_size, 1)
    end = start + batch_size

    h = 1 - y_train[start:end, :] * numpy.dot(X_train[start:end, :], w)
    y_hat = numpy.where(h > 0, y_train[start:end, :], 0)
    w -= learning_rate * (w - C * numpy.dot(X_train[start:end, :].transpose(), y_hat))

    y_predict_train = numpy.where(numpy.dot(X_train[start:end, :], w) > 0, 1, -1)
    loss_train = numpy.sum(w * w) + C * numpy.sum(numpy.maximum(1 - y_train[start:end, :] * numpy.dot(X_train[start:end, :], w), 0))
    losses_train.append(loss_train / y_train[start:end, :].size)

    y_predict_val = numpy.where(numpy.dot(X_val[start:end, :], w) > 0, 1, -1)
    loss_val = numpy.sum(w * w) + C * numpy.sum(numpy.maximum(1 - y_val[start:end, :] * numpy.dot(X_val[start:end, :], w), 0))
    losses_val.append(loss_val/y_val[start:end, :].size)



  from ipykernel import kernelapp as app


In [6]:
plt.figure(figsize=(18, 7))
plt.plot(losses_train, color="r", label="train")
plt.plot(losses_val, color="b", label="validation")
plt.legend()
plt.xlabel("epoch")
plt.ylabel("loss")
plt.title("The graph of loss value varing with the number of iterations")
plt.show()