In [2]:
import torch
import torch.nn.functional as nn
import torch.autograd as autograd
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os
from torch.autograd import Variable
from tensorflow.examples.tutorials.mnist import input_data
from tqdm import tqdm

In [3]:
# system parameters
n_epochs = 10000 
batch_size = 128 
lr = 0.001 # adam:learning rate
b1 = 0.9 # adam: decay of first order momentum of gradient
b2 = 0.999 # adam: decay of first order momentum of gradient
img_size = 28 # size of each image dimension
channels = 1 # number of image channels
sample_interval = 400 #interval between image samples
alpha= 10
Dim = 28*28
p_miss = 0.5
p_hint = 0.9
img_shape = (channels, img_size, img_size)
cuda = True if torch.cuda.is_available() else False
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

In [5]:
# Configure data loade
mnist = input_data.read_data_sets('../../MNIST_data', one_hot = True)


Extracting ../../MNIST_data/train-images-idx3-ubyte.gz
Extracting ../../MNIST_data/train-labels-idx1-ubyte.gz
Extracting ../../MNIST_data/t10k-images-idx3-ubyte.gz
Extracting ../../MNIST_data/t10k-labels-idx1-ubyte.gz


In [6]:
def xavier_init(size):
    in_dim = size[0]
    xavier_stddev = 1. / np.sqrt(in_dim / 2.)
    return Variable(torch.randn(*size) * xavier_stddev, requires_grad=True)

In [7]:
""" ==================== DISCRIMINATOR ======================== """
D_W1 = xavier_init(size = [Dim*2, 256])     # Data + Hint as inputs
D_b1 = Variable(torch.zeros(256),requires_grad=True)

D_W2 = xavier_init(size = [256, 128])
D_b2 = Variable(torch.zeros(128),requires_grad=True)

D_W3 = xavier_init(size = [128, Dim])
D_b3 = Variable(torch.zeros(Dim),requires_grad=True)  

def discriminator(x, m, g, h):
    inp = m * x + (1-m) * g  # Replace missing values to the imputed values
    inputs = torch.cat([inp,h],1)  # Hint + Data Concatenate
    D_h1 = nn.relu(torch.matmul(inputs, D_W1) + D_b1)
    D_h2 = nn.relu(torch.matmul(D_h1, D_W2) + D_b2)
    D_logit = torch.matmul(D_h2, D_W3) + D_b3
    D_prob = nn.sigmoid(D_logit)  # [0,1] Probability Output
    
    return D_prob

In [8]:
""" ==================== GENERATOR ======================== """
G_W1 = xavier_init(size = [Dim*2, 256])    # Data + Mask as inputs (Random Noises are in Missing Components)
G_b1 = Variable(torch.zeros(256),requires_grad=True)

G_W2 = xavier_init(size = [256, 128])
G_b2 = Variable(torch.zeros(128),requires_grad=True)

G_W3 = xavier_init(size = [128, Dim])
G_b3 = Variable(torch.zeros(Dim),requires_grad=True)

def generator(x,z,m):
    inp = m * x + (1-m) * z  # Fill in random noise on the missing values
    inputs = torch.cat([inp,m],1)  # Mask + Data Concatenate
    G_h1 = nn.relu(torch.matmul(inputs, G_W1) + G_b1)
    G_h2 = nn.relu(torch.matmul(G_h1, G_W2) + G_b2)
    G_prob = nn.sigmoid(torch.matmul(G_h2, G_W3) + G_b3) # [0,1] normalized Output
    
    return G_prob

In [9]:
G_params = [G_W1, G_W2, G_W3, G_b1, G_b2, G_b3]
D_params = [D_W1, D_W2, D_W3, D_b1, D_b2, D_b3]
params = G_params + D_params

In [10]:
""" ===================== TRAINING ======================== """
def plot(samples):
    fig = plt.figure(figsize = (5,5))
    gs = gridspec.GridSpec(5,5)
    gs.update(wspace=0.05, hspace=0.05)
    
    for i, sample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(sample.reshape(28,28), cmap='Greys_r')
        
    return fig
def reset_grad():
    for p in params:
        if p.grad is not None:
            data = p.grad.data
            p.grad = Variable(data.new().resize_as_(data).zero_())
            
G_solver = optim.Adam(G_params, lr=1e-3)
D_solver = optim.Adam(D_params, lr=1e-3)

def sample_Z(m, n):
    return np.random.uniform(0., 1., size = [m, n])        

# Mask Vector and Hint Vector Generation
def sample_M(m, n, p):
    A = np.random.uniform(0., 1., size = [m, n])
    B = A > p
    C = 1.*B
    return C
# make output file
if not os.path.exists('MNIST_Impuation_output/'):
    os.makedirs('MNIST_Impuation_output/')
i = 1
for it in tqdm(range(n_epochs)):
    
    X_mb, _ = mnist.train.next_batch(batch_size) 
    X_mb = Variable(torch.from_numpy(X_mb))
    Z_mb = sample_Z(batch_size, Dim) 
    Z_mb = Variable(torch.from_numpy(Z_mb.astype('float32')))
    M_mb = sample_M(batch_size, Dim, p_miss)
    M_mb = Variable(torch.from_numpy(M_mb.astype('float32')))
    H_mb1 = sample_M(batch_size, Dim, 1-p_hint)
    H_mb1 = Variable(torch.from_numpy(H_mb1.astype('float32')))
    H_mb = M_mb * H_mb1
    
    New_X_mb = M_mb * X_mb + (1-M_mb) * Z_mb  # Missing Data Introduce
    # Dicriminator forward-loss-backward-update
    G_sample = generator(X_mb,Z_mb,M_mb)
    D_sample = discriminator(X_mb, M_mb, G_sample, H_mb)
    D_loss = -torch.mean(M_mb * torch.log(D_sample + 1e-8) + (1-M_mb) * torch.log(1. - D_sample + 1e-8)) * 2
    D_loss.backward()
    D_solver.step()
    reset_grad()
    
    # Generator forward-loss-backward-update
    G_sample = generator(X_mb,Z_mb,M_mb)
    D_sample = discriminator(X_mb, M_mb, G_sample, H_mb)
    G_loss1 = -torch.mean((1-M_mb) * torch.log(D_sample + 1e-8)) / torch.mean(1-M_mb)
    MSE_train_loss = torch.mean((M_mb * X_mb - M_mb * G_sample)**2) / torch.mean(M_mb)
    MSE_test_loss = torch.mean(((1-M_mb) * X_mb - (1-M_mb)*G_sample)**2) / torch.mean(1-M_mb)
    G_loss = G_loss1  + alpha * MSE_train_loss 
    G_loss.backward()
    G_solver.step()
    reset_grad()
    
    # Print and Plot
    
 
    
        
    if it % 100 == 0:
        print('Iter: {}'.format(it))
        print('Train_loss: {:.4}'.format(MSE_train_loss))
        print('Test_loss: {:.4}'.format(MSE_test_loss))
        print()
        
        X_mb, _ = mnist.train.next_batch(5)    
        X_mb = Variable(torch.from_numpy(X_mb))
        Z_mb = sample_Z(5, Dim) 
        Z_mb = Variable(torch.from_numpy(Z_mb.astype('float32')))
        M_mb = sample_M(5, Dim, p_miss)
        M_mb = Variable(torch.from_numpy(M_mb.astype('float32')))
        New_X_mb = M_mb * X_mb + (1-M_mb) * Z_mb
        
        samples1 = X_mb                
        samples5 = M_mb * X_mb + (1-M_mb) * Z_mb

        
        samples2 = generator(X_mb,Z_mb,M_mb)
        samples2 = M_mb * X_mb + (1-M_mb) * samples2        

        
        Z_mb = sample_Z(5, Dim) 
        Z_mb = Variable(torch.from_numpy(Z_mb.astype('float32')))
        New_X_mb = M_mb * X_mb + (1-M_mb) * Z_mb       
        samples3 = generator(X_mb,Z_mb,M_mb)
        samples3 = M_mb * X_mb + (1-M_mb) * samples3     

        
        Z_mb = sample_Z(5, Dim) 
        Z_mb = Variable(torch.from_numpy(Z_mb.astype('float32')))
        New_X_mb = M_mb * X_mb + (1-M_mb) * Z_mb       
        samples4 = generator(X_mb,Z_mb,M_mb)
        samples4 = M_mb * X_mb + (1-M_mb) * samples4 

        
        
        
        samples = torch.cat([samples5,samples2,samples3,samples4,samples1],0)
        samples = samples.data.numpy()
        fig = plot(samples)
        plt.savefig('MNIST_Impuation_output/{}.png'.format(str(i).zfill(3)), bbox_inches='tight')
        i += 1
        plt.close(fig)



  0%|          | 0/10000 [00:00<?, ?it/s]

Iter: 0
Train_loss: 0.265
Test_loss: 0.2641



  1%|          | 99/10000 [00:06<11:10, 14.76it/s] 

Iter: 100
Train_loss: 0.06519
Test_loss: 0.06615



  2%|▏         | 199/10000 [00:13<10:44, 15.20it/s]

Iter: 200
Train_loss: 0.05368
Test_loss: 0.05669



  3%|▎         | 300/10000 [00:20<11:11, 14.44it/s]

Iter: 300
Train_loss: 0.04393
Test_loss: 0.04715



  4%|▍         | 400/10000 [00:28<11:16, 14.19it/s]

Iter: 400
Train_loss: 0.03746
Test_loss: 0.04286



  5%|▍         | 499/10000 [00:34<11:06, 14.26it/s]

Iter: 500
Train_loss: 0.03485
Test_loss: 0.04118



  6%|▌         | 599/10000 [00:41<10:57, 14.29it/s]

Iter: 600
Train_loss: 0.03169
Test_loss: 0.03849



  7%|▋         | 700/10000 [00:50<11:13, 13.81it/s]

Iter: 700
Train_loss: 0.03208
Test_loss: 0.03863



  8%|▊         | 800/10000 [01:06<12:42, 12.07it/s]

Iter: 800
Train_loss: 0.02764
Test_loss: 0.03401



  9%|▉         | 900/10000 [01:18<13:11, 11.49it/s]

Iter: 900
Train_loss: 0.02669
Test_loss: 0.03313



 10%|▉         | 999/10000 [01:29<13:24, 11.19it/s]

Iter: 1000
Train_loss: 0.02697
Test_loss: 0.03433



 11%|█         | 1100/10000 [01:40<13:35, 10.92it/s]

Iter: 1100
Train_loss: 0.02735
Test_loss: 0.03483



 12%|█▏        | 1200/10000 [01:51<13:38, 10.75it/s]

Iter: 1200
Train_loss: 0.02644
Test_loss: 0.0335



 13%|█▎        | 1300/10000 [02:02<13:40, 10.60it/s]

Iter: 1300
Train_loss: 0.02508
Test_loss: 0.03247



 14%|█▍        | 1400/10000 [02:13<13:42, 10.45it/s]

Iter: 1400
Train_loss: 0.02487
Test_loss: 0.03292



 15%|█▌        | 1500/10000 [02:25<13:42, 10.33it/s]

Iter: 1500
Train_loss: 0.02528
Test_loss: 0.03276



 16%|█▌        | 1600/10000 [02:36<13:41, 10.22it/s]

Iter: 1600
Train_loss: 0.02438
Test_loss: 0.03189



 17%|█▋        | 1700/10000 [02:47<13:38, 10.14it/s]

Iter: 1700
Train_loss: 0.0235
Test_loss: 0.03232



 18%|█▊        | 1800/10000 [02:59<13:37, 10.03it/s]

Iter: 1800
Train_loss: 0.02232
Test_loss: 0.0303



 19%|█▉        | 1900/10000 [03:10<13:34,  9.95it/s]

Iter: 1900
Train_loss: 0.02305
Test_loss: 0.03006



 20%|██        | 2000/10000 [03:22<13:30,  9.87it/s]

Iter: 2000
Train_loss: 0.02178
Test_loss: 0.03091



 21%|██        | 2100/10000 [03:34<13:27,  9.78it/s]

Iter: 2100
Train_loss: 0.02228
Test_loss: 0.03193



 22%|██▏       | 2200/10000 [03:46<13:23,  9.71it/s]

Iter: 2200
Train_loss: 0.021
Test_loss: 0.02857



 23%|██▎       | 2300/10000 [03:58<13:17,  9.65it/s]

Iter: 2300
Train_loss: 0.02144
Test_loss: 0.03055



 24%|██▍       | 2400/10000 [04:10<13:12,  9.59it/s]

Iter: 2400
Train_loss: 0.01965
Test_loss: 0.02852



 25%|██▌       | 2500/10000 [04:22<13:07,  9.52it/s]

Iter: 2500
Train_loss: 0.02072
Test_loss: 0.02962



 26%|██▌       | 2600/10000 [04:34<13:01,  9.46it/s]

Iter: 2600
Train_loss: 0.02037
Test_loss: 0.0295



 27%|██▋       | 2699/10000 [04:46<12:56,  9.41it/s]

Iter: 2700
Train_loss: 0.02147
Test_loss: 0.03168



 28%|██▊       | 2800/10000 [04:58<12:47,  9.38it/s]

Iter: 2800
Train_loss: 0.02011
Test_loss: 0.03024



 29%|██▉       | 2900/10000 [05:09<12:38,  9.36it/s]

Iter: 2900
Train_loss: 0.01923
Test_loss: 0.02768



 30%|███       | 3000/10000 [05:21<12:29,  9.34it/s]

Iter: 3000
Train_loss: 0.01994
Test_loss: 0.0295



 31%|███       | 3099/10000 [05:31<12:18,  9.34it/s]

Iter: 3100
Train_loss: 0.01885
Test_loss: 0.02815



 32%|███▏      | 3199/10000 [05:42<12:07,  9.35it/s]

Iter: 3200
Train_loss: 0.01962
Test_loss: 0.02903



 33%|███▎      | 3300/10000 [05:52<11:56,  9.35it/s]

Iter: 3300
Train_loss: 0.01964
Test_loss: 0.02858



 34%|███▍      | 3399/10000 [06:03<11:45,  9.36it/s]

Iter: 3400
Train_loss: 0.01858
Test_loss: 0.02805



 35%|███▍      | 3499/10000 [06:13<11:34,  9.36it/s]

Iter: 3500
Train_loss: 0.01847
Test_loss: 0.02911



 36%|███▌      | 3600/10000 [06:24<11:23,  9.36it/s]

Iter: 3600
Train_loss: 0.01833
Test_loss: 0.02874



 37%|███▋      | 3700/10000 [06:36<11:15,  9.33it/s]

Iter: 3700
Train_loss: 0.01917
Test_loss: 0.02917



 38%|███▊      | 3800/10000 [06:48<11:06,  9.30it/s]

Iter: 3800
Train_loss: 0.01918
Test_loss: 0.02951



 39%|███▉      | 3900/10000 [07:01<10:59,  9.25it/s]

Iter: 3900
Train_loss: 0.01891
Test_loss: 0.02953



 40%|████      | 4000/10000 [07:14<10:51,  9.21it/s]

Iter: 4000
Train_loss: 0.01848
Test_loss: 0.02965



 41%|████      | 4100/10000 [07:26<10:41,  9.19it/s]

Iter: 4100
Train_loss: 0.01761
Test_loss: 0.0271



 42%|████▏     | 4200/10000 [07:37<10:32,  9.17it/s]

Iter: 4200
Train_loss: 0.01945
Test_loss: 0.02972



 43%|████▎     | 4300/10000 [07:49<10:22,  9.15it/s]

Iter: 4300
Train_loss: 0.01907
Test_loss: 0.02954



 44%|████▍     | 4400/10000 [08:01<10:13,  9.13it/s]

Iter: 4400
Train_loss: 0.01818
Test_loss: 0.02777



 45%|████▌     | 4500/10000 [08:13<10:03,  9.12it/s]

Iter: 4500
Train_loss: 0.0178
Test_loss: 0.02774



 46%|████▌     | 4600/10000 [08:25<09:53,  9.10it/s]

Iter: 4600
Train_loss: 0.0171
Test_loss: 0.02715



 47%|████▋     | 4700/10000 [08:37<09:43,  9.09it/s]

Iter: 4700
Train_loss: 0.01806
Test_loss: 0.02839



 48%|████▊     | 4800/10000 [08:49<09:33,  9.07it/s]

Iter: 4800
Train_loss: 0.01806
Test_loss: 0.02787



 49%|████▉     | 4900/10000 [09:02<09:24,  9.03it/s]

Iter: 4900
Train_loss: 0.01658
Test_loss: 0.02884



 50%|█████     | 5000/10000 [09:14<09:14,  9.02it/s]

Iter: 5000
Train_loss: 0.01743
Test_loss: 0.02763



 51%|█████     | 5100/10000 [09:26<09:04,  9.00it/s]

Iter: 5100
Train_loss: 0.01773
Test_loss: 0.02791



 52%|█████▏    | 5200/10000 [09:39<08:54,  8.97it/s]

Iter: 5200
Train_loss: 0.01928
Test_loss: 0.0308



 53%|█████▎    | 5300/10000 [09:51<08:44,  8.95it/s]

Iter: 5300
Train_loss: 0.01751
Test_loss: 0.02944



 54%|█████▍    | 5400/10000 [10:04<08:34,  8.93it/s]

Iter: 5400
Train_loss: 0.01627
Test_loss: 0.02586



 55%|█████▌    | 5500/10000 [10:15<08:23,  8.93it/s]

Iter: 5500
Train_loss: 0.01741
Test_loss: 0.02845



 56%|█████▌    | 5600/10000 [10:28<08:13,  8.92it/s]

Iter: 5600
Train_loss: 0.01693
Test_loss: 0.02761



 57%|█████▋    | 5700/10000 [10:39<08:02,  8.92it/s]

Iter: 5700
Train_loss: 0.01749
Test_loss: 0.02829



 58%|█████▊    | 5800/10000 [10:50<07:51,  8.92it/s]

Iter: 5800
Train_loss: 0.01726
Test_loss: 0.02676



 59%|█████▉    | 5900/10000 [11:01<07:39,  8.91it/s]

Iter: 5900
Train_loss: 0.01718
Test_loss: 0.02829



 60%|██████    | 6000/10000 [11:13<07:28,  8.91it/s]

Iter: 6000
Train_loss: 0.01725
Test_loss: 0.0277



 61%|██████    | 6100/10000 [11:26<07:19,  8.88it/s]

Iter: 6100
Train_loss: 0.01703
Test_loss: 0.02865



 62%|██████▏   | 6200/10000 [11:38<07:08,  8.87it/s]

Iter: 6200
Train_loss: 0.01736
Test_loss: 0.02867



 63%|██████▎   | 6300/10000 [11:50<06:57,  8.87it/s]

Iter: 6300
Train_loss: 0.01742
Test_loss: 0.02877



 64%|██████▍   | 6400/10000 [12:01<06:45,  8.87it/s]

Iter: 6400
Train_loss: 0.01696
Test_loss: 0.0273



 65%|██████▌   | 6500/10000 [12:13<06:34,  8.87it/s]

Iter: 6500
Train_loss: 0.01688
Test_loss: 0.02759



 66%|██████▌   | 6600/10000 [12:24<06:23,  8.87it/s]

Iter: 6600
Train_loss: 0.01687
Test_loss: 0.02723



 67%|██████▋   | 6700/10000 [12:35<06:12,  8.87it/s]

Iter: 6700
Train_loss: 0.01661
Test_loss: 0.02663



 68%|██████▊   | 6800/10000 [12:47<06:00,  8.86it/s]

Iter: 6800
Train_loss: 0.0167
Test_loss: 0.0276



 69%|██████▉   | 6900/10000 [12:58<05:49,  8.86it/s]

Iter: 6900
Train_loss: 0.01474
Test_loss: 0.02544



 70%|███████   | 7000/10000 [13:09<05:38,  8.86it/s]

Iter: 7000
Train_loss: 0.01555
Test_loss: 0.02605



 71%|███████   | 7100/10000 [13:21<05:27,  8.86it/s]

Iter: 7100
Train_loss: 0.01595
Test_loss: 0.02633



 72%|███████▏  | 7200/10000 [13:32<05:15,  8.86it/s]

Iter: 7200
Train_loss: 0.01619
Test_loss: 0.02584



 73%|███████▎  | 7300/10000 [13:43<05:04,  8.86it/s]

Iter: 7300
Train_loss: 0.01735
Test_loss: 0.02755



 74%|███████▍  | 7400/10000 [13:55<04:53,  8.86it/s]

Iter: 7400
Train_loss: 0.01722
Test_loss: 0.02809



 75%|███████▌  | 7500/10000 [14:06<04:42,  8.86it/s]

Iter: 7500
Train_loss: 0.01568
Test_loss: 0.02671



 76%|███████▌  | 7600/10000 [14:17<04:30,  8.86it/s]

Iter: 7600
Train_loss: 0.01591
Test_loss: 0.0268



 77%|███████▋  | 7700/10000 [14:30<04:20,  8.85it/s]

Iter: 7700
Train_loss: 0.01545
Test_loss: 0.02611



 78%|███████▊  | 7800/10000 [14:41<04:08,  8.84it/s]

Iter: 7800
Train_loss: 0.01602
Test_loss: 0.0278



 79%|███████▉  | 7900/10000 [14:53<03:57,  8.84it/s]

Iter: 7900
Train_loss: 0.01652
Test_loss: 0.02794



 80%|████████  | 8000/10000 [15:04<03:46,  8.85it/s]

Iter: 8000
Train_loss: 0.01579
Test_loss: 0.0262



 81%|████████  | 8100/10000 [15:15<03:34,  8.84it/s]

Iter: 8100
Train_loss: 0.01678
Test_loss: 0.02774



 82%|████████▏ | 8200/10000 [15:27<03:23,  8.84it/s]

Iter: 8200
Train_loss: 0.01722
Test_loss: 0.02774



 83%|████████▎ | 8300/10000 [15:38<03:12,  8.84it/s]

Iter: 8300
Train_loss: 0.01531
Test_loss: 0.02591



 84%|████████▍ | 8400/10000 [15:49<03:00,  8.84it/s]

Iter: 8400
Train_loss: 0.0164
Test_loss: 0.02688



 85%|████████▌ | 8500/10000 [16:02<02:49,  8.83it/s]

Iter: 8500
Train_loss: 0.01574
Test_loss: 0.0266



 86%|████████▌ | 8600/10000 [16:14<02:38,  8.82it/s]

Iter: 8600
Train_loss: 0.01602
Test_loss: 0.02695



 87%|████████▋ | 8700/10000 [16:26<02:27,  8.82it/s]

Iter: 8700
Train_loss: 0.01574
Test_loss: 0.02715



 88%|████████▊ | 8800/10000 [16:38<02:16,  8.81it/s]

Iter: 8800
Train_loss: 0.01652
Test_loss: 0.02743



 89%|████████▉ | 8900/10000 [16:50<02:04,  8.81it/s]

Iter: 8900
Train_loss: 0.01652
Test_loss: 0.0275



 90%|█████████ | 9000/10000 [17:02<01:53,  8.80it/s]

Iter: 9000
Train_loss: 0.01582
Test_loss: 0.02603



 91%|█████████ | 9100/10000 [17:14<01:42,  8.80it/s]

Iter: 9100
Train_loss: 0.01621
Test_loss: 0.02701



 92%|█████████▏| 9200/10000 [17:26<01:30,  8.79it/s]

Iter: 9200
Train_loss: 0.01691
Test_loss: 0.02807



 93%|█████████▎| 9300/10000 [17:38<01:19,  8.79it/s]

Iter: 9300
Train_loss: 0.01639
Test_loss: 0.02689



 94%|█████████▍| 9400/10000 [17:50<01:08,  8.78it/s]

Iter: 9400
Train_loss: 0.01646
Test_loss: 0.02782



 95%|█████████▌| 9500/10000 [18:02<00:56,  8.78it/s]

Iter: 9500
Train_loss: 0.01571
Test_loss: 0.02743



 96%|█████████▌| 9600/10000 [18:14<00:45,  8.78it/s]

Iter: 9600
Train_loss: 0.01603
Test_loss: 0.02758



 97%|█████████▋| 9700/10000 [18:27<00:34,  8.76it/s]

Iter: 9700
Train_loss: 0.01584
Test_loss: 0.02731



 98%|█████████▊| 9800/10000 [18:39<00:22,  8.75it/s]

Iter: 9800
Train_loss: 0.01612
Test_loss: 0.02744



 99%|█████████▉| 9900/10000 [18:51<00:11,  8.75it/s]

Iter: 9900
Train_loss: 0.01659
Test_loss: 0.02847



100%|██████████| 10000/10000 [19:03<00:00,  8.75it/s]
