In [1]:
import numpy as np
import argparse
import os
import imp
import re
import pickle
import datetime
import random
import math

RANDOM_SEED = 12345
np.random.seed(RANDOM_SEED) #numpy
random.seed(RANDOM_SEED)

import torch
from torch import nn
import torch.nn.utils.rnn as rnn_utils
from torch.utils import data
from torch.autograd import Variable
import torch.nn.functional as F

torch.manual_seed(RANDOM_SEED) # cpu
torch.cuda.manual_seed(RANDOM_SEED) #gpu
torch.backends.cudnn.deterministic=True # cudnn

from utils import utils
from utils.readers import DecompensationReader
from utils.preprocessing import Discretizer, Normalizer
from utils import metrics
from utils import common_utils

In [5]:
data_path = './data/'
small_part = False
arg_timestep = 1.0
batch_size = 128
epochs = 50

In [3]:
train_data_loader = common_utils.DeepSupervisionDataLoader(dataset_dir=os.path.join(
    data_path, 'train'), listfile=os.path.join(data_path, 'train_listfile.csv'), small_part=small_part)
val_data_loader = common_utils.DeepSupervisionDataLoader(dataset_dir=os.path.join(
    data_path, 'train'), listfile=os.path.join(data_path, 'val_listfile.csv'), small_part=small_part)
discretizer = Discretizer(timestep=arg_timestep, store_masks=True,
                          impute_strategy='previous', start_time='zero')

In [4]:
discretizer_header = discretizer.transform(train_data_loader._data["X"][0])[1].split(',')
cont_channels = [i for (i, x) in enumerate(discretizer_header) if x.find("->") == -1]

normalizer = Normalizer(fields=cont_channels)  # choose here which columns to standardize
normalizer_state = 'decomp_normalizer'
normalizer_state = os.path.join(os.path.dirname(data_path), normalizer_state)
normalizer.load_params(normalizer_state)

In [20]:
n_trained_chunks = 0

train_data_gen = utils.BatchGenDeepSupervision(train_data_loader, discretizer,
                                                normalizer, batch_size, shuffle=True, return_names=True)
val_data_gen = utils.BatchGenDeepSupervision(val_data_loader, discretizer,
                                                normalizer, batch_size, shuffle=False, return_names=True)

In [None]:
TrueTruedemographic_data = []
diagnosis_data = []
idx_list = []
t_cnt = 0
m_cnt = []

demo_path = data_path + 'demographic/'
for cur_name in os.listdir(demo_path):
    t_cnt+=1
    cur_id, cur_episode = cur_name.split('_', 1)
    cur_episode = cur_episode[:-4]
    cur_file = demo_path + cur_name

    with open(cur_file, "r") as tsfile:
        header = tsfile.readline().strip().split(',')
        if header[0] != "Icustay":
            continue
        cur_data = tsfile.readline().strip().split(',')
        
    if len(cur_data) == 1:
        cur_demo = np.zeros(12)
        cur_diag = np.zeros(128)
    else:
        if cur_data[3] == '':
            cur_data[3] = 60.0
        if cur_data[4] == '':
            cur_data[4] = 160
        if cur_data[5] == '':
            cur_data[5] = 60

        cur_demo = np.zeros(12)
        cur_demo[int(cur_data[1])] = 1
        cur_demo[5 + int(cur_data[2])] = 1
        m_cnt.append(int(cur_data[2]))
        cur_demo[9:] = cur_data[3:6]
        cur_diag = np.array(cur_data[8:], dtype=np.int)

    demographic_data.append(cur_demo)
    diagnosis_data.append(cur_diag)
    idx_list.append(cur_id+'_'+cur_episode)

for each_idx in range(9,12):
    cur_val = []
    for i in range(len(demographic_data)):
        cur_val.append(demographic_data[i][each_idx])
    cur_val = np.array(cur_val)
    _mean = np.mean(cur_val)
    _std = np.std(cur_val)
    _std = _std if _std > 1e-7 else 1e-7
    for i in range(len(demographic_data)):
        demographic_data[i][each_idx] = (demographic_data[i][each_idx] - _mean) / _std

In [7]:
f_cnt = 0
for val in m_cnt:
    if val==1:
        f_cnt+=1
f_cnt/t_cnt

0.44104699583581203

In [7]:
device = torch.device("cuda:0" if torch.cuda.is_available() == True else 'cpu')
#device = torch.device('cpu')
print("available device: {}".format(device))

available device: cuda:0


In [8]:
class SELayer(nn.Module):
    def __init__(self, channel, reduction=9):
        super(SELayer, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool1d(1)
        self.nn_h = nn.Linear(channel, channel // reduction)
        self.nn_rescale = nn.Linear(channel // reduction, channel)

    def forward(self, x):
        b, c, t = x.size()

        y_pool = self.avg_pool(x).view(b, c) #B*C(*1)
        se_h = self.nn_h(y_pool)
        se_h = torch.relu(se_h)
        se_h = self.nn_rescale(se_h).view(b, c, 1)
        se_h = torch.sigmoid(se_h)
        return x * se_h.expand_as(x), se_h

In [9]:
class stage_LSTM(nn.Module):
    def __init__(self, input_dim, hidden_dim, conv_dim, conv_size, output_dim, levels, dropconnect=0., dropout=0., dropres=0.3):
        super(stage_LSTM, self).__init__()
        
        assert hidden_dim % levels == 0
        self.dropout = dropout
        self.dropconnect = dropconnect
        self.dropres = dropres
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.conv_dim = conv_dim
        self.conv_size = conv_size
        self.output_dim = output_dim
        self.levels = levels
        self.chunk_size = hidden_dim // levels
        
        self.kernel = nn.Linear(input_dim+1, hidden_dim*4+levels*2)
        nn.init.xavier_uniform_(self.kernel.weight)
        nn.init.zeros_(self.kernel.bias)
        self.recurrent_kernel = nn.Linear(hidden_dim+1, hidden_dim*4+levels*2)
        nn.init.orthogonal_(self.recurrent_kernel.weight)
        nn.init.zeros_(self.recurrent_kernel.bias)
        
        self.nn_scale = nn.Linear(hidden_dim, hidden_dim // 6)
        self.nn_rescale = nn.Linear(hidden_dim // 6, hidden_dim)
        self.nn_conv = nn.Conv1d(hidden_dim, conv_dim, conv_size, 1)
        self.nn_output = nn.Linear(conv_dim, output_size)
        
        if self.dropconnect:
            self.nn_dropconnect = nn.Dropout(p=dropconnect)
            self.nn_dropconnect_r = nn.Dropout(p=dropconnect)
        if self.dropout:
            self.nn_dropout = nn.Dropout(p=dropout)
            self.nn_dropres = nn.Dropout(p=dropres)
    
    def cumax(self, x, mode='l2r'):
        if mode == 'l2r':
            x = torch.softmax(x, dim=-1)
            x = torch.cumsum(x, dim=-1)
            return x
        elif mode == 'r2l':
            x = torch.flip(x, [-1])
            x = torch.softmax(x, dim=-1)
            x = torch.cumsum(x, dim=-1)
            return torch.flip(x, [-1])
        else:
            return x
    
    def step(self, inputs, c_last, h_last):
        x_in = inputs
        delta_t = np.array([1.0])
        interval = torch.ones((x_in.size(0),1),dtype=torch.float32).to(device)
        x_out1 = self.kernel(torch.cat((x_in,interval),dim=-1))
        x_out2 = self.recurrent_kernel(torch.cat((h_last,interval),dim=-1))
        if self.dropconnect:
            x_out1 = self.nn_dropconnect(x_out1)
            x_out2 = self.nn_dropconnect_r(x_out2)
        x_out = x_out1 + x_out2
        f_master_gate = self.cumax(x_out[:, :self.levels], 'l2r')
        f_master_gate = f_master_gate.unsqueeze(2)
        i_master_gate = self.cumax(x_out[:, self.levels:self.levels*2], 'r2l')
        i_master_gate = i_master_gate.unsqueeze(2)
        x_out = x_out[:, self.levels*2:]
        x_out = x_out.reshape(-1, self.levels*4, self.chunk_size)
        f_gate = torch.sigmoid(x_out[:, :self.levels])
        i_gate = torch.sigmoid(x_out[:, self.levels:self.levels*2])
        o_gate = torch.sigmoid(x_out[:, self.levels*2:self.levels*3])
        c_in = torch.tanh(x_out[:, self.levels*3:])
        c_last = c_last.reshape(-1, self.levels, self.chunk_size)
        overlap = f_master_gate * i_master_gate
        c_out = overlap * (f_gate * c_last + i_gate * c_in) + (f_master_gate - overlap) * c_last + (i_master_gate - overlap) * c_in
        h_out = o_gate * torch.tanh(c_out)
        c_out = c_out.reshape(-1, self.hidden_dim)
        h_out = h_out.reshape(-1, self.hidden_dim)
        out = torch.cat([h_out, f_master_gate[..., 0], i_master_gate[..., 0]], 1)
        return out, c_out, h_out
        
    
    def forward(self, input):
        batch_size, time_step, feature_dim = input.size()
        c_out = torch.zeros(batch_size, self.hidden_dim).to(device)
        h_out = torch.zeros(batch_size, self.hidden_dim).to(device)
        
        #s*B*H
        tmp_h = torch.zeros_like(h_out, dtype=torch.float32).view(-1).repeat(self.conv_size).view(self.conv_size, batch_size, self.hidden_dim).to(device)
        tmp_dis = torch.zeros((self.conv_size, batch_size)).to(device)
        h = []
        origin_h = []
        distance = []
        for t in range(time_step):
            out, c_out, h_out = self.step(input[:, t, :], c_out, h_out)
            cur_distance = 1 - torch.mean(out[..., self.hidden_dim:self.hidden_dim+self.levels], -1)
            cur_distance_in = torch.mean(out[..., self.hidden_dim+self.levels:], -1)
            origin_h.append(out[..., :self.hidden_dim])
            tmp_h = torch.cat((tmp_h[1:], out[..., :self.hidden_dim].unsqueeze(0)), 0)
            tmp_dis = torch.cat((tmp_dis[1:], cur_distance.unsqueeze(0)), 0)
            distance.append(cur_distance)
            
            local_dis = tmp_dis.permute(1, 0)
            local_dis = torch.cumsum(local_dis, dim=1)
            local_dis = torch.softmax(local_dis, dim=1)
            local_h = tmp_h.permute(1, 2, 0)
            local_h = local_h * local_dis.unsqueeze(1)
            
            local_theme = torch.mean(local_h, dim=-1)
            local_theme = self.nn_scale(local_theme)
            local_theme = torch.relu(local_theme)
            local_theme = self.nn_rescale(local_theme)
            local_theme = torch.sigmoid(local_theme)
            
            local_h = self.nn_conv(local_h).squeeze(-1)
            local_h = local_theme * local_h
            h.append(local_h)  

        origin_h = torch.stack(origin_h).permute(1, 0, 2)
        rnn_outputs = torch.stack(h).permute(1, 0, 2)
        if self.dropres > 0.0:
            origin_h = self.nn_dropres(origin_h)
        rnn_outputs = rnn_outputs + origin_h
        rnn_outputs = rnn_outputs.contiguous().view(-1, rnn_outputs.size(-1))
        if self.dropout > 0.0:
            rnn_outputs = self.nn_dropout(rnn_outputs)
        output = self.nn_output(rnn_outputs)
        output = output.contiguous().view(batch_size, time_step, self.output_dim)
        output = torch.sigmoid(output)

        return output, torch.stack(distance)

In [10]:
input_size = 76+17
hidden_size = 384
conv_dim = hidden_size
conv_size = 10
output_size = 1
levels = 3
dropconnect = 0.3
dropout = 0.3

model = stage_LSTM(input_size, hidden_size, conv_dim, conv_size, output_size, levels, dropconnect, dropout).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [17]:
train_loss = []
val_loss = []
batch_loss = []
train_history = []
val_history = []
max_auroc = 0
max_auprc = 0
max_len = 400
file_name = 'sa-crnn-se2'

for each_chunk in range(epochs):
    cur_batch_loss = []
    #train_data_gen.steps
    model.train()
    for each_batch in range(train_data_gen.steps):
        starttime = datetime.datetime.now()
        batch_data = next(train_data_gen)
        batch_name = batch_data['names']
        batch_data = batch_data['data']
        
        batch_demo = []
        for i in range(len(batch_name)):
            cur_id, cur_ep, _ = batch_name[i].split('_', 2)
            cur_idx= cur_id + '_' + cur_ep
            cur_demo = torch.tensor(demographic_data[idx_list.index(cur_idx)], dtype=torch.float32)
            batch_demo.append(cur_demo)
        
        batch_demo = torch.stack(batch_demo).to(device)
        batch_x = torch.tensor(batch_data[0][0], dtype=torch.float32).to(device)
        batch_mask = torch.tensor(batch_data[0][1], dtype=torch.float32).unsqueeze(-1).to(device)
        batch_y = torch.tensor(batch_data[1], dtype=torch.float32).to(device)
        batch_time = torch.zeros(batch_x.size(0),17, dtype=torch.float32).to(device)
        batch_interval = torch.zeros((batch_x.size(0),batch_x.size(1),17), dtype=torch.float32).to(device)
        
        for i in range(batch_x.size(1)):
            cur_ind = batch_x[:,i,-17:]
            batch_time+=(cur_ind == 0).float()
            batch_interval[:, i, :] = cur_ind * batch_time
            batch_time[cur_ind==1] = 0        
        
        if batch_mask.size()[1] > max_len:
            batch_x = batch_x[:, :max_len, :]
            batch_mask = batch_mask[:, :max_len, :]
            batch_y = batch_y[:, :max_len, :]
            batch_interval = batch_interval[:, :max_len, :]
        
        batch_x = torch.cat((batch_x, batch_interval), dim=-1)
        optimizer.zero_grad()
        cur_output, cur_dis = model(batch_x) #B T 1
        masked_output = cur_output * batch_mask 
        loss = batch_y * torch.log(masked_output + 1e-7) + (1 - batch_y) * torch.log(1 - masked_output + 1e-7)
        loss = torch.sum(loss, dim=1) / torch.sum(batch_mask, dim=1)
        loss = torch.neg(torch.sum(loss))
        cur_batch_loss.append(loss.cpu().detach().numpy())
        
        loss.backward()
        optimizer.step()
        if each_batch % 50 == 0:
            print('Chunk %d, Batch %d: Loss = %.4f'%(each_chunk, each_batch, cur_batch_loss[-1]))

    batch_loss.append(cur_batch_loss)
    train_loss.append(np.mean(np.array(cur_batch_loss)))
    
    print("\n==>Predicting on validation")
    with torch.no_grad():
        model.eval()
        cur_val_loss = []
        valid_true = []
        valid_pred = []
        for each_batch in range(val_data_gen.steps):
            valid_data = next(val_data_gen)
            valid_name = valid_data['names']
            valid_data = valid_data['data']
            
            valid_demo = []
            for i in range(len(valid_name)):
                cur_id, cur_ep, _ = valid_name[i].split('_', 2)
                cur_idx = cur_id + '_' + cur_ep
                cur_demo = torch.tensor(demographic_data[idx_list.index(cur_idx)], dtype=torch.float32)
                valid_demo.append(cur_demo)
            
            valid_demo = torch.stack(valid_demo).to(device)
            valid_x = torch.tensor(valid_data[0][0], dtype=torch.float32).to(device)
            valid_mask = torch.tensor(valid_data[0][1], dtype=torch.float32).unsqueeze(-1).to(device)
            valid_y = torch.tensor(valid_data[1], dtype=torch.float32).to(device)
            valid_time = torch.zeros(valid_x.size(0),17, dtype=torch.float32).to(device)
            valid_interval = torch.zeros((valid_x.size(0),valid_x.size(1),17), dtype=torch.float32).to(device)
            
            for i in range(valid_x.size(1)):
                cur_ind = valid_x[:,i,-17:]
                valid_time+=(cur_ind == 0).float()
                valid_interval[:, i, :] = cur_ind * valid_time
                valid_time[cur_ind==1] = 0  
            
            if valid_mask.size()[1] > max_len:
                valid_x = valid_x[:, :max_len, :]
                valid_mask = valid_mask[:, :max_len, :]
                valid_y = valid_y[:, :max_len, :]
                valid_interval = valid_interval[:, :max_len, :]
            
            valid_x = torch.cat((valid_x, valid_interval), dim=-1)
            valid_output, valid_dis = model(valid_x)
            masked_valid_output = valid_output * valid_mask

            valid_loss = valid_y * torch.log(masked_valid_output + 1e-7) + (1 - valid_y) * torch.log(1 - masked_valid_output + 1e-7)
            valid_loss = torch.sum(valid_loss, dim=1) / torch.sum(valid_mask, dim=1)
            valid_loss = torch.neg(torch.sum(valid_loss))
            cur_val_loss.append(valid_loss.cpu().detach().numpy())

            for m, t, p in zip(valid_mask.cpu().numpy().flatten(), valid_y.cpu().numpy().flatten(), valid_output.cpu().detach().numpy().flatten()):
                if np.equal(m, 1):
                    valid_true.append(t)
                    valid_pred.append(p)

        val_loss.append(np.mean(np.array(cur_val_loss)))
        print('Valid loss = %.4f'%(val_loss[-1]))
        print('\n')
        valid_pred = np.array(valid_pred)
        valid_pred = np.stack([1 - valid_pred, valid_pred], axis=1)
        ret = metrics.print_metrics_binary(valid_true, valid_pred)
        val_history.append(ret)
        print()

        cur_auroc = ret['auroc']
        if cur_auroc > max_auroc:
            max_auroc = cur_auroc
            state = {
                'net': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'chunk': each_chunk
            }
            torch.save(state, file_name+'roc')
            #print('\n------------ Save best model ------------\n')
        
        cur_auprc = ret['auprc']
        if cur_auprc > max_auprc:
            max_auprc = cur_auprc
            state = {
                'net': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'chunk': each_chunk
            }
            torch.save(state, file_name)
            print('\n------------ Save best model ------------\n')

Chunk 0, Batch 0: Loss = 15.6404
Chunk 0, Batch 50: Loss = 12.2789
Chunk 0, Batch 100: Loss = 10.1171
Chunk 0, Batch 150: Loss = 11.1317
Chunk 0, Batch 200: Loss = 11.3013

==>Predicting on validation
Valid loss = 14.5204


confusion matrix:
[[459434   3168]
 [  8402   2447]]
accuracy = 0.9755623936653137
precision class 0 = 0.9820407032966614
precision class 1 = 0.43579697608947754
recall class 0 = 0.9931517839431763
recall class 1 = 0.22555074095726013
AUC of ROC = 0.8754742790918492
AUC of PRC = 0.26797221611047783
min(+P, Se) = 0.32617511520737325


------------ Save best model ------------

Chunk 1, Batch 0: Loss = 9.2685
Chunk 1, Batch 50: Loss = 7.8653
Chunk 1, Batch 100: Loss = 14.6111
Chunk 1, Batch 150: Loss = 6.4118
Chunk 1, Batch 200: Loss = 12.4769

==>Predicting on validation
Valid loss = 11.6407


confusion matrix:
[[489005   1250]
 [  9437   1373]]
accuracy = 0.9786714315414429
precision class 0 = 0.9810670018196106
precision class 1 = 0.5234464406967163
recall class 0 

Valid loss = 13.0378


confusion matrix:
[[461830   7473]
 [  6843   3908]]
accuracy = 0.9701783657073975
precision class 0 = 0.9853991866111755
precision class 1 = 0.34337931871414185
recall class 0 = 0.9840763807296753
recall class 1 = 0.36350107192993164
AUC of ROC = 0.8920960561839717
AUC of PRC = 0.28452883628349895
min(+P, Se) = 0.3511429102397324

Chunk 15, Batch 0: Loss = 6.4553
Chunk 15, Batch 50: Loss = 14.1998
Chunk 15, Batch 100: Loss = 7.9179
Chunk 15, Batch 150: Loss = 22.4189
Chunk 15, Batch 200: Loss = 9.1670

==>Predicting on validation
Valid loss = 15.2520


confusion matrix:
[[463485   6192]
 [  7076   3740]]
accuracy = 0.9723867177963257
precision class 0 = 0.9849626421928406
precision class 1 = 0.37656059861183167
recall class 0 = 0.9868164658546448
recall class 1 = 0.3457840383052826
AUC of ROC = 0.8885932128354722
AUC of PRC = 0.29940929830573665
min(+P, Se) = 0.3615192680898253

Chunk 16, Batch 0: Loss = 9.7424
Chunk 16, Batch 50: Loss = 4.7563
Chunk 16, Batch 1

KeyboardInterrupt: 

In [None]:
file_name = 'sa-crnn-se2'
checkpoint = torch.load(file_name) 
save_chunk = checkpoint['chunk']
print("last saved model is in chunk {}".format(save_chunk))
model.load_state_dict(checkpoint['net'])
optimizer.load_state_dict(checkpoint['optimizer'])
model.eval()

In [16]:
test_data_loader = common_utils.DeepSupervisionDataLoader(dataset_dir=os.path.join(data_path, 'test'),
                                                                  listfile=os.path.join(data_path, 'test_listfile.csv'), small_part=False)
test_data_gen = utils.BatchGenDeepSupervision(test_data_loader, discretizer,
                                              normalizer, batch_size,
                                              shuffle=False, return_names=True)

In [17]:
#testing the model lamda=1
max_len = 400
with torch.no_grad():
    cur_test_loss = []
    test_true = []
    test_pred = []
    
    for each_batch in range(test_data_gen.steps):
        test_data = next(test_data_gen)
        test_name = test_data['names']
        test_data = test_data['data']

        test_demo = []
        for i in range(len(test_name)):
            cur_id, cur_ep, _ = test_name[i].split('_', 2)
            cur_idx = cur_id + '_' + cur_ep

        test_x = torch.tensor(test_data[0][0], dtype=torch.float32).to(device)
        test_mask = torch.tensor(test_data[0][1], dtype=torch.float32).unsqueeze(-1).to(device)
        test_y = torch.tensor(test_data[1], dtype=torch.float32).to(device)
        test_time = torch.zeros(test_x.size(0),17, dtype=torch.float32).to(device)
        test_interval = torch.zeros((test_x.size(0),test_x.size(1),17), dtype=torch.float32).to(device)

        for i in range(test_x.size(1)):
            cur_ind = test_x[:,i,-17:]
            test_time+=(cur_ind == 0).float()
            test_interval[:, i, :] = cur_ind * test_time
            test_time[cur_ind==1] = 0  
        
        if test_mask.size()[1] > max_len:
            test_x = test_x[:, :max_len, :]
            test_mask = test_mask[:, :max_len, :]
            test_y = test_y[:, :max_len, :]
            test_interval = test_interval[:, :max_len, :]
        
        test_x = torch.cat((test_x, test_interval), dim=-1)
        test_output, _ = model(test_x)
        masked_test_output = test_output * test_mask

        test_loss = test_y * torch.log(masked_test_output + 1e-7) + (1 - test_y) * torch.log(1 - masked_test_output + 1e-7)
        test_loss = torch.sum(test_loss, dim=1) / torch.sum(test_mask, dim=1)
        test_loss = torch.neg(torch.sum(test_loss))
        cur_test_loss.append(test_loss.cpu().detach().numpy()) 
        
        for m, t, p in zip(test_mask.cpu().numpy().flatten(), test_y.cpu().numpy().flatten(), test_output.cpu().detach().numpy().flatten()):
            if np.equal(m, 1):
                test_true.append(t)
                test_pred.append(p)
        print('.....Done (%d/%d)'%(each_batch, test_data_gen.steps))
    
    print('Test loss = %.4f'%(np.mean(np.array(cur_test_loss))))
    print('\n')
    test_pred = np.array(test_pred)
    test_pred = np.stack([1 - test_pred, test_pred], axis=1)
    test_ret = metrics.print_metrics_binary(test_true, test_pred)

.....Done (0/49)
.....Done (1/49)
.....Done (2/49)
.....Done (3/49)
.....Done (4/49)
.....Done (5/49)
.....Done (6/49)
.....Done (7/49)
.....Done (8/49)
.....Done (9/49)
.....Done (10/49)
.....Done (11/49)
.....Done (12/49)
.....Done (13/49)
.....Done (14/49)
.....Done (15/49)
.....Done (16/49)
.....Done (17/49)
.....Done (18/49)
.....Done (19/49)
.....Done (20/49)
.....Done (21/49)
.....Done (22/49)
.....Done (23/49)
.....Done (24/49)
.....Done (25/49)
.....Done (26/49)
.....Done (27/49)
.....Done (28/49)
.....Done (29/49)
.....Done (30/49)
.....Done (31/49)
.....Done (32/49)
.....Done (33/49)
.....Done (34/49)
.....Done (35/49)
.....Done (36/49)
.....Done (37/49)
.....Done (38/49)
.....Done (39/49)
.....Done (40/49)
.....Done (41/49)
.....Done (42/49)
.....Done (43/49)
.....Done (44/49)
.....Done (45/49)
.....Done (46/49)
.....Done (47/49)
.....Done (48/49)
Test loss = 10.6832


confusion matrix:
[[462454   4025]
 [  6100   2903]]
accuracy = 0.9787058234214783
precision class 0 = 0.9