In [52]:
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, roc_auc_score
import matplotlib.pyplot as plt
import time

In [53]:
digits = datasets.load_digits()
features = digits['data']
target = digits['target'] 

In [54]:
train_x, test_x, train_y, test_y = train_test_split(features, target, test_size=0.33)

print(train_x.shape)
print(train_y.shape)
print(test_x.shape)
print(test_y.shape)

(1203, 64)
(1203,)
(594, 64)
(594,)


# Pytorch

In [4]:
import torch.nn as nn
import torch.optim as optim
import torch
from torch.utils.data import DataLoader, TensorDataset

In [5]:
# 格式转化
pytorch_train_x = torch.tensor(train_x, dtype=torch.float32)
pytorch_train_y = torch.tensor(train_y, dtype=torch.long)

pytorch_test_x = torch.tensor(test_x, dtype=torch.float32)
pytorch_test_y = torch.tensor(test_y, dtype=torch.long)

# 装进数据容器分批
data = TensorDataset(pytorch_train_x, pytorch_train_y)
dataset = DataLoader(data, 64)

In [6]:
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        
        self.fc1 = nn.Linear(64, 128)
        self.fc2 = nn.Linear(128, 256)
        self.fc3 = nn.Linear(256, 512)
        self.fc4 = nn.Linear(512, 256)
        self.fc5 = nn.Linear(256, 128)
        self.fc6 = nn.Linear(128, 64)
        self.fc7 = nn.Linear(64, 10)
        self.softmax = nn.Softmax(dim=1)
        self.relu = nn.ReLU()
        
    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.relu(self.fc3(x))
        x = self.relu(self.fc4(x))
        x = self.relu(self.fc5(x))
        x = self.relu(self.fc6(x))
        result = self.fc7(x)
        return result
    
    def predict(self, x):
        result = self.softmax(self(x)).sort()[1][:,-1]
        return result

### CPU Version

In [7]:
model = Model()
optimizer = optim.Adam(model.parameters())
loss_fun = nn.CrossEntropyLoss()

In [8]:
%%time
# 训练
for i in range(100):
    loss_record = 0
    start_time = time.time()
    
    for batch, (x, y) in enumerate(dataset):
        optimizer.zero_grad()
        y_pre = model(x)
        
        loss = loss_fun(y_pre, y)
        loss.backward()
        optimizer.step()
        
        loss_record += loss.item()
    
    duration = time.time() - start_time
    
    predict = model.predict(pytorch_test_x).numpy()
    acc_score = accuracy_score(predict, pytorch_test_y.numpy())
    
    print('Epoch: {}, time: {:.3f}s, Loss: {:.3f}, Acc: {:.2f}%'.format(i+1, duration, loss_record/batch, acc_score*100))

Epoch: 1, time: 0.343s, Loss: 1.835, Acc: 73.40%
Epoch: 2, time: 0.114s, Loss: 0.613, Acc: 88.55%
Epoch: 3, time: 0.117s, Loss: 0.283, Acc: 91.58%
Epoch: 4, time: 0.112s, Loss: 0.172, Acc: 92.26%
Epoch: 5, time: 0.113s, Loss: 0.135, Acc: 95.62%
Epoch: 6, time: 0.114s, Loss: 0.127, Acc: 97.31%
Epoch: 7, time: 0.119s, Loss: 0.084, Acc: 96.80%
Epoch: 8, time: 0.109s, Loss: 0.051, Acc: 97.64%
Epoch: 9, time: 0.115s, Loss: 0.038, Acc: 96.46%
Epoch: 10, time: 0.108s, Loss: 0.032, Acc: 98.15%
Epoch: 11, time: 0.117s, Loss: 0.014, Acc: 98.15%
Epoch: 12, time: 0.111s, Loss: 0.010, Acc: 97.14%
Epoch: 13, time: 0.110s, Loss: 0.021, Acc: 96.97%
Epoch: 14, time: 0.111s, Loss: 0.025, Acc: 96.80%
Epoch: 15, time: 0.118s, Loss: 0.019, Acc: 97.47%
Epoch: 16, time: 0.134s, Loss: 0.017, Acc: 97.98%
Epoch: 17, time: 0.111s, Loss: 0.008, Acc: 97.14%
Epoch: 18, time: 0.109s, Loss: 0.008, Acc: 96.97%
Epoch: 19, time: 0.113s, Loss: 0.006, Acc: 97.81%
Epoch: 20, time: 0.107s, Loss: 0.005, Acc: 97.47%
Epoch: 21

### GPU Version

In [9]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
pytorch_test_x, pytorch_test_y = pytorch_test_x.to(device), pytorch_test_y.to(device)
model = Model().to(device)
optimizer = optim.Adam(model.parameters())
loss_fun = nn.CrossEntropyLoss()
loss_fun.to(device)

CrossEntropyLoss()

In [10]:
%%time
# 训练
for i in range(100):
    loss_record = 0
    start_time = time.time()
    
    for batch, (x, y) in enumerate(dataset):
        optimizer.zero_grad()

        x, y = x.to(device), y.to(device)
        y_pre = model(x)
        
        loss = loss_fun(y_pre, y)
        loss.backward()
        optimizer.step()
        
        loss_record += loss.cpu().item()
    
    duration = time.time() - start_time
    
    predict = model.predict(pytorch_test_x).cpu().numpy()
    acc_score = accuracy_score(predict, pytorch_test_y.cpu().numpy())
    
    print('Epoch: {}, time: {:.3f}s, Loss: {:.3f}, Acc: {:.2f}%'.format(i+1, duration, loss_record/batch, acc_score*100))

Epoch: 1, time: 0.134s, Loss: 1.762, Acc: 79.63%
Epoch: 2, time: 0.069s, Loss: 0.531, Acc: 89.90%
Epoch: 3, time: 0.067s, Loss: 0.287, Acc: 92.59%
Epoch: 4, time: 0.070s, Loss: 0.132, Acc: 95.79%
Epoch: 5, time: 0.072s, Loss: 0.076, Acc: 95.12%
Epoch: 6, time: 0.068s, Loss: 0.066, Acc: 96.80%
Epoch: 7, time: 0.070s, Loss: 0.078, Acc: 94.61%
Epoch: 8, time: 0.068s, Loss: 0.073, Acc: 94.78%
Epoch: 9, time: 0.068s, Loss: 0.083, Acc: 93.43%
Epoch: 10, time: 0.070s, Loss: 0.051, Acc: 96.30%
Epoch: 11, time: 0.081s, Loss: 0.015, Acc: 96.97%
Epoch: 12, time: 0.071s, Loss: 0.019, Acc: 96.63%
Epoch: 13, time: 0.070s, Loss: 0.021, Acc: 93.94%
Epoch: 14, time: 0.079s, Loss: 0.029, Acc: 96.97%
Epoch: 15, time: 0.090s, Loss: 0.016, Acc: 97.14%
Epoch: 16, time: 0.073s, Loss: 0.017, Acc: 96.46%
Epoch: 17, time: 0.077s, Loss: 0.011, Acc: 96.13%
Epoch: 18, time: 0.096s, Loss: 0.012, Acc: 96.80%
Epoch: 19, time: 0.096s, Loss: 0.023, Acc: 95.79%
Epoch: 20, time: 0.091s, Loss: 0.012, Acc: 96.63%
Epoch: 21

# JAX

In [222]:
import jax.numpy as jnp
from jax import jit, vmap, grad
import jax.nn as jnn
import jax

In [229]:
jax_train_x = jnp.array(train_x)
jax_train_y = jnn.one_hot(jnp.array(train_y), 10)
jax_test_x = jnp.array(test_x)
jax_test_y = jnn.one_hot(jnp.array(test_y), 10)

In [230]:
parmas = []
layer_dim = [(64, 128), (128, 256), (256,512), (512, 256), (256,128), (128,64), (64,10)]

key = jax.random.PRNGKey(0)
for n_dim in layer_dim:
    w = jax.random.normal(key, n_dim)
    b = jax.random.normal(key, (n_dim[1],))
    parmas.append((w, b))

In [231]:
def fc(inputs, w, b):
    return jnp.dot(inputs, w) + b

def relu(inputs):
    return  jnp.maximum(0, inputs)

def softmax(inputs):
    return jnn.softmax(inputs)

In [232]:
def forward(inputs, parmas):
    for w, b in parmas[:-1]:
        outputs = fc(inputs, w, b)
        inputs = relu(outputs)

    outputs = fc(inputs, *parmas[-1])
    print(outputs)
    return outputs - jax.scipy.special.logsumexp(outputs)

def loss_fun(inputs, parmas, y_true):
    predict = forward(inputs, parmas)
    loss = -jnp.mean(predict * y_true)
    print(loss)
    return loss

def train(inputs, y_true, parmas, lr=0.01):
    grad_data = grad(loss_fun, (1,))(inputs, parmas, y_true)
    new_parmas = [(w - lr * dw, b - lr * db) for (dw, db), (w, b) in zip(grad_data[0], parmas)]
    # print(new_parmas)
    return new_parmas


In [233]:
for i in range(2):
    for x, y in zip(jax_train_x, jax_train_y):
        parmas = train(x, y, parmas)

Traced<ConcreteArray(5916939.0)>with<JVPTrace(level=2/0)>
  with primal = DeviceArray(5916939., dtype=float32)
       tangent = Traced<ShapedArray(float32[]):JaxprTrace(level=1/0)>
Traced<ConcreteArray(8.7778366e+29)>with<JVPTrace(level=2/0)>
  with primal = DeviceArray(8.7778366e+29, dtype=float32)
       tangent = Traced<ShapedArray(float32[]):JaxprTrace(level=1/0)>
Traced<ConcreteArray(nan)>with<JVPTrace(level=2/0)>
  with primal = DeviceArray(nan, dtype=float32)
       tangent = Traced<ShapedArray(float32[]):JaxprTrace(level=1/0)>
Traced<ConcreteArray(nan)>with<JVPTrace(level=2/0)>
  with primal = DeviceArray(nan, dtype=float32)
       tangent = Traced<ShapedArray(float32[]):JaxprTrace(level=1/0)>
Traced<ConcreteArray(nan)>with<JVPTrace(level=2/0)>
  with primal = DeviceArray(nan, dtype=float32)
       tangent = Traced<ShapedArray(float32[]):JaxprTrace(level=1/0)>
Traced<ConcreteArray(nan)>with<JVPTrace(level=2/0)>
  with primal = DeviceArray(nan, dtype=float32)
       tangent = Tr

KeyboardInterrupt: ignored

In [234]:
loss_fun(jax_train_x[6], parmas, jax_train_y[6])

nan


DeviceArray(nan, dtype=float32)

In [235]:
forward(jax_train_x[6], parmas)

DeviceArray([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan], dtype=float32)

In [218]:
jax_train_y[6]

DeviceArray([0., 0., 0., 0., 0., 0., 0., 1., 0., 0.], dtype=float32)

In [118]:
y_pre.shape

(1203, 10)

In [119]:
jax_train_y.shape

(1203, 10)

DeviceArray([nan, nan, nan, ..., nan, nan, nan], dtype=float32)

In [107]:
parmas[0][0]

DeviceArray([[ 1.0901477 , -0.2822716 ,  0.12140159, ..., -2.0507786 ,
              -0.7466879 , -0.6604846 ],
             [ 0.16098852,  0.9195488 ,  0.04407377, ...,  0.6298926 ,
               0.11477478, -1.2528692 ],
             [ 1.1781842 , -2.8546088 ,  0.86888975, ...,  2.0047903 ,
               0.11670755,  0.7683624 ],
             ...,
             [ 1.0974492 ,  1.2009512 , -1.4820064 , ...,  0.29825997,
              -0.5716417 , -1.0285538 ],
             [ 0.45801577, -0.33522558,  0.42189184, ...,  1.8499086 ,
              -0.03835844,  0.79259753],
             [-0.24299109,  0.10712541,  2.2701766 , ..., -1.0625709 ,
              -0.14538413,  0.99837404]], dtype=float32)

In [141]:
jnp.maximum(0, parmas[0][1])

DeviceArray([0.        , 0.        , 0.        , 0.5978571 , 0.9159883 ,
             0.        , 0.00336711, 0.5141878 , 0.        , 0.        ,
             0.        , 1.5653223 , 0.        , 0.        , 0.6160338 ,
             0.        , 0.17686242, 0.        , 0.        , 0.        ,
             0.7195503 , 0.        , 0.        , 0.        , 0.14105996,
             1.3286992 , 0.49153855, 0.        , 1.1188195 , 0.24837267,
             0.57693535, 0.        , 0.        , 1.7324136 , 0.66268396,
             0.        , 0.7361932 , 0.        , 0.        , 1.1780932 ,
             0.        , 0.03500843, 0.        , 2.4009972 , 0.25273788,
             0.97014016, 0.        , 0.768804  , 0.        , 0.38931087,
             0.7114831 , 0.        , 0.305861  , 0.        , 0.        ,
             0.65044457, 0.        , 0.59613913, 1.8058568 , 0.        ,
             0.        , 0.        , 0.8995249 , 0.        , 0.15347093,
             0.53636414, 0.04819222, 0.23581387, 0.

In [143]:
relu(parmas[0][1])

DeviceArray([0.        , 0.        , 0.        , 0.5978571 , 0.9159883 ,
             0.        , 0.00336711, 0.5141878 , 0.        , 0.        ,
             0.        , 1.5653223 , 0.        , 0.        , 0.6160338 ,
             0.        , 0.17686242, 0.        , 0.        , 0.        ,
             0.7195503 , 0.        , 0.        , 0.        , 0.14105996,
             1.3286992 , 0.49153855, 0.        , 1.1188195 , 0.24837267,
             0.57693535, 0.        , 0.        , 1.7324136 , 0.66268396,
             0.        , 0.7361932 , 0.        , 0.        , 1.1780932 ,
             0.        , 0.03500843, 0.        , 2.4009972 , 0.25273788,
             0.97014016, 0.        , 0.768804  , 0.        , 0.38931087,
             0.7114831 , 0.        , 0.305861  , 0.        , 0.        ,
             0.65044457, 0.        , 0.59613913, 1.8058568 , 0.        ,
             0.        , 0.        , 0.8995249 , 0.        , 0.15347093,
             0.53636414, 0.04819222, 0.23581387, 0.

In [149]:
jnp.sum(softmax(parmas[0][1]))

DeviceArray(1., dtype=float32)

In [148]:
jax.scipy.special.logsumexp(parmas[0][1])

DeviceArray(5.283266, dtype=float32)