In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

# Prepare data

In [None]:
file = '../input/liangbo-capital/data.csv'
data = pd.read_csv(file)
time = list(data['Time'])
dates = [s[1:9] for s in time]
dates_unique = np.unique(dates)
print(f'Trading dates: {dates_unique}')
ind_end = []
for dt in dates_unique:
    ind_end.append(dates.index(dt))
ind_end.append(len(time))
ind_start = [ind+1 for ind in ind_end[:-1]]
ind_end = ind_end[1:]
print(f'Starting index: {ind_start}')
print(f'Ending index: {ind_end}')
whichday = 4 #0~4
data = data.iloc[ind_start[whichday]:ind_end[whichday],:]
#correlation matrix
plt.matshow(data.corr(method='spearman'),vmax=1,vmin=-1,cmap='PRGn')
plt.colorbar()
plt.show()

In [None]:
# from sklearn.preprocessing import MinMaxScaler
# mmscaler = MinMaxScaler(feature_range=(0, 1))
from sklearn.preprocessing import StandardScaler
sdscaler = StandardScaler()

num_row = len(data)
num_col = len(data.columns)
inputs = sdscaler.fit_transform(data.iloc[1:,1].to_numpy().reshape(num_row-1,1))

for i in range(2,num_col):
        
    if i == 2 or i == 3:
        #turnover/volume increment between two neighboring ticks
        feature = data.iloc[1:,i].to_numpy()-data.iloc[:-1,i].to_numpy() 
    elif i == 4:     
        # bid-ask spread
        feature = data.iloc[1:,i+2].to_numpy()-data.iloc[1:,i].to_numpy() 
    elif i == 6:
        #mid price
        feature = (data.iloc[1:,i].to_numpy()+data.iloc[1:,i-2].to_numpy())/2 
    else:
        #other features 
        feature = data.iloc[1:,i].to_numpy()
        
    feature = feature.reshape(num_row-1,1)    
    feature = sdscaler.fit_transform(feature)    
    inputs = np.hstack((inputs,feature))

trainX = torch.from_numpy(inputs[:-1,:])
# Bid_arrival>0
trainY = torch.from_numpy(data.iloc[2:,8].to_numpy())
print(trainX.shape,trainY.shape)

In [None]:
inds_cross = [0]
threshold = 0.01

while len(inds_cross) > 0:
    cor = np.corrcoef(trainX,rowvar=False)
    #print(np.round(cor, 1))
    inds = np.argwhere(abs(cor)>threshold)
    inds_cross = [[min(ind),max(ind)] for ind in inds if ind[0]!=ind[1]]
    # inds_unique = []
    # for i in inds_cross:
    #     if i not in inds_unique:
    #         inds_unique.append(i)
    #print(inds_cross)
    if len(inds_cross)>0:
        ind = inds_cross[0]
        a = trainX[:,ind[0]]
        b = trainX[:,ind[1]]
        m = np.dot(a.T,b)/np.dot(a.T,a)
        #print(m)
        res = b - m*a
        res = sdscaler.fit_transform(res.reshape((len(res),-1)))
        trainX[:,ind[1]] = torch.from_numpy(res.reshape(-1))
    
#correlation matrix after decorrelation   
cor = np.corrcoef(trainX,rowvar=False)
print(np.round(cor, 1))

In [None]:
# id=7
# plt.plot(trainX[:,id],trainY,'.')
# plt.show()

# hist = plt.hist2d(trainX[:,id].numpy(),trainY.numpy(),bins=[30,2])
# plt.show()

# center = (hist[1][:-1]+hist[1][1:])/2
# counts = hist[0]
# tot = counts[:,0]+counts[:,1]
# tot[tot==0] = -1
# p1 = counts[:,1]/tot
# p1[tot<0] = 0.5
# plt.plot(center,p1,'ro--')
# plt.axhline(y=0.5, color='b', linestyle='-')
# plt.grid()
# plt.ylim([-0.1,1.1])
# plt.show()

# Dataloader

In [None]:
from torch.utils.data import Dataset

class timeseries(Dataset):
    def __init__(self,x,y):
        self.x = x #torch.tensor(x,dtype=torch.float32)
        self.y = y #torch.tensor(y,dtype=torch.float32)
        self.len = x.shape[0]

    def __getitem__(self,idx):
        return self.x[idx,:,:],self.y[idx]
  
    def __len__(self):
        return self.len

In [None]:
#dataloader
from torch.utils.data import DataLoader

N_seq = 200
trainX_new = torch.zeros(trainX.shape[0]-N_seq,N_seq,trainX.shape[1])
#trainY_new = trainY[N_seq:] >0 #for dummy treament
trainY_new = np.round(trainY[N_seq:]).int()
trainY_new[trainY_new>=7]=7

for i in range(trainX_new.shape[0]):
    trainX_new[i,:,:] = trainX[i:i+N_seq,:]
print(trainX.shape,trainX_new.shape,trainY_new.shape)
#train val split
N_train = int(0.85 * trainX_new.shape[0])
#N_val = trainX.shape[0] - N_train
X_train = trainX_new[:N_train,:,:]
X_val = trainX_new[N_train:,:,:]
Y_train = trainY_new[:N_train]
Y_val = trainY_new[N_train:]

trainset = timeseries(X_train,Y_train)
valset = timeseries(X_val, Y_val)
 
train_loader = DataLoader(trainset,shuffle=True,batch_size=256)
val_loader = DataLoader(valset,shuffle=True,batch_size=256)

In [None]:
print(trainY_new.shape)
plt.rcParams.update({'font.size': 20})
plt.hist(trainY_new.numpy(),list(range(8)))
plt.xlabel('Number of order arrivals')
plt.show()
print(np.unique(trainY_new))

# Create the model

In [None]:
class LSTMmodel(nn.Module):

    def __init__(self, input_dim, hidden_dim, output_dim):
        super(LSTMmodel, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers=1, batch_first=True)
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        lstm_out, _ = self.lstm(x) #N_batch*N_seq*N_input
        output = self.fc(lstm_out[:,-1,:]) #N_batch*N_seq*N_hidden
        return output #torch.sigmoid(output)

# Train the model

In [None]:
input_dim = trainX.shape[1]
hidden_dim = 4 * input_dim
output_dim = len(np.unique(trainY_new))
print(input_dim, hidden_dim, output_dim)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = LSTMmodel(input_dim, hidden_dim, output_dim)
model.to(device=device)

criterion = nn.CrossEntropyLoss() #nn.MSELoss() #nn.BCEWithLogitsLoss() 
optimizer = optim.SGD(model.parameters(), lr=0.001)
epoches = 400

train_X = trainset[:][0].to(device=device, dtype=torch.float32)
train_Y = trainset[:][1].to(device=device, dtype=torch.float32)
val_X = valset[:][0].to(device=device, dtype=torch.float32)
val_Y = valset[:][1].to(device=device, dtype=torch.long)
accuracies_val = []
accuracies_train = []
percent_val = []
percent_train = []
loss_train = []
loss_val = []
#training loop
for i in range(epoches):
    for j,dt in enumerate(train_loader):
        x = dt[0].to(device=device, dtype=torch.float32)
        train_label = dt[1].to(device=device, dtype=torch.long)
        train_pred = model(x) #.reshape(-1)
        #print(device, x.device, y_pred.device,y_label.device,y_label.shape)
        #print(train_pred.shape,train_label.shape)
        loss = criterion(train_pred,train_label)
        loss.backward()
        optimizer.step()
    if i%4 == 0:
        print(i,"th iteration : ", loss.cpu().item())
        loss_train.append(loss.cpu().item())
        train_pred = model(train_X)
        train_pred = train_pred.cpu().detach().numpy()
        train_pred = np.argmax(train_pred,axis=1)
        train_label = train_Y.cpu().detach().numpy()
        accuracies_train.append(sum(train_pred==train_label)/len(train_label))
        
        val_pred = model(val_X)
        loss = criterion(val_pred,val_Y)
        loss_val.append(loss.cpu().item())
        Parv = nn.LogSoftmax(dim=1)(val_pred)
        Parv = np.exp(Parv[:,0].cpu().detach().numpy())
        val_pred = val_pred.cpu().detach().numpy()
        val_pred = np.argmax(val_pred,axis=1)
        val_label = val_Y.cpu().detach().numpy()
        accuracies_val.append(sum(val_pred==val_label)/len(val_label))
        print(f'train accuracy ={accuracies_train[-1].item()}, val accuracy = {accuracies_val[-1].item()}')
#         train_pred = train_pred.detach()>0.5
#         accuracy_train = (sum(train_pred*train_label)+sum((1-1.*train_pred)*(1-1.*train_label)))/len(train_pred)
#         accuracies_train.append(accuracy_train)
#         percent_train.append(sum(train_pred*train_label)/sum((1-1.*train_pred)*(1-1.*train_label)))
#         val_pred = model(val_X).view(-1)
#         val_pred = val_pred.detach()>0.5
#         accuracy_val = (sum(val_pred*val_Y)+sum((1-1.*val_pred)*(1-1.*val_Y)))/len(val_pred)
#         accuracies_val.append(accuracy_val)
#         percent_val.append(sum(val_pred*val_Y)/sum((1-1.*val_pred)*(1-1.*val_Y)))
#         ac_ref = 1 - sum(val_Y)/len(val_Y)
#         print(f'accuracy={accuracy_val} and reference ={ac_ref}')
#         print(f'percent = {percent_val[-1]}')

In [None]:
train_pred = model(train_X)
train_pred = train_pred.cpu().detach().numpy()
train_pred = np.argmax(train_pred,axis=1)
train_label = train_Y.cpu().detach().numpy()
val_pred = model(val_X)
val_pred = val_pred.cpu().detach().numpy()
val_pred = np.argmax(val_pred,axis=1)
val_label = val_Y.cpu().detach().numpy()
plt.plot(train_pred)
plt.plot(train_label)
plt.xlim([10,200])
plt.show()

plt.plot(val_label)
plt.plot(val_pred)
#plt.xlim([1300,1800])
#plt.ylim([0,8])
plt.show()

In [None]:
plt.plot(Parv)
m = nn.LogSoftmax(dim=1)
a = torch.tensor([[1,2,3,4],[4.,7,2,9]])
print(a.shape)
print(np.sum(np.exp(m(a).numpy()),axis=1))

In [None]:
# accuracies_train = [x.cpu().item() for x in accuracies_train]
# accuracies_val = [x.cpu().item() for x in accuracies_val]
# percent_train = [x.cpu().item() for x in percent_train]
# percent_val = [x.cpu().item() for x in percent_val]

# plt.plot(list(range(0,epoches,4)), accuracies_train,color='g')
# plt.plot(list(range(0,epoches,4)), accuracies_val,color='b')
# plt.plot(list(range(0,epoches,4)), loss_train, color='k')
# plt.axhline(y=ac_ref, color='r', linestyle='--')
# plt.axhline(y=sum(Y_train)/len(Y_train), color='c', linestyle='--')
# plt.grid()
# plt.show()
# plt.plot(list(range(0,epoches,4)), percent_train)
# plt.plot(list(range(0,epoches,4)), percent_val)
# plt.grid()
# plt.show()



#start here
#[0.518796980381012, 0.6015037894248962, 0.6992481350898743, 0.5789473652839661, 0.6842105388641357, 0.6315789818763733, 0.6616541743278503, 0.6090225577354431, 0.548872172832489, 0.6691729426383972, 0.6165413856506348, 0.6315789818763733, 0.6917293667793274, 0.6842105388641357, 0.6165413856506348, 0.6691729426383972, 0.6466165781021118, 0.6541353464126587, 0.6390977501869202, 0.6165413856506348, 0.7067669630050659, 0.6165413856506348, 0.6616541743278503, 0.6240601539611816, 0.6842105388641357, 0.7067669630050659, 0.6090225577354431, 0.6315789818763733, 0.6616541743278503, 0.548872172832489, 0.6842105388641357, 0.7518796920776367, 0.5714285969734192, 0.6992481350898743, 0.6466165781021118, 0.6466165781021118, 0.6992481350898743, 0.6766917705535889, 0.6240601539611816, 0.7293233275413513, 0.6766917705535889, 0.6466165781021118, 0.7067669630050659, 0.6015037894248962, 0.7142857313156128, 0.6842105388641357, 0.7819548845291138, 0.6165413856506348, 0.6466165781021118, 0.7218044996261597, 0.6917293667793274, 0.6691729426383972, 0.6992481350898743, 0.6390977501869202, 0.6917293667793274, 0.7293233275413513, 0.7593985199928284, 0.6766917705535889, 0.7894737124443054, 0.7067669630050659, 0.7593985199928284, 0.6842105388641357, 0.6390977501869202, 0.7518796920776367, 0.7669172883033752, 0.8120300769805908, 0.7293233275413513, 0.7593985199928284, 0.7744361162185669, 0.7819548845291138, 0.7669172883033752, 0.7669172883033752, 0.7969924807548523, 0.7593985199928284, 0.8195489048957825, 0.8120300769805908, 0.864661693572998, 0.7518796920776367, 0.804511308670044, 0.7894737124443054, 0.7293233275413513, 0.7368420958518982, 0.7744361162185669, 0.7819548845291138, 0.7969924807548523, 0.7969924807548523, 0.8195489048957825, 0.8120300769805908, 0.8571428656578064, 0.8195489048957825, 0.7669172883033752, 0.8270676732063293, 0.8270676732063293, 0.7894737124443054, 0.7593985199928284, 0.7518796920776367, 0.8571428656578064, 0.7067669630050659, 0.6090225577354431, 0.6315789818763733, 0.5263158082962036, 0.6240601539611816, 0.548872172832489, 0.6165413856506348, 0.6466165781021118, 0.5789473652839661, 0.49624061584472656, 0.5864661931991577, 0.548872172832489, 0.5864661931991577, 0.45864662528038025, 0.5413534045219421, 0.548872172832489, 0.6240601539611816, 0.5864661931991577, 0.5413534045219421, 0.5112782120704651, 0.548872172832489, 0.518796980381012, 0.5714285969734192, 0.5939849615097046, 0.4436090290546417, 0.5563910007476807, 0.5789473652839661, 0.5639097690582275]
accuracies_val = [0.612500011920929, 0.6131756901741028, 0.625, 0.6118243336677551, 0.6118243336677551, 0.6131756901741028, 0.6192567944526672, 0.6179054379463196, 0.6148648858070374, 0.620270311832428, 0.6192567944526672, 0.6152027249336243, 0.6192567944526672, 0.6175675988197327, 0.616216242313385, 0.6192567944526672, 0.6077702641487122, 0.6199324727058411, 0.6168919205665588, 0.6087837815284729, 0.6195946335792542, 0.6138513684272766, 0.6057432293891907, 0.6064189076423645, 0.6094594597816467, 0.6074324250221252, 0.6097972989082336, 0.6003378629684448, 0.5949324369430542, 0.6070945858955383, 0.5986486673355103, 0.6077702641487122, 0.6081081032752991, 0.583108127117157, 0.5945945978164673, 0.5898648500442505, 0.5793918967247009, 0.5864865183830261, 0.5763513445854187, 0.5766891837120056, 0.5966216325759888, 0.5618243217468262, 0.5945945978164673, 0.5925675630569458, 0.5797297358512878, 0.599662184715271, 0.5834459662437439, 0.5895270705223083, 0.5895270705223083, 0.5797297358512878, 0.5885135531425476, 0.5729730129241943, 0.5810810923576355, 0.5766891837120056, 0.5658783912658691, 0.5763513445854187, 0.5635135173797607, 0.5783783793449402, 0.5597972869873047, 0.5689189434051514, 0.5706081390380859, 0.5675675868988037, 0.5729730129241943, 0.5567567944526672, 0.5733108520507812, 0.553716242313385, 0.5746621489524841, 0.5625, 0.5692567825317383, 0.5699324607849121, 0.5658783912658691, 0.5645270347595215, 0.5604729652404785, 0.5628378391265869, 0.557770311832428, 0.5628378391265869, 0.5618243217468262, 0.5766891837120056, 0.5712838172912598, 0.5736486911773682, 0.5672297477722168, 0.5783783793449402, 0.5692567825317383, 0.5709459781646729, 0.5574324727058411, 0.5736486911773682, 0.5648648738861084, 0.5709459781646729, 0.5570946335792542, 0.5641891956329346, 0.5527027249336243, 0.5655405521392822, 0.5540540814399719, 0.5695946216583252, 0.5631756782531738, 0.5689189434051514, 0.5685811042785645, 0.5793918967247009, 0.5628378391265869, 0.5483108162879944, 0.5429053902626038, 0.5652027130126953, 0.5155405402183533, 0.5645270347595215, 0.5601351261138916, 0.512499988079071, 0.5597972869873047, 0.5709459781646729, 0.5834459662437439, 0.5527027249336243, 0.4868243336677551, 0.583108127117157, 0.5625, 0.48783785104751587, 0.5989865064620972, 0.5219594836235046, 0.5195946097373962, 0.5739864706993103, 0.5682432651519775, 0.558783769607544, 0.5429053902626038, 0.5148648619651794, 0.5952702760696411, 0.5658783912658691, 0.4398648738861084]
#[0.0, 0.03896103799343109, 0.27397260069847107, 0.0, 0.0, 0.03703703731298447, 0.2054794579744339, 0.03846153989434242, 0.05797101557254791, 0.07228915393352509, 0.17142857611179352, 0.012048192322254181, 0.08235294371843338, 0.1666666716337204, 0.037974681705236435, 0.03488372266292572, 0.24637681245803833, 0.07407407462596893, 0.08974359184503555, 0.09333333373069763, 0.04444444552063942, 0.06493506580591202, 0.23943662643432617, 0.050632912665605545, 0.12345679104328156, 0.18987341225147247, 0.06578947603702545, 0.1666666716337204, 0.1139240488409996, 0.07352941483259201, 0.1818181872367859, 0.1764705926179886, 0.1515151560306549, 0.23999999463558197, 0.08860759437084198, 0.1944444477558136, 0.2567567527294159, 0.26760563254356384, 0.23880596458911896, 0.1975308656692505, 0.18421052396297455, 0.3650793731212616, 0.1463414579629898, 0.1428571492433548, 0.21794871985912323, 0.09638553857803345, 0.3333333432674408, 0.15492957830429077, 0.1944444477558136, 0.2631579041481018, 0.2957746386528015, 0.508474588394165, 0.16249999403953552, 0.3492063581943512, 0.3529411852359772, 0.31081080436706543, 0.5303030014038086, 0.16883116960525513, 0.4189189076423645, 0.2876712381839752, 0.2469135820865631, 0.5964912176132202, 0.28787878155708313, 0.3513513505458832, 0.1860465109348297, 0.611940324306488, 0.2763157784938812, 0.4637681245803833, 0.2874999940395355, 0.2530120611190796, 0.39726027846336365, 0.25925925374031067, 0.37662336230278015, 0.29487180709838867, 0.37974682450294495, 0.24137930572032928, 0.5333333611488342, 0.23456789553165436, 0.44594594836235046, 0.47887325286865234, 0.6724137663841248, 0.4000000059604645, 0.5373134613037109, 0.3333333432674408, 0.5142857432365417, 0.3417721390724182, 0.34567901492118835, 0.5652173757553101, 0.5, 0.34567901492118835, 0.36000001430511475, 0.41025641560554504, 0.46666666865348816, 0.4000000059604645, 0.3835616409778595, 0.49253731966018677, 0.3103448152542114, 0.20512820780277252, 0.37288135290145874, 0.23529411852359772, 0.3461538553237915, 0.18571428954601288, 0.3035714328289032, 0.20588235557079315, 0.3030303120613098, 0.2222222238779068, 0.3469387888908386, 0.3448275923728943, 0.5531914830207825, 0.1818181872367859, 0.4878048896789551, 0.2857142984867096, 0.4038461446762085, 0.45614033937454224, 0.14705882966518402, 0.35849055647850037, 0.6190476417541504, 0.19672131538391113, 0.6428571343421936, 0.2063492089509964, 0.2153846174478531, 0.7878788113594055, 0.15625, 0.1666666716337204, 0.6666666865348816]
percent_val = [0.0, 0.0011031439062207937, 0.03699551522731781, 0.0, 0.0011055831564590335, 0.004983388818800449, 0.06818182021379471, 0.016675930470228195, 0.007194244768470526, 0.025698324665427208, 0.0731850117444992, 0.0077476478181779385, 0.0607638880610466, 0.15623024106025696, 0.015025041997432709, 0.040885861963033676, 0.17658600211143494, 0.07122008502483368, 0.06534422188997269, 0.12344139814376831, 0.04501424357295036, 0.07705987244844437, 0.11783042550086975, 0.03697284683585167, 0.06493506580591202, 0.09168184548616409, 0.04637681320309639, 0.12610900402069092, 0.11244472861289978, 0.05210772901773453, 0.1301020383834839, 0.09162621200084686, 0.07655502110719681, 0.18138261139392853, 0.0731707289814949, 0.08785046637058258, 0.16508151590824127, 0.08027379959821701, 0.17574086785316467, 0.16280654072761536, 0.1003115251660347, 0.19126074016094208, 0.10691823810338974, 0.11862245202064514, 0.1769547313451767, 0.09975216537714005, 0.18206708133220673, 0.13311688601970673, 0.13829092681407928, 0.2118644118309021, 0.12823833525180817, 0.24614253640174866, 0.1886662095785141, 0.21928571164608002, 0.29044684767723083, 0.18636995553970337, 0.31545740365982056, 0.21332389116287231, 0.3109177350997925, 0.2796352505683899, 0.2548291087150574, 0.3104524314403534, 0.22454874217510223, 0.32369476556777954, 0.23418182134628296, 0.3038981556892395, 0.23889292776584625, 0.2877030074596405, 0.26217228174209595, 0.2665165066719055, 0.2641509473323822, 0.23503325879573822, 0.2683486342430115, 0.23042836785316467, 0.2838258147239685, 0.2405063360929489, 0.2951713502407074, 0.23875181376934052, 0.30076923966407776, 0.23941606283187866, 0.2905457317829132, 0.2423802614212036, 0.3072148859500885, 0.2574404776096344, 0.30744850635528564, 0.2615155875682831, 0.30015552043914795, 0.264970064163208, 0.3066560924053192, 0.26323750615119934, 0.3046252131462097, 0.2662632465362549, 0.2954186499118805, 0.2657657563686371, 0.27837422490119934, 0.2557792663574219, 0.2616191804409027, 0.20604781806468964, 0.254518061876297, 0.3468879759311676, 0.335827112197876, 0.17983074486255646, 0.4382658004760742, 0.1668994426727295, 0.21021898090839386, 0.4558541178703308, 0.2301410585641861, 0.20800571143627167, 0.20600558817386627, 0.24885496497154236, 0.6172839403152466, 0.08077645301818848, 0.17917847633361816, 0.723150372505188, 0.06294964253902435, 0.3806970417499542, 0.3695458471775055, 0.1387399435043335, 0.24224519729614258, 0.33818769454956055, 0.30863192677497864, 0.4514285624027252, 0.08098159730434418, 0.22173595428466797, 1.952380895614624]

#[0.5609756112098694, 0.630081295967102, 0.6544715166091919, 0.5894308686256409, 0.5813007950782776, 0.6219512224197388, 0.6504064798355103, 0.5691056847572327, 0.6707316637039185, 0.5894308686256409, 0.6463414430618286, 0.6341463327407837, 0.6219512224197388, 0.642276406288147, 0.5528454780578613, 0.5934959053993225, 0.6016259789466858, 0.5975609421730042, 0.6178861856460571, 0.6829267740249634, 0.5934959053993225, 0.5203251838684082, 0.6016259789466858, 0.5853658318519592, 0.6138210892677307, 0.6951219439506531, 0.6219512224197388, 0.5934959053993225, 0.6585365533828735, 0.6382113695144653, 0.6666666269302368, 0.6260162591934204, 0.6707316637039185, 0.707317054271698, 0.7235772013664246, 0.630081295967102, 0.630081295967102, 0.6951219439506531, 0.7154471278190613, 0.6585365533828735, 0.7276422381401062, 0.6788617372512817, 0.6788617372512817, 0.6869918704032898, 0.6707316637039185, 0.6991869807243347, 0.6991869807243347, 0.7357723116874695, 0.7154471278190613, 0.7032520174980164, 0.7032520174980164, 0.7357723116874695, 0.7113820910453796, 0.7154471278190613, 0.7113820910453796, 0.7439023852348328, 0.7032520174980164, 0.7601625919342041, 0.7398373484611511, 0.7317072749137878, 0.7235772013664246, 0.7642276287078857, 0.7764227390289307, 0.7642276287078857, 0.7479674816131592, 0.7520325183868408, 0.8008129596710205, 0.8048779964447021, 0.7560975551605225, 0.7357723116874695, 0.7967479228973389, 0.7642276287078857, 0.7764227390289307, 0.7804877758026123, 0.8333333134651184, 0.8048779964447021, 0.8170731663703918, 0.7967479228973389, 0.8048779964447021, 0.8008129596710205, 0.7967479228973389, 0.8373983502388, 0.8170731663703918, 0.8089430332183838, 0.8089430332183838, 0.7764227390289307, 0.8008129596710205, 0.8252032399177551, 0.784552812576294, 0.7601625919342041, 0.707317054271698, 0.6869918704032898, 0.642276406288147, 0.5528454780578613, 0.5731707215309143, 0.6097560524940491, 0.5691056847572327, 0.5365853309631348, 0.544715404510498, 0.5528454780578613]
accuracies_val = [0.6116816401481628, 0.6100770235061646, 0.6139281392097473, 0.6103979349136353, 0.6107189059257507, 0.6136072278022766, 0.6120026111602783, 0.6107189059257507, 0.6116816401481628, 0.6129653453826904, 0.612323522567749, 0.6116816401481628, 0.6078305840492249, 0.608793318271637, 0.6107189059257507, 0.6071887016296387, 0.6075096726417542, 0.6081514954566956, 0.6065468788146973, 0.5991656184196472, 0.5962772965431213, 0.6094352006912231, 0.6094352006912231, 0.6039794683456421, 0.6010911464691162, 0.5994865298271179, 0.5953145027160645, 0.6017330288887024, 0.6036585569381714, 0.6062259674072266, 0.6055840849876404, 0.5972400903701782, 0.5879332423210144, 0.5937098860740662, 0.5982028245925903, 0.5978819131851196, 0.5943517684936523, 0.5905006527900696, 0.5917843580245972, 0.5895378589630127, 0.5885751247406006, 0.5898588299751282, 0.5956354737281799, 0.5882542133331299, 0.5860077142715454, 0.5943517684936523, 0.594672679901123, 0.5905006527900696, 0.5917843580245972, 0.589216947555542, 0.5905006527900696, 0.5917843580245972, 0.5795892477035522, 0.5783055424690247, 0.5879332423210144, 0.5847240090370178, 0.5827984809875488, 0.5811938643455505, 0.5802310705184937, 0.5866495966911316, 0.577984631061554, 0.5770218372344971, 0.5834403038024902, 0.5783055424690247, 0.5821565985679626, 0.5754172205924988, 0.5754172205924988, 0.5686777830123901, 0.5706033706665039, 0.5770218372344971, 0.5718870759010315, 0.5744544267654419, 0.5693196654319763, 0.5747753977775574, 0.5706033706665039, 0.5738126039505005, 0.5718870759010315, 0.5709242820739746, 0.569640576839447, 0.5670731663703918, 0.5718870759010315, 0.5673941373825073, 0.5734916925430298, 0.5645058155059814, 0.5638639330863953, 0.5609756112098694, 0.5693196654319763, 0.5529525279998779, 0.5558408498764038, 0.5561617612838745, 0.5689987540245056, 0.5722079873085022, 0.5365853905677795, 0.5510269999504089, 0.5481386780738831, 0.5587291717529297, 0.5683568716049194, 0.5750963091850281, 0.5741335153579712, 0.5975610017776489]
#[0.0, 0.13970588147640228, 0.17518247663974762, 0.1328125, 0.0833333358168602, 0.15037593245506287, 0.1267605572938919, 0.04477611929178238, 0.18705035746097565, 0.09848485141992569, 0.13571429252624512, 0.21875, 0.125, 0.03947368264198303, 0.08799999952316284, 0.1587301641702652, 0.19354838132858276, 0.1221374049782753, 0.1259259283542633, 0.15068493783473969, 0.22689075767993927, 0.1962616890668869, 0.13846154510974884, 0.1803278625011444, 0.1984127014875412, 0.21276596188545227, 0.2750000059604645, 0.2586206793785095, 0.27559053897857666, 0.18045112490653992, 0.2238806039094925, 0.27272728085517883, 0.1785714328289032, 0.24285714328289032, 0.26241135597229004, 0.2916666567325592, 0.28099173307418823, 0.3153846263885498, 0.23076923191547394, 0.31707316637039185, 0.31617647409439087, 0.35772356390953064, 0.24626865983009338, 0.27067670226097107, 0.26923078298568726, 0.3870967626571655, 0.3870967626571655, 0.3507462739944458, 0.2753623127937317, 0.3840000033378601, 0.341085284948349, 0.2928571403026581, 0.52173912525177, 0.46666666865348816, 0.356589138507843, 0.3863636255264282, 0.3840000033378601, 0.2986111044883728, 0.49180328845977783, 0.35338345170021057, 0.4590163826942444, 0.38235294818878174, 0.4253731369972229, 0.5161290168762207, 0.5081967115402222, 0.47999998927116394, 0.41726619005203247, 0.692307710647583, 0.6460176706314087, 0.4251968562602997, 0.53125, 0.4135338366031647, 0.43609023094177246, 0.40145984292030334, 0.5413534045219421, 0.3288590610027313, 0.42553192377090454, 0.4848484992980957, 0.4887218177318573, 0.5038167834281921, 0.5076923370361328, 0.6747967600822449, 0.46715328097343445, 0.5669291615486145, 0.5669291615486145, 0.5655737519264221, 0.4485294222831726, 0.5037037134170532, 0.41911765933036804, 0.5327869057655334, 0.4146341383457184, 0.4955752193927765, 0.3504273593425751, 0.3333333432674408, 0.29357796907424927, 0.3513513505458832, 0.27272728085517883, 0.15789473056793213, 0.17543859779834747, 0.24770642817020416]
percent_val = [0.0, 0.023694129660725594, 0.050521690398454666, 0.02368137799203396, 0.009013785980641842, 0.048245612531900406, 0.03081081062555313, 0.007411328610032797, 0.046103183180093765, 0.007915567606687546, 0.033586133271455765, 0.11331775784492493, 0.06824591010808945, 0.017703862860798836, 0.018191546201705933, 0.06591549515724182, 0.10896309465169907, 0.03382433205842972, 0.018318966031074524, 0.06807780265808105, 0.0916568711400032, 0.06625491380691528, 0.06565656512975693, 0.08849045634269714, 0.11421772837638855, 0.10597986727952957, 0.08797653764486313, 0.08885017782449722, 0.10000000149011612, 0.09253904223442078, 0.08886324614286423, 0.10971973836421967, 0.10228639841079712, 0.09792284667491913, 0.12019230425357819, 0.1332116723060608, 0.12106537818908691, 0.13931888341903687, 0.13827160000801086, 0.11671732366085052, 0.1469668596982956, 0.15670232474803925, 0.12009655684232712, 0.12939001619815826, 0.15936507284641266, 0.1361963152885437, 0.13264058530330658, 0.1727214753627777, 0.1619407683610916, 0.16942675411701202, 0.17048345506191254, 0.15901948511600494, 0.2202702760696411, 0.1981382966041565, 0.16465352475643158, 0.1994733363389969, 0.19473683834075928, 0.21056149899959564, 0.2183288335800171, 0.18087854981422424, 0.2152496576309204, 0.20187166333198547, 0.21605351567268372, 0.23171564936637878, 0.19656991958618164, 0.2255639135837555, 0.21148648858070374, 0.23226703703403473, 0.2321552336215973, 0.22063815593719482, 0.24094708263874054, 0.2193460464477539, 0.25017619132995605, 0.22671233117580414, 0.2529950737953186, 0.234806627035141, 0.2514044940471649, 0.25546929240226746, 0.26065340638160706, 0.26575931906700134, 0.2549295723438263, 0.2655690908432007, 0.26200565695762634, 0.2839415967464447, 0.26767677068710327, 0.2805860936641693, 0.26263344287872314, 0.2935435473918915, 0.2651570439338684, 0.2780236005783081, 0.2583392560482025, 0.26813656091690063, 0.3333333432674408, 0.24330195784568787, 0.23054754734039307, 0.1941015124320984, 0.15000000596046448, 0.17508196830749512, 0.12657430768013, 0.08255814015865326]

#[0.6371681690216064, 0.6106194853782654, 0.5575221180915833, 0.5663716793060303, 0.6460176706314087, 0.5663716793060303, 0.6283186078071594, 0.6814159154891968, 0.5486725568771362, 0.6548672318458557, 0.6283186078071594, 0.6637167930603027, 0.6194690465927124, 0.6371681690216064, 0.6283186078071594, 0.5929203629493713, 0.6283186078071594, 0.6017699241638184, 0.5663716793060303, 0.5309734344482422, 0.6725663542747498, 0.6460176706314087, 0.6548672318458557, 0.6460176706314087, 0.6548672318458557, 0.6991150379180908, 0.6637167930603027, 0.7433628439903259, 0.5840708017349243, 0.6017699241638184, 0.6814159154891968, 0.6637167930603027, 0.6460176706314087, 0.6017699241638184, 0.6371681690216064, 0.6106194853782654, 0.6725663542747498, 0.6460176706314087, 0.6814159154891968, 0.7345132827758789, 0.6814159154891968, 0.6548672318458557, 0.6637167930603027, 0.6814159154891968, 0.769911527633667, 0.5309734344482422, 0.6991150379180908, 0.6283186078071594, 0.7345132827758789, 0.6637167930603027, 0.7079645991325378, 0.7433628439903259, 0.6814159154891968, 0.7964601516723633, 0.7168141603469849, 0.7876105904579163, 0.7256637215614319, 0.6902654767036438, 0.7345132827758789, 0.7256637215614319, 0.6991150379180908, 0.7433628439903259, 0.769911527633667, 0.6637167930603027, 0.752212405204773, 0.7433628439903259, 0.7256637215614319, 0.6637167930603027, 0.752212405204773, 0.76106196641922, 0.752212405204773, 0.7964601516723633, 0.7964601516723633, 0.7433628439903259, 0.7964601516723633, 0.7433628439903259, 0.7787610292434692, 0.7964601516723633, 0.8141592741012573, 0.8230088353157043, 0.8407079577445984, 0.8318583965301514, 0.8141592741012573, 0.8584070801734924, 0.752212405204773, 0.8230088353157043, 0.8672566413879395, 0.7787610292434692, 0.8053097128868103, 0.8407079577445984, 0.76106196641922, 0.8053097128868103, 0.8407079577445984, 0.8141592741012573, 0.8318583965301514, 0.7964601516723633, 0.7876105904579163, 0.8407079577445984, 0.8230088353157043, 0.8761062026023865]
accuracies_val = [0.617511510848999, 0.6231832504272461, 0.6253101825714111, 0.6178659796714783, 0.6178659796714783, 0.6249556541442871, 0.6178659796714783, 0.6189294457435608, 0.6221197843551636, 0.6196384429931641, 0.6182204484939575, 0.6221197843551636, 0.6185749769210815, 0.6224743127822876, 0.6210563778877258, 0.6185749769210815, 0.6214108467102051, 0.6217653155326843, 0.6214108467102051, 0.6132577061653137, 0.6207018494606018, 0.6168025135993958, 0.6168025135993958, 0.6182204484939575, 0.6143211722373962, 0.6182204484939575, 0.61928391456604, 0.6150301098823547, 0.6207018494606018, 0.615739107131958, 0.6199929118156433, 0.6263735890388489, 0.6111307740211487, 0.6199929118156433, 0.6132577061653137, 0.6146756410598755, 0.611839771270752, 0.6075859665870667, 0.6146756410598755, 0.6129032373428345, 0.611839771270752, 0.6058135032653809, 0.5997872948646545, 0.6121942400932312, 0.5983693599700928, 0.6058135032653809, 0.5983693599700928, 0.6033321619033813, 0.6061680316925049, 0.596951425075531, 0.6086494326591492, 0.5973058938980103, 0.6121942400932312, 0.5937610864639282, 0.6051045656204224, 0.59517902135849, 0.5976603627204895, 0.5955334901809692, 0.5898617506027222, 0.5955334901809692, 0.5880893468856812, 0.5983693599700928, 0.5827720761299133, 0.5994328260421753, 0.5813541412353516, 0.594824492931366, 0.5856079459190369, 0.598723828792572, 0.5824175477027893, 0.6029776334762573, 0.5859624147415161, 0.6019142270088196, 0.5838354825973511, 0.5980148911476135, 0.5813541412353516, 0.5941155552864075, 0.5749734044075012, 0.5902162194252014, 0.5732010006904602, 0.5944700241088867, 0.574618935585022, 0.5902162194252014, 0.5838354825973511, 0.5856079459190369, 0.5891527533531189, 0.5873803496360779, 0.5994328260421753, 0.5834810137748718, 0.5997872948646545, 0.580290675163269, 0.5976603627204895, 0.5838354825973511, 0.5880893468856812, 0.5834810137748718, 0.5820630788803101, 0.5877348184585571, 0.5799362063407898, 0.5859624147415161, 0.5771003365516663, 0.576391339302063]
#[0.0, 0.095238097012043, 0.016129031777381897, 0.01587301678955555, 0.0, 0.1428571492433548, 0.014285714365541935, 0.02666666731238365, 0.03333333507180214, 0.0422535203397274, 0.014285714365541935, 0.20967741310596466, 0.04477611929178238, 0.02857142873108387, 0.20338982343673706, 0.046875, 0.1269841343164444, 0.09677419066429138, 0.04918032884597778, 0.1320754736661911, 0.0555555559694767, 0.07352941483259201, 0.15625, 0.013888888992369175, 0.12121212482452393, 0.12857143580913544, 0.08695652335882187, 0.2537313401699066, 0.0476190485060215, 0.13333334028720856, 0.10000000149011612, 0.11940298229455948, 0.140625, 0.0625, 0.22033898532390594, 0.18965516984462738, 0.07042253762483597, 0.21666666865348816, 0.1492537260055542, 0.296875, 0.1492537260055542, 0.1746031790971756, 0.2295081913471222, 0.203125, 0.4262295067310333, 0.1538461595773697, 0.2950819730758667, 0.20338982343673706, 0.2769230902194977, 0.171875, 0.23076923191547394, 0.35483869910240173, 0.1846153885126114, 0.38461539149284363, 0.20895522832870483, 0.30882352590560913, 0.28125, 0.30000001192092896, 0.4067796468734741, 0.24242424964904785, 0.36206895112991333, 0.37704917788505554, 0.5, 0.31578946113586426, 0.4166666567325592, 0.20000000298023224, 0.4909090995788574, 0.2295081913471222, 0.3076923191547394, 0.32307693362236023, 0.3492063581943512, 0.20000000298023224, 0.4285714328289032, 0.23529411852359772, 0.4516128897666931, 0.4237288236618042, 0.49152541160583496, 0.40625, 0.3142857253551483, 0.2567567527294159, 0.43939393758773804, 0.38235294818878174, 0.2957746386528015, 0.3287671208381653, 0.328125, 0.4307692348957062, 0.30666667222976685, 0.6603773832321167, 0.3382352888584137, 0.6101694703102112, 0.3030303120613098, 0.5166666507720947, 0.3970588147640228, 0.39393940567970276, 0.446153849363327, 0.26760563254356384, 0.30882352590560913, 0.5322580933570862, 0.40909090638160706, 0.43478259444236755]
percent_val = [0.0, 0.012089810334146023, 0.032786883413791656, 0.0005740527994930744, 0.0005740527994930744, 0.04505038633942604, 0.0011487650917842984, 0.006340057589113712, 0.025116821750998497, 0.009820912964642048, 0.002298850566148758, 0.07339449226856232, 0.008670520037412643, 0.023310022428631783, 0.08148147910833359, 0.006924408487975597, 0.030570251867175102, 0.04033214598894119, 0.018593840301036835, 0.0839598998427391, 0.029394473880529404, 0.04316546767950058, 0.08343710750341415, 0.02047981321811676, 0.07439553737640381, 0.04368641600012779, 0.04673457145690918, 0.09740670770406723, 0.030606238171458244, 0.08766437321901321, 0.06064281240105629, 0.051160022616386414, 0.10089399665594101, 0.0473053902387619, 0.09148264676332474, 0.07501550018787384, 0.05630354955792427, 0.1137102022767067, 0.049636803567409515, 0.13006536662578583, 0.08077645301818848, 0.09271099418401718, 0.1424712985754013, 0.07267080992460251, 0.16413792967796326, 0.07959570735692978, 0.13823330402374268, 0.11096605658531189, 0.09685695916414261, 0.15263518691062927, 0.08602150529623032, 0.17421603202819824, 0.09511730819940567, 0.17132867872714996, 0.11060507595539093, 0.15157750248908997, 0.1368846893310547, 0.14130434393882751, 0.16526611149311066, 0.1267605572938919, 0.18924731016159058, 0.12084993720054626, 0.21507760882377625, 0.11986754834651947, 0.2229679375886917, 0.11941294372081757, 0.2182890921831131, 0.12600000202655792, 0.24753226339817047, 0.13703209161758423, 0.2409909963607788, 0.1395973116159439, 0.23741547763347626, 0.1437288075685501, 0.25190839171409607, 0.1487320065498352, 0.25541794300079346, 0.15145228803157806, 0.24480369687080383, 0.1621621549129486, 0.22432024776935577, 0.177510604262352, 0.21729490160942078, 0.2005814015865326, 0.19139784574508667, 0.22740741074085236, 0.1800418645143509, 0.24981017410755157, 0.17581653594970703, 0.25536808371543884, 0.18482080101966858, 0.24583964049816132, 0.19956615567207336, 0.22016307711601257, 0.23551543056964874, 0.19624820351600647, 0.2771272361278534, 0.19522777199745178, 0.2768627405166626, 0.22347629070281982]

#[0.5833333730697632, 0.5, 0.6666666865348816, 0.625, 0.7083333730697632, 0.75, 0.625, 0.75, 0.4166666865348816, 0.7083333730697632, 0.75, 0.7916666865348816, 0.6666666865348816, 0.4166666865348816, 0.5416666865348816, 0.4166666865348816, 0.7083333730697632, 0.75, 0.5416666865348816, 0.7916666865348816, 0.7916666865348816, 0.8333333730697632, 0.6666666865348816, 0.625, 0.6666666865348816, 0.5833333730697632, 0.5, 0.5416666865348816, 0.6666666865348816, 0.6666666865348816, 0.4166666865348816, 0.6666666865348816, 0.625, 0.6666666865348816, 0.7083333730697632, 0.5833333730697632, 0.625, 0.6666666865348816, 0.6666666865348816, 0.5833333730697632, 0.6666666865348816, 0.75, 0.7916666865348816, 0.9166666865348816, 0.7916666865348816, 0.6666666865348816, 0.6666666865348816, 0.8333333730697632, 0.7083333730697632, 0.7916666865348816, 0.5833333730697632, 0.5833333730697632, 0.75, 0.7083333730697632, 0.75, 0.625, 0.75, 0.6666666865348816, 0.9583333730697632, 0.875, 0.7916666865348816, 0.8333333730697632, 0.625, 0.625, 0.75, 0.75, 0.6666666865348816, 0.6666666865348816, 0.9166666865348816, 0.875, 0.8333333730697632, 0.75, 0.8333333730697632, 0.7916666865348816, 0.7916666865348816, 0.7083333730697632, 0.75, 0.7083333730697632, 0.7916666865348816, 0.8333333730697632, 0.7916666865348816, 0.8333333730697632, 0.8333333730697632, 0.75, 0.7916666865348816, 0.9166666865348816, 0.7916666865348816, 0.9166666865348816, 0.9166666865348816, 0.7916666865348816, 0.875, 0.9166666865348816, 0.75, 0.7916666865348816, 0.8333333730697632, 0.75, 0.9166666865348816, 0.7916666865348816, 0.7916666865348816, 0.75]
accuracies_val = [0.6485507488250732, 0.6500000357627869, 0.645652174949646, 0.6485507488250732, 0.6481884121894836, 0.6514493227005005, 0.6518115997314453, 0.6507246494293213, 0.6507246494293213, 0.6510869860649109, 0.6514493227005005, 0.6507246494293213, 0.6496376991271973, 0.6518115997314453, 0.6510869860649109, 0.6503623127937317, 0.6521739363670349, 0.6518115997314453, 0.6532608866691589, 0.655434787273407, 0.6525362730026245, 0.6536232233047485, 0.6525362730026245, 0.6528985500335693, 0.6572464108467102, 0.6528985500335693, 0.654347836971283, 0.6528985500335693, 0.6550725102424622, 0.6532608866691589, 0.6532608866691589, 0.6539855003356934, 0.6510869860649109, 0.6507246494293213, 0.64673912525177, 0.6478261351585388, 0.6478261351585388, 0.645652174949646, 0.6452898979187012, 0.6409420371055603, 0.6431159377098083, 0.6384057998657227, 0.644565224647522, 0.6394927501678467, 0.6344203352928162, 0.6420289874076843, 0.6275362372398376, 0.6380435228347778, 0.6253623366355896, 0.6322463750839233, 0.6224637627601624, 0.6242753863334656, 0.6300724744796753, 0.6170290112495422, 0.6355072855949402, 0.6213768124580383, 0.634782612323761, 0.6159420609474182, 0.6311594247817993, 0.6173913478851318, 0.6278985738754272, 0.6199275851249695, 0.6159420609474182, 0.6210145354270935, 0.6210145354270935, 0.6282609105110168, 0.6170290112495422, 0.6315217614173889, 0.6192029118537903, 0.6344203352928162, 0.6242753863334656, 0.634782612323761, 0.6144927740097046, 0.6264492869377136, 0.6166666746139526, 0.6300724744796753, 0.6166666746139526, 0.6206521987915039, 0.6072463989257812, 0.6206521987915039, 0.6112319231033325, 0.6195652484893799, 0.6079710125923157, 0.6126812100410461, 0.6119565367698669, 0.6097826361656189, 0.6119565367698669, 0.6105072498321533, 0.6217391490936279, 0.6170290112495422, 0.6217391490936279, 0.6105072498321533, 0.6202898621559143, 0.604347825050354, 0.6061594486236572, 0.604347825050354, 0.6007246375083923, 0.6090579628944397, 0.5992754101753235, 0.6083333492279053]
#[0.0, 0.0, 0.06666667014360428, 0.0, 0.0, 0.20000000298023224, 0.0, 0.2857142984867096, 0.1111111119389534, 0.0625, 0.0, 0.3571428656578064, 0.06666667014360428, 0.25, 0.0833333358168602, 0.1111111119389534, 0.0625, 0.05882352963089943, 0.1818181872367859, 0.11764705926179886, 0.0555555559694767, 0.05263157933950424, 0.06666667014360428, 0.25, 0.3333333432674408, 0.27272728085517883, 0.0, 0.1818181872367859, 0.06666667014360428, 0.1428571492433548, 0.1111111119389534, 0.4545454680919647, 0.1538461595773697, 0.1428571492433548, 0.2142857164144516, 0.1666666716337204, 0.1538461595773697, 0.1428571492433548, 0.23076923191547394, 0.1666666716337204, 0.1428571492433548, 0.38461539149284363, 0.0555555559694767, 0.375, 0.2666666805744171, 0.06666667014360428, 0.1428571492433548, 0.4285714328289032, 0.5454545617103577, 0.5833333134651184, 0.1666666716337204, 0.5555555820465088, 0.0, 0.13333334028720856, 0.2857142984867096, 0.1538461595773697, 0.125, 0.4545454680919647, 0.095238097012043, 0.23529411852359772, 0.1875, 0.8181818127632141, 0.1538461595773697, 0.5, 0.38461539149284363, 0.38461539149284363, 0.4545454680919647, 0.1428571492433548, 0.2222222238779068, 0.3125, 0.25, 0.6363636255264282, 0.4285714328289032, 0.3571428656578064, 0.2666666805744171, 0.5454545617103577, 0.38461539149284363, 0.4166666567325592, 0.2666666805744171, 0.25, 0.3571428656578064, 0.1764705926179886, 0.6666666865348816, 0.6363636255264282, 0.1875, 0.46666666865348816, 0.1875, 0.375, 0.692307710647583, 0.4615384638309479, 0.23529411852359772, 0.46666666865348816, 0.20000000298023224, 0.3571428656578064, 0.5384615659713745, 0.5, 0.5714285969734192, 0.3571428656578064, 0.7272727489471436, 0.2857142984867096]
percent_val = [0.0, 0.005605380982160568, 0.05945303291082382, 0.0, 0.0005592841189354658, 0.05330989882349968, 0.016384180635213852, 0.008988764137029648, 0.0546095110476017, 0.01011804398149252, 0.013528748415410519, 0.06021251529455185, 0.009003939107060432, 0.03868360444903374, 0.025099828839302063, 0.006730230059474707, 0.05944673344492912, 0.013521126471459866, 0.015202702954411507, 0.041450776159763336, 0.014647887088358402, 0.04096941649913788, 0.020396601408720016, 0.020963173359632492, 0.047344110906124115, 0.014639639295637608, 0.0373348668217659, 0.029126213863492012, 0.023203169927001, 0.062463171780109406, 0.022688599303364754, 0.05125218257308006, 0.036332178860902786, 0.021615471690893173, 0.05621301755309105, 0.022883296012878418, 0.05736250802874565, 0.03484320640563965, 0.025921659544110298, 0.0586475171148777, 0.024235429242253304, 0.05508982017636299, 0.03130434826016426, 0.04314420744776726, 0.06703229993581772, 0.03747072443366051, 0.08250000327825546, 0.040165387094020844, 0.08213166147470474, 0.05565638095140457, 0.07576706260442734, 0.08296668529510498, 0.06621704250574112, 0.09729381650686264, 0.06561360508203506, 0.11075129359960556, 0.07155963033437729, 0.1111111119389534, 0.07398273795843124, 0.09652509540319443, 0.08244846761226654, 0.09679487347602844, 0.09819121658802032, 0.08825396746397018, 0.11082307249307632, 0.0898805782198906, 0.12113232165575027, 0.09347553551197052, 0.1265655905008316, 0.09780564159154892, 0.14181576669216156, 0.10466582328081131, 0.14749661087989807, 0.10197578370571136, 0.14844805002212524, 0.10905612260103226, 0.14844805002212524, 0.10802070051431656, 0.15030884742736816, 0.1196078434586525, 0.14295393228530884, 0.12278398126363754, 0.13685636222362518, 0.13110367953777313, 0.1327967792749405, 0.1387009471654892, 0.13127930462360382, 0.1636740267276764, 0.1386861354112625, 0.1777316778898239, 0.1356717348098755, 0.1758548468351364, 0.14591699838638306, 0.17464788258075714, 0.15618520975112915, 0.1631799191236496, 0.1733899563550949, 0.1529492437839508, 0.18566308915615082, 0.15236787497997284]

#[0.6000000238418579, 0.675000011920929, 0.6875, 0.6625000238418579, 0.637499988079071, 0.6625000238418579, 0.612500011920929, 0.6875, 0.6000000238418579, 0.550000011920929, 0.5875000357627869, 0.637499988079071, 0.5, 0.675000011920929, 0.6875, 0.6625000238418579, 0.6625000238418579, 0.6500000357627869, 0.6625000238418579, 0.75, 0.637499988079071, 0.675000011920929, 0.6000000238418579, 0.5875000357627869, 0.6875, 0.625, 0.6500000357627869, 0.699999988079071, 0.625, 0.625, 0.7250000238418579, 0.737500011920929, 0.762499988079071, 0.6625000238418579, 0.5250000357627869, 0.7125000357627869, 0.550000011920929, 0.737500011920929, 0.6500000357627869, 0.637499988079071, 0.6500000357627869, 0.6625000238418579, 0.6625000238418579, 0.6875, 0.737500011920929, 0.675000011920929, 0.75, 0.7250000238418579, 0.699999988079071, 0.7750000357627869, 0.7875000238418579, 0.6875, 0.8125, 0.737500011920929, 0.7875000238418579, 0.7750000357627869, 0.7250000238418579, 0.675000011920929, 0.762499988079071, 0.737500011920929, 0.75, 0.8125, 0.7250000238418579, 0.6875, 0.7750000357627869, 0.800000011920929, 0.8125, 0.6875, 0.8125, 0.7125000357627869, 0.7125000357627869, 0.7875000238418579, 0.8375000357627869, 0.762499988079071, 0.7875000238418579, 0.8125, 0.800000011920929, 0.8125, 0.800000011920929, 0.800000011920929, 0.737500011920929, 0.824999988079071, 0.7750000357627869, 0.875, 0.824999988079071, 0.800000011920929, 0.800000011920929, 0.824999988079071, 0.8500000238418579, 0.824999988079071, 0.824999988079071, 0.824999988079071, 0.800000011920929, 0.8500000238418579, 0.824999988079071, 0.800000011920929, 0.875, 0.824999988079071, 0.7750000357627869, 0.8125]
accuracies_val = [0.6293756365776062, 0.6282930374145508, 0.6257668733596802, 0.6297365427017212, 0.6293756365776062, 0.6293756365776062, 0.6264886260032654, 0.629014790058136, 0.6268494725227356, 0.6300974488258362, 0.628653883934021, 0.6228798031806946, 0.6261277198791504, 0.6268494725227356, 0.6264886260032654, 0.6275712847709656, 0.6272103786468506, 0.6275712847709656, 0.6272103786468506, 0.6221580505371094, 0.6254059672355652, 0.6239624619483948, 0.6210753917694092, 0.624323308467865, 0.6210753917694092, 0.6221580505371094, 0.6221580505371094, 0.6084445714950562, 0.6185492277145386, 0.6088054776191711, 0.6156622171401978, 0.6131360530853271, 0.6062793135643005, 0.6171057224273682, 0.6084445714950562, 0.6149404644966125, 0.6109707355499268, 0.6088054776191711, 0.6145795583724976, 0.6033922433853149, 0.6167448163032532, 0.6037531495094299, 0.6127751469612122, 0.6116925477981567, 0.6088054776191711, 0.6153013110160828, 0.6015878915786743, 0.6181883811950684, 0.6026704907417297, 0.6145795583724976, 0.6084445714950562, 0.6134968996047974, 0.6073619723320007, 0.6073619723320007, 0.6091663837432861, 0.6055575609207153, 0.6106098890304565, 0.5994225740432739, 0.6062793135643005, 0.5896788239479065, 0.6026704907417297, 0.5864309072494507, 0.6023096442222595, 0.5806567668914795, 0.5990617275238037, 0.5799350142478943, 0.5994225740432739, 0.5745218396186829, 0.5940093994140625, 0.5795741677284241, 0.5914832353591919, 0.5828220844268799, 0.5925658345222473, 0.5824611783027649, 0.5853482484817505, 0.5831829309463501, 0.5839047431945801, 0.5896788239479065, 0.5860700011253357, 0.5918440818786621, 0.5795741677284241, 0.5976181626319885, 0.5799350142478943, 0.5961746573448181, 0.5846264958381653, 0.5900396704673767, 0.5821003317832947, 0.5756044387817383, 0.5799350142478943, 0.5687477588653564, 0.5777697563171387, 0.5597257018089294, 0.5763262510299683, 0.5673041939735413, 0.5698303580284119, 0.5705521106719971, 0.5640562772750854, 0.5734391808509827, 0.5629736185073853, 0.5723565220832825]
#[0.0, 0.03846153989434242, 0.03773584961891174, 0.0, 0.040816325694322586, 0.12765957415103912, 0.06521739065647125, 0.0, 0.06666667014360428, 0.0731707289814949, 0.1190476194024086, 0.24390244483947754, 0.05263157933950424, 0.01886792480945587, 0.2222222238779068, 0.08163265138864517, 0.1041666641831398, 0.15555556118488312, 0.05999999865889549, 0.1320754736661911, 0.13333334028720856, 0.07999999821186066, 0.2631579041481018, 0.06818182021379471, 0.3095238208770752, 0.1627907007932663, 0.15555556118488312, 0.19148936867713928, 0.13636364042758942, 0.2195121943950653, 0.11538461595773697, 0.28260868787765503, 0.19607843458652496, 0.12765957415103912, 0.20000000298023224, 0.1875, 0.10000000149011612, 0.2291666716337204, 0.26829269528388977, 0.3076923191547394, 0.20930232107639313, 0.261904776096344, 0.32499998807907104, 0.3095238208770752, 0.20408163964748383, 0.17391304671764374, 0.22448979318141937, 0.2888889014720917, 0.5135135054588318, 0.21568627655506134, 0.4000000059604645, 0.25, 0.22641509771347046, 0.34090909361839294, 0.125, 0.3478260934352875, 0.23404255509376526, 0.5428571701049805, 0.3863636255264282, 0.2291666716337204, 0.27659574151039124, 0.20370370149612427, 0.5263158082962036, 0.44736841320991516, 0.40909090638160706, 0.5609756112098694, 0.30000001192092896, 0.3095238208770752, 0.47727271914482117, 0.6285714507102966, 0.3255814015865326, 0.4000000059604645, 0.3400000035762787, 0.3863636255264282, 0.4651162922382355, 0.27450981736183167, 0.3333333432674408, 0.3265306055545807, 0.3617021143436432, 0.5609756112098694, 0.34090909361839294, 0.20000000298023224, 0.3191489279270172, 0.3207547068595886, 0.5714285969734192, 0.5609756112098694, 0.4883720874786377, 0.5714285969734192, 0.4166666567325592, 0.5714285969734192, 0.5, 0.6097561120986938, 0.3913043439388275, 0.44680851697921753, 0.29411765933036804, 0.523809552192688, 0.37254902720451355, 0.3199999928474426, 0.44186046719551086, 0.20370370149612427]
percent_val = [0.0, 0.011033682152628899, 0.05090909078717232, 0.0005733944708481431, 0.002875215606763959, 0.022274326533079147, 0.029045643284916878, 0.004032257944345474, 0.025989368557929993, 0.008083140477538109, 0.007518797181546688, 0.0279928520321846, 0.009895226918160915, 0.009883721359074116, 0.02600472792983055, 0.00694846548140049, 0.02840236760675907, 0.03511904925107956, 0.006952491123229265, 0.03419316187500954, 0.022418878972530365, 0.02489626593887806, 0.062345679849386215, 0.015258216299116611, 0.042398545891046524, 0.0243612602353096, 0.02193242497742176, 0.054409004747867584, 0.02634730562567711, 0.05967336520552635, 0.04024390131235123, 0.04297114908695221, 0.0824742242693901, 0.03699211776256561, 0.08704061806201935, 0.04156479239463806, 0.05482865869998932, 0.07864449918270111, 0.04671173915266991, 0.12064342945814133, 0.05493827164173126, 0.10794702172279358, 0.07400379329919815, 0.07414449006319046, 0.09616634249687195, 0.07030759751796722, 0.11954331398010254, 0.07668133080005646, 0.12990528345108032, 0.08126983791589737, 0.12399999797344208, 0.10174983739852905, 0.11678832024335861, 0.1116248369216919, 0.10543549805879593, 0.1284465342760086, 0.10299869626760483, 0.13378839194774628, 0.1001964658498764, 0.14989444613456726, 0.10303831100463867, 0.15988579392433167, 0.10237780958414078, 0.1693314015865326, 0.10814419388771057, 0.17556694149971008, 0.11551376432180405, 0.1697281450033188, 0.1174473837018013, 0.17140772938728333, 0.12030075490474701, 0.17369185388088226, 0.1410701870918274, 0.16282421350479126, 0.14953933656215668, 0.15925393998622894, 0.1682310402393341, 0.15477031469345093, 0.1923641711473465, 0.1508771926164627, 0.20661157369613647, 0.147609144449234, 0.21283018589019775, 0.1520223170518875, 0.21348313987255096, 0.16952790319919586, 0.197475865483284, 0.20105421543121338, 0.17213712632656097, 0.22075910866260529, 0.15930485725402832, 0.23487260937690735, 0.16569343209266663, 0.21016165614128113, 0.19259819388389587, 0.18782870471477509, 0.2287735790014267, 0.17442719638347626, 0.2293144166469574, 0.18535126745700836]

print(np.mean(accuracies_val[75:]),np.mean(percent_val[75:]))

In [None]:
#test set actual vs predicted
train_pred = model(trainset[:][0]).view(-1)
train_pred = train_pred.detach().numpy()>0.5
train_target = trainset[:][1].numpy()
print(1-sum(train_target)/len(train_target))
accuracy = (sum(train_pred*train_target)+sum((1-train_pred)*(1-train_target)))/len(train_pred)
print(f'accuracy={accuracy}')

val_pred = model(valset[:][0]).view(-1)
val_pred = val_pred.detach().numpy()>0.5
val_target = valset[:][1].numpy()
print(1-sum(val_target)/len(val_target))
accuracy = (sum(val_pred*val_target)+sum((1-val_pred)*(1-val_target)))/len(val_pred)
print(f'accuracy={accuracy}')
plt.plot(1.2 * val_pred - 0.1,'b.') #,label='predicted')
plt.plot(val_target,'r.') #,label='original')
plt.legend()
plt.xlim([200, 400])
plt.show()