In [1]:
import numpy as np  
from matplotlib import pyplot as plt  
from __future__ import unicode_literals, print_function, division
from io import open
import unicodedata
import string
import re
import random
import torch
import torch.nn.functional as nn
import torch.autograd as autograd
from torch import optim
from torch.autograd import Variable
import torch.nn.functional as F
import os
import matplotlib.gridspec as gridspec
import os
from tqdm import tqdm
%matplotlib inline 

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import data_utils

In [4]:
data = data_utils.load_training_data()
data.shape

(7352, 128)

In [5]:
# system parameters
n_epochs = 500
batch_size = 10
lr = 0.001 # adam:learning rate
b1 = 0.9 # adam: decay of first order momentum of gradient
b2 = 0.999 # adam: decay of first order momentum of gradient
img_size = 28 # size of each image dimension
channels = 1 # number of image channels

alpha= 10
Dim = 10
p_miss = 0.5
p_hint = 0.9
data_loader = data_utils.DataLoader(data=data,batch_size=batch_size, num_steps=Dim)

cuda = True if torch.cuda.is_available() else False
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

In [6]:
X_mb,_ = data_loader.next_batch()

In [7]:
X_mb.shape

(10, 10)

In [8]:
def xavier_init(size):
    in_dim = size[0]
    xavier_stddev = 1. / np.sqrt(in_dim / 2.)
    return Variable(torch.randn(*size) * xavier_stddev, requires_grad=True)

In [9]:
""" ==================== DISCRIMINATOR ======================== """
D_W1 = xavier_init(size = [Dim*2, 256])     # Data + Hint as inputs
D_b1 = Variable(torch.zeros(256),requires_grad=True)

D_W2 = xavier_init(size = [256, 128])
D_b2 = Variable(torch.zeros(128),requires_grad=True)

D_W3 = xavier_init(size = [128, Dim])
D_b3 = Variable(torch.zeros(Dim),requires_grad=True)  

def discriminator(x, m, g, h):
    inp = m * x + (1-m) * g  # Replace missing values to the imputed values
    inputs = torch.cat([inp,h],1)  # Hint + Data Concatenate
    D_h1 = nn.relu(torch.matmul(inputs, D_W1) + D_b1)
    D_h2 = nn.relu(torch.matmul(D_h1, D_W2) + D_b2)
    D_logit = torch.matmul(D_h2, D_W3) + D_b3
    D_prob = nn.sigmoid(D_logit)  # [0,1] Probability Output
    
    return D_prob

In [10]:
""" ==================== GENERATOR ======================== """
G_W1 = xavier_init(size = [Dim*2, 256])    # Data + Mask as inputs (Random Noises are in Missing Components)
G_b1 = Variable(torch.zeros(256),requires_grad=True)

G_W2 = xavier_init(size = [256, 128])
G_b2 = Variable(torch.zeros(128),requires_grad=True)

G_W3 = xavier_init(size = [128, Dim])
G_b3 = Variable(torch.zeros(Dim),requires_grad=True)

def generator(x,z,m):
    inp = m * x + (1-m) * z  # Fill in random noise on the missing values
    inputs = torch.cat([inp,m],1)  # Mask + Data Concatenate
    G_h1 = nn.relu(torch.matmul(inputs, G_W1) + G_b1)
    G_h2 = nn.relu(torch.matmul(G_h1, G_W2) + G_b2)
    G_prob = nn.sigmoid(torch.matmul(G_h2, G_W3) + G_b3) # [0,1] normalized Output
    
    return G_prob

In [11]:
G_params = [G_W1, G_W2, G_W3, G_b1, G_b2, G_b3]
D_params = [D_W1, D_W2, D_W3, D_b1, D_b2, D_b3]
params = G_params + D_params

In [12]:
""" ===================== TRAINING ======================== """
def plot(samples):
    fig = plt.figure(figsize = (5,5))
    gs = gridspec.GridSpec(5,5)
    gs.update(wspace=0.05, hspace=0.05)
    
    for i, sample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(sample.reshape(28,28), cmap='Greys_r')
        
    return fig
def reset_grad():
    for p in params:
        if p.grad is not None:
            data = p.grad.data
            p.grad = Variable(data.new().resize_as_(data).zero_())
            
G_solver = optim.Adam(G_params, lr=1e-3)
D_solver = optim.Adam(D_params, lr=1e-3)

def sample_Z(m, n):
    return np.random.uniform(0., 1., size = [m, n])        

# Mask Vector and Hint Vector Generation
def sample_M(m, n, p):
    A = np.random.uniform(0., 1., size = [m, n])
    B = A > p
    C = 1.*B
    return C
# make output file
if not os.path.exists('MNIST_Impuation_output/'):
    os.makedirs('MNIST_Impuation_output/')
i = 1
for it in tqdm(range(n_epochs)):
    data_loader.reset()
    while data_loader.has_next():
        X_mb,_ = data_loader.next_batch()

        X_mb = Variable(torch.from_numpy(X_mb.astype('float32')))
        Z_mb = sample_Z(batch_size, Dim) 
        Z_mb = Variable(torch.from_numpy(Z_mb.astype('float32')))
        M_mb = sample_M(batch_size, Dim, p_miss)
        M_mb = Variable(torch.from_numpy(M_mb.astype('float32')))
        H_mb1 = sample_M(batch_size, Dim, 1-p_hint)
        H_mb1 = Variable(torch.from_numpy(H_mb1.astype('float32')))
        H_mb = M_mb * H_mb1

        New_X_mb = M_mb * X_mb + (1-M_mb) * Z_mb  # Missing Data Introduce
        # Dicriminator forward-loss-backward-update
        G_sample = generator(X_mb,Z_mb,M_mb)
        D_sample = discriminator(X_mb, M_mb, G_sample, H_mb)
        D_loss = -torch.mean(M_mb * torch.log(D_sample + 1e-8) + (1-M_mb) * torch.log(1. - D_sample + 1e-8)) * 2
        D_loss.backward()
        D_solver.step()
        reset_grad()

        # Generator forward-loss-backward-update
        G_sample = generator(X_mb,Z_mb,M_mb)
        D_sample = discriminator(X_mb, M_mb, G_sample, H_mb)
        G_loss1 = -torch.mean((1-M_mb) * torch.log(D_sample + 1e-8)) / torch.mean(1-M_mb)
        MSE_train_loss = torch.mean((M_mb * X_mb - M_mb * G_sample)**2) / torch.mean(M_mb)
        MSE_test_loss = torch.mean(((1-M_mb) * X_mb - (1-M_mb)*G_sample)**2) / torch.mean(1-M_mb)
        G_loss = G_loss1  + alpha * MSE_train_loss 
        G_loss.backward()
        G_solver.step()
        reset_grad()
    
    # Print and Plot
    
 
    
        
    if it % 10 == 0:
        print('Iter: {}'.format(it))
        print('Train_loss: {:.4}'.format(MSE_train_loss))
        print('Test_loss: {:.4}'.format(MSE_test_loss))
        print()
        



  0%|          | 1/500 [00:02<18:52,  2.27s/it]

Iter: 0
Train_loss: 0.0312
Test_loss: 0.1131



  2%|▏         | 11/500 [00:28<21:09,  2.60s/it]

Iter: 10
Train_loss: 0.02722
Test_loss: 0.07947



  4%|▍         | 21/500 [00:56<21:28,  2.69s/it]

Iter: 20
Train_loss: 0.03603
Test_loss: 0.07303



  6%|▌         | 31/500 [01:26<21:45,  2.78s/it]

Iter: 30
Train_loss: 0.04487
Test_loss: 0.06068



  8%|▊         | 41/500 [02:02<22:55,  3.00s/it]

Iter: 40
Train_loss: 0.03208
Test_loss: 0.07409



 10%|█         | 51/500 [02:38<23:16,  3.11s/it]

Iter: 50
Train_loss: 0.04124
Test_loss: 0.06638



 12%|█▏        | 61/500 [03:15<23:26,  3.20s/it]

Iter: 60
Train_loss: 0.04253
Test_loss: 0.06299



 14%|█▍        | 71/500 [03:58<24:00,  3.36s/it]

Iter: 70
Train_loss: 0.04422
Test_loss: 0.05646



 16%|█▌        | 81/500 [04:39<24:06,  3.45s/it]

Iter: 80
Train_loss: 0.05711
Test_loss: 0.05682



 18%|█▊        | 91/500 [05:16<23:41,  3.47s/it]

Iter: 90
Train_loss: 0.04132
Test_loss: 0.06526



 20%|██        | 101/500 [05:54<23:20,  3.51s/it]

Iter: 100
Train_loss: 0.06124
Test_loss: 0.03618



 22%|██▏       | 111/500 [06:36<23:08,  3.57s/it]

Iter: 110
Train_loss: 0.03777
Test_loss: 0.07978



 24%|██▍       | 121/500 [07:18<22:53,  3.63s/it]

Iter: 120
Train_loss: 0.02745
Test_loss: 0.07609



 26%|██▌       | 131/500 [08:04<22:45,  3.70s/it]

Iter: 130
Train_loss: 0.04858
Test_loss: 0.05517



 28%|██▊       | 141/500 [08:50<22:29,  3.76s/it]

Iter: 140
Train_loss: 0.03703
Test_loss: 0.08939



 30%|███       | 151/500 [09:35<22:10,  3.81s/it]

Iter: 150
Train_loss: 0.04207
Test_loss: 0.07288



 30%|███       | 152/500 [09:40<22:08,  3.82s/it]

KeyboardInterrupt: 

In [29]:
def reset_grad():
    for p in params:
        if p.grad is not None:
            data = p.grad.data
            p.grad = Variable(data.new().resize_as_(data).zero_())
            
G_solver = optim.Adam(G_params, lr=1e-3)
D_solver = optim.Adam(D_params, lr=1e-3)

def sample_Z(m, n):
    return np.random.uniform(0., 1., size = [m, n])        

# Mask Vector and Hint Vector Generation
def sample_M(m, n, p):
    A = np.random.uniform(0., 1., size = [m, n])
    B = A > p
    C = 1.*B
    return C
# make output file
if not os.path.exists('MNIST_Impuation_output/'):
    os.makedirs('MNIST_Impuation_output/')
i = 1
for it in tqdm(range(n_epochs)):
    data_loader.reset()
    while data_loader.has_next():
        X_mb,_ = data_loader.next_batch()

        X_mb = Variable(torch.from_numpy(X_mb.astype('float32')))
        Z_mb = sample_Z(batch_size, Dim) 
        Z_mb = Variable(torch.from_numpy(Z_mb.astype('float32')))
        M_mb = sample_M(batch_size, Dim, p_miss)
        M_mb = Variable(torch.from_numpy(M_mb.astype('float32')))
        H_mb1 = sample_M(batch_size, Dim, 1-p_hint)
        H_mb1 = Variable(torch.from_numpy(H_mb1.astype('float32')))
        H_mb = M_mb * H_mb1

        New_X_mb = M_mb * X_mb + (1-M_mb) * Z_mb  # Missing Data Introduce
        # Dicriminator forward-loss-backward-update
        G_sample = generator(X_mb,Z_mb,M_mb)
        D_sample = discriminator(X_mb, M_mb, G_sample, H_mb)
        D_loss = -torch.mean(M_mb * torch.log(D_sample + 1e-8) + (1-M_mb) * torch.log(1. - D_sample + 1e-8)) * 2
        D_loss.backward()
        D_solver.step()
        reset_grad()

        # Generator forward-loss-backward-update
        G_sample = generator(X_mb,Z_mb,M_mb)
        D_sample = discriminator(X_mb, M_mb, G_sample, H_mb)
        G_loss1 = -torch.mean((1-M_mb) * torch.log(D_sample + 1e-8)) / torch.mean(1-M_mb)
        MSE_train_loss = torch.mean((M_mb * X_mb - M_mb * G_sample)**2) / torch.mean(M_mb)
        MSE_test_loss = torch.mean(((1-M_mb) * X_mb - (1-M_mb)*G_sample)**2) / torch.mean(1-M_mb)
        G_loss = G_loss1  + alpha * MSE_train_loss 
        G_loss.backward()
        G_solver.step()
        reset_grad()
    
    # Print and Plot
    
 
    
        
    if it % 10 == 0:
        print('Iter: {}'.format(it))
        print('Train_loss: {:.4}'.format(MSE_train_loss))
        print('Test_loss: {:.4}'.format(MSE_test_loss))
        print()
        




  1%|          | 3/500 [00:00<00:22, 21.88it/s]

Iter: 0
Train_loss: 0.02545
Test_loss: 0.03415



  3%|▎         | 16/500 [00:00<00:21, 22.63it/s]

Iter: 10
Train_loss: 0.007392
Test_loss: 0.1816



  5%|▌         | 25/500 [00:01<00:19, 23.96it/s]

Iter: 20
Train_loss: 0.001457
Test_loss: 0.07146



  7%|▋         | 34/500 [00:01<00:21, 21.95it/s]

Iter: 30
Train_loss: 0.0004828
Test_loss: 0.1258



  9%|▉         | 45/500 [00:02<00:21, 21.46it/s]

Iter: 40
Train_loss: 0.0004133
Test_loss: 0.04844



 11%|█         | 54/500 [00:02<00:20, 21.32it/s]

Iter: 50
Train_loss: 0.0004208
Test_loss: 0.003301



 13%|█▎        | 64/500 [00:03<00:20, 21.06it/s]

Iter: 60
Train_loss: 0.0004194
Test_loss: 0.001083



 15%|█▍        | 73/500 [00:03<00:20, 20.65it/s]

Iter: 70
Train_loss: 0.0004051
Test_loss: 0.0004437



 17%|█▋        | 85/500 [00:04<00:20, 20.28it/s]

Iter: 80
Train_loss: 0.0004161
Test_loss: 0.004798



 19%|█▉        | 94/500 [00:04<00:20, 19.86it/s]

Iter: 90
Train_loss: 0.0004117
Test_loss: 0.0004348



 21%|██        | 104/500 [00:05<00:20, 19.80it/s]

Iter: 100
Train_loss: 0.0004264
Test_loss: 0.0004209



 23%|██▎       | 114/500 [00:05<00:20, 19.30it/s]

Iter: 110
Train_loss: 0.0004293
Test_loss: 0.0007848



 25%|██▌       | 125/500 [00:06<00:19, 19.44it/s]

Iter: 120
Train_loss: 0.000424
Test_loss: 0.0006126



 27%|██▋       | 135/500 [00:06<00:18, 19.51it/s]

Iter: 130
Train_loss: 0.0004255
Test_loss: 0.001559



 29%|██▊       | 143/500 [00:07<00:18, 19.28it/s]

Iter: 140
Train_loss: 0.0004173
Test_loss: 0.001278



 31%|███       | 155/500 [00:08<00:17, 19.17it/s]

Iter: 150
Train_loss: 0.000422
Test_loss: 0.001097



 33%|███▎      | 164/500 [00:08<00:17, 19.01it/s]

Iter: 160
Train_loss: 0.000403
Test_loss: 0.0004545



 35%|███▍      | 174/500 [00:09<00:17, 19.04it/s]

Iter: 170
Train_loss: 0.000438
Test_loss: 0.002819



 37%|███▋      | 183/500 [00:09<00:16, 18.99it/s]

Iter: 180
Train_loss: 0.0004301
Test_loss: 0.0005517



 39%|███▊      | 193/500 [00:10<00:16, 18.74it/s]

Iter: 190
Train_loss: 0.0004119
Test_loss: 0.0004646



 41%|████      | 203/500 [00:10<00:15, 18.65it/s]

Iter: 200
Train_loss: 0.00042
Test_loss: 0.0004979



 43%|████▎     | 214/500 [00:11<00:15, 18.49it/s]

Iter: 210
Train_loss: 0.0004209
Test_loss: 0.0007101



 45%|████▍     | 223/500 [00:11<00:14, 18.58it/s]

Iter: 220
Train_loss: 0.0004209
Test_loss: 0.0004937



 47%|████▋     | 233/500 [00:12<00:14, 18.61it/s]

Iter: 230
Train_loss: 0.0004182
Test_loss: 0.0004845



 49%|████▉     | 244/500 [00:13<00:13, 18.49it/s]

Iter: 240
Train_loss: 0.000412
Test_loss: 0.000429



 51%|█████     | 254/500 [00:13<00:13, 18.27it/s]

Iter: 250
Train_loss: 0.0004221
Test_loss: 0.0004311



 53%|█████▎    | 263/500 [00:14<00:13, 18.21it/s]

Iter: 260
Train_loss: 0.0004119
Test_loss: 0.00878



 55%|█████▌    | 275/500 [00:15<00:12, 18.18it/s]

Iter: 270
Train_loss: 0.0004359
Test_loss: 0.007431



 57%|█████▋    | 283/500 [00:15<00:11, 18.20it/s]

Iter: 280
Train_loss: 0.0004128
Test_loss: 0.0004438



 58%|█████▊    | 292/500 [00:16<00:11, 18.07it/s]

Iter: 290
Train_loss: 0.0004213
Test_loss: 0.0004779



 61%|██████    | 303/500 [00:16<00:10, 17.99it/s]

Iter: 300
Train_loss: 0.0007285
Test_loss: 0.02115



 63%|██████▎   | 313/500 [00:17<00:10, 17.87it/s]

Iter: 310
Train_loss: 0.0004058
Test_loss: 0.00182



 65%|██████▍   | 323/500 [00:18<00:09, 17.76it/s]

Iter: 320
Train_loss: 0.0004312
Test_loss: 0.001976



 67%|██████▋   | 333/500 [00:18<00:09, 17.62it/s]

Iter: 330
Train_loss: 0.0004345
Test_loss: 0.0004252



 69%|██████▊   | 343/500 [00:19<00:08, 17.52it/s]

Iter: 340
Train_loss: 0.0004235
Test_loss: 0.0004577



 71%|███████   | 353/500 [00:20<00:08, 17.43it/s]

Iter: 350
Train_loss: 0.0004212
Test_loss: 0.0004213



 73%|███████▎  | 363/500 [00:20<00:07, 17.35it/s]

Iter: 360
Train_loss: 0.000418
Test_loss: 0.00044



 75%|███████▍  | 373/500 [00:21<00:07, 17.32it/s]

Iter: 370
Train_loss: 0.0004387
Test_loss: 0.0005564



 77%|███████▋  | 383/500 [00:22<00:06, 17.31it/s]

Iter: 380
Train_loss: 0.0004225
Test_loss: 0.0005147



 79%|███████▊  | 393/500 [00:22<00:06, 17.22it/s]

Iter: 390
Train_loss: 0.0004248
Test_loss: 0.0004286



 81%|████████  | 403/500 [00:23<00:05, 17.14it/s]

Iter: 400
Train_loss: 0.0004124
Test_loss: 0.0004351



 83%|████████▎ | 413/500 [00:24<00:05, 17.11it/s]

Iter: 410
Train_loss: 0.0004217
Test_loss: 0.0008817



 85%|████████▍ | 423/500 [00:24<00:04, 17.04it/s]

Iter: 420
Train_loss: 0.0004221
Test_loss: 0.0004731



 87%|████████▋ | 434/500 [00:25<00:03, 17.05it/s]

Iter: 430
Train_loss: 0.0004316
Test_loss: 0.0006929



 89%|████████▉ | 444/500 [00:26<00:03, 17.01it/s]

Iter: 440
Train_loss: 0.000423
Test_loss: 0.0004167



 91%|█████████ | 454/500 [00:26<00:02, 17.05it/s]

Iter: 450
Train_loss: 0.0004273
Test_loss: 0.0004337



 93%|█████████▎| 464/500 [00:27<00:02, 17.07it/s]

Iter: 460
Train_loss: 0.0004239
Test_loss: 0.0004139



 95%|█████████▍| 474/500 [00:27<00:01, 17.04it/s]

Iter: 470
Train_loss: 0.0004199
Test_loss: 0.0004178



 96%|█████████▋| 482/500 [00:28<00:01, 17.00it/s]

Iter: 480
Train_loss: 0.0004323
Test_loss: 0.003546



 99%|█████████▉| 494/500 [00:29<00:00, 16.93it/s]

Iter: 490
Train_loss: 0.0004038
Test_loss: 0.0005646



100%|██████████| 500/500 [00:29<00:00, 16.95it/s]


In [12]:
n_epochs = 1000
def reset_grad():
    for p in params:
        if p.grad is not None:
            data = p.grad.data
            p.grad = Variable(data.new().resize_as_(data).zero_())
            
G_solver = optim.Adam(G_params, lr=1e-3)
D_solver = optim.Adam(D_params, lr=1e-3)

def sample_Z(m, n):
    return np.random.uniform(0., 1., size = [m, n])        

# Mask Vector and Hint Vector Generation
def sample_M(m, n, p):
    A = np.random.uniform(0., 1., size = [m, n])
    B = A > p
    C = 1.*B
    return C
# make output file
if not os.path.exists('MNIST_Impuation_output/'):
    os.makedirs('MNIST_Impuation_output/')
i = 1
for it in tqdm(range(n_epochs)):
    data_loader.reset()
    while data_loader.has_next():
        X_mb,_ = data_loader.next_batch()

        X_mb = Variable(torch.from_numpy(X_mb.astype('float32')))
        Z_mb = sample_Z(batch_size, Dim) 
        Z_mb = Variable(torch.from_numpy(Z_mb.astype('float32')))
        M_mb = sample_M(batch_size, Dim, p_miss)
        M_mb = Variable(torch.from_numpy(M_mb.astype('float32')))
        H_mb1 = sample_M(batch_size, Dim, 1-p_hint)
        H_mb1 = Variable(torch.from_numpy(H_mb1.astype('float32')))
        H_mb = M_mb * H_mb1

        New_X_mb = M_mb * X_mb + (1-M_mb) * Z_mb  # Missing Data Introduce
        # Dicriminator forward-loss-backward-update
        G_sample = generator(X_mb,Z_mb,M_mb)
        D_sample = discriminator(X_mb, M_mb, G_sample, H_mb)
        D_loss = -torch.mean(M_mb * torch.log(D_sample + 1e-8) + (1-M_mb) * torch.log(1. - D_sample + 1e-8)) * 2
        D_loss.backward()
        D_solver.step()
        reset_grad()

        # Generator forward-loss-backward-update
        G_sample = generator(X_mb,Z_mb,M_mb)
        D_sample = discriminator(X_mb, M_mb, G_sample, H_mb)
        G_loss1 = -torch.mean((1-M_mb) * torch.log(D_sample + 1e-8)) / torch.mean(1-M_mb)
        MSE_train_loss = torch.mean((M_mb * X_mb - M_mb * G_sample)**2) / torch.mean(M_mb)
        MSE_test_loss = torch.mean(((1-M_mb) * X_mb - (1-M_mb)*G_sample)**2) / torch.mean(1-M_mb)
        G_loss = G_loss1  + alpha * MSE_train_loss 
        G_loss.backward()
        G_solver.step()
        reset_grad()
    
    # Print and Plot
    
 
    
        
    if it % 10 == 0:
        print('Iter: {}'.format(it))
        print('Train_loss: {:.4}'.format(MSE_train_loss))
        print('Test_loss: {:.4}'.format(MSE_test_loss))
        print()
        

  0%|          | 5/1000 [00:00<00:46, 21.42it/s]

Iter: 0
Train_loss: 0.0188
Test_loss: 0.04297



  1%|▏         | 14/1000 [00:00<00:44, 22.15it/s]

Iter: 10
Train_loss: 0.003816
Test_loss: 0.1588



  3%|▎         | 26/1000 [00:01<00:40, 23.77it/s]

Iter: 20
Train_loss: 0.006954
Test_loss: 0.2467



  3%|▎         | 32/1000 [00:01<00:42, 22.84it/s]

Iter: 30
Train_loss: 0.0004292
Test_loss: 0.01736



  4%|▍         | 45/1000 [00:02<00:45, 21.06it/s]

Iter: 40
Train_loss: 0.0004361
Test_loss: 0.02705



  5%|▌         | 54/1000 [00:02<00:44, 21.13it/s]

Iter: 50
Train_loss: 0.0004209
Test_loss: 0.003516



  6%|▋         | 63/1000 [00:02<00:43, 21.53it/s]

Iter: 60
Train_loss: 0.0004332
Test_loss: 0.0161



  8%|▊         | 75/1000 [00:03<00:43, 21.42it/s]

Iter: 70
Train_loss: 0.0004223
Test_loss: 0.002604



  8%|▊         | 84/1000 [00:04<00:45, 20.19it/s]

Iter: 80
Train_loss: 0.0004199
Test_loss: 0.003406



 10%|▉         | 95/1000 [00:04<00:44, 20.23it/s]

Iter: 90
Train_loss: 0.0004314
Test_loss: 0.0006435



 10%|█         | 105/1000 [00:05<00:44, 20.21it/s]

Iter: 100
Train_loss: 0.0004069
Test_loss: 0.0005571



 11%|█▏        | 114/1000 [00:05<00:43, 20.27it/s]

Iter: 110
Train_loss: 0.0004134
Test_loss: 0.0004315



 12%|█▏        | 123/1000 [00:06<00:43, 20.30it/s]

Iter: 120
Train_loss: 0.0004308
Test_loss: 0.000413



 14%|█▎        | 135/1000 [00:06<00:42, 20.34it/s]

Iter: 130
Train_loss: 0.0004227
Test_loss: 0.01157



 14%|█▍        | 141/1000 [00:06<00:42, 20.36it/s]

Iter: 140
Train_loss: 0.0004244
Test_loss: 0.0004457



 15%|█▌        | 154/1000 [00:07<00:42, 19.72it/s]

Iter: 150
Train_loss: 0.0004173
Test_loss: 0.0004916



 16%|█▋        | 164/1000 [00:08<00:42, 19.73it/s]

Iter: 160
Train_loss: 0.0004198
Test_loss: 0.0004289



 17%|█▋        | 174/1000 [00:08<00:41, 19.74it/s]

Iter: 170
Train_loss: 0.0004123
Test_loss: 0.0004327



 18%|█▊        | 183/1000 [00:09<00:41, 19.67it/s]

Iter: 180
Train_loss: 0.0004134
Test_loss: 0.0004301



 19%|█▉        | 192/1000 [00:09<00:41, 19.47it/s]

Iter: 190
Train_loss: 0.000429
Test_loss: 0.0005987



 20%|██        | 202/1000 [00:10<00:41, 19.13it/s]

Iter: 200
Train_loss: 0.000428
Test_loss: 0.02738



 21%|██        | 212/1000 [00:11<00:42, 18.74it/s]

Iter: 210
Train_loss: 0.0004089
Test_loss: 0.0006098



 22%|██▏       | 223/1000 [00:11<00:41, 18.81it/s]

Iter: 220
Train_loss: 0.0004174
Test_loss: 0.000429



 23%|██▎       | 233/1000 [00:12<00:40, 18.83it/s]

Iter: 230
Train_loss: 0.0004132
Test_loss: 0.0006308



 24%|██▍       | 244/1000 [00:12<00:40, 18.90it/s]

Iter: 240
Train_loss: 0.0004168
Test_loss: 0.0004223



 25%|██▌       | 253/1000 [00:13<00:39, 18.95it/s]

Iter: 250
Train_loss: 0.0004193
Test_loss: 0.0004732



 26%|██▋       | 265/1000 [00:13<00:38, 19.02it/s]

Iter: 260
Train_loss: 0.0004186
Test_loss: 0.0004257



 27%|██▋       | 273/1000 [00:14<00:38, 18.95it/s]

Iter: 270
Train_loss: 0.0004148
Test_loss: 0.0004445



 28%|██▊       | 282/1000 [00:14<00:37, 18.96it/s]

Iter: 280
Train_loss: 0.0004345
Test_loss: 0.0006208



 29%|██▉       | 292/1000 [00:15<00:38, 18.49it/s]

Iter: 290
Train_loss: 0.0004139
Test_loss: 0.007278



 30%|███       | 302/1000 [00:16<00:37, 18.40it/s]

Iter: 300
Train_loss: 0.0004115
Test_loss: 0.000463



 31%|███▏      | 314/1000 [00:17<00:37, 18.22it/s]

Iter: 310
Train_loss: 0.0004252
Test_loss: 0.0007687



 32%|███▏      | 322/1000 [00:17<00:37, 18.07it/s]

Iter: 320
Train_loss: 0.0004273
Test_loss: 0.001602



 33%|███▎      | 332/1000 [00:18<00:37, 17.82it/s]

Iter: 330
Train_loss: 0.0004351
Test_loss: 0.0007018



 34%|███▍      | 342/1000 [00:19<00:37, 17.63it/s]

Iter: 340
Train_loss: 0.0004456
Test_loss: 0.001192



 35%|███▌      | 354/1000 [00:20<00:36, 17.58it/s]

Iter: 350
Train_loss: 0.0004245
Test_loss: 0.0007127



 36%|███▌      | 362/1000 [00:20<00:36, 17.46it/s]

Iter: 360
Train_loss: 0.0004222
Test_loss: 0.001569



 37%|███▋      | 374/1000 [00:21<00:36, 17.35it/s]

Iter: 370
Train_loss: 0.0003977
Test_loss: 0.0004495



 38%|███▊      | 382/1000 [00:22<00:35, 17.32it/s]

Iter: 380
Train_loss: 0.000425
Test_loss: 0.001168



 39%|███▉      | 394/1000 [00:22<00:34, 17.33it/s]

Iter: 390
Train_loss: 0.0004256
Test_loss: 0.00382



 40%|████      | 404/1000 [00:23<00:34, 17.36it/s]

Iter: 400
Train_loss: 0.0004368
Test_loss: 0.0003972



 41%|████      | 412/1000 [00:23<00:34, 17.29it/s]

Iter: 410
Train_loss: 0.000421
Test_loss: 0.006601



 42%|████▏     | 424/1000 [00:24<00:33, 17.26it/s]

Iter: 420
Train_loss: 0.000407
Test_loss: 0.0004646



 43%|████▎     | 432/1000 [00:25<00:33, 17.18it/s]

Iter: 430
Train_loss: 0.000417
Test_loss: 0.0008254



 44%|████▍     | 444/1000 [00:26<00:32, 17.06it/s]

Iter: 440
Train_loss: 0.0004137
Test_loss: 0.0004359



 45%|████▌     | 454/1000 [00:26<00:32, 16.96it/s]

Iter: 450
Train_loss: 0.0004081
Test_loss: 0.0005549



 46%|████▋     | 464/1000 [00:27<00:31, 16.89it/s]

Iter: 460
Train_loss: 0.0004253
Test_loss: 0.004333



 47%|████▋     | 472/1000 [00:28<00:31, 16.82it/s]

Iter: 470
Train_loss: 0.0004249
Test_loss: 0.000579



 48%|████▊     | 484/1000 [00:28<00:30, 16.85it/s]

Iter: 480
Train_loss: 0.0004147
Test_loss: 0.000429



 49%|████▉     | 492/1000 [00:29<00:30, 16.84it/s]

Iter: 490
Train_loss: 0.0004143
Test_loss: 0.002114



 50%|█████     | 502/1000 [00:29<00:29, 16.79it/s]

Iter: 500
Train_loss: 0.0004258
Test_loss: 0.0004806



 51%|█████     | 512/1000 [00:30<00:29, 16.75it/s]

Iter: 510
Train_loss: 0.0004095
Test_loss: 0.0007398



 52%|█████▏    | 524/1000 [00:31<00:28, 16.69it/s]

Iter: 520
Train_loss: 0.0004306
Test_loss: 0.00065



 53%|█████▎    | 532/1000 [00:31<00:28, 16.66it/s]

Iter: 530
Train_loss: 0.0004133
Test_loss: 0.0004233



 54%|█████▍    | 542/1000 [00:32<00:27, 16.56it/s]

Iter: 540
Train_loss: 0.000419
Test_loss: 0.0004403



 55%|█████▌    | 552/1000 [00:33<00:27, 16.46it/s]

Iter: 550
Train_loss: 0.000411
Test_loss: 0.0009396



 56%|█████▋    | 564/1000 [00:34<00:26, 16.38it/s]

Iter: 560
Train_loss: 0.0004234
Test_loss: 0.0004327



 57%|█████▋    | 572/1000 [00:35<00:26, 16.31it/s]

Iter: 570
Train_loss: 0.0004284
Test_loss: 0.0004379



 58%|█████▊    | 582/1000 [00:35<00:25, 16.30it/s]

Iter: 580
Train_loss: 0.0004307
Test_loss: 0.0004599



 59%|█████▉    | 592/1000 [00:36<00:25, 16.28it/s]

Iter: 590
Train_loss: 0.000418
Test_loss: 0.0004373



 60%|██████    | 602/1000 [00:37<00:24, 16.26it/s]

Iter: 600
Train_loss: 0.0004329
Test_loss: 0.0008711



 61%|██████▏   | 614/1000 [00:37<00:23, 16.25it/s]

Iter: 610
Train_loss: 0.000417
Test_loss: 0.001084



 62%|██████▏   | 624/1000 [00:38<00:23, 16.23it/s]

Iter: 620
Train_loss: 0.0004148
Test_loss: 0.0005511



 63%|██████▎   | 632/1000 [00:39<00:22, 16.20it/s]

Iter: 630
Train_loss: 0.0004192
Test_loss: 0.000419



 64%|██████▍   | 644/1000 [00:39<00:21, 16.22it/s]

Iter: 640
Train_loss: 0.0004203
Test_loss: 0.003249



 65%|██████▌   | 652/1000 [00:40<00:21, 16.22it/s]

Iter: 650
Train_loss: 0.0004187
Test_loss: 0.0004249



 66%|██████▋   | 664/1000 [00:40<00:20, 16.21it/s]

Iter: 660
Train_loss: 0.0004174
Test_loss: 0.0004208



 67%|██████▋   | 674/1000 [00:41<00:20, 16.23it/s]

Iter: 670
Train_loss: 0.0004319
Test_loss: 0.002031



 68%|██████▊   | 682/1000 [00:41<00:19, 16.24it/s]

Iter: 680
Train_loss: 0.0004102
Test_loss: 0.0004248



 69%|██████▉   | 694/1000 [00:42<00:18, 16.17it/s]

Iter: 690
Train_loss: 0.0004059
Test_loss: 0.0004347



 70%|███████   | 704/1000 [00:43<00:18, 16.18it/s]

Iter: 700
Train_loss: 0.0004551
Test_loss: 0.007041



 71%|███████▏  | 714/1000 [00:44<00:17, 16.13it/s]

Iter: 710
Train_loss: 0.0004094
Test_loss: 0.002144



 72%|███████▏  | 724/1000 [00:44<00:17, 16.12it/s]

Iter: 720
Train_loss: 0.0004038
Test_loss: 0.0004407



 73%|███████▎  | 732/1000 [00:45<00:16, 16.06it/s]

Iter: 730
Train_loss: 0.0004174
Test_loss: 0.0004187



 74%|███████▍  | 742/1000 [00:46<00:16, 16.02it/s]

Iter: 740
Train_loss: 0.0004351
Test_loss: 0.002965



 75%|███████▌  | 754/1000 [00:47<00:15, 15.97it/s]

Iter: 750
Train_loss: 0.0004385
Test_loss: 0.0006172



 76%|███████▌  | 762/1000 [00:47<00:14, 15.97it/s]

Iter: 760
Train_loss: 0.000416
Test_loss: 0.0004212



 77%|███████▋  | 772/1000 [00:48<00:14, 15.91it/s]

Iter: 770
Train_loss: 0.0004126
Test_loss: 0.0004808



 78%|███████▊  | 782/1000 [00:49<00:13, 15.88it/s]

Iter: 780
Train_loss: 0.0004095
Test_loss: 0.0004304



 79%|███████▉  | 794/1000 [00:50<00:13, 15.84it/s]

Iter: 790
Train_loss: 0.0004087
Test_loss: 0.0004337



 80%|████████  | 802/1000 [00:50<00:12, 15.83it/s]

Iter: 800
Train_loss: 0.0004345
Test_loss: 0.0004895



 81%|████████  | 812/1000 [00:51<00:11, 15.78it/s]

Iter: 810
Train_loss: 0.0004219
Test_loss: 0.0004151



 82%|████████▏ | 824/1000 [00:52<00:11, 15.77it/s]

Iter: 820
Train_loss: 0.0004204
Test_loss: 0.0004201



 83%|████████▎ | 832/1000 [00:52<00:10, 15.74it/s]

Iter: 830
Train_loss: 0.0004143
Test_loss: 0.0004228



 84%|████████▍ | 844/1000 [00:53<00:09, 15.71it/s]

Iter: 840
Train_loss: 0.0004358
Test_loss: 0.0004032



 85%|████████▌ | 852/1000 [00:54<00:09, 15.67it/s]

Iter: 850
Train_loss: 0.0004112
Test_loss: 0.0004313



 86%|████████▌ | 862/1000 [00:55<00:08, 15.65it/s]

Iter: 860
Train_loss: 0.000416
Test_loss: 0.0004377



 87%|████████▋ | 872/1000 [00:55<00:08, 15.62it/s]

Iter: 870
Train_loss: 0.0004161
Test_loss: 0.0004205



 88%|████████▊ | 884/1000 [00:56<00:07, 15.57it/s]

Iter: 880
Train_loss: 0.0004049
Test_loss: 0.0004411



 89%|████████▉ | 892/1000 [00:57<00:06, 15.53it/s]

Iter: 890
Train_loss: 0.0004361
Test_loss: 0.000406



 90%|█████████ | 904/1000 [00:58<00:06, 15.54it/s]

Iter: 900
Train_loss: 0.0004107
Test_loss: 0.0004331



 91%|█████████▏| 914/1000 [00:58<00:05, 15.55it/s]

Iter: 910
Train_loss: 0.0004107
Test_loss: 0.0004886



 92%|█████████▏| 922/1000 [00:59<00:05, 15.52it/s]

Iter: 920
Train_loss: 0.0004073
Test_loss: 0.0004374



 93%|█████████▎| 934/1000 [01:00<00:04, 15.48it/s]

Iter: 930
Train_loss: 0.0004166
Test_loss: 0.00042



 94%|█████████▍| 942/1000 [01:00<00:03, 15.48it/s]

Iter: 940
Train_loss: 0.0004166
Test_loss: 0.0004225



 95%|█████████▌| 952/1000 [01:01<00:03, 15.47it/s]

Iter: 950
Train_loss: 0.0004151
Test_loss: 0.0004214



 96%|█████████▌| 962/1000 [01:02<00:02, 15.48it/s]

Iter: 960
Train_loss: 0.0004066
Test_loss: 0.0004512



 97%|█████████▋| 974/1000 [01:02<00:01, 15.48it/s]

Iter: 970
Train_loss: 0.0004105
Test_loss: 0.0004273



 98%|█████████▊| 984/1000 [01:03<00:01, 15.46it/s]

Iter: 980
Train_loss: 0.000406
Test_loss: 0.000437



 99%|█████████▉| 992/1000 [01:04<00:00, 15.45it/s]

Iter: 990
Train_loss: 0.0004182
Test_loss: 0.0004216



100%|██████████| 1000/1000 [01:04<00:00, 15.45it/s]
