In [56]:
BATCH_SIZE = 64
EPOCH = 20
LEARNING_RATE = 0.01
MOMENTUM = 0.9
EPSILON = 1e-8
ETA = 0.1
NUM_CLASS = 10

In [14]:
import numpy as np
 
def gradient_descent(x, y, alpha, num_iters):
    m = x.shape[0]
    w = np.zeros(x.shape[1])
    b = 0
    for i in range(num_iters):
        y_pred = np.dot(x, w) + b                 # predict Y
        dw = (1/m) * np.dot(x.T, (y_pred - y))    
        db = (1/m) * np.sum(y_pred - y)
        w = w - alpha * dw
        b = b - alpha * db
    return w, b

In [16]:
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
transform = transforms.Compose([transforms.ToTensor(),   #改变通道顺序
                                transforms.Normalize((0.5,),(0.5,))])  #归一化
train_set = MNIST("./data", train=True, download=True, transform=transform)
test_set = MNIST("./data", train=False, download=True, transform=transform)

In [17]:
train_set

Dataset MNIST
    Number of datapoints: 60000
    Root location: ./data
    Split: Train
    StandardTransform
Transform: Compose(
               ToTensor()
               Normalize(mean=(0.5,), std=(0.5,))
           )

In [18]:
from torch.utils.data import DataLoader
train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=True)

In [54]:
import random
def initialize(dim):
    b=random.random()
    theta=np.random.rand(dim)
    return b,theta


b,theta=initialize(28)
print("Bias: ",b)
print("Weights: ",theta)

def predict_Y(b,theta,X):
    print("theta",theta.shape)
    print("X",X.shape)
    return b + np.dot(X,theta)

Bias:  0.20118983280379743
Weights:  [0.05895967 0.29255293 0.57385506 0.16094504 0.43035654 0.11492838
 0.2029012  0.49665624 0.84947032 0.26075237 0.135276   0.48426359
 0.91362678 0.26711967 0.24963784 0.38088299 0.47669105 0.18426748
 0.82477622 0.10246857 0.37661648 0.68328581 0.46480376 0.15995559
 0.51866078 0.93232954 0.16023007 0.0472164 ]


In [57]:
from torch import nn
# Define model
class Network(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.dense_relu_sequential = nn.Sequential(
            nn.Linear(28*28, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, NUM_CLASS)
        )

    def construct(self, x):
        x = self.flatten(x)
        logits = self.dense_relu_sequential(x)
        return logits

net = Network()
print(net)

Network(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (dense_relu_sequential): Sequential(
    (0): Linear(in_features=784, out_features=256, bias=True)
    (1): ReLU()
    (2): Linear(in_features=256, out_features=128, bias=True)
    (3): ReLU()
    (4): Linear(in_features=128, out_features=10, bias=True)
  )
)


In [25]:
import math
def get_cost(Y,Y_hat):
    Y_resd=Y-Y_hat
    return np.sum(np.dot(Y_resd.T,Y_resd))/len(Y-Y_resd)

In [59]:
cost = 0.0
for index, (X,Y) in enumerate(train_loader):
    print(X[0][0].  shape)
    Y_hat = predict_Y(b,theta,X)
    print("Y",Y.shape)
    print("Y_hat",Y_hat.shape)
    #cost = get_cost(Y,Y_hat)
    break
print(cost)
print(Y_hat[:10])

torch.Size([1, 28, 28])
theta (28,)
X torch.Size([64, 1, 28, 28])
Y torch.Size([64])
Y_hat (64, 1, 28)
0.0
[[[-10.60229656  -9.79740392  -9.08970741  -9.0248345   -9.09520022
    -9.32583429  -9.26062114  -9.26062114  -9.40115255  -9.67082092
    -9.32418489  -7.90480524  -3.82563192  -2.42421624  -2.4753033
    -4.71723866  -3.4715303   -0.59009968   0.85722539  -1.82415213
    -3.75070939 -10.60229656 -10.60229656 -10.60229656 -10.60229656
   -10.60229656 -10.60229656 -10.60229656]]

 [[-10.60229656 -10.60229656 -10.60229656 -10.60229656 -10.60229656
    -9.78656599  -8.78636182  -6.74750873  -7.14211737  -7.30431297
    -6.87442627  -6.67544654  -6.92862008  -7.34929588  -5.36153544
    -1.27551451  -2.75408287  -8.6250848   -9.23307531  -9.48441779
    -9.57560831  -9.59980732  -8.92395023  -7.30850946  -7.92437633
   -10.60229656 -10.60229656 -10.60229656]]

 [[-10.60229656 -10.60229656 -10.60229656 -10.60229656 -10.60229656
    -9.41687848  -8.00316101  -7.88495806  -8.35548981  