# Implement Batch Gradient Descent for Softmax Regression without sklearn

In [64]:
import numpy as np
from sklearn import datasets
iris = datasets.load_iris()
from sklearn.model_selection import train_test_split

X = iris["data"][:, (2, 3)]  # petal length, petal width
y = iris["target"]

X_train, X_test, y_train, y_test = train_test_split(X, y)

In [62]:
def softmax(logits):
    exps = np.exp(logits)
    total = np.sum(exps, axis=1, keepdims=True)
    return exps/total

def to_one_hot(y):
    n_classes = int(y.max() + 1)
    m = int(len(y))
    Y_one_hot = np.zeros((m, n_classes))
    Y_one_hot[np.arange(m), y] = 1
    return Y_one_hot

In [119]:
learn_rate = 0.3
n_iterations = 30000
epsilon = 1e-7
y_train_oh = to_one_hot(y_train)
m = len(X_train)

n_inputs, n_outputs = X.shape[1], y_train_oh.shape[1]

theta_classes = np.random.randn(n_inputs, n_outputs)
theta_classes.shape

(2, 3)

In [120]:
X_train.shape, theta_classes.shape

((112, 2), (2, 3))

Loss: $J(\mathbf{\Theta}) =
- \dfrac{1}{m}\sum\limits_{i=1}^{m}\sum\limits_{k=1}^{K}{y_k^{(i)}\log\left(\hat{p}_k^{(i)}\right)}$

Gradients: $\nabla_{\mathbf{\theta}^{(k)}} \, J(\mathbf{\Theta}) = \dfrac{1}{m} \sum\limits_{i=1}^{m}{ \left ( \hat{p}^{(i)}_k - y_k^{(i)} \right ) \mathbf{x}^{(i)}}$

In [121]:
best_loss = np.infty
times_in_a_row = 3

for iteration in range(n_iterations):
    logits = X_train.dot(theta_classes)
    p_hat = softmax(logits)
    loss = -np.mean(np.sum(y_train_oh * np.log(p_hat + epsilon), axis=1))# loss function that i have no idea how to do
    error = p_hat - y_train_oh
    if iteration % 500 == 1:
        print("{:<5}:{}".format(iteration, loss))
    if loss < best_loss:
        best_loss = loss
        times_in_a_row = 3
    else:
        times_in_a_row -= 1
    if times_in_a_row == 0:
        print("early stopping!")
        break
    gradients = 1/m * X_train.T.dot(error)
    theta_classes += -learn_rate * gradients

1    :1.4694396060537207
501  :0.6411677511479184
1001 :0.5692616874923935
1501 :0.5415018314518816
2001 :0.5273438196878488
2501 :0.5190001374151647
3001 :0.5136273518325545
3501 :0.5099546739406066
4001 :0.5073342102951518
4501 :0.5054034100590711
5001 :0.5039448229452618
5501 :0.5028208392580653
6001 :0.5019405784104956
6501 :0.5012419010752486
7001 :0.5006810809117711
7501 :0.5002265963254476
8001 :0.4998552521953034
8501 :0.49954967624107155
9001 :0.4992966563599956
9501 :0.49908600887096594
10001:0.4989097912621639
10501:0.4987617439709632
11001:0.4986368877436945
11501:0.49853122873465905
12001:0.49844153951808795
12501:0.4983651944321182
13001:0.4983000443655908
13501:0.49824432055212803
14001:0.49819655995132334
14501:0.49815554687076563
15001:0.4981202669297386
15501:0.49808987048872916
16001:0.49806364340143755
16501:0.49804098347645803
17001:0.49802138142400887
17501:0.4980044053500499
18001:0.49798968807419314
18501:0.49797691670890387
19001:0.49796582405970696
19501:0.497