In [3]:
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, roc_auc_score
import matplotlib.pyplot as plt
import time

In [4]:
digits = datasets.load_digits()
features = digits['data']
target = digits['target'] 

In [5]:
train_x, test_x, train_y, test_y = train_test_split(features, target, test_size=0.33)

print(train_x.shape)
print(train_y.shape)
print(test_x.shape)
print(test_y.shape)

(1203, 64)
(1203,)
(594, 64)
(594,)


# Pytorch

In [156]:
import torch.nn as nn
import torch.optim as optim
import torch
from torch.utils.data import DataLoader, TensorDataset

In [157]:
# 格式转化
pytorch_train_x = torch.tensor(train_x, dtype=torch.float32)
pytorch_train_y = torch.tensor(train_y, dtype=torch.long)

pytorch_test_x = torch.tensor(test_x, dtype=torch.float32)
pytorch_test_y = torch.tensor(test_y, dtype=torch.long)

# 装进数据容器分批
data = TensorDataset(pytorch_train_x, pytorch_train_y)
dataset = DataLoader(data, 64)

In [13]:
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        
        self.fc1 = nn.Linear(64, 128)
        self.fc2 = nn.Linear(128, 256)
        self.fc3 = nn.Linear(256, 512)
        self.fc4 = nn.Linear(512, 256)
        self.fc5 = nn.Linear(256, 128)
        self.fc6 = nn.Linear(128, 64)
        self.fc7 = nn.Linear(64, 10)
        self.softmax = nn.Softmax(dim=1)
        self.relu = nn.ReLU()
        
    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.relu(self.fc3(x))
        x = self.relu(self.fc4(x))
        x = self.relu(self.fc5(x))
        x = self.relu(self.fc6(x))
        result = self.fc7(x)
        return result
    
    def predict(self, x):
        result = self.softmax(self(x)).sort()[1][:,-1]
        return result

### CPU Version

In [14]:
model = Model()
optimizer = optim.Adam(model.parameters())
loss_fun = nn.CrossEntropyLoss()

In [15]:
%%time
# 训练
for i in range(100):
    loss_record = 0
    start_time = time.time()
    
    for batch, (x, y) in enumerate(dataset):
        optimizer.zero_grad()
        y_pre = model(x)
        
        loss = loss_fun(y_pre, y)
        loss.backward()
        optimizer.step()
        
        loss_record += loss.item()
    
    duration = time.time() - start_time
    
    predict = model.predict(pytorch_test_x).numpy()
    acc_score = accuracy_score(predict, pytorch_test_y.numpy())
    
    print('Epoch: {}, time: {:.3f}s, Loss: {:.3f}, Acc: {:.2f}%'.format(i+1, duration, loss_record/batch, acc_score*100))

Epoch: 1, time: 0.375s, Loss: 2.012, Acc: 71.21%
Epoch: 2, time: 0.121s, Loss: 0.720, Acc: 87.21%
Epoch: 3, time: 0.133s, Loss: 0.414, Acc: 89.23%
Epoch: 4, time: 0.114s, Loss: 0.279, Acc: 91.75%
Epoch: 5, time: 0.124s, Loss: 0.176, Acc: 92.93%
Epoch: 6, time: 0.117s, Loss: 0.138, Acc: 94.78%
Epoch: 7, time: 0.116s, Loss: 0.100, Acc: 94.78%
Epoch: 8, time: 0.125s, Loss: 0.109, Acc: 93.77%
Epoch: 9, time: 0.119s, Loss: 0.098, Acc: 94.11%
Epoch: 10, time: 0.116s, Loss: 0.075, Acc: 93.94%
Epoch: 11, time: 0.121s, Loss: 0.065, Acc: 96.30%
Epoch: 12, time: 0.119s, Loss: 0.079, Acc: 94.78%
Epoch: 13, time: 0.122s, Loss: 0.075, Acc: 95.62%
Epoch: 14, time: 0.121s, Loss: 0.034, Acc: 97.81%
Epoch: 15, time: 0.121s, Loss: 0.030, Acc: 96.63%
Epoch: 16, time: 0.114s, Loss: 0.033, Acc: 96.80%
Epoch: 17, time: 0.118s, Loss: 0.033, Acc: 96.13%
Epoch: 18, time: 0.114s, Loss: 0.069, Acc: 95.45%
Epoch: 19, time: 0.116s, Loss: 0.047, Acc: 96.13%
Epoch: 20, time: 0.112s, Loss: 0.015, Acc: 97.64%
Epoch: 21

### GPU Version

In [16]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
pytorch_test_x, pytorch_test_y = pytorch_test_x.to(device), pytorch_test_y.to(device)
model = Model().to(device)
optimizer = optim.Adam(model.parameters())
loss_fun = nn.CrossEntropyLoss()
loss_fun.to(device)

CrossEntropyLoss()

In [17]:
%%time
# 训练
for i in range(100):
    loss_record = 0
    start_time = time.time()
    
    for batch, (x, y) in enumerate(dataset):
        optimizer.zero_grad()

        x, y = x.to(device), y.to(device)
        y_pre = model(x)
        
        loss = loss_fun(y_pre, y)
        loss.backward()
        optimizer.step()
        
        loss_record += loss.cpu().item()
    
    duration = time.time() - start_time
    
    predict = model.predict(pytorch_test_x).cpu().numpy()
    acc_score = accuracy_score(predict, pytorch_test_y.cpu().numpy())
    
    print('Epoch: {}, time: {:.3f}s, Loss: {:.3f}, Acc: {:.2f}%'.format(i+1, duration, loss_record/batch, acc_score*100))

Epoch: 1, time: 0.132s, Loss: 1.958, Acc: 72.39%
Epoch: 2, time: 0.071s, Loss: 0.703, Acc: 87.04%
Epoch: 3, time: 0.069s, Loss: 0.345, Acc: 89.73%
Epoch: 4, time: 0.072s, Loss: 0.242, Acc: 92.26%
Epoch: 5, time: 0.070s, Loss: 0.202, Acc: 90.40%
Epoch: 6, time: 0.072s, Loss: 0.213, Acc: 93.43%
Epoch: 7, time: 0.081s, Loss: 0.161, Acc: 94.11%
Epoch: 8, time: 0.070s, Loss: 0.076, Acc: 94.78%
Epoch: 9, time: 0.071s, Loss: 0.055, Acc: 95.62%
Epoch: 10, time: 0.078s, Loss: 0.046, Acc: 96.13%
Epoch: 11, time: 0.071s, Loss: 0.033, Acc: 97.31%
Epoch: 12, time: 0.075s, Loss: 0.027, Acc: 97.14%
Epoch: 13, time: 0.076s, Loss: 0.026, Acc: 97.31%
Epoch: 14, time: 0.070s, Loss: 0.025, Acc: 97.64%
Epoch: 15, time: 0.074s, Loss: 0.027, Acc: 96.30%
Epoch: 16, time: 0.072s, Loss: 0.027, Acc: 97.14%
Epoch: 17, time: 0.069s, Loss: 0.029, Acc: 96.46%
Epoch: 18, time: 0.070s, Loss: 0.011, Acc: 96.30%
Epoch: 19, time: 0.081s, Loss: 0.015, Acc: 96.97%
Epoch: 20, time: 0.080s, Loss: 0.010, Acc: 97.64%
Epoch: 21

# JAX

In [198]:
import jax.numpy as jnp
from jax import grad, jit, vmap
import jax.nn as jnn
import jax
from jax.scipy.special import logsumexp

In [199]:
jax_train_x = jnp.array(train_x)
jax_train_y = jnn.one_hot(jnp.array(train_y), 10)
jax_test_x = jnp.array(test_x)
jax_test_y = jnn.one_hot(jnp.array(test_y), 10)

In [200]:
key = jax.random.PRNGKey(0)
layer_dim = [(64, 128), (128, 256), (256,512), (512, 256), (256,128), (128,64), (64,10)]
parmas = [(jax.random.normal(key, (m, n)) * 0.1, jax.random.normal(key, (n,)) * 0.1) for m, n in layer_dim]

In [201]:
def fc(x, w, b):
    return jnp.dot(x, w) + b

def relu(x):
    return jnn.relu(x)

def softmax(x):
    return jnn.softmax(x)

def logsoftmax(x):
    return logsumexp(x)

In [202]:
@jax.partial(jax.vmap, in_axes=(0, None))
def predict(x, parmas):
    for w, b in parmas[:-1]:
        x = relu(fc(x, w, b))
    x = fc(x, *parmas[-1])
    return x - logsoftmax(x)

@grad
def loss_fun(parmas, x, y):
    pre = predict(x, parmas)
    loss = -jnp.mean(pre*y)
    return loss

@jit
def train(parmas, x, y, lr=0.001):
    parmas_grad = loss_fun(parmas, x, y)
    new_grad = [(w - dw * lr, b - db * lr) for (dw, db), (w, b) in zip(parmas_grad, parmas)]
    return new_grad

In [203]:
%%time
for i in range(4000):
    
    parmas = train(parmas, jax_train_x, jax_train_y, lr=lr)

result = jnp.argmax(predict(jax_test_x, parmas), axis=1)
print(round(accuracy_score(result, test_y), 2),)

0.91
CPU times: user 10.4 s, sys: 6.01 s, total: 16.4 s
Wall time: 16.1 s
