In [1]:
import numpy as np
import matplotlib.pyplot as plt

%matplotlib inline
%config InlineBackend.figure_format = 'retina'

%load_ext autoreload
%autoreload 2

plt.rcParams['figure.figsize'] = (10.0, 8.0) # set default size of plots

In [2]:
data = np.load("../data/count_data.npz")

X_train = data['X_train']
y_train = data['Y_train']

X_test = data['X_test']
y_test = data['Y_test']

print("X_train", X_train.shape)
print("y_train", y_train.shape)
print("X_test", X_test.shape)
print("y_test", y_test.shape)

X_train (1000, 2)
y_train (1000,)
X_test (500, 2)
y_test (500,)


In [3]:
from count_regression import CountRegression

model = CountRegression(1e-4)

n_samples, n_features = X_train.shape
n_classes = 1

print(n_samples, n_classes, n_features)

init_wb = np.zeros((n_features + 1, n_classes))

obj_init = model.objective(init_wb, X_train, y_train)
print('Initial loss', obj_init)

grad = model.objective_grad(init_wb, X_train, y_train)
print('Initial grad', grad)

model.fit(X_train, y_train)
# 3853.2051767327357
# 2499.2137262563706

1000 1 2
Initial loss 3853.2051767327357
Initial grad [  386.06366597 -1744.49380844  1779.5       ]
Min loss value 2499.213726255116


In [4]:
w, b = model.get_params()

n_samples, n_features = X_test.shape
w = w.reshape((n_features, 1))

linear_comp = X_test @ w + b

fx = 1 / (1 + np.exp(-linear_comp))
assert fx.shape == (n_samples, 1)

y = y_test.reshape((n_samples, 1))
likelihood = fx * ((1 - fx) ** y)
assert likelihood.shape == (n_samples, 1)

mean_log_likelihood = np.log(likelihood).mean()

print('Mean log likelihood', mean_log_likelihood)
print('w', w.ravel())
print('b', b)

Mean log likelihood -2.505942002758697
w [-0.07865909  0.36477745]
b -1.366657745627738


In [8]:
y_pred = model.predict(X_train)
print('Train accuracy', (y_pred == y_train).mean())

y_pred = model.predict(X_test)
print('Mean Squared Error', ((y - y_pred) ** 2).mean())

print(y_pred[:10].ravel())
print(y_test[:10].ravel())

Train accuracy 0.123498
Mean Squared Error 1.354
[9. 4. 9. 5. 1. 2. 2. 2. 8. 3.]
[9. 5. 8. 4. 1. 2. 2. 0. 9. 2.]
