## Machine Learning project
[MI-net, instance space](#section_id)
[MI-net deep supervision, instance space](#MInetdeepsuper)


In [1]:
pip install -r requirements.txt

Collecting git+https://github.com/chlorochrule/cknn (from -r requirements.txt (line 9))
  Cloning https://github.com/chlorochrule/cknn to c:\users\seven\appdata\local\temp\pip-req-build-ix2t65th
  Resolved https://github.com/chlorochrule/cknn to commit 7d05c5049da72a573bd486fca6647f8b0376243c
  Installing build dependencies: started
  Installing build dependencies: still running...
  Installing build dependencies: finished with status 'done'
  Getting requirements to build wheel: started
  Getting requirements to build wheel: finished with status 'done'
  Preparing metadata (pyproject.toml): started
  Preparing metadata (pyproject.toml): finished with status 'done'
Collecting misvm
  Cloning https://github.com/garydoranjr/misvm.git to c:\users\seven\appdata\local\temp\pip-install-5h6zfwts\misvm_5380924db1b440b8999ab150a7cc8595
  Resolved https://github.com/garydoranjr/misvm.git to commit b2118fe04d98c00436bdf8a0e4bbfb6082c5751c
  Preparing metadata (setup.py): started
  Preparing metad

  Running command git clone --filter=blob:none --quiet https://github.com/chlorochrule/cknn 'C:\Users\seven\AppData\Local\Temp\pip-req-build-ix2t65th'
  Running command git clone --filter=blob:none --quiet https://github.com/garydoranjr/misvm.git 'C:\Users\seven\AppData\Local\Temp\pip-install-5h6zfwts\misvm_5380924db1b440b8999ab150a7cc8595'


### Embedded-Space


### MI-net
<a id='section_id'></a>

In [3]:
import numpy as np
import sys
import time
import random
from random import shuffle
import argparse

from keras.models import Model
from keras.optimizers import SGD
from keras.regularizers import l2
from keras.layers import Input, Dense, Layer, Dropout

from mil_nets.dataset import load_dataset
from mil_nets.layer import Feature_pooling
from mil_nets.metrics import bag_accuracy
from mil_nets.objectives import bag_loss
from mil_nets.utils import convertToBatch

In [12]:
def test_eval(model, test_set):
    """Evaluate on testing set.
    Parameters
    -----------------
    model : keras.engine.training.Model object
        The training MI-Net model.
    test_set : list
        A list of testing set contains all training bags features and labels.
    Returns
    -----------------
    test_loss : float
        Mean loss of evaluating on testing set.
    test_acc : float
        Mean accuracy of evaluating on testing set.
    """
    num_test_batch = len(test_set)
    test_loss = np.zeros((num_test_batch, 1), dtype=np.float32)
    test_acc = np.zeros((num_test_batch, 1), dtype=np.float32)
    for ibatch, batch in enumerate(test_set):
        result = model.test_on_batch({'input':batch[0].astype(np.float32)}, {'fp':batch[1].astype(np.float32)})
        test_loss[ibatch] = result[0]
        test_acc[ibatch][0] = result[1]
    return np.mean(test_loss), np.mean(test_acc)

def train_eval(model, train_set):
    """Evaluate on training set.
    Parameters
    -----------------
    model : keras.engine.training.Model object
        The training MI-Net model.
    train_set : list
        A list of training set contains all training bags features and labels.
    Returns
    -----------------
    test_loss : float
        Mean loss of evaluating on traing set..astype(np.float32)
    test_acc : float
        Mean accuracy of evaluating on testing set.
    """
    num_train_batch = len(train_set)
    train_loss = np.zeros((num_train_batch, 1), dtype=np.float32)
    train_acc = np.zeros((num_train_batch, 1), dtype=np.float32)
    shuffle(train_set)
    for ibatch, batch in enumerate(train_set):
        result = model.train_on_batch({'input':batch[0].astype(np.float32)}, {'fp':batch[1].astype(np.float32)})
        train_loss[ibatch] = result[0]
        train_acc[ibatch][0] = result[1]
    return np.mean(train_loss), np.mean(train_acc)

def MI_Net(dataset):
    """Train and evaluate on MI-Net.
    Parameters
    -----------------
    dataset : dict
        A dictionary contains all dataset information. We split train/test by keys.
    Returns
    -----------------
    test_acc : float
        Testing accuracy of MI-Net.
    """
    weight_decay=0.005
    init_lr=5e-4
    pooling_mode='max'
    momentum=0.9
    max_epoch=50
    # load data and convert type
    train_bags = dataset['train']
    test_bags = dataset['test']

    # convert bag to batch
    train_set = convertToBatch(train_bags)
    test_set = convertToBatch(test_bags)
    dimension = train_set[0][0].shape[1]

    # data: instance feature, n*d, n = number of training instance
    data_input = Input(shape=(dimension,), dtype='float32', name='input')

    # fully-connected
    fc1 = Dense(256, activation='relu', kernel_regularizer=l2(weight_decay))(data_input)
    fc2 = Dense(128, activation='relu', kernel_regularizer=l2(weight_decay))(fc1)
    fc3 = Dense(64, activation='relu', kernel_regularizer=l2(weight_decay))(fc2)

    # dropout
    dropout = Dropout(rate=0.5)(fc3)

    # features pooling
    fp = Feature_pooling(output_dim=1, kernel_regularizer=l2(weight_decay), pooling_mode=pooling_mode, name='fp')(dropout)

    model = Model(inputs=[data_input], outputs=[fp])
    sgd = SGD(lr=init_lr, decay=1e-4, momentum=momentum, nesterov=True)
    model.compile(loss=bag_loss, optimizer=sgd, metrics=[bag_accuracy])

    # train model
    t1 = time.time()
    num_batch = len(train_set)
    for epoch in range(max_epoch):
        train_loss, train_acc = train_eval(model, train_set)
        test_loss, test_acc = test_eval(model, test_set)
        print('epoch=', epoch, '  train_loss= {:.3f}'.format(train_loss), '  train_acc= {:.3f}'.format(train_acc), '  test_loss={:.3f}'.format(test_loss), '  test_acc= {:.3f}'.format(test_acc))
    t2 = time.time()
    print('run time:', (t2-t1) / 60, 'min')
    print('test_acc={:.3f}'.format(test_acc))

    return test_acc

In [13]:
# perform five times 10-fold cross-validation experiments
run = 5
n_folds = 10
acc = np.zeros((run, n_folds), dtype=np.float32)
for irun in range(run):
    dataset = load_dataset('musk1', n_folds)
    for ifold in range(n_folds):
        print('run=', irun, '  fold=', ifold)
        acc[irun][ifold] = MI_Net(dataset[ifold])
print('MI-Net mean accuracy = ', np.mean(acc))
print('std = ', np.std(acc))

run= 0   fold= 0
epoch= 0   train_loss= 2.966   train_acc= 0.598   test_loss=2.917   test_acc= 0.500
epoch= 1   train_loss= 2.684   train_acc= 0.817   test_loss=2.781   test_acc= 0.800
epoch= 2   train_loss= 2.531   train_acc= 0.939   test_loss=2.699   test_acc= 0.700
epoch= 3   train_loss= 2.424   train_acc= 0.939   test_loss=2.654   test_acc= 0.800
epoch= 4   train_loss= 2.362   train_acc= 0.976   test_loss=2.609   test_acc= 0.900
epoch= 5   train_loss= 2.351   train_acc= 0.951   test_loss=2.544   test_acc= 0.900
epoch= 6   train_loss= 2.305   train_acc= 0.976   test_loss=2.535   test_acc= 0.900
epoch= 7   train_loss= 2.284   train_acc= 0.976   test_loss=2.495   test_acc= 0.900
epoch= 8   train_loss= 2.223   train_acc= 0.988   test_loss=2.463   test_acc= 0.900
epoch= 9   train_loss= 2.218   train_acc= 0.988   test_loss=2.431   test_acc= 0.900
epoch= 10   train_loss= 2.203   train_acc= 0.988   test_loss=2.427   test_acc= 0.900
epoch= 11   train_loss= 2.169   train_acc= 0.988   test_lo

epoch= 46   train_loss= 1.704   train_acc= 1.000   test_loss=2.261   test_acc= 0.800
epoch= 47   train_loss= 1.692   train_acc= 1.000   test_loss=2.253   test_acc= 0.800
epoch= 48   train_loss= 1.683   train_acc= 1.000   test_loss=2.248   test_acc= 0.800
epoch= 49   train_loss= 1.674   train_acc= 1.000   test_loss=2.240   test_acc= 0.800
run time: 0.7714118798573811 min
test_acc=0.800
run= 0   fold= 2
epoch= 0   train_loss= 2.976   train_acc= 0.602   test_loss=2.805   test_acc= 0.778
epoch= 1   train_loss= 2.680   train_acc= 0.843   test_loss=2.599   test_acc= 0.889
epoch= 2   train_loss= 2.536   train_acc= 0.952   test_loss=2.561   test_acc= 0.889
epoch= 3   train_loss= 2.489   train_acc= 0.916   test_loss=2.583   test_acc= 0.889
epoch= 4   train_loss= 2.416   train_acc= 0.952   test_loss=2.438   test_acc= 1.000
epoch= 5   train_loss= 2.368   train_acc= 0.964   test_loss=2.453   test_acc= 0.889
epoch= 6   train_loss= 2.333   train_acc= 0.964   test_loss=2.442   test_acc= 0.889
epoch= 

epoch= 42   train_loss= 1.749   train_acc= 1.000   test_loss=2.023   test_acc= 0.889
epoch= 43   train_loss= 1.734   train_acc= 1.000   test_loss=2.005   test_acc= 0.889
epoch= 44   train_loss= 1.723   train_acc= 1.000   test_loss=2.006   test_acc= 0.889
epoch= 45   train_loss= 1.713   train_acc= 1.000   test_loss=1.982   test_acc= 0.889
epoch= 46   train_loss= 1.709   train_acc= 1.000   test_loss=1.981   test_acc= 0.889
epoch= 47   train_loss= 1.696   train_acc= 1.000   test_loss=1.954   test_acc= 0.889
epoch= 48   train_loss= 1.686   train_acc= 1.000   test_loss=1.946   test_acc= 0.889
epoch= 49   train_loss= 1.675   train_acc= 1.000   test_loss=1.952   test_acc= 0.889
run time: 0.7658559322357178 min
test_acc=0.889
run= 0   fold= 4
epoch= 0   train_loss= 2.933   train_acc= 0.651   test_loss=2.745   test_acc= 0.778
epoch= 1   train_loss= 2.653   train_acc= 0.867   test_loss=2.536   test_acc= 1.000
epoch= 2   train_loss= 2.506   train_acc= 0.928   test_loss=2.675   test_acc= 0.778
epo

epoch= 38   train_loss= 1.780   train_acc= 1.000   test_loss=2.100   test_acc= 0.889
epoch= 39   train_loss= 1.771   train_acc= 1.000   test_loss=2.094   test_acc= 0.889
epoch= 40   train_loss= 1.760   train_acc= 1.000   test_loss=2.094   test_acc= 0.889
epoch= 41   train_loss= 1.753   train_acc= 1.000   test_loss=2.077   test_acc= 0.889
epoch= 42   train_loss= 1.742   train_acc= 1.000   test_loss=2.076   test_acc= 0.889
epoch= 43   train_loss= 1.729   train_acc= 1.000   test_loss=2.070   test_acc= 0.889
epoch= 44   train_loss= 1.719   train_acc= 1.000   test_loss=2.024   test_acc= 0.889
epoch= 45   train_loss= 1.713   train_acc= 1.000   test_loss=2.114   test_acc= 0.889
epoch= 46   train_loss= 1.702   train_acc= 1.000   test_loss=2.060   test_acc= 0.889
epoch= 47   train_loss= 1.688   train_acc= 1.000   test_loss=2.033   test_acc= 0.889
epoch= 48   train_loss= 1.681   train_acc= 1.000   test_loss=2.003   test_acc= 0.889
epoch= 49   train_loss= 1.673   train_acc= 1.000   test_loss=2.00

epoch= 34   train_loss= 1.826   train_acc= 1.000   test_loss=2.126   test_acc= 0.889
epoch= 35   train_loss= 1.819   train_acc= 1.000   test_loss=2.122   test_acc= 0.889
epoch= 36   train_loss= 1.805   train_acc= 1.000   test_loss=2.103   test_acc= 0.889
epoch= 37   train_loss= 1.794   train_acc= 1.000   test_loss=2.092   test_acc= 0.889
epoch= 38   train_loss= 1.782   train_acc= 1.000   test_loss=2.079   test_acc= 0.889
epoch= 39   train_loss= 1.772   train_acc= 1.000   test_loss=2.068   test_acc= 0.889
epoch= 40   train_loss= 1.763   train_acc= 1.000   test_loss=2.059   test_acc= 0.889
epoch= 41   train_loss= 1.751   train_acc= 1.000   test_loss=2.053   test_acc= 0.889
epoch= 42   train_loss= 1.743   train_acc= 1.000   test_loss=2.049   test_acc= 0.889
epoch= 43   train_loss= 1.730   train_acc= 1.000   test_loss=2.036   test_acc= 0.889
epoch= 44   train_loss= 1.717   train_acc= 1.000   test_loss=2.028   test_acc= 0.889
epoch= 45   train_loss= 1.709   train_acc= 1.000   test_loss=2.01

epoch= 30   train_loss= 1.888   train_acc= 1.000   test_loss=2.414   test_acc= 0.667
epoch= 31   train_loss= 1.870   train_acc= 1.000   test_loss=2.369   test_acc= 0.667
epoch= 32   train_loss= 1.858   train_acc= 1.000   test_loss=2.370   test_acc= 0.667
epoch= 33   train_loss= 1.846   train_acc= 1.000   test_loss=2.318   test_acc= 0.778
epoch= 34   train_loss= 1.835   train_acc= 1.000   test_loss=2.298   test_acc= 0.778
epoch= 35   train_loss= 1.822   train_acc= 1.000   test_loss=2.314   test_acc= 0.667
epoch= 36   train_loss= 1.812   train_acc= 1.000   test_loss=2.313   test_acc= 0.667
epoch= 37   train_loss= 1.806   train_acc= 1.000   test_loss=2.352   test_acc= 0.667
epoch= 38   train_loss= 1.797   train_acc= 1.000   test_loss=2.312   test_acc= 0.667
epoch= 39   train_loss= 1.783   train_acc= 1.000   test_loss=2.318   test_acc= 0.667
epoch= 40   train_loss= 1.773   train_acc= 1.000   test_loss=2.278   test_acc= 0.667
epoch= 41   train_loss= 1.761   train_acc= 1.000   test_loss=2.27

epoch= 26   train_loss= 1.930   train_acc= 1.000   test_loss=2.255   test_acc= 0.900
epoch= 27   train_loss= 1.910   train_acc= 1.000   test_loss=2.244   test_acc= 0.900
epoch= 28   train_loss= 1.900   train_acc= 1.000   test_loss=2.227   test_acc= 0.900
epoch= 29   train_loss= 1.887   train_acc= 1.000   test_loss=2.221   test_acc= 0.900
epoch= 30   train_loss= 1.876   train_acc= 1.000   test_loss=2.214   test_acc= 0.900
epoch= 31   train_loss= 1.861   train_acc= 1.000   test_loss=2.202   test_acc= 0.900
epoch= 32   train_loss= 1.850   train_acc= 1.000   test_loss=2.187   test_acc= 0.900
epoch= 33   train_loss= 1.840   train_acc= 1.000   test_loss=2.181   test_acc= 0.900
epoch= 34   train_loss= 1.826   train_acc= 1.000   test_loss=2.158   test_acc= 0.900
epoch= 35   train_loss= 1.815   train_acc= 1.000   test_loss=2.143   test_acc= 0.900
epoch= 36   train_loss= 1.808   train_acc= 1.000   test_loss=2.139   test_acc= 0.900
epoch= 37   train_loss= 1.794   train_acc= 1.000   test_loss=2.13

epoch= 22   train_loss= 1.983   train_acc= 1.000   test_loss=2.300   test_acc= 0.889
epoch= 23   train_loss= 1.966   train_acc= 1.000   test_loss=2.292   test_acc= 0.889
epoch= 24   train_loss= 1.956   train_acc= 1.000   test_loss=2.254   test_acc= 0.889
epoch= 25   train_loss= 1.945   train_acc= 1.000   test_loss=2.252   test_acc= 0.889
epoch= 26   train_loss= 1.933   train_acc= 1.000   test_loss=2.240   test_acc= 0.889
epoch= 27   train_loss= 1.923   train_acc= 1.000   test_loss=2.238   test_acc= 0.889
epoch= 28   train_loss= 1.908   train_acc= 1.000   test_loss=2.229   test_acc= 0.889
epoch= 29   train_loss= 1.894   train_acc= 1.000   test_loss=2.227   test_acc= 0.889
epoch= 30   train_loss= 1.880   train_acc= 1.000   test_loss=2.196   test_acc= 0.889
epoch= 31   train_loss= 1.869   train_acc= 1.000   test_loss=2.187   test_acc= 0.889
epoch= 32   train_loss= 1.850   train_acc= 1.000   test_loss=2.169   test_acc= 0.889
epoch= 33   train_loss= 1.846   train_acc= 1.000   test_loss=2.15

epoch= 18   train_loss= 2.041   train_acc= 1.000   test_loss=2.245   test_acc= 1.000
epoch= 19   train_loss= 2.029   train_acc= 1.000   test_loss=2.217   test_acc= 0.889
epoch= 20   train_loss= 2.011   train_acc= 1.000   test_loss=2.202   test_acc= 0.889
epoch= 21   train_loss= 1.997   train_acc= 1.000   test_loss=2.209   test_acc= 0.889
epoch= 22   train_loss= 1.985   train_acc= 1.000   test_loss=2.191   test_acc= 0.889
epoch= 23   train_loss= 1.968   train_acc= 1.000   test_loss=2.161   test_acc= 0.889
epoch= 24   train_loss= 1.954   train_acc= 1.000   test_loss=2.142   test_acc= 1.000
epoch= 25   train_loss= 1.940   train_acc= 1.000   test_loss=2.122   test_acc= 1.000
epoch= 26   train_loss= 1.930   train_acc= 1.000   test_loss=2.114   test_acc= 0.889
epoch= 27   train_loss= 1.916   train_acc= 1.000   test_loss=2.097   test_acc= 0.889
epoch= 28   train_loss= 1.906   train_acc= 1.000   test_loss=2.104   test_acc= 0.889
epoch= 29   train_loss= 1.889   train_acc= 1.000   test_loss=2.07

epoch= 14   train_loss= 2.100   train_acc= 1.000   test_loss=2.448   test_acc= 0.889
epoch= 15   train_loss= 2.088   train_acc= 1.000   test_loss=2.453   test_acc= 0.889
epoch= 16   train_loss= 2.076   train_acc= 1.000   test_loss=2.428   test_acc= 0.889
epoch= 17   train_loss= 2.059   train_acc= 1.000   test_loss=2.427   test_acc= 0.889
epoch= 18   train_loss= 2.046   train_acc= 1.000   test_loss=2.438   test_acc= 0.889
epoch= 19   train_loss= 2.030   train_acc= 1.000   test_loss=2.414   test_acc= 0.889
epoch= 20   train_loss= 2.020   train_acc= 1.000   test_loss=2.400   test_acc= 0.889
epoch= 21   train_loss= 2.003   train_acc= 1.000   test_loss=2.387   test_acc= 0.889
epoch= 22   train_loss= 1.998   train_acc= 1.000   test_loss=2.392   test_acc= 0.889
epoch= 23   train_loss= 1.973   train_acc= 1.000   test_loss=2.360   test_acc= 0.889
epoch= 24   train_loss= 1.959   train_acc= 1.000   test_loss=2.358   test_acc= 0.889
epoch= 25   train_loss= 1.944   train_acc= 1.000   test_loss=2.35

epoch= 10   train_loss= 2.187   train_acc= 1.000   test_loss=2.286   test_acc= 1.000
epoch= 11   train_loss= 2.178   train_acc= 1.000   test_loss=2.317   test_acc= 0.889
epoch= 12   train_loss= 2.146   train_acc= 1.000   test_loss=2.287   test_acc= 0.889
epoch= 13   train_loss= 2.150   train_acc= 0.988   test_loss=2.211   test_acc= 1.000
epoch= 14   train_loss= 2.122   train_acc= 1.000   test_loss=2.240   test_acc= 1.000
epoch= 15   train_loss= 2.100   train_acc= 1.000   test_loss=2.222   test_acc= 1.000
epoch= 16   train_loss= 2.079   train_acc= 1.000   test_loss=2.212   test_acc= 0.889
epoch= 17   train_loss= 2.064   train_acc= 1.000   test_loss=2.169   test_acc= 1.000
epoch= 18   train_loss= 2.048   train_acc= 1.000   test_loss=2.150   test_acc= 1.000
epoch= 19   train_loss= 2.047   train_acc= 0.988   test_loss=2.143   test_acc= 1.000
epoch= 20   train_loss= 2.019   train_acc= 1.000   test_loss=2.148   test_acc= 1.000
epoch= 21   train_loss= 2.009   train_acc= 1.000   test_loss=2.11

epoch= 6   train_loss= 2.292   train_acc= 0.988   test_loss=2.461   test_acc= 0.900
epoch= 7   train_loss= 2.272   train_acc= 0.976   test_loss=2.400   test_acc= 1.000
epoch= 8   train_loss= 2.236   train_acc= 0.976   test_loss=2.384   test_acc= 1.000
epoch= 9   train_loss= 2.200   train_acc= 1.000   test_loss=2.374   test_acc= 0.900
epoch= 10   train_loss= 2.182   train_acc= 1.000   test_loss=2.349   test_acc= 0.900
epoch= 11   train_loss= 2.162   train_acc= 1.000   test_loss=2.288   test_acc= 1.000
epoch= 12   train_loss= 2.142   train_acc= 0.988   test_loss=2.324   test_acc= 0.900
epoch= 13   train_loss= 2.126   train_acc= 1.000   test_loss=2.253   test_acc= 1.000
epoch= 14   train_loss= 2.123   train_acc= 1.000   test_loss=2.251   test_acc= 0.900
epoch= 15   train_loss= 2.096   train_acc= 0.988   test_loss=2.218   test_acc= 0.900
epoch= 16   train_loss= 2.088   train_acc= 0.988   test_loss=2.196   test_acc= 1.000
epoch= 17   train_loss= 2.062   train_acc= 1.000   test_loss=2.185   

epoch= 2   train_loss= 2.463   train_acc= 0.952   test_loss=2.753   test_acc= 0.778
epoch= 3   train_loss= 2.415   train_acc= 0.976   test_loss=2.675   test_acc= 0.778
epoch= 4   train_loss= 2.346   train_acc= 0.952   test_loss=2.706   test_acc= 0.778
epoch= 5   train_loss= 2.325   train_acc= 0.976   test_loss=2.685   test_acc= 0.778
epoch= 6   train_loss= 2.264   train_acc= 1.000   test_loss=2.735   test_acc= 0.889
epoch= 7   train_loss= 2.266   train_acc= 0.988   test_loss=2.635   test_acc= 0.778
epoch= 8   train_loss= 2.227   train_acc= 0.976   test_loss=2.713   test_acc= 0.889
epoch= 9   train_loss= 2.203   train_acc= 0.988   test_loss=2.652   test_acc= 0.889
epoch= 10   train_loss= 2.187   train_acc= 0.988   test_loss=2.608   test_acc= 0.778
epoch= 11   train_loss= 2.163   train_acc= 0.988   test_loss=2.680   test_acc= 0.889
epoch= 12   train_loss= 2.135   train_acc= 1.000   test_loss=2.635   test_acc= 0.778
epoch= 13   train_loss= 2.120   train_acc= 1.000   test_loss=2.621   test

epoch= 48   train_loss= 1.678   train_acc= 1.000   test_loss=2.065   test_acc= 0.889
epoch= 49   train_loss= 1.674   train_acc= 1.000   test_loss=2.073   test_acc= 0.889
run time: 0.765680197874705 min
test_acc=0.889
run= 2   fold= 5
epoch= 0   train_loss= 2.920   train_acc= 0.663   test_loss=2.847   test_acc= 0.667
epoch= 1   train_loss= 2.679   train_acc= 0.795   test_loss=2.653   test_acc= 1.000
epoch= 2   train_loss= 2.529   train_acc= 0.892   test_loss=2.593   test_acc= 1.000
epoch= 3   train_loss= 2.460   train_acc= 0.964   test_loss=2.595   test_acc= 1.000
epoch= 4   train_loss= 2.381   train_acc= 0.964   test_loss=2.594   test_acc= 0.778
epoch= 5   train_loss= 2.343   train_acc= 0.976   test_loss=2.466   test_acc= 1.000
epoch= 6   train_loss= 2.290   train_acc= 0.988   test_loss=2.485   test_acc= 0.889
epoch= 7   train_loss= 2.288   train_acc= 0.976   test_loss=2.438   test_acc= 1.000
epoch= 8   train_loss= 2.233   train_acc= 1.000   test_loss=2.397   test_acc= 1.000
epoch= 9  

epoch= 44   train_loss= 1.721   train_acc= 1.000   test_loss=2.126   test_acc= 0.778
epoch= 45   train_loss= 1.713   train_acc= 1.000   test_loss=2.123   test_acc= 0.778
epoch= 46   train_loss= 1.703   train_acc= 1.000   test_loss=2.114   test_acc= 0.778
epoch= 47   train_loss= 1.692   train_acc= 1.000   test_loss=2.100   test_acc= 0.778
epoch= 48   train_loss= 1.689   train_acc= 1.000   test_loss=2.096   test_acc= 0.778
epoch= 49   train_loss= 1.678   train_acc= 1.000   test_loss=2.093   test_acc= 0.778
run time: 0.7465438326199849 min
test_acc=0.778
run= 2   fold= 7
epoch= 0   train_loss= 2.881   train_acc= 0.675   test_loss=2.705   test_acc= 1.000
epoch= 1   train_loss= 2.650   train_acc= 0.867   test_loss=2.701   test_acc= 1.000
epoch= 2   train_loss= 2.521   train_acc= 0.928   test_loss=2.661   test_acc= 0.889
epoch= 3   train_loss= 2.457   train_acc= 0.928   test_loss=2.488   test_acc= 1.000
epoch= 4   train_loss= 2.400   train_acc= 0.940   test_loss=2.447   test_acc= 1.000
epoch

epoch= 40   train_loss= 1.774   train_acc= 1.000   test_loss=1.849   test_acc= 1.000
epoch= 41   train_loss= 1.762   train_acc= 1.000   test_loss=1.837   test_acc= 1.000
epoch= 42   train_loss= 1.749   train_acc= 1.000   test_loss=1.829   test_acc= 1.000
epoch= 43   train_loss= 1.745   train_acc= 1.000   test_loss=1.814   test_acc= 1.000
epoch= 44   train_loss= 1.728   train_acc= 1.000   test_loss=1.801   test_acc= 1.000
epoch= 45   train_loss= 1.721   train_acc= 1.000   test_loss=1.787   test_acc= 1.000
epoch= 46   train_loss= 1.709   train_acc= 1.000   test_loss=1.779   test_acc= 1.000
epoch= 47   train_loss= 1.705   train_acc= 1.000   test_loss=1.762   test_acc= 1.000
epoch= 48   train_loss= 1.690   train_acc= 1.000   test_loss=1.762   test_acc= 1.000
epoch= 49   train_loss= 1.682   train_acc= 1.000   test_loss=1.749   test_acc= 1.000
run time: 0.725391165415446 min
test_acc=1.000
run= 2   fold= 9
epoch= 0   train_loss= 2.959   train_acc= 0.614   test_loss=2.886   test_acc= 0.556
ep

epoch= 36   train_loss= 1.801   train_acc= 1.000   test_loss=1.870   test_acc= 1.000
epoch= 37   train_loss= 1.793   train_acc= 1.000   test_loss=1.855   test_acc= 1.000
epoch= 38   train_loss= 1.782   train_acc= 1.000   test_loss=1.844   test_acc= 1.000
epoch= 39   train_loss= 1.771   train_acc= 1.000   test_loss=1.828   test_acc= 1.000
epoch= 40   train_loss= 1.760   train_acc= 1.000   test_loss=1.817   test_acc= 1.000
epoch= 41   train_loss= 1.748   train_acc= 1.000   test_loss=1.803   test_acc= 1.000
epoch= 42   train_loss= 1.739   train_acc= 1.000   test_loss=1.795   test_acc= 1.000
epoch= 43   train_loss= 1.729   train_acc= 1.000   test_loss=1.782   test_acc= 1.000
epoch= 44   train_loss= 1.722   train_acc= 1.000   test_loss=1.784   test_acc= 1.000
epoch= 45   train_loss= 1.712   train_acc= 1.000   test_loss=1.775   test_acc= 1.000
epoch= 46   train_loss= 1.698   train_acc= 1.000   test_loss=1.762   test_acc= 1.000
epoch= 47   train_loss= 1.694   train_acc= 1.000   test_loss=1.75

epoch= 32   train_loss= 1.861   train_acc= 1.000   test_loss=2.090   test_acc= 0.889
epoch= 33   train_loss= 1.847   train_acc= 1.000   test_loss=2.094   test_acc= 0.889
epoch= 34   train_loss= 1.834   train_acc= 1.000   test_loss=2.085   test_acc= 0.889
epoch= 35   train_loss= 1.837   train_acc= 1.000   test_loss=2.071   test_acc= 0.889
epoch= 36   train_loss= 1.814   train_acc= 1.000   test_loss=2.060   test_acc= 0.889
epoch= 37   train_loss= 1.803   train_acc= 1.000   test_loss=2.049   test_acc= 0.889
epoch= 38   train_loss= 1.794   train_acc= 1.000   test_loss=2.040   test_acc= 0.889
epoch= 39   train_loss= 1.785   train_acc= 1.000   test_loss=2.022   test_acc= 0.889
epoch= 40   train_loss= 1.768   train_acc= 1.000   test_loss=2.011   test_acc= 0.889
epoch= 41   train_loss= 1.761   train_acc= 1.000   test_loss=1.996   test_acc= 0.889
epoch= 42   train_loss= 1.753   train_acc= 1.000   test_loss=1.989   test_acc= 0.889
epoch= 43   train_loss= 1.736   train_acc= 1.000   test_loss=1.97

epoch= 28   train_loss= 1.906   train_acc= 1.000   test_loss=2.179   test_acc= 0.778
epoch= 29   train_loss= 1.898   train_acc= 1.000   test_loss=2.146   test_acc= 0.889
epoch= 30   train_loss= 1.884   train_acc= 1.000   test_loss=2.135   test_acc= 0.889
epoch= 31   train_loss= 1.868   train_acc= 1.000   test_loss=2.121   test_acc= 0.889
epoch= 32   train_loss= 1.865   train_acc= 1.000   test_loss=2.086   test_acc= 0.889
epoch= 33   train_loss= 1.847   train_acc= 1.000   test_loss=2.103   test_acc= 0.889
epoch= 34   train_loss= 1.835   train_acc= 1.000   test_loss=2.105   test_acc= 0.778
epoch= 35   train_loss= 1.825   train_acc= 1.000   test_loss=2.125   test_acc= 0.778
epoch= 36   train_loss= 1.812   train_acc= 1.000   test_loss=2.087   test_acc= 0.889
epoch= 37   train_loss= 1.804   train_acc= 1.000   test_loss=2.075   test_acc= 0.889
epoch= 38   train_loss= 1.792   train_acc= 1.000   test_loss=2.038   test_acc= 0.889
epoch= 39   train_loss= 1.780   train_acc= 1.000   test_loss=2.02

epoch= 24   train_loss= 1.953   train_acc= 1.000   test_loss=2.099   test_acc= 1.000
epoch= 25   train_loss= 1.940   train_acc= 1.000   test_loss=2.072   test_acc= 1.000
epoch= 26   train_loss= 1.931   train_acc= 1.000   test_loss=2.052   test_acc= 1.000
epoch= 27   train_loss= 1.916   train_acc= 1.000   test_loss=2.049   test_acc= 1.000
epoch= 28   train_loss= 1.898   train_acc= 1.000   test_loss=2.038   test_acc= 1.000
epoch= 29   train_loss= 1.891   train_acc= 1.000   test_loss=2.019   test_acc= 1.000
epoch= 30   train_loss= 1.878   train_acc= 1.000   test_loss=1.999   test_acc= 1.000
epoch= 31   train_loss= 1.868   train_acc= 1.000   test_loss=1.994   test_acc= 0.889
epoch= 32   train_loss= 1.852   train_acc= 1.000   test_loss=1.983   test_acc= 1.000
epoch= 33   train_loss= 1.844   train_acc= 1.000   test_loss=1.984   test_acc= 0.889
epoch= 34   train_loss= 1.829   train_acc= 1.000   test_loss=1.972   test_acc= 0.889
epoch= 35   train_loss= 1.819   train_acc= 1.000   test_loss=1.95

epoch= 20   train_loss= 2.009   train_acc= 1.000   test_loss=2.484   test_acc= 0.889
epoch= 21   train_loss= 2.000   train_acc= 1.000   test_loss=2.433   test_acc= 0.889
epoch= 22   train_loss= 1.989   train_acc= 1.000   test_loss=2.415   test_acc= 0.889
epoch= 23   train_loss= 1.965   train_acc= 1.000   test_loss=2.424   test_acc= 0.889
epoch= 24   train_loss= 1.958   train_acc= 1.000   test_loss=2.391   test_acc= 0.889
epoch= 25   train_loss= 1.938   train_acc= 1.000   test_loss=2.405   test_acc= 0.889
epoch= 26   train_loss= 1.926   train_acc= 1.000   test_loss=2.388   test_acc= 0.889
epoch= 27   train_loss= 1.913   train_acc= 1.000   test_loss=2.389   test_acc= 0.889
epoch= 28   train_loss= 1.898   train_acc= 1.000   test_loss=2.374   test_acc= 0.889
epoch= 29   train_loss= 1.886   train_acc= 1.000   test_loss=2.364   test_acc= 0.889
epoch= 30   train_loss= 1.874   train_acc= 1.000   test_loss=2.355   test_acc= 0.889
epoch= 31   train_loss= 1.864   train_acc= 1.000   test_loss=2.32

epoch= 16   train_loss= 2.069   train_acc= 1.000   test_loss=2.088   test_acc= 1.000
epoch= 17   train_loss= 2.056   train_acc= 1.000   test_loss=2.069   test_acc= 1.000
epoch= 18   train_loss= 2.036   train_acc= 1.000   test_loss=2.061   test_acc= 1.000
epoch= 19   train_loss= 2.029   train_acc= 1.000   test_loss=2.042   test_acc= 1.000
epoch= 20   train_loss= 2.010   train_acc= 1.000   test_loss=2.020   test_acc= 1.000
epoch= 21   train_loss= 1.997   train_acc= 1.000   test_loss=2.019   test_acc= 1.000
epoch= 22   train_loss= 1.981   train_acc= 1.000   test_loss=2.003   test_acc= 1.000
epoch= 23   train_loss= 1.966   train_acc= 1.000   test_loss=1.989   test_acc= 1.000
epoch= 24   train_loss= 1.956   train_acc= 1.000   test_loss=1.973   test_acc= 1.000
epoch= 25   train_loss= 1.942   train_acc= 1.000   test_loss=1.962   test_acc= 1.000
epoch= 26   train_loss= 1.929   train_acc= 1.000   test_loss=1.952   test_acc= 1.000
epoch= 27   train_loss= 1.920   train_acc= 1.000   test_loss=1.92

epoch= 12   train_loss= 2.144   train_acc= 1.000   test_loss=2.273   test_acc= 1.000
epoch= 13   train_loss= 2.118   train_acc= 1.000   test_loss=2.265   test_acc= 0.889
epoch= 14   train_loss= 2.105   train_acc= 1.000   test_loss=2.249   test_acc= 0.889
epoch= 15   train_loss= 2.097   train_acc= 0.988   test_loss=2.211   test_acc= 1.000
epoch= 16   train_loss= 2.073   train_acc= 1.000   test_loss=2.210   test_acc= 1.000
epoch= 17   train_loss= 2.066   train_acc= 1.000   test_loss=2.198   test_acc= 0.889
epoch= 18   train_loss= 2.042   train_acc= 1.000   test_loss=2.167   test_acc= 0.889
epoch= 19   train_loss= 2.021   train_acc= 1.000   test_loss=2.153   test_acc= 0.889
epoch= 20   train_loss= 2.018   train_acc= 1.000   test_loss=2.129   test_acc= 1.000
epoch= 21   train_loss= 2.001   train_acc= 1.000   test_loss=2.106   test_acc= 1.000
epoch= 22   train_loss= 1.983   train_acc= 1.000   test_loss=2.094   test_acc= 1.000
epoch= 23   train_loss= 1.967   train_acc= 1.000   test_loss=2.08

epoch= 8   train_loss= 2.228   train_acc= 1.000   test_loss=2.613   test_acc= 0.778
epoch= 9   train_loss= 2.211   train_acc= 1.000   test_loss=2.586   test_acc= 0.778
epoch= 10   train_loss= 2.183   train_acc= 1.000   test_loss=2.526   test_acc= 0.778
epoch= 11   train_loss= 2.173   train_acc= 0.988   test_loss=2.518   test_acc= 0.778
epoch= 12   train_loss= 2.152   train_acc= 0.988   test_loss=2.496   test_acc= 0.778
epoch= 13   train_loss= 2.132   train_acc= 0.988   test_loss=2.523   test_acc= 0.778
epoch= 14   train_loss= 2.113   train_acc= 1.000   test_loss=2.439   test_acc= 0.778
epoch= 15   train_loss= 2.094   train_acc= 1.000   test_loss=2.420   test_acc= 0.778
epoch= 16   train_loss= 2.079   train_acc= 1.000   test_loss=2.397   test_acc= 0.778
epoch= 17   train_loss= 2.068   train_acc= 1.000   test_loss=2.421   test_acc= 0.778
epoch= 18   train_loss= 2.048   train_acc= 1.000   test_loss=2.431   test_acc= 0.778
epoch= 19   train_loss= 2.026   train_acc= 1.000   test_loss=2.395 

epoch= 4   train_loss= 2.384   train_acc= 0.952   test_loss=2.545   test_acc= 0.889
epoch= 5   train_loss= 2.337   train_acc= 0.952   test_loss=2.534   test_acc= 0.778
epoch= 6   train_loss= 2.308   train_acc= 0.964   test_loss=2.513   test_acc= 0.778
epoch= 7   train_loss= 2.265   train_acc= 0.976   test_loss=2.529   test_acc= 0.778
epoch= 8   train_loss= 2.235   train_acc= 1.000   test_loss=2.491   test_acc= 0.778
epoch= 9   train_loss= 2.218   train_acc= 0.988   test_loss=2.456   test_acc= 0.778
epoch= 10   train_loss= 2.200   train_acc= 0.988   test_loss=2.447   test_acc= 0.778
epoch= 11   train_loss= 2.177   train_acc= 1.000   test_loss=2.441   test_acc= 0.778
epoch= 12   train_loss= 2.139   train_acc= 1.000   test_loss=2.419   test_acc= 0.778
epoch= 13   train_loss= 2.120   train_acc= 1.000   test_loss=2.411   test_acc= 0.778
epoch= 14   train_loss= 2.110   train_acc= 1.000   test_loss=2.407   test_acc= 0.778
epoch= 15   train_loss= 2.096   train_acc= 1.000   test_loss=2.415   te

epoch= 0   train_loss= 2.891   train_acc= 0.639   test_loss=2.814   test_acc= 0.778
epoch= 1   train_loss= 2.579   train_acc= 0.880   test_loss=2.723   test_acc= 0.778
epoch= 2   train_loss= 2.501   train_acc= 0.904   test_loss=2.789   test_acc= 0.667
epoch= 3   train_loss= 2.429   train_acc= 0.940   test_loss=2.664   test_acc= 0.778
epoch= 4   train_loss= 2.367   train_acc= 0.976   test_loss=2.751   test_acc= 0.778
epoch= 5   train_loss= 2.304   train_acc= 0.988   test_loss=2.647   test_acc= 0.778
epoch= 6   train_loss= 2.271   train_acc= 0.988   test_loss=2.616   test_acc= 0.778
epoch= 7   train_loss= 2.261   train_acc= 1.000   test_loss=2.677   test_acc= 0.778
epoch= 8   train_loss= 2.219   train_acc= 1.000   test_loss=2.674   test_acc= 0.778
epoch= 9   train_loss= 2.197   train_acc= 1.000   test_loss=2.661   test_acc= 0.778
epoch= 10   train_loss= 2.183   train_acc= 0.988   test_loss=2.676   test_acc= 0.778
epoch= 11   train_loss= 2.148   train_acc= 1.000   test_loss=2.655   test_a

epoch= 46   train_loss= 1.693   train_acc= 1.000   test_loss=1.764   test_acc= 1.000
epoch= 47   train_loss= 1.682   train_acc= 1.000   test_loss=1.753   test_acc= 1.000
epoch= 48   train_loss= 1.671   train_acc= 1.000   test_loss=1.750   test_acc= 1.000
epoch= 49   train_loss= 1.662   train_acc= 1.000   test_loss=1.743   test_acc= 1.000
run time: 0.7546258330345154 min
test_acc=1.000
MI-Net mean accuracy =  0.89066666
std =  0.10099162


### MI-net pooling deep supervision

In [7]:
import numpy as np
import sys
import time
import random
from random import shuffle
import argparse

from keras.models import Model
from keras.optimizers import SGD
from keras.regularizers import l2
from keras.layers import Input, Dense, Layer, Dropout, average

from mil_nets.dataset import load_dataset
from mil_nets.layer import Feature_pooling
from mil_nets.metrics import bag_accuracy
from mil_nets.objectives import bag_loss
from mil_nets.utils import convertToBatch

In [9]:
def test_eval(model, test_set):
    """Evaluate on testing set.
    Parameters
    -----------------
    model : keras.engine.training.Model object
        The training MI-Net with deep supervision model.
    test_set : list
        A list of testing set contains all training bags features and labels.
    Returns
    -----------------
    test_loss : float
        Mean loss of evaluating on testing set.
    test_acc : float
        Mean accuracy of evaluating on testing set.
    """
    num_test_batch = len(test_set)
    test_loss = np.zeros((num_test_batch, 1), dtype=np.float32)
    test_acc = np.zeros((num_test_batch, 1), dtype=np.float32)
    for ibatch, batch in enumerate(test_set):
        result = model.test_on_batch({'input':batch[0].astype(np.float32)}, {'fp1':batch[1].astype(np.float32), 'fp2':batch[1].astype(np.float32), 'fp3':batch[1].astype(np.float32), 'ave':batch[1].astype(np.float32)})
        test_loss[ibatch] = result[0]
        test_acc[ibatch] = result[-1]
    return np.mean(test_loss), np.mean(test_acc)

def train_eval(model, train_set):
    """Evaluate on training set.
    Parameters
    -----------------
    model : keras.engine.training.Model object
        The training MI-Net with deep supervision model.
    train_set : list
        A list of training set contains all training bags features and labels.
    Returns
    -----------------
    test_loss : float
        Mean loss of evaluating on traing set.
    test_acc : float
        Mean accuracy of evaluating on testing set.
    """
    num_train_batch = len(train_set)
    train_loss = np.zeros((num_train_batch, 1), dtype=np.float32)
    train_acc = np.zeros((num_train_batch, 1), dtype=np.float32)
    shuffle(train_set)
    for ibatch, batch in enumerate(train_set):
        result = model.train_on_batch({'input':batch[0].astype(np.float32)}, {'fp1':batch[1].astype(np.float32), 'fp2':batch[1].astype(np.float32), 'fp3':batch[1].astype(np.float32), 'ave':batch[1].astype(np.float32)})
        train_loss[ibatch] = result[0]
        train_acc[ibatch] = result[-1]
    return np.mean(train_loss), np.mean(train_acc)

def MI_Net_with_DS(dataset):
    """Train and evaluate on MI-Net with deep supervision.
    Parameters
    -----------------
    dataset : dict
        A dictionary contains all dataset information. We split train/test by keys.
    Returns
    -----------------
    test_acc : float
        Testing accuracy of MI-Net with deep supervision.
    """
    weight_decay=0.005
    init_lr=5e-4
    pooling_mode='max'
    momentum=0.9
    max_epoch=50
    # load data and convert type
    train_bags = dataset['train']
    test_bags = dataset['test']

    # convert bag to batch
    train_set = convertToBatch(train_bags)
    test_set = convertToBatch(test_bags)
    dimension = train_set[0][0].shape[1]
    weight = [1.0, 1.0, 1.0, 0.0]

    # data: instance feature, n*d, n = number of training instance
    data_input = Input(shape=(dimension,), dtype='float32', name='input')

    # fully-connected
    fc1 = Dense(256, activation='relu', kernel_regularizer=l2(weight_decay))(data_input)
    fc2 = Dense(128, activation='relu', kernel_regularizer=l2(weight_decay))(fc1)
    fc3 = Dense(64, activation='relu', kernel_regularizer=l2(weight_decay))(fc2)

    # dropout
    dropout1 = Dropout(rate=0.5)(fc1)
    dropout2 = Dropout(rate=0.5)(fc2)
    dropout3 = Dropout(rate=0.5)(fc3)

    # features pooling
    fp1 = Feature_pooling(output_dim=1, kernel_regularizer=l2(weight_decay), pooling_mode=pooling_mode, name='fp1')(dropout1)
    fp2 = Feature_pooling(output_dim=1, kernel_regularizer=l2(weight_decay), pooling_mode=pooling_mode, name='fp2')(dropout2)
    fp3 = Feature_pooling(output_dim=1, kernel_regularizer=l2(weight_decay), pooling_mode=pooling_mode, name='fp3')(dropout3)

    # score average
    mg_ave =average([fp1,fp2,fp3], name='ave')

    model = Model(inputs=[data_input], outputs=[fp1, fp2, fp3, mg_ave])
    sgd = SGD(lr=init_lr, decay=1e-4, momentum=momentum, nesterov=True)
    model.compile(loss={'fp1':bag_loss, 'fp2':bag_loss, 'fp3':bag_loss, 'ave':bag_loss}, loss_weights={'fp1':weight[0], 'fp2':weight[1], 'fp3':weight[2], 'ave':weight[3]}, optimizer=sgd, metrics=[bag_accuracy])

    # train model
    t1 = time.time()
    num_batch = len(train_set)
    for epoch in range(max_epoch):
        train_loss, train_acc = train_eval(model, train_set)
        test_loss, test_acc = test_eval(model, test_set)
        print('epoch=', epoch, '  train_loss= {:.3f}'.format(train_loss), '  train_acc= {:.3f}'.format(train_acc), '  test_loss={:.3f}'.format(test_loss), '  test_acc= {:.3f}'.format(test_acc))
    t2 = time.time()
    print('run time:', (t2-t1) / 60, 'min')
    print('test_acc={:.3f}'.format(test_acc))

    return test_acc

In [10]:
# perform five times 10-fold cross=validation experiments
run = 5
n_folds = 10
acc = np.zeros((run, n_folds), dtype=float)
for irun in range(run):
    dataset = load_dataset('musk1', n_folds)
    for ifold in range(n_folds):
        print('run=', irun, '  fold=', ifold)
        acc[irun][ifold] = MI_Net_with_DS(dataset[ifold])
print('MI-Net with DS mean accuracy = ', np.mean(acc))
print('std = ', np.std(acc))

run= 0   fold= 0
epoch= 0   train_loss= 4.697   train_acc= 0.646   test_loss=3.608   test_acc= 0.700
epoch= 1   train_loss= 3.576   train_acc= 0.817   test_loss=3.062   test_acc= 0.900
epoch= 2   train_loss= 2.949   train_acc= 0.927   test_loss=3.044   test_acc= 0.900
epoch= 3   train_loss= 2.781   train_acc= 0.939   test_loss=2.858   test_acc= 0.900
epoch= 4   train_loss= 2.590   train_acc= 0.976   test_loss=2.990   test_acc= 0.900
epoch= 5   train_loss= 2.591   train_acc= 0.951   test_loss=2.754   test_acc= 1.000
epoch= 6   train_loss= 2.410   train_acc= 0.988   test_loss=2.781   test_acc= 0.900
epoch= 7   train_loss= 2.364   train_acc= 1.000   test_loss=2.655   test_acc= 1.000
epoch= 8   train_loss= 2.309   train_acc= 1.000   test_loss=2.594   test_acc= 1.000
epoch= 9   train_loss= 2.276   train_acc= 1.000   test_loss=2.758   test_acc= 0.900
epoch= 10   train_loss= 2.267   train_acc= 1.000   test_loss=2.562   test_acc= 1.000
epoch= 11   train_loss= 2.248   train_acc= 1.000   test_lo

epoch= 46   train_loss= 1.776   train_acc= 1.000   test_loss=2.678   test_acc= 0.800
epoch= 47   train_loss= 1.760   train_acc= 1.000   test_loss=2.742   test_acc= 0.800
epoch= 48   train_loss= 1.755   train_acc= 1.000   test_loss=2.752   test_acc= 0.800
epoch= 49   train_loss= 1.741   train_acc= 1.000   test_loss=2.751   test_acc= 0.800
run time: 1.3564526200294496 min
test_acc=0.800
run= 0   fold= 2
epoch= 0   train_loss= 4.495   train_acc= 0.699   test_loss=3.243   test_acc= 0.889
epoch= 1   train_loss= 3.333   train_acc= 0.867   test_loss=3.100   test_acc= 1.000
epoch= 2   train_loss= 2.935   train_acc= 0.952   test_loss=2.785   test_acc= 1.000
epoch= 3   train_loss= 2.819   train_acc= 0.964   test_loss=2.667   test_acc= 1.000
epoch= 4   train_loss= 2.698   train_acc= 0.964   test_loss=2.637   test_acc= 1.000
epoch= 5   train_loss= 2.569   train_acc= 0.976   test_loss=2.508   test_acc= 1.000
epoch= 6   train_loss= 2.531   train_acc= 0.988   test_loss=2.448   test_acc= 1.000
epoch= 

epoch= 42   train_loss= 1.805   train_acc= 1.000   test_loss=2.073   test_acc= 1.000
epoch= 43   train_loss= 1.794   train_acc= 1.000   test_loss=2.052   test_acc= 1.000
epoch= 44   train_loss= 1.781   train_acc= 1.000   test_loss=2.047   test_acc= 1.000
epoch= 45   train_loss= 1.776   train_acc= 1.000   test_loss=2.029   test_acc= 1.000
epoch= 46   train_loss= 1.760   train_acc= 1.000   test_loss=2.007   test_acc= 1.000
epoch= 47   train_loss= 1.754   train_acc= 1.000   test_loss=2.004   test_acc= 1.000
epoch= 48   train_loss= 1.743   train_acc= 1.000   test_loss=1.992   test_acc= 1.000
epoch= 49   train_loss= 1.732   train_acc= 1.000   test_loss=1.968   test_acc= 1.000
run time: 1.2091976523399353 min
test_acc=1.000
run= 0   fold= 4
epoch= 0   train_loss= 4.329   train_acc= 0.747   test_loss=4.795   test_acc= 0.333
epoch= 1   train_loss= 3.264   train_acc= 0.904   test_loss=4.177   test_acc= 0.667
epoch= 2   train_loss= 2.806   train_acc= 0.976   test_loss=4.303   test_acc= 0.667
epo

epoch= 38   train_loss= 1.842   train_acc= 1.000   test_loss=2.998   test_acc= 0.889
epoch= 39   train_loss= 1.833   train_acc= 1.000   test_loss=3.003   test_acc= 0.889
epoch= 40   train_loss= 1.828   train_acc= 1.000   test_loss=2.969   test_acc= 0.889
epoch= 41   train_loss= 1.818   train_acc= 1.000   test_loss=2.956   test_acc= 0.889
epoch= 42   train_loss= 1.806   train_acc= 1.000   test_loss=2.971   test_acc= 0.889
epoch= 43   train_loss= 1.791   train_acc= 1.000   test_loss=2.968   test_acc= 0.889
epoch= 44   train_loss= 1.780   train_acc= 1.000   test_loss=2.941   test_acc= 0.889
epoch= 45   train_loss= 1.771   train_acc= 1.000   test_loss=2.940   test_acc= 0.889
epoch= 46   train_loss= 1.762   train_acc= 1.000   test_loss=2.942   test_acc= 0.889
epoch= 47   train_loss= 1.750   train_acc= 1.000   test_loss=2.935   test_acc= 0.889
epoch= 48   train_loss= 1.739   train_acc= 1.000   test_loss=2.934   test_acc= 0.889
epoch= 49   train_loss= 1.739   train_acc= 1.000   test_loss=2.92

epoch= 34   train_loss= 1.880   train_acc= 1.000   test_loss=4.390   test_acc= 0.667
epoch= 35   train_loss= 1.866   train_acc= 1.000   test_loss=4.461   test_acc= 0.667
epoch= 36   train_loss= 1.855   train_acc= 1.000   test_loss=4.388   test_acc= 0.667
epoch= 37   train_loss= 1.841   train_acc= 1.000   test_loss=4.349   test_acc= 0.667
epoch= 38   train_loss= 1.827   train_acc= 1.000   test_loss=4.373   test_acc= 0.667
epoch= 39   train_loss= 1.817   train_acc= 1.000   test_loss=4.385   test_acc= 0.667
epoch= 40   train_loss= 1.804   train_acc= 1.000   test_loss=4.400   test_acc= 0.667
epoch= 41   train_loss= 1.799   train_acc= 1.000   test_loss=4.332   test_acc= 0.667
epoch= 42   train_loss= 1.788   train_acc= 1.000   test_loss=4.353   test_acc= 0.667
epoch= 43   train_loss= 1.776   train_acc= 1.000   test_loss=4.289   test_acc= 0.667
epoch= 44   train_loss= 1.762   train_acc= 1.000   test_loss=4.243   test_acc= 0.667
epoch= 45   train_loss= 1.753   train_acc= 1.000   test_loss=4.32

epoch= 30   train_loss= 1.942   train_acc= 1.000   test_loss=2.096   test_acc= 1.000
epoch= 31   train_loss= 1.923   train_acc= 1.000   test_loss=2.061   test_acc= 1.000
epoch= 32   train_loss= 1.914   train_acc= 1.000   test_loss=2.039   test_acc= 1.000
epoch= 33   train_loss= 1.899   train_acc= 1.000   test_loss=2.025   test_acc= 1.000
epoch= 34   train_loss= 1.915   train_acc= 1.000   test_loss=2.025   test_acc= 1.000
epoch= 35   train_loss= 1.886   train_acc= 1.000   test_loss=2.022   test_acc= 1.000
epoch= 36   train_loss= 1.867   train_acc= 1.000   test_loss=1.997   test_acc= 1.000
epoch= 37   train_loss= 1.853   train_acc= 1.000   test_loss=1.992   test_acc= 1.000
epoch= 38   train_loss= 1.846   train_acc= 1.000   test_loss=1.977   test_acc= 1.000
epoch= 39   train_loss= 1.829   train_acc= 1.000   test_loss=1.963   test_acc= 1.000
epoch= 40   train_loss= 1.819   train_acc= 1.000   test_loss=1.950   test_acc= 1.000
epoch= 41   train_loss= 1.827   train_acc= 1.000   test_loss=1.94

epoch= 26   train_loss= 2.003   train_acc= 1.000   test_loss=2.980   test_acc= 0.900
epoch= 27   train_loss= 1.988   train_acc= 1.000   test_loss=2.952   test_acc= 0.900
epoch= 28   train_loss= 1.981   train_acc= 1.000   test_loss=2.922   test_acc= 0.900
epoch= 29   train_loss= 1.966   train_acc= 1.000   test_loss=2.920   test_acc= 0.900
epoch= 30   train_loss= 1.953   train_acc= 1.000   test_loss=2.930   test_acc= 0.900
epoch= 31   train_loss= 1.932   train_acc= 1.000   test_loss=2.893   test_acc= 0.900
epoch= 32   train_loss= 1.921   train_acc= 1.000   test_loss=2.801   test_acc= 0.900
epoch= 33   train_loss= 1.918   train_acc= 1.000   test_loss=2.892   test_acc= 0.900
epoch= 34   train_loss= 1.910   train_acc= 1.000   test_loss=2.730   test_acc= 0.900
epoch= 35   train_loss= 1.883   train_acc= 1.000   test_loss=2.799   test_acc= 0.900
epoch= 36   train_loss= 1.869   train_acc= 1.000   test_loss=2.788   test_acc= 0.900
epoch= 37   train_loss= 1.871   train_acc= 1.000   test_loss=2.81

epoch= 22   train_loss= 2.037   train_acc= 1.000   test_loss=3.485   test_acc= 0.667
epoch= 23   train_loss= 2.024   train_acc= 1.000   test_loss=3.518   test_acc= 0.667
epoch= 24   train_loss= 2.014   train_acc= 1.000   test_loss=3.348   test_acc= 0.667
epoch= 25   train_loss= 2.002   train_acc= 1.000   test_loss=3.536   test_acc= 0.667
epoch= 26   train_loss= 1.985   train_acc= 1.000   test_loss=3.549   test_acc= 0.667
epoch= 27   train_loss= 1.972   train_acc= 1.000   test_loss=3.391   test_acc= 0.667
epoch= 28   train_loss= 1.958   train_acc= 1.000   test_loss=3.362   test_acc= 0.667
epoch= 29   train_loss= 1.938   train_acc= 1.000   test_loss=3.400   test_acc= 0.667
epoch= 30   train_loss= 1.932   train_acc= 1.000   test_loss=3.605   test_acc= 0.667
epoch= 31   train_loss= 1.927   train_acc= 1.000   test_loss=3.416   test_acc= 0.667
epoch= 32   train_loss= 1.908   train_acc= 1.000   test_loss=3.376   test_acc= 0.667
epoch= 33   train_loss= 1.896   train_acc= 1.000   test_loss=3.35

epoch= 18   train_loss= 2.107   train_acc= 1.000   test_loss=2.522   test_acc= 1.000
epoch= 19   train_loss= 2.101   train_acc= 1.000   test_loss=2.547   test_acc= 0.889
epoch= 20   train_loss= 2.079   train_acc= 1.000   test_loss=2.503   test_acc= 0.889
epoch= 21   train_loss= 2.057   train_acc= 1.000   test_loss=2.498   test_acc= 0.889
epoch= 22   train_loss= 2.043   train_acc= 1.000   test_loss=2.478   test_acc= 0.889
epoch= 23   train_loss= 2.038   train_acc= 1.000   test_loss=2.445   test_acc= 0.889
epoch= 24   train_loss= 2.019   train_acc= 1.000   test_loss=2.420   test_acc= 1.000
epoch= 25   train_loss= 1.999   train_acc= 1.000   test_loss=2.406   test_acc= 1.000
epoch= 26   train_loss= 1.993   train_acc= 1.000   test_loss=2.379   test_acc= 1.000
epoch= 27   train_loss= 1.978   train_acc= 1.000   test_loss=2.367   test_acc= 1.000
epoch= 28   train_loss= 1.965   train_acc= 1.000   test_loss=2.357   test_acc= 1.000
epoch= 29   train_loss= 1.961   train_acc= 1.000   test_loss=2.34

epoch= 14   train_loss= 2.164   train_acc= 1.000   test_loss=2.341   test_acc= 1.000
epoch= 15   train_loss= 2.200   train_acc= 1.000   test_loss=2.374   test_acc= 1.000
epoch= 16   train_loss= 2.149   train_acc= 1.000   test_loss=2.352   test_acc= 1.000
epoch= 17   train_loss= 2.136   train_acc= 1.000   test_loss=2.307   test_acc= 1.000
epoch= 18   train_loss= 2.109   train_acc= 1.000   test_loss=2.289   test_acc= 1.000
epoch= 19   train_loss= 2.082   train_acc= 1.000   test_loss=2.255   test_acc= 1.000
epoch= 20   train_loss= 2.075   train_acc= 1.000   test_loss=2.212   test_acc= 1.000
epoch= 21   train_loss= 2.051   train_acc= 1.000   test_loss=2.196   test_acc= 1.000
epoch= 22   train_loss= 2.044   train_acc= 1.000   test_loss=2.195   test_acc= 1.000
epoch= 23   train_loss= 2.051   train_acc= 1.000   test_loss=2.186   test_acc= 1.000
epoch= 24   train_loss= 2.014   train_acc= 1.000   test_loss=2.176   test_acc= 1.000
epoch= 25   train_loss= 2.012   train_acc= 1.000   test_loss=2.15

epoch= 10   train_loss= 2.301   train_acc= 1.000   test_loss=2.674   test_acc= 1.000
epoch= 11   train_loss= 2.280   train_acc= 1.000   test_loss=2.621   test_acc= 1.000
epoch= 12   train_loss= 2.259   train_acc= 1.000   test_loss=2.710   test_acc= 0.889
epoch= 13   train_loss= 2.215   train_acc= 1.000   test_loss=2.629   test_acc= 1.000
epoch= 14   train_loss= 2.210   train_acc= 1.000   test_loss=2.648   test_acc= 1.000
epoch= 15   train_loss= 2.192   train_acc= 1.000   test_loss=2.651   test_acc= 0.889
epoch= 16   train_loss= 2.182   train_acc= 1.000   test_loss=2.555   test_acc= 1.000
epoch= 17   train_loss= 2.138   train_acc= 1.000   test_loss=2.567   test_acc= 1.000
epoch= 18   train_loss= 2.137   train_acc= 1.000   test_loss=2.520   test_acc= 1.000
epoch= 19   train_loss= 2.120   train_acc= 1.000   test_loss=2.501   test_acc= 1.000
epoch= 20   train_loss= 2.108   train_acc= 1.000   test_loss=2.550   test_acc= 0.889
epoch= 21   train_loss= 2.087   train_acc= 1.000   test_loss=2.50

epoch= 6   train_loss= 2.433   train_acc= 0.988   test_loss=2.832   test_acc= 1.000
epoch= 7   train_loss= 2.368   train_acc= 1.000   test_loss=2.791   test_acc= 1.000
epoch= 8   train_loss= 2.341   train_acc= 1.000   test_loss=2.763   test_acc= 1.000
epoch= 9   train_loss= 2.289   train_acc= 1.000   test_loss=2.599   test_acc= 1.000
epoch= 10   train_loss= 2.311   train_acc= 1.000   test_loss=2.691   test_acc= 1.000
epoch= 11   train_loss= 2.266   train_acc= 1.000   test_loss=2.665   test_acc= 1.000
epoch= 12   train_loss= 2.226   train_acc= 1.000   test_loss=2.608   test_acc= 1.000
epoch= 13   train_loss= 2.221   train_acc= 1.000   test_loss=2.650   test_acc= 1.000
epoch= 14   train_loss= 2.180   train_acc= 1.000   test_loss=2.575   test_acc= 1.000
epoch= 15   train_loss= 2.155   train_acc= 1.000   test_loss=2.589   test_acc= 1.000
epoch= 16   train_loss= 2.138   train_acc= 1.000   test_loss=2.543   test_acc= 1.000
epoch= 17   train_loss= 2.133   train_acc= 1.000   test_loss=2.509   

epoch= 2   train_loss= 2.923   train_acc= 0.952   test_loss=3.297   test_acc= 0.778
epoch= 3   train_loss= 2.757   train_acc= 0.940   test_loss=3.509   test_acc= 0.778
epoch= 4   train_loss= 2.647   train_acc= 0.964   test_loss=3.357   test_acc= 0.778
epoch= 5   train_loss= 2.498   train_acc= 1.000   test_loss=3.110   test_acc= 0.889
epoch= 6   train_loss= 2.411   train_acc= 1.000   test_loss=3.432   test_acc= 0.778
epoch= 7   train_loss= 2.460   train_acc= 0.976   test_loss=3.492   test_acc= 0.778
epoch= 8   train_loss= 2.365   train_acc= 1.000   test_loss=3.074   test_acc= 0.778
epoch= 9   train_loss= 2.314   train_acc= 1.000   test_loss=3.003   test_acc= 0.889
epoch= 10   train_loss= 2.292   train_acc= 1.000   test_loss=3.329   test_acc= 0.778
epoch= 11   train_loss= 2.262   train_acc= 1.000   test_loss=3.100   test_acc= 0.778
epoch= 12   train_loss= 2.238   train_acc= 1.000   test_loss=3.268   test_acc= 0.778
epoch= 13   train_loss= 2.205   train_acc= 1.000   test_loss=3.207   test

epoch= 48   train_loss= 1.730   train_acc= 1.000   test_loss=2.056   test_acc= 1.000
epoch= 49   train_loss= 1.727   train_acc= 1.000   test_loss=2.095   test_acc= 0.889
run time: 1.216727582613627 min
test_acc=0.889
run= 2   fold= 5
epoch= 0   train_loss= 4.389   train_acc= 0.699   test_loss=3.468   test_acc= 0.889
epoch= 1   train_loss= 3.408   train_acc= 0.819   test_loss=3.008   test_acc= 1.000
epoch= 2   train_loss= 3.033   train_acc= 0.916   test_loss=3.256   test_acc= 0.889
epoch= 3   train_loss= 2.811   train_acc= 0.952   test_loss=2.911   test_acc= 1.000
epoch= 4   train_loss= 2.716   train_acc= 0.964   test_loss=2.956   test_acc= 1.000
epoch= 5   train_loss= 2.552   train_acc= 0.988   test_loss=3.257   test_acc= 0.778
epoch= 6   train_loss= 2.533   train_acc= 0.964   test_loss=2.867   test_acc= 0.889
epoch= 7   train_loss= 2.391   train_acc= 1.000   test_loss=3.179   test_acc= 0.778
epoch= 8   train_loss= 2.356   train_acc= 1.000   test_loss=3.295   test_acc= 0.778
epoch= 9  

epoch= 44   train_loss= 1.776   train_acc= 1.000   test_loss=1.981   test_acc= 1.000
epoch= 45   train_loss= 1.774   train_acc= 1.000   test_loss=1.975   test_acc= 1.000
epoch= 46   train_loss= 1.786   train_acc= 1.000   test_loss=2.004   test_acc= 1.000
epoch= 47   train_loss= 1.760   train_acc= 1.000   test_loss=1.949   test_acc= 1.000
epoch= 48   train_loss= 1.738   train_acc= 1.000   test_loss=1.920   test_acc= 1.000
epoch= 49   train_loss= 1.728   train_acc= 1.000   test_loss=1.922   test_acc= 1.000
run time: 1.2274794141451517 min
test_acc=1.000
run= 2   fold= 7
epoch= 0   train_loss= 4.606   train_acc= 0.627   test_loss=3.453   test_acc= 1.000
epoch= 1   train_loss= 3.459   train_acc= 0.831   test_loss=3.101   test_acc= 0.889
epoch= 2   train_loss= 3.016   train_acc= 0.952   test_loss=3.034   test_acc= 1.000
epoch= 3   train_loss= 2.802   train_acc= 0.976   test_loss=2.881   test_acc= 1.000
epoch= 4   train_loss= 2.633   train_acc= 0.988   test_loss=2.849   test_acc= 1.000
epoch

epoch= 40   train_loss= 1.814   train_acc= 1.000   test_loss=3.258   test_acc= 0.889
epoch= 41   train_loss= 1.791   train_acc= 1.000   test_loss=3.288   test_acc= 0.778
epoch= 42   train_loss= 1.789   train_acc= 1.000   test_loss=3.282   test_acc= 0.778
epoch= 43   train_loss= 1.776   train_acc= 1.000   test_loss=3.245   test_acc= 0.889
epoch= 44   train_loss= 1.763   train_acc= 1.000   test_loss=3.276   test_acc= 0.778
epoch= 45   train_loss= 1.752   train_acc= 1.000   test_loss=3.225   test_acc= 0.889
epoch= 46   train_loss= 1.739   train_acc= 1.000   test_loss=3.258   test_acc= 0.778
epoch= 47   train_loss= 1.737   train_acc= 1.000   test_loss=3.163   test_acc= 0.889
epoch= 48   train_loss= 1.727   train_acc= 1.000   test_loss=3.244   test_acc= 0.778
epoch= 49   train_loss= 1.715   train_acc= 1.000   test_loss=3.223   test_acc= 0.778
run time: 1.2224454522132873 min
test_acc=0.778
run= 2   fold= 9
epoch= 0   train_loss= 4.240   train_acc= 0.614   test_loss=4.643   test_acc= 0.444
e

epoch= 36   train_loss= 1.863   train_acc= 1.000   test_loss=2.955   test_acc= 0.800
epoch= 37   train_loss= 1.845   train_acc= 1.000   test_loss=2.884   test_acc= 0.800
epoch= 38   train_loss= 1.839   train_acc= 1.000   test_loss=2.879   test_acc= 0.800
epoch= 39   train_loss= 1.820   train_acc= 1.000   test_loss=2.888   test_acc= 0.800
epoch= 40   train_loss= 1.818   train_acc= 1.000   test_loss=3.008   test_acc= 0.800
epoch= 41   train_loss= 1.807   train_acc= 1.000   test_loss=2.859   test_acc= 0.800
epoch= 42   train_loss= 1.788   train_acc= 1.000   test_loss=2.867   test_acc= 0.800
epoch= 43   train_loss= 1.780   train_acc= 1.000   test_loss=2.811   test_acc= 0.800
epoch= 44   train_loss= 1.772   train_acc= 1.000   test_loss=2.871   test_acc= 0.800
epoch= 45   train_loss= 1.754   train_acc= 1.000   test_loss=2.833   test_acc= 0.800
epoch= 46   train_loss= 1.750   train_acc= 1.000   test_loss=2.795   test_acc= 0.800
epoch= 47   train_loss= 1.737   train_acc= 1.000   test_loss=2.82

epoch= 32   train_loss= 1.917   train_acc= 1.000   test_loss=2.916   test_acc= 0.889
epoch= 33   train_loss= 1.917   train_acc= 1.000   test_loss=2.939   test_acc= 0.889
epoch= 34   train_loss= 1.899   train_acc= 1.000   test_loss=2.961   test_acc= 0.889
epoch= 35   train_loss= 1.889   train_acc= 1.000   test_loss=2.959   test_acc= 0.889
epoch= 36   train_loss= 1.876   train_acc= 1.000   test_loss=2.961   test_acc= 0.889
epoch= 37   train_loss= 1.869   train_acc= 1.000   test_loss=2.956   test_acc= 0.889
epoch= 38   train_loss= 1.854   train_acc= 1.000   test_loss=2.922   test_acc= 0.889
epoch= 39   train_loss= 1.843   train_acc= 1.000   test_loss=2.896   test_acc= 0.889
epoch= 40   train_loss= 1.830   train_acc= 1.000   test_loss=2.924   test_acc= 0.889
epoch= 41   train_loss= 1.825   train_acc= 1.000   test_loss=2.922   test_acc= 0.889
epoch= 42   train_loss= 1.811   train_acc= 1.000   test_loss=2.797   test_acc= 0.889
epoch= 43   train_loss= 1.797   train_acc= 1.000   test_loss=2.81

epoch= 28   train_loss= 1.970   train_acc= 1.000   test_loss=3.796   test_acc= 0.778
epoch= 29   train_loss= 1.962   train_acc= 1.000   test_loss=3.741   test_acc= 0.778
epoch= 30   train_loss= 1.953   train_acc= 1.000   test_loss=3.741   test_acc= 0.778
epoch= 31   train_loss= 1.942   train_acc= 1.000   test_loss=3.811   test_acc= 0.778
epoch= 32   train_loss= 1.923   train_acc= 1.000   test_loss=3.721   test_acc= 0.778
epoch= 33   train_loss= 1.919   train_acc= 1.000   test_loss=4.042   test_acc= 0.778
epoch= 34   train_loss= 1.901   train_acc= 1.000   test_loss=3.721   test_acc= 0.778
epoch= 35   train_loss= 1.884   train_acc= 1.000   test_loss=3.660   test_acc= 0.778
epoch= 36   train_loss= 1.875   train_acc= 1.000   test_loss=3.595   test_acc= 0.778
epoch= 37   train_loss= 1.861   train_acc= 1.000   test_loss=3.616   test_acc= 0.778
epoch= 38   train_loss= 1.854   train_acc= 1.000   test_loss=3.675   test_acc= 0.778
epoch= 39   train_loss= 1.846   train_acc= 1.000   test_loss=3.63

epoch= 24   train_loss= 2.030   train_acc= 1.000   test_loss=2.245   test_acc= 1.000
epoch= 25   train_loss= 2.002   train_acc= 1.000   test_loss=2.237   test_acc= 1.000
epoch= 26   train_loss= 1.987   train_acc= 1.000   test_loss=2.219   test_acc= 1.000
epoch= 27   train_loss= 1.971   train_acc= 1.000   test_loss=2.201   test_acc= 1.000
epoch= 28   train_loss= 1.962   train_acc= 1.000   test_loss=2.186   test_acc= 1.000
epoch= 29   train_loss= 1.948   train_acc= 1.000   test_loss=2.171   test_acc= 1.000
epoch= 30   train_loss= 1.939   train_acc= 1.000   test_loss=2.156   test_acc= 1.000
epoch= 31   train_loss= 1.933   train_acc= 1.000   test_loss=2.145   test_acc= 1.000
epoch= 32   train_loss= 1.910   train_acc= 1.000   test_loss=2.136   test_acc= 1.000
epoch= 33   train_loss= 1.898   train_acc= 1.000   test_loss=2.128   test_acc= 1.000
epoch= 34   train_loss= 1.880   train_acc= 1.000   test_loss=2.116   test_acc= 1.000
epoch= 35   train_loss= 1.875   train_acc= 1.000   test_loss=2.10

epoch= 20   train_loss= 2.096   train_acc= 1.000   test_loss=2.213   test_acc= 1.000
epoch= 21   train_loss= 2.081   train_acc= 1.000   test_loss=2.222   test_acc= 1.000
epoch= 22   train_loss= 2.062   train_acc= 1.000   test_loss=2.182   test_acc= 1.000
epoch= 23   train_loss= 2.041   train_acc= 1.000   test_loss=2.154   test_acc= 1.000
epoch= 24   train_loss= 2.019   train_acc= 1.000   test_loss=2.138   test_acc= 1.000
epoch= 25   train_loss= 2.018   train_acc= 1.000   test_loss=2.139   test_acc= 1.000
epoch= 26   train_loss= 2.004   train_acc= 1.000   test_loss=2.125   test_acc= 1.000
epoch= 27   train_loss= 1.987   train_acc= 1.000   test_loss=2.097   test_acc= 1.000
epoch= 28   train_loss= 1.973   train_acc= 1.000   test_loss=2.081   test_acc= 1.000
epoch= 29   train_loss= 1.959   train_acc= 1.000   test_loss=2.062   test_acc= 1.000
epoch= 30   train_loss= 1.943   train_acc= 1.000   test_loss=2.039   test_acc= 1.000
epoch= 31   train_loss= 1.932   train_acc= 1.000   test_loss=2.03

epoch= 16   train_loss= 2.148   train_acc= 1.000   test_loss=3.088   test_acc= 0.800
epoch= 17   train_loss= 2.133   train_acc= 1.000   test_loss=3.107   test_acc= 0.800
epoch= 18   train_loss= 2.107   train_acc= 1.000   test_loss=3.066   test_acc= 0.800
epoch= 19   train_loss= 2.100   train_acc= 1.000   test_loss=3.060   test_acc= 0.800
epoch= 20   train_loss= 2.076   train_acc= 1.000   test_loss=3.014   test_acc= 0.800
epoch= 21   train_loss= 2.063   train_acc= 1.000   test_loss=2.976   test_acc= 0.800
epoch= 22   train_loss= 2.049   train_acc= 1.000   test_loss=3.051   test_acc= 0.800
epoch= 23   train_loss= 2.033   train_acc= 1.000   test_loss=2.955   test_acc= 0.800
epoch= 24   train_loss= 2.040   train_acc= 1.000   test_loss=2.819   test_acc= 0.800
epoch= 25   train_loss= 2.006   train_acc= 1.000   test_loss=2.922   test_acc= 0.800
epoch= 26   train_loss= 1.992   train_acc= 1.000   test_loss=2.958   test_acc= 0.800
epoch= 27   train_loss= 1.979   train_acc= 1.000   test_loss=2.98

epoch= 12   train_loss= 2.222   train_acc= 1.000   test_loss=2.343   test_acc= 1.000
epoch= 13   train_loss= 2.191   train_acc= 1.000   test_loss=2.321   test_acc= 1.000
epoch= 14   train_loss= 2.154   train_acc= 1.000   test_loss=2.313   test_acc= 1.000
epoch= 15   train_loss= 2.149   train_acc= 1.000   test_loss=2.301   test_acc= 1.000
epoch= 16   train_loss= 2.128   train_acc= 1.000   test_loss=2.273   test_acc= 1.000
epoch= 17   train_loss= 2.125   train_acc= 1.000   test_loss=2.231   test_acc= 1.000
epoch= 18   train_loss= 2.081   train_acc= 1.000   test_loss=2.216   test_acc= 1.000
epoch= 19   train_loss= 2.083   train_acc= 1.000   test_loss=2.202   test_acc= 1.000
epoch= 20   train_loss= 2.067   train_acc= 1.000   test_loss=2.184   test_acc= 1.000
epoch= 21   train_loss= 2.045   train_acc= 1.000   test_loss=2.167   test_acc= 1.000
epoch= 22   train_loss= 2.036   train_acc= 1.000   test_loss=2.149   test_acc= 1.000
epoch= 23   train_loss= 2.036   train_acc= 1.000   test_loss=2.14

epoch= 8   train_loss= 2.323   train_acc= 1.000   test_loss=3.088   test_acc= 0.889
epoch= 9   train_loss= 2.328   train_acc= 1.000   test_loss=3.091   test_acc= 0.778
epoch= 10   train_loss= 2.353   train_acc= 0.988   test_loss=3.253   test_acc= 0.778
epoch= 11   train_loss= 2.263   train_acc= 0.988   test_loss=3.033   test_acc= 0.778
epoch= 12   train_loss= 2.225   train_acc= 1.000   test_loss=3.114   test_acc= 0.778
epoch= 13   train_loss= 2.229   train_acc= 1.000   test_loss=2.975   test_acc= 0.889
epoch= 14   train_loss= 2.191   train_acc= 1.000   test_loss=3.202   test_acc= 0.778
epoch= 15   train_loss= 2.176   train_acc= 1.000   test_loss=2.926   test_acc= 0.778
epoch= 16   train_loss= 2.149   train_acc= 1.000   test_loss=2.945   test_acc= 0.778
epoch= 17   train_loss= 2.124   train_acc= 1.000   test_loss=2.925   test_acc= 0.778
epoch= 18   train_loss= 2.116   train_acc= 1.000   test_loss=2.870   test_acc= 0.889
epoch= 19   train_loss= 2.106   train_acc= 1.000   test_loss=3.021 

epoch= 4   train_loss= 2.838   train_acc= 0.952   test_loss=2.817   test_acc= 1.000
epoch= 5   train_loss= 2.610   train_acc= 0.976   test_loss=2.784   test_acc= 0.889
epoch= 6   train_loss= 2.515   train_acc= 0.988   test_loss=2.673   test_acc= 1.000
epoch= 7   train_loss= 2.433   train_acc= 0.988   test_loss=2.621   test_acc= 1.000
epoch= 8   train_loss= 2.388   train_acc= 0.988   test_loss=2.589   test_acc= 1.000
epoch= 9   train_loss= 2.388   train_acc= 0.988   test_loss=2.543   test_acc= 1.000
epoch= 10   train_loss= 2.311   train_acc= 1.000   test_loss=2.457   test_acc= 1.000
epoch= 11   train_loss= 2.315   train_acc= 0.988   test_loss=2.456   test_acc= 1.000
epoch= 12   train_loss= 2.260   train_acc= 1.000   test_loss=2.410   test_acc= 1.000
epoch= 13   train_loss= 2.253   train_acc= 1.000   test_loss=2.398   test_acc= 1.000
epoch= 14   train_loss= 2.212   train_acc= 1.000   test_loss=2.350   test_acc= 1.000
epoch= 15   train_loss= 2.198   train_acc= 1.000   test_loss=2.339   te

epoch= 0   train_loss= 4.349   train_acc= 0.675   test_loss=4.071   test_acc= 0.778
epoch= 1   train_loss= 3.517   train_acc= 0.855   test_loss=3.668   test_acc= 0.778
epoch= 2   train_loss= 2.958   train_acc= 0.952   test_loss=3.866   test_acc= 0.778
epoch= 3   train_loss= 2.739   train_acc= 0.952   test_loss=3.498   test_acc= 0.778
epoch= 4   train_loss= 2.633   train_acc= 0.976   test_loss=3.398   test_acc= 0.778
epoch= 5   train_loss= 2.480   train_acc= 0.988   test_loss=3.388   test_acc= 0.778
epoch= 6   train_loss= 2.428   train_acc= 0.988   test_loss=3.458   test_acc= 0.889
epoch= 7   train_loss= 2.378   train_acc= 1.000   test_loss=3.345   test_acc= 0.778
epoch= 8   train_loss= 2.311   train_acc= 1.000   test_loss=3.404   test_acc= 0.778
epoch= 9   train_loss= 2.335   train_acc= 1.000   test_loss=3.350   test_acc= 0.778
epoch= 10   train_loss= 2.269   train_acc= 1.000   test_loss=3.317   test_acc= 0.778
epoch= 11   train_loss= 2.243   train_acc= 1.000   test_loss=3.404   test_a

epoch= 46   train_loss= 1.758   train_acc= 1.000   test_loss=2.699   test_acc= 0.778
epoch= 47   train_loss= 1.747   train_acc= 1.000   test_loss=2.683   test_acc= 0.778
epoch= 48   train_loss= 1.737   train_acc= 1.000   test_loss=2.665   test_acc= 0.778
epoch= 49   train_loss= 1.737   train_acc= 1.000   test_loss=2.669   test_acc= 0.778
run time: 1.343178351720174 min
test_acc=0.778
MI-Net with DS mean accuracy =  0.8822222262620926
std =  0.12047231613055681


## Bag-Space

*CkNN*

https://github.com/chlorochrule/cknn/blob/master/cknn/cknn.py

In [None]:
import pandas as pd
X = pd.read_table("./clean2.data") #pd.read_csv("sample_data/mnist_test.csv") 

In [None]:
from cknn import cknneighbors_graph

#ckng = cknneighbors_graph(X, n_neighbors=5, delta=1.0)

In [None]:
import numpy as np
from sklearn.datasets import load_digits
from sklearn.manifold import SpectralEmbedding
import matplotlib.pyplot as plt
from matplotlib import offsetbox
import seaborn as sns

from cknn import cknneighbors_graph

sns.set()


def plot2d_label(X, title=None):
    digits = load_digits()
    y = digits.target
    x_min, x_max = np.min(X, 0), np.max(X, 0)
    X = (X - x_min) / (x_max - x_min)

    plt.figure()
    ax = plt.subplot(111)
    for i in range(X.shape[0]):
        plt.text(X[i, 0], X[i, 1], str(digits.target[i]),
                 color=plt.cm.Set1(y[i] / 10.),
                 fontdict={'weight': 'bold', 'size': 9})

    
    plt.xticks([])
    plt.yticks([])
    if title is not None:
        plt.title(title)
    plt.show()


def main():
    data = X
    print(data)
    n_neighbors = 2

    model_normal = SpectralEmbedding(n_components=2, n_neighbors=n_neighbors)
    y_normal = model_normal.fit_transform(data)
    plot2d_label(y_normal)

    #ckng = cknneighbors_graph(data, n_neighbors=n_neighbors, delta=1.5)
    #model_cknn = SpectralEmbedding(n_components=2, affinity='precomputed')
    #y_cknn = model_cknn.fit_transform(ckng.toarray())
    #plot2d_label(y_cknn)

main()

     MUSK-211,211_1+1,46,-108,-60,-69,-117,49,38,-161,-8,5,-323,-220,-113,-299,-283,-307,-31,-106,-227,-42,-59,-22,-67,189,81,17,-27,-89,-67,105,-116,124,-106,5,-120,63,-165,40,-27,68,-44,98,-33,-314,-282,-335,-144,-13,-197,-2,-144,-13,-11,-131,108,-43,42,-151,-4,8,-102,51,-15,108,-135,59,-166,20,-20,23,-48,-68,-299,-256,-97,-183,-24,-271,-229,-177,-6,0,-129,112,15,36,-66,-54,-75,132,-188,119,-120,-312,23,-55,-53,-26,-71,41,-55,148,-247,-306,-308,-230,-166,-35,-205,-280,-239,-53,-10,-23,25,-5,163,61,59,-39,92,72,113,-107,80,25,-27,81,-114,-187,45,-118,-75,-182,-234,-19,12,-13,-41,-119,-149,70,17,-20,-177,-101,-116,-14,-50,24,-81,-125,-114,-44,128,3,-244,-308,52,-7,39,126,156,-50,-112,96,1.
0     MUSK-211,211_1+10,41,-188,-145,22,-117,-6,57,-...                                                                                                                                                                                                                                                      

ValueError: ignored

In [1]:
import numpy as np
from sklearn.datasets import load_digits
from sklearn.manifold import SpectralEmbedding
import matplotlib.pyplot as plt
from matplotlib import offsetbox
import seaborn as sns

sns.set()


def plot2d_label(X, title=None):
    y = X[1]
    x_min, x_max = np.min(X[0], 0), np.max(X[0], 0)
    X = (X - x_min) / (x_max - x_min)

    plt.figure()
    plt.xticks([])
    plt.yticks([])
    if title is not None:
        plt.title(title)
    plt.show()


def main():
    data = X
    n_neighbors = 10

    model_normal = SpectralEmbedding(n_components=2, n_neighbors=n_neighbors)
    y_normal = model_normal.fit_transform(data)
    #plot2d_label(y_normal)

    ckng = cknneighbors_graph(data, n_neighbors=n_neighbors, delta=1.5)
    model_cknn = SpectralEmbedding(n_components=2, affinity='precomputed')
    y_cknn = model_cknn.fit_transform(ckng.toarray())
    #plot2d_label(y_cknn)
    print(y_cknn)

main()

NameError: name 'X' is not defined

## instance-Space


MI-SVM and mi-SVM

In [6]:
import misvm
from loader import parse_c45, bag_set
from __future__ import print_function, division
import numpy as np

In [7]:
# Load list of C4.5 Examples
example_set = parse_c45('fox')

# Get stats to normalize data
raw_data = np.array(example_set.to_float())
data_mean = np.average(raw_data, axis=0)
data_std  = np.std(raw_data, axis=0)
data_std[np.nonzero(data_std == 0.0)] = 1.0
def normalizer(ex):
    ex = np.array(ex)
    normed = ((ex - data_mean) / data_std)
    # The ...[:, 2:-1] removes first two columns and last column,
    # which are the bag/instance ids and class label, as part of the
    # normalization process
    return normed[2:-1]


# Group examples into bags
bagset = bag_set(example_set)

# Convert bags to NumPy arrays
bags = [np.array(b.to_float(normalizer)) for b in bagset]
labels = np.array([b.label for b in bagset], dtype=float)
# Convert 0/1 labels to -1/1 labels
labels = 2 * labels - 1

# Spilt dataset arbitrarily to train/test sets
train_bags = bags[10:]
train_labels = labels[10:]
test_bags = bags[:10]
test_labels = labels[:10]

# Construct classifiers
classifiers = {}

# MISVM   : the MI-SVM algorithm of Andrews, Tsochantaridis, & Hofmann (2002)
# miSVM   : the mi-SVM algorithm of Andrews, Tsochantaridis, & Hofmann (2002)

#  : the semi-supervised learning approach of Zhou & Xu (2007)
#     : the MI classification algorithm of Mangasarian & Wild (2008)
# sMIL    : sparse MIL (Bunescu & Mooney, 2007)
# stMIL   : sparse, transductive  MIL (Bunescu & Mooney, 2007)

classifiers['MissSVM'] = misvm.MissSVM(kernel='linear', C=1.0, max_iters=20)
classifiers['sbMIL'] = misvm.sbMIL(kernel='linear', eta=0.1, C=1e2)
classifiers['SIL'] = misvm.SIL(kernel='linear', C=1.0)
classifiers['STK'] = misvm.STK(kernel='linear', C=1.0)
classifiers['NSK'] = misvm.NSK(kernel='linear', C=1.0)
classifiers['MICA'] = misvm.MICA(kernel='linear', C=1.0)

# Train/Evaluate classifiers
accuracies = {}
for algorithm, classifier in classifiers.items():
    classifier.fit(train_bags, train_labels)
    predictions = classifier.predict(test_bags)
    accuracies[algorithm] = np.average(test_labels == np.sign(predictions))

for algorithm, accuracy in accuracies.items():
    print('\n%s Accuracy: %.f%%' % (algorithm, 100 * accuracy))

Non-random start...
     pcost       dcost       gap    pres   dres
 0: -7.9927e+02 -2.5551e+01  1e+04  1e+02  2e-12
 1: -2.6254e+01 -2.5378e+01  3e+02  3e+00  2e-12
 2: -8.9020e+00 -2.1580e+01  4e+01  2e-01  2e-13
 3: -7.6256e+00 -1.5791e+01  1e+01  7e-02  7e-14
 4: -7.6313e+00 -1.0442e+01  3e+00  1e-02  4e-14
 5: -8.1403e+00 -8.7778e+00  7e-01  2e-03  4e-14
 6: -8.2634e+00 -8.5681e+00  3e-01  7e-04  4e-14
 7: -8.3370e+00 -8.4497e+00  1e-01  2e-04  4e-14
 8: -8.3702e+00 -8.4012e+00  3e-02  1e-05  5e-14
 9: -8.3799e+00 -8.3891e+00  9e-03  2e-06  4e-14
10: -8.3838e+00 -8.3846e+00  8e-04  1e-07  4e-14
11: -8.3842e+00 -8.3842e+00  4e-05  6e-09  5e-14
12: -8.3842e+00 -8.3842e+00  6e-07  9e-11  5e-14
Optimal solution found.

Iteration 1...
Linearizing constraints...
Computing slacks...
Linearizing...
Solving QP...
     pcost       dcost       gap    pres   dres
 0: -7.9958e+02 -2.5496e+01  1e+04  1e+02  2e-12
 1: -2.3437e+01 -2.5311e+01  3e+02  3e+00  3e-12
 2: -8.7003e+00 -2.0664e+01  4e+0

 3: -6.7004e+00 -1.3980e+01  1e+01  5e-02  6e-14
 4: -6.8085e+00 -8.7354e+00  3e+00  1e-02  5e-14
 5: -6.9048e+00 -7.5985e+00  9e-01  3e-03  4e-14
 6: -6.9363e+00 -7.4590e+00  7e-01  2e-03  4e-14
 7: -6.9373e+00 -7.4190e+00  6e-01  1e-03  4e-14
 8: -6.9803e+00 -7.2138e+00  3e-01  5e-04  4e-14
 9: -6.9985e+00 -7.1278e+00  2e-01  2e-04  4e-14
10: -7.0103e+00 -7.0791e+00  9e-02  1e-04  4e-14
11: -7.0167e+00 -7.0568e+00  5e-02  3e-05  5e-14
12: -7.0205e+00 -7.0450e+00  3e-02  2e-05  4e-14
13: -7.0229e+00 -7.0380e+00  2e-02  1e-05  4e-14
14: -7.0235e+00 -7.0364e+00  2e-02  6e-06  4e-14
15: -7.0248e+00 -7.0333e+00  1e-02  4e-06  4e-14
16: -7.0263e+00 -7.0295e+00  4e-03  3e-07  5e-14
17: -7.0269e+00 -7.0285e+00  2e-03  1e-07  5e-14
18: -7.0273e+00 -7.0279e+00  6e-04  3e-08  5e-14
19: -7.0275e+00 -7.0276e+00  2e-04  6e-09  5e-14
20: -7.0275e+00 -7.0276e+00  3e-05  8e-10  5e-14
21: -7.0275e+00 -7.0276e+00  4e-06  1e-10  5e-14
Optimal solution found.
delta obj ratio: 3.37e+04

Iteration 10...
Li

14: -6.9411e+00 -6.9415e+00  4e-04  1e-08  5e-14
15: -6.9412e+00 -6.9413e+00  7e-05  2e-09  5e-14
16: -6.9413e+00 -6.9413e+00  1e-05  2e-10  5e-14
17: -6.9413e+00 -6.9413e+00  1e-06  1e-11  6e-14
Optimal solution found.
delta obj ratio: 1.25e+04

Iteration 17...
Linearizing constraints...
Computing slacks...
Linearizing...
Solving QP...
     pcost       dcost       gap    pres   dres
 0: -7.9858e+02 -2.5200e+01  1e+04  1e+02  3e-12
 1: -2.6910e+01 -2.5010e+01  4e+02  3e+00  2e-12
 2: -8.4215e+00 -2.0997e+01  5e+01  3e-01  2e-13
 3: -6.6306e+00 -1.3368e+01  1e+01  6e-02  6e-14
 4: -6.4941e+00 -8.8395e+00  3e+00  9e-03  5e-14
 5: -6.6424e+00 -8.1300e+00  2e+00  5e-03  4e-14
 6: -6.7453e+00 -7.6673e+00  1e+00  3e-03  4e-14
 7: -6.8172e+00 -7.2604e+00  5e-01  3e-04  5e-14
 8: -6.8930e+00 -7.0219e+00  1e-01  8e-05  5e-14
 9: -6.9090e+00 -6.9860e+00  9e-02  4e-05  4e-14
10: -6.9155e+00 -6.9714e+00  6e-02  3e-05  4e-14
11: -6.9173e+00 -6.9687e+00  6e-02  2e-05  4e-14
12: -6.9258e+00 -6.9515e+



Training initial sMIL classifier for sbMIL...
Setup QP...
Solving QP...
     pcost       dcost       gap    pres   dres
 0: -3.8213e+00 -3.6872e+02  5e+03  5e+00  3e-13
 1: -1.8471e+00 -2.4416e+02  7e+02  5e-01  2e-13
 2: -4.4218e-01 -9.4044e+01  1e+02  7e-02  1e-13
 3:  3.0227e-02 -1.1960e+01  2e+01  7e-03  3e-14
 4:  3.9888e-02 -2.2258e+00  3e+00  1e-03  5e-15
 5:  2.5598e-02 -6.7670e-02  1e-01  6e-06  8e-15
 6:  1.2254e-03 -8.6537e-03  1e-02  2e-09  2e-15
 7: -1.1701e-03 -4.1400e-03  3e-03  4e-10  8e-16
 8: -1.3407e-03 -4.3097e-03  3e-03  4e-10  8e-16
 9: -2.0282e-03 -3.7965e-03  2e-03  2e-10  4e-16
10: -2.3677e-03 -3.6344e-03  1e-03  6e-11  4e-16
11: -2.6621e-03 -3.4649e-03  8e-04  2e-16  4e-16
12: -2.8946e-03 -3.0201e-03  1e-04  2e-16  5e-16
13: -2.9351e-03 -2.9627e-03  3e-05  2e-16  4e-16
14: -2.9480e-03 -2.9483e-03  3e-07  2e-16  4e-16
15: -2.9482e-03 -2.9482e-03  3e-09  2e-16  5e-16
Optimal solution found.
Computing initial instance labels for sbMIL...
Retraining with top 10% a

86: -8.5908e-01 -8.7072e-01  1e-02  2e-15  4e-02
87: -8.6238e-01 -8.6552e-01  3e-03  6e-16  8e-03
88: -8.6329e-01 -8.6403e-01  7e-04  2e-16  4e-04
89: -8.6354e-01 -8.6367e-01  1e-04  1e-16  5e-05
90: -8.6359e-01 -8.6360e-01  1e-05  6e-17  1e-07
91: -8.6360e-01 -8.6360e-01  5e-07  9e-17  4e-09
Optimal solution found.
Update LP...
Solve LP...
delta obj ratio: 7.55e+06

Iteration 3...
Update QP...
Solve QP...
     pcost       dcost       gap    pres   dres
 0: -7.6213e-01 -7.6213e-01  5e-07  9e-17  7e-01
 1: -7.6219e-01 -7.6219e-01  9e-07  2e-16  7e-01
 2: -7.6216e-01 -7.6216e-01  1e-06  2e-16  7e-01
 3: -7.6219e-01 -7.6219e-01  2e-06  2e-16  7e-01
 4: -7.6226e-01 -7.6226e-01  2e-06  6e-17  7e-01
 5: -7.6224e-01 -7.6224e-01  2e-06  1e-16  7e-01
 6: -7.6245e-01 -7.6245e-01  4e-06  3e-16  7e-01
 7: -7.6279e-01 -7.6280e-01  8e-06  8e-17  7e-01
 8: -7.6291e-01 -7.6292e-01  1e-05  3e-16  7e-01
 9: -7.6326e-01 -7.6328e-01  2e-05  2e-16  6e-01
10: -7.6387e-01 -7.6391e-01  3e-05  8e-17  6e-01
11:

13: -8.2334e-01 -8.2347e-01  1e-04  5e-16  9e-02
14: -8.2336e-01 -8.2349e-01  1e-04  5e-16  9e-02
15: -8.2335e-01 -8.2348e-01  1e-04  5e-16  9e-02
16: -8.2380e-01 -8.2399e-01  2e-04  4e-16  8e-02
17: -8.2376e-01 -8.2395e-01  2e-04  4e-16  8e-02
18: -8.2391e-01 -8.2414e-01  2e-04  4e-16  7e-02
19: -8.2401e-01 -8.2426e-01  2e-04  4e-16  7e-02
20: -8.2439e-01 -8.2469e-01  3e-04  4e-16  5e-02
21: -8.2446e-01 -8.2481e-01  4e-04  3e-16  5e-02
22: -8.2478e-01 -8.2513e-01  4e-04  2e-16  3e-02
23: -8.2477e-01 -8.2513e-01  4e-04  2e-16  3e-02
24: -8.2493e-01 -8.2529e-01  4e-04  1e-16  2e-02
25: -8.2494e-01 -8.2530e-01  4e-04  2e-16  2e-02
26: -8.2500e-01 -8.2533e-01  3e-04  7e-17  1e-02
27: -8.2507e-01 -8.2535e-01  3e-04  7e-17  2e-03
28: -8.2519e-01 -8.2521e-01  1e-05  7e-17  9e-05
29: -8.2520e-01 -8.2520e-01  8e-07  2e-16  6e-06
30: -8.2520e-01 -8.2520e-01  4e-08  3e-16  6e-08
Optimal solution found.
Update LP...
Solve LP...
delta obj ratio: 2.83e+06

Iteration 7...
Update QP...
Solve QP...
  

### mi-Net

In [1]:
import sys
import time
from random import shuffle
import numpy as np
import argparse

from keras.models import Model
from keras.optimizers import SGD
from keras.regularizers import l2
from keras.layers import Input, Dense, Layer, Dropout

from mil_nets.dataset import load_dataset
from mil_nets.layer import Score_pooling
from mil_nets.metrics import bag_accuracy
from mil_nets.objectives import bag_loss
from mil_nets.utils import convertToBatch

In [22]:
def test_eval(model, test_set):
    """Evaluate on testing set.
    Parameters
    -----------------
    model : keras.engine.training.Model object
        The training mi-Net model.
    test_set : list
        A list of testing set contains all training bags features and labels.
    Returns
    -----------------
    test_loss : float
        Mean loss of evaluating on testing set.
    test_acc : float
        Mean accuracy of evaluating on testing set.
    """
    num_test_batch = len(test_set)
    test_loss = np.zeros((num_test_batch, 1), dtype=np.float32)
    test_acc = np.zeros((num_test_batch, 1), dtype=np.float32)
    for ibatch, batch in enumerate(test_set):
        result = model.test_on_batch({'input':batch[0].astype(np.float32)}, {'sp':batch[1].astype(np.float32)})
        test_loss[ibatch] = result[0]
        test_acc[ibatch] = result[1]
    return np.mean(test_loss), np.mean(test_acc)

def train_eval(model, train_set):
    """Evaluate on training set.
    Parameters
    -----------------
    model : keras.engine.training.Model object
        The training mi-Net model.
    train_set : list
        A list of training set contains all training bags features and labels.
    Returns
    -----------------
    test_loss : float
        Mean loss of evaluating on traing set.
    test_acc : float
        Mean accuracy of evaluating on testing set.
    """
    num_train_batch = len(train_set)
    train_loss = np.zeros((num_train_batch, 1), dtype=np.float32)
    train_acc = np.zeros((num_train_batch, 1), dtype=np.float32)
    shuffle(train_set)
    for ibatch, batch in enumerate(train_set):
        result = model.train_on_batch({'input':batch[0].astype(np.float32)}, {'sp':batch[1].astype(np.float32)})
        train_loss[ibatch] = result[0]
        train_acc[ibatch] = result[1]
    return np.mean(train_loss), np.mean(train_acc)

def mi_Net(dataset):
    weight_decay=0.005
    init_lr=5e-4
    pooling_mode='max'
    momentum=0.9
    max_epoch=50
    """Train and evaluate on mi-Net.
    Parameters
    -----------------
    dataset : dict
        A dictionary contains all dataset information. We split train/test by keys.
    Returns
    -----------------
    test_acc : float
        Testing accuracy of mi-Net.
    """
    # load data and convert type
    train_bags = dataset['train']
    test_bags = dataset['test']

    # convert bag to batch
    train_set = convertToBatch(train_bags)
    test_set = convertToBatch(test_bags)
    dimension = train_set[0][0].shape[1]

    # data: instance feature, n*d, n = number of training instance
    data_input = Input(shape=(dimension,), dtype='float32', name='input')

    # fully-connected
    fc1 = Dense(256, activation='relu', kernel_regularizer=l2(weight_decay))(data_input)
    fc2 = Dense(128, activation='relu', kernel_regularizer=l2(weight_decay))(fc1)
    fc3 = Dense(64, activation='relu', kernel_regularizer=l2(weight_decay))(fc2)

    # dropout
    dropout = Dropout(rate=0.5)(fc3)

    # score pooling
    sp = Score_pooling(output_dim=1, kernel_regularizer=l2(weight_decay), pooling_mode=pooling_mode, name='sp')(dropout)

    model = Model(inputs=[data_input], outputs=[sp])
    sgd = SGD(lr=init_lr, decay=1e-4, momentum=momentum, nesterov=True)
    model.compile(loss=bag_loss, optimizer=sgd, metrics=[bag_accuracy])

    # train model
    t1 = time.time()
    num_batch = len(train_set)
    for epoch in range(max_epoch):
        train_loss, train_acc = train_eval(model, train_set)
        test_loss, test_acc = test_eval(model, test_set)
        print('epoch=', epoch, '  train_loss= {:.3f}'.format(train_loss), '  train_acc= {:.3f}'.format(train_acc), '  test_loss={:.3f}'.format(test_loss), '  test_acc= {:.3f}'.format(test_acc))
    t2 = time.time()
    print('run time:', (t2-t1) / 60.0, 'min')
    print('test_acc={:.3f}'.format(test_acc))

    return test_acc


In [23]:
# perform five times 10-fold cross-validation experiments
run = 5
n_folds = 10
acc = np.zeros((run, n_folds), dtype=np.float32)
for irun in range(run):
    dataset = load_dataset('musk1', n_folds)
    for ifold in range(n_folds):
        print('run=', irun, '  fold=', ifold)
        acc[irun][ifold] = mi_Net(dataset[ifold])
print('mi-net mean accuracy = ', np.mean(acc))
print('std = ', np.std(acc))

MUSK-188     
MUSK-188     
MUSK-188     
MUSK-188     
MUSK-190     
MUSK-190     
MUSK-190     
MUSK-190     
MUSK-211     
MUSK-211     
MUSK-212     
MUSK-212     
MUSK-212     
MUSK-213     
MUSK-213     
MUSK-213     
MUSK-213     
MUSK-219     
MUSK-219     
MUSK-224     
MUSK-224     
MUSK-227     
MUSK-227     
MUSK-228     
MUSK-228     
MUSK-228     
MUSK-228     
MUSK-228     
MUSK-236     
MUSK-236     
MUSK-236     
MUSK-236     
MUSK-236     
MUSK-236     
MUSK-238     
MUSK-238     
MUSK-238     
MUSK-238     
MUSK-238     
MUSK-240     
MUSK-240     
MUSK-240     
MUSK-240     
MUSK-240     
MUSK-240     
MUSK-240     
MUSK-240     
MUSK-246     
MUSK-246     
MUSK-246     
MUSK-246     
MUSK-254     
MUSK-254     
MUSK-256     
MUSK-256     
MUSK-256     
MUSK-256     
MUSK-272     
MUSK-272     
MUSK-272     
MUSK-273     
MUSK-273     
MUSK-273     
MUSK-273     
MUSK-273     
MUSK-284     
MUSK-284     
MUSK-284     
MUSK-284     
MUSK-285     
MUSK-285     
MUSK-2

epoch= 18   train_loss= 2.081   train_acc= 1.000   test_loss=2.426   test_acc= 0.900
epoch= 19   train_loss= 2.059   train_acc= 1.000   test_loss=2.396   test_acc= 0.900
epoch= 20   train_loss= 2.065   train_acc= 0.976   test_loss=2.386   test_acc= 0.900
epoch= 21   train_loss= 2.039   train_acc= 0.976   test_loss=2.394   test_acc= 0.900
epoch= 22   train_loss= 2.029   train_acc= 1.000   test_loss=2.429   test_acc= 0.900
epoch= 23   train_loss= 2.010   train_acc= 0.976   test_loss=2.376   test_acc= 0.900
epoch= 24   train_loss= 1.978   train_acc= 1.000   test_loss=2.356   test_acc= 0.900
epoch= 25   train_loss= 1.962   train_acc= 1.000   test_loss=2.370   test_acc= 0.900
epoch= 26   train_loss= 1.946   train_acc= 1.000   test_loss=2.366   test_acc= 0.900
epoch= 27   train_loss= 1.941   train_acc= 1.000   test_loss=2.373   test_acc= 0.900
epoch= 28   train_loss= 1.922   train_acc= 1.000   test_loss=2.357   test_acc= 0.900
epoch= 29   train_loss= 1.907   train_acc= 1.000   test_loss=2.33

epoch= 14   train_loss= 2.198   train_acc= 0.976   test_loss=2.442   test_acc= 0.778
epoch= 15   train_loss= 2.180   train_acc= 0.976   test_loss=2.355   test_acc= 0.778
epoch= 16   train_loss= 2.147   train_acc= 0.964   test_loss=2.329   test_acc= 0.778
epoch= 17   train_loss= 2.141   train_acc= 0.988   test_loss=2.366   test_acc= 0.667
epoch= 18   train_loss= 2.141   train_acc= 0.964   test_loss=2.397   test_acc= 0.778
epoch= 19   train_loss= 2.089   train_acc= 0.988   test_loss=2.366   test_acc= 0.778
epoch= 20   train_loss= 2.079   train_acc= 0.976   test_loss=2.333   test_acc= 0.778
epoch= 21   train_loss= 2.063   train_acc= 0.976   test_loss=2.295   test_acc= 0.778
epoch= 22   train_loss= 2.044   train_acc= 1.000   test_loss=2.254   test_acc= 0.778
epoch= 23   train_loss= 2.030   train_acc= 0.976   test_loss=2.209   test_acc= 0.889
epoch= 24   train_loss= 2.014   train_acc= 1.000   test_loss=2.221   test_acc= 0.778
epoch= 25   train_loss= 1.978   train_acc= 1.000   test_loss=2.21

epoch= 10   train_loss= 2.322   train_acc= 0.940   test_loss=2.306   test_acc= 0.889
epoch= 11   train_loss= 2.300   train_acc= 0.952   test_loss=2.276   test_acc= 0.889
epoch= 12   train_loss= 2.265   train_acc= 0.964   test_loss=2.326   test_acc= 0.889
epoch= 13   train_loss= 2.245   train_acc= 0.976   test_loss=2.291   test_acc= 0.889
epoch= 14   train_loss= 2.179   train_acc= 0.988   test_loss=2.261   test_acc= 0.889
epoch= 15   train_loss= 2.181   train_acc= 0.952   test_loss=2.224   test_acc= 0.889
epoch= 16   train_loss= 2.165   train_acc= 0.976   test_loss=2.276   test_acc= 0.889
epoch= 17   train_loss= 2.138   train_acc= 0.976   test_loss=2.307   test_acc= 0.889
epoch= 18   train_loss= 2.099   train_acc= 0.988   test_loss=2.221   test_acc= 0.889
epoch= 19   train_loss= 2.072   train_acc= 1.000   test_loss=2.234   test_acc= 0.889
epoch= 20   train_loss= 2.062   train_acc= 1.000   test_loss=2.244   test_acc= 0.889
epoch= 21   train_loss= 2.043   train_acc= 0.988   test_loss=2.19

epoch= 6   train_loss= 2.495   train_acc= 0.867   test_loss=2.423   test_acc= 0.889
epoch= 7   train_loss= 2.424   train_acc= 0.928   test_loss=2.344   test_acc= 1.000
epoch= 8   train_loss= 2.427   train_acc= 0.880   test_loss=2.349   test_acc= 0.889
epoch= 9   train_loss= 2.384   train_acc= 0.940   test_loss=2.274   test_acc= 1.000
epoch= 10   train_loss= 2.336   train_acc= 0.916   test_loss=2.262   test_acc= 1.000
epoch= 11   train_loss= 2.274   train_acc= 0.964   test_loss=2.232   test_acc= 1.000
epoch= 12   train_loss= 2.259   train_acc= 0.940   test_loss=2.262   test_acc= 0.889
epoch= 13   train_loss= 2.225   train_acc= 0.964   test_loss=2.160   test_acc= 1.000
epoch= 14   train_loss= 2.190   train_acc= 0.964   test_loss=2.307   test_acc= 0.889
epoch= 15   train_loss= 2.192   train_acc= 0.952   test_loss=2.262   test_acc= 0.889
epoch= 16   train_loss= 2.154   train_acc= 0.988   test_loss=2.142   test_acc= 1.000
epoch= 17   train_loss= 2.122   train_acc= 0.988   test_loss=2.128   

epoch= 2   train_loss= 2.695   train_acc= 0.807   test_loss=2.877   test_acc= 0.667
epoch= 3   train_loss= 2.644   train_acc= 0.807   test_loss=2.733   test_acc= 0.889
epoch= 4   train_loss= 2.545   train_acc= 0.867   test_loss=2.716   test_acc= 0.889
epoch= 5   train_loss= 2.510   train_acc= 0.855   test_loss=2.643   test_acc= 0.778
epoch= 6   train_loss= 2.435   train_acc= 0.904   test_loss=2.681   test_acc= 0.778
epoch= 7   train_loss= 2.402   train_acc= 0.892   test_loss=2.661   test_acc= 0.778
epoch= 8   train_loss= 2.375   train_acc= 0.916   test_loss=2.627   test_acc= 0.889
epoch= 9   train_loss= 2.312   train_acc= 0.952   test_loss=2.609   test_acc= 0.778
epoch= 10   train_loss= 2.306   train_acc= 0.940   test_loss=2.542   test_acc= 0.889
epoch= 11   train_loss= 2.287   train_acc= 0.952   test_loss=2.620   test_acc= 0.778
epoch= 12   train_loss= 2.260   train_acc= 0.952   test_loss=2.570   test_acc= 0.889
epoch= 13   train_loss= 2.223   train_acc= 0.976   test_loss=2.541   test

epoch= 48   train_loss= 1.708   train_acc= 1.000   test_loss=1.938   test_acc= 0.889
epoch= 49   train_loss= 1.695   train_acc= 1.000   test_loss=1.952   test_acc= 0.889
run time: 0.7161436160405477 min
test_acc=0.889
MUSK-188     
MUSK-188     
MUSK-188     
MUSK-188     
MUSK-190     
MUSK-190     
MUSK-190     
MUSK-190     
MUSK-211     
MUSK-211     
MUSK-212     
MUSK-212     
MUSK-212     
MUSK-213     
MUSK-213     
MUSK-213     
MUSK-213     
MUSK-219     
MUSK-219     
MUSK-224     
MUSK-224     
MUSK-227     
MUSK-227     
MUSK-228     
MUSK-228     
MUSK-228     
MUSK-228     
MUSK-228     
MUSK-236     
MUSK-236     
MUSK-236     
MUSK-236     
MUSK-236     
MUSK-236     
MUSK-238     
MUSK-238     
MUSK-238     
MUSK-238     
MUSK-238     
MUSK-240     
MUSK-240     
MUSK-240     
MUSK-240     
MUSK-240     
MUSK-240     
MUSK-240     
MUSK-240     
MUSK-246     
MUSK-246     
MUSK-246     
MUSK-246     
MUSK-254     
MUSK-254     
MUSK-256     
MUSK-256     
MUSK-256    

epoch= 16   train_loss= 2.172   train_acc= 0.976   test_loss=2.296   test_acc= 0.900
epoch= 17   train_loss= 2.149   train_acc= 0.988   test_loss=2.268   test_acc= 0.900
epoch= 18   train_loss= 2.139   train_acc= 0.988   test_loss=2.238   test_acc= 0.900
epoch= 19   train_loss= 2.085   train_acc= 1.000   test_loss=2.233   test_acc= 0.900
epoch= 20   train_loss= 2.083   train_acc= 0.988   test_loss=2.249   test_acc= 0.900
epoch= 21   train_loss= 2.067   train_acc= 1.000   test_loss=2.214   test_acc= 0.900
epoch= 22   train_loss= 2.048   train_acc= 0.988   test_loss=2.228   test_acc= 0.900
epoch= 23   train_loss= 2.028   train_acc= 1.000   test_loss=2.135   test_acc= 0.900
epoch= 24   train_loss= 2.021   train_acc= 1.000   test_loss=2.209   test_acc= 0.900
epoch= 25   train_loss= 1.982   train_acc= 1.000   test_loss=2.106   test_acc= 0.900
epoch= 26   train_loss= 1.982   train_acc= 1.000   test_loss=2.142   test_acc= 0.900
epoch= 27   train_loss= 1.974   train_acc= 1.000   test_loss=2.05

epoch= 12   train_loss= 2.254   train_acc= 0.952   test_loss=2.191   test_acc= 1.000
epoch= 13   train_loss= 2.217   train_acc= 0.952   test_loss=2.187   test_acc= 0.889
epoch= 14   train_loss= 2.203   train_acc= 0.928   test_loss=2.151   test_acc= 1.000
epoch= 15   train_loss= 2.185   train_acc= 0.964   test_loss=2.155   test_acc= 1.000
epoch= 16   train_loss= 2.133   train_acc= 0.976   test_loss=2.147   test_acc= 1.000
epoch= 17   train_loss= 2.118   train_acc= 0.988   test_loss=2.114   test_acc= 1.000
epoch= 18   train_loss= 2.088   train_acc= 0.988   test_loss=2.100   test_acc= 1.000
epoch= 19   train_loss= 2.047   train_acc= 1.000   test_loss=2.069   test_acc= 1.000
epoch= 20   train_loss= 2.049   train_acc= 1.000   test_loss=2.084   test_acc= 0.889
epoch= 21   train_loss= 2.024   train_acc= 0.988   test_loss=2.051   test_acc= 1.000
epoch= 22   train_loss= 2.003   train_acc= 1.000   test_loss=2.037   test_acc= 1.000
epoch= 23   train_loss= 2.001   train_acc= 0.988   test_loss=2.00

epoch= 8   train_loss= 2.398   train_acc= 0.904   test_loss=2.525   test_acc= 0.667
epoch= 9   train_loss= 2.349   train_acc= 0.940   test_loss=2.508   test_acc= 0.667
epoch= 10   train_loss= 2.317   train_acc= 0.904   test_loss=2.488   test_acc= 0.889
epoch= 11   train_loss= 2.301   train_acc= 0.952   test_loss=2.496   test_acc= 0.667
epoch= 12   train_loss= 2.273   train_acc= 0.940   test_loss=2.470   test_acc= 0.667
epoch= 13   train_loss= 2.228   train_acc= 0.964   test_loss=2.466   test_acc= 0.667
epoch= 14   train_loss= 2.198   train_acc= 0.964   test_loss=2.444   test_acc= 0.667
epoch= 15   train_loss= 2.188   train_acc= 0.964   test_loss=2.460   test_acc= 0.667
epoch= 16   train_loss= 2.169   train_acc= 0.976   test_loss=2.442   test_acc= 0.667
epoch= 17   train_loss= 2.117   train_acc= 0.988   test_loss=2.356   test_acc= 0.778
epoch= 18   train_loss= 2.099   train_acc= 1.000   test_loss=2.365   test_acc= 0.778
epoch= 19   train_loss= 2.091   train_acc= 0.976   test_loss=2.345 

epoch= 4   train_loss= 2.566   train_acc= 0.819   test_loss=2.527   test_acc= 0.889
epoch= 5   train_loss= 2.515   train_acc= 0.892   test_loss=2.493   test_acc= 0.889
epoch= 6   train_loss= 2.461   train_acc= 0.855   test_loss=2.476   test_acc= 0.889
epoch= 7   train_loss= 2.432   train_acc= 0.867   test_loss=2.466   test_acc= 0.889
epoch= 8   train_loss= 2.382   train_acc= 0.940   test_loss=2.371   test_acc= 0.889
epoch= 9   train_loss= 2.340   train_acc= 0.964   test_loss=2.344   test_acc= 0.889
epoch= 10   train_loss= 2.322   train_acc= 0.928   test_loss=2.434   test_acc= 0.778
epoch= 11   train_loss= 2.314   train_acc= 0.928   test_loss=2.303   test_acc= 1.000
epoch= 12   train_loss= 2.275   train_acc= 0.940   test_loss=2.311   test_acc= 0.889
epoch= 13   train_loss= 2.228   train_acc= 0.964   test_loss=2.313   test_acc= 0.889
epoch= 14   train_loss= 2.227   train_acc= 0.964   test_loss=2.255   test_acc= 0.889
epoch= 15   train_loss= 2.179   train_acc= 0.964   test_loss=2.251   te

epoch= 0   train_loss= 2.983   train_acc= 0.494   test_loss=2.813   test_acc= 0.778
epoch= 1   train_loss= 2.831   train_acc= 0.747   test_loss=2.784   test_acc= 0.778
epoch= 2   train_loss= 2.733   train_acc= 0.807   test_loss=2.728   test_acc= 0.667
epoch= 3   train_loss= 2.621   train_acc= 0.880   test_loss=2.720   test_acc= 0.667
epoch= 4   train_loss= 2.600   train_acc= 0.843   test_loss=2.652   test_acc= 0.667
epoch= 5   train_loss= 2.504   train_acc= 0.904   test_loss=2.699   test_acc= 0.667
epoch= 6   train_loss= 2.455   train_acc= 0.904   test_loss=2.692   test_acc= 0.667
epoch= 7   train_loss= 2.403   train_acc= 0.940   test_loss=2.718   test_acc= 0.667
epoch= 8   train_loss= 2.382   train_acc= 0.916   test_loss=2.717   test_acc= 0.667
epoch= 9   train_loss= 2.358   train_acc= 0.964   test_loss=2.503   test_acc= 0.778
epoch= 10   train_loss= 2.349   train_acc= 0.952   test_loss=2.584   test_acc= 0.667
epoch= 11   train_loss= 2.278   train_acc= 0.952   test_loss=2.596   test_a

epoch= 46   train_loss= 1.727   train_acc= 1.000   test_loss=1.860   test_acc= 0.889
epoch= 47   train_loss= 1.719   train_acc= 1.000   test_loss=1.839   test_acc= 0.889
epoch= 48   train_loss= 1.702   train_acc= 1.000   test_loss=1.826   test_acc= 0.889
epoch= 49   train_loss= 1.701   train_acc= 1.000   test_loss=1.820   test_acc= 0.889
run time: 0.7510839025179545 min
test_acc=0.889
MUSK-188     
MUSK-188     
MUSK-188     
MUSK-188     
MUSK-190     
MUSK-190     
MUSK-190     
MUSK-190     
MUSK-211     
MUSK-211     
MUSK-212     
MUSK-212     
MUSK-212     
MUSK-213     
MUSK-213     
MUSK-213     
MUSK-213     
MUSK-219     
MUSK-219     
MUSK-224     
MUSK-224     
MUSK-227     
MUSK-227     
MUSK-228     
MUSK-228     
MUSK-228     
MUSK-228     
MUSK-228     
MUSK-236     
MUSK-236     
MUSK-236     
MUSK-236     
MUSK-236     
MUSK-236     
MUSK-238     
MUSK-238     
MUSK-238     
MUSK-238     
MUSK-238     
MUSK-240     
MUSK-240     
MUSK-240     
MUSK-240     
MUSK-240  

epoch= 14   train_loss= 2.185   train_acc= 0.988   test_loss=2.416   test_acc= 0.800
epoch= 15   train_loss= 2.162   train_acc= 0.976   test_loss=2.437   test_acc= 0.800
epoch= 16   train_loss= 2.176   train_acc= 0.951   test_loss=2.446   test_acc= 0.800
epoch= 17   train_loss= 2.132   train_acc= 0.988   test_loss=2.364   test_acc= 0.800
epoch= 18   train_loss= 2.106   train_acc= 0.988   test_loss=2.342   test_acc= 0.800
epoch= 19   train_loss= 2.096   train_acc= 1.000   test_loss=2.393   test_acc= 0.800
epoch= 20   train_loss= 2.060   train_acc= 1.000   test_loss=2.354   test_acc= 0.800
epoch= 21   train_loss= 2.046   train_acc= 0.988   test_loss=2.327   test_acc= 0.800
epoch= 22   train_loss= 2.054   train_acc= 0.976   test_loss=2.312   test_acc= 0.800
epoch= 23   train_loss= 2.029   train_acc= 1.000   test_loss=2.321   test_acc= 0.800
epoch= 24   train_loss= 2.003   train_acc= 1.000   test_loss=2.302   test_acc= 0.800
epoch= 25   train_loss= 1.981   train_acc= 1.000   test_loss=2.25

epoch= 10   train_loss= 2.314   train_acc= 0.940   test_loss=2.432   test_acc= 0.778
epoch= 11   train_loss= 2.304   train_acc= 0.928   test_loss=2.452   test_acc= 0.889
epoch= 12   train_loss= 2.285   train_acc= 0.964   test_loss=2.276   test_acc= 1.000
epoch= 13   train_loss= 2.273   train_acc= 0.940   test_loss=2.422   test_acc= 0.889
epoch= 14   train_loss= 2.232   train_acc= 0.952   test_loss=2.380   test_acc= 0.889
epoch= 15   train_loss= 2.204   train_acc= 0.988   test_loss=2.254   test_acc= 0.889
epoch= 16   train_loss= 2.167   train_acc= 0.976   test_loss=2.231   test_acc= 0.889
epoch= 17   train_loss= 2.131   train_acc= 0.988   test_loss=2.248   test_acc= 0.889
epoch= 18   train_loss= 2.124   train_acc= 0.976   test_loss=2.183   test_acc= 1.000
epoch= 19   train_loss= 2.078   train_acc= 0.988   test_loss=2.280   test_acc= 0.889
epoch= 20   train_loss= 2.076   train_acc= 0.988   test_loss=2.184   test_acc= 0.889
epoch= 21   train_loss= 2.050   train_acc= 0.988   test_loss=2.20

epoch= 6   train_loss= 2.492   train_acc= 0.916   test_loss=2.532   test_acc= 0.778
epoch= 7   train_loss= 2.495   train_acc= 0.855   test_loss=2.493   test_acc= 0.778
epoch= 8   train_loss= 2.400   train_acc= 0.916   test_loss=2.439   test_acc= 1.000
epoch= 9   train_loss= 2.384   train_acc= 0.928   test_loss=2.439   test_acc= 0.889
epoch= 10   train_loss= 2.335   train_acc= 0.928   test_loss=2.394   test_acc= 0.889
epoch= 11   train_loss= 2.305   train_acc= 0.904   test_loss=2.370   test_acc= 0.889
epoch= 12   train_loss= 2.300   train_acc= 0.916   test_loss=2.401   test_acc= 0.889
epoch= 13   train_loss= 2.258   train_acc= 0.916   test_loss=2.325   test_acc= 0.889
epoch= 14   train_loss= 2.214   train_acc= 0.988   test_loss=2.317   test_acc= 0.889
epoch= 15   train_loss= 2.180   train_acc= 0.964   test_loss=2.300   test_acc= 0.889
epoch= 16   train_loss= 2.189   train_acc= 0.928   test_loss=2.345   test_acc= 0.889
epoch= 17   train_loss= 2.158   train_acc= 0.952   test_loss=2.332   

epoch= 2   train_loss= 2.714   train_acc= 0.819   test_loss=2.747   test_acc= 0.778
epoch= 3   train_loss= 2.641   train_acc= 0.831   test_loss=2.730   test_acc= 0.778
epoch= 4   train_loss= 2.620   train_acc= 0.843   test_loss=2.707   test_acc= 0.778
epoch= 5   train_loss= 2.530   train_acc= 0.880   test_loss=2.662   test_acc= 0.778
epoch= 6   train_loss= 2.487   train_acc= 0.892   test_loss=2.677   test_acc= 0.778
epoch= 7   train_loss= 2.434   train_acc= 0.916   test_loss=2.646   test_acc= 0.778
epoch= 8   train_loss= 2.397   train_acc= 0.952   test_loss=2.646   test_acc= 0.778
epoch= 9   train_loss= 2.371   train_acc= 0.940   test_loss=2.606   test_acc= 0.778
epoch= 10   train_loss= 2.354   train_acc= 0.952   test_loss=2.586   test_acc= 0.778
epoch= 11   train_loss= 2.316   train_acc= 0.940   test_loss=2.547   test_acc= 0.778
epoch= 12   train_loss= 2.262   train_acc= 0.952   test_loss=2.547   test_acc= 0.778
epoch= 13   train_loss= 2.246   train_acc= 0.940   test_loss=2.548   test

epoch= 48   train_loss= 1.697   train_acc= 1.000   test_loss=2.287   test_acc= 0.778
epoch= 49   train_loss= 1.684   train_acc= 1.000   test_loss=2.285   test_acc= 0.778
run time: 0.756281320254008 min
test_acc=0.778
run= 2   fold= 8
epoch= 0   train_loss= 2.909   train_acc= 0.687   test_loss=2.735   test_acc= 0.778
epoch= 1   train_loss= 2.803   train_acc= 0.771   test_loss=2.616   test_acc= 0.889
epoch= 2   train_loss= 2.685   train_acc= 0.819   test_loss=2.572   test_acc= 0.778
epoch= 3   train_loss= 2.583   train_acc= 0.867   test_loss=2.544   test_acc= 0.778
epoch= 4   train_loss= 2.558   train_acc= 0.855   test_loss=2.549   test_acc= 0.778
epoch= 5   train_loss= 2.516   train_acc= 0.892   test_loss=2.549   test_acc= 0.778
epoch= 6   train_loss= 2.433   train_acc= 0.904   test_loss=2.527   test_acc= 0.778
epoch= 7   train_loss= 2.389   train_acc= 0.904   test_loss=2.517   test_acc= 0.778
epoch= 8   train_loss= 2.373   train_acc= 0.916   test_loss=2.564   test_acc= 0.778
epoch= 9  

epoch= 44   train_loss= 1.745   train_acc= 1.000   test_loss=3.650   test_acc= 0.556
epoch= 45   train_loss= 1.737   train_acc= 1.000   test_loss=3.719   test_acc= 0.556
epoch= 46   train_loss= 1.724   train_acc= 1.000   test_loss=3.693   test_acc= 0.556
epoch= 47   train_loss= 1.716   train_acc= 1.000   test_loss=3.753   test_acc= 0.556
epoch= 48   train_loss= 1.709   train_acc= 1.000   test_loss=3.771   test_acc= 0.556
epoch= 49   train_loss= 1.694   train_acc= 1.000   test_loss=3.698   test_acc= 0.556
run time: 0.6951592485109965 min
test_acc=0.556
MUSK-188     
MUSK-188     
MUSK-188     
MUSK-188     
MUSK-190     
MUSK-190     
MUSK-190     
MUSK-190     
MUSK-211     
MUSK-211     
MUSK-212     
MUSK-212     
MUSK-212     
MUSK-213     
MUSK-213     
MUSK-213     
MUSK-213     
MUSK-219     
MUSK-219     
MUSK-224     
MUSK-224     
MUSK-227     
MUSK-227     
MUSK-228     
MUSK-228     
MUSK-228     
MUSK-228     
MUSK-228     
MUSK-236     
MUSK-236     
MUSK-236     
MUSK-236

epoch= 12   train_loss= 2.257   train_acc= 0.963   test_loss=2.399   test_acc= 0.900
epoch= 13   train_loss= 2.257   train_acc= 0.939   test_loss=2.362   test_acc= 0.900
epoch= 14   train_loss= 2.197   train_acc= 0.951   test_loss=2.368   test_acc= 0.800
epoch= 15   train_loss= 2.177   train_acc= 0.976   test_loss=2.264   test_acc= 0.900
epoch= 16   train_loss= 2.154   train_acc= 0.976   test_loss=2.277   test_acc= 0.900
epoch= 17   train_loss= 2.143   train_acc= 0.988   test_loss=2.308   test_acc= 0.900
epoch= 18   train_loss= 2.115   train_acc= 0.988   test_loss=2.269   test_acc= 0.900
epoch= 19   train_loss= 2.106   train_acc= 0.963   test_loss=2.256   test_acc= 0.900
epoch= 20   train_loss= 2.072   train_acc= 0.988   test_loss=2.196   test_acc= 0.900
epoch= 21   train_loss= 2.054   train_acc= 0.988   test_loss=2.151   test_acc= 0.900
epoch= 22   train_loss= 2.023   train_acc= 1.000   test_loss=2.194   test_acc= 0.900
epoch= 23   train_loss= 2.014   train_acc= 0.988   test_loss=2.18

epoch= 8   train_loss= 2.404   train_acc= 0.928   test_loss=2.718   test_acc= 0.667
epoch= 9   train_loss= 2.348   train_acc= 0.892   test_loss=2.681   test_acc= 0.667
epoch= 10   train_loss= 2.303   train_acc= 0.976   test_loss=2.643   test_acc= 0.667
epoch= 11   train_loss= 2.317   train_acc= 0.904   test_loss=2.613   test_acc= 0.667
epoch= 12   train_loss= 2.255   train_acc= 0.964   test_loss=2.570   test_acc= 0.667
epoch= 13   train_loss= 2.262   train_acc= 0.952   test_loss=2.571   test_acc= 0.667
epoch= 14   train_loss= 2.194   train_acc= 0.964   test_loss=2.502   test_acc= 0.889
epoch= 15   train_loss= 2.161   train_acc= 0.964   test_loss=2.505   test_acc= 0.667
epoch= 16   train_loss= 2.152   train_acc= 0.988   test_loss=2.468   test_acc= 0.778
epoch= 17   train_loss= 2.117   train_acc= 1.000   test_loss=2.440   test_acc= 0.667
epoch= 18   train_loss= 2.106   train_acc= 0.976   test_loss=2.419   test_acc= 0.889
epoch= 19   train_loss= 2.079   train_acc= 1.000   test_loss=2.373 

epoch= 4   train_loss= 2.532   train_acc= 0.831   test_loss=2.841   test_acc= 0.444
epoch= 5   train_loss= 2.494   train_acc= 0.880   test_loss=2.851   test_acc= 0.444
epoch= 6   train_loss= 2.440   train_acc= 0.916   test_loss=2.886   test_acc= 0.667
epoch= 7   train_loss= 2.364   train_acc= 0.952   test_loss=2.902   test_acc= 0.667
epoch= 8   train_loss= 2.319   train_acc= 0.952   test_loss=2.915   test_acc= 0.444
epoch= 9   train_loss= 2.318   train_acc= 0.940   test_loss=2.897   test_acc= 0.556
epoch= 10   train_loss= 2.280   train_acc= 0.976   test_loss=2.946   test_acc= 0.667
epoch= 11   train_loss= 2.242   train_acc= 0.952   test_loss=2.987   test_acc= 0.667
epoch= 12   train_loss= 2.209   train_acc= 0.976   test_loss=2.935   test_acc= 0.556
epoch= 13   train_loss= 2.192   train_acc= 0.964   test_loss=2.996   test_acc= 0.667
epoch= 14   train_loss= 2.164   train_acc= 0.976   test_loss=3.012   test_acc= 0.556
epoch= 15   train_loss= 2.123   train_acc= 1.000   test_loss=3.049   te

epoch= 0   train_loss= 3.049   train_acc= 0.482   test_loss=3.040   test_acc= 0.444
epoch= 1   train_loss= 2.833   train_acc= 0.687   test_loss=3.046   test_acc= 0.444
epoch= 2   train_loss= 2.743   train_acc= 0.771   test_loss=2.874   test_acc= 0.667
epoch= 3   train_loss= 2.654   train_acc= 0.855   test_loss=2.879   test_acc= 0.333
epoch= 4   train_loss= 2.599   train_acc= 0.831   test_loss=2.803   test_acc= 0.556
epoch= 5   train_loss= 2.522   train_acc= 0.880   test_loss=2.843   test_acc= 0.444
epoch= 6   train_loss= 2.479   train_acc= 0.880   test_loss=2.738   test_acc= 0.778
epoch= 7   train_loss= 2.457   train_acc= 0.892   test_loss=2.904   test_acc= 0.444
epoch= 8   train_loss= 2.395   train_acc= 0.940   test_loss=2.790   test_acc= 0.556
epoch= 9   train_loss= 2.357   train_acc= 0.928   test_loss=2.880   test_acc= 0.444
epoch= 10   train_loss= 2.322   train_acc= 0.964   test_loss=2.827   test_acc= 0.667
epoch= 11   train_loss= 2.259   train_acc= 0.976   test_loss=2.780   test_a

epoch= 46   train_loss= 1.728   train_acc= 1.000   test_loss=2.918   test_acc= 0.778
epoch= 47   train_loss= 1.719   train_acc= 1.000   test_loss=2.979   test_acc= 0.667
epoch= 48   train_loss= 1.709   train_acc= 1.000   test_loss=2.961   test_acc= 0.667
epoch= 49   train_loss= 1.705   train_acc= 1.000   test_loss=2.930   test_acc= 0.667
run time: 0.7117309292157491 min
test_acc=0.667
run= 3   fold= 8
epoch= 0   train_loss= 2.948   train_acc= 0.602   test_loss=2.559   test_acc= 1.000
epoch= 1   train_loss= 2.809   train_acc= 0.747   test_loss=2.499   test_acc= 1.000
epoch= 2   train_loss= 2.718   train_acc= 0.819   test_loss=2.432   test_acc= 1.000
epoch= 3   train_loss= 2.653   train_acc= 0.807   test_loss=2.355   test_acc= 1.000
epoch= 4   train_loss= 2.622   train_acc= 0.831   test_loss=2.349   test_acc= 1.000
epoch= 5   train_loss= 2.551   train_acc= 0.795   test_loss=2.275   test_acc= 1.000
epoch= 6   train_loss= 2.495   train_acc= 0.892   test_loss=2.251   test_acc= 1.000
epoch= 

epoch= 42   train_loss= 1.779   train_acc= 1.000   test_loss=1.757   test_acc= 1.000
epoch= 43   train_loss= 1.757   train_acc= 1.000   test_loss=1.745   test_acc= 1.000
epoch= 44   train_loss= 1.752   train_acc= 1.000   test_loss=1.737   test_acc= 1.000
epoch= 45   train_loss= 1.736   train_acc= 1.000   test_loss=1.727   test_acc= 1.000
epoch= 46   train_loss= 1.730   train_acc= 1.000   test_loss=1.719   test_acc= 1.000
epoch= 47   train_loss= 1.717   train_acc= 1.000   test_loss=1.710   test_acc= 1.000
epoch= 48   train_loss= 1.705   train_acc= 1.000   test_loss=1.698   test_acc= 1.000
epoch= 49   train_loss= 1.709   train_acc= 1.000   test_loss=1.691   test_acc= 1.000
run time: 0.7496605674425761 min
test_acc=1.000
MUSK-188     
MUSK-188     
MUSK-188     
MUSK-188     
MUSK-190     
MUSK-190     
MUSK-190     
MUSK-190     
MUSK-211     
MUSK-211     
MUSK-212     
MUSK-212     
MUSK-212     
MUSK-213     
MUSK-213     
MUSK-213     
MUSK-213     
MUSK-219     
MUSK-219     
MUSK-2

epoch= 10   train_loss= 2.334   train_acc= 0.939   test_loss=2.647   test_acc= 0.700
epoch= 11   train_loss= 2.286   train_acc= 0.976   test_loss=2.456   test_acc= 0.900
epoch= 12   train_loss= 2.259   train_acc= 0.963   test_loss=2.568   test_acc= 0.900
epoch= 13   train_loss= 2.196   train_acc= 1.000   test_loss=2.513   test_acc= 0.900
epoch= 14   train_loss= 2.179   train_acc= 0.976   test_loss=2.522   test_acc= 0.800
epoch= 15   train_loss= 2.154   train_acc= 1.000   test_loss=2.554   test_acc= 0.800
epoch= 16   train_loss= 2.150   train_acc= 0.963   test_loss=2.380   test_acc= 0.900
epoch= 17   train_loss= 2.145   train_acc= 0.963   test_loss=2.469   test_acc= 0.900
epoch= 18   train_loss= 2.115   train_acc= 0.988   test_loss=2.392   test_acc= 0.900
epoch= 19   train_loss= 2.089   train_acc= 0.988   test_loss=2.497   test_acc= 0.800
epoch= 20   train_loss= 2.081   train_acc= 0.976   test_loss=2.403   test_acc= 0.900
epoch= 21   train_loss= 2.052   train_acc= 1.000   test_loss=2.42

epoch= 6   train_loss= 2.474   train_acc= 0.928   test_loss=2.486   test_acc= 0.778
epoch= 7   train_loss= 2.408   train_acc= 0.916   test_loss=2.483   test_acc= 0.778
epoch= 8   train_loss= 2.386   train_acc= 0.916   test_loss=2.438   test_acc= 0.778
epoch= 9   train_loss= 2.366   train_acc= 0.916   test_loss=2.483   test_acc= 0.778
epoch= 10   train_loss= 2.306   train_acc= 0.952   test_loss=2.484   test_acc= 0.778
epoch= 11   train_loss= 2.306   train_acc= 0.940   test_loss=2.488   test_acc= 0.778
epoch= 12   train_loss= 2.229   train_acc= 0.976   test_loss=2.457   test_acc= 0.778
epoch= 13   train_loss= 2.238   train_acc= 0.940   test_loss=2.411   test_acc= 0.778
epoch= 14   train_loss= 2.187   train_acc= 0.988   test_loss=2.457   test_acc= 0.778
epoch= 15   train_loss= 2.175   train_acc= 0.964   test_loss=2.387   test_acc= 0.778
epoch= 16   train_loss= 2.158   train_acc= 0.988   test_loss=2.475   test_acc= 0.778
epoch= 17   train_loss= 2.160   train_acc= 0.952   test_loss=2.373   

epoch= 2   train_loss= 2.699   train_acc= 0.855   test_loss=2.813   test_acc= 0.667
epoch= 3   train_loss= 2.655   train_acc= 0.807   test_loss=2.719   test_acc= 0.778
epoch= 4   train_loss= 2.545   train_acc= 0.880   test_loss=2.697   test_acc= 0.889
epoch= 5   train_loss= 2.500   train_acc= 0.904   test_loss=2.745   test_acc= 0.778
epoch= 6   train_loss= 2.469   train_acc= 0.892   test_loss=2.643   test_acc= 0.778
epoch= 7   train_loss= 2.425   train_acc= 0.904   test_loss=2.623   test_acc= 0.778
epoch= 8   train_loss= 2.401   train_acc= 0.904   test_loss=2.608   test_acc= 0.778
epoch= 9   train_loss= 2.341   train_acc= 0.952   test_loss=2.649   test_acc= 0.778
epoch= 10   train_loss= 2.309   train_acc= 0.940   test_loss=2.651   test_acc= 0.667
epoch= 11   train_loss= 2.284   train_acc= 0.940   test_loss=2.650   test_acc= 0.778
epoch= 12   train_loss= 2.252   train_acc= 0.952   test_loss=2.635   test_acc= 0.778
epoch= 13   train_loss= 2.202   train_acc= 0.976   test_loss=2.656   test

epoch= 48   train_loss= 1.698   train_acc= 1.000   test_loss=1.844   test_acc= 0.889
epoch= 49   train_loss= 1.693   train_acc= 1.000   test_loss=1.871   test_acc= 0.889
run time: 0.7094895680745442 min
test_acc=0.889
run= 4   fold= 6
epoch= 0   train_loss= 2.918   train_acc= 0.711   test_loss=2.843   test_acc= 0.778
epoch= 1   train_loss= 2.835   train_acc= 0.711   test_loss=2.783   test_acc= 0.667
epoch= 2   train_loss= 2.721   train_acc= 0.819   test_loss=2.692   test_acc= 0.889
epoch= 3   train_loss= 2.669   train_acc= 0.807   test_loss=2.647   test_acc= 0.778
epoch= 4   train_loss= 2.582   train_acc= 0.867   test_loss=2.607   test_acc= 0.889
epoch= 5   train_loss= 2.572   train_acc= 0.855   test_loss=2.546   test_acc= 0.889
epoch= 6   train_loss= 2.479   train_acc= 0.892   test_loss=2.504   test_acc= 0.889
epoch= 7   train_loss= 2.506   train_acc= 0.855   test_loss=2.493   test_acc= 0.889
epoch= 8   train_loss= 2.419   train_acc= 0.940   test_loss=2.471   test_acc= 0.889
epoch= 9 

epoch= 44   train_loss= 1.732   train_acc= 1.000   test_loss=3.326   test_acc= 0.778
epoch= 45   train_loss= 1.726   train_acc= 1.000   test_loss=3.319   test_acc= 0.778
epoch= 46   train_loss= 1.710   train_acc= 1.000   test_loss=3.318   test_acc= 0.778
epoch= 47   train_loss= 1.705   train_acc= 1.000   test_loss=3.291   test_acc= 0.778
epoch= 48   train_loss= 1.696   train_acc= 1.000   test_loss=3.289   test_acc= 0.778
epoch= 49   train_loss= 1.684   train_acc= 1.000   test_loss=3.270   test_acc= 0.778
run time: 0.7162322839101155 min
test_acc=0.778
run= 4   fold= 8
epoch= 0   train_loss= 3.014   train_acc= 0.627   test_loss=2.805   test_acc= 0.889
epoch= 1   train_loss= 2.784   train_acc= 0.735   test_loss=2.753   test_acc= 0.778
epoch= 2   train_loss= 2.655   train_acc= 0.855   test_loss=2.701   test_acc= 0.778
epoch= 3   train_loss= 2.597   train_acc= 0.867   test_loss=2.692   test_acc= 0.778
epoch= 4   train_loss= 2.503   train_acc= 0.880   test_loss=2.694   test_acc= 0.778
epoch

epoch= 40   train_loss= 1.807   train_acc= 1.000   test_loss=1.882   test_acc= 1.000
epoch= 41   train_loss= 1.796   train_acc= 1.000   test_loss=1.873   test_acc= 1.000
epoch= 42   train_loss= 1.787   train_acc= 1.000   test_loss=1.834   test_acc= 1.000
epoch= 43   train_loss= 1.777   train_acc= 1.000   test_loss=1.833   test_acc= 1.000
epoch= 44   train_loss= 1.771   train_acc= 1.000   test_loss=1.803   test_acc= 1.000
epoch= 45   train_loss= 1.756   train_acc= 1.000   test_loss=1.808   test_acc= 1.000
epoch= 46   train_loss= 1.751   train_acc= 1.000   test_loss=1.796   test_acc= 1.000
epoch= 47   train_loss= 1.734   train_acc= 1.000   test_loss=1.791   test_acc= 1.000
epoch= 48   train_loss= 1.733   train_acc= 1.000   test_loss=1.797   test_acc= 1.000
epoch= 49   train_loss= 1.719   train_acc= 1.000   test_loss=1.780   test_acc= 1.000
run time: 0.7182445844014486 min
test_acc=1.000
mi-net mean accuracy =  0.8300000059604645
std =  0.10402396848155698
