In [None]:
from data_fns import load_fashion_mnist
from estimator import RFClassifier, classical_weights, V1_inspired_weights, relu, parallelized_clf
from sklearn.linear_model import SGDClassifier
from sklearn.svm import LinearSVC
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import numpy as np
import numpy.linalg as la
import pickle

In [2]:
import dask
from dask.distributed import Client
client = Client(threads_per_worker=6, n_workers=6)
client

0,1
Client  Scheduler: tcp://127.0.0.1:39105  Dashboard: http://127.0.0.1:8787/status,Cluster  Workers: 6  Cores: 36  Memory: 67.32 GB


## Find the right params via a parameter sweep with SGD

In [None]:
train, train_labels, _, _ = load_fashion_mnist('data/fashion_mnist/')
X_train, X_val, y_train, y_val = train_test_split(train, train_labels, train_size=0.85, stratify=train_labels, 
                                          random_state=None)

n_features = np.arange(1, 1050, 50)
t_list = np.arange(0.1, 10.1, 0.5)
l_list = np.arange(0.1, 10.1, 0.5)
sgd = SGDClassifier(loss="squared_hinge", alpha=1, max_iter=300, tol=1e-5, shuffle=True, n_jobs=5,
                    learning_rate="optimal", early_stopping=True, validation_fraction=0.1, n_iter_no_change=20)
b = np.mean(la.norm(X_train, axis=1) / np.sqrt(X_train.shape[1]))
    
for t in t_list:
    for l in l_list:
        results = {}
        m = np.zeros_like(n_features, dtype='float')
        results['classical'] = {'avg_test_err': m.copy(), 'std_test_err': m.copy()}
        results['V1'] = {'avg_test_err': m.copy(), 'std_test_err': m.copy()}
        
        print('t=%0.2f, l=%0.2f' % (t, l))
        for i, n in enumerate(n_features):
            # classical random features
            weights_classical = {'weight_fun': classical_weights}
            params_classical = {'width': n, **weights_classical, 'bias': b, 'nonlinearity': relu, 'clf': sgd}
            _, _, results['classical']['avg_test_err'][i], results['classical']['std_test_err'][i]= parallelized_clf(RFClassifier, 
                                                                                                            params_classical, 
                                                                                                            X_train, y_train, 
                                                                                                            X_val, y_val, 
                                                                                                            n_iters=10, return_clf=False)

            # haltere inspired
            kwargs = {'t': t, 'l': l}
            weights_V1 = {'weight_fun': V1_inspired_weights, 'kwargs': kwargs}
            params_V1 = {'width': n, **weights_V1, 'bias': b, 'nonlinearity': relu, 'clf': sgd} 
            _, _, results['V1']['avg_test_err'][i], results['V1']['std_test_err'][i] = parallelized_clf(RFClassifier, 
                                                                                                        params_V1, 
                                                                                                        X_train, y_train, 
                                                                                                        X_val, y_val, 
                                                                                                        n_iters=10, return_clf=False)
            print('Iter: %d/%d, V1 test err=%0.2f, RF test err= %0.2f' % (n, n_features[-1],  
                                                                               results['V1']['avg_test_err'][i],
                                                                              results['classical']['avg_test_err'][i]))
            
        with open('results/fashion_mnist_clf/fashion_mnist_clf_t=%0.3f_l=%0.3f_sgd.pickle' % (t, l), 'wb') as handle:
            pickle.dump(results, handle, protocol=pickle.HIGHEST_PROTOCOL)


t=0.10, l=0.10
Iter: 1/1001, V1 test err=0.84, RF test err= 0.87
Iter: 51/1001, V1 test err=0.47, RF test err= 0.43
Iter: 101/1001, V1 test err=0.42, RF test err= 0.34
Iter: 151/1001, V1 test err=0.37, RF test err= 0.33
Iter: 201/1001, V1 test err=0.36, RF test err= 0.30
Iter: 251/1001, V1 test err=0.31, RF test err= 0.30
Iter: 301/1001, V1 test err=0.30, RF test err= 0.28
Iter: 351/1001, V1 test err=0.30, RF test err= 0.27
Iter: 401/1001, V1 test err=0.29, RF test err= 0.26
Iter: 451/1001, V1 test err=0.28, RF test err= 0.25
Iter: 501/1001, V1 test err=0.27, RF test err= 0.24
Iter: 551/1001, V1 test err=0.27, RF test err= 0.23
Iter: 601/1001, V1 test err=0.26, RF test err= 0.22
Iter: 651/1001, V1 test err=0.26, RF test err= 0.22
Iter: 701/1001, V1 test err=0.26, RF test err= 0.22
Iter: 751/1001, V1 test err=0.25, RF test err= 0.21
Iter: 801/1001, V1 test err=0.24, RF test err= 0.20
Iter: 851/1001, V1 test err=0.25, RF test err= 0.20
Iter: 901/1001, V1 test err=0.25, RF test err= 0.19


## training on full dataset

In [None]:
# load data
X_train, y_train, X_test, y_test = load_fashion_mnist('data/fashion_mnist/')

In [None]:
n_features = sorted(set(np.logspace(0, 3.2, 50).astype('int')))

# weight params
t, l = 5, 1
kwargs = {'t': t, 'l': l}
weights_V1 = {'weight_fun': V1_inspired_weights, 'kwargs': kwargs}
weights_classical = {'weight_fun': classical_weights}

# params for classification
sgd = SGDClassifier(loss="squared_hinge", alpha=1, max_iter=500, tol=1e-4, shuffle=True, n_jobs=5,
                    learning_rate="optimal", early_stopping=True, validation_fraction=0.1, n_iter_no_change=20)
svc = LinearSVC(random_state=None, tol=1e-4, max_iter=500)
# b = np.mean(la.norm(X_train) / np.sqrt(X_train.shape[0]))
b = np.mean(la.norm(X_train, axis=1) / np.sqrt(X_train.shape[1]))

In [None]:
%%time
results = {}
m = np.zeros_like(n_features, dtype='float')
results['classical'] = {'avg_test_err': m.copy(), 'std_test_err': m.copy()}
results['V1'] = {'avg_test_err': m.copy(), 'std_test_err': m.copy()}
for i, n in enumerate(n_features):

    # classical random features
    params_classical = {'width': n, **weights_classical, 'bias': b, 'nonlinearity': relu, 'clf': svc}
    _, _, results['classical']['avg_test_err'][i], results['classical']['std_test_err'][i]= parallelized_clf(RFClassifier, 
                                                                                                    params_classical, 
                                                                                                    X_train, y_train, 
                                                                                                    X_test, y_test, 
                                                                                                    n_iters=5, return_clf=False)
    
    # haltere inspired
    params_V1 = {'width': n, **weights_V1, 'bias': b, 'nonlinearity': relu, 'clf': svc} 
    _, _, results['V1']['avg_test_err'][i], results['V1']['std_test_err'][i] = parallelized_clf(RFClassifier, 
                                                                                                params_V1, 
                                                                                                X_train, y_train, 
                                                                                                X_test, y_test, 
                                                                                                n_iters=5, return_clf=False)

    print('Iter: %d/%d, V1 test err=%0.2f, RF test err= %0.2f' % (n, n_features[-1],  
                                                                       results['V1']['avg_test_err'][i],
                                                                      results['classical']['avg_test_err'][i]))
    
# with open('results/fashion_mnist_clf/fashion_mnist_clf_t=%0.2f_l=%0.2f.pickle' % (t, l), 'wb') as handle:
#     pickle.dump(results, handle, protocol=pickle.HIGHEST_PROTOCOL)
    

In [None]:
# t, l = 5, 2
# with open('results/fashion_mnist_clf/fashion_mnist_clf_t=%0.2f_l=%0.2f.pickle' % (t, l), 'rb') as handle:
#     results = pickle.load(handle)
    

fig = plt.figure(figsize=(10.6, 8))
ax = fig.add_subplot(111)
ax.errorbar(n_features, results['V1']['avg_test_err'], yerr=results['V1']['std_test_err'], fmt='-', 
            label='V1-inspired',  markersize=4, lw=5, elinewidth=3)
ax.errorbar(n_features, results['classical']['avg_test_err'], yerr=results['classical']['std_test_err'], 
            fmt='-', label='classical', markersize=4, lw=5, elinewidth=3)
plt.xlabel('Hidden layer width', fontsize=40)
plt.ylabel('Classification error', fontsize=40)
# plt.xticks(np.arange(0, 1020, 200))
plt.xlim([0, 1020])
plt.yticks(np.arange(0, 0.8, 0.1))
plt.ylim([-0.05, 0.55])
plt.xticks(np.arange(0, 1020, 200))
ax.tick_params(axis = 'both', which = 'major', labelsize = 30, width=2, length=6)

plt.legend(loc = 'upper right', fontsize=30)
# plt.savefig('results/fashion_mnist_clf/fashion_mnist_clf_t=%0.2f_l=%0.2f.pdf' % (t, l))

In [None]:
list(zip(n_features, results['V1']['avg_test_err']))

In [None]:
list(zip(n_features, results['classical']['avg_test_err']))

### Few shot learning

In [None]:
train, train_labels, test, test_labels = load_fashion_mnist('data/fashion_mnist/')
num_train = 50
X_train, _, y_train, _ = train_test_split(train, train_labels, train_size=num_train, stratify=train_labels, 
                                          random_state=42)
X_test, y_test = test.copy(), test_labels.copy()

In [None]:
n_features = sorted(set(np.logspace(0, 3.2, 50).astype('int')))

# weight params
t, l = 5, 2
kwargs = {'t': t, 'l': l}
weights_V1 = {'weight_fun': V1_inspired_weights, 'kwargs': kwargs}
weights_classical = {'weight_fun': classical_weights}

# params for classification
# sgd = SGDClassifier(loss="squared_hinge", alpha=1, max_iter=200, tol=1e-4, shuffle=True, n_jobs=5,
#                     learning_rate="optimal", early_stopping=True, validation_fraction=0.1, n_iter_no_change=20)
svc = LinearSVC(random_state=20, tol=1e-4, max_iter=500)
b = np.mean(la.norm(X_train)/ np.sqrt(X_train.shape[0]))

In [None]:
%%time
results = {}
m = np.zeros_like(n_features, dtype='float')
results['classical'] = {'avg_test_err': m.copy(), 'std_test_err': m.copy()}
results['V1'] = {'avg_test_err': m.copy(), 'std_test_err': m.copy()}
for i, n in enumerate(n_features):

    # classical random features
    params_classical = {'width': n, **weights_classical, 'bias': b, 'nonlinearity': relu, 'clf': svc}
    _, _, results['classical']['avg_test_err'][i], results['classical']['std_test_err'][i] = parallelized_clf(RFClassifier, 
                                                                                                    params_classical, 
                                                                                                    X_train, y_train, 
                                                                                                    X_test, y_test, 
                                                                                                    n_iters=10)
    
    # haltere inspired
    params_V1 = {'width': n, **weights_V1, 'bias': b, 'nonlinearity': relu, 'clf': svc} 
    _, _, results['V1']['avg_test_err'][i], results['V1']['std_test_err'][i] = parallelized_clf(RFClassifier, 
                                                                                                params_V1, 
                                                                                                X_train, y_train, 
                                                                                                X_test, y_test, 
                                                                                                n_iters=10)

    print('Iter: %d/%d, V1 test err=%0.2f, RF test err= %0.2f' % (n, n_features[-1],  
                                                                       results['V1']['avg_test_err'][i],
                                                                      results['classical']['avg_test_err'][i]))
    
with open('results/fashion_mnist_clf/fashion_mnist_clf_t=%0.2f_l=%0.3f_few_shot.pickle' % (t, l), 'wb') as handle:
    pickle.dump(results, handle, protocol=pickle.HIGHEST_PROTOCOL)
    

In [None]:
fig = plt.figure(figsize=(10.6, 8))
ax = fig.add_subplot(111)
ax.errorbar(n_features, results['V1']['avg_test_err'], yerr=results['V1']['std_test_err'], fmt='-', 
            label='V1-inspired',  markersize=4, lw=5, elinewidth=3)
ax.errorbar(n_features, results['classical']['avg_test_err'], yerr=results['classical']['std_test_err'], 
            fmt='-', label='classical', markersize=4, lw=5, elinewidth=3)
plt.xlabel('Hidden layer width', fontsize=40)
plt.ylabel('Classification error', fontsize=40)
# plt.xticks(np.arange(0, 1020, 200))
plt.xlim([0, 1020])
plt.yticks(np.arange(0, 0.8, 0.1))
plt.ylim([-0.05, 0.55])
plt.xticks(np.arange(0, 1020, 200))
ax.tick_params(axis = 'both', which = 'major', labelsize = 30, width=2, length=6)

plt.legend(loc = 'upper right', fontsize=30)