# 批次

In [1]:
import numpy as np

## Dataset

### Training Features and Labels

In [2]:
train_features = np.array([[22.5, 72.0],
                           [31.4, 45.0],
                           [19.8, 85.0],
                           [27.6, 63]])

train_labels = np.array([[95],
                        [210],
                        [70],
                        [155]])

### Testing Features and Labels

In [3]:
test_features = np.array([[28.1, 58.0]])
test_labels = np.array([[165]])

## Model

### Weight and Bias

In [4]:
weight = np.ones([1, 2]) / 2
bias = np.zeros(1)

### Prediction Function

In [5]:
def forward(x, w, b):
    return x @ w.T + b

### Mean Squared Error Loss Function

In [6]:
def mse_loss(p, y):
    return ((p - y) ** 2).mean()

### Gradient Function

In [7]:
def gradient(p, y):
    return (p - y) * 2 / len(y)

### Learning Rate

In [8]:
LEARNING_RATE = 0.00001

### Backward Function

In [9]:
def backward(x, d, w, b):
    w -= d.T @ x * LEARNING_RATE
    b -= np.sum(d, axis=0) * LEARNING_RATE
    return w, b

## Training

### Batch

In [10]:
BATCH_SIZE = 2

### Iteration Training

In [11]:
for i in range(0, len(train_features), BATCH_SIZE):

    features = train_features[i: i + BATCH_SIZE]
    labels = train_labels[i: i + BATCH_SIZE]

    predictions = forward(features,
                          weight,
                          bias)
    error = mse_loss(predictions,
                     labels)

    delta = gradient(predictions,
                     labels)
    weight, bias = backward(features,
                            delta,
                            weight,
                            bias)

print(f"weight: {weight}")
print(f"bias: {bias}")

weight: [[0.59388172 0.68104165]]
bias: [0.00327249]


## Testing

### Predicting

In [12]:
predictions = forward(test_features,
                      weight,
                      bias)

print(f'predictions: {predictions}')

predictions: [[56.19176426]]


### Calculating Loss

In [13]:
error = mse_loss(predictions,
                 test_labels)

print(f'error: {error}')

error: 11839.232164432306
