In [1]:
import pandas as pd
import numpy as np
import random
from scipy.io import loadmat
from multiprocessing import Pool
from sklearn.preprocessing import OneHotEncoder

from typing import List,Tuple
import plotly.express as px
import plotly.graph_objects as go

In [2]:
data = loadmat('ex4data1.mat')
thetas = loadmat("ex4weights.mat")
theta1 = thetas["Theta1"]
theta2 = thetas["Theta2"]
thetas = [theta1,theta2]
X = data["X"]
y = data["y"]

In [3]:
encoder_y = OneHotEncoder(sparse=False)
y_onehot = encoder_y.fit_transform(y)

In [5]:
# 前100个 排列成10*10 (200*200)

digits = [X[random.randint(0,4999)].reshape(20,20).T for _ in range(100)]

rows_digits = []
for row in range(10):
    rows_digits.append(np.hstack(digits[row*10:row*10+10]))

digits_img = np.vstack(rows_digits)

In [6]:
fig = px.imshow(digits_img,color_continuous_scale='gray')
fig.show()

In [7]:
# g(theta,x)
def sigmoid(x):
    s = 1 / (1 + np.exp(-x))
    return s

def serialize(*arrays:np.ndarray):
    arr_shapes:List[tuple] = []
    arr_flatten = np.array([])
    for array in arrays:
        arr_shapes.append(array.shape)
        arr_flatten = np.hstack([arr_flatten,array.flatten()])
    return arr_flatten,arr_shapes

def deserialize(array:np.ndarray, shapes:List[tuple]):
    position = 0
    arrays:List[np.ndarray] = []
    for shape in shapes:
        length = np.prod(shape)
        arrays.append(array[position:position+length].reshape(*shape))
        position += length
    return arrays

def forward_propa(thetas:List[np.ndarray],X:np.ndarray):
    n_layers =len(thetas)+1
    a_list = []
    z_list = []

    a_list.append(np.insert(X,0,1,axis=1))
    for index,theta in enumerate(thetas):
        z = a_list[index]@theta.T
        z_list.append(z)
        if index+1 == n_layers-1:
            a_list.append(sigmoid(z))
        else:
            a_list.append(np.insert(sigmoid(z),0,1,axis=1))
    return a_list,z_list

def h(thetas,X):
    return forward_propa(thetas,X)[0][-1]
    
# J(Theta)
def cost(theta:List[np.ndarray],X:np.ndarray,y:np.ndarray):
    return np.mean(-np.sum(np.log(h(theta,X))*y,axis=1)-np.sum((1-y)*np.log(1-h(theta,X)),axis=1))

def regularized_cost(theta:List[np.ndarray], X:np.ndarray, y:np.ndarray, L:float=1.):
    n = X.shape[0]
    regular_term = 0
    for t in theta:
        regular_term += np.sum(t[:,1:]**2)
    regular_term = regular_term/(2*n)*L
    return cost(theta,X,y)+regular_term

def sigmoid_gradient(x):
    return sigmoid(x)*(1-sigmoid(x))

# def gradients(theta,x,y):
#     return np.mean(x.T*(h(theta,x)-y),axis=1)

In [14]:
# # 模拟H的计算过程，不需要了
# a1 = np.insert(X,0,1,axis=1)
# z2 = a1@theta1.T
# a2 = sigmoid(z2)
# a2 = np.insert(a2,0,1,axis=1)
# z3 = a2@theta2.T
# a3 = sigmoid(z3)
# np.mean(-np.sum(np.log(a3)*y_onehot,axis=1)-np.sum((1-y_onehot)*np.log(1-a3),axis=1))

In [8]:
cost(thetas,X,y_onehot)

0.2876291651613189

In [9]:
regularized_cost(thetas,X,y_onehot)

0.38376985909092365

# Back Propagation

In [10]:
epsilon = 0.12
theta1_init = np.random.rand(25,401) * 2 * epsilon - epsilon
theta2_init = np.random.rand(10,26) * 2 * epsilon - epsilon
thetas_flat,shapes = serialize(theta1_init,theta2_init)
thetas=deserialize(thetas_flat,shapes)

In [94]:
deltas=[np.zeros(theta.shape) for theta in thetas]

a,z = forward_propa(thetas,X)

for index,theta in reversed(list(enumerate(thetas))):
    a[index+1]-y_onehot



for row_id,example in enumerate(X):

    example=example.reshape(1,-1)
    a,z = forward_propa(thetas,example)
    delta3 = a[-1]-y_onehot[row_id]
    BigDelta2=delta3.T@a[-2][1:]

    delta2 = delta3*thetas[-1]*sigmoid_gradient(z[-2])
    delta2 = delta2[1:]
    BigDelta1=delta2.T@a[-3][1:]



ValueError: matmul: Input operand 1 has a mismatch in its core dimension 0, with gufunc signature (n?,k),(k,m?)->(n?,m?) (size 0 is different from 1)

In [13]:
for index,theta in reversed(list(enumerate(thetas))):
    print(index,theta)

1 [[-0.11255792 -0.09244434  0.10166217 -0.03268586 -0.04933895 -0.00217107
   0.01024438  0.06422628 -0.08488105 -0.09636253 -0.10354529 -0.07758216
   0.02845684  0.03126636  0.07921758  0.09600367  0.06825854 -0.11116809
  -0.11070497 -0.08751036 -0.0590906  -0.09046223 -0.00152265  0.03297337
   0.06847087 -0.10630391]
 [-0.01349743 -0.06120622 -0.04099987 -0.04555985 -0.10352187 -0.08971907
   0.09069512 -0.0838353  -0.07405906  0.04593864 -0.05112366 -0.07369021
   0.00130704  0.07353663  0.08231821 -0.04223664 -0.05888207  0.05333436
  -0.04227929 -0.02737606 -0.0260562  -0.06220948  0.11502719 -0.09220657
   0.10671397  0.09874098]
 [ 0.06195266  0.01774823  0.09371717 -0.10293147  0.07303216  0.0720509
  -0.00250926  0.09045442  0.09581775  0.10479641 -0.09429329 -0.03483065
   0.00050348 -0.08279435  0.02911578  0.10614949  0.06911225  0.08195694
  -0.08292913 -0.06845471  0.09221786  0.10875799  0.09964132  0.09095119
   0.07700276 -0.00306896]
 [-0.00928931  0.09990294 -0.0

In [85]:
y_onehot[0].shape

(10,)

In [90]:
a[-1]

array([[4.81465717e-05, 4.58821829e-04, 2.15146201e-05, 3.31719561e-03,
        1.55814354e-04, 3.02724040e-03, 3.69700393e-02, 5.73434571e-03,
        6.96288990e-01, 8.18576980e-02]])

In [91]:
y_onehot[0]

array([0., 0., 0., 0., 0., 0., 0., 0., 0., 1.])

In [92]:
(a[-1]-y_onehot[0]).shape

(1, 10)

In [11]:
thetas

[array([[ 0.09012066,  0.11719591,  0.05729271, ..., -0.00457155,
          0.01221955, -0.10617145],
        [-0.11101568,  0.09326634, -0.01150639, ..., -0.02493641,
         -0.10754586,  0.11366942],
        [-0.02781685,  0.00153593, -0.07127581, ..., -0.09849576,
         -0.02483176, -0.05466766],
        ...,
        [ 0.02025303,  0.03481656,  0.02995595, ...,  0.02832567,
         -0.0992575 ,  0.11878387],
        [ 0.05734626,  0.01278467,  0.02919068, ...,  0.0089911 ,
         -0.07033499,  0.11369218],
        [-0.0060914 , -0.06516306,  0.0919051 , ...,  0.06732953,
          0.11739942, -0.07595419]]),
 array([[-0.11255792, -0.09244434,  0.10166217, -0.03268586, -0.04933895,
         -0.00217107,  0.01024438,  0.06422628, -0.08488105, -0.09636253,
         -0.10354529, -0.07758216,  0.02845684,  0.03126636,  0.07921758,
          0.09600367,  0.06825854, -0.11116809, -0.11070497, -0.08751036,
         -0.0590906 , -0.09046223, -0.00152265,  0.03297337,  0.06847087,
   