In [1]:
#Generate data
import numpy as np

n = 100
synthetic_data = np.empty((n*n,3))

x = np.linspace(1,20,n)
y = np.linspace(1,20,n)

for i in range(n):
    for j in range(n):
        #synthetic_data[i+j*n-1,:] = np.array([x[i],y[j],(x[i]-10.0)**2-(y[j]-15)**2])
        synthetic_data[i+j*n-1,:] = np.array([x[i],y[j],0.5*(x[i]-10.0)-2*y[j]])


#Shuffle contents
np.random.shuffle(synthetic_data)

In [2]:
#Graph this data to be sure it is what we want
%matplotlib notebook

from mpl_toolkits.mplot3d import Axes3D
import matplotlib
import numpy as np
import matplotlib.pyplot as plt



fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(synthetic_data[:,0], synthetic_data[:,1], synthetic_data[:,2])

<IPython.core.display.Javascript object>

<mpl_toolkits.mplot3d.art3d.Path3DCollection at 0x107469390>

In [3]:
#Helper functions
def normalise(x):
    """Convert values to range 0-1"""
    x_h = (x - x.min())*1.0 / (x.max() - x.min())
    
    return x_h, x.min(), x.max()

def un_normalise(x_h, x_min, x_max):
    """Convert values from range 0-1 back to normal"""
    x = (x_max-x_min)*x_h + x_min
    
    return x

In [4]:
#Now let's try with an autoencoder
import tensorflow as tf
#Prepare for training and validation with a 70:30 split
split_size = int(synthetic_data.shape[0]*0.7)

#Normalise values
#x_in, x_min, x_max = normalise(synthetic_data)

x_in = synthetic_data

#Remember that for autoencoders we don't need to have any targets since the values themselves are the targets
#train_x, val_x = x[:split_size], x[split_size:]


# number of neurons in each layer
input_num_units = 3
hidden_num_units = 2
output_num_units = 3

# define placeholders
x = tf.placeholder(tf.float32, [None, input_num_units])
x_h = tf.placeholder(tf.float32, [None, output_num_units])


# define weights and biases of the neural network (refer this article if you don't understand the terminologies)

weights = {
    'hidden': tf.Variable(tf.random_normal([input_num_units, hidden_num_units])),
    'output': tf.Variable(tf.random_normal([hidden_num_units, output_num_units]))
}

biases = {
    'hidden': tf.Variable(tf.random_normal([hidden_num_units])),
    'output': tf.Variable(tf.random_normal([output_num_units]))
}

In [12]:
# hidden = sigmoid(x*w_h+b_h)
hidden_layer = tf.add(tf.matmul(x, weights['hidden']), biases['hidden'])
hidden_layer = tf.nn.relu(hidden_layer)
#output_layer = hidden*w_o+b_o
output_layer = tf.matmul(hidden_layer, weights['output']) + biases['output']
#Define cost
cost = tf.reduce_mean(tf.pow(output_layer - x_h, 2))
cross_entropy = -tf.reduce_sum(output_layer*tf.log(x_h))
#Choose Optimiser
optimiser = tf.train.RMSPropOptimizer(0.005).minimize(cost)
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)


n_rounds = 10000
batch_size = 10000
loss_vector = np.empty(n_rounds)

In [6]:
#Training loop

#TODO: add validation since it most likely overfits!

for i in range(n_rounds):
    sample = np.random.randint(n*n, size=batch_size)
    batch_xs = x_in[sample][:]
    batch_ys = x_in[sample][:]
    sess.run(optimiser, feed_dict={x: batch_xs, x_h:batch_ys})
    loss_vector[i] = sess.run(cost, feed_dict={x: batch_xs, x_h:batch_ys})
    if i % 1000 == 0:
        print i, sess.run(cross_entropy, feed_dict={x: batch_xs, x_h:batch_ys}), sess.run(cost, feed_dict={x: batch_xs, x_h:batch_ys})

0 nan 4539.39
1000 nan 0.247526
2000 nan 0.168476
3000 nan 0.16883
4000 nan 0.150783
5000 nan 0.158955
6000 nan 0.150934
7000 nan 0.146443
8000 nan 0.147858
9000 nan 0.149001


In [7]:
print 'W_h:',sess.run(weights['hidden'])
W_h = sess.run(weights['hidden'])
print 'B_h:',sess.run(biases['hidden'])
B_h = sess.run(biases['hidden'])
print 'W_o:',sess.run(weights['output'])
W_o = sess.run(weights['output'])
print 'B_o:',sess.run(biases['output'])
B_o = sess.run(biases['output'])

W_h: [[ 1.10825479  1.79662824]
 [-0.92254263 -0.58999604]
 [-0.77435929 -0.18340752]]
B_h: [-1.40252721  1.05449009]
W_o: [[ 0.18856849  1.36291373 -2.63904214]
 [ 0.51064163 -0.58149254  1.41623688]]
B_o: [-1.46065426 -2.25098562 -1.20063186]


In [11]:
#downscale synthetic data
reconstructed_data = np.empty((n*n,3))
def sigmoid (x): return 1/(1 + np.exp(-x))  
def relu(x): return np.maximum(x, 0)


for i in range(len(x_in)):    
    reconstructed_data[i,:] = np.dot(relu(np.dot(x_in[i,:],W_h)+B_h),W_o)+B_o
    #print np.mean(np.power(reconstructed_data[i,:]-x_in[i,:],2))
#reconstructed_data = un_normalise(reconstructed_data, x_min, x_max)

0.238645666236
0.295258577333
0.377265292469
0.0781455944498
0.150685374519
0.0531310774038
0.213445947126
0.0351322599938
0.0604008363091
0.0763114528689
0.0356324821968
0.0794228456295
0.254917865293
0.0322280146183
0.220630376267
0.143556202347
0.243539150754
0.112732716833
0.0796704014167
0.260066623743
0.229947013751
0.0574653853638
0.304601379159
0.326766263959
0.114102008949
0.0809567076853
0.259473385993
0.120499673603
0.241239153482
0.233058999188
0.0235217684986
0.141404852662
0.0262875730889
0.220223449245
0.291526785074
0.27528878062
0.206749568653
0.246270962594
0.158909646981
0.0346688375805
0.168595279587
0.206896737054
0.0220070442376
0.0517329442086
0.0452301340496
0.0138425605958
0.0774866513354
0.0173818048544
0.0810558700904
0.296742670315
0.0977368011139
0.0418259291026
0.0457927915657
0.110381424574
0.206506815226
0.0465135573723
0.218611761105
0.0314868727976
0.105550133017
0.188928453625
0.277121765215
0.203154231832
0.285656802711
0.0276842255898
0.27524836908


0.160063317377
0.269865149021
0.00738075628806
0.112417508757
0.187754384133
0.059611585143
0.185411361126
0.179175009661
0.100039681268
0.0957933577102
0.174315141016
0.277218203523
0.154923437148
0.157544132651
0.216515282031
0.245218181509
0.249326533333
0.142482861118
0.0465030785666
0.300147159348
0.096465684997
0.201407782717
0.135704876271
0.10040565219
0.199590477586
0.15024558058
0.159291615593
0.0564456483681
0.132355417403
0.175289998833
0.0835534676234
0.0514132669164
0.0869861276383
0.041457070788
0.153959219194
0.0208659509407
0.0275980440749
0.0117871591138
0.172128261226
0.0576573503126
0.12271613661
0.107684357426
0.0403148081947
0.0737855680139
0.0415170570139
0.132193032548
0.140146013678
0.232091255695
0.0496099944528
0.135172809952
0.184536573298
0.201527197747
0.297761638634
0.0242048980733
0.289327231028
0.199549831894
0.161119258764
0.114549471453
0.0464592714288
0.15133605075
0.0629990564863
0.0327675744319
0.124426618266
0.193738013098
0.117074775175
0.0466835

0.117568427287
0.105188938061
0.129153006696
0.256110677461
0.0199848440539
0.0628023514313
0.014502643353
0.0640035228054
0.097892416068
0.244102953047
0.20294683683
0.0525034317838
0.244138880448
0.0823048398758
0.115901135739
0.030994242692
0.0979558996722
0.335542502145
0.184508557394
0.247646035353
0.332951926871
0.015770079243
0.143044950916
0.0823026893935
0.175379930131
0.211710445187
0.174895774115
0.207167408445
0.309802625915
0.0941703686262
0.200886515687
0.189857543487
0.202608238425
0.183026494054
0.371665755894
0.105117644739
0.2164308206
0.0156187386537
0.264207818099
0.0572573377571
0.21247628572
0.0769342587284
0.170781717261
0.290971723601
0.158557420306
0.190395272866
0.100672312183
0.276322018239
0.0466648256051
0.315224512677
0.205024788399
0.0791526895869
0.0365726082422
0.181576805477
0.167886828771
0.0702936833528
0.0512795710887
0.164370643337
0.257523868069
0.0291658028157
0.317235243239
0.131241051357
0.222829667936
0.163249023464
0.274195657867
0.2278238927

0.0457108428654
0.0329849786575
0.0967959353803
0.116248173591
0.0138395540554
0.0424493407035
0.114906518348
0.187961269927
0.0349837865159
0.199369778851
0.249730821779
0.170823448206
0.265355355645
0.344765302175
0.0586818986329
0.12235235894
0.228694936561
0.347815842908
0.145252652845
0.149204518348
0.0474549480838
0.204499920986
0.315854132898
0.226539604929
0.185131218945
0.246535533128
0.177063937616
0.181360504974
0.365655734816
0.337863735287
0.120286522693
0.111451962378
0.233722421939
0.13531743625
0.166499291749
0.122688202065
0.0212268833278
0.0813047936437
0.178361489175
0.0579834304231
0.105734984796
0.0155741486773
0.165689976724
0.055810587943
0.0421347340882
0.0578105937988
0.0319561708993
0.22760769534
0.0535364031917
0.254226625487
0.191369052502
0.0609473602439
0.034608619045
0.0986404108673
0.0400348722723
0.174451978858
0.0899893690502
0.201973363663
0.27169447345
0.0644308503168
0.254053724787
0.0891437809393
0.0676638100543
0.171222856978
0.220204755446
0.1167

0.0429931818635
0.240057132906
0.0541208418717
0.333937097208
0.164225088979
0.131111175849
0.299577013639
0.165452879711
0.0370805955849
0.105958189413
0.0365471192987
0.0134424044936
0.117230799082
0.0928394253485
0.0931084646404
0.336039118638
0.0250764747146
0.36294787275
0.245368228847
0.126174293747
0.222989874506
0.229750313125
0.324827100693
0.268921197035
0.258940273922
0.0728583444388
0.233837556217
0.123893537839
0.047549933024
0.0303866795605
0.35819953635
0.153813939993
0.153759667984
0.241581160597
0.111440479311
0.0733213926244
0.0964896108931
0.0203179136419
0.237245544834
0.0226264185431
0.137787822255
0.23488220139
0.0845489494356
0.0578304403593
0.252357051321
0.0455609897163
0.0589874009484
0.116140682224
0.0899600440097
0.155171140337
0.171046000421
0.075079215891
0.0641603185502
0.0680696303158
0.13032362589
0.253188585735
0.0249034847063
0.04820326783
0.157132677466
0.0355426709697
0.0382894171309
0.234682396178
0.271332371377
0.156070528844
0.2159339531
0.024623

0.0537870641819
0.251551172253
0.169254456108
0.0115645588866
0.12270607
0.047043706343
0.19581696659
0.192098775629
0.158382928066
0.252706174955
0.241308433123
0.115800778322
0.384256273785
0.102609716435
0.0224609850687
0.25427349212
0.0748409606248
0.192513300027
0.0567860836922
0.178498979456
0.124256106769
0.0308770049324
0.233148005282
0.114495767696
0.127638110948
0.034220010015
0.0479658355681
0.107752648257
0.0231335679388
0.192351159581
0.192454420244
0.142022811174
0.0454225866782
0.0892747504504
0.0523893513911
0.169567115542
0.0973933719099
0.206198390045
0.223371521934
0.0934672545071
0.102604471561
0.207625361925
0.218291124501
0.0843081066483
0.208640046959
0.0815223048279
0.212161310706
0.186495856206
0.0526737005369
0.0687354415461
0.0709907584916
0.138702955635
0.200729292104
0.105875666459
0.0525536417492
0.0781194659563
0.145903654565
0.168630749359
0.0140947355405
0.240754179737
0.264943879372
0.118805382911
0.245227573254
0.188801800088
0.0817147103204
0.2329649

0.0325450114944
0.171326278494
0.248768398944
0.0839060209698
0.0569934574333
0.0715690167026
0.0610672257929
0.0230106451906
0.176694791616
0.240563377302
0.072513090692
0.0538950460132
0.0342896892118
0.0567124091593
0.233641731823
0.0454686641015
0.165378933739
0.159163856794
0.0429718805502
0.121891586802
0.0646789849822
0.0814162607898
0.122725088592
0.083261688653
0.160156901978
0.0351002788682
0.0695068420652
0.208718342757
0.0565538910365
0.060526178689
0.107076660104
0.166673103994
0.122770160584
0.0740778715879
0.190289747457
0.356153876816
0.160636256015
0.182420070715
0.181102033364
0.0790069951684
0.184036498982
0.27292794365
0.151232576113
0.0925579539591
0.177411361428
0.117748717464
0.332345434654
0.172067460974
0.0667571480665
0.144666756411
0.367727791724
0.100122781846
0.0765231015567
0.168798753072
0.0146428699737
0.161033726323
0.278845927806
0.27703911526
0.152481037395
0.0576087652533
0.273470732297
0.129246224371
0.189105460087
0.30204192874
0.280633285906
0.142

0.0179896538115
0.194037092253
0.228219069161
0.0240391070278
0.189871644846
0.100152611285
0.0753657585302
0.0601981381876
0.119155830035
0.0996015429918
0.166169995372
0.189598194232
0.0914032919062
0.0912781323899
0.041745545542
0.239056363893
0.110054128368
0.0145325274033
0.214824506511
0.0510861854272
0.192844570443
0.0456346059354
0.0852147625655
0.257136922113
0.0101206845219
0.117963451802
0.0481782132153
0.0973957928207
0.0331913407858
0.207541913507
0.335906568335
0.0753480856468
0.17549610043
0.127096698396
0.133647366905
0.109603636518
0.12430702075
0.187926183274
0.0686945888208
0.283662150866
0.0792864328946
0.173522157872
0.310819108557
0.0829242851425
0.124684883098
0.130594052832
0.341521096506
0.0580752423672
0.0408358872675
0.126708989443
0.22873924201
0.163834218053
0.0234117084428
0.372539293272
0.0694590036224
0.241586812231
0.0665604792484
0.222233429813
0.0238507017861
0.242049748258
0.0832922801228
0.319873053578
0.22495488111
0.13204721314
0.149370529281
0.06

0.156015389851
0.294072673969
0.0658567012998
0.196854893311
0.0476193233645
0.183046332263
0.161901588114
0.133981998926
0.0209621861624
0.238068642441
0.186452116286
0.307781852687
0.0328982981396
0.233663412973
0.149461042651
0.0135910695824
0.158720124415
0.0991247397406
0.0864726996541
0.115218006172
0.0534730980128
0.0777321074901
0.067427760133
0.109980024382
0.136147501489
0.188472695926
0.0916218805401
0.0216400284227
0.0359299740444
0.342137959125
0.13136667512
0.137266788043
0.0883922960821
0.201227887561
0.0384411731339
0.318549984686
0.169745806069
0.0871009093708
0.255535398926
0.266423644042
0.339522016498
0.117066573731
0.135310951323
0.353887681651
0.0302733522882
0.139624722329
0.101313703637
0.185651286346
0.0571482781099
0.307654316244
0.107511947476
0.0403591534904
0.0352972125799
0.0906850019352
0.280884096328
0.128668460312
0.101777722717
0.223310388044
0.168838834666
0.0906429780626
0.0381839639544
0.230325851301
0.182585612617
1.04459581864
0.165978883321
0.080

0.241324948749
0.301385681169
0.0494201288153
0.032380897213
0.149058861954
0.144612901966
0.169715183363
0.144658003599
0.285540388168
0.131684102999
0.30454583818
0.0407847615865
0.327658056093
0.243561933785
0.211691677395
0.0909346477939
0.171760889031
0.140610195334
0.0260608654233
0.1722174163
0.172621058644
0.10695868933
0.127041590676
0.0299633521871
0.113430075251
0.0457298140364
0.053086509074
0.307313520055
0.0979130733669
0.227544896951
0.150348707779
0.278367965883
0.340561129168
0.254405318597
0.13705837423
0.168797377566
0.179162256472
0.0256202798817
0.0949741965375
0.200429505419
0.362309501081
0.113725987293
0.020140218377
0.0127021760025
0.0879179232605
0.0165101907688
0.219222259503
0.0728596999118
0.145907099133
0.0687279646328
0.123621036615
0.0813371694184
0.184210398303
0.0413407912898
0.213494705069
0.108121228061
0.118608942782
0.0797021865575
0.0424919049587
0.165081562967
0.137635902516
0.146006556838
0.0515822807289
0.232499566821
0.0103574945751
0.20823236

In [9]:
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(reconstructed_data[:,0], reconstructed_data[:,1], reconstructed_data[:,2])

<IPython.core.display.Javascript object>

<mpl_toolkits.mplot3d.art3d.Path3DCollection at 0x10f66f110>

In [10]:
np.savetxt('reconstructed.csv', reconstructed_data, delimiter=",")
np.savetxt('synthetic.csv', synthetic_data, delimiter=",")

array([  0.00000000e+000,   9.88131292e-324,   0.00000000e+000, ...,
         0.00000000e+000,   0.00000000e+000,   0.00000000e+000])