In [1]:
import pandas as pd
import numpy as np
import copy
import sys
import os
sys.path.append("../")
from parser.utils import load_json, dfs_cardinality, estimate_scan_in_mb
from models.feature.single_xgboost_feature import find_top_k_operators, featurize_one_plan, get_top_k_table_by_size
from utils.load_brad_trace import load_trace, create_concurrency_dataset, load_trace_all_version
from models.concurrency.utils import pre_info_train_test_seperation
import torch
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence, pad_sequence
from models.single.stage import SingleStage
from models.concurrency.complex_models import ConcurrentRNN
np.set_printoptions(suppress=True)

In [2]:
parsed_queries_path = "/Users/ziniuw/Desktop/research/Data/AWS_trace/mixed_aurora/aurora_mixed_parsed_queries.json"
plans = load_json(parsed_queries_path, namespace=False)
folder_name = "mixed_postgres"
directory = f"/Users/ziniuw/Desktop/research/Data/AWS_trace/{folder_name}/"
all_raw_trace, all_trace = load_trace_all_version(directory, 8, concat=True)
all_concurrency_df = []
for trace in all_trace:
    concurrency_df = create_concurrency_dataset(trace, engine=None, pre_exec_interval=200)
    all_concurrency_df.append(concurrency_df)
concurrency_df = pd.concat(all_concurrency_df, ignore_index=True)

In [3]:
train_trace_df_sep, eval_trace_df_sep = pre_info_train_test_seperation(concurrency_df)
print(len(train_trace_df_sep), len(eval_trace_df_sep))
np.random.seed(0)
train_idx = np.random.choice(len(concurrency_df), size=int(0.8 * len(concurrency_df)), replace=False)
test_idx = [i for i in range(len(concurrency_df)) if i not in train_idx]
train_trace_df = copy.deepcopy(concurrency_df.iloc[train_idx])
eval_trace_df = concurrency_df.iloc[test_idx]
eval_trace_df = copy.deepcopy(eval_trace_df[eval_trace_df['num_concurrent_queries'] > 0])
print(len(train_trace_df), len(eval_trace_df))

17603 13399
24889 6198


In [4]:
ss = SingleStage(use_table_features=True, true_card=True)
#df = ss.featurize_data(train_trace_df, parsed_queries_path)
df = ss.featurize_data(concurrency_df, parsed_queries_path)
ss.train(df)

Top 20 operators contains 0.9650782102582758 total operators


In [6]:
rnn = ConcurrentRNN(ss, 
                    "postgres",
                    input_size=len(ss.all_feature[0]) * 2 + 7,
                    embedding_dim=128,
                    hidden_size=256,
                    num_layers=2,
                    loss_function="q_loss",
                    last_output=True,
                    use_separation=False
                   )
rnn.train(train_trace_df, eval_trace_df, lr=0.001, loss_function="q_loss", val_on_test=True)
#rnn.load_model("checkpoints")

********Epoch 0, training loss: 347.58929279981515 || evaluation loss: ********


100%|███████████████████████████████████████████████████████████████████████████| 49/49 [00:02<00:00, 17.19it/s]


50% absolute error is 4.008946895599365, q-error is 1.669534981250763
90% absolute error is 56.427736282348675, q-error is 5.132400846481324
95% absolute error is 99.74682159423816, q-error is 7.393720984458918
********Epoch 5, training loss: 154.66684285944854 || evaluation loss: ********


100%|███████████████████████████████████████████████████████████████████████████| 49/49 [00:02<00:00, 16.91it/s]


50% absolute error is 3.3501548767089844, q-error is 1.394617736339569
90% absolute error is 21.596798706054688, q-error is 3.6809489250183107
95% absolute error is 35.74762611389158, q-error is 6.075449180603016
********Epoch 10, training loss: 0.27595379108037704 || evaluation loss: ********


100%|███████████████████████████████████████████████████████████████████████████| 49/49 [00:03<00:00, 16.20it/s]


50% absolute error is 2.732160210609436, q-error is 1.3436789512634277
90% absolute error is 25.88196487426758, q-error is 3.1858203649520873
95% absolute error is 50.56133327484129, q-error is 4.81699743270873
********Epoch 15, training loss: 5.076713414222766 || evaluation loss: ********


100%|███████████████████████████████████████████████████████████████████████████| 49/49 [00:03<00:00, 16.27it/s]


50% absolute error is 2.5771565437316895, q-error is 1.302615761756897
90% absolute error is 20.151940536499023, q-error is 3.0724877595901487
95% absolute error is 32.027860832214344, q-error is 4.70953342914581
********Epoch 20, training loss: 0.23362355667811174 || evaluation loss: ********


100%|███████████████████████████████████████████████████████████████████████████| 49/49 [00:02<00:00, 16.51it/s]


50% absolute error is 2.5559717416763306, q-error is 1.3026060461997986
90% absolute error is 21.500904846191407, q-error is 2.956184434890749
95% absolute error is 35.47093429565428, q-error is 4.372178816795347
********Epoch 25, training loss: 8.318775666829866 || evaluation loss: ********


100%|███████████████████████████████████████████████████████████████████████████| 49/49 [00:02<00:00, 16.66it/s]


50% absolute error is 2.476870894432068, q-error is 1.2943041920661926
90% absolute error is 19.770213317871097, q-error is 2.8973747730255135
95% absolute error is 33.72637729644775, q-error is 4.472911119461058
********Epoch 30, training loss: 12.735641101002694 || evaluation loss: ********


100%|███████████████████████████████████████████████████████████████████████████| 49/49 [00:02<00:00, 16.58it/s]


50% absolute error is 2.733838438987732, q-error is 1.2923868894577026
90% absolute error is 22.405932235717778, q-error is 2.979494094848633
95% absolute error is 38.36084594726562, q-error is 4.556399250030516
********Epoch 35, training loss: 8.50082020083299 || evaluation loss: ********


100%|███████████████████████████████████████████████████████████████████████████| 49/49 [00:02<00:00, 16.46it/s]


50% absolute error is 2.7718400955200195, q-error is 1.3003828525543213
90% absolute error is 18.969323348999023, q-error is 2.9934602260589602
95% absolute error is 29.024862861633284, q-error is 4.440809941291809
********Epoch 40, training loss: 8.334890223007935 || evaluation loss: ********


100%|███████████████████████████████████████████████████████████████████████████| 49/49 [00:02<00:00, 16.60it/s]


50% absolute error is 2.4107199907302856, q-error is 1.3126941323280334
90% absolute error is 17.917407608032228, q-error is 3.2280951976776135
95% absolute error is 29.10806121826171, q-error is 5.221411943435652
********Epoch 45, training loss: 0.1668203660692924 || evaluation loss: ********


100%|███████████████████████████████████████████████████████████████████████████| 49/49 [00:02<00:00, 16.42it/s]


50% absolute error is 2.6773428916931152, q-error is 1.2827044129371643
90% absolute error is 19.399579620361333, q-error is 2.9053829193115237
95% absolute error is 28.970496177673272, q-error is 4.2669408321380535
********Epoch 50, training loss: 8.493627790266123 || evaluation loss: ********


100%|███████████████████████████████████████████████████████████████████████████| 49/49 [00:03<00:00, 16.20it/s]


50% absolute error is 2.5529887676239014, q-error is 1.285063922405243
90% absolute error is 18.57699604034424, q-error is 3.029467940330507
95% absolute error is 30.51652069091793, q-error is 4.629619932174682
********Epoch 55, training loss: 34.45945652245711 || evaluation loss: ********


100%|███████████████████████████████████████████████████████████████████████████| 49/49 [00:02<00:00, 16.50it/s]


50% absolute error is 2.6983680725097656, q-error is 1.2991698384284973
90% absolute error is 17.301197814941414, q-error is 3.1333164453506472
95% absolute error is 27.5585428237915, q-error is 4.999049568176263
********Epoch 60, training loss: 14.878932293370749 || evaluation loss: ********


100%|███████████████████████████████████████████████████████████████████████████| 49/49 [00:02<00:00, 16.38it/s]


50% absolute error is 2.6564834117889404, q-error is 1.2962819337844849
90% absolute error is 19.772585296630865, q-error is 2.9524541378021243
95% absolute error is 32.37805233001706, q-error is 4.28236179351806
********Epoch 65, training loss: 12.937910144451337 || evaluation loss: ********


100%|███████████████████████████████████████████████████████████████████████████| 49/49 [00:03<00:00, 16.32it/s]


50% absolute error is 2.687609553337097, q-error is 1.2870275378227234
90% absolute error is 19.4614501953125, q-error is 2.9776795625686647
95% absolute error is 30.656786727905263, q-error is 4.592781472206109
********Epoch 70, training loss: 8.31051180530817 || evaluation loss: ********


100%|███████████████████████████████████████████████████████████████████████████| 49/49 [00:03<00:00, 16.01it/s]


50% absolute error is 2.7452070713043213, q-error is 1.2910370826721191
90% absolute error is 17.373887634277345, q-error is 2.8900103092193605
95% absolute error is 28.838561058044434, q-error is 4.108472967147825
********Epoch 75, training loss: 8.700345549102012 || evaluation loss: ********


100%|███████████████████████████████████████████████████████████████████████████| 49/49 [00:02<00:00, 16.85it/s]


50% absolute error is 2.5477957725524902, q-error is 1.2918135523796082
90% absolute error is 17.278061103820804, q-error is 2.917835545539856
95% absolute error is 27.06932144165038, q-error is 4.49826734066009
********Epoch 80, training loss: 16.85967827324684 || evaluation loss: ********


100%|███████████████████████████████████████████████████████████████████████████| 49/49 [00:02<00:00, 16.58it/s]


50% absolute error is 2.540003538131714, q-error is 1.292146623134613
90% absolute error is 18.370839309692386, q-error is 2.914396119117737
95% absolute error is 28.359736633300766, q-error is 4.252169251441945
********Epoch 85, training loss: 0.11961331218481064 || evaluation loss: ********


100%|███████████████████████████████████████████████████████████████████████████| 49/49 [00:02<00:00, 16.65it/s]


50% absolute error is 2.7122802734375, q-error is 1.2748683094978333
90% absolute error is 17.1804012298584, q-error is 2.8075184106826785
95% absolute error is 28.108411121368405, q-error is 4.06545920372009
********Epoch 90, training loss: 8.960526917263483 || evaluation loss: ********


100%|███████████████████████████████████████████████████████████████████████████| 49/49 [00:02<00:00, 16.54it/s]


50% absolute error is 2.3814773559570312, q-error is 1.2803987264633179
90% absolute error is 16.997070312500004, q-error is 2.79355251789093
95% absolute error is 26.743215370178113, q-error is 4.018131518363948
********Epoch 95, training loss: 4.699431594900596 || evaluation loss: ********


100%|███████████████████████████████████████████████████████████████████████████| 49/49 [00:02<00:00, 16.55it/s]


50% absolute error is 2.5556888580322266, q-error is 1.2836337089538574
90% absolute error is 18.012384605407718, q-error is 2.820553731918335
95% absolute error is 27.946987915039056, q-error is 3.9305380702018704
********Epoch 100, training loss: 24.937843833481654 || evaluation loss: ********


100%|███████████████████████████████████████████████████████████████████████████| 49/49 [00:02<00:00, 16.62it/s]


50% absolute error is 2.7225570678710938, q-error is 1.2844673991203308
90% absolute error is 20.096814918518067, q-error is 3.200940203666687
95% absolute error is 33.70134887695312, q-error is 4.690931701660155
********Epoch 105, training loss: 17.251273626929674 || evaluation loss: ********


100%|███████████████████████████████████████████████████████████████████████████| 49/49 [00:02<00:00, 16.50it/s]


50% absolute error is 2.5348044633865356, q-error is 1.2759929299354553
90% absolute error is 17.924271392822266, q-error is 2.883139467239381
95% absolute error is 28.00610580444335, q-error is 4.070354032516475
********Epoch 110, training loss: 16.493377273510664 || evaluation loss: ********


100%|███████████████████████████████████████████████████████████████████████████| 49/49 [00:03<00:00, 16.16it/s]


50% absolute error is 2.65010929107666, q-error is 1.2925434112548828
90% absolute error is 17.48025074005127, q-error is 3.101140975952149
95% absolute error is 26.98633823394775, q-error is 4.978283834457375
********Epoch 115, training loss: 14.049337375660738 || evaluation loss: ********


100%|███████████████████████████████████████████████████████████████████████████| 49/49 [00:02<00:00, 16.65it/s]


50% absolute error is 2.5001492500305176, q-error is 1.2890622019767761
90% absolute error is 17.25225067138672, q-error is 2.8979191064834593
95% absolute error is 26.077169609069774, q-error is 4.194843029975887
********Epoch 120, training loss: 4.625079183127635 || evaluation loss: ********


100%|███████████████████████████████████████████████████████████████████████████| 49/49 [00:02<00:00, 16.67it/s]


50% absolute error is 2.648329734802246, q-error is 1.2909478545188904
90% absolute error is 18.16018352508545, q-error is 2.822620964050293
95% absolute error is 26.905928230285642, q-error is 4.047429370880127


KeyboardInterrupt: 

In [9]:
rnn.save_model("checkpoints")

In [6]:
rnn2 = ConcurrentRNN(ss, 
                    input_size=len(ss.all_feature[0]) * 2 + 7,
                    embedding_dim=128,
                    hidden_size=256,
                    num_layers=2,
                    loss_function="q_loss",
                    last_output=False,
                    use_seperation=False
                   )
rnn2.train(train_trace_df_sep, eval_trace_df_sep, lr=0.001, loss_function="q_loss", val_on_test=True)

********Epoch 0, training loss: 401.60364659328377 || evaluation loss: ********


100%|█████████████████████████████████████████████████████████████████████████████████| 200/200 [00:01<00:00, 132.50it/s]


50% absolute error is 4.02608060836792, q-error is 3.915198802947998
90% absolute error is 20.422407150268555, q-error is 31.662242889404297
95% absolute error is 34.7833251953125, q-error is 59.1192626953125
********Epoch 5, training loss: 0.3782997516404211 || evaluation loss: ********


100%|█████████████████████████████████████████████████████████████████████████████████| 200/200 [00:01<00:00, 119.85it/s]


50% absolute error is 2.1188182830810547, q-error is 2.816147804260254
90% absolute error is 11.656293869018555, q-error is 18.462446212768555
95% absolute error is 25.714630126953125, q-error is 33.27926254272461
********Epoch 10, training loss: 0.33071399506478183 || evaluation loss: ********


100%|█████████████████████████████████████████████████████████████████████████████████| 200/200 [00:01<00:00, 120.39it/s]


50% absolute error is 3.057675361633301, q-error is 3.2474231719970703
90% absolute error is 9.518875122070312, q-error is 22.488019943237305
95% absolute error is 23.943695068359375, q-error is 40.63731002807617
********Epoch 15, training loss: 0.3061664091934145 || evaluation loss: ********


100%|█████████████████████████████████████████████████████████████████████████████████| 200/200 [00:01<00:00, 129.68it/s]


50% absolute error is 2.7309913635253906, q-error is 3.187948226928711
90% absolute error is 10.453584671020508, q-error is 21.621904373168945
95% absolute error is 25.34214210510254, q-error is 38.842430114746094
********Epoch 20, training loss: 4.57088536380139 || evaluation loss: ********


100%|█████████████████████████████████████████████████████████████████████████████████| 200/200 [00:01<00:00, 119.06it/s]


50% absolute error is 2.9840917587280273, q-error is 3.2921011447906494
90% absolute error is 9.696727752685547, q-error is 23.52199935913086
95% absolute error is 25.022659301757812, q-error is 40.50761413574219
********Epoch 25, training loss: 5.104323701520936 || evaluation loss: ********


100%|█████████████████████████████████████████████████████████████████████████████████| 200/200 [00:01<00:00, 130.45it/s]


50% absolute error is 2.501612663269043, q-error is 3.0457828044891357
90% absolute error is 9.88021469116211, q-error is 19.38321876525879
95% absolute error is 24.71220588684082, q-error is 34.32194137573242
********Epoch 30, training loss: 3.8351619181643546 || evaluation loss: ********


100%|█████████████████████████████████████████████████████████████████████████████████| 200/200 [00:01<00:00, 122.86it/s]


50% absolute error is 2.680535316467285, q-error is 3.1737663745880127
90% absolute error is 9.848319053649902, q-error is 20.910255432128906
95% absolute error is 25.309398651123047, q-error is 35.394561767578125
********Epoch 35, training loss: 0.27158757767318625 || evaluation loss: ********


100%|█████████████████████████████████████████████████████████████████████████████████| 200/200 [00:01<00:00, 126.67it/s]


50% absolute error is 2.517225980758667, q-error is 3.0666236877441406
90% absolute error is 9.862235069274902, q-error is 19.41659164428711
95% absolute error is 25.392702102661133, q-error is 32.34953308105469
********Epoch 40, training loss: 4.251501953311726 || evaluation loss: ********


100%|█████████████████████████████████████████████████████████████████████████████████| 200/200 [00:01<00:00, 117.74it/s]


50% absolute error is 2.02493953704834, q-error is 2.811598300933838
90% absolute error is 10.41359806060791, q-error is 17.052032470703125
95% absolute error is 25.90966796875, q-error is 28.204675674438477
********Epoch 45, training loss: 32.79510834463666 || evaluation loss: ********


100%|█████████████████████████████████████████████████████████████████████████████████| 200/200 [00:01<00:00, 123.16it/s]


50% absolute error is 2.052144765853882, q-error is 2.7790801525115967
90% absolute error is 10.189210891723633, q-error is 16.966291427612305
95% absolute error is 25.5352783203125, q-error is 27.46196174621582
********Epoch 50, training loss: 4.669592439403049 || evaluation loss: ********


100%|█████████████████████████████████████████████████████████████████████████████████| 200/200 [00:01<00:00, 103.88it/s]


50% absolute error is 1.897214412689209, q-error is 2.685394763946533
90% absolute error is 10.424579620361328, q-error is 15.887414932250977
95% absolute error is 26.08725357055664, q-error is 26.610307693481445
********Epoch 55, training loss: 0.24073854807467587 || evaluation loss: ********


100%|█████████████████████████████████████████████████████████████████████████████████| 200/200 [00:01<00:00, 105.47it/s]


50% absolute error is 1.7075432538986206, q-error is 2.5083935260772705
90% absolute error is 10.122432708740234, q-error is 13.90810775756836
95% absolute error is 25.41785430908203, q-error is 23.904821395874023
********Epoch 60, training loss: 3.8738698383743784 || evaluation loss: ********


100%|█████████████████████████████████████████████████████████████████████████████████| 200/200 [00:01<00:00, 112.88it/s]


50% absolute error is 2.4759559631347656, q-error is 3.0127789974212646
90% absolute error is 9.986977577209473, q-error is 19.526498794555664
95% absolute error is 25.06679344177246, q-error is 33.768333435058594
********Epoch 65, training loss: 8.905193518616457 || evaluation loss: ********


100%|█████████████████████████████████████████████████████████████████████████████████| 200/200 [00:01<00:00, 115.57it/s]


50% absolute error is 2.6671369075775146, q-error is 3.165036916732788
90% absolute error is 9.655319213867188, q-error is 20.323694229125977
95% absolute error is 24.702091217041016, q-error is 34.07258605957031
********Epoch 70, training loss: 5.189687584353759 || evaluation loss: ********


100%|█████████████████████████████████████████████████████████████████████████████████| 200/200 [00:01<00:00, 122.74it/s]


50% absolute error is 2.329052209854126, q-error is 2.9481449127197266
90% absolute error is 9.90483283996582, q-error is 18.6979923248291
95% absolute error is 25.53311538696289, q-error is 30.63114356994629
********Epoch 75, training loss: 0.2135592408280457 || evaluation loss: ********


100%|█████████████████████████████████████████████████████████████████████████████████| 200/200 [00:01<00:00, 113.68it/s]


50% absolute error is 2.2931246757507324, q-error is 2.915856122970581
90% absolute error is 9.785589218139648, q-error is 17.67096710205078
95% absolute error is 24.739238739013672, q-error is 29.528165817260742
********Epoch 80, training loss: 7.477858738825384 || evaluation loss: ********


100%|█████████████████████████████████████████████████████████████████████████████████| 200/200 [00:01<00:00, 118.56it/s]


50% absolute error is 2.9165666103363037, q-error is 3.2909066677093506
90% absolute error is 9.472898483276367, q-error is 22.067970275878906
95% absolute error is 24.78216552734375, q-error is 36.99903869628906
********Epoch 85, training loss: 13.503025806583134 || evaluation loss: ********


100%|█████████████████████████████████████████████████████████████████████████████████| 200/200 [00:01<00:00, 113.04it/s]


50% absolute error is 2.5994062423706055, q-error is 3.1410393714904785
90% absolute error is 9.814783096313477, q-error is 20.109203338623047
95% absolute error is 24.89160919189453, q-error is 34.24311447143555
********Epoch 90, training loss: 4.679185905551488 || evaluation loss: ********


100%|█████████████████████████████████████████████████████████████████████████████████| 200/200 [00:01<00:00, 118.58it/s]


50% absolute error is 1.925673007965088, q-error is 2.6986167430877686
90% absolute error is 10.201363563537598, q-error is 15.419150352478027
95% absolute error is 25.82308006286621, q-error is 27.030786514282227
********Epoch 95, training loss: 5.584833062130265 || evaluation loss: ********


100%|█████████████████████████████████████████████████████████████████████████████████| 200/200 [00:01<00:00, 117.97it/s]


50% absolute error is 2.142277956008911, q-error is 2.8431923389434814
90% absolute error is 10.04932975769043, q-error is 17.189533233642578
95% absolute error is 25.420177459716797, q-error is 28.772003173828125
********Epoch 100, training loss: 5.3807773006239294 || evaluation loss: ********


100%|█████████████████████████████████████████████████████████████████████████████████| 200/200 [00:01<00:00, 120.43it/s]


50% absolute error is 2.308891534805298, q-error is 2.920111656188965
90% absolute error is 9.708581924438477, q-error is 17.337556838989258
95% absolute error is 25.382572174072266, q-error is 30.598960876464844
********Epoch 105, training loss: 0.19500983130615368 || evaluation loss: ********


100%|█████████████████████████████████████████████████████████████████████████████████| 200/200 [00:01<00:00, 125.51it/s]


50% absolute error is 1.913907766342163, q-error is 2.6873414516448975
90% absolute error is 10.039213180541992, q-error is 15.178519248962402
95% absolute error is 25.581897735595703, q-error is 25.83416175842285
********Epoch 110, training loss: 4.118279430718548 || evaluation loss: ********


100%|█████████████████████████████████████████████████████████████████████████████████| 200/200 [00:01<00:00, 119.13it/s]


50% absolute error is 2.419550895690918, q-error is 3.0060484409332275
90% absolute error is 9.915777206420898, q-error is 19.065942764282227
95% absolute error is 25.270217895507812, q-error is 32.452247619628906
********Epoch 115, training loss: 4.6011467881962265 || evaluation loss: ********


100%|█████████████████████████████████████████████████████████████████████████████████| 200/200 [00:01<00:00, 130.18it/s]


50% absolute error is 2.053647041320801, q-error is 2.7769992351531982
90% absolute error is 9.967422485351562, q-error is 15.947748184204102
95% absolute error is 25.478567123413086, q-error is 27.364757537841797
********Epoch 120, training loss: 23.59974587016401 || evaluation loss: ********


100%|█████████████████████████████████████████████████████████████████████████████████| 200/200 [00:01<00:00, 119.26it/s]


50% absolute error is 2.208066463470459, q-error is 2.866530179977417
90% absolute error is 9.777957916259766, q-error is 16.763961791992188
95% absolute error is 25.390098571777344, q-error is 28.77872085571289
********Epoch 125, training loss: 5.691447961712833 || evaluation loss: ********


100%|█████████████████████████████████████████████████████████████████████████████████| 200/200 [00:01<00:00, 121.61it/s]


50% absolute error is 2.1189651489257812, q-error is 2.864086389541626
90% absolute error is 10.259038925170898, q-error is 17.193912506103516
95% absolute error is 26.23992919921875, q-error is 28.833757400512695
********Epoch 130, training loss: 4.015672493594146 || evaluation loss: ********


100%|█████████████████████████████████████████████████████████████████████████████████| 200/200 [00:01<00:00, 124.27it/s]


50% absolute error is 2.96690034866333, q-error is 3.3172285556793213
90% absolute error is 9.39437198638916, q-error is 22.32872772216797
95% absolute error is 24.833667755126953, q-error is 38.204383850097656
********Epoch 135, training loss: 8.069464484740674 || evaluation loss: ********


100%|█████████████████████████████████████████████████████████████████████████████████| 200/200 [00:01<00:00, 120.75it/s]


50% absolute error is 1.8422811031341553, q-error is 2.6789448261260986
90% absolute error is 10.545415878295898, q-error is 15.437335014343262
95% absolute error is 26.261991500854492, q-error is 26.05470848083496
********Epoch 140, training loss: 18.05346157322679 || evaluation loss: ********


100%|█████████████████████████████████████████████████████████████████████████████████| 200/200 [00:01<00:00, 122.29it/s]


50% absolute error is 2.7971699237823486, q-error is 3.207273483276367
90% absolute error is 9.613585472106934, q-error is 21.05196189880371
95% absolute error is 24.851119995117188, q-error is 36.257774353027344
********Epoch 145, training loss: 4.027922123967283 || evaluation loss: ********


100%|█████████████████████████████████████████████████████████████████████████████████| 200/200 [00:01<00:00, 119.24it/s]


50% absolute error is 2.09159779548645, q-error is 2.776171922683716
90% absolute error is 10.080284118652344, q-error is 16.686464309692383
95% absolute error is 25.96129608154297, q-error is 28.996135711669922
********Epoch 150, training loss: 0.17318295552271656 || evaluation loss: ********


100%|█████████████████████████████████████████████████████████████████████████████████| 200/200 [00:01<00:00, 117.80it/s]


50% absolute error is 2.227588415145874, q-error is 2.8709473609924316
90% absolute error is 9.685270309448242, q-error is 17.105562210083008
95% absolute error is 25.357437133789062, q-error is 29.299144744873047
********Epoch 155, training loss: 5.801195154477537 || evaluation loss: ********


100%|█████████████████████████████████████████████████████████████████████████████████| 200/200 [00:01<00:00, 121.41it/s]


50% absolute error is 2.1666388511657715, q-error is 2.8034486770629883
90% absolute error is 9.878110885620117, q-error is 17.122697830200195
95% absolute error is 25.492036819458008, q-error is 29.573806762695312
********Epoch 160, training loss: 8.03464655700115 || evaluation loss: ********


100%|█████████████████████████████████████████████████████████████████████████████████| 200/200 [00:01<00:00, 116.31it/s]


50% absolute error is 2.4785423278808594, q-error is 3.0049920082092285
90% absolute error is 9.754626274108887, q-error is 19.390413284301758
95% absolute error is 25.51573371887207, q-error is 33.767005920410156
********Epoch 165, training loss: 9.50005738865749 || evaluation loss: ********


100%|█████████████████████████████████████████████████████████████████████████████████| 200/200 [00:01<00:00, 124.12it/s]


50% absolute error is 2.438920021057129, q-error is 2.9437084197998047
90% absolute error is 9.689056396484375, q-error is 18.377559661865234
95% absolute error is 25.14614486694336, q-error is 32.45532989501953
********Epoch 170, training loss: 11.603582703697997 || evaluation loss: ********


100%|█████████████████████████████████████████████████████████████████████████████████| 200/200 [00:01<00:00, 116.75it/s]


50% absolute error is 2.5776288509368896, q-error is 3.09067440032959
90% absolute error is 9.881180763244629, q-error is 20.5354061126709
95% absolute error is 25.122547149658203, q-error is 34.50389099121094
********Epoch 175, training loss: 3.6534878711125494 || evaluation loss: ********


100%|█████████████████████████████████████████████████████████████████████████████████| 200/200 [00:01<00:00, 117.16it/s]


50% absolute error is 2.6104846000671387, q-error is 3.1053080558776855
90% absolute error is 9.615549087524414, q-error is 19.60382843017578
95% absolute error is 24.84284019470215, q-error is 34.09321975708008
********Epoch 180, training loss: 0.15831208407087663 || evaluation loss: ********


100%|█████████████████████████████████████████████████████████████████████████████████| 200/200 [00:01<00:00, 109.64it/s]


50% absolute error is 2.02262020111084, q-error is 2.724745273590088
90% absolute error is 10.001008987426758, q-error is 15.97861385345459
95% absolute error is 25.90723991394043, q-error is 27.802444458007812
********Epoch 185, training loss: 4.022961144986669 || evaluation loss: ********


100%|█████████████████████████████████████████████████████████████████████████████████| 200/200 [00:01<00:00, 114.85it/s]


50% absolute error is 2.1996700763702393, q-error is 2.830948829650879
90% absolute error is 9.968746185302734, q-error is 17.36515235900879
95% absolute error is 25.595638275146484, q-error is 30.142972946166992
********Epoch 190, training loss: 7.856241999375346 || evaluation loss: ********


100%|█████████████████████████████████████████████████████████████████████████████████| 200/200 [00:01<00:00, 117.27it/s]


50% absolute error is 2.0888543128967285, q-error is 2.7173123359680176
90% absolute error is 10.123455047607422, q-error is 16.875577926635742
95% absolute error is 25.55385971069336, q-error is 30.01237678527832
********Epoch 195, training loss: 7.147099245998976 || evaluation loss: ********


100%|█████████████████████████████████████████████████████████████████████████████████| 200/200 [00:01<00:00, 113.78it/s]


50% absolute error is 2.354909658432007, q-error is 2.9794392585754395
90% absolute error is 10.141586303710938, q-error is 19.38707733154297
95% absolute error is 25.74962043762207, q-error is 32.55106735229492


In [8]:
preds, labels = rnn.predict(eval_trace_df, use_pre_info_only=False)

100%|███████████████████████████████████████████████████████████████████████████| 49/49 [00:03<00:00, 16.06it/s]

50% absolute error is 2.4635677337646484, q-error is 1.2799019813537598
90% absolute error is 17.452255630493166, q-error is 2.8439884185791016
95% absolute error is 26.966872787475545, q-error is 4.003956580162048





In [44]:
preds2, labels2 = rnn2.predict(eval_trace_df_sep, use_pre_info_only=False)

100%|████████████████████████████████████████| 200/200 [00:01<00:00, 139.35it/s]


50% absolute error is 3.380521774291992, q-error is 4.0659284591674805
90% absolute error is 40.663795471191406, q-error is 47.04194641113281
95% absolute error is 59.644752502441406, q-error is 101.35169982910156


In [48]:
i = 210
idx = np.argsort(preds2[i])
print(len(idx))
np.stack((np.asarray(preds2[i])[idx], np.asarray(preds[i])[idx], np.asarray(labels2[i])[idx]), axis=1)

111


array([[  2.8291447,   7.1109896,  79.540054 ],
       [  2.8653407,   2.3412442,   3.416112 ],
       [  2.934177 ,   4.0760264,   4.0271263],
       [  2.9910045,   3.7841144,   3.8717294],
       [  3.0511217,   5.0034904,   3.1730235],
       [  3.148857 ,  18.968218 ,   3.1449356],
       [  3.170916 ,   6.0808353, 106.32712  ],
       [  3.183551 ,   6.5285897,   4.7487164],
       [  3.2096725,  33.221123 ,   2.8220966],
       [  3.2253332,   5.26631  ,   3.3115742],
       [  3.246746 ,   4.4229527,   5.8701944],
       [  3.2892797,   5.2571554,   7.374345 ],
       [  3.2927618,  24.601357 ,   2.9005537],
       [  3.3001895,   3.752159 ,   4.9782295],
       [  3.3236983,   3.4375095,   4.804597 ],
       [  3.3453178,   4.3794765,   3.0326128],
       [  3.4195132,   5.3936315,   8.1688   ],
       [  3.4844189,   6.0924673,   4.3422008],
       [  3.5340056,   4.0386086,   4.554105 ],
       [  3.5757537,   8.585522 ,  32.302185 ],
       [  3.5846162,   6.08425  ,   5.50

In [8]:
i = 210
idx = np.argsort(preds2[i])
print(len(idx))
np.stack((np.asarray(preds[i])[idx], np.asarray(preds2[i])[idx], np.asarray(labels2[i])[idx]), axis=1)

NameError: name 'preds2' is not defined

In [11]:
i = 2
idx = np.argsort(preds[i])
print(len(idx))
np.stack((np.asarray(preds[i])[idx], np.asarray(labels[i])[idx]), axis=1)

7


array([[  9.086283,  10.558472],
       [  9.326768,  10.257236],
       [ 16.062849,  39.89623 ],
       [ 20.419281,  69.41913 ],
       [ 35.720085,  52.07332 ],
       [ 44.310814, 110.97227 ],
       [ 50.8906  ,  45.67449 ]], dtype=float32)