# Simulated Data

We study the performance of Hamiltonian Monte Carlo and Stochastic Gradient Hamiltonian Monte Carlo for classification tasks. 



# Binary Classification

First, a 2-dimensional binary classification problem is simulated. Data is generated from two well separated clusters, so a linear classification model is well suited. The generated data is split into train and test datasets. 

In [1]:
from sklearn.datasets import make_classification
from sklearn.datasets import make_blobs
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.model_selection import train_test_split

D=2
centers = [[-5, 0],  [5, -1]]
X, y = make_blobs(n_samples=1000, centers=centers, cluster_std=1,random_state=40)
X = (X - X.mean(axis=0)) / X.std(axis=0)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)

A logistic regression model is fitted using the 2-dimensional features and $L_2$ regularization with $\alpha=0.25$. The weights and bias are initialized with $\mathbf w=0$ and $b=0$. Stochastic Gradient descent is also used to estimate the parameters using $10e4$ epochs, a learning rate $\epsilon=1e-5$ and a batch size $b=50$.

In [2]:
try:
    from google.colab import drive
    %tensorflow_version 2.x
    COLAB = True
    print("Note: using Google CoLab")
except:
    print("Note: not using Google CoLab")
    COLAB = False

Note: not using Google CoLab


In [3]:
if COLAB:
    !git clone https://github.com/sherna90/hamiltonian_montecarlo/
    !pip install cupy

else:
    import sys
    sys.path.append("../") 

In [4]:
import hamiltonian.models.cpu.logistic as base_model_cpu
import hamiltonian.inference.cpu.sgd as inference_cpu
import numpy as np

D=X_train.shape[1]
N=X_train.shape[0]

epochs = 1e4
eta=1e-5
batch_size=50
alpha=1/4.

start_p={'weights':2*np.random.random((D,1)),'bias':2*np.random.random(1)}
hyper_p={'alpha':alpha}

model_cpu=base_model_cpu.logistic(hyper_p)
optim_cpu=inference_cpu.sgd(model_cpu,start_p,step_size=eta)
par_cpu,loss=optim_cpu.fit(epochs=epochs,batch_size=batch_size,gamma=0.9,
                   X_train=X_train,y_train=y_train,verbose=True)
y_pred=model_cpu.predict(par_cpu,X_test,batchsize=batch_size)

print(classification_report(y_test, y_pred))
print(confusion_matrix(y_test, y_pred))

  3%|▎         | 265/10000 [00:00<00:07, 1334.46it/s]

loss: 0.8813


 11%|█         | 1124/10000 [00:00<00:07, 1156.37it/s]

loss: 0.1525


 23%|██▎       | 2265/10000 [00:01<00:05, 1310.80it/s]

loss: 0.1526


 32%|███▏      | 3194/10000 [00:02<00:05, 1304.82it/s]

loss: 0.1527


 43%|████▎     | 4273/10000 [00:03<00:04, 1340.73it/s]

loss: 0.1527


 52%|█████▏    | 5217/10000 [00:04<00:03, 1332.15it/s]

loss: 0.1527


 62%|██████▏   | 6189/10000 [00:04<00:02, 1384.64it/s]

loss: 0.1527


 72%|███████▏  | 7160/10000 [00:05<00:02, 1360.89it/s]

loss: 0.1527


 83%|████████▎ | 8267/10000 [00:06<00:01, 1359.10it/s]

loss: 0.1527


 92%|█████████▏| 9239/10000 [00:07<00:00, 1384.81it/s]

loss: 0.1527


100%|██████████| 10000/10000 [00:07<00:00, 1300.21it/s]

              precision    recall  f1-score   support

           0       1.00      1.00      1.00       116
           1       1.00      1.00      1.00       134

    accuracy                           1.00       250
   macro avg       1.00      1.00      1.00       250
weighted avg       1.00      1.00      1.00       250

[[116   0]
 [  0 134]]





In [7]:
import cupy as cp

ModuleNotFoundError: No module named 'cupy'

In [5]:
import hamiltonian.models.gpu.logistic as base_model_gpu
import hamiltonian.inference.gpu.sgd as inference_gpu

model_gpu=base_model_gpu.logistic(hyper_p)
optim_gpu=inference_gpu.sgd(model_gpu,start_p,step_size=eta)
par_gpu,loss=optim_cpu.fit(epochs=epochs,batch_size=batch_size,gamma=0.9,
                   X_train=X_train,y_train=y_train,verbose=True)
y_pred=model_gpu.predict(par_gpu,X_test,batchsize=batch_size)

print(classification_report(y_test, y_pred))
print(confusion_matrix(y_test, y_pred))

ModuleNotFoundError: No module named 'cupy'

# Hamiltonian Monte Carlo

Now we want to estimate posterior distributions for the logistic regression model. We run $100$ burn-in samples and then we sample $1e3$ iterations of the Hamiltonian Monte Carlo algorithm. 

In [None]:
import hamiltonian.inference.cpu.hmc as sampler_gpu

burnin=1e2

hmc_gpu=sampler_gpu.hmc(model_cpu,start_p,path_length=1,step_size=eta)
samples,loss,positions,momentums=hmc_gpu.sample(1e3,burnin,None,
                                            X_train=X_train,y_train=y_train)




  0%|          | 0/100 [00:00<?, ?it/s][A[A

  1%|          | 1/100 [00:22<36:38, 22.20s/it][A[A

loss: 0.1615




  2%|▏         | 2/100 [00:40<34:20, 21.02s/it][A[A

  3%|▎         | 3/100 [00:42<24:56, 15.43s/it][A[A

  4%|▍         | 4/100 [00:49<20:41, 12.93s/it][A[A

  5%|▌         | 5/100 [00:57<17:58, 11.35s/it][A[A

  6%|▌         | 6/100 [01:14<20:17, 12.95s/it][A[A

  7%|▋         | 7/100 [01:33<23:01, 14.86s/it][A[A

  8%|▊         | 8/100 [01:49<23:11, 15.13s/it][A[A

  9%|▉         | 9/100 [02:09<25:17, 16.67s/it][A[A

 10%|█         | 10/100 [02:13<19:04, 12.71s/it][A[A

 11%|█         | 11/100 [02:24<18:27, 12.44s/it][A[A

loss: 0.1615




 12%|█▏        | 12/100 [02:36<17:40, 12.05s/it][A[A

 13%|█▎        | 13/100 [02:36<12:25,  8.57s/it][A[A

 14%|█▍        | 14/100 [02:50<14:40, 10.24s/it][A[A

 15%|█▌        | 15/100 [03:12<19:14, 13.59s/it][A[A

 16%|█▌        | 16/100 [03:26<19:15, 13.76s/it][A[A

 17%|█▋        | 17/100 [03:31<15:36, 11.28s/it][A[A

 18%|█▊        | 18/100 [03:45<16:27, 12.05s/it][A[A

 19%|█▉        | 19/100 [04:01<17:55, 13.27s/it][A[A

 20%|██        | 20/100 [04:22<20:33, 15.41s/it][A[A

 21%|██        | 21/100 [04:41<21:54, 16.64s/it][A[A

loss: 0.1615




 22%|██▏       | 22/100 [04:59<22:19, 17.17s/it][A[A

 23%|██▎       | 23/100 [05:00<15:34, 12.14s/it][A[A

 24%|██▍       | 24/100 [05:05<12:45, 10.07s/it][A[A

 25%|██▌       | 25/100 [05:15<12:28,  9.98s/it][A[A

 26%|██▌       | 26/100 [05:15<08:45,  7.10s/it][A[A

 27%|██▋       | 27/100 [05:33<12:24, 10.20s/it][A[A

 28%|██▊       | 28/100 [05:42<11:53,  9.91s/it][A[A

 29%|██▉       | 29/100 [05:56<13:16, 11.22s/it][A[A

 30%|███       | 30/100 [06:04<11:51, 10.17s/it][A[A

 31%|███       | 31/100 [06:24<15:09, 13.18s/it][A[A

loss: 0.1615




 32%|███▏      | 32/100 [06:31<12:39, 11.17s/it][A[A

 33%|███▎      | 33/100 [06:47<14:10, 12.70s/it][A[A

 34%|███▍      | 34/100 [06:55<12:26, 11.32s/it][A[A

 35%|███▌      | 35/100 [07:13<14:31, 13.41s/it][A[A

 36%|███▌      | 36/100 [07:26<14:04, 13.19s/it][A[A

 37%|███▋      | 37/100 [07:43<15:02, 14.33s/it][A[A

 38%|███▊      | 38/100 [07:46<11:16, 10.92s/it][A[A

 39%|███▉      | 39/100 [08:04<13:10, 12.96s/it][A[A

 40%|████      | 40/100 [08:10<10:55, 10.93s/it][A[A

 41%|████      | 41/100 [08:19<10:18, 10.49s/it][A[A

loss: 0.1615




 42%|████▏     | 42/100 [08:36<12:02, 12.46s/it][A[A

 43%|████▎     | 43/100 [08:37<08:29,  8.94s/it][A[A

 44%|████▍     | 44/100 [08:51<09:36, 10.29s/it][A[A

 45%|████▌     | 45/100 [08:57<08:18,  9.06s/it][A[A

 46%|████▌     | 46/100 [09:04<07:44,  8.61s/it][A[A

 47%|████▋     | 47/100 [09:24<10:28, 11.85s/it][A[A

 48%|████▊     | 48/100 [09:29<08:30,  9.81s/it][A[A

 49%|████▉     | 49/100 [09:47<10:33, 12.43s/it][A[A

 50%|█████     | 50/100 [09:54<08:58, 10.77s/it][A[A

 51%|█████     | 51/100 [10:07<09:16, 11.35s/it][A[A

loss: 0.1615




 52%|█████▏    | 52/100 [10:11<07:16,  9.09s/it][A[A

 53%|█████▎    | 53/100 [10:23<07:47,  9.94s/it][A[A

 54%|█████▍    | 54/100 [10:43<10:04, 13.14s/it][A[A

 55%|█████▌    | 55/100 [10:59<10:22, 13.84s/it][A[A

 56%|█████▌    | 56/100 [11:05<08:35, 11.71s/it][A[A

 57%|█████▋    | 57/100 [11:10<06:53,  9.61s/it][A[A

 58%|█████▊    | 58/100 [11:25<07:50, 11.19s/it][A[A

 59%|█████▉    | 59/100 [11:36<07:36, 11.12s/it][A[A

 60%|██████    | 60/100 [11:42<06:28,  9.71s/it][A[A

 61%|██████    | 61/100 [12:04<08:40, 13.34s/it][A[A

loss: 0.1615




 62%|██████▏   | 62/100 [12:14<07:40, 12.12s/it][A[A

 63%|██████▎   | 63/100 [12:22<06:43, 10.92s/it][A[A

 64%|██████▍   | 64/100 [12:37<07:17, 12.16s/it][A[A

 65%|██████▌   | 65/100 [12:45<06:23, 10.96s/it][A[A

 66%|██████▌   | 66/100 [13:00<06:59, 12.33s/it][A[A

 67%|██████▋   | 67/100 [13:07<05:50, 10.62s/it][A[A

 68%|██████▊   | 68/100 [13:12<04:47,  9.00s/it][A[A

 69%|██████▉   | 69/100 [13:16<03:46,  7.30s/it][A[A

 70%|███████   | 70/100 [13:19<03:02,  6.07s/it][A[A

 71%|███████   | 71/100 [13:35<04:21,  9.00s/it][A[A

loss: 0.1615




 72%|███████▏  | 72/100 [13:48<04:47, 10.25s/it][A[A

 73%|███████▎  | 73/100 [13:52<03:45,  8.34s/it][A[A

 74%|███████▍  | 74/100 [13:58<03:18,  7.65s/it][A[A

 75%|███████▌  | 75/100 [13:58<02:15,  5.44s/it][A[A

 76%|███████▌  | 76/100 [14:08<02:46,  6.96s/it][A[A

 77%|███████▋  | 77/100 [14:19<03:06,  8.11s/it][A[A

 78%|███████▊  | 78/100 [14:34<03:42, 10.13s/it][A[A

 79%|███████▉  | 79/100 [14:34<02:31,  7.20s/it][A[A

 80%|████████  | 80/100 [14:56<03:48, 11.41s/it][A[A

 81%|████████  | 81/100 [15:16<04:25, 13.96s/it][A[A

loss: 0.1615




 82%|████████▏ | 82/100 [15:24<03:39, 12.20s/it][A[A

 83%|████████▎ | 83/100 [15:28<02:46,  9.80s/it][A[A

 84%|████████▍ | 84/100 [15:48<03:24, 12.77s/it][A[A

 85%|████████▌ | 85/100 [16:05<03:33, 14.25s/it][A[A

 86%|████████▌ | 86/100 [16:25<03:40, 15.77s/it][A[A

 87%|████████▋ | 87/100 [16:36<03:08, 14.48s/it][A[A

 88%|████████▊ | 88/100 [16:43<02:27, 12.32s/it][A[A

 89%|████████▉ | 89/100 [16:56<02:17, 12.46s/it][A[A

 90%|█████████ | 90/100 [17:11<02:11, 13.12s/it][A[A

 91%|█████████ | 91/100 [17:18<01:41, 11.26s/it][A[A

loss: 0.1615




 92%|█████████▏| 92/100 [17:36<01:46, 13.29s/it][A[A

 93%|█████████▎| 93/100 [17:48<01:31, 13.08s/it][A[A

 94%|█████████▍| 94/100 [17:59<01:13, 12.23s/it][A[A

 95%|█████████▌| 95/100 [18:21<01:16, 15.25s/it][A[A

 96%|█████████▌| 96/100 [18:39<01:04, 16.07s/it][A[A

 97%|█████████▋| 97/100 [18:47<00:40, 13.60s/it][A[A

 98%|█████████▊| 98/100 [19:06<00:30, 15.25s/it][A[A

 99%|█████████▉| 99/100 [19:13<00:12, 12.84s/it][A[A

100%|██████████| 100/100 [19:32<00:00, 14.73s/it][A[A

  0%|          | 0/1000 [00:00<?, ?it/s][A[A

adapted step size :  0.07905574957317466




  0%|          | 1/1000 [00:11<3:17:01, 11.83s/it][A[A

loss: 0.1615




  0%|          | 2/1000 [00:27<3:37:55, 13.10s/it][A[A

  0%|          | 3/1000 [00:39<3:31:04, 12.70s/it][A[A

  0%|          | 4/1000 [00:42<2:43:32,  9.85s/it][A[A

  0%|          | 5/1000 [00:48<2:20:57,  8.50s/it][A[A

  1%|          | 6/1000 [01:04<3:00:28, 10.89s/it][A[A

  1%|          | 7/1000 [01:19<3:19:59, 12.08s/it][A[A

  1%|          | 8/1000 [01:31<3:17:51, 11.97s/it][A[A

  1%|          | 9/1000 [01:45<3:26:40, 12.51s/it][A[A

  1%|          | 10/1000 [02:07<4:14:52, 15.45s/it][A[A

  1%|          | 11/1000 [02:08<3:05:23, 11.25s/it][A[A

  1%|          | 12/1000 [02:28<3:48:49, 13.90s/it][A[A

  1%|▏         | 13/1000 [02:34<3:06:39, 11.35s/it][A[A

  1%|▏         | 14/1000 [02:46<3:13:09, 11.75s/it][A[A

  2%|▏         | 15/1000 [03:03<3:36:47, 13.21s/it][A[A

  2%|▏         | 16/1000 [03:23<4:07:31, 15.09s/it][A[A

  2%|▏         | 17/1000 [03:29<3:25:16, 12.53s/it][A[A

  2%|▏         | 18/1000 [03:29<2:24:55,  8.86s/it][A[A

  2

loss: 0.1615




 10%|█         | 102/1000 [19:14<2:23:01,  9.56s/it][A[A

 10%|█         | 103/1000 [19:22<2:15:36,  9.07s/it][A[A

 10%|█         | 104/1000 [19:24<1:45:09,  7.04s/it][A[A

 10%|█         | 105/1000 [19:43<2:39:04, 10.66s/it][A[A

 11%|█         | 106/1000 [19:54<2:37:48, 10.59s/it][A[A

 11%|█         | 107/1000 [20:08<2:57:02, 11.90s/it][A[A

 11%|█         | 108/1000 [20:30<3:39:24, 14.76s/it][A[A

 11%|█         | 109/1000 [20:38<3:10:56, 12.86s/it][A[A

 11%|█         | 110/1000 [21:01<3:53:04, 15.71s/it][A[A

 11%|█         | 111/1000 [21:11<3:27:46, 14.02s/it][A[A

 11%|█         | 112/1000 [21:33<4:02:13, 16.37s/it][A[A

 11%|█▏        | 113/1000 [21:35<2:59:22, 12.13s/it][A[A

 11%|█▏        | 114/1000 [21:50<3:11:41, 12.98s/it][A[A

 12%|█▏        | 115/1000 [22:00<2:56:59, 12.00s/it][A[A

 12%|█▏        | 116/1000 [22:13<3:04:06, 12.50s/it][A[A

 12%|█▏        | 117/1000 [22:26<3:06:21, 12.66s/it][A[A

 12%|█▏        | 118/1000 [22:43<3:24:

loss: 0.1615




 20%|██        | 202/1000 [38:58<2:19:53, 10.52s/it][A[A

 20%|██        | 203/1000 [39:13<2:34:27, 11.63s/it][A[A

 20%|██        | 204/1000 [39:21<2:22:49, 10.77s/it][A[A

 20%|██        | 205/1000 [39:44<3:08:00, 14.19s/it][A[A

 21%|██        | 206/1000 [40:03<3:29:21, 15.82s/it][A[A

 21%|██        | 207/1000 [40:14<3:07:42, 14.20s/it][A[A

 21%|██        | 208/1000 [40:20<2:36:30, 11.86s/it][A[A

 21%|██        | 209/1000 [40:35<2:49:38, 12.87s/it][A[A

 21%|██        | 210/1000 [40:41<2:20:32, 10.67s/it][A[A

 21%|██        | 211/1000 [40:57<2:41:11, 12.26s/it][A[A

 21%|██        | 212/1000 [40:59<2:02:53,  9.36s/it][A[A

 21%|██▏       | 213/1000 [41:02<1:34:40,  7.22s/it][A[A

 21%|██▏       | 214/1000 [41:05<1:19:49,  6.09s/it][A[A

 22%|██▏       | 215/1000 [41:26<2:16:25, 10.43s/it][A[A

 22%|██▏       | 216/1000 [41:37<2:20:24, 10.75s/it][A[A

 22%|██▏       | 217/1000 [41:43<2:00:17,  9.22s/it][A[A

 22%|██▏       | 218/1000 [41:45<1:32:

loss: 0.1615




 30%|███       | 302/1000 [58:24<2:13:31, 11.48s/it][A[A

 30%|███       | 303/1000 [58:42<2:36:43, 13.49s/it][A[A

 30%|███       | 304/1000 [58:54<2:29:46, 12.91s/it][A[A

 30%|███       | 305/1000 [59:06<2:27:02, 12.69s/it][A[A

 31%|███       | 306/1000 [59:18<2:22:28, 12.32s/it][A[A

 31%|███       | 307/1000 [59:20<1:47:23,  9.30s/it][A[A

 31%|███       | 308/1000 [59:29<1:46:13,  9.21s/it][A[A

 31%|███       | 309/1000 [59:30<1:18:25,  6.81s/it][A[A

 31%|███       | 310/1000 [59:41<1:32:51,  8.07s/it][A[A

 31%|███       | 311/1000 [59:42<1:07:11,  5.85s/it][A[A

 31%|███       | 312/1000 [59:56<1:34:47,  8.27s/it][A[A

 31%|███▏      | 313/1000 [1:00:02<1:29:01,  7.77s/it][A[A

 31%|███▏      | 314/1000 [1:00:22<2:09:26, 11.32s/it][A[A

 32%|███▏      | 315/1000 [1:00:22<1:30:55,  7.96s/it][A[A

 32%|███▏      | 316/1000 [1:00:26<1:17:28,  6.80s/it][A[A

 32%|███▏      | 317/1000 [1:00:38<1:33:48,  8.24s/it][A[A

 32%|███▏      | 318/1000 [1

loss: 0.1615




 40%|████      | 402/1000 [1:15:38<2:18:00, 13.85s/it][A[A

 40%|████      | 403/1000 [1:15:54<2:24:50, 14.56s/it][A[A

 40%|████      | 404/1000 [1:16:03<2:05:54, 12.68s/it][A[A

 40%|████      | 405/1000 [1:16:23<2:30:12, 15.15s/it][A[A

 41%|████      | 406/1000 [1:16:42<2:40:05, 16.17s/it][A[A

 41%|████      | 407/1000 [1:16:43<1:56:07, 11.75s/it][A[A

 41%|████      | 408/1000 [1:16:49<1:37:53,  9.92s/it][A[A

 41%|████      | 409/1000 [1:16:55<1:25:31,  8.68s/it][A[A

 41%|████      | 410/1000 [1:17:01<1:18:01,  7.93s/it][A[A

 41%|████      | 411/1000 [1:17:13<1:30:07,  9.18s/it][A[A

 41%|████      | 412/1000 [1:17:15<1:07:04,  6.84s/it][A[A

 41%|████▏     | 413/1000 [1:17:34<1:42:43, 10.50s/it][A[A

 41%|████▏     | 414/1000 [1:17:37<1:22:47,  8.48s/it][A[A

 42%|████▏     | 415/1000 [1:17:52<1:40:51, 10.34s/it][A[A

 42%|████▏     | 416/1000 [1:18:05<1:47:16, 11.02s/it][A[A

 42%|████▏     | 417/1000 [1:18:11<1:32:16,  9.50s/it][A[A

 42%|█

loss: 0.1615




 50%|█████     | 502/1000 [1:36:17<1:30:17, 10.88s/it][A[A

 50%|█████     | 503/1000 [1:36:28<1:29:39, 10.82s/it][A[A

 50%|█████     | 504/1000 [1:36:28<1:03:47,  7.72s/it][A[A

 50%|█████     | 505/1000 [1:36:32<54:39,  6.63s/it]  [A[A

 51%|█████     | 506/1000 [1:36:43<1:04:02,  7.78s/it][A[A

 51%|█████     | 507/1000 [1:36:43<46:42,  5.68s/it]  [A[A

 51%|█████     | 508/1000 [1:36:55<1:00:36,  7.39s/it][A[A

 51%|█████     | 509/1000 [1:37:12<1:23:16, 10.18s/it][A[A

 51%|█████     | 510/1000 [1:37:34<1:52:09, 13.73s/it][A[A

 51%|█████     | 511/1000 [1:37:46<1:48:25, 13.30s/it][A[A

 51%|█████     | 512/1000 [1:38:04<2:01:15, 14.91s/it][A[A

 51%|█████▏    | 513/1000 [1:38:10<1:37:26, 12.01s/it][A[A

 51%|█████▏    | 514/1000 [1:38:30<1:56:09, 14.34s/it][A[A

 52%|█████▏    | 515/1000 [1:38:48<2:04:52, 15.45s/it][A[A

 52%|█████▏    | 516/1000 [1:39:05<2:08:58, 15.99s/it][A[A

 52%|█████▏    | 517/1000 [1:39:26<2:21:55, 17.63s/it][A[A

 52%|█

loss: 0.1615




 60%|██████    | 602/1000 [1:54:47<1:05:22,  9.86s/it][A[A

 60%|██████    | 603/1000 [1:55:07<1:25:40, 12.95s/it][A[A

 60%|██████    | 604/1000 [1:55:11<1:08:19, 10.35s/it][A[A

 60%|██████    | 605/1000 [1:55:32<1:29:10, 13.54s/it][A[A

 61%|██████    | 606/1000 [1:55:36<1:10:06, 10.68s/it][A[A

 61%|██████    | 607/1000 [1:55:50<1:14:43, 11.41s/it][A[A

 61%|██████    | 608/1000 [1:56:08<1:27:35, 13.41s/it][A[A

 61%|██████    | 609/1000 [1:56:24<1:33:55, 14.41s/it][A[A

 61%|██████    | 610/1000 [1:56:38<1:31:47, 14.12s/it][A[A

 61%|██████    | 611/1000 [1:56:40<1:07:45, 10.45s/it][A[A

 61%|██████    | 612/1000 [1:57:02<1:30:24, 13.98s/it][A[A

 61%|██████▏   | 613/1000 [1:57:10<1:18:33, 12.18s/it][A[A

 61%|██████▏   | 614/1000 [1:57:15<1:04:17,  9.99s/it][A[A

 62%|██████▏   | 615/1000 [1:57:23<1:01:23,  9.57s/it][A[A

 62%|██████▏   | 616/1000 [1:57:37<1:08:42, 10.73s/it][A[A

 62%|██████▏   | 617/1000 [1:57:56<1:24:47, 13.28s/it][A[A

 62%|█

loss: 0.1615




 70%|███████   | 702/1000 [2:13:29<52:44, 10.62s/it]  [A[A

 70%|███████   | 703/1000 [2:13:39<52:21, 10.58s/it][A[A

 70%|███████   | 704/1000 [2:13:58<1:03:38, 12.90s/it][A[A

 70%|███████   | 705/1000 [2:14:04<53:27, 10.87s/it]  [A[A

 71%|███████   | 706/1000 [2:14:09<45:21,  9.26s/it][A[A

 71%|███████   | 707/1000 [2:14:16<41:53,  8.58s/it][A[A

 71%|███████   | 708/1000 [2:14:19<33:25,  6.87s/it][A[A

 71%|███████   | 709/1000 [2:14:38<50:35, 10.43s/it][A[A

 71%|███████   | 710/1000 [2:14:57<1:02:46, 12.99s/it][A[A

 71%|███████   | 711/1000 [2:15:04<53:39, 11.14s/it]  [A[A

 71%|███████   | 712/1000 [2:15:07<42:27,  8.85s/it][A[A

 71%|███████▏  | 713/1000 [2:15:17<43:35,  9.11s/it][A[A

 71%|███████▏  | 714/1000 [2:15:25<41:23,  8.68s/it][A[A

 72%|███████▏  | 715/1000 [2:15:37<47:01,  9.90s/it][A[A

 72%|███████▏  | 716/1000 [2:15:45<43:23,  9.17s/it][A[A

 72%|███████▏  | 717/1000 [2:16:01<53:30, 11.35s/it][A[A

 72%|███████▏  | 718/1000 [2

loss: 0.1619




 80%|████████  | 802/1000 [2:30:07<29:32,  8.95s/it][A[A

 80%|████████  | 803/1000 [2:30:16<29:18,  8.93s/it][A[A

 80%|████████  | 804/1000 [2:30:23<26:53,  8.23s/it][A[A

 80%|████████  | 805/1000 [2:30:42<37:59, 11.69s/it][A[A

 81%|████████  | 806/1000 [2:31:03<46:30, 14.39s/it][A[A

 81%|████████  | 807/1000 [2:31:07<36:11, 11.25s/it][A[A

 81%|████████  | 808/1000 [2:31:16<33:40, 10.52s/it][A[A

 81%|████████  | 809/1000 [2:31:26<32:59, 10.37s/it][A[A

 81%|████████  | 810/1000 [2:31:44<40:23, 12.75s/it][A[A

 81%|████████  | 811/1000 [2:32:04<46:36, 14.79s/it][A[A

 81%|████████  | 812/1000 [2:32:16<44:01, 14.05s/it][A[A

 81%|████████▏ | 813/1000 [2:32:28<41:36, 13.35s/it][A[A

 81%|████████▏ | 814/1000 [2:32:49<48:44, 15.72s/it][A[A

 82%|████████▏ | 815/1000 [2:33:04<47:33, 15.42s/it][A[A

 82%|████████▏ | 816/1000 [2:33:19<47:23, 15.46s/it][A[A

 82%|████████▏ | 817/1000 [2:33:36<48:30, 15.90s/it][A[A

 82%|████████▏ | 818/1000 [2:33:52<47:

loss: 0.1619




 90%|█████████ | 902/1000 [2:48:07<13:29,  8.26s/it][A[A

 90%|█████████ | 903/1000 [2:48:07<09:37,  5.95s/it][A[A

 90%|█████████ | 904/1000 [2:48:11<08:27,  5.28s/it][A[A

 90%|█████████ | 905/1000 [2:48:18<09:02,  5.72s/it][A[A

 91%|█████████ | 906/1000 [2:48:40<16:41, 10.66s/it][A[A

 91%|█████████ | 907/1000 [2:48:45<13:51,  8.94s/it][A[A

 91%|█████████ | 908/1000 [2:48:52<12:54,  8.41s/it][A[A

 91%|█████████ | 909/1000 [2:49:15<19:38, 12.95s/it][A[A

 91%|█████████ | 910/1000 [2:49:28<19:14, 12.83s/it][A[A

 91%|█████████ | 911/1000 [2:49:51<23:29, 15.83s/it][A[A

In [None]:
par_mean={var:np.mean(samples[var],axis=0).reshape(start_p[var].shape) for var in samples.keys()}
par_var={var:np.var(samples[var],axis=0).reshape(start_p[var].shape) for var in samples.keys()}
y_pred=model_cpu.predict(X_test,par_mean)

print(classification_report(y_test, y_pred))
print(confusion_matrix(y_test, y_pred))




In [None]:
import matplotlib.pyplot as plt

def plot_uncertainty_hyperplane(par_mean,par_var, color):
    bd = lambda x0,par :  (-(x0 * par['weights'][0]) - par['bias']) / par['weights'][1]
    r=np.linspace(xmin,xmax)
    par_m={var:par_mean[var]-np.sqrt(par_var[var]) for var in par_mean.keys()}
    par_p={var:par_mean[var]+np.sqrt(par_var[var]) for var in par_mean.keys()}
    plt.plot(r,bd(r,par_mean),ls="-", color=color)
    plt.plot(r,bd(r,par_m),ls="--", color=color, alpha=0.5)
    plt.plot(r,bd(r,par_p),ls="--", color=color, alpha=0.5)
    #plt.fill_between(r, bd(r,par_mean),  bd(r,par_p), color=color, alpha=0.5)
    #plt.fill_between(r, bd(r,par_m),  bd(r,par_mean), color=color, alpha=0.5)
    
plt.figure()
plt.scatter(X[:, 0], X[:, 1],marker='o', c=y,s=25, edgecolor='k')
plt.axis('tight')
xmin, xmax = plt.xlim()
ymin, ymax = plt.ylim()
plot_uncertainty_hyperplane(par_mean, par_var,"b")
plt.show()