In [1]:
import numpy as np
import pandas as pd
import re
import time
import pickle

In [2]:
def get_sites(string, site_type='ACT_SITE'):
    """string=site entry
    Returns the sites as a list, e.g. if 
    string = 'ACT_SITE 85;  /evidence=...; ACT_SITE 110;  /evidence=...'
    returns [85, 110]"""
    
    pattern = re.compile(f'{site_type} [0-9]+;')
    matches = pattern.findall(string)
    get_site = lambda match: int(match.split(' ')[1][:-1])
    return [get_site(match) for match in matches]
get_binding_sites = lambda string: get_sites(string, 'BINDING')

def process_domain(string):
    pattern = re.compile('DOMAIN [0-9]+..[0-9]+')
    matches = pattern.findall(string)
    intervals = [[int(num) for num in match.split()[1].split('..')] for match in matches]

    dn_pattern = re.compile('note="[^;]*"')
    domain_names = [s.split('"')[1] for s in dn_pattern.findall(string)]

    return list(zip( domain_names, intervals))

In [3]:
df = pd.read_csv('uniprot-verbose.tab', sep='\t', lineterminator='\n')
df = df.fillna('')

In [4]:
df.loc[:, 'Active site'] = df.loc[:, 'Active site'].apply(get_sites)
df.loc[:, 'Binding site'] = df.loc[:, 'Binding site'].apply(get_binding_sites)
df.loc[:, 'Domain [FT]'] = df.loc[:, 'Domain [FT]'].apply(process_domain)
df.loc[:, 'Protein families'] = df.loc[:, 'Protein families'].apply(lambda s: re.split(', |; ', s))
df.loc[:, 'EC number'] = df.loc[:, 'EC number'].apply(lambda s: s.split('; '))

In [5]:
df[:5]

Unnamed: 0,Entry,Entry name,Status,Protein names,Gene names,Organism,Length,Sequence,Active site,Binding site,EC number,Zinc finger,Domain [FT],Domain [CC],Protein families
0,Q4R8P0,ABHGA_MACFA,reviewed,Phosphatidylserine lipase ABHD16A (EC 3.1.-.-)...,ABHD16A BAT5 QtsA-11941,Macaca fascicularis (Crab-eating macaque) (Cyn...,558,MAKLLSCVLGPRLYKIYRERDSERAPASVPETPTAVTAPHSSSWDT...,"[355, 430, 507]",[],"[3.1.-.-, 3.1.1.23]",,"[(AB hydrolase-1, [281, 407])]",,"[AB hydrolase superfamily, ABHD16 family]"
1,P03949,ABL1_CAEEL,reviewed,Tyrosine-protein kinase abl-1 (EC 2.7.10.2),abl-1 M79.1,Caenorhabditis elegans,1224,MGHSHSTGKEINDNELFTCEDPVFDQPVASPKSEISSKLAEEIERS...,[432],[340],[2.7.10.2],,"[(SH3, [115, 188]), (SH2, [194, 284]), (Protei...",,"[Protein kinase superfamily, Tyr protein kinas..."
2,P31937,3HIDH_HUMAN,reviewed,"3-hydroxyisobutyrate dehydrogenase, mitochondr...",HIBADH,Homo sapiens (Human),336,MAASLRLLGAASGLRYWSRRLRPAAGSFAAVCSRSVASKTPVGFIG...,[209],"[108, 134, 284]",[1.1.1.31],,[],,"[HIBADH-related family, 3-hydroxyisobutyrate d..."
3,Q9FFR3,6PGD3_ARATH,reviewed,"6-phosphogluconate dehydrogenase, decarboxylat...",PGD3 At5g41670 MBK23.20,Arabidopsis thaliana (Mouse-ear cress),487,MESVALSRIGLAGLAVMGQNLALNIADKGFPISVYNRTTSKVDETL...,"[188, 195]","[108, 108, 196, 266, 293, 458, 464]",[1.1.1.44],,[],,[6-phosphogluconate dehydrogenase family]
4,Q8W1W9,5MAT1_SALSN,reviewed,Malonyl-coenzyme:anthocyanin 5-O-glucoside-6''...,5MAT1,Salvia splendens (Scarlet sage),462,MTTTTTILETCHIPPPPAANDLSIPLSFFDIKWLHYHPVRRLLFYH...,"[167, 390]",[],[2.3.1.172],,[],,[Plant acyltransferase family]


# Look at domains, families

In [6]:
domains = df['Domain [FT]']

In [7]:
domlst = []
for dom in domains:
    domlst.extend(dom)

In [8]:
domlst = list(next(zip(*domlst)))

In [9]:
doms, counts = np.unique(domlst, return_counts=True)

In [10]:
doms[:1000]

array(['(+)RNA virus helicase ATP-binding',
       '(+)RNA virus helicase C-terminal', '2Fe-2S ferredoxin-type',
       '4Fe-4S Mo/W bis-MGD-type', '4Fe-4S ferredoxin-type',
       '4Fe-4S ferredoxin-type 1', '4Fe-4S ferredoxin-type 2',
       '4Fe-4S ferredoxin-type 3', 'A to I editase', 'AB hydrolase-1',
       'ABC transmembrane type-1', 'ABC transporter', 'ACD', 'ACT',
       'ACT 1', 'ACT 2', 'ADD', 'ADPK', 'AGC-kinase C-terminal', 'ALOG',
       'ATP-cone', 'ATP-cone 1', 'ATP-cone 2', 'ATP-cone 3', 'ATP-grasp',
       'ATP-grasp 1', 'ATP-grasp 2', 'AV ZBD', 'Acylphosphatase-like',
       "Adrift-type SAM-dependent 2'-O-MTase", 'Alpha-carbonic anhydrase',
       'Alpha-type protein kinase', 'Alphavirus-like MT', 'Apple',
       'Apple 1', 'Apple 2', 'Apple 3', 'Apple 4',
       'Asparaginase/glutaminase', 'Asparagine synthetase',
       'Autotransporter', 'B30.2/SPRY', 'B30.2/SPRY 1', 'B30.2/SPRY 2',
       'B30.2/SPRY 3', 'BACON', 'BAH', 'BAH 1', 'BAH 2', 'BIG2',
       'BPL/LPL 

In [11]:
counts

array([93, 93, 34, ..., 18, 59,  8], dtype=int64)

In [12]:
decorder = np.argsort(counts)[::-1]

In [13]:
doms[decorder][:100]

array(['Protein kinase', 'Glutamine amidotransferase type-1', 'Rhodanese',
       'BRCT', 'Peptidase S1', 'KARI N-terminal Rossmann',
       'GMPS ATP-PPase', 'AB hydrolase-1', 'KARI C-terminal knotted',
       'FAD-binding PCMH-type', 'RNase III', 'DRBM', 'BPL/LPL catalytic',
       'Tyr recombinase', 'Core-binding (CB)', 'Thioredoxin',
       'RdRp catalytic', 'GATase cobBQ-type', 'AGC-kinase C-terminal',
       'Peptidase S8', 'USP', 'Glutamine amidotransferase type-2',
       'CN hydrolase', 'Acylphosphatase-like', 'Exonuclease', 'Urease',
       'TRAM', 'Tyrosine-protein phosphatase', 'PABS', 'MsrB', 'UmuC',
       'Response regulatory', 'Peptidase M12B',
       'CheB-type methylesterase', 'Peptidase A1', 'Carrier',
       'Inhibitor I9', 'MGS-like', 'SIS', 'C2',
       'Pyruvate carboxyltransferase', 'Lon proteolytic', 'TRUD',
       'Deacetylase sirtuin-type', 'Lon N-terminal', 'Disintegrin',
       'Integrase catalytic', 'PDZ', 'PLD phosphodiesterase 2',
       'PLD phosphodies

In [14]:
counts[decorder][:100]

array([4078, 2134,  856,  808,  680,  647,  587,  571,  527,  517,  497,
        497,  482,  460,  440,  425,  376,  349,  343,  343,  311,  293,
        288,  271,  267,  256,  254,  250,  241,  241,  238,  237,  237,
        236,  231,  230,  224,  222,  219,  219,  217,  204,  201,  191,
        187,  185,  182,  181,  178,  178,  175,  175,  171,  168,  168,
        168,  163,  162,  161,  159,  158,  158,  158,  153,  152,  148,
        147,  145,  142,  142,  141,  141,  141,  140,  137,  133,  133,
        127,  127,  125,  122,  120,  118,  116,  115,  115,  112,  112,
        110,  106,  105,  105,  105,  105,  102,  101,  100,   98,   95,
         94], dtype=int64)

In [15]:
families = df['Protein families']
familist = []
for family in families:
    familist.extend(family)

In [16]:
fams, famcounts = np.unique(familist, return_counts=True)

In [17]:
len(fams), len(familist)

(1817, 143724)

In [18]:
decorder= np.argsort(famcounts)[::-1]
fams[:10], fams[decorder][:10], famcounts[decorder][:10]

(array(['', "'GDSL' lipolytic enzyme family",
        "'GDXG' lipolytic enzyme family", "'phage' integrase family",
        '17-beta-HSD 3 subfamily', '2,4-dienoyl-CoA reductase subfamily',
        '2-acyl-GPE acetyltransferase family',
        '2-hydroxy-3-oxopropionate reductase subfamily',
        '2-oxoacid dehydrogenase family', '2H phosphoesterase superfamily'],
       dtype='<U75'), array(['Protein kinase superfamily', '',
        'Class I-like SAM-binding methyltransferase superfamily',
        'Type 1 subfamily', 'Ser/Thr protein kinase family',
        'HisA/HisF family', 'EPSP synthase family',
        'Thiolase-like superfamily',
        'Short-chain dehydrogenases/reductases (SDR) family',
        'Transferase hexapeptide repeat family'], dtype='<U75'), array([4055, 3378, 2329, 1863, 1443, 1251, 1211, 1024,  976,  947],
       dtype=int64))

In [19]:
ecs = df['EC number']

In [20]:
np.mean([len(np.unique([e[0] for e in ec if ec[0] != ''])) for ec in ecs])

1.0153353170061798

In [21]:
ec_list = [np.unique([int(e[0]) for e in ec if ec[0] != '']) for ec in ecs]

# Load into unirep mlstm

In [22]:
import tensorflow as tf

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [23]:
import os
from UniRep.unirep import babbler1900 as babbler

W0505 20:20:25.612918 15772 deprecation_wrapper.py:119] From E:\Academic\School\SPRING2020\CS294-150\project\UniRep\unirep.py:43: The name tf.nn.rnn_cell.RNNCell is deprecated. Please use tf.compat.v1.nn.rnn_cell.RNNCell instead.

W0505 20:20:29.037851 15772 lazy_loader.py:50] 
The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.



In [171]:
tf.reset_default_graph()

In [172]:
# If you want to initialize a model from scratch, set new to True and trained to False
batch_size = 12
b = babbler(batch_size=batch_size, model_path=f'2000_iters_hs256_nl4/', trained=True, 
                rnn_size=256, n_layers=4, new=False)

W0501 15:16:50.748506 21944 deprecation_wrapper.py:119] From E:\Academic\School\SPRING2020\CS294-150\project\UniRep\unirep.py:336: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead.

W0501 15:16:51.118917 21944 deprecation_wrapper.py:119] From E:\Academic\School\SPRING2020\CS294-150\project\UniRep\unirep.py:390: The name tf.get_variable is deprecated. Please use tf.compat.v1.get_variable instead.

W0501 15:16:51.127894 21944 deprecation.py:323] From E:\Academic\School\SPRING2020\CS294-150\project\UniRep\unirep.py:400: dynamic_rnn (from tensorflow.python.ops.rnn) is deprecated and will be removed in a future version.
Instructions for updating:
Please use `keras.layers.RNN(cell)`, which is equivalent to this API
W0501 15:16:51.547101 21944 deprecation_wrapper.py:119] From E:\Academic\School\SPRING2020\CS294-150\project\294env\lib\site-packages\tensorflow\python\autograph\converters\directives.py:117: The name tf.variable_scope is deprecated. Please use tf

In [14]:
# get sequences of length less than 275
df_le = df[df['Length']<275]
# get site sequences, 0=normal, 1=active, 2=binding
sequences = []
for i,row in df_le.iterrows():
    seq_len = row['Length']
    site_seq = np.zeros(seq_len+1,dtype=np.int32)
    site_seq[np.array(row['Active site'], dtype=np.int32)] = 1
    site_seq[np.array(row['Binding site'], dtype=np.int32)] = 2
    sequences.append(site_seq)
df_le['Site sequence'] = sequences

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/indexing.html#indexing-view-versus-copy
  # This is added back by InteractiveShellApp.init_path()


In [15]:
df_le.iloc[0]

Entry                                                       P56221
Entry name                                              SCYD_MAGO7
Status                                                    reviewed
Protein names       Scytalone dehydratase (SD) (SDH) (EC 4.2.1.94)
Gene names                                          SDH1 MGG_05059
Organism         Magnaporthe oryzae (strain 70-15 / ATCC MYA-46...
Length                                                         172
Sequence         MGSQVQKSDEITFSDYLGLMTCVYEWADSYDSKDWDRLRKVIAPTL...
Active site                                              [85, 110]
Binding site                                     [30, 50, 53, 131]
Site sequence    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...
Name: 1, dtype: object

In [16]:
for entry in df_le[:5].iterrows():
    print(entry[1]['Sequence'])

MGSQVQKSDEITFSDYLGLMTCVYEWADSYDSKDWDRLRKVIAPTLRIDYRSFLDKLWEAMPAEEFVGMVSSKQVLGDPTLRTQHFIGGTRWEKVSEDEVIGYHQLRVPHQRYKDTTMKEVTMKGHAHSANLHWYKKIDGVWKFAGLKPDIRWGEFDFDRIFEDGRETFGDK
MAIKLIAIDMDGTLLLPDHTISPAVKNAIAAARARGVNVVLTTGRPYAGVHNYLKELHMEQPGDYCITYNGALVQKAADGSTVAQTALSYDDYRFLEKLSREVGSHFHALDRTTLYTANRDISYYTVHESFVATIPLVFCEAEKMDPNTQFLKVMMIDEPAILDQAIARIPQEVKEKYTVLKSAPYFLEILDKRVNKGTGVKSLADVLGIKPEEIMAIGDQENDIAMIEYAGVGVAMDNAIPSVKEVANFVTKSNLEDGVAFAIEKYVLN
MEGQRWLPLEANPEVTNQFLKQLGLHPNWQFVDVYGMEPELLSMVPRPVCAVLLLFPITEKYEVFRTEEEEKIKSQGQDVTSSVYFMKQTISNACGTIGLIHAIANNKDKMHFESGSTLKKFLEESVSMSPEERAKFLENYDAIRVTHETSAHEGQTEAPSIDEKVDLHFIALVHVDGHLYELDGRKPFPINHGKTSDETLLEDAIEVCKKFMERDPDELRFNAIALSAA
MLRLLVVASLVLYGHSTQDFPETNARVVGGTEAQRNSWPSQISLQYRSGSSWAHTCGGTLIRQNWVMTAAHCVDRELTFRVVVGEHNLNQNDGTEQYVGVQKIVVHPYWNTDDVAAGYDIALLRLAQSVTLNSYVQLGVLPRAGTILANNSPCYITGWGLTRTNGQLAQTLQQAYLPTVDYAICSSSSYWGSTVKNSMVCAGGDGVRSGCQGDSGGPLHCLVNGQYAVHGVTSFVSRLGCNVTRKPTVFTRVSAYISWINNVIASN
MAKNRFNQHWLHDHINDPYVKMAQREGYRARAAYKLKEIDEQDKLIRPGQVIVDLGAT

In [17]:
# Before you can train your model, 
with open("formatted.txt", "w") as destination:
    for i,row in df_le.iterrows():
        seq = row['Sequence']
        site_seq = row['Site sequence']
        if b.is_valid_seq(seq): 
            seq = b.format_seq(seq)
            interleaved_seq = np.zeros(2*len(seq), dtype=np.int32)
            # interleave the amino acid sequence and site sequence
            # because this is how we are storing both datas
            interleaved_seq[::2] = seq
            interleaved_seq[1::2] = site_seq+1
            formatted = ",".join(map(str,interleaved_seq))
            destination.write(formatted)
            destination.write('\n')

In [18]:
bucket_op = b.bucket_batch_pad("formatted.txt", lower=200, interval=1000) # Large interval

W0427 12:53:07.163650  1000 deprecation_wrapper.py:119] From E:\Academic\School\SPRING2020\CS294-150\project\UniRep\data_utils.py:94: The name tf.string_to_number is deprecated. Please use tf.strings.to_number instead.

W0427 12:53:07.178602  1000 deprecation.py:323] From E:\Academic\School\SPRING2020\CS294-150\project\UniRep\data_utils.py:190: group_by_window (from tensorflow.contrib.data.python.ops.grouping) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.data.experimental.group_by_window(...)`.
W0427 12:53:07.210974  1000 deprecation.py:323] From E:\Academic\School\SPRING2020\CS294-150\project\UniRep\unirep.py:601: DatasetV1.make_one_shot_iterator (from tensorflow.python.data.ops.dataset_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `for ... in dataset:` to iterate over a dataset. If using `tf.estimator`, return the `Dataset` object directly from your input function. As a last resort, you can use 

In [19]:
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    batch = sess.run(bucket_op)
    em = sess.run(b.embed_matrix)
   
# even indices = amino acid 
# odd indices = site type
print(batch)
print(batch.shape)

[[24  1  1 ...  0  0  0]
 [24  1  1 ...  0  0  0]
 [24  1  1 ...  0  0  0]
 ...
 [24  1  1 ...  0  0  0]
 [24  1  1 ...  0  0  0]
 [24  1  1 ...  0  0  0]]
(12, 530)


In [20]:
final_hidden, x_placeholder, batch_size_placeholder, seq_length_placeholder, initial_state_placeholder = (
    b.get_rep_ops())

In [21]:
def nonpad_len(batch):
    nonzero = batch > 0
    lengths = np.sum(nonzero, axis=1)/2
    return lengths

nonpad_len(batch)

array([173., 243., 150., 265., 247., 256., 232., 202., 256., 138., 260.,
       110.])

In [22]:
optimizer = tf.train.AdamOptimizer(0.001)
all_step_op = optimizer.minimize(b._loss+b.site_loss)

W0427 12:53:11.000681  1000 deprecation.py:323] From E:\Academic\School\SPRING2020\CS294-150\project\294env\lib\site-packages\tensorflow\python\ops\math_grad.py:1250: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where


In [26]:
dummy='dummy_folder'

In [27]:
if not os.path.exists(dummy):
    os.mkdir(dummy)

In [28]:
# How training looks like
num_iters = 1
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for i in range(num_iters):
        batch = sess.run(bucket_op)
        seq_batch = batch[:, ::2]
        site_seq_batch = batch[:, 1::2]-1
        length = nonpad_len(batch)-1
        loss, site_loss, _= sess.run([b._loss, b.site_loss, all_step_op],
                                       feed_dict = {
                                        x_placeholder: seq_batch[:, :-1],
                                        b._minibatch_y_placeholder: seq_batch[:, 1:],
                                        b._minibatch_site_y_placeholder: site_seq_batch[:, :-1],
                                        initial_state_placeholder: b._zero_state,
                                        seq_length_placeholder: length,
                                        batch_size_placeholder: batch_size,
                                        b._temp_placeholder: 0.1
                                       })
    
        print(f'Iteration {i}, loss {loss}, site loss {site_loss}')
    b.dump_weights(sess, dummy)

Iteration 0, loss 1.8776484727859497, site loss 0.010596365667879581
embed_matrix_0
[[ 0.01795816 -0.59574187 -0.31995562  0.06984629  0.40433308 -0.20387918
   0.20837028  0.02349814 -0.08639826  0.3730955 ]
 [ 0.5577035  -0.7487646  -0.10239369 -0.32437634 -0.5262135  -0.15576366
   0.4926013  -0.17384009  0.40252152 -0.21631326]
 [ 0.5650982   0.23117451  0.40731284 -0.3695129   0.23110402  0.1681975
  -0.02075569 -0.47987574 -0.66583145 -0.00447276]
 [-0.03245609 -0.11626884 -0.33265555  0.5045284  -0.17571148  0.05205186
   0.16283442  0.19018494 -0.12135593 -0.49977335]
 [-0.5703836  -0.446131    0.2960824  -0.1666478  -0.31925482 -0.12011106
   0.02298783  0.02474621 -0.18123847 -0.32975945]
 [ 0.33877403  0.72887784 -0.26684377 -0.53426933  0.06602486 -0.26432738
   0.21926865  0.38794106 -0.46740597 -0.15146168]
 [-0.68923706 -0.31945857  0.13209684 -0.3690833   0.3801327   0.00466809
   0.29135716 -0.17630763 -0.33551458  0.38000694]
 [ 0.05775635 -0.04033554  0.5785264  -0.3

rnn/mlstm_stack/mlstm_stack1/mlstm_stack1/wh_0
[[ 0.03058092  0.07574227 -0.04154119 ...  0.01504024 -0.16453116
  -0.05265732]
 [ 0.11254816 -0.1757097  -0.05270271 ...  0.20333765  0.05326265
   0.07906359]
 [-0.07605886 -0.10128421 -0.04283068 ...  0.03830232  0.06412466
   0.10309768]
 ...
 [ 0.04668658  0.0159542  -0.01361777 ... -0.04152701 -0.0678052
   0.06222383]
 [ 0.11282471  0.01528359  0.10479821 ...  0.05632594 -0.02272473
   0.13245209]
 [-0.08332232 -0.05206282 -0.0748313  ... -0.01816005  0.01620666
   0.05192275]]
rnn/mlstm_stack/mlstm_stack1/mlstm_stack1/wmx_0
[[ 0.09421043 -0.0781231   0.02037435 ... -0.02504922  0.04252028
   0.03068395]
 [ 0.21274836  0.00855812 -0.10528336 ... -0.12454712  0.04522993
   0.03402837]
 [ 0.02699681 -0.1573665  -0.15890574 ... -0.02753062  0.02302762
  -0.08012205]
 ...
 [ 0.02670928 -0.17337579 -0.13450345 ... -0.08739836 -0.17312494
   0.03137153]
 [ 0.03731266 -0.06851152 -0.20184967 ...  0.12788942  0.14036721
  -0.13469648]
 [ 0

rnn/mlstm_stack/mlstm_stack2/mlstm_stack2/b_0
[0.9869088  1.0195531  1.0045556  ... 1.0151227  0.97398984 1.0394554 ]
rnn/mlstm_stack/mlstm_stack2/mlstm_stack2/gx_0
[1.0722752  1.0518187  0.985807   ... 1.028929   1.05064    0.98292613]
rnn/mlstm_stack/mlstm_stack2/mlstm_stack2/gh_0
[1.0500592 1.055963  1.0372034 ... 1.0095779 0.9779369 1.0382587]
rnn/mlstm_stack/mlstm_stack2/mlstm_stack2/gmx_0
[1.0519857 1.1069614 1.1262553 1.1349046 1.0496802 1.0689726 1.1137491
 1.1090862 1.0589852 1.0463536 1.1260011 1.1515509 1.1284994 1.0571102
 1.0953683 1.1648903 1.1232973 1.0794948 1.1697711 1.1325263 1.0762184
 1.1306193 1.1007996 1.0926807 1.1552886 1.0496541 1.0528674 1.0475335
 1.0768211 1.0697201 1.1276075 1.0508606 1.1463007 1.1190034 1.199441
 1.0428331 1.0599384 1.1503233 1.1236423 1.1547949 1.0737019 1.1188691
 1.0943556 1.0817437 1.0898468 1.1075395 1.1597595 1.058198  1.1044289
 1.0651953 1.2097456 1.1109205 1.1058978 1.0954901 1.1396918 1.056434
 1.0520744 1.1367759 1.1541733 1.053

fully_connected/biases_0
[ 0.03505907  0.01039774 -0.0016424   0.00355181  0.00422901  0.00458979
  0.00787218  0.01096485  0.00654174  0.0100566  -0.0129068  -0.04903596
  0.00998564  0.00271189  0.00904888  0.0086786   0.00993572  0.00210927
  0.00435306  0.00063221  0.00733471 -0.12081234 -0.11099907 -0.11211418
 -0.10853834]
fully_connected_1/weights_0
[[ 1.33566245e-01  2.82388311e-02 -7.35270604e-02]
 [ 5.80520295e-02 -2.11535886e-01  9.84705761e-02]
 [-1.01594336e-01 -1.21100783e-01  1.24244280e-01]
 [ 2.36091018e-02  4.29169051e-02  1.67265534e-01]
 [ 2.19468430e-01 -2.26097852e-01 -1.06713008e-02]
 [-7.74747431e-02  6.90045804e-02 -2.20922709e-01]
 [-1.56648472e-01 -1.65638775e-02  8.58242735e-02]
 [-7.15215281e-02 -5.55984713e-02 -1.10751152e-01]
 [ 1.46027029e-01  1.28295064e-01  4.34653535e-02]
 [-8.48647207e-02  1.28462866e-01  1.09572962e-01]
 [-1.33235669e-02 -1.21733449e-01 -6.63605034e-02]
 [-2.13183463e-01  3.13877910e-01  6.93300217e-02]
 [ 1.37540758e-01 -5.73109575

## Get the representation of an amino acid sequence

In [30]:
# tensorflow is kinda dumb so you need to do this whenever you initialize a new model 
# of the same type, so if you want to rereun the previous stuff after initializing 
# b_trained you need to go all the way back to the beginning where it also resets the 
# default graph
tf.reset_default_graph()

In [31]:
import UniRep

In [32]:
import importlib
importlib.reload(UniRep.unirep)

<module 'UniRep.unirep' from 'E:\\Academic\\School\\SPRING2020\\CS294-150\\project\\UniRep\\unirep.py'>

In [33]:
babbler=UniRep.unirep.babbler1900

In [34]:
# set batch size to one to predict next character for one sequence at a time
# in next section
b_trained = babbler(batch_size=256, model_path=f'./2000_iters_hs256_nl4_nosite/', trained=True,
                   rnn_size=256, n_layers=4, new=False)

In [35]:
i=np.random.randint(0, 10000)
seq = df.iloc[i]['Sequence']
df.iloc[i]

Entry                                                          B3PTU7
Entry name                                                 DNLJ_RHIE6
Status                                                       reviewed
Protein names       DNA ligase (EC 6.5.1.2) (Polydeoxyribonucleoti...
Gene names                                     ligA RHECIAT_CH0002985
Organism                             Rhizobium etli (strain CIAT 652)
Length                                                            718
Sequence            MSTEGSAVDTLTIEEAAAELERLAKEIAHHDALYHGKDQPEISDAD...
Active site                                                     [129]
Binding site                                [127, 150, 186, 302, 326]
EC number                                                   [6.5.1.2]
Zinc finger                                                          
Domain [FT]                                      [(BRCT, [640, 718])]
Domain [CC]                                                          
Protein families    

In [36]:
# how to obtain hidden representations of MLSTM
avg_hidden, final_hidden, final_cell = b_trained.get_rep(seq)

W0505 20:20:39.560739 15772 deprecation_wrapper.py:119] From E:\Academic\School\SPRING2020\CS294-150\project\UniRep\unirep.py:35: The name tf.global_variables is deprecated. Please use tf.compat.v1.global_variables instead.



## see how well it does on this sequence

In [28]:
batch = np.array([np.array(b_trained.format_seq(seq))])

In [265]:
final_hidden, x_placeholder, batch_size_placeholder, seq_length_placeholder, initial_state_placeholder = (
    b_trained.get_rep_ops())

In [266]:
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    length = batch[0].shape[0]-1
    loss, logits, sample = sess.run([b_trained._loss, b_trained._logits, b_trained._sample],
                                   feed_dict = {
                                    x_placeholder: batch[:, :-1],
                                    b_trained._minibatch_y_placeholder: batch[:, 1:],
                                    initial_state_placeholder: b_trained._zero_state,
                                    seq_length_placeholder: np.array([length]),
                                    batch_size_placeholder: 1,
                                    b_trained._temp_placeholder: 0.1
                                   })

    print(f'loss {loss}')

# for reference, uniform random predictor would get ln(24) = 3.17. 

loss 2.6757969856262207


In [267]:
# 2nd through nth characters of sequence
aas=batch[0][1:]
aas

array([ 1,  4,  4, 17,  1,  1, 17, 13, 19, 13, 15,  1, 15,  2,  6, 16, 21,
        7,  2, 21, 14,  5, 13, 16,  7, 16, 13, 20, 17, 21, 15,  2, 15, 15,
        3,  3, 15, 15, 17,  5,  7, 15, 18, 13, 13, 10, 16, 10, 15, 21,  8,
        3, 14,  5, 10, 11,  8,  6, 10, 14,  5, 21, 16, 21,  6, 11, 15,  7,
       10, 10, 15, 16, 15,  6, 18, 13,  6, 15, 16, 16,  8,  2, 13, 20, 14,
       21, 15, 16, 17,  7,  8, 13, 15, 21, 15,  5, 15, 15, 21, 10, 10,  2,
       21, 10, 10, 15, 11,  2, 10,  3, 10, 13, 10, 21, 17, 16, 21,  7, 13,
       15, 16, 15, 13,  1,  5, 13, 21, 15,  7, 15,  2,  6, 13, 13, 21,  5,
        7, 16,  8, 19, 10, 15, 11,  4,  7, 14, 15,  7, 20,  2, 13,  7,  1,
       15,  6, 10, 21, 17,  5, 21,  5, 15, 16,  7,  6, 15, 10, 16, 18, 18,
        6, 13,  7, 15,  2,  6, 15, 15,  2, 21, 18, 14, 15,  9, 15,  9, 16,
       15, 15,  8, 17, 15, 21,  9, 13, 21, 13,  1,  5, 15,  8,  2, 16,  2,
       21, 21, 16,  5, 14, 15,  8,  2,  2,  9,  8,  3,  2, 21, 10, 16, 11,
       13,  9, 18, 13,  6

In [268]:
# predicted 2nd through nth characters of sequence
aas2=sample[0]+1
aas2

array([ 1, 15,  4, 21, 21, 17, 17, 21, 21, 21,  4, 18, 13, 15, 15, 21, 17,
        6,  6, 21, 21,  6, 17, 16,  6, 17, 16, 15, 16, 15,  5,  2,  5,  6,
        6,  4, 15,  6, 15, 15, 15, 21, 21, 14,  5, 16, 15, 16, 16, 16,  5,
        5,  5,  6,  6, 16, 21, 15, 15, 15,  5, 16, 16, 16, 17, 15,  8, 13,
       13,  5, 15, 17,  6,  6, 15, 16, 15, 15, 15, 21, 15, 15, 13, 16,  2,
       16, 16, 16, 16, 15, 15, 13, 13, 18,  7,  5,  6,  6, 21, 15, 15, 15,
       21, 15, 15, 15, 13, 15,  6, 15, 13, 16, 16, 21, 21, 13, 13, 13, 13,
       14,  7, 21, 13, 13,  5,  3, 21, 15, 13, 21, 21, 21, 15, 21, 21, 15,
       15, 21, 21, 21, 21, 15, 13, 13, 13, 13, 15, 15, 15, 15, 15, 21, 15,
       15, 15, 21, 21, 15, 15, 21, 13, 21, 14, 21, 16, 16, 14, 21, 15, 14,
        5, 13,  2, 21, 21, 15, 21, 21, 16, 21, 15, 15,  9, 21, 21, 21, 21,
       15, 15, 21, 21, 15, 15, 15, 15, 21, 15, 21, 15, 15, 21, 21, 15, 15,
       15, 15, 15,  5, 15, 21, 21,  2, 15, 15, 21, 15, 15, 15, 21, 15, 21,
       13, 21, 13, 15,  2

In [269]:
print(f"Proportion of amino acids predicted correctly: {np.sum(aas==aas2)/len(aas)}")

Proportion of amino acids predicted correctly: 0.18181818181818182


## Train EC prediction

In [37]:
import sklearn
import torch
import gc

In [38]:
ec_list = np.array(ec_list)

In [39]:
dset_mask = np.array([len(ec)>0 for ec in ec_list])
dset_inds = np.where(dset_mask)[0]

In [40]:
choice = dset_inds[np.random.choice(len(dset_inds), 5)]
print(ec_list[choice])
print(df['EC number'][choice])

[array([5]) array([4]) array([3]) array([1]) array([6])]
59890     [5.3.1.16]
43375    [4.2.1.126]
98619      [3.5.1.5]
12644     [1.1.1.23]
34719      [6.3.4.4]
Name: EC number, dtype: object


In [35]:
with open('reps/2000_iters_hs256_nl4_reps_2.pkl', 'rb') as f:
    repdict = pickle.load(f)

In [36]:
truinds = repdict['inds']

In [35]:
perm  = np.random.permutation(len(dset_inds))
tr_n  = 10000
test_n = 5000
tr_inds = dset_inds[perm[:tr_n]]
test_inds = dset_inds[perm[tr_n:tr_n+test_n]]

In [38]:
other_inds = np.array([ind for ind in dset_inds if (ind not in tr_inds) and (ind not in test_inds)])

In [44]:
ninds = 15000
perm = np.random.permutation(len(other_inds))
random_otherinds = other_inds[perm[:ninds]]

In [45]:
CHOSENINDS = truinds[9600:15000-24]

In [41]:
avg_hidds, final_hidds, final_cels = [], [], []
for i in range(0, len(dset_inds), 256):
    time1=time.time()
    avg_hidd, final_hidd, final_cel = b_trained.get_reps(df['Sequence'][dset_inds[i:i+256]],
                                                         batch_size=256)
    avg_hidds.append(avg_hidd)
    final_hidds.append(final_hidd)
    final_cels.append(final_cel)
    print(i, time.time()-time1)

0 46.12215733528137
256 35.983701944351196
512 27.17345404624939
768 32.5729603767395
1024 62.470458030700684
1280 53.298503160476685
1536 44.74130153656006
1792 58.54796481132507
2048 113.31681227684021
2304 68.15387105941772
2560 64.09104347229004
2816 53.414634227752686
3072 30.816041707992554
3328 57.42905068397522
3584 55.430110931396484
3840 377.5745368003845
4096 98.64564847946167
4352 37.5747435092926
4608 60.801188707351685
4864 138.49310207366943
5120 105.25707221031189
5376 100.83867239952087
5632 52.54754567146301
5888 42.863407135009766
6144 112.35000777244568
6400 62.572108030319214
6656 110.18701100349426
6912 47.20949578285217
7168 41.73917841911316
7424 58.544806480407715
7680 23.671121835708618
7936 39.416433572769165
8192 132.05930757522583
8448 38.97865128517151
8704 66.91602420806885
8960 53.149168729782104
9216 26.458524465560913
9472 39.75183320045471
9728 78.45279097557068
9984 63.982258558273315
10240 44.95059895515442
10496 89.51104187965393
10752 108.17236804

86528 199.88636946678162
86784 59.3066840171814
87040 110.10653281211853
87296 104.28512716293335
87552 150.3551573753357
87808 138.53744745254517
88064 111.05420756340027
88320 60.45739030838013
88576 79.3699631690979
88832 235.17159128189087
89088 175.3391032218933
89344 82.37285852432251
89600 122.44573664665222
89856 87.02854084968567
90112 74.87170267105103
90368 122.21816754341125
90624 104.1155297756195
90880 483.8997006416321
91136 101.66307854652405
91392 122.62224888801575
91648 120.11565256118774
91904 235.60380840301514
92160 105.27639770507812
92416 203.38869261741638
92672 182.14376091957092
92928 135.68643260002136
93184 472.5738351345062
93440 5523.026383399963
93696 669.8905100822449
93952 320.69121527671814
94208 257.6786136627197
94464 93.34719777107239
94720 153.6081416606903
94976 275.1426315307617
95232 118.61586117744446
95488 638.2354953289032
95744 94.30101537704468
96000 72.47443914413452
96256 242.32404780387878
96512 135.09342670440674
96768 761.668354988098

ValueError: need at least one array to concatenate

In [38]:
avg_hiddss = np.concatenate(avg_hidds)
final_hiddss = np.concatenate(final_hidds)
final_celss = np.concatenate(np.array(final_cels)[:, 0])

In [39]:
avg_hs = np.concatenate((avg_hiddss,))
fin_hs = np.concatenate((final_hiddss,))
fin_cs = np.concatenate((final_celss,))

In [40]:
avg_hs.shape, fin_hs.shape, fin_cs.shape

((97792, 256), (97792, 256), (97792, 256))

In [206]:
truinds = np.concatenate((asdf['inds'][:2854], CHOSENINDS))

In [43]:
# with open('reps/2000_iters_hs256_nl4_reps_full.pkl', 'wb') as f:
#     pickle.dump(({'inds': dset_inds[:avg_hs.shape[0]],
#                  'representations': (avg_hs, fin_hs, fin_cs)}), f)

In [60]:
# with open('reps/2000_iters_hs256_nl4_reps.pkl', 'rb') as f:
#     stuff = pickle.load(f)

In [61]:
tr_n  = 10000
test_n = 4000

In [62]:
tr_X = np.array(avg_hs[:tr_n])
test_X = np.array(avg_hs[:test_n])
tr_y = np.array([np.isin(np.arange(7), ec-1) for ec in ec_list[tr_inds]])
tr_y = tr_y/tr_y.sum(axis=1).reshape(-1,1)
test_y = np.array([np.isin(np.arange(7), ec-1) for ec in ec_list[test_inds]])
test_y = test_y/test_y.sum(axis=1).reshape(-1,1)

NameError: name 'tr_inds' is not defined

In [70]:
s = tf.InteractiveSession()



In [78]:
input_X = tf.placeholder('float32',shape =(None,1900),name="input_X")
input_y = tf.placeholder('float32',shape = (None,7),name='input_Y')

In [79]:
weights_0 = tf.Variable(tf.random_normal([1900, 7], stddev=(1/tf.sqrt(float(256)))))
bias_0 = tf.Variable(tf.random_normal([7]))
predicted_y = tf.matmul(input_X,weights_0) + bias_0

In [80]:
starter_learning_rate = 0.001

In [81]:
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=predicted_y,labels=input_y))

In [82]:
learning_rate = tf.train.exponential_decay(starter_learning_rate, 0, 5, 0.9, staircase=True)
## Adam optimzer for finding the right weight
optimizer = tf.train.AdamOptimizer(learning_rate).minimize(loss,var_list=[weights_0, bias_0])

In [83]:
## Metrics definition
correct_prediction = tf.equal(tf.argmax(input_y,1), tf.argmax(predicted_y,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

In [84]:
## Training parameters
batch_size = 128
epochs=15
ntrials = 200
max_test_accs = []
for trial in range(ntrials):
    training_accuracy = []
    training_loss = []
    testing_accuracy = []
    s.run(tf.global_variables_initializer())
    for epoch in range(epochs):    
        arr = np.arange(tr_n)
        np.random.shuffle(arr)
        for index in range(0,tr_n,batch_size):
            s.run(optimizer, {input_X: tr_X[arr[index:index+batch_size]],
                              input_y: tr_y[arr[index:index+batch_size]]})
        training_accuracy.append(s.run(accuracy, feed_dict= {input_X:tr_X, 
                                                             input_y: tr_y}))
        training_loss.append(s.run(loss, {input_X: tr_X, 
                                          input_y: tr_y}))

        ## Evaluation of model
        test_preds = s.run(predicted_y, {input_X: test_X}).argmax(1)
        test_acc = np.sum(test_y[np.arange(test_n), test_preds] != 0)/test_n
        testing_accuracy.append(test_acc)
        print("Epoch:{0}, Train loss: {1:.2f} Train acc: {2:.3f}, Test acc:{3:.3f}".format(epoch,
                                                                        training_loss[epoch],
                                                                        training_accuracy[epoch],
                                                                       testing_accuracy[epoch]))
    max_test_accs.append(max(testing_accuracy))
print('Average max test accuracy:', np.mean(max_test_accs))

Epoch:0, Train loss: 1.97 Train acc: 0.097, Test acc:0.086
Epoch:1, Train loss: 1.63 Train acc: 0.396, Test acc:0.286
Epoch:2, Train loss: 1.53 Train acc: 0.449, Test acc:0.308
Epoch:3, Train loss: 1.46 Train acc: 0.483, Test acc:0.290
Epoch:4, Train loss: 1.41 Train acc: 0.532, Test acc:0.244
Epoch:5, Train loss: 1.36 Train acc: 0.533, Test acc:0.268
Epoch:6, Train loss: 1.33 Train acc: 0.552, Test acc:0.270
Epoch:7, Train loss: 1.30 Train acc: 0.549, Test acc:0.280
Epoch:8, Train loss: 1.27 Train acc: 0.563, Test acc:0.274
Epoch:9, Train loss: 1.25 Train acc: 0.563, Test acc:0.286
Epoch:10, Train loss: 1.23 Train acc: 0.570, Test acc:0.282
Epoch:11, Train loss: 1.21 Train acc: 0.578, Test acc:0.272
Epoch:12, Train loss: 1.19 Train acc: 0.581, Test acc:0.268
Epoch:13, Train loss: 1.17 Train acc: 0.594, Test acc:0.260
Epoch:14, Train loss: 1.16 Train acc: 0.593, Test acc:0.268
Epoch:0, Train loss: 1.63 Train acc: 0.326, Test acc:0.296
Epoch:1, Train loss: 1.53 Train acc: 0.391, Test ac

Epoch:9, Train loss: 1.22 Train acc: 0.558, Test acc:0.288
Epoch:10, Train loss: 1.20 Train acc: 0.578, Test acc:0.278
Epoch:11, Train loss: 1.18 Train acc: 0.593, Test acc:0.274
Epoch:12, Train loss: 1.16 Train acc: 0.590, Test acc:0.276
Epoch:13, Train loss: 1.15 Train acc: 0.602, Test acc:0.276
Epoch:14, Train loss: 1.13 Train acc: 0.613, Test acc:0.266
Epoch:0, Train loss: 1.65 Train acc: 0.334, Test acc:0.312
Epoch:1, Train loss: 1.54 Train acc: 0.462, Test acc:0.288
Epoch:2, Train loss: 1.46 Train acc: 0.493, Test acc:0.274
Epoch:3, Train loss: 1.40 Train acc: 0.513, Test acc:0.276
Epoch:4, Train loss: 1.35 Train acc: 0.528, Test acc:0.270
Epoch:5, Train loss: 1.31 Train acc: 0.551, Test acc:0.268
Epoch:6, Train loss: 1.27 Train acc: 0.556, Test acc:0.268
Epoch:7, Train loss: 1.25 Train acc: 0.563, Test acc:0.268
Epoch:8, Train loss: 1.22 Train acc: 0.572, Test acc:0.268
Epoch:9, Train loss: 1.20 Train acc: 0.577, Test acc:0.262
Epoch:10, Train loss: 1.18 Train acc: 0.588, Test a

Epoch:0, Train loss: 1.82 Train acc: 0.289, Test acc:0.276
Epoch:1, Train loss: 1.62 Train acc: 0.366, Test acc:0.298
Epoch:2, Train loss: 1.52 Train acc: 0.399, Test acc:0.270
Epoch:3, Train loss: 1.44 Train acc: 0.445, Test acc:0.284
Epoch:4, Train loss: 1.38 Train acc: 0.504, Test acc:0.288
Epoch:5, Train loss: 1.34 Train acc: 0.521, Test acc:0.276
Epoch:6, Train loss: 1.31 Train acc: 0.538, Test acc:0.280
Epoch:7, Train loss: 1.28 Train acc: 0.548, Test acc:0.278
Epoch:8, Train loss: 1.25 Train acc: 0.556, Test acc:0.290
Epoch:9, Train loss: 1.23 Train acc: 0.560, Test acc:0.282
Epoch:10, Train loss: 1.21 Train acc: 0.571, Test acc:0.284
Epoch:11, Train loss: 1.19 Train acc: 0.571, Test acc:0.274
Epoch:12, Train loss: 1.17 Train acc: 0.578, Test acc:0.280
Epoch:13, Train loss: 1.16 Train acc: 0.584, Test acc:0.278
Epoch:14, Train loss: 1.14 Train acc: 0.597, Test acc:0.264
Epoch:0, Train loss: 1.75 Train acc: 0.114, Test acc:0.206
Epoch:1, Train loss: 1.60 Train acc: 0.364, Test ac

Epoch:7, Train loss: 1.24 Train acc: 0.568, Test acc:0.270
Epoch:8, Train loss: 1.22 Train acc: 0.576, Test acc:0.276
Epoch:9, Train loss: 1.20 Train acc: 0.588, Test acc:0.260
Epoch:10, Train loss: 1.18 Train acc: 0.592, Test acc:0.270
Epoch:11, Train loss: 1.16 Train acc: 0.592, Test acc:0.276
Epoch:12, Train loss: 1.14 Train acc: 0.607, Test acc:0.262
Epoch:13, Train loss: 1.13 Train acc: 0.613, Test acc:0.266
Epoch:14, Train loss: 1.11 Train acc: 0.615, Test acc:0.264
Epoch:0, Train loss: 1.85 Train acc: 0.173, Test acc:0.182
Epoch:1, Train loss: 1.63 Train acc: 0.366, Test acc:0.300
Epoch:2, Train loss: 1.54 Train acc: 0.386, Test acc:0.302
Epoch:3, Train loss: 1.44 Train acc: 0.467, Test acc:0.286
Epoch:4, Train loss: 1.39 Train acc: 0.509, Test acc:0.290
Epoch:5, Train loss: 1.35 Train acc: 0.518, Test acc:0.298
Epoch:6, Train loss: 1.31 Train acc: 0.542, Test acc:0.286
Epoch:7, Train loss: 1.28 Train acc: 0.547, Test acc:0.280
Epoch:8, Train loss: 1.26 Train acc: 0.565, Test ac

Epoch:14, Train loss: 1.12 Train acc: 0.603, Test acc:0.272
Epoch:0, Train loss: 1.60 Train acc: 0.360, Test acc:0.306
Epoch:1, Train loss: 1.50 Train acc: 0.446, Test acc:0.296
Epoch:2, Train loss: 1.42 Train acc: 0.476, Test acc:0.298
Epoch:3, Train loss: 1.36 Train acc: 0.522, Test acc:0.278
Epoch:4, Train loss: 1.32 Train acc: 0.533, Test acc:0.274
Epoch:5, Train loss: 1.28 Train acc: 0.548, Test acc:0.280
Epoch:6, Train loss: 1.25 Train acc: 0.557, Test acc:0.272
Epoch:7, Train loss: 1.23 Train acc: 0.565, Test acc:0.274
Epoch:8, Train loss: 1.20 Train acc: 0.585, Test acc:0.268
Epoch:9, Train loss: 1.18 Train acc: 0.582, Test acc:0.278
Epoch:10, Train loss: 1.16 Train acc: 0.590, Test acc:0.272
Epoch:11, Train loss: 1.15 Train acc: 0.608, Test acc:0.256
Epoch:12, Train loss: 1.13 Train acc: 0.611, Test acc:0.264
Epoch:13, Train loss: 1.12 Train acc: 0.619, Test acc:0.260
Epoch:14, Train loss: 1.10 Train acc: 0.620, Test acc:0.258
Epoch:0, Train loss: 1.74 Train acc: 0.345, Test a

Epoch:8, Train loss: 1.28 Train acc: 0.553, Test acc:0.262
Epoch:9, Train loss: 1.26 Train acc: 0.557, Test acc:0.266
Epoch:10, Train loss: 1.24 Train acc: 0.564, Test acc:0.266
Epoch:11, Train loss: 1.22 Train acc: 0.563, Test acc:0.270
Epoch:12, Train loss: 1.20 Train acc: 0.575, Test acc:0.264
Epoch:13, Train loss: 1.19 Train acc: 0.577, Test acc:0.270
Epoch:14, Train loss: 1.17 Train acc: 0.585, Test acc:0.254
Epoch:0, Train loss: 1.73 Train acc: 0.351, Test acc:0.308
Epoch:1, Train loss: 1.54 Train acc: 0.454, Test acc:0.286
Epoch:2, Train loss: 1.44 Train acc: 0.517, Test acc:0.282
Epoch:3, Train loss: 1.38 Train acc: 0.544, Test acc:0.266
Epoch:4, Train loss: 1.33 Train acc: 0.539, Test acc:0.274
Epoch:5, Train loss: 1.30 Train acc: 0.541, Test acc:0.270
Epoch:6, Train loss: 1.27 Train acc: 0.539, Test acc:0.274
Epoch:7, Train loss: 1.24 Train acc: 0.563, Test acc:0.268
Epoch:8, Train loss: 1.22 Train acc: 0.569, Test acc:0.276
Epoch:9, Train loss: 1.20 Train acc: 0.575, Test ac

Epoch:0, Train loss: 2.02 Train acc: 0.278, Test acc:0.310
Epoch:1, Train loss: 1.67 Train acc: 0.285, Test acc:0.284
Epoch:2, Train loss: 1.57 Train acc: 0.364, Test acc:0.278
Epoch:3, Train loss: 1.49 Train acc: 0.401, Test acc:0.262
Epoch:4, Train loss: 1.42 Train acc: 0.500, Test acc:0.252
Epoch:5, Train loss: 1.37 Train acc: 0.499, Test acc:0.286
Epoch:6, Train loss: 1.34 Train acc: 0.526, Test acc:0.258
Epoch:7, Train loss: 1.30 Train acc: 0.529, Test acc:0.262
Epoch:8, Train loss: 1.28 Train acc: 0.543, Test acc:0.254
Epoch:9, Train loss: 1.25 Train acc: 0.547, Test acc:0.268
Epoch:10, Train loss: 1.23 Train acc: 0.558, Test acc:0.268
Epoch:11, Train loss: 1.21 Train acc: 0.574, Test acc:0.262
Epoch:12, Train loss: 1.20 Train acc: 0.578, Test acc:0.266
Epoch:13, Train loss: 1.18 Train acc: 0.579, Test acc:0.270
Epoch:14, Train loss: 1.16 Train acc: 0.588, Test acc:0.266
Epoch:0, Train loss: 1.67 Train acc: 0.392, Test acc:0.300
Epoch:1, Train loss: 1.50 Train acc: 0.481, Test ac

Epoch:7, Train loss: 1.29 Train acc: 0.549, Test acc:0.264
Epoch:8, Train loss: 1.26 Train acc: 0.554, Test acc:0.262
Epoch:9, Train loss: 1.24 Train acc: 0.571, Test acc:0.250
Epoch:10, Train loss: 1.21 Train acc: 0.578, Test acc:0.254
Epoch:11, Train loss: 1.19 Train acc: 0.593, Test acc:0.252
Epoch:12, Train loss: 1.17 Train acc: 0.598, Test acc:0.260
Epoch:13, Train loss: 1.16 Train acc: 0.604, Test acc:0.262
Epoch:14, Train loss: 1.14 Train acc: 0.607, Test acc:0.260
Epoch:0, Train loss: 1.79 Train acc: 0.318, Test acc:0.298
Epoch:1, Train loss: 1.63 Train acc: 0.362, Test acc:0.262
Epoch:2, Train loss: 1.53 Train acc: 0.401, Test acc:0.264
Epoch:3, Train loss: 1.44 Train acc: 0.439, Test acc:0.304
Epoch:4, Train loss: 1.38 Train acc: 0.473, Test acc:0.276
Epoch:5, Train loss: 1.33 Train acc: 0.538, Test acc:0.264
Epoch:6, Train loss: 1.30 Train acc: 0.546, Test acc:0.258
Epoch:7, Train loss: 1.27 Train acc: 0.551, Test acc:0.254
Epoch:8, Train loss: 1.24 Train acc: 0.552, Test ac

Epoch:14, Train loss: 1.13 Train acc: 0.593, Test acc:0.262
Epoch:0, Train loss: 1.81 Train acc: 0.315, Test acc:0.300
Epoch:1, Train loss: 1.59 Train acc: 0.391, Test acc:0.286
Epoch:2, Train loss: 1.50 Train acc: 0.444, Test acc:0.268
Epoch:3, Train loss: 1.42 Train acc: 0.483, Test acc:0.272
Epoch:4, Train loss: 1.36 Train acc: 0.536, Test acc:0.272
Epoch:5, Train loss: 1.32 Train acc: 0.548, Test acc:0.268
Epoch:6, Train loss: 1.29 Train acc: 0.550, Test acc:0.276
Epoch:7, Train loss: 1.26 Train acc: 0.555, Test acc:0.274
Epoch:8, Train loss: 1.24 Train acc: 0.563, Test acc:0.272
Epoch:9, Train loss: 1.21 Train acc: 0.568, Test acc:0.264
Epoch:10, Train loss: 1.20 Train acc: 0.583, Test acc:0.266
Epoch:11, Train loss: 1.18 Train acc: 0.573, Test acc:0.278
Epoch:12, Train loss: 1.16 Train acc: 0.589, Test acc:0.260
Epoch:13, Train loss: 1.15 Train acc: 0.585, Test acc:0.258
Epoch:14, Train loss: 1.13 Train acc: 0.604, Test acc:0.256
Epoch:0, Train loss: 1.71 Train acc: 0.294, Test a

Epoch:8, Train loss: 1.26 Train acc: 0.541, Test acc:0.270
Epoch:9, Train loss: 1.24 Train acc: 0.545, Test acc:0.280
Epoch:10, Train loss: 1.22 Train acc: 0.557, Test acc:0.282
Epoch:11, Train loss: 1.20 Train acc: 0.571, Test acc:0.270
Epoch:12, Train loss: 1.18 Train acc: 0.577, Test acc:0.274
Epoch:13, Train loss: 1.17 Train acc: 0.583, Test acc:0.270
Epoch:14, Train loss: 1.15 Train acc: 0.593, Test acc:0.258
Epoch:0, Train loss: 1.66 Train acc: 0.433, Test acc:0.324
Epoch:1, Train loss: 1.51 Train acc: 0.452, Test acc:0.312
Epoch:2, Train loss: 1.42 Train acc: 0.472, Test acc:0.338
Epoch:3, Train loss: 1.37 Train acc: 0.493, Test acc:0.324
Epoch:4, Train loss: 1.32 Train acc: 0.530, Test acc:0.278
Epoch:5, Train loss: 1.29 Train acc: 0.542, Test acc:0.288
Epoch:6, Train loss: 1.26 Train acc: 0.546, Test acc:0.284
Epoch:7, Train loss: 1.23 Train acc: 0.557, Test acc:0.278
Epoch:8, Train loss: 1.21 Train acc: 0.564, Test acc:0.278
Epoch:9, Train loss: 1.19 Train acc: 0.585, Test ac

Epoch:13, Train loss: 1.16 Train acc: 0.592, Test acc:0.274
Epoch:14, Train loss: 1.14 Train acc: 0.594, Test acc:0.270
Epoch:0, Train loss: 1.85 Train acc: 0.321, Test acc:0.308
Epoch:1, Train loss: 1.63 Train acc: 0.385, Test acc:0.284
Epoch:2, Train loss: 1.52 Train acc: 0.454, Test acc:0.280
Epoch:3, Train loss: 1.44 Train acc: 0.466, Test acc:0.282
Epoch:4, Train loss: 1.38 Train acc: 0.476, Test acc:0.276
Epoch:5, Train loss: 1.34 Train acc: 0.503, Test acc:0.270
Epoch:6, Train loss: 1.31 Train acc: 0.531, Test acc:0.270
Epoch:7, Train loss: 1.28 Train acc: 0.540, Test acc:0.280
Epoch:8, Train loss: 1.25 Train acc: 0.559, Test acc:0.264
Epoch:9, Train loss: 1.23 Train acc: 0.566, Test acc:0.266
Epoch:10, Train loss: 1.21 Train acc: 0.578, Test acc:0.270
Epoch:11, Train loss: 1.19 Train acc: 0.584, Test acc:0.264
Epoch:12, Train loss: 1.17 Train acc: 0.593, Test acc:0.260
Epoch:13, Train loss: 1.16 Train acc: 0.593, Test acc:0.270
Epoch:14, Train loss: 1.14 Train acc: 0.597, Test 

Epoch:5, Train loss: 1.31 Train acc: 0.533, Test acc:0.268
Epoch:6, Train loss: 1.28 Train acc: 0.543, Test acc:0.272
Epoch:7, Train loss: 1.25 Train acc: 0.551, Test acc:0.266
Epoch:8, Train loss: 1.22 Train acc: 0.557, Test acc:0.272
Epoch:9, Train loss: 1.20 Train acc: 0.579, Test acc:0.266
Epoch:10, Train loss: 1.18 Train acc: 0.587, Test acc:0.268
Epoch:11, Train loss: 1.16 Train acc: 0.591, Test acc:0.264
Epoch:12, Train loss: 1.15 Train acc: 0.601, Test acc:0.262
Epoch:13, Train loss: 1.13 Train acc: 0.601, Test acc:0.268
Epoch:14, Train loss: 1.12 Train acc: 0.615, Test acc:0.262
Epoch:0, Train loss: 1.61 Train acc: 0.355, Test acc:0.304
Epoch:1, Train loss: 1.50 Train acc: 0.425, Test acc:0.290
Epoch:2, Train loss: 1.43 Train acc: 0.504, Test acc:0.276
Epoch:3, Train loss: 1.37 Train acc: 0.527, Test acc:0.278
Epoch:4, Train loss: 1.32 Train acc: 0.541, Test acc:0.274
Epoch:5, Train loss: 1.29 Train acc: 0.552, Test acc:0.268
Epoch:6, Train loss: 1.26 Train acc: 0.558, Test ac

Epoch:13, Train loss: 1.16 Train acc: 0.579, Test acc:0.278
Epoch:14, Train loss: 1.15 Train acc: 0.588, Test acc:0.274
Epoch:0, Train loss: 1.72 Train acc: 0.329, Test acc:0.270
Epoch:1, Train loss: 1.55 Train acc: 0.393, Test acc:0.280
Epoch:2, Train loss: 1.45 Train acc: 0.430, Test acc:0.278
Epoch:3, Train loss: 1.39 Train acc: 0.484, Test acc:0.300
Epoch:4, Train loss: 1.34 Train acc: 0.511, Test acc:0.280
Epoch:5, Train loss: 1.30 Train acc: 0.532, Test acc:0.278
Epoch:6, Train loss: 1.27 Train acc: 0.540, Test acc:0.272
Epoch:7, Train loss: 1.24 Train acc: 0.547, Test acc:0.270
Epoch:8, Train loss: 1.22 Train acc: 0.563, Test acc:0.274
Epoch:9, Train loss: 1.20 Train acc: 0.572, Test acc:0.278
Epoch:10, Train loss: 1.18 Train acc: 0.574, Test acc:0.274
Epoch:11, Train loss: 1.16 Train acc: 0.577, Test acc:0.274
Epoch:12, Train loss: 1.15 Train acc: 0.585, Test acc:0.278
Epoch:13, Train loss: 1.13 Train acc: 0.602, Test acc:0.250
Epoch:14, Train loss: 1.12 Train acc: 0.597, Test 

Epoch:6, Train loss: 1.32 Train acc: 0.504, Test acc:0.294
Epoch:7, Train loss: 1.28 Train acc: 0.527, Test acc:0.282
Epoch:8, Train loss: 1.25 Train acc: 0.554, Test acc:0.270
Epoch:9, Train loss: 1.23 Train acc: 0.567, Test acc:0.274
Epoch:10, Train loss: 1.21 Train acc: 0.572, Test acc:0.272
Epoch:11, Train loss: 1.19 Train acc: 0.597, Test acc:0.260
Epoch:12, Train loss: 1.17 Train acc: 0.599, Test acc:0.266
Epoch:13, Train loss: 1.15 Train acc: 0.603, Test acc:0.254
Epoch:14, Train loss: 1.14 Train acc: 0.603, Test acc:0.268
Epoch:0, Train loss: 1.69 Train acc: 0.337, Test acc:0.294
Epoch:1, Train loss: 1.55 Train acc: 0.428, Test acc:0.286
Epoch:2, Train loss: 1.48 Train acc: 0.460, Test acc:0.252
Epoch:3, Train loss: 1.40 Train acc: 0.494, Test acc:0.278
Epoch:4, Train loss: 1.35 Train acc: 0.528, Test acc:0.288
Epoch:5, Train loss: 1.31 Train acc: 0.527, Test acc:0.286
Epoch:6, Train loss: 1.28 Train acc: 0.549, Test acc:0.282
Epoch:7, Train loss: 1.25 Train acc: 0.564, Test ac

Epoch:12, Train loss: 1.15 Train acc: 0.600, Test acc:0.276
Epoch:13, Train loss: 1.14 Train acc: 0.605, Test acc:0.276
Epoch:14, Train loss: 1.12 Train acc: 0.621, Test acc:0.258
Epoch:0, Train loss: 1.63 Train acc: 0.370, Test acc:0.314
Epoch:1, Train loss: 1.52 Train acc: 0.463, Test acc:0.290
Epoch:2, Train loss: 1.45 Train acc: 0.476, Test acc:0.296
Epoch:3, Train loss: 1.39 Train acc: 0.516, Test acc:0.292
Epoch:4, Train loss: 1.34 Train acc: 0.544, Test acc:0.286
Epoch:5, Train loss: 1.30 Train acc: 0.528, Test acc:0.290
Epoch:6, Train loss: 1.27 Train acc: 0.556, Test acc:0.274
Epoch:7, Train loss: 1.24 Train acc: 0.561, Test acc:0.278
Epoch:8, Train loss: 1.22 Train acc: 0.560, Test acc:0.286
Epoch:9, Train loss: 1.20 Train acc: 0.578, Test acc:0.272
Epoch:10, Train loss: 1.18 Train acc: 0.606, Test acc:0.244
Epoch:11, Train loss: 1.16 Train acc: 0.586, Test acc:0.280
Epoch:12, Train loss: 1.14 Train acc: 0.595, Test acc:0.272
Epoch:13, Train loss: 1.13 Train acc: 0.612, Test 

Epoch:7, Train loss: 1.26 Train acc: 0.552, Test acc:0.260
Epoch:8, Train loss: 1.23 Train acc: 0.559, Test acc:0.272
Epoch:9, Train loss: 1.21 Train acc: 0.578, Test acc:0.280
Epoch:10, Train loss: 1.19 Train acc: 0.576, Test acc:0.274
Epoch:11, Train loss: 1.17 Train acc: 0.596, Test acc:0.274
Epoch:12, Train loss: 1.16 Train acc: 0.596, Test acc:0.270
Epoch:13, Train loss: 1.14 Train acc: 0.600, Test acc:0.270
Epoch:14, Train loss: 1.13 Train acc: 0.602, Test acc:0.272
Epoch:0, Train loss: 1.90 Train acc: 0.260, Test acc:0.278
Epoch:1, Train loss: 1.66 Train acc: 0.367, Test acc:0.300
Epoch:2, Train loss: 1.55 Train acc: 0.431, Test acc:0.290
Epoch:3, Train loss: 1.46 Train acc: 0.478, Test acc:0.292
Epoch:4, Train loss: 1.40 Train acc: 0.508, Test acc:0.302
Epoch:5, Train loss: 1.35 Train acc: 0.521, Test acc:0.308
Epoch:6, Train loss: 1.31 Train acc: 0.541, Test acc:0.294
Epoch:7, Train loss: 1.28 Train acc: 0.548, Test acc:0.296
Epoch:8, Train loss: 1.26 Train acc: 0.552, Test ac

Epoch:12, Train loss: 1.12 Train acc: 0.608, Test acc:0.262
Epoch:13, Train loss: 1.11 Train acc: 0.614, Test acc:0.266
Epoch:14, Train loss: 1.09 Train acc: 0.620, Test acc:0.262
Epoch:0, Train loss: 1.67 Train acc: 0.388, Test acc:0.296
Epoch:1, Train loss: 1.53 Train acc: 0.429, Test acc:0.302
Epoch:2, Train loss: 1.46 Train acc: 0.476, Test acc:0.270
Epoch:3, Train loss: 1.39 Train acc: 0.514, Test acc:0.286
Epoch:4, Train loss: 1.34 Train acc: 0.500, Test acc:0.292
Epoch:5, Train loss: 1.30 Train acc: 0.534, Test acc:0.278
Epoch:6, Train loss: 1.27 Train acc: 0.557, Test acc:0.268
Epoch:7, Train loss: 1.24 Train acc: 0.571, Test acc:0.260
Epoch:8, Train loss: 1.21 Train acc: 0.572, Test acc:0.272
Epoch:9, Train loss: 1.19 Train acc: 0.593, Test acc:0.258
Epoch:10, Train loss: 1.17 Train acc: 0.601, Test acc:0.262
Epoch:11, Train loss: 1.15 Train acc: 0.612, Test acc:0.256
Epoch:12, Train loss: 1.14 Train acc: 0.613, Test acc:0.248
Epoch:13, Train loss: 1.12 Train acc: 0.617, Test 

Epoch:6, Train loss: 1.28 Train acc: 0.545, Test acc:0.268
Epoch:7, Train loss: 1.25 Train acc: 0.557, Test acc:0.272
Epoch:8, Train loss: 1.23 Train acc: 0.562, Test acc:0.274
Epoch:9, Train loss: 1.21 Train acc: 0.571, Test acc:0.274
Epoch:10, Train loss: 1.19 Train acc: 0.578, Test acc:0.274
Epoch:11, Train loss: 1.17 Train acc: 0.583, Test acc:0.270
Epoch:12, Train loss: 1.15 Train acc: 0.588, Test acc:0.272
Epoch:13, Train loss: 1.14 Train acc: 0.597, Test acc:0.264
Epoch:14, Train loss: 1.12 Train acc: 0.597, Test acc:0.268
Epoch:0, Train loss: 1.97 Train acc: 0.291, Test acc:0.270
Epoch:1, Train loss: 1.70 Train acc: 0.356, Test acc:0.308
Epoch:2, Train loss: 1.54 Train acc: 0.421, Test acc:0.294
Epoch:3, Train loss: 1.47 Train acc: 0.424, Test acc:0.268
Epoch:4, Train loss: 1.42 Train acc: 0.464, Test acc:0.256
Epoch:5, Train loss: 1.36 Train acc: 0.506, Test acc:0.270
Epoch:6, Train loss: 1.33 Train acc: 0.535, Test acc:0.274
Epoch:7, Train loss: 1.29 Train acc: 0.559, Test ac

Epoch:10, Train loss: 1.18 Train acc: 0.582, Test acc:0.262
Epoch:11, Train loss: 1.17 Train acc: 0.588, Test acc:0.266
Epoch:12, Train loss: 1.15 Train acc: 0.603, Test acc:0.258
Epoch:13, Train loss: 1.13 Train acc: 0.597, Test acc:0.264
Epoch:14, Train loss: 1.12 Train acc: 0.614, Test acc:0.260
Epoch:0, Train loss: 1.72 Train acc: 0.339, Test acc:0.280
Epoch:1, Train loss: 1.55 Train acc: 0.443, Test acc:0.280
Epoch:2, Train loss: 1.45 Train acc: 0.464, Test acc:0.284
Epoch:3, Train loss: 1.39 Train acc: 0.506, Test acc:0.280
Epoch:4, Train loss: 1.35 Train acc: 0.530, Test acc:0.282
Epoch:5, Train loss: 1.31 Train acc: 0.527, Test acc:0.284
Epoch:6, Train loss: 1.28 Train acc: 0.548, Test acc:0.268
Epoch:7, Train loss: 1.25 Train acc: 0.554, Test acc:0.274
Epoch:8, Train loss: 1.23 Train acc: 0.578, Test acc:0.256
Epoch:9, Train loss: 1.20 Train acc: 0.566, Test acc:0.266
Epoch:10, Train loss: 1.18 Train acc: 0.575, Test acc:0.272
Epoch:11, Train loss: 1.17 Train acc: 0.585, Test 

Epoch:0, Train loss: 1.88 Train acc: 0.276, Test acc:0.276
Epoch:1, Train loss: 1.68 Train acc: 0.332, Test acc:0.294
Epoch:2, Train loss: 1.57 Train acc: 0.396, Test acc:0.258
Epoch:3, Train loss: 1.47 Train acc: 0.432, Test acc:0.264
Epoch:4, Train loss: 1.41 Train acc: 0.488, Test acc:0.292
Epoch:5, Train loss: 1.36 Train acc: 0.489, Test acc:0.294
Epoch:6, Train loss: 1.32 Train acc: 0.525, Test acc:0.280
Epoch:7, Train loss: 1.29 Train acc: 0.539, Test acc:0.274
Epoch:8, Train loss: 1.26 Train acc: 0.549, Test acc:0.272
Epoch:9, Train loss: 1.24 Train acc: 0.554, Test acc:0.278
Epoch:10, Train loss: 1.21 Train acc: 0.572, Test acc:0.278
Epoch:11, Train loss: 1.20 Train acc: 0.574, Test acc:0.274
Epoch:12, Train loss: 1.18 Train acc: 0.580, Test acc:0.270
Epoch:13, Train loss: 1.16 Train acc: 0.592, Test acc:0.270
Epoch:14, Train loss: 1.15 Train acc: 0.597, Test acc:0.266
Epoch:0, Train loss: 1.64 Train acc: 0.350, Test acc:0.258
Epoch:1, Train loss: 1.52 Train acc: 0.392, Test ac

Epoch:6, Train loss: 1.29 Train acc: 0.554, Test acc:0.262
Epoch:7, Train loss: 1.26 Train acc: 0.566, Test acc:0.270
Epoch:8, Train loss: 1.23 Train acc: 0.567, Test acc:0.274
Epoch:9, Train loss: 1.21 Train acc: 0.571, Test acc:0.270
Epoch:10, Train loss: 1.19 Train acc: 0.579, Test acc:0.266
Epoch:11, Train loss: 1.17 Train acc: 0.587, Test acc:0.268
Epoch:12, Train loss: 1.16 Train acc: 0.595, Test acc:0.262
Epoch:13, Train loss: 1.14 Train acc: 0.600, Test acc:0.262
Epoch:14, Train loss: 1.13 Train acc: 0.609, Test acc:0.254
Epoch:0, Train loss: 1.91 Train acc: 0.176, Test acc:0.150
Epoch:1, Train loss: 1.65 Train acc: 0.308, Test acc:0.272
Epoch:2, Train loss: 1.56 Train acc: 0.390, Test acc:0.292
Epoch:3, Train loss: 1.48 Train acc: 0.436, Test acc:0.302
Epoch:4, Train loss: 1.40 Train acc: 0.506, Test acc:0.270
Epoch:5, Train loss: 1.36 Train acc: 0.531, Test acc:0.260
Epoch:6, Train loss: 1.32 Train acc: 0.543, Test acc:0.260
Epoch:7, Train loss: 1.29 Train acc: 0.556, Test ac

Epoch:10, Train loss: 1.20 Train acc: 0.579, Test acc:0.270
Epoch:11, Train loss: 1.18 Train acc: 0.585, Test acc:0.272
Epoch:12, Train loss: 1.16 Train acc: 0.600, Test acc:0.272
Epoch:13, Train loss: 1.15 Train acc: 0.598, Test acc:0.276
Epoch:14, Train loss: 1.13 Train acc: 0.597, Test acc:0.264
Epoch:0, Train loss: 1.75 Train acc: 0.300, Test acc:0.298
Epoch:1, Train loss: 1.59 Train acc: 0.358, Test acc:0.302
Epoch:2, Train loss: 1.49 Train acc: 0.446, Test acc:0.264
Epoch:3, Train loss: 1.41 Train acc: 0.462, Test acc:0.288
Epoch:4, Train loss: 1.36 Train acc: 0.485, Test acc:0.300
Epoch:5, Train loss: 1.32 Train acc: 0.509, Test acc:0.282
Epoch:6, Train loss: 1.29 Train acc: 0.525, Test acc:0.272
Epoch:7, Train loss: 1.26 Train acc: 0.543, Test acc:0.270
Epoch:8, Train loss: 1.23 Train acc: 0.549, Test acc:0.282
Epoch:9, Train loss: 1.21 Train acc: 0.567, Test acc:0.266
Epoch:10, Train loss: 1.19 Train acc: 0.573, Test acc:0.264
Epoch:11, Train loss: 1.17 Train acc: 0.573, Test 

In [49]:
## Training parameters
batch_size = 128
epochs=15
ntrials = 200
max_test_accs = []
for trial in range(ntrials):
    training_accuracy = []
    training_loss = []
    testing_accuracy = []
    s.run(tf.global_variables_initializer())
    for epoch in range(epochs):    
        arr = np.arange(tr_n)
        np.random.shuffle(arr)
        for index in range(0,tr_n,batch_size):
            s.run(optimizer, {input_X: tr_X[arr[index:index+batch_size]],
                              input_y: tr_y[arr[index:index+batch_size]]})
        training_accuracy.append(s.run(accuracy, feed_dict= {input_X:tr_X, 
                                                             input_y: tr_y}))
        training_loss.append(s.run(loss, {input_X: tr_X, 
                                          input_y: tr_y}))

        ## Evaluation of model
        test_preds = s.run(predicted_y, {input_X: test_X}).argmax(1)
        test_acc = np.sum(test_y[np.arange(test_n), test_preds] != 0)/test_n
        testing_accuracy.append(test_acc)
        print("Epoch:{0}, Train loss: {1:.2f} Train acc: {2:.3f}, Test acc:{3:.3f}".format(epoch,
                                                                        training_loss[epoch],
                                                                        training_accuracy[epoch],
                                                                       testing_accuracy[epoch]))
    max_test_accs.append(max(testing_accuracy))
print('Average max test accuracy:', np.mean(max_test_accs))

Epoch:0, Train loss: 1.74 Train acc: 0.288, Test acc:0.266
Epoch:1, Train loss: 1.70 Train acc: 0.291, Test acc:0.270
Epoch:2, Train loss: 1.68 Train acc: 0.289, Test acc:0.260
Epoch:3, Train loss: 1.66 Train acc: 0.305, Test acc:0.300
Epoch:4, Train loss: 1.65 Train acc: 0.304, Test acc:0.256
Epoch:5, Train loss: 1.65 Train acc: 0.345, Test acc:0.262
Epoch:6, Train loss: 1.64 Train acc: 0.334, Test acc:0.290
Epoch:7, Train loss: 1.64 Train acc: 0.376, Test acc:0.290
Epoch:8, Train loss: 1.63 Train acc: 0.388, Test acc:0.284
Epoch:9, Train loss: 1.63 Train acc: 0.383, Test acc:0.296
Epoch:10, Train loss: 1.62 Train acc: 0.388, Test acc:0.280
Epoch:11, Train loss: 1.62 Train acc: 0.386, Test acc:0.310
Epoch:12, Train loss: 1.62 Train acc: 0.383, Test acc:0.286
Epoch:13, Train loss: 1.62 Train acc: 0.380, Test acc:0.310
Epoch:14, Train loss: 1.61 Train acc: 0.350, Test acc:0.292
Epoch:0, Train loss: 1.86 Train acc: 0.290, Test acc:0.278
Epoch:1, Train loss: 1.69 Train acc: 0.290, Test ac

Epoch:0, Train loss: 2.24 Train acc: 0.047, Test acc:0.060
Epoch:1, Train loss: 1.80 Train acc: 0.285, Test acc:0.310
Epoch:2, Train loss: 1.67 Train acc: 0.296, Test acc:0.316
Epoch:3, Train loss: 1.66 Train acc: 0.291, Test acc:0.282
Epoch:4, Train loss: 1.65 Train acc: 0.292, Test acc:0.284
Epoch:5, Train loss: 1.64 Train acc: 0.347, Test acc:0.308
Epoch:6, Train loss: 1.64 Train acc: 0.335, Test acc:0.310
Epoch:7, Train loss: 1.64 Train acc: 0.351, Test acc:0.300
Epoch:8, Train loss: 1.63 Train acc: 0.330, Test acc:0.302
Epoch:9, Train loss: 1.63 Train acc: 0.361, Test acc:0.310
Epoch:10, Train loss: 1.63 Train acc: 0.356, Test acc:0.302
Epoch:11, Train loss: 1.62 Train acc: 0.349, Test acc:0.298
Epoch:12, Train loss: 1.62 Train acc: 0.355, Test acc:0.310
Epoch:13, Train loss: 1.62 Train acc: 0.360, Test acc:0.316
Epoch:14, Train loss: 1.61 Train acc: 0.359, Test acc:0.300
Epoch:0, Train loss: 1.79 Train acc: 0.299, Test acc:0.286
Epoch:1, Train loss: 1.68 Train acc: 0.290, Test ac

Epoch:0, Train loss: 1.92 Train acc: 0.110, Test acc:0.160
Epoch:1, Train loss: 1.72 Train acc: 0.290, Test acc:0.278
Epoch:2, Train loss: 1.67 Train acc: 0.320, Test acc:0.284
Epoch:3, Train loss: 1.67 Train acc: 0.298, Test acc:0.306
Epoch:4, Train loss: 1.66 Train acc: 0.319, Test acc:0.308
Epoch:5, Train loss: 1.65 Train acc: 0.289, Test acc:0.278
Epoch:6, Train loss: 1.64 Train acc: 0.343, Test acc:0.276
Epoch:7, Train loss: 1.64 Train acc: 0.366, Test acc:0.304
Epoch:8, Train loss: 1.64 Train acc: 0.364, Test acc:0.298
Epoch:9, Train loss: 1.63 Train acc: 0.365, Test acc:0.302
Epoch:10, Train loss: 1.63 Train acc: 0.367, Test acc:0.288
Epoch:11, Train loss: 1.63 Train acc: 0.369, Test acc:0.312
Epoch:12, Train loss: 1.62 Train acc: 0.352, Test acc:0.296
Epoch:13, Train loss: 1.62 Train acc: 0.366, Test acc:0.300
Epoch:14, Train loss: 1.62 Train acc: 0.366, Test acc:0.304
Epoch:0, Train loss: 2.02 Train acc: 0.047, Test acc:0.062
Epoch:1, Train loss: 1.74 Train acc: 0.285, Test ac

Epoch:0, Train loss: 1.75 Train acc: 0.298, Test acc:0.316
Epoch:1, Train loss: 1.67 Train acc: 0.269, Test acc:0.268
Epoch:2, Train loss: 1.68 Train acc: 0.291, Test acc:0.278
Epoch:3, Train loss: 1.67 Train acc: 0.266, Test acc:0.290
Epoch:4, Train loss: 1.65 Train acc: 0.289, Test acc:0.272
Epoch:5, Train loss: 1.64 Train acc: 0.314, Test acc:0.284
Epoch:6, Train loss: 1.64 Train acc: 0.320, Test acc:0.288
Epoch:7, Train loss: 1.63 Train acc: 0.340, Test acc:0.284
Epoch:8, Train loss: 1.63 Train acc: 0.340, Test acc:0.304
Epoch:9, Train loss: 1.63 Train acc: 0.370, Test acc:0.312
Epoch:10, Train loss: 1.62 Train acc: 0.365, Test acc:0.318
Epoch:11, Train loss: 1.62 Train acc: 0.363, Test acc:0.280
Epoch:12, Train loss: 1.62 Train acc: 0.370, Test acc:0.310
Epoch:13, Train loss: 1.61 Train acc: 0.368, Test acc:0.318
Epoch:14, Train loss: 1.61 Train acc: 0.368, Test acc:0.308
Epoch:0, Train loss: 1.91 Train acc: 0.290, Test acc:0.278
Epoch:1, Train loss: 1.71 Train acc: 0.290, Test ac

Epoch:0, Train loss: 1.71 Train acc: 0.301, Test acc:0.310
Epoch:1, Train loss: 1.68 Train acc: 0.290, Test acc:0.278
Epoch:2, Train loss: 1.66 Train acc: 0.301, Test acc:0.272
Epoch:3, Train loss: 1.65 Train acc: 0.314, Test acc:0.318
Epoch:4, Train loss: 1.65 Train acc: 0.306, Test acc:0.272
Epoch:5, Train loss: 1.64 Train acc: 0.339, Test acc:0.302
Epoch:6, Train loss: 1.64 Train acc: 0.349, Test acc:0.306
Epoch:7, Train loss: 1.63 Train acc: 0.344, Test acc:0.310
Epoch:8, Train loss: 1.63 Train acc: 0.355, Test acc:0.302
Epoch:9, Train loss: 1.62 Train acc: 0.365, Test acc:0.302
Epoch:10, Train loss: 1.62 Train acc: 0.366, Test acc:0.312
Epoch:11, Train loss: 1.62 Train acc: 0.342, Test acc:0.302
Epoch:12, Train loss: 1.61 Train acc: 0.365, Test acc:0.312
Epoch:13, Train loss: 1.61 Train acc: 0.374, Test acc:0.300
Epoch:14, Train loss: 1.61 Train acc: 0.356, Test acc:0.308
Epoch:0, Train loss: 2.20 Train acc: 0.012, Test acc:0.024
Epoch:1, Train loss: 1.86 Train acc: 0.285, Test ac

Epoch:0, Train loss: 2.79 Train acc: 0.025, Test acc:0.024
Epoch:1, Train loss: 2.16 Train acc: 0.167, Test acc:0.168
Epoch:2, Train loss: 1.86 Train acc: 0.285, Test acc:0.310
Epoch:3, Train loss: 1.74 Train acc: 0.290, Test acc:0.306
Epoch:4, Train loss: 1.70 Train acc: 0.249, Test acc:0.266
Epoch:5, Train loss: 1.69 Train acc: 0.290, Test acc:0.278
Epoch:6, Train loss: 1.68 Train acc: 0.231, Test acc:0.274
Epoch:7, Train loss: 1.67 Train acc: 0.293, Test acc:0.306
Epoch:8, Train loss: 1.66 Train acc: 0.296, Test acc:0.304
Epoch:9, Train loss: 1.66 Train acc: 0.253, Test acc:0.296
Epoch:10, Train loss: 1.65 Train acc: 0.281, Test acc:0.296
Epoch:11, Train loss: 1.65 Train acc: 0.265, Test acc:0.294
Epoch:12, Train loss: 1.65 Train acc: 0.268, Test acc:0.304
Epoch:13, Train loss: 1.64 Train acc: 0.284, Test acc:0.290
Epoch:14, Train loss: 1.64 Train acc: 0.297, Test acc:0.310
Epoch:0, Train loss: 1.71 Train acc: 0.296, Test acc:0.310
Epoch:1, Train loss: 1.67 Train acc: 0.298, Test ac

Epoch:0, Train loss: 2.16 Train acc: 0.282, Test acc:0.308
Epoch:1, Train loss: 1.87 Train acc: 0.285, Test acc:0.310
Epoch:2, Train loss: 1.69 Train acc: 0.297, Test acc:0.310
Epoch:3, Train loss: 1.68 Train acc: 0.290, Test acc:0.278
Epoch:4, Train loss: 1.67 Train acc: 0.290, Test acc:0.278
Epoch:5, Train loss: 1.65 Train acc: 0.350, Test acc:0.318
Epoch:6, Train loss: 1.65 Train acc: 0.305, Test acc:0.318
Epoch:7, Train loss: 1.64 Train acc: 0.357, Test acc:0.312
Epoch:8, Train loss: 1.64 Train acc: 0.343, Test acc:0.280
Epoch:9, Train loss: 1.63 Train acc: 0.354, Test acc:0.298
Epoch:10, Train loss: 1.63 Train acc: 0.360, Test acc:0.318
Epoch:11, Train loss: 1.63 Train acc: 0.360, Test acc:0.296
Epoch:12, Train loss: 1.63 Train acc: 0.362, Test acc:0.296
Epoch:13, Train loss: 1.62 Train acc: 0.360, Test acc:0.298
Epoch:14, Train loss: 1.62 Train acc: 0.367, Test acc:0.300
Epoch:0, Train loss: 1.80 Train acc: 0.180, Test acc:0.220
Epoch:1, Train loss: 1.69 Train acc: 0.250, Test ac

Epoch:0, Train loss: 2.36 Train acc: 0.047, Test acc:0.060
Epoch:1, Train loss: 1.94 Train acc: 0.290, Test acc:0.278
Epoch:2, Train loss: 1.74 Train acc: 0.290, Test acc:0.278
Epoch:3, Train loss: 1.66 Train acc: 0.297, Test acc:0.298
Epoch:4, Train loss: 1.66 Train acc: 0.286, Test acc:0.308
Epoch:5, Train loss: 1.65 Train acc: 0.297, Test acc:0.310
Epoch:6, Train loss: 1.64 Train acc: 0.313, Test acc:0.314
Epoch:7, Train loss: 1.64 Train acc: 0.318, Test acc:0.314
Epoch:8, Train loss: 1.63 Train acc: 0.354, Test acc:0.304
Epoch:9, Train loss: 1.63 Train acc: 0.366, Test acc:0.304
Epoch:10, Train loss: 1.62 Train acc: 0.363, Test acc:0.308
Epoch:11, Train loss: 1.62 Train acc: 0.360, Test acc:0.304
Epoch:12, Train loss: 1.62 Train acc: 0.367, Test acc:0.290
Epoch:13, Train loss: 1.62 Train acc: 0.367, Test acc:0.292
Epoch:14, Train loss: 1.61 Train acc: 0.360, Test acc:0.304
Epoch:0, Train loss: 1.71 Train acc: 0.284, Test acc:0.270
Epoch:1, Train loss: 1.67 Train acc: 0.266, Test ac

Epoch:0, Train loss: 2.03 Train acc: 0.290, Test acc:0.278
Epoch:1, Train loss: 1.75 Train acc: 0.288, Test acc:0.278
Epoch:2, Train loss: 1.68 Train acc: 0.285, Test acc:0.308
Epoch:3, Train loss: 1.68 Train acc: 0.286, Test acc:0.308
Epoch:4, Train loss: 1.65 Train acc: 0.342, Test acc:0.302
Epoch:5, Train loss: 1.65 Train acc: 0.291, Test acc:0.276
Epoch:6, Train loss: 1.64 Train acc: 0.318, Test acc:0.286
Epoch:7, Train loss: 1.64 Train acc: 0.364, Test acc:0.304
Epoch:8, Train loss: 1.63 Train acc: 0.353, Test acc:0.304
Epoch:9, Train loss: 1.63 Train acc: 0.355, Test acc:0.298
Epoch:10, Train loss: 1.63 Train acc: 0.358, Test acc:0.302
Epoch:11, Train loss: 1.62 Train acc: 0.374, Test acc:0.302
Epoch:12, Train loss: 1.62 Train acc: 0.369, Test acc:0.304
Epoch:13, Train loss: 1.62 Train acc: 0.368, Test acc:0.304
Epoch:14, Train loss: 1.61 Train acc: 0.366, Test acc:0.308
Epoch:0, Train loss: 1.75 Train acc: 0.290, Test acc:0.278
Epoch:1, Train loss: 1.68 Train acc: 0.292, Test ac

Epoch:0, Train loss: 2.43 Train acc: 0.047, Test acc:0.060
Epoch:1, Train loss: 1.96 Train acc: 0.290, Test acc:0.278
Epoch:2, Train loss: 1.76 Train acc: 0.290, Test acc:0.278
Epoch:3, Train loss: 1.67 Train acc: 0.288, Test acc:0.300
Epoch:4, Train loss: 1.66 Train acc: 0.293, Test acc:0.310
Epoch:5, Train loss: 1.65 Train acc: 0.327, Test acc:0.316
Epoch:6, Train loss: 1.64 Train acc: 0.322, Test acc:0.326
Epoch:7, Train loss: 1.64 Train acc: 0.329, Test acc:0.324
Epoch:8, Train loss: 1.63 Train acc: 0.362, Test acc:0.302
Epoch:9, Train loss: 1.63 Train acc: 0.355, Test acc:0.304
Epoch:10, Train loss: 1.63 Train acc: 0.368, Test acc:0.304
Epoch:11, Train loss: 1.62 Train acc: 0.355, Test acc:0.314
Epoch:12, Train loss: 1.62 Train acc: 0.358, Test acc:0.312
Epoch:13, Train loss: 1.62 Train acc: 0.355, Test acc:0.314
Epoch:14, Train loss: 1.61 Train acc: 0.368, Test acc:0.300
Epoch:0, Train loss: 2.02 Train acc: 0.285, Test acc:0.310
Epoch:1, Train loss: 1.75 Train acc: 0.295, Test ac

Epoch:0, Train loss: 1.83 Train acc: 0.290, Test acc:0.278
Epoch:1, Train loss: 1.70 Train acc: 0.318, Test acc:0.290
Epoch:2, Train loss: 1.67 Train acc: 0.304, Test acc:0.310
Epoch:3, Train loss: 1.65 Train acc: 0.318, Test acc:0.294
Epoch:4, Train loss: 1.64 Train acc: 0.322, Test acc:0.304
Epoch:5, Train loss: 1.63 Train acc: 0.359, Test acc:0.286
Epoch:6, Train loss: 1.62 Train acc: 0.376, Test acc:0.306
Epoch:7, Train loss: 1.62 Train acc: 0.367, Test acc:0.302
Epoch:8, Train loss: 1.62 Train acc: 0.364, Test acc:0.304
Epoch:9, Train loss: 1.61 Train acc: 0.360, Test acc:0.298
Epoch:10, Train loss: 1.61 Train acc: 0.378, Test acc:0.316
Epoch:11, Train loss: 1.61 Train acc: 0.369, Test acc:0.296
Epoch:12, Train loss: 1.60 Train acc: 0.368, Test acc:0.308
Epoch:13, Train loss: 1.60 Train acc: 0.368, Test acc:0.304
Epoch:14, Train loss: 1.60 Train acc: 0.381, Test acc:0.314
Epoch:0, Train loss: 1.78 Train acc: 0.285, Test acc:0.310
Epoch:1, Train loss: 1.70 Train acc: 0.289, Test ac

Epoch:0, Train loss: 1.74 Train acc: 0.285, Test acc:0.310
Epoch:1, Train loss: 1.68 Train acc: 0.238, Test acc:0.290
Epoch:2, Train loss: 1.68 Train acc: 0.290, Test acc:0.278
Epoch:3, Train loss: 1.66 Train acc: 0.263, Test acc:0.280
Epoch:4, Train loss: 1.66 Train acc: 0.274, Test acc:0.280
Epoch:5, Train loss: 1.65 Train acc: 0.274, Test acc:0.274
Epoch:6, Train loss: 1.65 Train acc: 0.255, Test acc:0.286
Epoch:7, Train loss: 1.64 Train acc: 0.292, Test acc:0.278
Epoch:8, Train loss: 1.64 Train acc: 0.276, Test acc:0.288
Epoch:9, Train loss: 1.64 Train acc: 0.304, Test acc:0.308
Epoch:10, Train loss: 1.63 Train acc: 0.312, Test acc:0.294
Epoch:11, Train loss: 1.63 Train acc: 0.291, Test acc:0.280
Epoch:12, Train loss: 1.63 Train acc: 0.358, Test acc:0.290
Epoch:13, Train loss: 1.62 Train acc: 0.363, Test acc:0.306
Epoch:14, Train loss: 1.62 Train acc: 0.363, Test acc:0.290
Epoch:0, Train loss: 1.68 Train acc: 0.299, Test acc:0.312
Epoch:1, Train loss: 1.66 Train acc: 0.335, Test ac

Epoch:0, Train loss: 1.71 Train acc: 0.278, Test acc:0.278
Epoch:1, Train loss: 1.69 Train acc: 0.285, Test acc:0.310
Epoch:2, Train loss: 1.68 Train acc: 0.299, Test acc:0.314
Epoch:3, Train loss: 1.66 Train acc: 0.290, Test acc:0.278
Epoch:4, Train loss: 1.66 Train acc: 0.290, Test acc:0.278
Epoch:5, Train loss: 1.65 Train acc: 0.354, Test acc:0.308
Epoch:6, Train loss: 1.64 Train acc: 0.357, Test acc:0.300
Epoch:7, Train loss: 1.64 Train acc: 0.380, Test acc:0.314
Epoch:8, Train loss: 1.63 Train acc: 0.338, Test acc:0.298
Epoch:9, Train loss: 1.63 Train acc: 0.362, Test acc:0.292
Epoch:10, Train loss: 1.63 Train acc: 0.377, Test acc:0.316
Epoch:11, Train loss: 1.62 Train acc: 0.376, Test acc:0.318
Epoch:12, Train loss: 1.62 Train acc: 0.361, Test acc:0.310
Epoch:13, Train loss: 1.62 Train acc: 0.380, Test acc:0.308
Epoch:14, Train loss: 1.61 Train acc: 0.375, Test acc:0.316
Epoch:0, Train loss: 1.68 Train acc: 0.300, Test acc:0.266
Epoch:1, Train loss: 1.68 Train acc: 0.291, Test ac

Epoch:0, Train loss: 1.93 Train acc: 0.291, Test acc:0.284
Epoch:1, Train loss: 1.74 Train acc: 0.286, Test acc:0.318
Epoch:2, Train loss: 1.67 Train acc: 0.314, Test acc:0.304
Epoch:3, Train loss: 1.65 Train acc: 0.322, Test acc:0.312
Epoch:4, Train loss: 1.64 Train acc: 0.340, Test acc:0.312
Epoch:5, Train loss: 1.63 Train acc: 0.354, Test acc:0.306
Epoch:6, Train loss: 1.62 Train acc: 0.355, Test acc:0.306
Epoch:7, Train loss: 1.62 Train acc: 0.359, Test acc:0.306
Epoch:8, Train loss: 1.62 Train acc: 0.363, Test acc:0.314
Epoch:9, Train loss: 1.61 Train acc: 0.363, Test acc:0.306
Epoch:10, Train loss: 1.61 Train acc: 0.361, Test acc:0.306
Epoch:11, Train loss: 1.61 Train acc: 0.364, Test acc:0.308
Epoch:12, Train loss: 1.60 Train acc: 0.365, Test acc:0.296
Epoch:13, Train loss: 1.60 Train acc: 0.367, Test acc:0.308
Epoch:14, Train loss: 1.60 Train acc: 0.372, Test acc:0.294
Epoch:0, Train loss: 1.91 Train acc: 0.284, Test acc:0.280
Epoch:1, Train loss: 1.73 Train acc: 0.291, Test ac

Epoch:0, Train loss: 1.78 Train acc: 0.285, Test acc:0.310
Epoch:1, Train loss: 1.70 Train acc: 0.290, Test acc:0.278
Epoch:2, Train loss: 1.67 Train acc: 0.281, Test acc:0.280
Epoch:3, Train loss: 1.65 Train acc: 0.309, Test acc:0.312
Epoch:4, Train loss: 1.65 Train acc: 0.301, Test acc:0.320
Epoch:5, Train loss: 1.64 Train acc: 0.298, Test acc:0.284
Epoch:6, Train loss: 1.64 Train acc: 0.324, Test acc:0.306
Epoch:7, Train loss: 1.63 Train acc: 0.347, Test acc:0.294
Epoch:8, Train loss: 1.63 Train acc: 0.347, Test acc:0.304
Epoch:9, Train loss: 1.63 Train acc: 0.327, Test acc:0.306
Epoch:10, Train loss: 1.62 Train acc: 0.354, Test acc:0.302
Epoch:11, Train loss: 1.62 Train acc: 0.373, Test acc:0.292
Epoch:12, Train loss: 1.62 Train acc: 0.376, Test acc:0.290
Epoch:13, Train loss: 1.61 Train acc: 0.360, Test acc:0.308
Epoch:14, Train loss: 1.61 Train acc: 0.356, Test acc:0.320
Epoch:0, Train loss: 2.08 Train acc: 0.111, Test acc:0.128
Epoch:1, Train loss: 1.74 Train acc: 0.285, Test ac

Epoch:0, Train loss: 2.19 Train acc: 0.161, Test acc:0.156
Epoch:1, Train loss: 1.79 Train acc: 0.290, Test acc:0.278
Epoch:2, Train loss: 1.70 Train acc: 0.267, Test acc:0.288
Epoch:3, Train loss: 1.70 Train acc: 0.288, Test acc:0.302
Epoch:4, Train loss: 1.68 Train acc: 0.297, Test acc:0.304
Epoch:5, Train loss: 1.66 Train acc: 0.287, Test acc:0.274
Epoch:6, Train loss: 1.65 Train acc: 0.301, Test acc:0.282
Epoch:7, Train loss: 1.65 Train acc: 0.304, Test acc:0.292
Epoch:8, Train loss: 1.64 Train acc: 0.331, Test acc:0.308
Epoch:9, Train loss: 1.64 Train acc: 0.341, Test acc:0.292
Epoch:10, Train loss: 1.63 Train acc: 0.347, Test acc:0.292
Epoch:11, Train loss: 1.63 Train acc: 0.326, Test acc:0.286
Epoch:12, Train loss: 1.63 Train acc: 0.363, Test acc:0.306
Epoch:13, Train loss: 1.63 Train acc: 0.361, Test acc:0.300
Epoch:14, Train loss: 1.62 Train acc: 0.363, Test acc:0.294
Epoch:0, Train loss: 2.02 Train acc: 0.278, Test acc:0.272
Epoch:1, Train loss: 1.78 Train acc: 0.304, Test ac

Epoch:0, Train loss: 1.73 Train acc: 0.289, Test acc:0.308
Epoch:1, Train loss: 1.67 Train acc: 0.327, Test acc:0.310
Epoch:2, Train loss: 1.68 Train acc: 0.290, Test acc:0.278
Epoch:3, Train loss: 1.66 Train acc: 0.342, Test acc:0.310
Epoch:4, Train loss: 1.65 Train acc: 0.317, Test acc:0.310
Epoch:5, Train loss: 1.64 Train acc: 0.353, Test acc:0.314
Epoch:6, Train loss: 1.64 Train acc: 0.356, Test acc:0.314
Epoch:7, Train loss: 1.64 Train acc: 0.355, Test acc:0.312
Epoch:8, Train loss: 1.63 Train acc: 0.363, Test acc:0.302
Epoch:9, Train loss: 1.63 Train acc: 0.355, Test acc:0.318
Epoch:10, Train loss: 1.62 Train acc: 0.361, Test acc:0.320
Epoch:11, Train loss: 1.62 Train acc: 0.364, Test acc:0.306
Epoch:12, Train loss: 1.62 Train acc: 0.362, Test acc:0.306
Epoch:13, Train loss: 1.62 Train acc: 0.364, Test acc:0.308
Epoch:14, Train loss: 1.61 Train acc: 0.366, Test acc:0.306
Epoch:0, Train loss: 1.81 Train acc: 0.293, Test acc:0.274
Epoch:1, Train loss: 1.67 Train acc: 0.321, Test ac

Epoch:0, Train loss: 1.93 Train acc: 0.089, Test acc:0.100
Epoch:1, Train loss: 1.69 Train acc: 0.282, Test acc:0.314
Epoch:2, Train loss: 1.70 Train acc: 0.283, Test acc:0.298
Epoch:3, Train loss: 1.69 Train acc: 0.291, Test acc:0.278
Epoch:4, Train loss: 1.67 Train acc: 0.300, Test acc:0.292
Epoch:5, Train loss: 1.66 Train acc: 0.299, Test acc:0.324
Epoch:6, Train loss: 1.65 Train acc: 0.352, Test acc:0.306
Epoch:7, Train loss: 1.65 Train acc: 0.335, Test acc:0.308
Epoch:8, Train loss: 1.64 Train acc: 0.351, Test acc:0.314
Epoch:9, Train loss: 1.64 Train acc: 0.361, Test acc:0.306
Epoch:10, Train loss: 1.63 Train acc: 0.362, Test acc:0.308
Epoch:11, Train loss: 1.63 Train acc: 0.360, Test acc:0.300
Epoch:12, Train loss: 1.63 Train acc: 0.365, Test acc:0.308
Epoch:13, Train loss: 1.63 Train acc: 0.362, Test acc:0.308
Epoch:14, Train loss: 1.62 Train acc: 0.364, Test acc:0.306
Epoch:0, Train loss: 1.80 Train acc: 0.285, Test acc:0.310
Epoch:1, Train loss: 1.67 Train acc: 0.289, Test ac

Epoch:0, Train loss: 2.47 Train acc: 0.047, Test acc:0.060
Epoch:1, Train loss: 1.90 Train acc: 0.168, Test acc:0.168
Epoch:2, Train loss: 1.71 Train acc: 0.294, Test acc:0.310
Epoch:3, Train loss: 1.69 Train acc: 0.306, Test acc:0.298
Epoch:4, Train loss: 1.67 Train acc: 0.308, Test acc:0.288
Epoch:5, Train loss: 1.66 Train acc: 0.316, Test acc:0.302
Epoch:6, Train loss: 1.66 Train acc: 0.311, Test acc:0.312
Epoch:7, Train loss: 1.65 Train acc: 0.318, Test acc:0.304
Epoch:8, Train loss: 1.64 Train acc: 0.324, Test acc:0.296
Epoch:9, Train loss: 1.64 Train acc: 0.336, Test acc:0.300
Epoch:10, Train loss: 1.64 Train acc: 0.337, Test acc:0.310
Epoch:11, Train loss: 1.63 Train acc: 0.331, Test acc:0.306
Epoch:12, Train loss: 1.63 Train acc: 0.341, Test acc:0.310
Epoch:13, Train loss: 1.63 Train acc: 0.346, Test acc:0.314
Epoch:14, Train loss: 1.63 Train acc: 0.342, Test acc:0.300
Epoch:0, Train loss: 1.69 Train acc: 0.301, Test acc:0.310
Epoch:1, Train loss: 1.69 Train acc: 0.290, Test ac

Epoch:0, Train loss: 1.94 Train acc: 0.285, Test acc:0.310
Epoch:1, Train loss: 1.73 Train acc: 0.298, Test acc:0.308
Epoch:2, Train loss: 1.69 Train acc: 0.290, Test acc:0.278
Epoch:3, Train loss: 1.68 Train acc: 0.290, Test acc:0.278
Epoch:4, Train loss: 1.66 Train acc: 0.304, Test acc:0.260
Epoch:5, Train loss: 1.65 Train acc: 0.309, Test acc:0.310
Epoch:6, Train loss: 1.65 Train acc: 0.332, Test acc:0.268
Epoch:7, Train loss: 1.64 Train acc: 0.321, Test acc:0.288
Epoch:8, Train loss: 1.64 Train acc: 0.294, Test acc:0.282
Epoch:9, Train loss: 1.63 Train acc: 0.337, Test acc:0.300
Epoch:10, Train loss: 1.63 Train acc: 0.353, Test acc:0.304
Epoch:11, Train loss: 1.63 Train acc: 0.350, Test acc:0.308
Epoch:12, Train loss: 1.62 Train acc: 0.362, Test acc:0.310
Epoch:13, Train loss: 1.62 Train acc: 0.357, Test acc:0.302
Epoch:14, Train loss: 1.62 Train acc: 0.354, Test acc:0.294
Epoch:0, Train loss: 2.24 Train acc: 0.106, Test acc:0.114
Epoch:1, Train loss: 1.91 Train acc: 0.290, Test ac

In [324]:
## Training parameters
batch_size = 128
epochs=15
ntrials = 20
max_test_accs = []
for trial in range(ntrials):
    training_accuracy = []
    training_loss = []
    testing_accuracy = []
    s.run(tf.global_variables_initializer())
    for epoch in range(epochs):    
        arr = np.arange(tr_n)
        np.random.shuffle(arr)
        for index in range(0,tr_n,batch_size):
            s.run(optimizer, {input_X: tr_X[arr[index:index+batch_size]],
                              input_y: tr_y[arr[index:index+batch_size]]})
        training_accuracy.append(s.run(accuracy, feed_dict= {input_X:tr_X, 
                                                             input_y: tr_y}))
        training_loss.append(s.run(loss, {input_X: tr_X, 
                                          input_y: tr_y}))

        ## Evaluation of model
        test_preds = s.run(predicted_y, {input_X: test_X}).argmax(1)
        test_acc = np.sum(test_y[np.arange(test_n), test_preds] != 0)/test_n
        testing_accuracy.append(test_acc)
        print("Epoch:{0}, Train loss: {1:.2f} Train acc: {2:.3f}, Test acc:{3:.3f}".format(epoch,
                                                                        training_loss[epoch],
                                                                        training_accuracy[epoch],
                                                                       testing_accuracy[epoch]))
    max_test_accs.append(max(testing_accuracy))
print('Average max test accuracy:', np.mean(max_test_accs))

Epoch:0, Train loss: 1.95 Train acc: 0.322, Test acc:0.320
Epoch:1, Train loss: 1.80 Train acc: 0.290, Test acc:0.280
Epoch:2, Train loss: 1.67 Train acc: 0.325, Test acc:0.326
Epoch:3, Train loss: 1.64 Train acc: 0.320, Test acc:0.298
Epoch:4, Train loss: 1.64 Train acc: 0.354, Test acc:0.302
Epoch:5, Train loss: 1.63 Train acc: 0.330, Test acc:0.308
Epoch:6, Train loss: 1.62 Train acc: 0.361, Test acc:0.322
Epoch:7, Train loss: 1.62 Train acc: 0.369, Test acc:0.330
Epoch:8, Train loss: 1.61 Train acc: 0.361, Test acc:0.318
Epoch:9, Train loss: 1.61 Train acc: 0.364, Test acc:0.310
Epoch:10, Train loss: 1.60 Train acc: 0.373, Test acc:0.314
Epoch:11, Train loss: 1.60 Train acc: 0.360, Test acc:0.322
Epoch:12, Train loss: 1.60 Train acc: 0.373, Test acc:0.308
Epoch:13, Train loss: 1.59 Train acc: 0.365, Test acc:0.312
Epoch:14, Train loss: 1.59 Train acc: 0.377, Test acc:0.318
Epoch:0, Train loss: 2.14 Train acc: 0.089, Test acc:0.100
Epoch:1, Train loss: 1.77 Train acc: 0.286, Test ac

Epoch:11, Train loss: 1.62 Train acc: 0.375, Test acc:0.302
Epoch:12, Train loss: 1.62 Train acc: 0.388, Test acc:0.304
Epoch:13, Train loss: 1.61 Train acc: 0.329, Test acc:0.292
Epoch:14, Train loss: 1.61 Train acc: 0.374, Test acc:0.310
Epoch:0, Train loss: 1.87 Train acc: 0.296, Test acc:0.312
Epoch:1, Train loss: 1.70 Train acc: 0.306, Test acc:0.312
Epoch:2, Train loss: 1.70 Train acc: 0.243, Test acc:0.268
Epoch:3, Train loss: 1.68 Train acc: 0.242, Test acc:0.272
Epoch:4, Train loss: 1.66 Train acc: 0.305, Test acc:0.280
Epoch:5, Train loss: 1.65 Train acc: 0.276, Test acc:0.252
Epoch:6, Train loss: 1.65 Train acc: 0.311, Test acc:0.292
Epoch:7, Train loss: 1.64 Train acc: 0.310, Test acc:0.286
Epoch:8, Train loss: 1.64 Train acc: 0.308, Test acc:0.286
Epoch:9, Train loss: 1.63 Train acc: 0.322, Test acc:0.308
Epoch:10, Train loss: 1.63 Train acc: 0.322, Test acc:0.328
Epoch:11, Train loss: 1.62 Train acc: 0.354, Test acc:0.274
Epoch:12, Train loss: 1.62 Train acc: 0.362, Test 

Epoch:0, Train loss: 1.95 Train acc: 0.240, Test acc:0.184
Epoch:1, Train loss: 1.70 Train acc: 0.289, Test acc:0.274
Epoch:2, Train loss: 1.65 Train acc: 0.317, Test acc:0.314
Epoch:3, Train loss: 1.64 Train acc: 0.386, Test acc:0.324
Epoch:4, Train loss: 1.63 Train acc: 0.289, Test acc:0.286
Epoch:5, Train loss: 1.63 Train acc: 0.379, Test acc:0.314
Epoch:6, Train loss: 1.62 Train acc: 0.377, Test acc:0.312
Epoch:7, Train loss: 1.62 Train acc: 0.374, Test acc:0.308
Epoch:8, Train loss: 1.61 Train acc: 0.380, Test acc:0.318
Epoch:9, Train loss: 1.61 Train acc: 0.382, Test acc:0.312
Epoch:10, Train loss: 1.60 Train acc: 0.374, Test acc:0.314
Epoch:11, Train loss: 1.60 Train acc: 0.393, Test acc:0.320
Epoch:12, Train loss: 1.60 Train acc: 0.375, Test acc:0.310
Epoch:13, Train loss: 1.59 Train acc: 0.382, Test acc:0.304
Epoch:14, Train loss: 1.59 Train acc: 0.374, Test acc:0.310
Average max test accuracy: 0.3229
