In [8]:
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, roc_auc_score
import matplotlib.pyplot as plt
import time

In [9]:
digits = datasets.load_digits()
features = digits['data']
target = digits['target'] 

In [10]:
train_x, test_x, train_y, test_y = train_test_split(features, target, test_size=0.33)

print(train_x.shape)
print(train_y.shape)
print(test_x.shape)
print(test_y.shape)

(1203, 64)
(1203,)
(594, 64)
(594,)


# Pytorch

In [2]:
import torch.nn as nn
import torch.optim as optim
import torch
from torch.utils.data import DataLoader, TensorDataset

In [12]:
# 格式转化
pytorch_train_x = torch.tensor(train_x, dtype=torch.float32)
pytorch_train_y = torch.tensor(train_y, dtype=torch.long)

pytorch_test_x = torch.tensor(test_x, dtype=torch.float32)
pytorch_test_y = torch.tensor(test_y, dtype=torch.long)

# 装进数据容器分批
data = TensorDataset(pytorch_train_x, pytorch_train_y)
dataset = DataLoader(data, 64)

In [13]:
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        
        self.fc1 = nn.Linear(64, 128)
        self.fc2 = nn.Linear(128, 256)
        self.fc3 = nn.Linear(256, 512)
        self.fc4 = nn.Linear(512, 256)
        self.fc5 = nn.Linear(256, 128)
        self.fc6 = nn.Linear(128, 64)
        self.fc7 = nn.Linear(64, 10)
        self.softmax = nn.Softmax(dim=1)
        self.relu = nn.ReLU()
        
    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.relu(self.fc3(x))
        x = self.relu(self.fc4(x))
        x = self.relu(self.fc5(x))
        x = self.relu(self.fc6(x))
        result = self.fc7(x)
        return result
    
    def predict(self, x):
        result = self.softmax(self(x)).sort()[1][:,-1]
        return result

In [82]:
model = Model()
model.fc2((model.relu(model.fc1(pytorch_train_x[0]))))

tensor([ 1.2238e+00, -4.5168e-01,  2.2505e-01, -8.9987e-01, -3.3439e-01,
        -1.2527e+00,  2.2490e+00,  8.9131e-01, -3.2683e+00, -5.3231e-01,
        -4.0276e+00,  7.7310e-01, -8.1476e-01, -4.3402e+00, -1.2724e+00,
         4.2437e+00, -3.7290e+00, -2.0524e+00, -2.1846e-01,  3.4116e-01,
        -3.0236e+00,  4.1288e-01, -1.2392e+00, -3.3066e+00, -1.8802e+00,
        -1.3699e+00, -2.7548e+00, -2.0649e-01,  7.0265e-01,  3.6390e-01,
         1.0461e+00, -1.9002e-01,  2.4421e-03, -2.9163e-01,  2.8602e-01,
        -4.2469e+00, -8.5473e-01, -9.8331e-01, -1.6659e+00,  1.5747e-02,
        -2.5940e+00,  9.4085e-01,  3.3631e-01, -3.1555e-02,  3.4583e-01,
        -1.9310e-02, -3.1192e+00,  1.8290e+00,  3.6814e-01,  2.2955e+00,
        -7.3573e-02,  3.0953e-01,  5.2701e-01, -1.4145e+00,  4.4904e-01,
        -2.1649e+00, -3.9509e+00, -5.4132e-01, -1.4426e+00,  3.2334e+00,
        -6.8390e-01, -1.2033e+00,  1.1372e+00,  2.4144e+00, -8.7205e-01,
        -9.8211e-01, -2.9181e+00, -7.8811e-01, -1.0

In [25]:
pytorch_train_x[0]

tensor([ 0.,  0.,  6., 11., 16., 13.,  5.,  0.,  0.,  2., 16., 16., 16., 16.,
        12.,  0.,  0.,  0.,  0.,  0.,  5., 16.,  4.,  0.,  0.,  0.,  0., 10.,
        15.,  5.,  0.,  0.,  0.,  0.,  9., 16.,  3.,  0.,  0.,  0.,  0.,  0.,
        13., 16., 13.,  1.,  0.,  0.,  0.,  0.,  0.,  5., 16., 14.,  0.,  0.,
         0.,  0.,  5., 14., 11.,  6.,  0.,  0.])

### CPU Version

In [14]:
model = Model()
optimizer = optim.Adam(model.parameters())
loss_fun = nn.CrossEntropyLoss()

In [15]:
%%time
# 训练
for i in range(100):
    loss_record = 0
    start_time = time.time()
    
    for batch, (x, y) in enumerate(dataset):
        optimizer.zero_grad()
        y_pre = model(x)
        
        loss = loss_fun(y_pre, y)
        loss.backward()
        optimizer.step()
        
        loss_record += loss.item()
    
    duration = time.time() - start_time
    
    predict = model.predict(pytorch_test_x).numpy()
    acc_score = accuracy_score(predict, pytorch_test_y.numpy())
    
    print('Epoch: {}, time: {:.3f}s, Loss: {:.3f}, Acc: {:.2f}%'.format(i+1, duration, loss_record/batch, acc_score*100))

Epoch: 1, time: 0.375s, Loss: 2.012, Acc: 71.21%
Epoch: 2, time: 0.121s, Loss: 0.720, Acc: 87.21%
Epoch: 3, time: 0.133s, Loss: 0.414, Acc: 89.23%
Epoch: 4, time: 0.114s, Loss: 0.279, Acc: 91.75%
Epoch: 5, time: 0.124s, Loss: 0.176, Acc: 92.93%
Epoch: 6, time: 0.117s, Loss: 0.138, Acc: 94.78%
Epoch: 7, time: 0.116s, Loss: 0.100, Acc: 94.78%
Epoch: 8, time: 0.125s, Loss: 0.109, Acc: 93.77%
Epoch: 9, time: 0.119s, Loss: 0.098, Acc: 94.11%
Epoch: 10, time: 0.116s, Loss: 0.075, Acc: 93.94%
Epoch: 11, time: 0.121s, Loss: 0.065, Acc: 96.30%
Epoch: 12, time: 0.119s, Loss: 0.079, Acc: 94.78%
Epoch: 13, time: 0.122s, Loss: 0.075, Acc: 95.62%
Epoch: 14, time: 0.121s, Loss: 0.034, Acc: 97.81%
Epoch: 15, time: 0.121s, Loss: 0.030, Acc: 96.63%
Epoch: 16, time: 0.114s, Loss: 0.033, Acc: 96.80%
Epoch: 17, time: 0.118s, Loss: 0.033, Acc: 96.13%
Epoch: 18, time: 0.114s, Loss: 0.069, Acc: 95.45%
Epoch: 19, time: 0.116s, Loss: 0.047, Acc: 96.13%
Epoch: 20, time: 0.112s, Loss: 0.015, Acc: 97.64%
Epoch: 21

### GPU Version

In [16]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
pytorch_test_x, pytorch_test_y = pytorch_test_x.to(device), pytorch_test_y.to(device)
model = Model().to(device)
optimizer = optim.Adam(model.parameters())
loss_fun = nn.CrossEntropyLoss()
loss_fun.to(device)

CrossEntropyLoss()

In [17]:
%%time
# 训练
for i in range(100):
    loss_record = 0
    start_time = time.time()
    
    for batch, (x, y) in enumerate(dataset):
        optimizer.zero_grad()

        x, y = x.to(device), y.to(device)
        y_pre = model(x)
        
        loss = loss_fun(y_pre, y)
        loss.backward()
        optimizer.step()
        
        loss_record += loss.cpu().item()
    
    duration = time.time() - start_time
    
    predict = model.predict(pytorch_test_x).cpu().numpy()
    acc_score = accuracy_score(predict, pytorch_test_y.cpu().numpy())
    
    print('Epoch: {}, time: {:.3f}s, Loss: {:.3f}, Acc: {:.2f}%'.format(i+1, duration, loss_record/batch, acc_score*100))

Epoch: 1, time: 0.132s, Loss: 1.958, Acc: 72.39%
Epoch: 2, time: 0.071s, Loss: 0.703, Acc: 87.04%
Epoch: 3, time: 0.069s, Loss: 0.345, Acc: 89.73%
Epoch: 4, time: 0.072s, Loss: 0.242, Acc: 92.26%
Epoch: 5, time: 0.070s, Loss: 0.202, Acc: 90.40%
Epoch: 6, time: 0.072s, Loss: 0.213, Acc: 93.43%
Epoch: 7, time: 0.081s, Loss: 0.161, Acc: 94.11%
Epoch: 8, time: 0.070s, Loss: 0.076, Acc: 94.78%
Epoch: 9, time: 0.071s, Loss: 0.055, Acc: 95.62%
Epoch: 10, time: 0.078s, Loss: 0.046, Acc: 96.13%
Epoch: 11, time: 0.071s, Loss: 0.033, Acc: 97.31%
Epoch: 12, time: 0.075s, Loss: 0.027, Acc: 97.14%
Epoch: 13, time: 0.076s, Loss: 0.026, Acc: 97.31%
Epoch: 14, time: 0.070s, Loss: 0.025, Acc: 97.64%
Epoch: 15, time: 0.074s, Loss: 0.027, Acc: 96.30%
Epoch: 16, time: 0.072s, Loss: 0.027, Acc: 97.14%
Epoch: 17, time: 0.069s, Loss: 0.029, Acc: 96.46%
Epoch: 18, time: 0.070s, Loss: 0.011, Acc: 96.30%
Epoch: 19, time: 0.081s, Loss: 0.015, Acc: 96.97%
Epoch: 20, time: 0.080s, Loss: 0.010, Acc: 97.64%
Epoch: 21

# JAX

In [73]:
import jax.numpy as jnp
from jax import jit, vmap, grad
import jax.nn as jnn
import jax

In [74]:
jax_train_x = jnp.array(train_x)
jax_train_y = jnn.one_hot(jnp.array(train_y), 10)
jax_test_x = jnp.array(test_x)
jax_test_y = jnn.one_hot(jnp.array(test_y), 10)

In [77]:
key = jax.random.PRNGKey(0)
layer_dim = [(64, 128), (128, 256), (256,512), (512, 256), (256,128), (128,64), (64,10)]
parmas = [(jax.random.normal(key, (m, n)), jax.random.normal(key, (n,))) for m, n in layer_dim]

In [148]:
def fc(inputs, w, b):
    return jnp.dot(inputs, w) + b

def relu(inputs):
    # return  jnn.relu(inputs)
    return jnp.maximum(0, inputs)

def sigmoid(inputs):
    return 1 / (1 + jnp.exp(inputs))

def softmax(inputs):
    return jnn.softmax(inputs)

In [151]:
def forward_1(inputs, params):
    for w, b in parmas[:-1]:
        outputs = fc(inputs, w, b)
        inputs = sigmoid(outputs)

    outputs = fc(inputs, *parmas[-1])
    print(outputs)
    outputs = outputs - jax.scipy.special.logsumexp(outputs)
    print(outputs)
    return outputs
        
def loss_fun_1(inputs, parmas, y_true):
    predict = forward_1(inputs, parmas)
    # loss = 
    # print(loss)
    return l-jnp.sum(predict * y_true)

def train_1(inputs, y_true, parmas, lr=0.01):
    grad_data = grad(loss_fun_1, (1,))(inputs, parmas, y_true)
    print(grad_data[0])
    new_parmas = [(w - lr * dw, b - lr * db) for (dw, db), (w, b) in zip(grad_data[0], parmas)]
    return new_parmas


In [153]:
for i in range(1):
    parmas = train_1(jax_train_x[3], jax_train_y[3], parmas)

[ 0.900438  -1.2475622 -4.912817  -4.8752565 -1.5413809  1.5393157
 -1.680436  -6.711699  -4.3902516  1.2724842]
[-1.5338743  -3.6818745  -7.3471293  -7.309569   -3.9756932  -0.89499664
 -4.1147485  -9.146011   -6.824564   -1.1618282 ]
1.5338743
[(array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32), array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0.

In [None]:
forward(jax_train_x[6], parmas)

In [None]:
jax_train_y[6]

In [None]:
y_pre.shape

In [None]:
jax_train_y.shape

In [None]:
parmas[0][0]

In [None]:
jnp.maximum(0, parmas[0][1])

In [None]:
relu(parmas[0][1])

In [None]:
jnp.sum(softmax(parmas[0][1]))

In [None]:
jax.scipy.special.logsumexp(parmas[0][1])

In [235]:
import jax.numpy as jnp
from jax import grad, jit, vmap
import jax.nn as jnn
import jax
from jax.scipy.special import logsumexp

In [236]:
def fc(x, w, b):
    return jnp.dot(x, w) + b

def relu(x):
    return jnn.relu(x)

def softmax(x):
    return jnn.softmax(x)

def logsoftmax(x):
    return logsumexp(x)

In [237]:
@jax.partial(jax.vmap, in_axes=(0, None))
def predict(x, parmas):
    for w, b in parmas[:-1]:
        x = relu(fc(x, w, b))
        
    x = fc(x, *parmas[-1])
    return x - logsoftmax(x)

@grad
def loss_fun(parmas, x, y):
    pre = predict(x, parmas)
    loss = -jnp.mean(pre*y)
    return loss

# @jax.partial(jax.jit, static_argnums=(3,))
@jit
def train(parmas, x, y, lr=0.001):
    parmas_grad = loss_fun(parmas, inputs, y)
    new_grad = [(w - dw * lr, b - db * lr) for (dw, db), (w, b) in zip(parmas_grad, parmas)]
    return new_grad

In [238]:
%%time
key = jax.random.PRNGKey(4)
layer_dim = [(4,8), (8,16), (16,32), (32,64), (64,32), (32,4)]
parmas = [(jax.random.normal(key, (m, n)), jax.random.normal(key, (n,))) for m, n in layer_dim]

inputs = jnp.array([[1,5,3,-1]], dtype=jnp.float32)
y = jnp.array([[1,0,0,0]], dtype=jnp.float32)

CPU times: user 8.83 ms, sys: 3.14 ms, total: 12 ms
Wall time: 13.3 ms


In [240]:
%%time
for i in range(1):
    parmas = train(parmas, inputs, y, lr=0.001)

CPU times: user 406 ms, sys: 14.4 ms, total: 421 ms
Wall time: 544 ms


In [241]:
predict(inputs, parmas)

DeviceArray([[    0.    , -2411.7173, -5136.0386, -4565.2817]], dtype=float32)

In [231]:
predict(inputs, parmas)

DeviceArray([[    0.    , -2411.7173, -5136.0386, -4565.2817]], dtype=float32)

In [224]:
jnp.argmax(predict(inputs, parmas), axis=1)

DeviceArray([3], dtype=int32)

In [170]:
fc(relu(fc(relu(fc(relu(fc(inputs, *parmas[0])), *parmas[1])), *parmas[2])), *parmas[3])

DeviceArray([[-11.767935 , -37.907917 ,  -3.1590328, -10.243803 ]], dtype=float32)

In [171]:
softmax(fc(relu(fc(relu(fc(relu(fc(inputs, *parmas[0])), *parmas[1])), *parmas[2])), *parmas[3]))

DeviceArray([[1.8228816e-04, 8.0966921e-16, 9.9898070e-01, 8.3691336e-04]],            dtype=float32)

In [172]:
logsoftmax(fc(relu(fc(relu(fc(relu(fc(inputs, *parmas[0])), *parmas[1])), *parmas[2])), *parmas[3]))

DeviceArray(-3.158013, dtype=float32)

In [173]:
fc(relu(fc(relu(fc(relu(fc(inputs, *parmas[0])), *parmas[1])), *parmas[2])), *parmas[3]) - logsumexp(fc(relu(fc(relu(fc(relu(fc(inputs, *parmas[0])), *parmas[1])), *parmas[2])), *parmas[3]))

DeviceArray([[-8.6099215e+00, -3.4749905e+01, -1.0197163e-03,
              -7.0857897e+00]], dtype=float32)

In [174]:
(fc(relu(fc(relu(fc(relu(fc(inputs, *parmas[0])), *parmas[1])), *parmas[2])), *parmas[3]) - logsumexp(fc(relu(fc(relu(fc(relu(fc(inputs, *parmas[0])), *parmas[1])), *parmas[2])), *parmas[3]))) * y

DeviceArray([[-0.        , -0.        , -0.00101972, -0.        ]], dtype=float32)

In [175]:
-jnp.mean((fc(relu(fc(relu(fc(relu(fc(inputs, *parmas[0])), *parmas[1])), *parmas[2])), *parmas[3]) - logsumexp(fc(relu(fc(relu(fc(relu(fc(inputs, *parmas[0])), *parmas[1])), *parmas[2])), *parmas[3]))) * y)

DeviceArray(0.00025493, dtype=float32)

In [140]:
jnp.log(softmax(fc(relu(fc(relu(fc(relu(fc(inputs, *parmas[0])), *parmas[1])), *parmas[2])), *parmas[3])))

DeviceArray([[   0.     ,       -inf, -100.44572,  -77.83852]], dtype=float32)

In [141]:
jnp.log(softmax(fc(relu(fc(relu(fc(relu(fc(inputs, *parmas[0])), *parmas[1])), *parmas[2])), *parmas[3]))) * y

DeviceArray([[   0.     ,        nan, -100.44572,   -0.     ]], dtype=float32)

In [142]:
(jnp.log(softmax(fc(relu(fc(relu(fc(relu(fc(inputs, *parmas[0])), *parmas[1])), *parmas[2])), *parmas[3]))) * y).sum()

DeviceArray(nan, dtype=float32)

In [121]:
jnp.log1p(softmax(fc(relu(fc(relu(fc(relu(fc(inputs, *parmas[0])), *parmas[1])), *parmas[2])), *parmas[3]))) * y

DeviceArray([[0.0e+00, 0.0e+00, 2.4e-44, 0.0e+00]], dtype=float32)

In [9]:
a = jax.random.normal(key, (5000, 2000))
b = jax.random.normal(key, (2000, 200))

In [10]:
def fun(x):
    return jnp.dot(x, b)

In [16]:
%%timeit
fun(a)

100 loops, best of 3: 2.74 ms per loop


In [17]:
v_fun = vmap(fun)

In [18]:
%%timeit
v_fun(a)

100 loops, best of 3: 2.69 ms per loop


In [127]:
a = jnp.array([1,2,3])

In [145]:
a[jnp.where(a<=2)]

DeviceArray([1, 2], dtype=int32)