In [None]:
import os, sys
os.chdir('/home/seigyo/Documents/pytorch/brain_decoder')
sys.path.append(os.pardir)
import numpy as np
from numpy.random import RandomState
import torch
from torch import nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch import optim
import mne
from mne.io import concatenate_raws
from mymodule.utils import data_loader, evaluator
from mymodule.layers import LSTM, Residual_block, Res_net, Wavelet_cnn, NlayersSeqConvLSTM
from mymodule.trainer import Trainer
from mymodule.optim import Eve, YFOptimizer
from sklearn.utils import shuffle
from tensorboardX import SummaryWriter
from load_data import get_data, get_data_multi, get_crops, get_crops_multi
from sklearn.model_selection import KFold
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline

epochs = 100
batch_size = 10
cv_splits = 5
torch.manual_seed(1214)
torch.cuda.manual_seed_all(1214)
num_of_subjects = 99



class Conv_lstm(nn.Module):
  def __init__(self):
    super(Conv_lstm, self).__init__()
    self.conv_time = nn.Conv2d(1, 40, (25, 1))
    self.batchnorm1 = nn.BatchNorm2d(40)
    self.conv_spat = nn.Conv2d(40, 40, (1, 64), bias=False)
    self.batchnorm2 = nn.BatchNorm2d(40)
    self.pool = nn.AvgPool2d(kernel_size=(75, 1), stride=(15, 1))
    self.dropout = nn.Dropout2d(p=0.5)
    self.lstm = LSTM(40, 10, batch_size, bidirectional=True, gpu=True, return_seq=False)
    self.dropout_linear = nn.Dropout(p=0.5)
    self.classifier = nn.Linear(20, 2)

  def forward(self, x):
    h = self.conv_time(x)
#     h = self.batchnorm1(h)
    h = self.conv_spat(h)
    h = self.batchnorm2(h)
    h = self.pool(h)
    h = self.dropout(h)
    h = h.squeeze().transpose(1, 2)
    h = self.lstm(h)
    h = self.dropout_linear(h)
    h = self.classifier(h)
    return h

# criterion = torch.nn.CrossEntropyLoss()
# optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

def cv_train(model_class, criterion_class, optimizer_class, X, y,
             epoch=100, num_of_cv=10, batch_size=16):
    kf = KFold(n_splits=num_of_cv, shuffle=True)
    accuracy = []
    for train_idx, val_idx in kf.split(X=X, y=y):
        train_x, val_x = X[train_idx], X[val_idx]
        train_y, val_y = y[train_idx], y[val_idx]
        train_loader = data_loader(train_x, train_y, batch_size=batch_size,
                           shuffle=True, gpu=False)
        val_loader = data_loader(val_x, val_y, batch_size=batch_size)
        writer = SummaryWriter()
        model = model_class().cuda()
        criterion = criterion_class()
        optimizer = optimizer_class(model.parameters(), lr=1e-4)
        trainer = Trainer(model, criterion, optimizer,
                  train_loader, val_loader,
                  val_num=1, early_stopping=2,
                  writer=writer, gpu=True)
        trainer.run(epochs=epoch)
        accuracy.append(trainer.val_best_acc)
    return accuracy


all_accs_list = []
all_mean_list = []
all_var_list = []

for idx in range(num_of_subjects):
    X, y = get_data(id=idx+1, event_code=[6,10,14], filter=[0.5, 30], t=[0., 4])
    X = X.reshape(X.shape[0], 1, X.shape[1], X.shape[2]).transpose(0,1,3,2)

#     model = Conv_lstm()
#     model.cuda()

    acc = cv_train(Conv_lstm, torch.nn.CrossEntropyLoss,
                   torch.optim.Adam, X, y, epoch=epochs,
                   num_of_cv=cv_splits, batch_size=batch_size)

    mean = np.mean(acc)
    var = np.var(acc)
    print('subject{}   mean_acc:{}, var_acc:{}'.format(idx+1, mean, var))

    all_accs_list.append(acc)
    all_mean_list.append(mean)
    all_var_list.append(var)

all_mean = np.mean(all_accs_list)
all_var = np.var(all_accs_list)

print('all subjects  mean_acc:{}, var_acc:{}'.format(all_mean, all_var))

Removing orphaned offset at the beginning of the file.
89 events found
Events id: [1 2 3]
45 matching events found
Not setting metadata
Loading data for 45 events and 641 original time points ...
0 bad epochs dropped
----------start training----------
epoch:1, tr_loss:0.5671, tr_acc:0.4722,   val_loss:0.3276, val_acc:0.6667
epoch:2, tr_loss:0.5657, tr_acc:0.4444,   val_loss:0.3300, val_acc:0.6667
epoch:3, tr_loss:0.5538, tr_acc:0.5000,   val_loss:0.3298, val_acc:0.6667
epoch:4, tr_loss:0.5447, tr_acc:0.4167,   val_loss:0.3303, val_acc:0.7778
epoch:5, tr_loss:0.5506, tr_acc:0.5278,   val_loss:0.3304, val_acc:0.6667
epoch:6, tr_loss:0.5308, tr_acc:0.5833,   val_loss:0.3318, val_acc:0.6667
epoch:7, tr_loss:0.5428, tr_acc:0.5000,   val_loss:0.3337, val_acc:0.6667
epoch:8, tr_loss:0.5163, tr_acc:0.6667,   val_loss:0.3357, val_acc:0.6667
epoch:9, tr_loss:0.5292, tr_acc:0.6944,   val_loss:0.3361, val_acc:0.6667
epoch:10, tr_loss:0.5313, tr_acc:0.5556,   val_loss:0.3362, val_acc:0.6667
epoch:1

epoch:8, tr_loss:0.5277, tr_acc:0.5833,   val_loss:0.3322, val_acc:0.5556
epoch:9, tr_loss:0.5469, tr_acc:0.6667,   val_loss:0.3331, val_acc:0.5556
epoch:10, tr_loss:0.5114, tr_acc:0.6944,   val_loss:0.3340, val_acc:0.5556
epoch:11, tr_loss:0.5176, tr_acc:0.6111,   val_loss:0.3341, val_acc:0.5556
epoch:12, tr_loss:0.5119, tr_acc:0.6944,   val_loss:0.3350, val_acc:0.5556
epoch:13, tr_loss:0.4965, tr_acc:0.6389,   val_loss:0.3362, val_acc:0.5556
epoch:14, tr_loss:0.4890, tr_acc:0.6667,   val_loss:0.3372, val_acc:0.5556
epoch:15, tr_loss:0.5105, tr_acc:0.6389,   val_loss:0.3381, val_acc:0.6667
epoch:16, tr_loss:0.5057, tr_acc:0.6667,   val_loss:0.3396, val_acc:0.6667
epoch:17, tr_loss:0.4659, tr_acc:0.7500,   val_loss:0.3417, val_acc:0.6667
epoch:18, tr_loss:0.5034, tr_acc:0.7222,   val_loss:0.3435, val_acc:0.6667
epoch:19, tr_loss:0.4633, tr_acc:0.7500,   val_loss:0.3454, val_acc:0.6667
epoch:20, tr_loss:0.4565, tr_acc:0.7778,   val_loss:0.3480, val_acc:0.5556
epoch:21, tr_loss:0.4694, t

epoch:16, tr_loss:0.4924, tr_acc:0.7500,   val_loss:0.3582, val_acc:0.5556
epoch:17, tr_loss:0.4630, tr_acc:0.6944,   val_loss:0.3657, val_acc:0.5556
epoch:18, tr_loss:0.4587, tr_acc:0.7222,   val_loss:0.3784, val_acc:0.5556
epoch:19, tr_loss:0.4445, tr_acc:0.7778,   val_loss:0.3676, val_acc:0.5556
epoch:20, tr_loss:0.4890, tr_acc:0.6389,   val_loss:0.3577, val_acc:0.5556
epoch:21, tr_loss:0.4535, tr_acc:0.8056,   val_loss:0.3659, val_acc:0.6667
epoch:22, tr_loss:0.4673, tr_acc:0.7500,   val_loss:0.3827, val_acc:0.5556
epoch:23, tr_loss:0.4740, tr_acc:0.6667,   val_loss:0.3756, val_acc:0.6667
epoch:24, tr_loss:0.4397, tr_acc:0.7222,   val_loss:0.3926, val_acc:0.5556
epoch:25, tr_loss:0.4077, tr_acc:0.6944,   val_loss:0.4095, val_acc:0.4444
epoch:26, tr_loss:0.4361, tr_acc:0.7778,   val_loss:0.4153, val_acc:0.4444
epoch:27, tr_loss:0.4228, tr_acc:0.8056,   val_loss:0.4307, val_acc:0.4444
epoch:28, tr_loss:0.3724, tr_acc:0.8889,   val_loss:0.4176, val_acc:0.2222
epoch:29, tr_loss:0.4181,

epoch:25, tr_loss:0.4659, tr_acc:0.7222,   val_loss:0.3193, val_acc:0.6667
epoch:26, tr_loss:0.4229, tr_acc:0.8889,   val_loss:0.3212, val_acc:0.6667
epoch:27, tr_loss:0.4129, tr_acc:0.8611,   val_loss:0.3271, val_acc:0.6667
epoch:28, tr_loss:0.4181, tr_acc:0.8056,   val_loss:0.3294, val_acc:0.5556
epoch:29, tr_loss:0.3847, tr_acc:0.9167,   val_loss:0.3199, val_acc:0.6667
epoch:30, tr_loss:0.4298, tr_acc:0.8333,   val_loss:0.3292, val_acc:0.6667
epoch:31, tr_loss:0.4067, tr_acc:0.8333,   val_loss:0.3404, val_acc:0.5556
epoch:32, tr_loss:0.3572, tr_acc:0.8333,   val_loss:0.3364, val_acc:0.5556
epoch:33, tr_loss:0.3551, tr_acc:0.9167,   val_loss:0.3242, val_acc:0.6667
epoch:34, tr_loss:0.3875, tr_acc:0.8611,   val_loss:0.3173, val_acc:0.6667
epoch:35, tr_loss:0.3798, tr_acc:0.8333,   val_loss:0.3277, val_acc:0.6667
epoch:36, tr_loss:0.3387, tr_acc:0.8333,   val_loss:0.3447, val_acc:0.5556
epoch:37, tr_loss:0.3447, tr_acc:0.9167,   val_loss:0.3448, val_acc:0.6667
epoch:38, tr_loss:0.3452,

epoch:34, tr_loss:0.4104, tr_acc:0.8056,   val_loss:0.2740, val_acc:0.7778
epoch:35, tr_loss:0.3807, tr_acc:0.8611,   val_loss:0.2684, val_acc:0.7778
epoch:36, tr_loss:0.3690, tr_acc:0.8611,   val_loss:0.2656, val_acc:0.7778
epoch:37, tr_loss:0.3306, tr_acc:0.8611,   val_loss:0.2649, val_acc:0.7778
epoch:38, tr_loss:0.3569, tr_acc:0.8611,   val_loss:0.2784, val_acc:0.7778
epoch:39, tr_loss:0.3253, tr_acc:0.8611,   val_loss:0.2915, val_acc:0.7778
epoch:40, tr_loss:0.3184, tr_acc:0.8611,   val_loss:0.2859, val_acc:0.7778
epoch:41, tr_loss:0.3152, tr_acc:0.9722,   val_loss:0.2882, val_acc:0.7778
epoch:42, tr_loss:0.3074, tr_acc:0.9167,   val_loss:0.2831, val_acc:0.7778
epoch:43, tr_loss:0.2697, tr_acc:0.9444,   val_loss:0.2601, val_acc:0.7778
epoch:44, tr_loss:0.2661, tr_acc:0.9444,   val_loss:0.2579, val_acc:0.7778
epoch:45, tr_loss:0.2359, tr_acc:0.9444,   val_loss:0.2770, val_acc:0.7778
epoch:46, tr_loss:0.2824, tr_acc:0.8889,   val_loss:0.2726, val_acc:0.7778
epoch:47, tr_loss:0.2489,

epoch:38, tr_loss:0.1990, tr_acc:1.0000,   val_loss:0.6193, val_acc:0.2222
epoch:39, tr_loss:0.2322, tr_acc:0.9722,   val_loss:0.6199, val_acc:0.2222
epoch:40, tr_loss:0.1958, tr_acc:1.0000,   val_loss:0.6286, val_acc:0.2222
epoch:41, tr_loss:0.1869, tr_acc:0.9722,   val_loss:0.6243, val_acc:0.2222
epoch:42, tr_loss:0.1981, tr_acc:0.9722,   val_loss:0.6328, val_acc:0.3333
epoch:43, tr_loss:0.1710, tr_acc:1.0000,   val_loss:0.6525, val_acc:0.3333
epoch:44, tr_loss:0.1605, tr_acc:1.0000,   val_loss:0.6652, val_acc:0.3333
epoch:45, tr_loss:0.1227, tr_acc:1.0000,   val_loss:0.6835, val_acc:0.3333
epoch:46, tr_loss:0.1373, tr_acc:0.9722,   val_loss:0.6927, val_acc:0.3333
epoch:47, tr_loss:0.1554, tr_acc:1.0000,   val_loss:0.7461, val_acc:0.2222
epoch:48, tr_loss:0.1602, tr_acc:1.0000,   val_loss:0.7744, val_acc:0.2222
epoch:49, tr_loss:0.1160, tr_acc:1.0000,   val_loss:0.7896, val_acc:0.2222
epoch:50, tr_loss:0.1551, tr_acc:0.9722,   val_loss:0.7769, val_acc:0.2222
epoch:51, tr_loss:0.1397,

epoch:46, tr_loss:0.1846, tr_acc:1.0000,   val_loss:0.6115, val_acc:0.3333
epoch:47, tr_loss:0.1810, tr_acc:1.0000,   val_loss:0.6223, val_acc:0.3333
epoch:48, tr_loss:0.1637, tr_acc:1.0000,   val_loss:0.6167, val_acc:0.4444
epoch:49, tr_loss:0.1536, tr_acc:1.0000,   val_loss:0.6239, val_acc:0.4444
epoch:50, tr_loss:0.2033, tr_acc:0.9167,   val_loss:0.6498, val_acc:0.3333
epoch:51, tr_loss:0.1760, tr_acc:1.0000,   val_loss:0.6705, val_acc:0.3333
epoch:52, tr_loss:0.1363, tr_acc:0.9722,   val_loss:0.6883, val_acc:0.3333
epoch:53, tr_loss:0.1001, tr_acc:1.0000,   val_loss:0.6976, val_acc:0.3333
epoch:54, tr_loss:0.1076, tr_acc:1.0000,   val_loss:0.7013, val_acc:0.3333
epoch:55, tr_loss:0.1359, tr_acc:0.9722,   val_loss:0.7217, val_acc:0.3333
epoch:56, tr_loss:0.0767, tr_acc:1.0000,   val_loss:0.7583, val_acc:0.3333
epoch:57, tr_loss:0.0953, tr_acc:1.0000,   val_loss:0.7633, val_acc:0.3333
epoch:58, tr_loss:0.0923, tr_acc:1.0000,   val_loss:0.7638, val_acc:0.3333
epoch:59, tr_loss:0.0692,

epoch:54, tr_loss:0.2029, tr_acc:0.9722,   val_loss:0.6800, val_acc:0.4444
epoch:55, tr_loss:0.2046, tr_acc:0.8889,   val_loss:0.6910, val_acc:0.4444
epoch:56, tr_loss:0.1640, tr_acc:1.0000,   val_loss:0.6313, val_acc:0.5556
epoch:57, tr_loss:0.1609, tr_acc:0.9444,   val_loss:0.6695, val_acc:0.5556
epoch:58, tr_loss:0.1414, tr_acc:0.9444,   val_loss:0.7202, val_acc:0.5556
epoch:59, tr_loss:0.1658, tr_acc:0.9722,   val_loss:0.6831, val_acc:0.5556
epoch:60, tr_loss:0.1511, tr_acc:0.9722,   val_loss:0.7074, val_acc:0.5556
epoch:61, tr_loss:0.1077, tr_acc:0.9722,   val_loss:0.7312, val_acc:0.5556
epoch:62, tr_loss:0.0968, tr_acc:1.0000,   val_loss:0.7235, val_acc:0.5556
epoch:63, tr_loss:0.0883, tr_acc:0.9722,   val_loss:0.7141, val_acc:0.5556
epoch:64, tr_loss:0.1363, tr_acc:0.9722,   val_loss:0.7611, val_acc:0.5556
epoch:65, tr_loss:0.1004, tr_acc:1.0000,   val_loss:0.7789, val_acc:0.5556
epoch:66, tr_loss:0.1113, tr_acc:0.9444,   val_loss:0.7887, val_acc:0.5556
epoch:67, tr_loss:0.1387,

epoch:64, tr_loss:0.1900, tr_acc:0.9722,   val_loss:0.2525, val_acc:0.6667
epoch:65, tr_loss:0.1516, tr_acc:0.9722,   val_loss:0.2641, val_acc:0.5556
epoch:66, tr_loss:0.1233, tr_acc:1.0000,   val_loss:0.2655, val_acc:0.6667
epoch:67, tr_loss:0.1358, tr_acc:0.9722,   val_loss:0.3003, val_acc:0.5556
epoch:68, tr_loss:0.1253, tr_acc:0.9444,   val_loss:0.3140, val_acc:0.5556
epoch:69, tr_loss:0.1218, tr_acc:1.0000,   val_loss:0.2620, val_acc:0.6667
epoch:70, tr_loss:0.1209, tr_acc:1.0000,   val_loss:0.2330, val_acc:0.6667
epoch:71, tr_loss:0.1039, tr_acc:1.0000,   val_loss:0.2463, val_acc:0.6667
epoch:72, tr_loss:0.0616, tr_acc:1.0000,   val_loss:0.2638, val_acc:0.6667
epoch:73, tr_loss:0.0846, tr_acc:0.9722,   val_loss:0.2654, val_acc:0.5556
epoch:74, tr_loss:0.0951, tr_acc:0.9722,   val_loss:0.2520, val_acc:0.6667
epoch:75, tr_loss:0.0941, tr_acc:1.0000,   val_loss:0.2374, val_acc:0.6667
epoch:76, tr_loss:0.1041, tr_acc:0.9722,   val_loss:0.2281, val_acc:0.6667
epoch:77, tr_loss:0.1109,

epoch:74, tr_loss:0.0473, tr_acc:1.0000,   val_loss:0.4999, val_acc:0.6667
epoch:75, tr_loss:0.0572, tr_acc:1.0000,   val_loss:0.5165, val_acc:0.6667
epoch:76, tr_loss:0.0554, tr_acc:1.0000,   val_loss:0.5240, val_acc:0.6667
epoch:77, tr_loss:0.0550, tr_acc:1.0000,   val_loss:0.5075, val_acc:0.6667
epoch:78, tr_loss:0.0795, tr_acc:0.9722,   val_loss:0.5251, val_acc:0.6667
epoch:79, tr_loss:0.0472, tr_acc:1.0000,   val_loss:0.5562, val_acc:0.6667
epoch:80, tr_loss:0.0626, tr_acc:1.0000,   val_loss:0.5752, val_acc:0.6667
epoch:81, tr_loss:0.0620, tr_acc:1.0000,   val_loss:0.5984, val_acc:0.5556
epoch:82, tr_loss:0.1012, tr_acc:0.9444,   val_loss:0.6372, val_acc:0.4444
epoch:83, tr_loss:0.0785, tr_acc:0.9722,   val_loss:0.5305, val_acc:0.6667
epoch:84, tr_loss:0.0844, tr_acc:0.9722,   val_loss:0.5026, val_acc:0.6667
epoch:85, tr_loss:0.0800, tr_acc:0.9722,   val_loss:0.5494, val_acc:0.5556
epoch:86, tr_loss:0.0601, tr_acc:1.0000,   val_loss:0.5649, val_acc:0.5556
epoch:87, tr_loss:0.0674,

epoch:78, tr_loss:0.3593, tr_acc:0.8333,   val_loss:0.5075, val_acc:0.4444
epoch:79, tr_loss:0.3234, tr_acc:0.8889,   val_loss:0.5130, val_acc:0.4444
epoch:80, tr_loss:0.3136, tr_acc:0.8889,   val_loss:0.4971, val_acc:0.6667
epoch:81, tr_loss:0.3505, tr_acc:0.8056,   val_loss:0.4713, val_acc:0.6667
epoch:82, tr_loss:0.3566, tr_acc:0.8333,   val_loss:0.4802, val_acc:0.5556
epoch:83, tr_loss:0.3548, tr_acc:0.8056,   val_loss:0.4780, val_acc:0.5556
epoch:84, tr_loss:0.2725, tr_acc:0.8889,   val_loss:0.5205, val_acc:0.5556
epoch:85, tr_loss:0.2960, tr_acc:0.8611,   val_loss:0.4389, val_acc:0.6667
epoch:86, tr_loss:0.2754, tr_acc:0.8333,   val_loss:0.4896, val_acc:0.5556
epoch:87, tr_loss:0.3003, tr_acc:0.8056,   val_loss:0.4731, val_acc:0.5556
epoch:88, tr_loss:0.3613, tr_acc:0.8611,   val_loss:0.5392, val_acc:0.5556
epoch:89, tr_loss:0.2525, tr_acc:0.9167,   val_loss:0.5654, val_acc:0.4444
epoch:90, tr_loss:0.2993, tr_acc:0.8611,   val_loss:0.5549, val_acc:0.5556
epoch:91, tr_loss:0.3059,

epoch:86, tr_loss:0.2569, tr_acc:0.9167,   val_loss:0.6426, val_acc:0.2222
epoch:87, tr_loss:0.2984, tr_acc:0.8333,   val_loss:0.6468, val_acc:0.2222
epoch:88, tr_loss:0.2405, tr_acc:0.9167,   val_loss:0.6300, val_acc:0.2222
epoch:89, tr_loss:0.2393, tr_acc:0.9722,   val_loss:0.6043, val_acc:0.3333
epoch:90, tr_loss:0.2512, tr_acc:0.8611,   val_loss:0.6044, val_acc:0.3333
epoch:91, tr_loss:0.2887, tr_acc:0.8889,   val_loss:0.6277, val_acc:0.3333
epoch:92, tr_loss:0.2813, tr_acc:0.8333,   val_loss:0.6460, val_acc:0.3333
epoch:93, tr_loss:0.2518, tr_acc:0.8333,   val_loss:0.6293, val_acc:0.4444
epoch:94, tr_loss:0.3009, tr_acc:0.8333,   val_loss:0.6325, val_acc:0.4444
epoch:95, tr_loss:0.2652, tr_acc:0.8889,   val_loss:0.6452, val_acc:0.3333
epoch:96, tr_loss:0.2815, tr_acc:0.8056,   val_loss:0.6234, val_acc:0.3333
epoch:97, tr_loss:0.2349, tr_acc:0.9444,   val_loss:0.5903, val_acc:0.4444
epoch:98, tr_loss:0.3037, tr_acc:0.8611,   val_loss:0.5908, val_acc:0.4444
epoch:99, tr_loss:0.1727,

epoch:94, tr_loss:0.3807, tr_acc:0.8611,   val_loss:0.4206, val_acc:0.4444
epoch:95, tr_loss:0.3058, tr_acc:0.8611,   val_loss:0.4067, val_acc:0.5556
epoch:96, tr_loss:0.3057, tr_acc:0.8889,   val_loss:0.3764, val_acc:0.6667
epoch:97, tr_loss:0.3621, tr_acc:0.8611,   val_loss:0.4410, val_acc:0.6667
epoch:98, tr_loss:0.3455, tr_acc:0.8611,   val_loss:0.4037, val_acc:0.6667
epoch:99, tr_loss:0.3439, tr_acc:0.7778,   val_loss:0.4050, val_acc:0.6667
epoch:100, tr_loss:0.3195, tr_acc:0.8889,   val_loss:0.4351, val_acc:0.4444
----------finish training---------
training_best_acc:0.8888888888888888, val_best_acc:0.6666666666666666
----------start training----------
epoch:1, tr_loss:0.5927, tr_acc:0.4444,   val_loss:0.3614, val_acc:0.5556
epoch:2, tr_loss:0.5760, tr_acc:0.4722,   val_loss:0.3571, val_acc:0.5556
epoch:3, tr_loss:0.5699, tr_acc:0.5556,   val_loss:0.3542, val_acc:0.5556
epoch:4, tr_loss:0.5728, tr_acc:0.4722,   val_loss:0.3529, val_acc:0.5556
epoch:5, tr_loss:0.5597, tr_acc:0.5278

epoch:1, tr_loss:0.5487, tr_acc:0.6111,   val_loss:0.3514, val_acc:0.4444
epoch:2, tr_loss:0.5627, tr_acc:0.5000,   val_loss:0.3546, val_acc:0.4444
epoch:3, tr_loss:0.5635, tr_acc:0.5278,   val_loss:0.3561, val_acc:0.4444
epoch:4, tr_loss:0.5518, tr_acc:0.5278,   val_loss:0.3557, val_acc:0.4444
epoch:5, tr_loss:0.5628, tr_acc:0.5000,   val_loss:0.3548, val_acc:0.4444
epoch:6, tr_loss:0.5472, tr_acc:0.5278,   val_loss:0.3544, val_acc:0.4444
epoch:7, tr_loss:0.5460, tr_acc:0.5833,   val_loss:0.3559, val_acc:0.4444
epoch:8, tr_loss:0.5425, tr_acc:0.6111,   val_loss:0.3575, val_acc:0.3333
epoch:9, tr_loss:0.5610, tr_acc:0.5278,   val_loss:0.3594, val_acc:0.3333
epoch:10, tr_loss:0.5775, tr_acc:0.5556,   val_loss:0.3593, val_acc:0.3333
epoch:11, tr_loss:0.5550, tr_acc:0.4444,   val_loss:0.3589, val_acc:0.3333
epoch:12, tr_loss:0.5428, tr_acc:0.6111,   val_loss:0.3580, val_acc:0.3333
epoch:13, tr_loss:0.5691, tr_acc:0.5000,   val_loss:0.3576, val_acc:0.3333
epoch:14, tr_loss:0.5388, tr_acc:0

epoch:5, tr_loss:0.5275, tr_acc:0.6389,   val_loss:0.3705, val_acc:0.4444
epoch:6, tr_loss:0.5140, tr_acc:0.6111,   val_loss:0.3705, val_acc:0.4444
epoch:7, tr_loss:0.4751, tr_acc:0.7500,   val_loss:0.3746, val_acc:0.4444
epoch:8, tr_loss:0.5175, tr_acc:0.6111,   val_loss:0.3718, val_acc:0.4444
epoch:9, tr_loss:0.4972, tr_acc:0.6389,   val_loss:0.3569, val_acc:0.4444
epoch:10, tr_loss:0.4500, tr_acc:0.8333,   val_loss:0.3553, val_acc:0.3333
epoch:11, tr_loss:0.4447, tr_acc:0.8333,   val_loss:0.3604, val_acc:0.3333
epoch:12, tr_loss:0.4605, tr_acc:0.6944,   val_loss:0.3677, val_acc:0.4444
epoch:13, tr_loss:0.4551, tr_acc:0.7500,   val_loss:0.3684, val_acc:0.3333
epoch:14, tr_loss:0.4330, tr_acc:0.8056,   val_loss:0.3805, val_acc:0.4444
epoch:15, tr_loss:0.3805, tr_acc:0.9167,   val_loss:0.3830, val_acc:0.4444
epoch:16, tr_loss:0.4066, tr_acc:0.8333,   val_loss:0.3982, val_acc:0.4444
epoch:17, tr_loss:0.3758, tr_acc:0.8333,   val_loss:0.4161, val_acc:0.4444
epoch:18, tr_loss:0.3670, tr_a

epoch:15, tr_loss:0.4776, tr_acc:0.6944,   val_loss:0.3337, val_acc:0.5556
epoch:16, tr_loss:0.4755, tr_acc:0.8056,   val_loss:0.3300, val_acc:0.5556
epoch:17, tr_loss:0.4437, tr_acc:0.8333,   val_loss:0.3297, val_acc:0.5556
epoch:18, tr_loss:0.4227, tr_acc:0.8889,   val_loss:0.3297, val_acc:0.5556
epoch:19, tr_loss:0.4341, tr_acc:0.8889,   val_loss:0.3296, val_acc:0.6667
epoch:20, tr_loss:0.4527, tr_acc:0.8056,   val_loss:0.3303, val_acc:0.5556
epoch:21, tr_loss:0.4072, tr_acc:0.8333,   val_loss:0.3296, val_acc:0.5556
epoch:22, tr_loss:0.3896, tr_acc:0.8611,   val_loss:0.3326, val_acc:0.5556
epoch:23, tr_loss:0.4081, tr_acc:0.8333,   val_loss:0.3389, val_acc:0.5556
epoch:24, tr_loss:0.3740, tr_acc:0.8611,   val_loss:0.3378, val_acc:0.5556
epoch:25, tr_loss:0.4065, tr_acc:0.8611,   val_loss:0.3401, val_acc:0.5556
epoch:26, tr_loss:0.3899, tr_acc:0.8333,   val_loss:0.3580, val_acc:0.5556
epoch:27, tr_loss:0.3538, tr_acc:0.8889,   val_loss:0.3569, val_acc:0.5556
epoch:28, tr_loss:0.3154,

In [None]:
accs = np.array(all_accs_list)
accs = accs.reshape(-1)
sns.set_style("ticks")
plt.title('histgram of validation accracy')
plt.xlabel('accuracy')
sns.distplot(accs, kde=False);
plt.show()


In [None]:
accs = np.array(all_mean_list)
accs = accs.reshape(-1)
sns.set_style("ticks")
plt.title('histgram of cross validation accracy')
plt.xlabel('accuracy')
sns.distplot(accs, kde=False);
plt.show()

In [None]:
x