### SQL2Circuits development notebook

In [1]:
import os
from data_preparation.queries import QueryGenerator
from data_preparation.prepare import DataPreparation
from data_preparation.database import Database
from circuit_preparation.circuits import Circuits
import jax
from training.train import SQL2CircuitsEstimator
import itertools
jax.config.update("jax_enable_x64", True)

this_folder = os.path.abspath(os.getcwd())
seed_paths = ["data_preparation//query_seeds//JOB_query_seed_execution_time.json",
              "data_preparation//query_seeds//JOB_query_seed_cardinality.json"]
workload_types = ["execution_time", "cardinality"]
run_id = 1
ty = 1
workload_type = workload_types[ty]
database = Database("IMDB")
generator = QueryGenerator(run_id, workload_type = "cardinality", database = "IMDB", query_seed_file_path = seed_paths[ty])
query_file = generator.get_query_file()
data_preparator = DataPreparation(run_id, query_file, database = database, workload_type = workload_type, classification = 2)


Number of training queries is  398
Number of test queries is  140
Number of validation queries is  135
cardinality
Error while fetching data from PostgreSQL connection to server at "localhost" (127.0.0.1), port 5432 failed: Connection refused
	Is the server running on that host and accepting TCP/IP connections?

{'2': 969831, '3': 2288576, '4': 52494222, '7': 21027170, '8': 2447042, '9': 577409, '11': 0, '12': 37545, '15': 479596, '20': 2866, '24': 5542, '25': 21219, '26': 811831, '30': 29757, '32': 257934, '34': 1413783, '35': 1626450, '37': 243288, '40': 2407523, '41': 68258, '43': 68258, '44': 1697156, '46': 2449481, '47': 559065, '48': 99094, '51': 0, '52': 488912, '53': 2288639, '54': 969837, '56': 52501982, '57': 102555254, '59': 21028934, '61': 577460, '64': 2669657, '65': 1572370, '66': 2669657, '67': 10827927, '68': 32397582, '70': 3182571, '71': 752392, '74': 0, '75': 683930, '82': 37545, '83': 14104, '85': 479596, '89': 35773, '91': 45270, '92': 0, '97': 1512771, '99': 25632

In [2]:
this_folder = os.path.abspath(os.getcwd())
output_folder = this_folder + "//circuit_preparation//data//circuits//" + str(run_id) + "//"
if not os.path.exists(output_folder):
    os.makedirs(output_folder)
    print("The new directory: ", output_folder, " is created.")

circuits = Circuits(run_id, query_file, output_folder, write_cfg_to_file = True, write_pregroup_to_file=True, generate_circuit_png_diagrams = True)
circuits.execute_full_transformation()

In [3]:
from training.utils import construct_data_and_labels, select_circuits

optimization_interval = 20
init_n_of_circuits = 10

training_data = data_preparator.get_training_data()
validation_data = data_preparator.get_validation_data()
test_data = data_preparator.get_test_data()

training_classes = data_preparator.get_training_data_labels()
validation_classes = data_preparator.get_validation_data_labels()
test_classes = data_preparator.get_test_data_labels()

circuits.select_circuits_with_data_point(training_data, validation_data, test_data)

# Training, validation and test circuits
training_circuits = circuits.get_training_circuits()
validation_circuits = circuits.get_validation_circuits()
test_circuits = circuits.get_test_circuits()

# Select the first 10 circuits
current_training_circuits = dict(itertools.islice(training_circuits.items(), init_n_of_circuits))
current_validation_circuits = select_circuits(current_training_circuits, validation_circuits, init_n_of_circuits)
current_test_circuits = select_circuits(current_training_circuits, test_circuits, init_n_of_circuits)

# Construct the data and labels for the training, validation and test circuits
training_circuits_X, training_labels_y = construct_data_and_labels(current_training_circuits, training_classes)
validation_circuits_X, validation_labels_y = construct_data_and_labels(current_validation_circuits, validation_classes)
test_circuits_X, test_labels_y = construct_data_and_labels(current_test_circuits, test_classes)

X_train = list(zip(training_circuits_X, training_labels_y))
X_valid = list(zip(validation_circuits_X, validation_labels_y))
y = [(circ, ) for circ in test_circuits_X]

In [4]:
trainer = SQL2CircuitsEstimator(run_id, 
                              workload = "cardinality", 
                              a = 0.001, 
                              c = 0.001, 
                              classification = 2, 
                              optimization_method = "SPSA")
trainer.fit(X_train, y, X_valid = X_valid)

id: 1
a: 0.001
c: 0.001
epochs: 1000
classification: 4
workload: cardinality
optimization_medthod: SPSA



Initializing new parameters
10 10


[[[0.19622516 0.31067676]
  [0.1908816  0.30221647]]

 [[0.4631932  0.09936826]
  [0.36017141 0.07726714]]

 [[0.08531964 0.11641461]
  [0.33761113 0.46065462]]

 ...

 [[0.30909446 0.37781228]
  [0.14088578 0.17220748]]

 [[0.09898675 0.10274735]
  [0.39169256 0.40657334]]

 [[0.2991915  0.31055807]
  [0.19148781 0.19876262]]]
[[0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 1, 0], [0, 1, 0, 0]]
[0.19622516 0.31067676 0.1908816  0.30221647] [0 0 0 1]
4 4
[0.4631932  0.09936826 0.36017141 0.07726714] [0 0 0 1]
4 4
[0.08531964 0.11641461 0.33761113 0.46065462] [0 0 0 1]
4 4
[0.2533512  0.34568592 0.16957952 0.23138336] [0 0 0 1]
4 4
[0.20499402 0.2797048  0.2179367  0.29736448] [0 0 0 1]
4 4
[0.25788182 0.35186774 0.1650489  0.22520154] [0 0 0 1]
4 4
[0.51946163 0.07957548 0.34769938 0.05326351] [1 0 0 0]
4 4
[0.30909446 0.37781228 0.14088578 0.17220748] [0 0 1 0]
4 4
[0.09898675 0.10274735 0.39169256 0.40657334] [0 

iters: 20
train/loss: 2.9459
train/acc: 0.3
valid/loss: 2.8888
valid/acc: 0.0



[0.29964695 0.23460248 0.26122767 0.2045229 ] [0 0 0 1]
4 4
[0.05436417 0.22743926 0.13855082 0.57964576] [0 0 0 1]
4 4
[0.11059399 0.46268521 0.08232072 0.34440007] [0 0 0 1]
4 4
[0.09698238 0.4057391  0.09593235 0.40134616] [0 0 0 1]
4 4
[0.11068522 0.46306688 0.0822295  0.3440184 ] [0 0 0 1]
4 4
[0.43794827 0.13533093 0.32598712 0.10073368] [1 0 0 0]
4 4
[0.30477176 0.36649423 0.14925357 0.17948044] [0 0 1 0]
4 4
[0.12684685 0.15495642 0.32327869 0.39491804] [0 0 1 0]
4 4
[0.25826047 0.31549163 0.19186506 0.23438283] [0 1 0 0]
4 4
[[[0.09316614 0.47862255]
  [0.06977192 0.35843938]]

 [[0.2554579  0.3163308 ]
  [0.19131186 0.23689944]]

 [[0.22437338 0.27783918]
  [0.22239637 0.27539106]]

 ...

 [[0.42759219 0.1441965 ]
  [0.32022286 0.10798845]]

 [[0.42623233 0.14373791]
  [0.32158272 0.10844703]]

 [[0.48537114 0.05154573]
  [0.41862567 0.04445746]]]
[[0, 0, 0, 1], [0, 0, 1, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 1, 0, 0]]
[0.09316614 0.47862255 0.069771

iters: 40
train/loss: 2.5563
train/acc: 0.4
valid/loss: 3.088
valid/acc: 0.0



[[[0.0781655  0.40912499]
  [0.08224292 0.4304666 ]]

 [[0.26603067 0.26244924]
  [0.23735775 0.23416233]]

 [[0.0150472  0.09167406]
  [0.12594635 0.76733239]]

 ...

 [[0.32446381 0.34864599]
  [0.15757316 0.16931704]]

 [[0.04705593 0.05966491]
  [0.39386942 0.49940973]]

 [[0.24731045 0.31357917]
  [0.1936149  0.24549549]]]
[[0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 1, 0], [0, 1, 0, 0]]
[0.0781655  0.40912499 0.08224292 0.4304666 ] [0 0 0 1]
4 4
[0.26603067 0.26244924 0.23735775 0.23416233] [0 0 0 1]
4 4
[0.0150472  0.09167406 0.12594635 0.76733239] [0 0 0 1]
4 4
[0.08013425 0.48822159 0.06085885 0.37078531] [0 0 0 1]
4 4
[0.07129753 0.43438336 0.0696956  0.42462352] [0 0 0 1]
4 4
[0.07908156 0.48180805 0.06191154 0.37719885] [0 0 0 1]
4 4
[0.43791689 0.13043895 0.33258085 0.09906331] [1 0 0 0]
4 4
[0.32446381 0.34864599 0.15757316 0.16931704] [0 0 1 0]
4 4
[0.04705593 0.05966491 0.39386942 0.49940973] [0 

iters: 60
train/loss: 2.4423
train/acc: 0.4
valid/loss: 3.7109
valid/acc: 0.0



[[[0.10842575 0.37937668]
  [0.11384815 0.39834941]]

 [[0.26258163 0.26036044]
  [0.23954211 0.23751581]]

 [[0.00606626 0.06228358]
  [0.08268337 0.8489668 ]]

 ...

 [[0.3185723  0.31730258]
  [0.1824261  0.18169902]]

 [[0.02880767 0.03954172]
  [0.39266879 0.53898181]]

 [[0.23187371 0.31827259]
  [0.18960274 0.26025096]]]
[[0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 1, 0], [0, 1, 0, 0]]
[0.10842575 0.37937668 0.11384815 0.39834941] [0 0 0 1]
4 4
[0.26258163 0.26036044 0.23954211 0.23751581] [0 0 0 1]
4 4
[0.00606626 0.06228358 0.08268337 0.8489668 ] [0 0 0 1]
4 4
[0.04968227 0.5101236  0.03906684 0.40112728] [0 0 0 1]
4 4
[0.04468639 0.45882708 0.04406275 0.45242378] [0 0 0 1]
4 4
[0.04882499 0.5013213  0.03992412 0.40992959] [0 0 0 1]
4 4
[0.47299677 0.08680911 0.37193321 0.06826092] [1 0 0 0]
4 4
[0.3185723  0.31730258 0.1824261  0.18169902] [0 0 1 0]
4 4
[0.02880767 0.03954172 0.39266879 0.53898181] [0 

iters: 80
train/loss: 2.4722
train/acc: 0.5
valid/loss: 3.2582
valid/acc: 0.0



[[[0.11266068 0.37200619]
  [0.11978904 0.39554408]]

 [[0.2731756  0.24547132]
  [0.25353263 0.22782045]]

 [[0.00627797 0.06915381]
  [0.07694631 0.84762191]]

 ...

 [[0.33774804 0.3314204 ]
  [0.16697995 0.16385161]]

 [[0.03074001 0.04469138]
  [0.37678263 0.54778599]]

 [[0.22332343 0.32467912]
  [0.18419918 0.26779827]]]
[[0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 1, 0], [0, 1, 0, 0]]
[0.11266068 0.37200619 0.11978904 0.39554408] [0 0 0 1]
4 4
[0.2731756  0.24547132 0.25353263 0.22782045] [0 0 0 1]
4 4
[0.00627797 0.06915381 0.07694631 0.84762191] [0 0 0 1]
4 4
[0.04687033 0.51631389 0.03635348 0.4004623 ] [0 0 0 1]
4 4
[0.04150664 0.45722843 0.04171719 0.45954774] [0 0 0 1]
4 4
[0.04560686 0.50239568 0.03761695 0.41438051] [0 0 0 1]
4 4
[0.45108879 0.11209543 0.34987254 0.08694323] [1 0 0 0]
4 4
[0.33774804 0.3314204  0.16697995 0.16385161] [0 0 1 0]
4 4
[0.03074001 0.04469138 0.37678263 0.54778599] [0 

iters: 100
train/loss: 2.3706
train/acc: 0.4
valid/loss: 3.3815
valid/acc: 0.0



[0.47053201 0.08951222 0.36963736 0.07031841] [1 0 0 0]
4 4
[0.33849685 0.29933069 0.19220593 0.16996653] [0 0 1 0]
4 4
[0.02806436 0.03697736 0.40341796 0.53154032] [0 0 1 0]
4 4
[0.23591537 0.31084025 0.19556692 0.25767745] [0 1 0 0]
4 4
[[[0.10250102 0.3815982 ]
  [0.10923454 0.40666624]]

 [[0.27429559 0.24414322]
  [0.25478438 0.22677681]]

 [[0.00113595 0.03505792]
  [0.03024032 0.93356581]]

 ...

 [[0.33666666 0.29623614]
  [0.19527388 0.17182331]]

 [[0.01553277 0.02066057]
  [0.41362794 0.55017871]]

 [[0.23490259 0.31245087]
  [0.1942581  0.25838844]]]
[[0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 1, 0], [0, 1, 0, 0]]
[0.10250102 0.3815982  0.10923454 0.40666624] [0 0 0 1]
4 4
[0.27429559 0.24414322 0.25478438 0.22677681] [0 0 0 1]
4 4
[0.00113595 0.03505792 0.03024032 0.93356581] [0 0 0 1]
4 4
[0.01755406 0.54192694 0.01382156 0.42669744] [0 0 0 1]
4 4
[0.01589573 0.49073057 0.01547992 0.47789379] [0 

iters: 120
train/loss: 2.3069
train/acc: 0.5
valid/loss: 3.2418
valid/acc: 0.0



[[[0.09996468 0.45168656]
  [0.08124525 0.36710352]]

 [[0.23313458 0.31851666]
  [0.18947768 0.25887108]]

 [[0.21008887 0.28703081]
  [0.21252339 0.29035693]]

 ...

 [[0.45251416 0.09913708]
  [0.36777614 0.08057262]]

 [[0.44038663 0.09648017]
  [0.37990368 0.08322953]]

 [[0.51591142 0.02813814]
  [0.43236876 0.02358168]]]
[[0, 0, 0, 1], [0, 0, 1, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 1, 0, 0]]
[0.09996468 0.45168656 0.08124525 0.36710352] [0 0 0 1]
4 4
[0.23313458 0.31851666 0.18947768 0.25887108] [0 0 1 0]
4 4
[0.21008887 0.28703081 0.21252339 0.29035693] [0 0 1 0]
4 4
[0.39235301 0.15929823 0.31888079 0.12946797] [0 1 0 0]
4 4
[0.45251416 0.09913708 0.36777614 0.08057262] [0 0 1 0]
4 4
[0.44038663 0.09648017 0.37990368 0.08322953] [0 1 0 0]
4 4
[0.51591142 0.02813814 0.43236876 0.02358168] [0 1 0 0]
4 4
[[[9.20611855e-02 3.81454437e-01]
  [1.02359400e-01 4.24124978e-01]]

 [[2.60161374e-01 2.49250848e-01]
  [2.50547562e-01 2.40040215e-01]]

 [[3.979312

iters: 140
train/loss: 2.2905
train/acc: 0.5
valid/loss: 3.2365
valid/acc: 0.0



[[[7.30104791e-02 3.86853069e-01]
  [8.57550451e-02 4.54381407e-01]]

 [[2.61569473e-01 2.41004067e-01]
  [2.58890623e-01 2.38535836e-01]]

 [[1.26539754e-04 3.14773101e-02]
  [3.86789709e-03 9.64528253e-01]]

 ...

 [[3.37381906e-01 2.89778444e-01]
  [2.00569681e-01 1.72269969e-01]]

 [[1.36837282e-02 1.79196566e-02]
  [4.19298970e-01 5.49097646e-01]]

 [[2.27265000e-01 2.97617428e-01]
  [2.05717679e-01 2.69399893e-01]]]
[[0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 1, 0], [0, 1, 0, 0]]
[0.07301048 0.38685307 0.08575505 0.45438141] [0 0 0 1]
4 4
[0.26156947 0.24100407 0.25889062 0.23853584] [0 0 0 1]
4 4
[1.26539754e-04 3.14773101e-02 3.86789709e-03 9.64528253e-01] [0 0 0 1]
4 4
[0.0021853  0.54498248 0.00180854 0.45102368] [0 0 0 1]
4 4
[0.00197239 0.49188243 0.00202147 0.50412371] [0 0 0 1]
4 4
[0.0020963  0.52278613 0.00189754 0.47322003] [0 0 0 1]
4 4
[0.45788945 0.08927833 0.3789461  0.07388612] [1 0 0 0]
4

iters: 160
train/loss: 2.2725
train/acc: 0.5
valid/loss: 3.1177
valid/acc: 0.0



[[[0.06498686 0.39178275]
  [0.07728806 0.46594232]]

 [[0.25321467 0.24983112]
  [0.25014839 0.24680582]]

 [[0.00527441 0.07339227]
  [0.0617704  0.85956293]]

 ...

 [[0.35228068 0.29530072]
  [0.19171376 0.16070484]]

 [[0.03504347 0.04362284]
  [0.4104263  0.5109074 ]]

 [[0.23505432 0.29260064]
  [0.21041542 0.26192961]]]
[[0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 1, 0], [0, 1, 0, 0]]
[0.06498686 0.39178275 0.07728806 0.46594232] [0 0 0 1]
4 4
[0.25321467 0.24983112 0.25014839 0.24680582] [0 0 0 1]
4 4
[0.00527441 0.07339227 0.0617704  0.85956293] [0 0 0 1]
4 4
[0.03681209 0.51225884 0.03023224 0.42069682] [0 0 0 1]
4 4
[0.03335161 0.46410429 0.03369274 0.46885136] [0 0 0 1]
4 4
[0.03537627 0.49227869 0.03166806 0.44067698] [0 0 0 1]
4 4
[0.46040486 0.08866608 0.37811131 0.07281776] [1 0 0 0]
4 4
[0.35228068 0.29530072 0.19171376 0.16070484] [0 0 1 0]
4 4
[0.03504347 0.04362284 0.4104263  0.5109074 ] [0 

iters: 180
train/loss: 2.2628
train/acc: 0.5
valid/loss: 3.1464
valid/acc: 0.0



[[[0.08053705 0.38400175]
  [0.09283286 0.44262834]]

 [[0.25355102 0.2497244 ]
  [0.25025069 0.24647389]]

 [[0.00402079 0.07851898]
  [0.04468977 0.87277046]]

 ...

 [[0.35286682 0.28849937]
  [0.19731313 0.16132068]]

 [[0.03762368 0.04491569]
  [0.41820331 0.49925732]]

 [[0.24147095 0.28827161]
  [0.21435603 0.25590141]]]
[[0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 1, 0], [0, 1, 0, 0]]
[0.08053705 0.38400175 0.09283286 0.44262834] [0 0 0 1]
4 4
[0.25355102 0.2497244  0.25025069 0.24647389] [0 0 0 1]
4 4
[0.00402079 0.07851898 0.04468977 0.87277046] [0 0 0 1]
4 4
[0.02667812 0.5210143  0.02203192 0.43027567] [0 0 0 1]
4 4
[0.02431869 0.47493546 0.02439135 0.47635449] [0 0 0 1]
4 4
[0.02580378 0.50393878 0.02290626 0.44735119] [0 0 0 1]
4 4
[0.46602006 0.08167235 0.3848591  0.06744849] [1 0 0 0]
4 4
[0.35286682 0.28849937 0.19731313 0.16132068] [0 0 1 0]
4 4
[0.03762368 0.04491569 0.41820331 0.49925732] [0 

iters: 200
train/loss: 2.2892
train/acc: 0.5
valid/loss: 3.2482
valid/acc: 0.0



[0.01261528 0.48339215 0.01281837 0.4911742 ] [0 0 0 1]
4 4
[0.01320697 0.50606486 0.01222666 0.46850151] [0 0 0 1]
4 4
[0.45777322 0.08083372 0.39214754 0.06924552] [1 0 0 0]
4 4
[0.33663207 0.27687227 0.21207158 0.17442408] [0 0 1 0]
4 4
[0.023308   0.03136508 0.40300765 0.54231927] [0 0 1 0]
4 4
[0.2213737  0.29789813 0.20494193 0.27578624] [0 1 0 0]
4 4
[[[0.08944472 0.44921828]
  [0.07660478 0.38473222]]

 [[0.22969442 0.30896858]
  [0.19672139 0.26461561]]

 [[0.21151965 0.28452118]
  [0.21489616 0.28906301]]

 ...

 [[0.45160546 0.08705754]
  [0.38677672 0.07456028]]

 [[0.43540451 0.08393443]
  [0.40297767 0.07768339]]

 [[0.54133299 0.02291445]
  [0.41805635 0.0176962 ]]]
[[0, 0, 0, 1], [0, 0, 1, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 1, 0, 0]]
[0.08944472 0.44921828 0.07660478 0.38473222] [0 0 0 1]
4 4
[0.22969442 0.30896858 0.19672139 0.26461561] [0 0 1 0]
4 4
[0.21151965 0.28452118 0.21489616 0.28906301] [0 0 1 0]
4 4
[0.38771877 0.15094423 0.332061

iters: 220
train/loss: 2.2994
train/acc: 0.5
valid/loss: 3.3916
valid/acc: 0.0



[0.07775457 0.45480161 0.068248   0.39919583] [0 0 0 1]
4 4
[0.22607292 0.30648326 0.19843238 0.26901144] [0 0 1 0]
4 4
[0.21011572 0.28485036 0.21438958 0.29064435] [0 0 1 0]
4 4
[0.38391619 0.14863999 0.33697713 0.1304667 ] [0 1 0 0]
4 4
[0.4408755  0.09168067 0.38697238 0.08047144] [0 0 1 0]
4 4
[0.42400944 0.08817335 0.40383845 0.08397876] [0 1 0 0]
4 4
[0.54939891 0.01446273 0.42495164 0.01118672] [0 1 0 0]
4 4
[[[0.08007952 0.38120073]
  [0.09352323 0.44519652]]

 [[0.25472554 0.24228563]
  [0.25778918 0.24519965]]

 [[0.00326494 0.15995037]
  [0.01673784 0.82004684]]

 ...

 [[0.3268489  0.27584186]
  [0.2154672  0.18184205]]

 [[0.06979762 0.0934174 ]
  [0.35784448 0.4789405 ]]

 [[0.21893502 0.29302353]
  [0.20870706 0.27933438]]]
[[0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 1, 0], [0, 1, 0, 0]]
[0.08007952 0.38120073 0.09352323 0.44519652] [0 0 0 1]
4 4
[0.25472554 0.24228563 0.25778918 0.24519965] [0 

iters: 240
train/loss: 2.2396
train/acc: 0.5
valid/loss: 3.4212
valid/acc: 0.0



[[[0.06480971 0.4654871 ]
  [0.05740432 0.41229887]]

 [[0.2222662  0.30803061]
  [0.19686926 0.27283393]]

 [[0.20765438 0.28778062]
  [0.21148109 0.29308392]]

 ...

 [[0.43106087 0.09923594]
  [0.38180631 0.08789688]]

 [[0.41534176 0.09561719]
  [0.39752542 0.09151563]]

 [[0.54698413 0.012977  ]
  [0.42984104 0.01019783]]]
[[0, 0, 0, 1], [0, 0, 1, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 1, 0, 0]]
[0.06480971 0.4654871  0.05740432 0.41229887] [0 0 0 1]
4 4
[0.2222662  0.30803061 0.19686926 0.27283393] [0 0 1 0]
4 4
[0.20765438 0.28778062 0.21148109 0.29308392] [0 0 1 0]
4 4
[0.38770173 0.14259508 0.34340154 0.12630166] [0 1 0 0]
4 4
[0.43106087 0.09923594 0.38180631 0.08789688] [0 0 1 0]
4 4
[0.41534176 0.09561719 0.39752542 0.09151563] [0 1 0 0]
4 4
[0.54698413 0.012977   0.42984104 0.01019783] [0 1 0 0]
4 4
[[[6.98697680e-02 3.94752158e-01]
  [8.05100626e-02 4.54868012e-01]]

 [[2.52916213e-01 2.41604024e-01]
  [2.58521326e-01 2.46958437e-01]]

 [[6.323603

iters: 260
train/loss: 2.2678
train/acc: 0.5
valid/loss: 3.4457
valid/acc: 0.0



[0 0 1 0]
4 4
[0.21543731 0.29770894 0.20439877 0.28245498] [0 1 0 0]
4 4
[[[0.08771423 0.44383336]
  [0.07730247 0.39114993]]

 [[0.22773047 0.30381712]
  [0.20069866 0.26775375]]

 [[0.21214261 0.28302122]
  [0.21628652 0.28854965]]

 ...

 [[0.43561496 0.09593263]
  [0.38390707 0.08454534]]

 [[0.41969104 0.09242581]
  [0.399831   0.08805216]]

 [[0.54736331 0.01407864]
  [0.42756081 0.01099724]]]
[[0, 0, 0, 1], [0, 0, 1, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 1, 0, 0]]
[0.08771423 0.44383336 0.07730247 0.39114993] [0 0 0 1]
4 4
[0.22773047 0.30381712 0.20069866 0.26775375] [0 0 1 0]
4 4
[0.21214261 0.28302122 0.21628652 0.28854965] [0 0 1 0]
4 4
[0.38148025 0.15006734 0.33619819 0.13225421] [0 1 0 0]
4 4
[0.43561496 0.09593263 0.38390707 0.08454534] [0 0 1 0]
4 4
[0.41969104 0.09242581 0.399831   0.08805216] [0 1 0 0]
4 4
[0.54736331 0.01407864 0.42756081 0.01099724] [0 1 0 0]
4 4
[[[0.0814214  0.38120608]
  [0.09457636 0.44279616]]

 [[0.2570236  0.2384146

iters: 280
train/loss: 2.216
train/acc: 0.5
valid/loss: 3.3419
valid/acc: 0.0



[0.01782854 0.51089377 0.01589151 0.45538619] [0 0 0 1]
4 4
[0.01669085 0.47829188 0.01702921 0.48798806] [0 0 0 1]
4 4
[0.01720524 0.4930324  0.01651481 0.47324755] [0 0 0 1]
4 4
[0.46876405 0.05995826 0.41783378 0.05344392] [1 0 0 0]
4 4
[0.3183602  0.26992563 0.22280565 0.18890852] [0 0 1 0]
4 4
[0.02196426 0.03035256 0.39786722 0.54981597] [0 0 1 0]
4 4
[0.21421381 0.29602383 0.20561764 0.28414472] [0 1 0 0]
4 4
[[[8.30130221e-02 3.78892298e-01]
  [9.67056710e-02 4.41389009e-01]]

 [[2.51460667e-01 2.42122662e-01]
  [2.57998733e-01 2.48417937e-01]]

 [[3.67747666e-04 1.30971454e-02]
  [2.69184800e-02 9.59616627e-01]]

 ...

 [[3.17829251e-01 2.67482165e-01]
  [2.25179552e-01 1.89509033e-01]]

 [[5.68490085e-03 7.77943270e-03]
  [4.16533410e-01 5.70002256e-01]]

 [[2.15558421e-01 2.94979441e-01]
  [2.06659865e-01 2.82802272e-01]]]
[[0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 1, 0], [0, 1, 0, 0]]
[0.08301302 0

iters: 300
train/loss: 2.2092
train/acc: 0.5
valid/loss: 3.3969
valid/acc: 0.0



[0.21542492 0.29232966 0.20884485 0.28340057] [0 1 0 0]
4 4
[[[0.08042969 0.44577955]
  [0.07241767 0.40137309]]

 [[0.2231602  0.30304905]
  [0.20093003 0.27286073]]

 [[0.21000724 0.28518748]
  [0.21408298 0.2907223 ]]

 ...

 [[0.4295207  0.09668854]
  [0.38673387 0.08705689]]

 [[0.41444364 0.09329457]
  [0.40181094 0.09045085]]

 [[0.54656903 0.01418184]
  [0.42814016 0.01110898]]]
[[0, 0, 0, 1], [0, 0, 1, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 1, 0, 0]]
[0.08042969 0.44577955 0.07241767 0.40137309] [0 0 0 1]
4 4
[0.2231602  0.30304905 0.20093003 0.27286073] [0 0 1 0]
4 4
[0.21000724 0.28518748 0.21408298 0.2907223 ] [0 0 1 0]
4 4
[0.38140336 0.14480588 0.34340975 0.130381  ] [0 1 0 0]
4 4
[0.4295207  0.09668854 0.38673387 0.08705689] [0 0 1 0]
4 4
[0.41444364 0.09329457 0.40181094 0.09045085] [0 1 0 0]
4 4
[0.54656903 0.01418184 0.42814016 0.01110898] [0 1 0 0]
4 4
[[[8.33709790e-02 3.79724550e-01]
  [9.66587821e-02 4.40245689e-01]]

 [[2.43618027e-01 2.4

iters: 320
train/loss: 2.2026
train/acc: 0.5
valid/loss: 3.2976
valid/acc: 0.0



[0.08669513 0.38084826 0.09873179 0.43372483] [0 0 0 1]
4 4
[0.24375398 0.24584567 0.2541099  0.25629045] [0 0 0 1]
4 4
[0.00193912 0.13411505 0.01231162 0.8516342 ] [0 0 0 1]
4 4
[0.00743349 0.51421174 0.0068166  0.47153816] [0 0 0 1]
4 4
[0.00703519 0.48665887 0.00721491 0.49909102] [0 0 0 1]
4 4
[0.00718545 0.49705334 0.00706464 0.48869657] [0 0 0 1]
4 4
[0.45926773 0.06237751 0.42115386 0.0572009 ] [1 0 0 0]
4 4
[0.3139248  0.26590105 0.22748742 0.19268673] [0 0 1 0]
4 4
[0.05611775 0.07993603 0.35634958 0.50759663] [0 0 1 0]
4 4
[0.20798202 0.29625677 0.20448529 0.29127592] [0 1 0 0]
4 4
[[[0.08358457 0.44006494]
  [0.07603474 0.40031575]]

 [[0.22139605 0.30225347]
  [0.20139828 0.2749522 ]]

 [[0.20906997 0.28542571]
  [0.21372436 0.29177996]]

 ...

 [[0.41326498 0.11038453]
  [0.37593651 0.10041397]]

 [[0.39925129 0.10664142]
  [0.38995021 0.10415707]]

 [[0.53895281 0.01825554]
  [0.42828468 0.01450698]]]
[[0, 0, 0, 1], [0, 0, 1, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 1, 0],

iters: 340
train/loss: 2.2
train/acc: 0.6
valid/loss: 3.2637
valid/acc: 0.0



[0.09017318 0.3772365  0.10274789 0.42984242] [0 0 0 1]
4 4
[0.23055495 0.26158715 0.23791736 0.26994054] [0 0 0 1]
4 4
[0.00215846 0.04672274 0.04199228 0.90912652] [0 0 0 1]
4 4
[0.02316062 0.50142791 0.02098945 0.45442202] [0 0 0 1]
4 4
[0.02192038 0.47457667 0.02222969 0.48127326] [0 0 0 1]
4 4
[0.02246628 0.48639546 0.02168379 0.46945447] [0 0 0 1]
4 4
[0.46099233 0.06359619 0.41777704 0.05763443] [1 0 0 0]
4 4
[0.31484731 0.25963822 0.23320358 0.19231088] [0 0 1 0]
4 4
[0.02085918 0.02802149 0.40587728 0.54524205] [0 0 1 0]
4 4
[0.21714985 0.2917119  0.20958659 0.28155167] [0 1 0 0]
4 4
[[[0.08554614 0.43761133]
  [0.07797277 0.39886976]]

 [[0.22337718 0.2997803 ]
  [0.20360167 0.27324085]]

 [[0.21179919 0.28424221]
  [0.21517966 0.28877893]]

 ...

 [[0.4135753  0.10958217]
  [0.37696163 0.0998809 ]]

 [[0.40133871 0.10633992]
  [0.38919823 0.10312314]]

 [[0.53619498 0.01847702]
  [0.4304934  0.01483461]]]
[[0, 0, 0, 1], [0, 0, 1, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 1, 0],

iters: 360
train/loss: 2.2012
train/acc: 0.6
valid/loss: 3.1986
valid/acc: 0.0



[0.0959452  0.37268921 0.10878838 0.42257722] [0 0 0 1]
4 4
[0.21973877 0.27098092 0.22805001 0.2812303 ] [0 0 0 1]
4 4
[0.00157947 0.04774896 0.0304334  0.92023817] [0 0 0 1]
4 4
[0.01670595 0.50515623 0.01530624 0.46283159] [0 0 0 1]
4 4
[0.01585268 0.47935475 0.01615952 0.48863305] [0 0 0 1]
4 4
[0.0162294  0.49074633 0.01578278 0.47724149] [0 0 0 1]
4 4
[0.46731959 0.05454258 0.4281651  0.04997272] [1 0 0 0]
4 4
[0.31081269 0.25584531 0.23768868 0.19565332] [0 0 1 0]
4 4
[0.02164064 0.02768725 0.41706918 0.53360293] [0 0 1 0]
4 4
[0.22241522 0.28456051 0.21629458 0.27672969] [0 1 0 0]
4 4
[[[0.09545698 0.42649441]
  [0.08742783 0.39062078]]

 [[0.22992686 0.29202453]
  [0.21058707 0.26746154]]

 [[0.21823371 0.27717335]
  [0.22228022 0.28231272]]

 ...

 [[0.40990306 0.11204833]
  [0.37542498 0.10262363]]

 [[0.39828909 0.10887361]
  [0.38703895 0.10579835]]

 [[0.52564358 0.02509012]
  [0.42879879 0.02046751]]]
[[0, 0, 0, 1], [0, 0, 1, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 1, 0],

iters: 380
train/loss: 2.2012
train/acc: 0.6
valid/loss: 3.2522
valid/acc: 0.0



[[[0.08606801 0.43465366]
  [0.07921801 0.40006032]]

 [[0.22010644 0.30061523]
  [0.20258855 0.27668978]]

 [[0.20969048 0.2863894 ]
  [0.21300451 0.29091561]]

 ...

 [[0.40751386 0.11320781]
  [0.37508053 0.1041978 ]]

 [[0.39728175 0.11036532]
  [0.38531264 0.10704029]]

 [[0.52248715 0.02547598]
  [0.43102069 0.02101617]]]
[[0, 0, 0, 1], [0, 0, 1, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 1, 0, 0]]
[0.08606801 0.43465366 0.07921801 0.40006032] [0 0 0 1]
4 4
[0.22010644 0.30061523 0.20258855 0.27668978] [0 0 1 0]
4 4
[0.20969048 0.2863894  0.21300451 0.29091561] [0 0 1 0]
4 4
[0.39083267 0.129889   0.35972697 0.11955136] [0 1 0 0]
4 4
[0.40751386 0.11320781 0.37508053 0.1041978 ] [0 0 1 0]
4 4
[0.39728175 0.11036532 0.38531264 0.10704029] [0 1 0 0]
4 4
[0.52248715 0.02547598 0.43102069 0.02101617] [0 1 0 0]
4 4
[[[0.092107   0.37763389]
  [0.10397344 0.42628567]]

 [[0.22485339 0.26750062]
  [0.23183709 0.27580889]]

 [[0.00166247 0.07357617]
  [0.02042913 0.9

iters: 400
train/loss: 2.1903
train/acc: 0.6
valid/loss: 3.2635
valid/acc: 0.0



[[[0.09155543 0.4233965 ]
  [0.0862387  0.39880937]]

 [[0.21683938 0.29811255]
  [0.20424727 0.2808008 ]]

 [[0.20794624 0.28588618]
  [0.21314041 0.29302717]]

 ...

 [[0.41038592 0.10456601]
  [0.38655433 0.09849374]]

 [[0.39918541 0.10171213]
  [0.39775484 0.10134762]]

 [[0.52468155 0.02646712]
  [0.4272967  0.02155463]]]
[[0, 0, 0, 1], [0, 0, 1, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 1, 0, 0]]
[0.09155543 0.4233965  0.0862387  0.39880937] [0 0 0 1]
4 4
[0.21683938 0.29811255 0.20424727 0.2808008 ] [0 0 1 0]
4 4
[0.20794624 0.28588618 0.21314041 0.29302717] [0 0 1 0]
4 4
[0.38801068 0.12694125 0.36547845 0.11956962] [0 1 0 0]
4 4
[0.41038592 0.10456601 0.38655433 0.09849374] [0 0 1 0]
4 4
[0.39918541 0.10171213 0.39775484 0.10134762] [0 1 0 0]
4 4
[0.52468155 0.02646712 0.4272967  0.02155463] [0 1 0 0]
4 4
[[[0.08927236 0.377923  ]
  [0.10180908 0.43099556]]

 [[0.23224963 0.25767488]
  [0.24180224 0.26827326]]

 [[0.0015572  0.08872295]
  [0.01568744 0.8

iters: 420
train/loss: 2.203
train/acc: 0.6
valid/loss: 3.2426
valid/acc: 0.0



[0.29134844 0.24712132 0.24971897 0.21181127] [0 0 1 0]
4 4
[0.04953211 0.06798083 0.37197125 0.5105158 ] [0 0 1 0]
4 4
[0.20855955 0.28623974 0.21294379 0.29225693] [0 1 0 0]
4 4
[[[0.07972293 0.38703841]
  [0.09107726 0.4421614 ]]

 [[0.23771097 0.2493267 ]
  [0.25036415 0.26259818]]

 [[0.00721832 0.11660945]
  [0.05107304 0.82509919]]

 ...

 [[0.29244783 0.24825554]
  [0.24841773 0.2108789 ]]

 [[0.05202001 0.07180736]
  [0.36808097 0.50809166]]

 [[0.20811342 0.28727564]
  [0.21198754 0.2926234 ]]]
[[0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 1, 0], [0, 1, 0, 0]]
[0.07972293 0.38703841 0.09107726 0.4421614 ] [0 0 0 1]
4 4
[0.23771097 0.2493267  0.25036415 0.26259818] [0 0 0 1]
4 4
[0.00721832 0.11660945 0.05107304 0.82509919] [0 0 0 1]
4 4
[0.02970782 0.47994094 0.02858296 0.46176828] [0 0 0 1]
4 4
[0.02867277 0.46321921 0.02961802 0.47849   ] [0 0 0 1]
4 4
[0.02887661 0.46651245 0.02941416 0.47519677] [0 

iters: 440
train/loss: 2.2071
train/acc: 0.6
valid/loss: 3.1897
valid/acc: 0.0



[[[0.07085513 0.39744742]
  [0.0804469  0.45125055]]

 [[0.2416373  0.24569212]
  [0.25420245 0.25846813]]

 [[0.00303733 0.06484367]
  [0.04170269 0.8904163 ]]

 ...

 [[0.29377106 0.24716491]
  [0.24930811 0.20975592]]

 [[0.02878314 0.0390973 ]
  [0.39524374 0.53687582]]

 [[0.21057145 0.28602787]
  [0.21345541 0.28994527]]]
[[0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 1, 0], [0, 1, 0, 0]]
[0.07085513 0.39744742 0.0804469  0.45125055] [0 0 0 1]
4 4
[0.2416373  0.24569212 0.25420245 0.25846813] [0 0 0 1]
4 4
[0.00303733 0.06484367 0.04170269 0.8904163 ] [0 0 0 1]
4 4
[0.02282956 0.48745042 0.02190972 0.4678103 ] [0 0 0 1]
4 4
[0.02201521 0.47006252 0.02272409 0.48519818] [0 0 0 1]
4 4
[0.0222175  0.47438182 0.02252179 0.48087889] [0 0 0 1]
4 4
[0.43556538 0.07471459 0.41801579 0.07170423] [1 0 0 0]
4 4
[0.29377106 0.24716491 0.24930811 0.20975592] [0 0 1 0]
4 4
[0.02878314 0.0390973  0.39524374 0.53687582] [0 

iters: 460
train/loss: 2.2989
train/acc: 0.5
valid/loss: 3.3202
valid/acc: 0.0



[[[0.07000417 0.40207662]
  [0.07828436 0.44963485]]

 [[0.24343432 0.24317027]
  [0.256837   0.25655842]]

 [[0.00532793 0.11069466]
  [0.04059136 0.84338604]]

 ...

 [[0.28142524 0.24939875]
  [0.24874152 0.22043448]]

 [[0.04756567 0.0684565 ]
  [0.36240475 0.52157308]]

 [[0.20217737 0.29097378]
  [0.20779302 0.29905583]]]
[[0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 1, 0], [0, 1, 0, 0]]
[0.07000417 0.40207662 0.07828436 0.44963485] [0 0 0 1]
4 4
[0.24343432 0.24317027 0.256837   0.25655842] [0 0 0 1]
4 4
[0.00532793 0.11069466 0.04059136 0.84338604] [0 0 0 1]
4 4
[0.02334102 0.48497116 0.02257765 0.46911018] [0 0 0 1]
4 4
[0.02259145 0.4693969  0.02332722 0.48468442] [0 0 0 1]
4 4
[0.02264484 0.47050631 0.02327382 0.48357502] [0 0 0 1]
4 4
[0.43359271 0.07471946 0.41941206 0.07227576] [1 0 0 0]
4 4
[0.28142524 0.24939875 0.24874152 0.22043448] [0 0 1 0]
4 4
[0.04756567 0.0684565  0.36240475 0.52157308] [0 

iters: 480
train/loss: 2.2172
train/acc: 0.6
valid/loss: 3.3594
valid/acc: 0.0



[[[0.07952269 0.39431408]
  [0.08830449 0.43785874]]

 [[0.2338244  0.25376813]
  [0.24572437 0.26668309]]

 [[0.00114619 0.06193074]
  [0.01701884 0.91990423]]

 ...

 [[0.28595891 0.25401044]
  [0.24362469 0.21640596]]

 [[0.02611981 0.03695651]
  [0.38797844 0.54894525]]

 [[0.20562056 0.29092967]
  [0.20847765 0.29497213]]]
[[0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 1, 0], [0, 1, 0, 0]]
[0.07952269 0.39431408 0.08830449 0.43785874] [0 0 0 1]
4 4
[0.2338244  0.25376813 0.24572437 0.26668309] [0 0 0 1]
4 4
[0.00114619 0.06193074 0.01701884 0.91990423] [0 0 0 1]
4 4
[0.00928266 0.50175988 0.00888151 0.48007595] [0 0 0 1]
4 4
[0.00894658 0.48359313 0.0092176  0.49824269] [0 0 0 1]
4 4
[0.00901942 0.4875308  0.00914475 0.49430503] [0 0 0 1]
4 4
[0.43487592 0.07616662 0.41608244 0.07287502] [1 0 0 0]
4 4
[0.28595891 0.25401044 0.24362469 0.21640596] [0 0 1 0]
4 4
[0.02611981 0.03695651 0.38797844 0.54894525] [0 

iters: 500
train/loss: 2.2512
train/acc: 0.6
valid/loss: 3.3364
valid/acc: 0.0



[[[0.10145557 0.37413899]
  [0.11186808 0.41253736]]

 [[0.22622488 0.26264227]
  [0.2365284  0.27460445]]

 [[0.00963109 0.1334014 ]
  [0.05770235 0.79926515]]

 ...

 [[0.28623581 0.25288821]
  [0.24469177 0.21618422]]

 [[0.05824496 0.08478715]
  [0.34897094 0.50799695]]

 [[0.20232368 0.29452258]
  [0.20489219 0.29826155]]]
[[0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 1, 0], [0, 1, 0, 0]]
[0.10145557 0.37413899 0.11186808 0.41253736] [0 0 0 1]
4 4
[0.22622488 0.26264227 0.2365284  0.27460445] [0 0 0 1]
4 4
[0.00963109 0.1334014  0.05770235 0.79926515] [0 0 0 1]
4 4
[0.03441151 0.47665415 0.03292135 0.45601299] [0 0 0 1]
4 4
[0.03322234 0.46018221 0.03411052 0.47248492] [0 0 0 1]
4 4
[0.03345408 0.46339219 0.03387878 0.46927496] [0 0 0 1]
4 4
[0.43492293 0.07614274 0.4160889  0.07284543] [1 0 0 0]
4 4
[0.28623581 0.25288821 0.24469177 0.21618422] [0 0 1 0]
4 4
[0.05824496 0.08478715 0.34897094 0.50799695] [0 

iters: 520
train/loss: 2.2006
train/acc: 0.6
valid/loss: 3.2332
valid/acc: 0.0



[[[0.07708848 0.4338933 ]
  [0.07377498 0.41524325]]

 [[0.21966128 0.2913205 ]
  [0.21021957 0.27879866]]

 [[0.21242361 0.28172171]
  [0.21745724 0.28839744]]

 ...

 [[0.4275956  0.08338618]
  [0.40921624 0.07980199]]

 [[0.41592059 0.08110941]
  [0.42089125 0.08207875]]

 [[0.50756318 0.0302966 ]
  [0.43610875 0.02603147]]]
[[0, 0, 0, 1], [0, 0, 1, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 1, 0, 0]]
[0.07708848 0.4338933  0.07377498 0.41524325] [0 0 0 1]
4 4
[0.21966128 0.2913205  0.21021957 0.27879866] [0 0 1 0]
4 4
[0.21242361 0.28172171 0.21745724 0.28839744] [0 0 1 0]
4 4
[0.39042513 0.12055665 0.37364347 0.11537476] [0 1 0 0]
4 4
[0.4275956  0.08338618 0.40921624 0.07980199] [0 0 1 0]
4 4
[0.41592059 0.08110941 0.42089125 0.08207875] [0 1 0 0]
4 4
[0.50756318 0.0302966  0.43610875 0.02603147] [0 1 0 0]
4 4
[[[0.08575916 0.38904278]
  [0.09486175 0.43033632]]

 [[0.23101556 0.2583139 ]
  [0.24109082 0.26957972]]

 [[0.0031797  0.04932037]
  [0.05737914 0.8

iters: 540
train/loss: 2.2237
train/acc: 0.5
valid/loss: 3.3637
valid/acc: 0.0



[0.38694371 0.11926272 0.37745533 0.11633824] [0 1 0 0]
4 4
[0.41798421 0.08822222 0.40773467 0.0860589 ] [0 0 1 0]
4 4
[0.40620181 0.08573536 0.41951707 0.08854576] [0 1 0 0]
4 4
[0.51935319 0.01738084 0.44826421 0.01500176] [0 1 0 0]
4 4
[[[7.40067038e-02 4.03197094e-01]
  [8.10773571e-02 4.41718845e-01]]

 [[2.54232398e-01 2.31903025e-01]
  [2.68733808e-01 2.45130769e-01]]

 [[6.43385709e-04 1.04191074e-01]
  [5.49050359e-03 8.89675036e-01]]

 ...

 [[2.89290603e-01 2.45918936e-01]
  [2.51227795e-01 2.13562665e-01]]

 [[4.35841358e-02 6.12497805e-02]
  [3.72160355e-01 5.23005729e-01]]

 [[2.04661043e-01 2.87615013e-01]
  [2.11083417e-01 2.96640527e-01]]]
[[0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 1, 0], [0, 1, 0, 0]]
[0.0740067  0.40319709 0.08107736 0.44171885] [0 0 0 1]
4 4
[0.2542324  0.23190303 0.26873381 0.24513077] [0 0 0 1]
4 4
[6.43385709e-04 1.04191074e-01 5.49050359e-03 8.89675036e-01] [0 0 0 1]
4

iters: 560
train/loss: 2.2315
train/acc: 0.5
valid/loss: 3.3493
valid/acc: 0.0



[[[0.07228929 0.40439927]
  [0.0793596  0.44395184]]

 [[0.25603885 0.23116019]
  [0.26949349 0.24330747]]

 [[0.00281844 0.12623037]
  [0.01901948 0.85193172]]

 ...

 [[0.29261955 0.24459221]
  [0.25208101 0.21070722]]

 [[0.05401748 0.07503086]
  [0.36456574 0.50638592]]

 [[0.2068795  0.28735797]
  [0.21170369 0.29405883]]]
[[0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 1, 0], [0, 1, 0, 0]]
[0.07228929 0.40439927 0.0793596  0.44395184] [0 0 0 1]
4 4
[0.25603885 0.23116019 0.26949349 0.24330747] [0 0 0 1]
4 4
[0.00281844 0.12623037 0.01901948 0.85193172] [0 0 0 1]
4 4
[0.01110551 0.49745376 0.01073169 0.48070904] [0 0 0 1]
4 4
[0.01075579 0.48178822 0.01108142 0.49637457] [0 0 0 1]
4 4
[0.01079276 0.48344471 0.01104444 0.49471809] [0 0 0 1]
4 4
[0.44268276 0.06587651 0.42778167 0.06365905] [1 0 0 0]
4 4
[0.29261955 0.24459221 0.25208101 0.21070722] [0 0 1 0]
4 4
[0.05401748 0.07503086 0.36456574 0.50638592] [0 

iters: 580
train/loss: 2.1954
train/acc: 0.5
valid/loss: 3.361
valid/acc: 0.0



[0.42605136 0.08259941 0.4115594  0.07978983] [1 0 0 0]
4 4
[0.28921261 0.24447943 0.25269656 0.2136114 ] [0 0 1 0]
4 4
[0.0127509  0.01814026 0.40001718 0.56909167] [0 0 1 0]
4 4
[0.20306703 0.288897   0.20970101 0.29833495] [0 1 0 0]
4 4
[[[6.39887177e-02 4.12484569e-01]
  [7.03078281e-02 4.53218885e-01]]

 [[2.52446729e-01 2.32639156e-01]
  [2.67969834e-01 2.46944281e-01]]

 [[6.87969993e-04 4.48695000e-02]
  [1.44042262e-02 9.40038304e-01]]

 ...

 [[2.89313530e-01 2.44195774e-01]
  [2.52970415e-01 2.13520282e-01]]

 [[1.90217254e-02 2.65350871e-02]
  [3.98516491e-01 5.55926696e-01]]

 [[2.05367532e-01 2.86485762e-01]
  [2.12170653e-01 2.95976053e-01]]]
[[0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 1, 0], [0, 1, 0, 0]]
[0.06398872 0.41248457 0.07030783 0.45321888] [0 0 0 1]
4 4
[0.25244673 0.23263916 0.26796983 0.24694428] [0 0 0 1]
4 4
[6.87969993e-04 4.48695000e-02 1.44042262e-02 9.40038304e-01] [0 0 0 1]
4

iters: 600
train/loss: 2.1865
train/acc: 0.5
valid/loss: 3.3795
valid/acc: 0.0



[0.24715174 0.23689517 0.26344286 0.25251022] [0 0 0 1]
4 4
[0.00124248 0.04258679 0.02709698 0.92907375] [0 0 0 1]
4 4
[0.01437374 0.49283955 0.01396491 0.47882179] [0 0 0 1]
4 4
[0.01391229 0.47701729 0.01442638 0.49464404] [0 0 0 1]
4 4
[0.01389259 0.4763419  0.01444607 0.49531944] [0 0 0 1]
4 4
[0.4259224  0.0812909  0.41380795 0.07897875] [1 0 0 0]
4 4
[0.28838174 0.24202477 0.25531773 0.21427576] [0 0 1 0]
4 4
[0.01820827 0.02562037 0.39723368 0.55893768] [0 0 1 0]
4 4
[0.20366396 0.28657053 0.21177796 0.29798755] [0 1 0 0]
4 4
[[[0.06601738 0.44261063]
  [0.06377763 0.42759436]]

 [[0.21054454 0.29808346]
  [0.20340149 0.28797051]]

 [[0.20327448 0.2877907 ]
  [0.21067155 0.29826327]]

 ...

 [[0.42483781 0.0837902 ]
  [0.41042451 0.08094748]]

 [[0.41060062 0.08098221]
  [0.42466171 0.08375546]]

 [[0.52050652 0.01937918]
  [0.44359851 0.0165158 ]]]
[[0, 0, 0, 1], [0, 0, 1, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 1, 0, 0]]
[0.06601738 0.44261063 0.063777

iters: 620
train/loss: 2.1919
train/acc: 0.5
valid/loss: 3.3337
valid/acc: 0.0



[0.05895409 0.44741591 0.05747084 0.43615917] [0 0 0 1]
4 4
[0.2125279  0.29384209 0.20718082 0.28644919] [0 0 1 0]
4 4
[0.20575244 0.2844743  0.21395628 0.29581698] [0 0 1 0]
4 4
[0.39620935 0.11016064 0.38624094 0.10738906] [0 1 0 0]
4 4
[0.42314687 0.08322313 0.41250073 0.08112928] [0 0 1 0]
4 4
[0.40897057 0.08043498 0.42667703 0.08391743] [0 1 0 0]
4 4
[0.5177109  0.02242672 0.44076871 0.01909367] [0 1 0 0]
4 4
[[[6.38795216e-02 4.12659678e-01]
  [7.01693048e-02 4.53291495e-01]]

 [[2.52656457e-01 2.31106695e-01]
  [2.69616593e-01 2.46620255e-01]]

 [[9.26443522e-04 3.85611267e-02]
  [2.25258333e-02 9.37986596e-01]]

 ...

 [[2.83630973e-01 2.45831154e-01]
  [2.52065461e-01 2.18472413e-01]]

 [[1.65668560e-02 2.29201167e-02]
  [4.02985226e-01 5.57527802e-01]]

 [[2.05236197e-01 2.83943146e-01]
  [2.14315854e-01 2.96504802e-01]]]
[[0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 1, 0], [0, 1, 0, 0]]
[0.06387952 0

iters: 640
train/loss: 2.2043
train/acc: 0.5
valid/loss: 3.3412
valid/acc: 0.0



[0.01333235 0.49214491 0.01304341 0.48147933] [0 0 0 1]
4 4
[0.01292339 0.47704859 0.01345238 0.49657564] [0 0 0 1]
4 4
[0.01286586 0.47492535 0.01350989 0.49869889] [0 0 0 1]
4 4
[0.42756597 0.07791129 0.41829991 0.07622283] [1 0 0 0]
4 4
[0.28366242 0.24259909 0.25535177 0.21838672] [0 0 1 0]
4 4
[0.02689094 0.0377649  0.38901778 0.54632639] [0 0 1 0]
4 4
[0.2028766  0.28491461 0.21303208 0.2991767 ] [0 1 0 0]
4 4
[[[0.05751286 0.44790151]
  [0.05628062 0.43830501]]

 [[0.21024235 0.29517202]
  [0.20573781 0.28884783]]

 [[0.20380226 0.2861304 ]
  [0.21217789 0.29788945]]

 ...

 [[0.42217582 0.08323855]
  [0.41313051 0.08145512]]

 [[0.40742679 0.08033054]
  [0.42787955 0.08436312]]

 [[0.51717904 0.0245992 ]
  [0.43741639 0.02080536]]]
[[0, 0, 0, 1], [0, 0, 1, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 1, 0, 0]]
[0.05751286 0.44790151 0.05628062 0.43830501] [0 0 0 1]
4 4
[0.21024235 0.29517202 0.20573781 0.28884783] [0 0 1 0]
4 4
[0.20380226 0.2861304  0.212177

iters: 660
train/loss: 2.2472
train/acc: 0.5
valid/loss: 3.4498
valid/acc: 0.0



[[[0.08150395 0.39456752]
  [0.08969713 0.4342314 ]]

 [[0.25861096 0.22357364]
  [0.2777209  0.24009451]]

 [[0.00381446 0.06450098]
  [0.05201716 0.8796674 ]]

 ...

 [[0.27988542 0.23977221]
  [0.25871039 0.22163199]]

 [[0.02855383 0.03976109]
  [0.38941956 0.54226551]]

 [[0.20296079 0.28262227]
  [0.21501258 0.29940436]]]
[[0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 1, 0], [0, 1, 0, 0]]
[0.08150395 0.39456752 0.08969713 0.4342314 ] [0 0 0 1]
4 4
[0.25861096 0.22357364 0.2777209  0.24009451] [0 0 0 1]
4 4
[0.00381446 0.06450098 0.05201716 0.8796674 ] [0 0 0 1]
4 4
[0.0280743  0.47477022 0.02775667 0.4693988 ] [0 0 0 1]
4 4
[0.02729746 0.46163288 0.02853352 0.48253613] [0 0 0 1]
4 4
[0.02711057 0.45847249 0.0287204  0.48569655] [0 0 0 1]
4 4
[0.42813447 0.07471005 0.42329067 0.0738648 ] [1 0 0 0]
4 4
[0.27988542 0.23977221 0.25871039 0.22163199] [0 0 1 0]
4 4
[0.02855383 0.03976109 0.38941956 0.54226551] [0 

iters: 680
train/loss: 2.1771
train/acc: 0.5
valid/loss: 3.2283
valid/acc: 0.0



[[[0.05964386 0.41051044]
  [0.06721632 0.46262939]]

 [[0.25343924 0.22736679]
  [0.27367403 0.24551994]]

 [[0.00095626 0.02964898]
  [0.03027705 0.93911771]]

 ...

 [[0.28473519 0.24148003]
  [0.25636507 0.21741972]]

 [[0.01334833 0.01725633]
  [0.42280489 0.54659046]]

 [[0.21197275 0.27403251]
  [0.22418045 0.28981429]]]
[[0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 1, 0], [0, 1, 0, 0]]
[0.05964386 0.41051044 0.06721632 0.46262939] [0 0 0 1]
4 4
[0.25343924 0.22736679 0.27367403 0.24551994] [0 0 0 1]
4 4
[0.00095626 0.02964898 0.03027705 0.93911771] [0 0 0 1]
4 4
[0.0157688  0.48911403 0.01546379 0.47965338] [0 0 0 1]
4 4
[0.01526882 0.47360576 0.01596377 0.49516164] [0 0 0 1]
4 4
[0.0151792  0.47082605 0.01605339 0.49794136] [0 0 0 1]
4 4
[0.42526202 0.0796208  0.41703643 0.07808074] [1 0 0 0]
4 4
[0.28473519 0.24148003 0.25636507 0.21741972] [0 0 1 0]
4 4
[0.01334833 0.01725633 0.42280489 0.54659046] [0 

iters: 700
train/loss: 2.1881
train/acc: 0.5
valid/loss: 3.205
valid/acc: 0.0



[0.05520197 0.44924325 0.05422908 0.44132569] [0 0 0 1]
4 4
[0.22174633 0.28269889 0.21783823 0.27771655] [0 0 1 0]
4 4
[0.2149862  0.27408056 0.22459836 0.28633488] [0 0 1 0]
4 4
[0.38368324 0.12076199 0.37692112 0.11863366] [0 1 0 0]
4 4
[0.43153891 0.07290632 0.42393337 0.0716214 ] [0 0 1 0]
4 4
[0.41529312 0.07016167 0.44017916 0.07436605] [0 1 0 0]
4 4
[0.51627736 0.03519566 0.41990145 0.02862553] [0 1 0 0]
4 4
[[[0.05104172 0.41881597]
  [0.05759057 0.47255174]]

 [[0.25620352 0.22465209]
  [0.27660407 0.24254031]]

 [[0.00137421 0.04084476]
  [0.03116814 0.92661288]]

 ...

 [[0.28256414 0.24224356]
  [0.25585049 0.21934182]]

 [[0.01867062 0.02354784]
  [0.42356752 0.53421401]]

 [[0.21467542 0.27075404]
  [0.22756271 0.28700783]]]
[[0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 1, 0], [0, 1, 0, 0]]
[0.05104172 0.41881597 0.05759057 0.47255174] [0 0 0 1]
4 4
[0.25620352 0.22465209 0.27660407 0.24254031] [0 

iters: 720
train/loss: 2.154
train/acc: 0.5
valid/loss: 3.1589
valid/acc: 0.0



[[[4.07355702e-02 4.28626530e-01]
  [4.60536454e-02 4.84584254e-01]]

 [[2.54857967e-01 2.23451809e-01]
  [2.77972386e-01 2.43717837e-01]]

 [[2.25619300e-04 2.86271505e-02]
  [7.58222800e-03 9.63565002e-01]]

 ...

 [[2.76004242e-01 2.37435463e-01]
  [2.61554968e-01 2.25005327e-01]]

 [[1.27125717e-02 1.61396587e-02]
  [4.27896780e-01 5.43250990e-01]]

 [[2.11505664e-01 2.68524258e-01]
  [2.29103669e-01 2.90866408e-01]]]
[[0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 1, 0], [0, 1, 0, 0]]
[0.04073557 0.42862653 0.04605365 0.48458425] [0 0 0 1]
4 4
[0.25485797 0.22345181 0.27797239 0.24371784] [0 0 0 1]
4 4
[2.25619300e-04 2.86271505e-02 7.58222800e-03 9.63565002e-01] [0 0 0 1]
4 4
[0.00390148 0.49582966 0.00390568 0.49636318] [0 0 0 1]
4 4
[0.00380213 0.48320237 0.00400504 0.50899046] [0 0 0 1]
4 4
[0.00374767 0.47628225 0.00405949 0.51591059] [0 0 0 1]
4 4
[0.41689513 0.082836   0.41734372 0.08292514] [1 0 0 0]
4

iters: 740
train/loss: 2.1586
train/acc: 0.5
valid/loss: 3.1635
valid/acc: 0.2857



[[[4.51674405e-02 4.24533503e-01]
  [5.09946802e-02 4.79304376e-01]]

 [[2.57631250e-01 2.20902068e-01]
  [2.80745578e-01 2.40721104e-01]]

 [[9.43688184e-05 2.23212629e-02]
  [4.09950467e-03 9.73484864e-01]]

 ...

 [[2.74739057e-01 2.37187189e-01]
  [2.61937973e-01 2.26135781e-01]]

 [[9.79682433e-03 1.26182414e-02]
  [4.27266994e-01 5.50317941e-01]]

 [[2.09464328e-01 2.69789111e-01]
  [2.27599469e-01 2.93147091e-01]]]
[[0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 1, 0], [0, 1, 0, 0]]
[0.04516744 0.4245335  0.05099468 0.47930438] [0 0 0 1]
4 4
[0.25763125 0.22090207 0.28074558 0.2407211 ] [0 0 0 1]
4 4
[9.43688184e-05 2.23212629e-02 4.09950467e-03 9.73484864e-01] [0 0 0 1]
4 4
[0.00209307 0.49706925 0.00210009 0.49873759] [0 0 0 1]
4 4
[0.0020426  0.48508287 0.00215057 0.51072396] [0 0 0 1]
4 4
[0.00200958 0.47724386 0.00218357 0.51856299] [0 0 0 1]
4 4
[0.41940432 0.079758   0.42081198 0.0800257 ] [1 0 0 0]
4

iters: 760
train/loss: 2.1477
train/acc: 0.5
valid/loss: 3.1636
valid/acc: 0.2857



[0.01337702 0.47481021 0.0140244  0.49778836] [0 0 0 1]
4 4
[0.01316968 0.46745089 0.01423173 0.5051477 ] [0 0 0 1]
4 4
[0.41752015 0.08242598 0.41761012 0.08244374] [1 0 0 0]
4 4
[0.27600675 0.23365567 0.26554142 0.22479616] [0 0 1 0]
4 4
[0.01815699 0.02344736 0.41826329 0.54013235] [0 0 1 0]
4 4
[0.20975256 0.27086801 0.22666771 0.29271172] [0 1 0 0]
4 4
[[[0.03834347 0.46169976]
  [0.03833684 0.46161992]]

 [[0.21872057 0.28132267]
  [0.21868274 0.28127402]]

 [[0.21329759 0.27434753]
  [0.22410572 0.28824916]]

 ...

 [[0.42427724 0.075766  ]
  [0.42420387 0.0757529 ]]

 [[0.40764636 0.07279611]
  [0.44083475 0.07872278]]

 [[0.50879008 0.04176872]
  [0.41534387 0.03409733]]]
[[0, 0, 0, 1], [0, 0, 1, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 1, 0, 0]]
[0.03834347 0.46169976 0.03833684 0.46161992] [0 0 0 1]
4 4
[0.21872057 0.28132267 0.21868274 0.28127402] [0 0 1 0]
4 4
[0.21329759 0.27434753 0.22410572 0.28824916] [0 0 1 0]
4 4
[0.38259168 0.11745156 0.382525

iters: 780
train/loss: 2.1595
train/acc: 0.5
valid/loss: 3.1391
valid/acc: 0.0



[[[3.28830089e-02 4.38382507e-01]
  [3.68929586e-02 4.91841525e-01]]

 [[2.59168165e-01 2.20902174e-01]
  [2.80686402e-01 2.39243259e-01]]

 [[7.37808057e-04 3.81503644e-02]
  [1.82263067e-02 9.42885521e-01]]

 ...

 [[2.76721974e-01 2.36622561e-01]
  [2.62335044e-01 2.24320421e-01]]

 [[1.71600449e-02 2.17275992e-02]
  [4.24112145e-01 5.37000211e-01]]

 [[2.12591777e-01 2.69178411e-01]
  [2.28680394e-01 2.89549418e-01]]]
[[0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 1, 0], [0, 1, 0, 0]]
[0.03288301 0.43838251 0.03689296 0.49184153] [0 0 0 1]
4 4
[0.25916816 0.22090217 0.2806864  0.23924326] [0 0 0 1]
4 4
[7.37808057e-04 3.81503644e-02 1.82263067e-02 9.42885521e-01] [0 0 0 1]
4 4
[0.00951467 0.49222285 0.00944877 0.48881371] [0 0 0 1]
4 4
[0.00925186 0.47862659 0.00971159 0.50240996] [0 0 0 1]
4 4
[0.00913602 0.47263417 0.00982742 0.50840239] [0 0 0 1]
4 4
[0.41835963 0.08337789 0.41546207 0.08280041] [1 0 0 0]
4

iters: 800
train/loss: 2.1571
train/acc: 0.5
valid/loss: 3.1392
valid/acc: 0.0



[[[3.65029281e-02 4.34161888e-01]
  [4.10531699e-02 4.88282015e-01]]

 [[2.53480360e-01 2.27762081e-01]
  [2.73240350e-01 2.45517209e-01]]

 [[2.61843201e-04 5.82991258e-02]
  [4.20440403e-03 9.37234627e-01]]

 ...

 [[2.75643827e-01 2.40230245e-01]
  [2.58680036e-01 2.25445892e-01]]

 [[2.59076604e-02 3.26528452e-02]
  [4.16500623e-01 5.24938871e-01]]

 [[2.13919834e-01 2.69615059e-01]
  [2.28488433e-01 2.87976674e-01]]]
[[0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 1, 0], [0, 1, 0, 0]]
[0.03650293 0.43416189 0.04105317 0.48828201] [0 0 0 1]
4 4
[0.25348036 0.22776208 0.27324035 0.24551721] [0 0 0 1]
4 4
[2.61843201e-04 5.82991258e-02 4.20440403e-03 9.37234627e-01] [0 0 0 1]
4 4
[0.00224699 0.50092814 0.00221863 0.49460624] [0 0 0 1]
4 4
[0.00218505 0.48711916 0.00228058 0.50841521] [0 0 0 1]
4 4
[0.00215928 0.48137561 0.00230633 0.51415877] [0 0 0 1]
4 4
[0.42043235 0.08274278 0.41512633 0.08169854] [1 0 0 0]
4

iters: 820
train/loss: 2.1612
train/acc: 0.5
valid/loss: 3.1354
valid/acc: 0.0



[0.21515305 0.26798569 0.23017049 0.28669078] [0 1 0 0]
4 4
[[[4.24871855e-02 4.27526381e-01]
  [4.79084692e-02 4.82077965e-01]]

 [[2.50974967e-01 2.29338640e-01]
  [2.71548158e-01 2.48138235e-01]]

 [[8.23198537e-04 5.04920580e-02]
  [1.52132799e-02 9.33471464e-01]]

 ...

 [[2.73172905e-01 2.39852676e-01]
  [2.59301333e-01 2.27673087e-01]]

 [[2.27804533e-02 2.85343544e-02]
  [4.21154675e-01 5.27530518e-01]]

 [[2.13921117e-01 2.67953624e-01]
  [2.30013995e-01 2.88111264e-01]]]
[[0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 1, 0], [0, 1, 0, 0]]
[0.04248719 0.42752638 0.04790847 0.48207796] [0 0 0 1]
4 4
[0.25097497 0.22933864 0.27154816 0.24813824] [0 0 0 1]
4 4
[8.23198537e-04 5.04920580e-02 1.52132799e-02 9.33471464e-01] [0 0 0 1]
4 4
[0.00805533 0.4942762  0.00798055 0.48968791] [0 0 0 1]
4 4
[0.00783792 0.48093549 0.00819797 0.50302861] [0 0 0 1]
4 4
[0.00772729 0.47414745 0.0083086  0.50981666] [0 0 0 1]
4

iters: 840
train/loss: 2.151
train/acc: 0.5
valid/loss: 3.1216
valid/acc: 0.0



[[[4.00634840e-02 4.30558073e-01]
  [4.50653885e-02 4.84313054e-01]]

 [[2.49751033e-01 2.30938003e-01]
  [2.69817782e-01 2.49493181e-01]]

 [[1.95663529e-04 3.55687049e-02]
  [5.26623792e-03 9.58969394e-01]]

 ...

 [[2.74043747e-01 2.39933389e-01]
  [2.59139011e-01 2.26883853e-01]]

 [[1.60078887e-02 1.97559768e-02]
  [4.31591404e-01 5.32644730e-01]]

 [[2.15830750e-01 2.66365628e-01]
  [2.31768527e-01 2.86035095e-01]]]
[[0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 1, 0], [0, 1, 0, 0]]
[0.04006348 0.43055807 0.04506539 0.48431305] [0 0 0 1]
4 4
[0.24975103 0.230938   0.26981778 0.24949318] [0 0 0 1]
4 4
[1.95663529e-04 3.55687049e-02 5.26623792e-03 9.58969394e-01] [0 0 0 1]
4 4
[0.00274352 0.49961757 0.00271773 0.49492119] [0 0 0 1]
4 4
[0.00266945 0.48612901 0.0027918  0.50840973] [0 0 0 1]
4 4
[0.00263339 0.47956299 0.00282785 0.51497577] [0 0 0 1]
4 4
[0.41809217 0.08426892 0.41416212 0.0834768 ] [1 0 0 0]
4

iters: 860
train/loss: 2.1523
train/acc: 0.5
valid/loss: 3.1204
valid/acc: 0.0



[[[3.83331393e-02 4.31872528e-01]
  [4.31910542e-02 4.86603278e-01]]

 [[2.50785796e-01 2.29830323e-01]
  [2.71014839e-01 2.48369042e-01]]

 [[7.57905472e-04 4.57710181e-02]
  [1.55240689e-02 9.37947008e-01]]

 ...

 [[2.75474006e-01 2.41076813e-01]
  [2.57821066e-01 2.25628116e-01]]

 [[2.06148817e-02 2.59135296e-02]
  [4.22445010e-01 5.31026579e-01]]

 [[2.13946910e-01 2.68937962e-01]
  [2.29112963e-01 2.88002165e-01]]]
[[0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 1, 0], [0, 1, 0, 0]]
[0.03833314 0.43187253 0.04319105 0.48660328] [0 0 0 1]
4 4
[0.2507858  0.22983032 0.27101484 0.24836904] [0 0 0 1]
4 4
[7.57905472e-04 4.57710181e-02 1.55240689e-02 9.37947008e-01] [0 0 0 1]
4 4
[0.00818918 0.4947913  0.00809212 0.4889274 ] [0 0 0 1]
4 4
[0.00796067 0.48098452 0.00832064 0.50273417] [0 0 0 1]
4 4
[0.00786199 0.47502288 0.0084193  0.50869582] [0 0 0 1]
4 4
[0.42020457 0.08277591 0.41522462 0.08179491] [1 0 0 0]
4

iters: 880
train/loss: 2.1423
train/acc: 0.5
valid/loss: 3.1228
valid/acc: 0.0



[[[0.03101181 0.43975249]
  [0.03486364 0.49437206]]

 [[0.24931665 0.23109587]
  [0.26964703 0.24994044]]

 [[0.00103864 0.06049114]
  [0.01583663 0.9226336 ]]

 ...

 [[0.27302156 0.24049537]
  [0.25864847 0.2278346 ]]

 [[0.02743597 0.03409332]
  [0.4184648  0.52000591]]

 [[0.21478698 0.26690537]
  [0.23111377 0.28719388]]]
[[0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 1, 0], [0, 1, 0, 0]]
[0.03101181 0.43975249 0.03486364 0.49437206] [0 0 0 1]
4 4
[0.24931665 0.23109587 0.26964703 0.24994044] [0 0 0 1]
4 4
[0.00103864 0.06049114 0.01583663 0.9226336 ] [0 0 0 1]
4 4
[0.00847188 0.49357727 0.00840272 0.48954813] [0 0 0 1]
4 4
[0.00824408 0.48030531 0.00863053 0.50282007] [0 0 0 1]
4 4
[0.00812837 0.47356399 0.00874623 0.50956141] [0 0 0 1]
4 4
[0.41629832 0.08575083 0.41290002 0.08505083] [1 0 0 0]
4 4
[0.27302156 0.24049537 0.25864847 0.2278346 ] [0 0 1 0]
4 4
[0.02743597 0.03409332 0.4184648  0.52000591] [0 

iters: 900
train/loss: 2.1448
train/acc: 0.5
valid/loss: 3.1279
valid/acc: 0.0



[[[2.72194282e-02 4.43546405e-01]
  [3.06000326e-02 4.98634134e-01]]

 [[2.49405695e-01 2.31033161e-01]
  [2.69714880e-01 2.49846264e-01]]

 [[7.34228667e-04 4.06935015e-02]
  [1.69815480e-02 9.41590722e-01]]

 ...

 [[2.71058410e-01 2.36901554e-01]
  [2.62563192e-01 2.29476844e-01]]

 [[1.85919031e-02 2.28353448e-02]
  [4.30192325e-01 5.28380427e-01]]

 [[2.15990558e-01 2.65288756e-01]
  [2.32793655e-01 2.85927031e-01]]]
[[0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 1, 0], [0, 1, 0, 0]]
[0.02721943 0.4435464  0.03060003 0.49863413] [0 0 0 1]
4 4
[0.2494057  0.23103316 0.26971488 0.24984626] [0 0 0 1]
4 4
[7.34228667e-04 4.06935015e-02 1.69815480e-02 9.41590722e-01] [0 0 0 1]
4 4
[0.00888013 0.49239319 0.00883502 0.48989166] [0 0 0 1]
4 4
[0.00865611 0.47997136 0.00905904 0.50231348] [0 0 0 1]
4 4
[0.00852593 0.47275338 0.00918921 0.50953147] [0 0 0 1]
4 4
[0.41466117 0.08661215 0.41255455 0.08617213] [1 0 0 0]
4

iters: 920
train/loss: 2.1427
train/acc: 0.5
valid/loss: 3.0854
valid/acc: 0.0



[0.00516748 0.48266252 0.00542531 0.50674468] [0 0 0 1]
4 4
[0.00508812 0.47524999 0.00550467 0.51415723] [0 0 0 1]
4 4
[0.42014188 0.08123066 0.41784155 0.08078591] [1 0 0 0]
4 4
[0.26870018 0.24001438 0.25949421 0.23179122] [0 0 1 0]
4 4
[0.03479999 0.04351383 0.40956579 0.5121204 ] [0 0 1 0]
4 4
[0.2134458  0.2668923  0.23091995 0.28874194] [0 1 0 0]
4 4
[[[0.0306524  0.47082488]
  [0.03047181 0.46805091]]

 [[0.22343038 0.2780469 ]
  [0.22211399 0.27640873]]

 [[0.21744237 0.27059514]
  [0.22810201 0.28386048]]

 ...

 [[0.42828852 0.07318876]
  [0.42576516 0.07275756]]

 [[0.41047754 0.0701451 ]
  [0.44357615 0.07580121]]

 [[0.49784474 0.05410642]
  [0.4041277  0.04392114]]]
[[0, 0, 0, 1], [0, 0, 1, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 1, 0, 0]]
[0.0306524  0.47082488 0.03047181 0.46805091] [0 0 0 1]
4 4
[0.22343038 0.2780469  0.22211399 0.27640873] [0 0 1 0]
4 4
[0.21744237 0.27059514 0.22810201 0.28386048] [0 0 1 0]
4 4
[0.38745285 0.11402443 0.385170

iters: 940
train/loss: 2.1384
train/acc: 0.5
valid/loss: 3.0842
valid/acc: 0.0



[[[0.02484198 0.44714633]
  [0.02779064 0.50022105]]

 [[0.25184668 0.22791631]
  [0.27309311 0.24714391]]

 [[0.00122604 0.07277869]
  [0.01533767 0.9106576 ]]

 ...

 [[0.26797945 0.23783937]
  [0.2618139  0.23236727]]

 [[0.03265419 0.04135015]
  [0.40859268 0.51740297]]

 [[0.21171129 0.26809109]
  [0.22953557 0.29066205]]]
[[0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 1, 0], [0, 1, 0, 0]]
[0.02484198 0.44714633 0.02779064 0.50022105] [0 0 0 1]
4 4
[0.25184668 0.22791631 0.27309311 0.24714391] [0 0 0 1]
4 4
[0.00122604 0.07277869 0.01533767 0.9106576 ] [0 0 0 1]
4 4
[0.00828968 0.49219946 0.00827348 0.49123738] [0 0 0 1]
4 4
[0.00808296 0.47992516 0.00848021 0.50351168] [0 0 0 1]
4 4
[0.00794704 0.47185534 0.00861611 0.51158151] [0 0 0 1]
4 4
[0.41749417 0.08299497 0.41667812 0.08283274] [1 0 0 0]
4 4
[0.26797945 0.23783937 0.2618139  0.23236727] [0 0 1 0]
4 4
[0.03265419 0.04135015 0.40859268 0.51740297] [0 

iters: 960
train/loss: 2.1381
train/acc: 0.5
valid/loss: 3.0693
valid/acc: 0.0



[0.00110042 0.0398328  0.02577597 0.93329082] [0 0 0 1]
4 4
[0.01349918 0.48878091 0.01337663 0.48434328] [0 0 0 1]
4 4
[0.01311503 0.47487112 0.01376079 0.49825307] [0 0 0 1]
4 4
[0.01292844 0.46811535 0.01394737 0.50500884] [0 0 0 1]
4 4
[0.42139907 0.08088103 0.41757319 0.08014671] [1 0 0 0]
4 4
[0.26994255 0.24001704 0.25939851 0.2306419 ] [0 0 1 0]
4 4
[0.01797696 0.0229558  0.4212055  0.53786173] [0 0 1 0]
4 4
[0.21126599 0.2697778  0.22791646 0.29103975] [0 1 0 0]
4 4
[[[3.39212946e-02 4.37295817e-01]
  [3.80652533e-02 4.90717635e-01]]

 [[2.48053333e-01 2.31400361e-01]
  [2.69313279e-01 2.51233027e-01]]

 [[6.15942502e-05 8.77731863e-03]
  [6.86794303e-03 9.84293144e-01]]

 ...

 [[2.69950234e-01 2.39553875e-01]
  [2.59879122e-01 2.30616769e-01]]

 [[3.89379899e-03 4.94458266e-03]
  [4.36660724e-01 5.54500894e-01]]

 [[2.11715163e-01 2.68850037e-01]
  [2.28839341e-01 2.90595460e-01]]]
[[0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [1, 0, 0,

iters: 980
train/loss: 2.1447
train/acc: 0.5
valid/loss: 3.0587
valid/acc: 0.0



[[[4.43883263e-02 4.26509687e-01]
  [4.98748130e-02 4.79227174e-01]]

 [[2.43924525e-01 2.34744899e-01]
  [2.65664166e-01 2.55666410e-01]]

 [[6.40587384e-04 4.34647218e-02]
  [1.38772242e-02 9.42017467e-01]]

 ...

 [[2.69903504e-01 2.36018061e-01]
  [2.63585327e-01 2.30493109e-01]]

 [[1.97490511e-02 2.43558294e-02]
  [4.28025520e-01 5.27869600e-01]]

 [[2.14471889e-01 2.64501036e-01]
  [2.33302667e-01 2.87724408e-01]]]
[[0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 1, 0], [0, 1, 0, 0]]
[0.04438833 0.42650969 0.04987481 0.47922717] [0 0 0 1]
4 4
[0.24392452 0.2347449  0.26566417 0.25566641] [0 0 0 1]
4 4
[6.40587384e-04 4.34647218e-02 1.38772242e-02 9.42017467e-01] [0 0 0 1]
4 4
[0.00727303 0.49371973 0.00724421 0.49176303] [0 0 0 1]
4 4
[0.00707915 0.48055767 0.0074381  0.50492507] [0 0 0 1]
4 4
[0.00695337 0.47201956 0.00756387 0.5134632 ] [0 0 0 1]
4 4
[0.42638645 0.07460631 0.4246966  0.07431063] [1 0 0 0]
4

iters: 1000
train/loss: 2.1561
train/acc: 0.5
valid/loss: 3.0427
valid/acc: 0.0



SQL2CircuitsEstimator(a=0.001, c=0.001, id=1, workload='cardinality')

In [5]:
from skopt import BayesSearchCV
from sklearn.datasets import load_digits
from sklearn.svm import SVC
from sklearn.model_selection import train_test_split

X, y = load_digits(n_class=10, return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.75, test_size=.25, random_state=0)

# log-uniform: understand as search over p = exp(x) by varying x
opt = BayesSearchCV(
    SVC(),
    {
        'C': (1e-6, 1e+6, 'log-uniform'),
        'gamma': (1e-6, 1e+1, 'log-uniform'),
        'degree': (1, 8),  # integer valued parameter
        'kernel': ['linear', 'poly', 'rbf'],  # categorical parameter
    },
    n_iter=32,
    cv=3
)

opt.fit(X_train, y_train)

print("val. score: %s" % opt.best_score_)
print("test score: %s" % opt.score(X_test, y_test))

val. score: 0.9836674090571641
test score: 0.9822222222222222
