In [181]:
import numpy as np
import pandas as pd
import time

In [201]:
def load_data(file='train.dat'):
    arr = np.zeros((0, 6))
    with open(file, 'r') as fin:
        for line in fin:
            arr = np.append(arr, np.ones((1, 6)), axis=0)
            line = line.replace('\t', ' ').split(' ')
            for i in range(1, 6):
                arr[-1][i] = float(line[i-1])
    
    return (arr[:,:-1], arr[:,-1])

In [186]:
def PLA(X, y):
    w = np.zeros(X[0].shape)
    n = X.shape[0]
    no_err_cnt = 0
    i = 0
    it = 0
    while no_err_cnt < n:
        pred_sign = np.sign(w @ X[i])
        if pred_sign == 0:
            pred_sign = -1
        if pred_sign != y[i]:
            w += y[i] * X[i]
            no_err_cnt = 0
            it += 1
        else:
            no_err_cnt += 1
        i = (i + 1) % n
    print('Halted in', it, 'iterations.')
    return w

In [187]:
X, y = load_data()

In [188]:
w = PLA(X, y)
print('Error: ', np.sum(np.sign(X@w)-y))

Halted in 45 iterations.
Error:  0.0


In [189]:
def PLA_rnd(X, y):
    w = np.zeros(X[0].shape)
    n = X.shape[0]
    no_err_cnt = 0
    i = 0
    it = 0
    
    np.random.seed(int(time.time()*1000000)%int(1e9+7))
    seq = np.arange(n)
    np.random.shuffle(seq)
    
    while no_err_cnt < n:
        pred_sign = np.sign(w @ X[seq[i]])
        if pred_sign == 0:
            pred_sign = -1
        if pred_sign != y[seq[i]]:
            w += y[seq[i]] * X[seq[i]]
            no_err_cnt = 0
            it += 1
        else:
            no_err_cnt += 1
        i = (i + 1) % n
#     print('Halted in', it, 'iterations.')
    return it

In [190]:
def test_rnd(X, y):
    s = 0
    for i in range(2000):
        s += PLA_rnd(X, y)
    print('Halted in (on average)', s / 2000, 'iterations.')

In [191]:
test_rnd(X, y)

Halted in (on average) 39.8175 iterations.


In [192]:
def PLA_rnd_2(X, y):
    w = np.zeros(X[0].shape)
    n = X.shape[0]
    no_err_cnt = 0
    i = 0
    it = 0
    
    np.random.seed(int(time.time()*1000000)%int(1e9+7))
    seq = np.arange(n)
    np.random.shuffle(seq)
    
    while no_err_cnt < n:
        pred_sign = np.sign(w @ X[seq[i]])
        if pred_sign == 0:
            pred_sign = -1
        if pred_sign != y[seq[i]]:
            w += 0.5 * y[seq[i]] * X[seq[i]]
            no_err_cnt = 0
            it += 1
        else:
            no_err_cnt += 1
        i = (i + 1) % n
#     print('Halted in', it, 'iterations.')
    return it

In [193]:
def test_rnd_2(X, y):
    s = 0
    for i in range(2000):
        s += PLA_rnd_2(X, y)
    print('Halted in (on average)', s / 2000, 'iterations.')

In [194]:
test_rnd_2(X, y)

Halted in (on average) 39.671 iterations.


In [281]:
def pocket(X, y):
    w = np.zeros(X[0].shape)
    w_best = np.zeros(X[0].shape)
    n = X.shape[0]
    no_err_cnt = 0
    it = 0
    
    np.random.seed(int(time.time()*1000000)%int(1e9+7))
    
    while it < 100:
        i = np.random.randint(0, n)
        pred_sign = np.sign(2*np.sign(w @ X[i])-1)
#         pred_sign = np.sign(w @ X[i])
#         if pred_sign == 0:
#             pred_sign = -1
        if pred_sign != y[i]:
            w += y[i] * X[i]
            
            # Check w against w_best
            w_err = np.sum(np.sign(2*np.sign(X @ w)-1) != y)
            w_best_err = np.sum(np.sign(2*np.sign(X @ w_best)-1) != y)
            if w_err < w_best_err:
                w_best = w.copy()
            
            no_err_cnt = 0
            it += 1
        else:
            no_err_cnt += 1
#     print('Halted in', it, 'iterations.')
    return w_best

In [282]:
def test_pocket():
    err = 0
    X_train, y_train = load_data('hw1_18_train.dat')
    X_test, y_test = load_data('hw1_18_test.dat')
    for rep in range(2000):
        w = pocket(X_train, y_train)
#         print(np.sum(np.abs(np.sign(2*np.sign(X_train @ w)-1)-y_train)) / len(y_train))
        err += np.sum(np.sign(2*np.sign(X_test @ w)-1) != y_test) / len(y_test)
    print(err / 2000)

In [283]:
test_pocket()

0.11634500000000013


In [259]:
pocket(X_train, y_train)

array([ 1.        , -1.175936  , -2.824038  , -1.140554  ,  2.30257638])