## Machine Learning project


In [5]:
pip install -r requirements.txt

Collecting git+https://github.com/chlorochrule/cknn (from -r requirements.txt (line 9))
  Cloning https://github.com/chlorochrule/cknn to /private/var/folders/dc/1bymglpd6198p4cqkyfcnrqw0000gn/T/pip-req-build-91ulfn55
  Running command git clone --filter=blob:none --quiet https://github.com/chlorochrule/cknn /private/var/folders/dc/1bymglpd6198p4cqkyfcnrqw0000gn/T/pip-req-build-91ulfn55
  Resolved https://github.com/chlorochrule/cknn to commit 7d05c5049da72a573bd486fca6647f8b0376243c
  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone




[0mNote: you may need to restart the kernel to use updated packages.


### Embedded-Space


### MI-net

In [1]:
import numpy as np
import sys
import time
import random
from random import shuffle
import argparse

from keras.models import Model
from keras.optimizers import SGD
from keras.regularizers import l2
from keras.layers import Input, Dense, Layer, Dropout

from mil_nets.dataset import load_dataset
from mil_nets.layer import Feature_pooling
from mil_nets.metrics import bag_accuracy
from mil_nets.objectives import bag_loss
from mil_nets.utils import convertToBatch

In [4]:
def test_eval(model, test_set):
    """Evaluate on testing set.
    Parameters
    -----------------
    model : keras.engine.training.Model object
        The training MI-Net model.
    test_set : list
        A list of testing set contains all training bags features and labels.
    Returns
    -----------------
    test_loss : float
        Mean loss of evaluating on testing set.
    test_acc : float
        Mean accuracy of evaluating on testing set.
    """
    num_test_batch = len(test_set)
    test_loss = np.zeros((num_test_batch, 1), dtype=np.float32)
    test_acc = np.zeros((num_test_batch, 1), dtype=np.float32)
    for ibatch, batch in enumerate(test_set):
        result = model.test_on_batch({'input':batch[0].astype(np.float32)}, {'fp':batch[1].astype(np.float32)})
        test_loss[ibatch] = result[0]
        test_acc[ibatch][0] = result[1]
    return np.mean(test_loss), np.mean(test_acc)

def train_eval(model, train_set):
    """Evaluate on training set.
    Parameters
    -----------------
    model : keras.engine.training.Model object
        The training MI-Net model.
    train_set : list
        A list of training set contains all training bags features and labels.
    Returns
    -----------------
    test_loss : float
        Mean loss of evaluating on traing set..astype(np.float32)
    test_acc : float
        Mean accuracy of evaluating on testing set.
    """
    num_train_batch = len(train_set)
    train_loss = np.zeros((num_train_batch, 1), dtype=np.float32)
    train_acc = np.zeros((num_train_batch, 1), dtype=np.float32)
    shuffle(train_set)
    for ibatch, batch in enumerate(train_set):
        result = model.train_on_batch({'input':batch[0].astype(np.float32)}, {'fp':batch[1].astype(np.float32)})
        train_loss[ibatch] = result[0]
        train_acc[ibatch][0] = result[1]
    return np.mean(train_loss), np.mean(train_acc)

def MI_Net(dataset):
    """Train and evaluate on MI-Net.
    Parameters
    -----------------
    dataset : dict
        A dictionary contains all dataset information. We split train/test by keys.
    Returns
    -----------------
    test_acc : float
        Testing accuracy of MI-Net.
    """
    weight_decay=0.005
    init_lr=5e-4
    pooling_mode='max'
    momentum=0.9
    max_epoch=50
    # load data and convert type
    train_bags = dataset['train']
    test_bags = dataset['test']

    # convert bag to batch
    train_set = convertToBatch(train_bags)
    test_set = convertToBatch(test_bags)
    dimension = train_set[0][0].shape[1]

    # data: instance feature, n*d, n = number of training instance
    data_input = Input(shape=(dimension,), dtype='float32', name='input')

    # fully-connected
    fc1 = Dense(256, activation='relu', kernel_regularizer=l2(weight_decay))(data_input)
    fc2 = Dense(128, activation='relu', kernel_regularizer=l2(weight_decay))(fc1)
    fc3 = Dense(64, activation='relu', kernel_regularizer=l2(weight_decay))(fc2)

    # dropout
    dropout = Dropout(rate=0.5)(fc3)

    # features pooling
    fp = Feature_pooling(output_dim=1, kernel_regularizer=l2(weight_decay), pooling_mode=pooling_mode, name='fp')(dropout)

    model = Model(inputs=[data_input], outputs=[fp])
    sgd = SGD(lr=init_lr, decay=1e-4, momentum=momentum, nesterov=True)
    model.compile(loss=bag_loss, optimizer=sgd, metrics=[bag_accuracy])

    # train model
    t1 = time.time()
    num_batch = len(train_set)
    for epoch in range(max_epoch):
        train_loss, train_acc = train_eval(model, train_set)
        test_loss, test_acc = test_eval(model, test_set)
        print('epoch=', epoch, '  train_loss= {:.3f}'.format(train_loss), '  train_acc= {:.3f}'.format(train_acc), '  test_loss={:.3f}'.format(test_loss), '  test_acc= {:.3f}'.format(test_acc))
    t2 = time.time()
    print('run time:', (t2-t1) / 60, 'min')
    print('test_acc={:.3f}'.format(test_acc))

    return test_acc

In [5]:
# perform five times 10-fold cross-validation experiments
run = 5
n_folds = 10
acc = np.zeros((run, n_folds), dtype=np.float32)
for irun in range(run):
    dataset = load_dataset('musk1', n_folds)
    for ifold in range(n_folds):
        print('run=', irun, '  fold=', ifold)
        acc[irun][ifold] = MI_Net(dataset[ifold])
print('MI-Net mean accuracy = ', np.mean(acc))
print('std = ', np.std(acc))

run= 0   fold= 0
epoch= 0   train_loss= 2.963   train_acc= 0.573   test_loss=3.051   test_acc= 0.300
epoch= 1   train_loss= 2.726   train_acc= 0.768   test_loss=2.876   test_acc= 0.700
epoch= 2   train_loss= 2.562   train_acc= 0.902   test_loss=2.802   test_acc= 0.700
epoch= 3   train_loss= 2.486   train_acc= 0.915   test_loss=2.670   test_acc= 0.900
epoch= 4   train_loss= 2.385   train_acc= 0.951   test_loss=2.698   test_acc= 0.700
epoch= 5   train_loss= 2.364   train_acc= 0.976   test_loss=2.616   test_acc= 0.900
epoch= 6   train_loss= 2.309   train_acc= 0.988   test_loss=2.605   test_acc= 0.900
epoch= 7   train_loss= 2.305   train_acc= 0.963   test_loss=2.635   test_acc= 0.700
epoch= 8   train_loss= 2.239   train_acc= 0.976   test_loss=2.522   test_acc= 0.900
epoch= 9   train_loss= 2.219   train_acc= 0.988   test_loss=2.574   test_acc= 0.700
epoch= 10   train_loss= 2.224   train_acc= 0.963   test_loss=2.443   test_acc= 0.900
epoch= 11   train_loss= 2.178   train_acc= 0.988   test_lo

epoch= 46   train_loss= 1.706   train_acc= 1.000   test_loss=1.914   test_acc= 0.900
epoch= 47   train_loss= 1.696   train_acc= 1.000   test_loss=1.907   test_acc= 0.900
epoch= 48   train_loss= 1.686   train_acc= 1.000   test_loss=1.902   test_acc= 0.900
epoch= 49   train_loss= 1.678   train_acc= 1.000   test_loss=1.889   test_acc= 0.900
run time: 0.7238585313161214 min
test_acc=0.900
run= 0   fold= 2
epoch= 0   train_loss= 2.857   train_acc= 0.675   test_loss=2.853   test_acc= 0.889
epoch= 1   train_loss= 2.585   train_acc= 0.892   test_loss=2.785   test_acc= 0.556
epoch= 2   train_loss= 2.485   train_acc= 0.916   test_loss=2.744   test_acc= 0.889
epoch= 3   train_loss= 2.393   train_acc= 0.976   test_loss=2.715   test_acc= 0.889
epoch= 4   train_loss= 2.349   train_acc= 0.976   test_loss=2.674   test_acc= 0.889
epoch= 5   train_loss= 2.280   train_acc= 1.000   test_loss=2.688   test_acc= 0.889
epoch= 6   train_loss= 2.294   train_acc= 0.952   test_loss=2.695   test_acc= 0.889
epoch= 

epoch= 42   train_loss= 1.755   train_acc= 1.000   test_loss=1.829   test_acc= 1.000
epoch= 43   train_loss= 1.740   train_acc= 1.000   test_loss=1.819   test_acc= 1.000
epoch= 44   train_loss= 1.732   train_acc= 1.000   test_loss=1.812   test_acc= 1.000
epoch= 45   train_loss= 1.721   train_acc= 1.000   test_loss=1.796   test_acc= 1.000
epoch= 46   train_loss= 1.712   train_acc= 1.000   test_loss=1.785   test_acc= 1.000
epoch= 47   train_loss= 1.707   train_acc= 1.000   test_loss=1.783   test_acc= 1.000
epoch= 48   train_loss= 1.697   train_acc= 1.000   test_loss=1.772   test_acc= 1.000
epoch= 49   train_loss= 1.683   train_acc= 1.000   test_loss=1.770   test_acc= 1.000
run time: 0.9416592796643575 min
test_acc=1.000
run= 0   fold= 4
epoch= 0   train_loss= 2.927   train_acc= 0.687   test_loss=2.666   test_acc= 1.000
epoch= 1   train_loss= 2.641   train_acc= 0.855   test_loss=2.574   test_acc= 1.000
epoch= 2   train_loss= 2.540   train_acc= 0.904   test_loss=2.477   test_acc= 1.000
epo

epoch= 38   train_loss= 1.780   train_acc= 1.000   test_loss=2.024   test_acc= 0.889
epoch= 39   train_loss= 1.771   train_acc= 1.000   test_loss=1.988   test_acc= 0.889
epoch= 40   train_loss= 1.761   train_acc= 1.000   test_loss=1.988   test_acc= 0.889
epoch= 41   train_loss= 1.756   train_acc= 1.000   test_loss=2.004   test_acc= 0.889
epoch= 42   train_loss= 1.738   train_acc= 1.000   test_loss=1.972   test_acc= 0.889
epoch= 43   train_loss= 1.728   train_acc= 1.000   test_loss=1.962   test_acc= 0.889
epoch= 44   train_loss= 1.717   train_acc= 1.000   test_loss=1.959   test_acc= 0.889
epoch= 45   train_loss= 1.708   train_acc= 1.000   test_loss=1.939   test_acc= 0.889
epoch= 46   train_loss= 1.698   train_acc= 1.000   test_loss=1.928   test_acc= 0.889
epoch= 47   train_loss= 1.690   train_acc= 1.000   test_loss=1.932   test_acc= 0.889
epoch= 48   train_loss= 1.679   train_acc= 1.000   test_loss=1.916   test_acc= 0.889
epoch= 49   train_loss= 1.667   train_acc= 1.000   test_loss=1.90

epoch= 34   train_loss= 1.830   train_acc= 1.000   test_loss=1.891   test_acc= 1.000
epoch= 35   train_loss= 1.816   train_acc= 1.000   test_loss=1.888   test_acc= 1.000
epoch= 36   train_loss= 1.803   train_acc= 1.000   test_loss=1.875   test_acc= 1.000
epoch= 37   train_loss= 1.794   train_acc= 1.000   test_loss=1.866   test_acc= 1.000
epoch= 38   train_loss= 1.782   train_acc= 1.000   test_loss=1.852   test_acc= 1.000
epoch= 39   train_loss= 1.774   train_acc= 1.000   test_loss=1.834   test_acc= 1.000
epoch= 40   train_loss= 1.765   train_acc= 1.000   test_loss=1.833   test_acc= 1.000
epoch= 41   train_loss= 1.755   train_acc= 1.000   test_loss=1.813   test_acc= 1.000
epoch= 42   train_loss= 1.740   train_acc= 1.000   test_loss=1.807   test_acc= 1.000
epoch= 43   train_loss= 1.730   train_acc= 1.000   test_loss=1.796   test_acc= 1.000
epoch= 44   train_loss= 1.722   train_acc= 1.000   test_loss=1.783   test_acc= 1.000
epoch= 45   train_loss= 1.710   train_acc= 1.000   test_loss=1.77

epoch= 30   train_loss= 1.876   train_acc= 1.000   test_loss=2.199   test_acc= 0.889
epoch= 31   train_loss= 1.865   train_acc= 1.000   test_loss=2.210   test_acc= 0.778
epoch= 32   train_loss= 1.849   train_acc= 1.000   test_loss=2.193   test_acc= 0.778
epoch= 33   train_loss= 1.841   train_acc= 1.000   test_loss=2.208   test_acc= 0.778
epoch= 34   train_loss= 1.832   train_acc= 1.000   test_loss=2.167   test_acc= 0.889
epoch= 35   train_loss= 1.815   train_acc= 1.000   test_loss=2.203   test_acc= 0.778
epoch= 36   train_loss= 1.807   train_acc= 1.000   test_loss=2.146   test_acc= 0.778
epoch= 37   train_loss= 1.797   train_acc= 1.000   test_loss=2.119   test_acc= 0.778
epoch= 38   train_loss= 1.783   train_acc= 1.000   test_loss=2.121   test_acc= 0.778
epoch= 39   train_loss= 1.773   train_acc= 1.000   test_loss=2.132   test_acc= 0.778
epoch= 40   train_loss= 1.760   train_acc= 1.000   test_loss=2.111   test_acc= 0.778
epoch= 41   train_loss= 1.753   train_acc= 1.000   test_loss=2.06

epoch= 26   train_loss= 1.951   train_acc= 1.000   test_loss=2.109   test_acc= 0.900
epoch= 27   train_loss= 1.938   train_acc= 1.000   test_loss=2.071   test_acc= 0.900
epoch= 28   train_loss= 1.919   train_acc= 1.000   test_loss=2.060   test_acc= 0.900
epoch= 29   train_loss= 1.906   train_acc= 1.000   test_loss=2.069   test_acc= 0.900
epoch= 30   train_loss= 1.890   train_acc= 1.000   test_loss=2.038   test_acc= 0.900
epoch= 31   train_loss= 1.881   train_acc= 1.000   test_loss=2.018   test_acc= 0.900
epoch= 32   train_loss= 1.869   train_acc= 1.000   test_loss=2.024   test_acc= 0.900
epoch= 33   train_loss= 1.855   train_acc= 1.000   test_loss=2.005   test_acc= 0.900
epoch= 34   train_loss= 1.848   train_acc= 1.000   test_loss=2.001   test_acc= 0.900
epoch= 35   train_loss= 1.832   train_acc= 1.000   test_loss=1.996   test_acc= 0.900
epoch= 36   train_loss= 1.821   train_acc= 1.000   test_loss=1.982   test_acc= 0.900
epoch= 37   train_loss= 1.816   train_acc= 1.000   test_loss=1.97

epoch= 22   train_loss= 1.977   train_acc= 1.000   test_loss=2.195   test_acc= 0.778
epoch= 23   train_loss= 1.969   train_acc= 1.000   test_loss=2.180   test_acc= 0.778
epoch= 24   train_loss= 1.948   train_acc= 1.000   test_loss=2.184   test_acc= 0.889
epoch= 25   train_loss= 1.949   train_acc= 0.988   test_loss=2.140   test_acc= 0.778
epoch= 26   train_loss= 1.925   train_acc= 1.000   test_loss=2.154   test_acc= 0.889
epoch= 27   train_loss= 1.906   train_acc= 1.000   test_loss=2.118   test_acc= 0.778
epoch= 28   train_loss= 1.913   train_acc= 0.988   test_loss=2.103   test_acc= 0.778
epoch= 29   train_loss= 1.886   train_acc= 1.000   test_loss=2.097   test_acc= 0.778
epoch= 30   train_loss= 1.875   train_acc= 1.000   test_loss=2.080   test_acc= 0.778
epoch= 31   train_loss= 1.861   train_acc= 1.000   test_loss=2.071   test_acc= 0.778
epoch= 32   train_loss= 1.851   train_acc= 1.000   test_loss=2.063   test_acc= 0.778
epoch= 33   train_loss= 1.835   train_acc= 1.000   test_loss=2.05

epoch= 18   train_loss= 2.041   train_acc= 1.000   test_loss=2.106   test_acc= 1.000
epoch= 19   train_loss= 2.024   train_acc= 1.000   test_loss=2.082   test_acc= 1.000
epoch= 20   train_loss= 2.012   train_acc= 1.000   test_loss=2.067   test_acc= 1.000
epoch= 21   train_loss= 1.998   train_acc= 1.000   test_loss=2.053   test_acc= 1.000
epoch= 22   train_loss= 1.979   train_acc= 1.000   test_loss=2.038   test_acc= 1.000
epoch= 23   train_loss= 1.968   train_acc= 1.000   test_loss=2.023   test_acc= 1.000
epoch= 24   train_loss= 1.957   train_acc= 1.000   test_loss=2.012   test_acc= 1.000
epoch= 25   train_loss= 1.939   train_acc= 1.000   test_loss=2.000   test_acc= 1.000
epoch= 26   train_loss= 1.927   train_acc= 1.000   test_loss=1.985   test_acc= 1.000
epoch= 27   train_loss= 1.921   train_acc= 1.000   test_loss=1.971   test_acc= 1.000
epoch= 28   train_loss= 1.899   train_acc= 1.000   test_loss=1.958   test_acc= 1.000
epoch= 29   train_loss= 1.894   train_acc= 1.000   test_loss=1.94

epoch= 14   train_loss= 2.094   train_acc= 1.000   test_loss=2.260   test_acc= 0.889
epoch= 15   train_loss= 2.079   train_acc= 1.000   test_loss=2.191   test_acc= 1.000
epoch= 16   train_loss= 2.062   train_acc= 1.000   test_loss=2.183   test_acc= 1.000
epoch= 17   train_loss= 2.054   train_acc= 1.000   test_loss=2.187   test_acc= 0.889
epoch= 18   train_loss= 2.032   train_acc= 1.000   test_loss=2.146   test_acc= 1.000
epoch= 19   train_loss= 2.016   train_acc= 1.000   test_loss=2.132   test_acc= 1.000
epoch= 20   train_loss= 1.999   train_acc= 1.000   test_loss=2.129   test_acc= 1.000
epoch= 21   train_loss= 1.990   train_acc= 1.000   test_loss=2.111   test_acc= 1.000
epoch= 22   train_loss= 1.977   train_acc= 1.000   test_loss=2.097   test_acc= 1.000
epoch= 23   train_loss= 1.963   train_acc= 1.000   test_loss=2.072   test_acc= 1.000
epoch= 24   train_loss= 1.948   train_acc= 1.000   test_loss=2.069   test_acc= 1.000
epoch= 25   train_loss= 1.938   train_acc= 1.000   test_loss=2.04

epoch= 10   train_loss= 2.175   train_acc= 1.000   test_loss=2.330   test_acc= 0.889
epoch= 11   train_loss= 2.176   train_acc= 1.000   test_loss=2.320   test_acc= 1.000
epoch= 12   train_loss= 2.138   train_acc= 1.000   test_loss=2.277   test_acc= 1.000
epoch= 13   train_loss= 2.132   train_acc= 1.000   test_loss=2.276   test_acc= 1.000
epoch= 14   train_loss= 2.105   train_acc= 1.000   test_loss=2.245   test_acc= 1.000
epoch= 15   train_loss= 2.090   train_acc= 1.000   test_loss=2.215   test_acc= 1.000
epoch= 16   train_loss= 2.073   train_acc= 1.000   test_loss=2.205   test_acc= 1.000
epoch= 17   train_loss= 2.053   train_acc= 1.000   test_loss=2.189   test_acc= 1.000
epoch= 18   train_loss= 2.039   train_acc= 1.000   test_loss=2.173   test_acc= 1.000
epoch= 19   train_loss= 2.027   train_acc= 1.000   test_loss=2.148   test_acc= 1.000
epoch= 20   train_loss= 2.011   train_acc= 1.000   test_loss=2.139   test_acc= 1.000
epoch= 21   train_loss= 2.000   train_acc= 1.000   test_loss=2.12

epoch= 6   train_loss= 2.279   train_acc= 0.988   test_loss=2.470   test_acc= 0.900
epoch= 7   train_loss= 2.252   train_acc= 0.988   test_loss=2.393   test_acc= 1.000
epoch= 8   train_loss= 2.211   train_acc= 1.000   test_loss=2.381   test_acc= 1.000
epoch= 9   train_loss= 2.201   train_acc= 1.000   test_loss=2.341   test_acc= 1.000
epoch= 10   train_loss= 2.178   train_acc= 1.000   test_loss=2.342   test_acc= 1.000
epoch= 11   train_loss= 2.153   train_acc= 1.000   test_loss=2.342   test_acc= 0.900
epoch= 12   train_loss= 2.139   train_acc= 1.000   test_loss=2.288   test_acc= 1.000
epoch= 13   train_loss= 2.126   train_acc= 1.000   test_loss=2.271   test_acc= 1.000
epoch= 14   train_loss= 2.111   train_acc= 1.000   test_loss=2.292   test_acc= 0.900
epoch= 15   train_loss= 2.089   train_acc= 1.000   test_loss=2.265   test_acc= 1.000
epoch= 16   train_loss= 2.072   train_acc= 1.000   test_loss=2.224   test_acc= 1.000
epoch= 17   train_loss= 2.056   train_acc= 1.000   test_loss=2.217   

epoch= 2   train_loss= 2.524   train_acc= 0.940   test_loss=2.547   test_acc= 1.000
epoch= 3   train_loss= 2.438   train_acc= 0.964   test_loss=2.457   test_acc= 1.000
epoch= 4   train_loss= 2.369   train_acc= 0.976   test_loss=2.478   test_acc= 0.889
epoch= 5   train_loss= 2.368   train_acc= 0.952   test_loss=2.425   test_acc= 0.889
epoch= 6   train_loss= 2.289   train_acc= 0.988   test_loss=2.322   test_acc= 1.000
epoch= 7   train_loss= 2.268   train_acc= 1.000   test_loss=2.318   test_acc= 1.000
epoch= 8   train_loss= 2.244   train_acc= 0.988   test_loss=2.312   test_acc= 0.889
epoch= 9   train_loss= 2.221   train_acc= 0.988   test_loss=2.294   test_acc= 0.889
epoch= 10   train_loss= 2.194   train_acc= 1.000   test_loss=2.240   test_acc= 1.000
epoch= 11   train_loss= 2.177   train_acc= 1.000   test_loss=2.213   test_acc= 1.000
epoch= 12   train_loss= 2.154   train_acc= 1.000   test_loss=2.227   test_acc= 0.889
epoch= 13   train_loss= 2.132   train_acc= 1.000   test_loss=2.215   test

epoch= 48   train_loss= 1.687   train_acc= 1.000   test_loss=2.149   test_acc= 0.889
epoch= 49   train_loss= 1.677   train_acc= 1.000   test_loss=2.143   test_acc= 0.889
run time: 1.0385619163513184 min
test_acc=0.889
run= 2   fold= 5
epoch= 0   train_loss= 2.884   train_acc= 0.675   test_loss=2.794   test_acc= 0.778
epoch= 1   train_loss= 2.610   train_acc= 0.928   test_loss=2.760   test_acc= 1.000
epoch= 2   train_loss= 2.519   train_acc= 0.916   test_loss=2.714   test_acc= 1.000
epoch= 3   train_loss= 2.437   train_acc= 0.940   test_loss=2.592   test_acc= 1.000
epoch= 4   train_loss= 2.405   train_acc= 0.940   test_loss=2.550   test_acc= 1.000
epoch= 5   train_loss= 2.332   train_acc= 0.976   test_loss=2.521   test_acc= 1.000
epoch= 6   train_loss= 2.306   train_acc= 0.976   test_loss=2.491   test_acc= 0.889
epoch= 7   train_loss= 2.265   train_acc= 0.988   test_loss=2.563   test_acc= 0.778
epoch= 8   train_loss= 2.257   train_acc= 0.952   test_loss=2.479   test_acc= 0.889
epoch= 9 

KeyboardInterrupt: 

### MI-net pooling deep supervision

In [7]:
import numpy as np
import sys
import time
import random
from random import shuffle
import argparse

from keras.models import Model
from keras.optimizers import SGD
from keras.regularizers import l2
from keras.layers import Input, Dense, Layer, Dropout, average

from mil_nets.dataset import load_dataset
from mil_nets.layer import Feature_pooling
from mil_nets.metrics import bag_accuracy
from mil_nets.objectives import bag_loss
from mil_nets.utils import convertToBatch

In [9]:
def test_eval(model, test_set):
    """Evaluate on testing set.
    Parameters
    -----------------
    model : keras.engine.training.Model object
        The training MI-Net with deep supervision model.
    test_set : list
        A list of testing set contains all training bags features and labels.
    Returns
    -----------------
    test_loss : float
        Mean loss of evaluating on testing set.
    test_acc : float
        Mean accuracy of evaluating on testing set.
    """
    num_test_batch = len(test_set)
    test_loss = np.zeros((num_test_batch, 1), dtype=np.float32)
    test_acc = np.zeros((num_test_batch, 1), dtype=np.float32)
    for ibatch, batch in enumerate(test_set):
        result = model.test_on_batch({'input':batch[0].astype(np.float32)}, {'fp1':batch[1].astype(np.float32), 'fp2':batch[1].astype(np.float32), 'fp3':batch[1].astype(np.float32), 'ave':batch[1].astype(np.float32)})
        test_loss[ibatch] = result[0]
        test_acc[ibatch] = result[-1]
    return np.mean(test_loss), np.mean(test_acc)

def train_eval(model, train_set):
    """Evaluate on training set.
    Parameters
    -----------------
    model : keras.engine.training.Model object
        The training MI-Net with deep supervision model.
    train_set : list
        A list of training set contains all training bags features and labels.
    Returns
    -----------------
    test_loss : float
        Mean loss of evaluating on traing set.
    test_acc : float
        Mean accuracy of evaluating on testing set.
    """
    num_train_batch = len(train_set)
    train_loss = np.zeros((num_train_batch, 1), dtype=np.float32)
    train_acc = np.zeros((num_train_batch, 1), dtype=np.float32)
    shuffle(train_set)
    for ibatch, batch in enumerate(train_set):
        result = model.train_on_batch({'input':batch[0].astype(np.float32)}, {'fp1':batch[1].astype(np.float32), 'fp2':batch[1].astype(np.float32), 'fp3':batch[1].astype(np.float32), 'ave':batch[1].astype(np.float32)})
        train_loss[ibatch] = result[0]
        train_acc[ibatch] = result[-1]
    return np.mean(train_loss), np.mean(train_acc)

def MI_Net_with_DS(dataset):
    """Train and evaluate on MI-Net with deep supervision.
    Parameters
    -----------------
    dataset : dict
        A dictionary contains all dataset information. We split train/test by keys.
    Returns
    -----------------
    test_acc : float
        Testing accuracy of MI-Net with deep supervision.
    """
    weight_decay=0.005
    init_lr=5e-4
    pooling_mode='max'
    momentum=0.9
    max_epoch=50
    # load data and convert type
    train_bags = dataset['train']
    test_bags = dataset['test']

    # convert bag to batch
    train_set = convertToBatch(train_bags)
    test_set = convertToBatch(test_bags)
    dimension = train_set[0][0].shape[1]
    weight = [1.0, 1.0, 1.0, 0.0]

    # data: instance feature, n*d, n = number of training instance
    data_input = Input(shape=(dimension,), dtype='float32', name='input')

    # fully-connected
    fc1 = Dense(256, activation='relu', kernel_regularizer=l2(weight_decay))(data_input)
    fc2 = Dense(128, activation='relu', kernel_regularizer=l2(weight_decay))(fc1)
    fc3 = Dense(64, activation='relu', kernel_regularizer=l2(weight_decay))(fc2)

    # dropout
    dropout1 = Dropout(rate=0.5)(fc1)
    dropout2 = Dropout(rate=0.5)(fc2)
    dropout3 = Dropout(rate=0.5)(fc3)

    # features pooling
    fp1 = Feature_pooling(output_dim=1, kernel_regularizer=l2(weight_decay), pooling_mode=pooling_mode, name='fp1')(dropout1)
    fp2 = Feature_pooling(output_dim=1, kernel_regularizer=l2(weight_decay), pooling_mode=pooling_mode, name='fp2')(dropout2)
    fp3 = Feature_pooling(output_dim=1, kernel_regularizer=l2(weight_decay), pooling_mode=pooling_mode, name='fp3')(dropout3)

    # score average
    mg_ave =average([fp1,fp2,fp3], name='ave')

    model = Model(inputs=[data_input], outputs=[fp1, fp2, fp3, mg_ave])
    sgd = SGD(lr=init_lr, decay=1e-4, momentum=momentum, nesterov=True)
    model.compile(loss={'fp1':bag_loss, 'fp2':bag_loss, 'fp3':bag_loss, 'ave':bag_loss}, loss_weights={'fp1':weight[0], 'fp2':weight[1], 'fp3':weight[2], 'ave':weight[3]}, optimizer=sgd, metrics=[bag_accuracy])

    # train model
    t1 = time.time()
    num_batch = len(train_set)
    for epoch in range(max_epoch):
        train_loss, train_acc = train_eval(model, train_set)
        test_loss, test_acc = test_eval(model, test_set)
        print('epoch=', epoch, '  train_loss= {:.3f}'.format(train_loss), '  train_acc= {:.3f}'.format(train_acc), '  test_loss={:.3f}'.format(test_loss), '  test_acc= {:.3f}'.format(test_acc))
    t2 = time.time()
    print('run time:', (t2-t1) / 60, 'min')
    print('test_acc={:.3f}'.format(test_acc))

    return test_acc

In [None]:
# perform five times 10-fold cross=validation experiments
run = 5
n_folds = 10
acc = np.zeros((run, n_folds), dtype=float)
for irun in range(run):
    dataset = load_dataset('musk1', n_folds)
    for ifold in range(n_folds):
        print('run=', irun, '  fold=', ifold)
        acc[irun][ifold] = MI_Net_with_DS(dataset[ifold])
print('MI-Net with DS mean accuracy = ', np.mean(acc))
print('std = ', np.std(acc))

run= 0   fold= 0
epoch= 0   train_loss= 4.697   train_acc= 0.646   test_loss=3.608   test_acc= 0.700
epoch= 1   train_loss= 3.576   train_acc= 0.817   test_loss=3.062   test_acc= 0.900
epoch= 2   train_loss= 2.949   train_acc= 0.927   test_loss=3.044   test_acc= 0.900
epoch= 3   train_loss= 2.781   train_acc= 0.939   test_loss=2.858   test_acc= 0.900
epoch= 4   train_loss= 2.590   train_acc= 0.976   test_loss=2.990   test_acc= 0.900
epoch= 5   train_loss= 2.591   train_acc= 0.951   test_loss=2.754   test_acc= 1.000
epoch= 6   train_loss= 2.410   train_acc= 0.988   test_loss=2.781   test_acc= 0.900
epoch= 7   train_loss= 2.364   train_acc= 1.000   test_loss=2.655   test_acc= 1.000
epoch= 8   train_loss= 2.309   train_acc= 1.000   test_loss=2.594   test_acc= 1.000
epoch= 9   train_loss= 2.276   train_acc= 1.000   test_loss=2.758   test_acc= 0.900
epoch= 10   train_loss= 2.267   train_acc= 1.000   test_loss=2.562   test_acc= 1.000
epoch= 11   train_loss= 2.248   train_acc= 1.000   test_lo

## Bag-Space

*CkNN*

https://github.com/chlorochrule/cknn/blob/master/cknn/cknn.py

In [None]:
import pandas as pd
X = pd.read_table("./clean2.data") #pd.read_csv("sample_data/mnist_test.csv") 

In [None]:
from cknn import cknneighbors_graph

#ckng = cknneighbors_graph(X, n_neighbors=5, delta=1.0)

In [None]:
import numpy as np
from sklearn.datasets import load_digits
from sklearn.manifold import SpectralEmbedding
import matplotlib.pyplot as plt
from matplotlib import offsetbox
import seaborn as sns

from cknn import cknneighbors_graph

sns.set()


def plot2d_label(X, title=None):
    digits = load_digits()
    y = digits.target
    x_min, x_max = np.min(X, 0), np.max(X, 0)
    X = (X - x_min) / (x_max - x_min)

    plt.figure()
    ax = plt.subplot(111)
    for i in range(X.shape[0]):
        plt.text(X[i, 0], X[i, 1], str(digits.target[i]),
                 color=plt.cm.Set1(y[i] / 10.),
                 fontdict={'weight': 'bold', 'size': 9})

    
    plt.xticks([])
    plt.yticks([])
    if title is not None:
        plt.title(title)
    plt.show()


def main():
    data = X
    print(data)
    n_neighbors = 2

    model_normal = SpectralEmbedding(n_components=2, n_neighbors=n_neighbors)
    y_normal = model_normal.fit_transform(data)
    plot2d_label(y_normal)

    #ckng = cknneighbors_graph(data, n_neighbors=n_neighbors, delta=1.5)
    #model_cknn = SpectralEmbedding(n_components=2, affinity='precomputed')
    #y_cknn = model_cknn.fit_transform(ckng.toarray())
    #plot2d_label(y_cknn)

main()

     MUSK-211,211_1+1,46,-108,-60,-69,-117,49,38,-161,-8,5,-323,-220,-113,-299,-283,-307,-31,-106,-227,-42,-59,-22,-67,189,81,17,-27,-89,-67,105,-116,124,-106,5,-120,63,-165,40,-27,68,-44,98,-33,-314,-282,-335,-144,-13,-197,-2,-144,-13,-11,-131,108,-43,42,-151,-4,8,-102,51,-15,108,-135,59,-166,20,-20,23,-48,-68,-299,-256,-97,-183,-24,-271,-229,-177,-6,0,-129,112,15,36,-66,-54,-75,132,-188,119,-120,-312,23,-55,-53,-26,-71,41,-55,148,-247,-306,-308,-230,-166,-35,-205,-280,-239,-53,-10,-23,25,-5,163,61,59,-39,92,72,113,-107,80,25,-27,81,-114,-187,45,-118,-75,-182,-234,-19,12,-13,-41,-119,-149,70,17,-20,-177,-101,-116,-14,-50,24,-81,-125,-114,-44,128,3,-244,-308,52,-7,39,126,156,-50,-112,96,1.
0     MUSK-211,211_1+10,41,-188,-145,22,-117,-6,57,-...                                                                                                                                                                                                                                                      

ValueError: ignored

In [1]:
import numpy as np
from sklearn.datasets import load_digits
from sklearn.manifold import SpectralEmbedding
import matplotlib.pyplot as plt
from matplotlib import offsetbox
import seaborn as sns

sns.set()


def plot2d_label(X, title=None):
    y = X[1]
    x_min, x_max = np.min(X[0], 0), np.max(X[0], 0)
    X = (X - x_min) / (x_max - x_min)

    plt.figure()
    plt.xticks([])
    plt.yticks([])
    if title is not None:
        plt.title(title)
    plt.show()


def main():
    data = X
    n_neighbors = 10

    model_normal = SpectralEmbedding(n_components=2, n_neighbors=n_neighbors)
    y_normal = model_normal.fit_transform(data)
    #plot2d_label(y_normal)

    ckng = cknneighbors_graph(data, n_neighbors=n_neighbors, delta=1.5)
    model_cknn = SpectralEmbedding(n_components=2, affinity='precomputed')
    y_cknn = model_cknn.fit_transform(ckng.toarray())
    #plot2d_label(y_cknn)
    print(y_cknn)

main()

NameError: name 'X' is not defined

## instance-Space


MI-SVM and mi-SVM

In [5]:
import misvm
from misvmio import parse_c45, bag_set
from __future__ import print_function, division
import numpy as np

In [14]:
    # Load list of C4.5 Examples
    example_set = parse_c45('musk1')
    
    print(example_set[:10])
    # Get stats to normalize data
    raw_data = np.array(example_set.to_float())
    data_mean = np.average(raw_data, axis=0)
    data_std  = np.std(raw_data, axis=0)
    data_std[np.nonzero(data_std == 0.0)] = 1.0
    def normalizer(ex):
        ex = np.array(ex)
        normed = ((ex - data_mean) / data_std)
        # The ...[:, 2:-1] removes first two columns and last column,
        # which are the bag/instance ids and class label, as part of the
        # normalization process
        return normed[2:-1]

    
    # Group examples into bags
    bagset = bag_set(example_set)

    # Convert bags to NumPy arrays
    bags = [np.array(b.to_float(normalizer)) for b in bagset]
    labels = np.array([b.label for b in bagset], dtype=float)
    # Convert 0/1 labels to -1/1 labels
    labels = 2 * labels - 1

    # Spilt dataset arbitrarily to train/test sets
    train_bags = bags[10:]
    train_labels = labels[10:]
    test_bags = bags[:10]
    test_labels = labels[:10]

    # Construct classifiers
    classifiers = {}
    
 # MISVM   : the MI-SVM algorithm of Andrews, Tsochantaridis, & Hofmann (2002)
 # miSVM   : the mi-SVM algorithm of Andrews, Tsochantaridis, & Hofmann (2002)

 # MissSVM : the semi-supervised learning approach of Zhou & Xu (2007)
 # MICA    : the MI classification algorithm of Mangasarian & Wild (2008)
 # sMIL    : sparse MIL (Bunescu & Mooney, 2007)
 # stMIL   : sparse, transductive  MIL (Bunescu & Mooney, 2007)
    
    classifiers['MissSVM'] = misvm.MissSVM(kernel='linear', C=1.0, max_iters=20)
    classifiers['sbMIL'] = misvm.sbMIL(kernel='linear', eta=0.1, C=1e2)
    classifiers['SIL'] = misvm.SIL(kernel='linear', C=1.0)
    classifiers['STK'] = misvm.STK(kernel='linear', C=1.0)
    classifiers['NSK'] = misvm.NSK(kernel='linear', C=1.0)

    # Train/Evaluate classifiers
    accuracies = {}
    for algorithm, classifier in classifiers.items():
        classifier.fit(train_bags, train_labels)
        predictions = classifier.predict(test_bags)
        accuracies[algorithm] = np.average(test_labels == np.sign(predictions))

    for algorithm, accuracy in accuracies.items():
        print('\n%s Accuracy: %.1f%%' % (algorithm, 100 * accuracy))

[<(<molecule_name, ID, ('MUSK-jf78', 'MUSK-jf67', 'MUSK-jf59', 'MUSK-jf58', 'MUSK-jf47', 'MUSK-jf46', 'MUSK-jf17', 'MUSK-j51', 'MUSK-j33', 'MUSK-f205', 'MUSK-f184', 'MUSK-f159', 'MUSK-f158', 'MUSK-f152', 'MUSK-344', 'MUSK-333', 'MUSK-331', 'MUSK-330', 'MUSK-323', 'MUSK-322', 'MUSK-321', 'MUSK-316', 'MUSK-315', 'MUSK-314', 'MUSK-311', 'MUSK-301', 'MUSK-293', 'MUSK-292', 'MUSK-285', 'MUSK-284', 'MUSK-273', 'MUSK-272', 'MUSK-256', 'MUSK-254', 'MUSK-246', 'MUSK-240', 'MUSK-238', 'MUSK-236', 'MUSK-228', 'MUSK-227', 'MUSK-224', 'MUSK-219', 'MUSK-213', 'MUSK-212', 'MUSK-211', 'MUSK-190', 'MUSK-188', 'NON-MUSK-jp13', 'NON-MUSK-jp10', 'NON-MUSK-j97', 'NON-MUSK-j96', 'NON-MUSK-j93', 'NON-MUSK-j90', 'NON-MUSK-j84', 'NON-MUSK-j83', 'NON-MUSK-j81', 'NON-MUSK-j148', 'NON-MUSK-j147', 'NON-MUSK-j146', 'NON-MUSK-j130', 'NON-MUSK-j129', 'NON-MUSK-j100', 'NON-MUSK-f209', 'NON-MUSK-f164', 'NON-MUSK-f161', 'NON-MUSK-f150', 'NON-MUSK-334', 'NON-MUSK-327', 'NON-MUSK-320', 'NON-MUSK-319', 'NON-MUSK-318', 'NON

Non-random start...
     pcost       dcost       gap    pres   dres
 0: -2.1953e+02 -2.2856e+01  4e+03  7e+01  5e-13
 1: -1.4260e+01 -2.1760e+01  2e+02  2e+00  7e-13
 2: -6.1494e+00 -1.7486e+01  3e+01  3e-01  9e-14
 3: -5.3271e+00 -1.1501e+01  9e+00  7e-02  2e-14
 4: -5.5278e+00 -6.7476e+00  1e+00  8e-03  1e-14
 5: -5.8459e+00 -6.1864e+00  4e-01  2e-03  1e-14
 6: -5.9329e+00 -6.0618e+00  1e-01  4e-04  1e-14
 7: -5.9703e+00 -6.0071e+00  4e-02  6e-05  1e-14
 8: -5.9831e+00 -5.9897e+00  7e-03  9e-06  1e-14
 9: -5.9857e+00 -5.9863e+00  6e-04  4e-07  1e-14
10: -5.9859e+00 -5.9860e+00  8e-05  6e-08  1e-14
11: -5.9860e+00 -5.9860e+00  3e-06  2e-09  1e-14
Optimal solution found.

Iteration 1...
Linearizing constraints...
Computing slacks...
Linearizing...
Solving QP...
     pcost       dcost       gap    pres   dres
 0: -2.1691e+02 -2.3067e+01  4e+03  7e+01  7e-13
 1: -1.3497e+01 -2.1886e+01  2e+02  2e+00  5e-13
 2: -5.6195e+00 -1.6509e+01  3e+01  3e-01  7e-14
 3: -4.6768e+00 -9.5961e+00  8e+0

 3: -4.8633e+00 -9.5622e+00  8e+00  6e-02  2e-14
 4: -5.0297e+00 -6.0551e+00  1e+00  1e-02  1e-14
 5: -5.2383e+00 -5.5276e+00  4e-01  2e-03  1e-14
 6: -5.3102e+00 -5.3866e+00  1e-01  5e-04  1e-14
 7: -5.3312e+00 -5.3519e+00  3e-02  1e-04  1e-14
 8: -5.3383e+00 -5.3410e+00  3e-03  1e-05  1e-14
 9: -5.3393e+00 -5.3394e+00  1e-04  4e-07  2e-14
10: -5.3394e+00 -5.3394e+00  5e-06  2e-08  2e-14
Optimal solution found.
delta obj ratio: 2.55e+05

Iteration 12...
Linearizing constraints...
Computing slacks...
Linearizing...
Solving QP...
     pcost       dcost       gap    pres   dres
 0: -2.1800e+02 -2.3207e+01  4e+03  7e+01  6e-13
 1: -1.3495e+01 -2.1965e+01  2e+02  2e+00  7e-13
 2: -5.9445e+00 -1.6459e+01  3e+01  3e-01  9e-14
 3: -5.0171e+00 -1.0239e+01  1e+01  9e-02  3e-14
 4: -5.0492e+00 -6.0718e+00  1e+00  1e-02  2e-14
 5: -5.2396e+00 -5.6237e+00  5e-01  3e-03  1e-14
 6: -5.3335e+00 -5.4393e+00  1e-01  8e-04  2e-14
 7: -5.3674e+00 -5.3805e+00  1e-02  3e-05  2e-14
 8: -5.3721e+00 -5.3737e+



Training initial sMIL classifier for sbMIL...
Setup QP...
Solving QP...
     pcost       dcost       gap    pres   dres
 0: -4.2149e+00 -5.1185e+02  4e+03  2e+00  2e-13
 1: -2.6280e-01 -3.0053e+02  7e+02  3e-01  1e-13
 2:  1.8438e+00 -1.2085e+02  2e+02  6e-02  1e-13
 3:  1.4873e+00 -2.7865e+01  4e+01  1e-02  5e-14
 4:  6.4435e-01 -2.5490e+00  3e+00  3e-04  3e-14
 5:  1.9592e-03 -3.0018e-01  3e-01  3e-06  1e-14
 6: -7.7758e-02 -1.8625e-01  1e-01  9e-07  5e-15
 7: -1.1869e-01 -1.6662e-01  5e-02  2e-16  4e-15
 8: -1.2816e-01 -1.4611e-01  2e-02  3e-16  4e-15
 9: -1.3223e-01 -1.3830e-01  6e-03  2e-16  3e-15
10: -1.3397e-01 -1.3512e-01  1e-03  2e-16  4e-15
11: -1.3444e-01 -1.3446e-01  3e-05  2e-16  5e-15
12: -1.3445e-01 -1.3445e-01  4e-07  2e-16  5e-15
13: -1.3445e-01 -1.3445e-01  1e-08  3e-16  4e-15
Optimal solution found.
Computing initial instance labels for sbMIL...
Retraining with top 10% as positive...
     pcost       dcost       gap    pres   dres
 0: -3.6646e+01 -1.7743e+02  3e+03  

### mi-Net

In [1]:
import sys
import time
from random import shuffle
import numpy as np
import argparse

from keras.models import Model
from keras.optimizers import SGD
from keras.regularizers import l2
from keras.layers import Input, Dense, Layer, Dropout

from mil_nets.dataset import load_dataset
from mil_nets.layer import Score_pooling
from mil_nets.metrics import bag_accuracy
from mil_nets.objectives import bag_loss
from mil_nets.utils import convertToBatch

In [22]:
def test_eval(model, test_set):
    """Evaluate on testing set.
    Parameters
    -----------------
    model : keras.engine.training.Model object
        The training mi-Net model.
    test_set : list
        A list of testing set contains all training bags features and labels.
    Returns
    -----------------
    test_loss : float
        Mean loss of evaluating on testing set.
    test_acc : float
        Mean accuracy of evaluating on testing set.
    """
    num_test_batch = len(test_set)
    test_loss = np.zeros((num_test_batch, 1), dtype=np.float32)
    test_acc = np.zeros((num_test_batch, 1), dtype=np.float32)
    for ibatch, batch in enumerate(test_set):
        result = model.test_on_batch({'input':batch[0].astype(np.float32)}, {'sp':batch[1].astype(np.float32)})
        test_loss[ibatch] = result[0]
        test_acc[ibatch] = result[1]
    return np.mean(test_loss), np.mean(test_acc)

def train_eval(model, train_set):
    """Evaluate on training set.
    Parameters
    -----------------
    model : keras.engine.training.Model object
        The training mi-Net model.
    train_set : list
        A list of training set contains all training bags features and labels.
    Returns
    -----------------
    test_loss : float
        Mean loss of evaluating on traing set.
    test_acc : float
        Mean accuracy of evaluating on testing set.
    """
    num_train_batch = len(train_set)
    train_loss = np.zeros((num_train_batch, 1), dtype=np.float32)
    train_acc = np.zeros((num_train_batch, 1), dtype=np.float32)
    shuffle(train_set)
    for ibatch, batch in enumerate(train_set):
        result = model.train_on_batch({'input':batch[0].astype(np.float32)}, {'sp':batch[1].astype(np.float32)})
        train_loss[ibatch] = result[0]
        train_acc[ibatch] = result[1]
    return np.mean(train_loss), np.mean(train_acc)

def mi_Net(dataset):
    weight_decay=0.005
    init_lr=5e-4
    pooling_mode='max'
    momentum=0.9
    max_epoch=50
    """Train and evaluate on mi-Net.
    Parameters
    -----------------
    dataset : dict
        A dictionary contains all dataset information. We split train/test by keys.
    Returns
    -----------------
    test_acc : float
        Testing accuracy of mi-Net.
    """
    # load data and convert type
    train_bags = dataset['train']
    test_bags = dataset['test']

    # convert bag to batch
    train_set = convertToBatch(train_bags)
    test_set = convertToBatch(test_bags)
    dimension = train_set[0][0].shape[1]

    # data: instance feature, n*d, n = number of training instance
    data_input = Input(shape=(dimension,), dtype='float32', name='input')

    # fully-connected
    fc1 = Dense(256, activation='relu', kernel_regularizer=l2(weight_decay))(data_input)
    fc2 = Dense(128, activation='relu', kernel_regularizer=l2(weight_decay))(fc1)
    fc3 = Dense(64, activation='relu', kernel_regularizer=l2(weight_decay))(fc2)

    # dropout
    dropout = Dropout(rate=0.5)(fc3)

    # score pooling
    sp = Score_pooling(output_dim=1, kernel_regularizer=l2(weight_decay), pooling_mode=pooling_mode, name='sp')(dropout)

    model = Model(inputs=[data_input], outputs=[sp])
    sgd = SGD(lr=init_lr, decay=1e-4, momentum=momentum, nesterov=True)
    model.compile(loss=bag_loss, optimizer=sgd, metrics=[bag_accuracy])

    # train model
    t1 = time.time()
    num_batch = len(train_set)
    for epoch in range(max_epoch):
        train_loss, train_acc = train_eval(model, train_set)
        test_loss, test_acc = test_eval(model, test_set)
        print('epoch=', epoch, '  train_loss= {:.3f}'.format(train_loss), '  train_acc= {:.3f}'.format(train_acc), '  test_loss={:.3f}'.format(test_loss), '  test_acc= {:.3f}'.format(test_acc))
    t2 = time.time()
    print('run time:', (t2-t1) / 60.0, 'min')
    print('test_acc={:.3f}'.format(test_acc))

    return test_acc


In [23]:
# perform five times 10-fold cross-validation experiments
run = 5
n_folds = 10
acc = np.zeros((run, n_folds), dtype=np.float32)
for irun in range(run):
    dataset = load_dataset('musk1', n_folds)
    for ifold in range(n_folds):
        print('run=', irun, '  fold=', ifold)
        acc[irun][ifold] = mi_Net(dataset[ifold])
print('mi-net mean accuracy = ', np.mean(acc))
print('std = ', np.std(acc))

MUSK-188     
MUSK-188     
MUSK-188     
MUSK-188     
MUSK-190     
MUSK-190     
MUSK-190     
MUSK-190     
MUSK-211     
MUSK-211     
MUSK-212     
MUSK-212     
MUSK-212     
MUSK-213     
MUSK-213     
MUSK-213     
MUSK-213     
MUSK-219     
MUSK-219     
MUSK-224     
MUSK-224     
MUSK-227     
MUSK-227     
MUSK-228     
MUSK-228     
MUSK-228     
MUSK-228     
MUSK-228     
MUSK-236     
MUSK-236     
MUSK-236     
MUSK-236     
MUSK-236     
MUSK-236     
MUSK-238     
MUSK-238     
MUSK-238     
MUSK-238     
MUSK-238     
MUSK-240     
MUSK-240     
MUSK-240     
MUSK-240     
MUSK-240     
MUSK-240     
MUSK-240     
MUSK-240     
MUSK-246     
MUSK-246     
MUSK-246     
MUSK-246     
MUSK-254     
MUSK-254     
MUSK-256     
MUSK-256     
MUSK-256     
MUSK-256     
MUSK-272     
MUSK-272     
MUSK-272     
MUSK-273     
MUSK-273     
MUSK-273     
MUSK-273     
MUSK-273     
MUSK-284     
MUSK-284     
MUSK-284     
MUSK-284     
MUSK-285     
MUSK-285     
MUSK-2

epoch= 18   train_loss= 2.081   train_acc= 1.000   test_loss=2.426   test_acc= 0.900
epoch= 19   train_loss= 2.059   train_acc= 1.000   test_loss=2.396   test_acc= 0.900
epoch= 20   train_loss= 2.065   train_acc= 0.976   test_loss=2.386   test_acc= 0.900
epoch= 21   train_loss= 2.039   train_acc= 0.976   test_loss=2.394   test_acc= 0.900
epoch= 22   train_loss= 2.029   train_acc= 1.000   test_loss=2.429   test_acc= 0.900
epoch= 23   train_loss= 2.010   train_acc= 0.976   test_loss=2.376   test_acc= 0.900
epoch= 24   train_loss= 1.978   train_acc= 1.000   test_loss=2.356   test_acc= 0.900
epoch= 25   train_loss= 1.962   train_acc= 1.000   test_loss=2.370   test_acc= 0.900
epoch= 26   train_loss= 1.946   train_acc= 1.000   test_loss=2.366   test_acc= 0.900
epoch= 27   train_loss= 1.941   train_acc= 1.000   test_loss=2.373   test_acc= 0.900
epoch= 28   train_loss= 1.922   train_acc= 1.000   test_loss=2.357   test_acc= 0.900
epoch= 29   train_loss= 1.907   train_acc= 1.000   test_loss=2.33

epoch= 14   train_loss= 2.198   train_acc= 0.976   test_loss=2.442   test_acc= 0.778
epoch= 15   train_loss= 2.180   train_acc= 0.976   test_loss=2.355   test_acc= 0.778
epoch= 16   train_loss= 2.147   train_acc= 0.964   test_loss=2.329   test_acc= 0.778
epoch= 17   train_loss= 2.141   train_acc= 0.988   test_loss=2.366   test_acc= 0.667
epoch= 18   train_loss= 2.141   train_acc= 0.964   test_loss=2.397   test_acc= 0.778
epoch= 19   train_loss= 2.089   train_acc= 0.988   test_loss=2.366   test_acc= 0.778
epoch= 20   train_loss= 2.079   train_acc= 0.976   test_loss=2.333   test_acc= 0.778
epoch= 21   train_loss= 2.063   train_acc= 0.976   test_loss=2.295   test_acc= 0.778
epoch= 22   train_loss= 2.044   train_acc= 1.000   test_loss=2.254   test_acc= 0.778
epoch= 23   train_loss= 2.030   train_acc= 0.976   test_loss=2.209   test_acc= 0.889
epoch= 24   train_loss= 2.014   train_acc= 1.000   test_loss=2.221   test_acc= 0.778
epoch= 25   train_loss= 1.978   train_acc= 1.000   test_loss=2.21

epoch= 10   train_loss= 2.322   train_acc= 0.940   test_loss=2.306   test_acc= 0.889
epoch= 11   train_loss= 2.300   train_acc= 0.952   test_loss=2.276   test_acc= 0.889
epoch= 12   train_loss= 2.265   train_acc= 0.964   test_loss=2.326   test_acc= 0.889
epoch= 13   train_loss= 2.245   train_acc= 0.976   test_loss=2.291   test_acc= 0.889
epoch= 14   train_loss= 2.179   train_acc= 0.988   test_loss=2.261   test_acc= 0.889
epoch= 15   train_loss= 2.181   train_acc= 0.952   test_loss=2.224   test_acc= 0.889
epoch= 16   train_loss= 2.165   train_acc= 0.976   test_loss=2.276   test_acc= 0.889
epoch= 17   train_loss= 2.138   train_acc= 0.976   test_loss=2.307   test_acc= 0.889
epoch= 18   train_loss= 2.099   train_acc= 0.988   test_loss=2.221   test_acc= 0.889
epoch= 19   train_loss= 2.072   train_acc= 1.000   test_loss=2.234   test_acc= 0.889
epoch= 20   train_loss= 2.062   train_acc= 1.000   test_loss=2.244   test_acc= 0.889
epoch= 21   train_loss= 2.043   train_acc= 0.988   test_loss=2.19

epoch= 6   train_loss= 2.495   train_acc= 0.867   test_loss=2.423   test_acc= 0.889
epoch= 7   train_loss= 2.424   train_acc= 0.928   test_loss=2.344   test_acc= 1.000
epoch= 8   train_loss= 2.427   train_acc= 0.880   test_loss=2.349   test_acc= 0.889
epoch= 9   train_loss= 2.384   train_acc= 0.940   test_loss=2.274   test_acc= 1.000
epoch= 10   train_loss= 2.336   train_acc= 0.916   test_loss=2.262   test_acc= 1.000
epoch= 11   train_loss= 2.274   train_acc= 0.964   test_loss=2.232   test_acc= 1.000
epoch= 12   train_loss= 2.259   train_acc= 0.940   test_loss=2.262   test_acc= 0.889
epoch= 13   train_loss= 2.225   train_acc= 0.964   test_loss=2.160   test_acc= 1.000
epoch= 14   train_loss= 2.190   train_acc= 0.964   test_loss=2.307   test_acc= 0.889
epoch= 15   train_loss= 2.192   train_acc= 0.952   test_loss=2.262   test_acc= 0.889
epoch= 16   train_loss= 2.154   train_acc= 0.988   test_loss=2.142   test_acc= 1.000
epoch= 17   train_loss= 2.122   train_acc= 0.988   test_loss=2.128   

epoch= 2   train_loss= 2.695   train_acc= 0.807   test_loss=2.877   test_acc= 0.667
epoch= 3   train_loss= 2.644   train_acc= 0.807   test_loss=2.733   test_acc= 0.889
epoch= 4   train_loss= 2.545   train_acc= 0.867   test_loss=2.716   test_acc= 0.889
epoch= 5   train_loss= 2.510   train_acc= 0.855   test_loss=2.643   test_acc= 0.778
epoch= 6   train_loss= 2.435   train_acc= 0.904   test_loss=2.681   test_acc= 0.778
epoch= 7   train_loss= 2.402   train_acc= 0.892   test_loss=2.661   test_acc= 0.778
epoch= 8   train_loss= 2.375   train_acc= 0.916   test_loss=2.627   test_acc= 0.889
epoch= 9   train_loss= 2.312   train_acc= 0.952   test_loss=2.609   test_acc= 0.778
epoch= 10   train_loss= 2.306   train_acc= 0.940   test_loss=2.542   test_acc= 0.889
epoch= 11   train_loss= 2.287   train_acc= 0.952   test_loss=2.620   test_acc= 0.778
epoch= 12   train_loss= 2.260   train_acc= 0.952   test_loss=2.570   test_acc= 0.889
epoch= 13   train_loss= 2.223   train_acc= 0.976   test_loss=2.541   test

epoch= 48   train_loss= 1.708   train_acc= 1.000   test_loss=1.938   test_acc= 0.889
epoch= 49   train_loss= 1.695   train_acc= 1.000   test_loss=1.952   test_acc= 0.889
run time: 0.7161436160405477 min
test_acc=0.889
MUSK-188     
MUSK-188     
MUSK-188     
MUSK-188     
MUSK-190     
MUSK-190     
MUSK-190     
MUSK-190     
MUSK-211     
MUSK-211     
MUSK-212     
MUSK-212     
MUSK-212     
MUSK-213     
MUSK-213     
MUSK-213     
MUSK-213     
MUSK-219     
MUSK-219     
MUSK-224     
MUSK-224     
MUSK-227     
MUSK-227     
MUSK-228     
MUSK-228     
MUSK-228     
MUSK-228     
MUSK-228     
MUSK-236     
MUSK-236     
MUSK-236     
MUSK-236     
MUSK-236     
MUSK-236     
MUSK-238     
MUSK-238     
MUSK-238     
MUSK-238     
MUSK-238     
MUSK-240     
MUSK-240     
MUSK-240     
MUSK-240     
MUSK-240     
MUSK-240     
MUSK-240     
MUSK-240     
MUSK-246     
MUSK-246     
MUSK-246     
MUSK-246     
MUSK-254     
MUSK-254     
MUSK-256     
MUSK-256     
MUSK-256    

epoch= 16   train_loss= 2.172   train_acc= 0.976   test_loss=2.296   test_acc= 0.900
epoch= 17   train_loss= 2.149   train_acc= 0.988   test_loss=2.268   test_acc= 0.900
epoch= 18   train_loss= 2.139   train_acc= 0.988   test_loss=2.238   test_acc= 0.900
epoch= 19   train_loss= 2.085   train_acc= 1.000   test_loss=2.233   test_acc= 0.900
epoch= 20   train_loss= 2.083   train_acc= 0.988   test_loss=2.249   test_acc= 0.900
epoch= 21   train_loss= 2.067   train_acc= 1.000   test_loss=2.214   test_acc= 0.900
epoch= 22   train_loss= 2.048   train_acc= 0.988   test_loss=2.228   test_acc= 0.900
epoch= 23   train_loss= 2.028   train_acc= 1.000   test_loss=2.135   test_acc= 0.900
epoch= 24   train_loss= 2.021   train_acc= 1.000   test_loss=2.209   test_acc= 0.900
epoch= 25   train_loss= 1.982   train_acc= 1.000   test_loss=2.106   test_acc= 0.900
epoch= 26   train_loss= 1.982   train_acc= 1.000   test_loss=2.142   test_acc= 0.900
epoch= 27   train_loss= 1.974   train_acc= 1.000   test_loss=2.05

epoch= 12   train_loss= 2.254   train_acc= 0.952   test_loss=2.191   test_acc= 1.000
epoch= 13   train_loss= 2.217   train_acc= 0.952   test_loss=2.187   test_acc= 0.889
epoch= 14   train_loss= 2.203   train_acc= 0.928   test_loss=2.151   test_acc= 1.000
epoch= 15   train_loss= 2.185   train_acc= 0.964   test_loss=2.155   test_acc= 1.000
epoch= 16   train_loss= 2.133   train_acc= 0.976   test_loss=2.147   test_acc= 1.000
epoch= 17   train_loss= 2.118   train_acc= 0.988   test_loss=2.114   test_acc= 1.000
epoch= 18   train_loss= 2.088   train_acc= 0.988   test_loss=2.100   test_acc= 1.000
epoch= 19   train_loss= 2.047   train_acc= 1.000   test_loss=2.069   test_acc= 1.000
epoch= 20   train_loss= 2.049   train_acc= 1.000   test_loss=2.084   test_acc= 0.889
epoch= 21   train_loss= 2.024   train_acc= 0.988   test_loss=2.051   test_acc= 1.000
epoch= 22   train_loss= 2.003   train_acc= 1.000   test_loss=2.037   test_acc= 1.000
epoch= 23   train_loss= 2.001   train_acc= 0.988   test_loss=2.00

epoch= 8   train_loss= 2.398   train_acc= 0.904   test_loss=2.525   test_acc= 0.667
epoch= 9   train_loss= 2.349   train_acc= 0.940   test_loss=2.508   test_acc= 0.667
epoch= 10   train_loss= 2.317   train_acc= 0.904   test_loss=2.488   test_acc= 0.889
epoch= 11   train_loss= 2.301   train_acc= 0.952   test_loss=2.496   test_acc= 0.667
epoch= 12   train_loss= 2.273   train_acc= 0.940   test_loss=2.470   test_acc= 0.667
epoch= 13   train_loss= 2.228   train_acc= 0.964   test_loss=2.466   test_acc= 0.667
epoch= 14   train_loss= 2.198   train_acc= 0.964   test_loss=2.444   test_acc= 0.667
epoch= 15   train_loss= 2.188   train_acc= 0.964   test_loss=2.460   test_acc= 0.667
epoch= 16   train_loss= 2.169   train_acc= 0.976   test_loss=2.442   test_acc= 0.667
epoch= 17   train_loss= 2.117   train_acc= 0.988   test_loss=2.356   test_acc= 0.778
epoch= 18   train_loss= 2.099   train_acc= 1.000   test_loss=2.365   test_acc= 0.778
epoch= 19   train_loss= 2.091   train_acc= 0.976   test_loss=2.345 

epoch= 4   train_loss= 2.566   train_acc= 0.819   test_loss=2.527   test_acc= 0.889
epoch= 5   train_loss= 2.515   train_acc= 0.892   test_loss=2.493   test_acc= 0.889
epoch= 6   train_loss= 2.461   train_acc= 0.855   test_loss=2.476   test_acc= 0.889
epoch= 7   train_loss= 2.432   train_acc= 0.867   test_loss=2.466   test_acc= 0.889
epoch= 8   train_loss= 2.382   train_acc= 0.940   test_loss=2.371   test_acc= 0.889
epoch= 9   train_loss= 2.340   train_acc= 0.964   test_loss=2.344   test_acc= 0.889
epoch= 10   train_loss= 2.322   train_acc= 0.928   test_loss=2.434   test_acc= 0.778
epoch= 11   train_loss= 2.314   train_acc= 0.928   test_loss=2.303   test_acc= 1.000
epoch= 12   train_loss= 2.275   train_acc= 0.940   test_loss=2.311   test_acc= 0.889
epoch= 13   train_loss= 2.228   train_acc= 0.964   test_loss=2.313   test_acc= 0.889
epoch= 14   train_loss= 2.227   train_acc= 0.964   test_loss=2.255   test_acc= 0.889
epoch= 15   train_loss= 2.179   train_acc= 0.964   test_loss=2.251   te

epoch= 0   train_loss= 2.983   train_acc= 0.494   test_loss=2.813   test_acc= 0.778
epoch= 1   train_loss= 2.831   train_acc= 0.747   test_loss=2.784   test_acc= 0.778
epoch= 2   train_loss= 2.733   train_acc= 0.807   test_loss=2.728   test_acc= 0.667
epoch= 3   train_loss= 2.621   train_acc= 0.880   test_loss=2.720   test_acc= 0.667
epoch= 4   train_loss= 2.600   train_acc= 0.843   test_loss=2.652   test_acc= 0.667
epoch= 5   train_loss= 2.504   train_acc= 0.904   test_loss=2.699   test_acc= 0.667
epoch= 6   train_loss= 2.455   train_acc= 0.904   test_loss=2.692   test_acc= 0.667
epoch= 7   train_loss= 2.403   train_acc= 0.940   test_loss=2.718   test_acc= 0.667
epoch= 8   train_loss= 2.382   train_acc= 0.916   test_loss=2.717   test_acc= 0.667
epoch= 9   train_loss= 2.358   train_acc= 0.964   test_loss=2.503   test_acc= 0.778
epoch= 10   train_loss= 2.349   train_acc= 0.952   test_loss=2.584   test_acc= 0.667
epoch= 11   train_loss= 2.278   train_acc= 0.952   test_loss=2.596   test_a

epoch= 46   train_loss= 1.727   train_acc= 1.000   test_loss=1.860   test_acc= 0.889
epoch= 47   train_loss= 1.719   train_acc= 1.000   test_loss=1.839   test_acc= 0.889
epoch= 48   train_loss= 1.702   train_acc= 1.000   test_loss=1.826   test_acc= 0.889
epoch= 49   train_loss= 1.701   train_acc= 1.000   test_loss=1.820   test_acc= 0.889
run time: 0.7510839025179545 min
test_acc=0.889
MUSK-188     
MUSK-188     
MUSK-188     
MUSK-188     
MUSK-190     
MUSK-190     
MUSK-190     
MUSK-190     
MUSK-211     
MUSK-211     
MUSK-212     
MUSK-212     
MUSK-212     
MUSK-213     
MUSK-213     
MUSK-213     
MUSK-213     
MUSK-219     
MUSK-219     
MUSK-224     
MUSK-224     
MUSK-227     
MUSK-227     
MUSK-228     
MUSK-228     
MUSK-228     
MUSK-228     
MUSK-228     
MUSK-236     
MUSK-236     
MUSK-236     
MUSK-236     
MUSK-236     
MUSK-236     
MUSK-238     
MUSK-238     
MUSK-238     
MUSK-238     
MUSK-238     
MUSK-240     
MUSK-240     
MUSK-240     
MUSK-240     
MUSK-240  

epoch= 14   train_loss= 2.185   train_acc= 0.988   test_loss=2.416   test_acc= 0.800
epoch= 15   train_loss= 2.162   train_acc= 0.976   test_loss=2.437   test_acc= 0.800
epoch= 16   train_loss= 2.176   train_acc= 0.951   test_loss=2.446   test_acc= 0.800
epoch= 17   train_loss= 2.132   train_acc= 0.988   test_loss=2.364   test_acc= 0.800
epoch= 18   train_loss= 2.106   train_acc= 0.988   test_loss=2.342   test_acc= 0.800
epoch= 19   train_loss= 2.096   train_acc= 1.000   test_loss=2.393   test_acc= 0.800
epoch= 20   train_loss= 2.060   train_acc= 1.000   test_loss=2.354   test_acc= 0.800
epoch= 21   train_loss= 2.046   train_acc= 0.988   test_loss=2.327   test_acc= 0.800
epoch= 22   train_loss= 2.054   train_acc= 0.976   test_loss=2.312   test_acc= 0.800
epoch= 23   train_loss= 2.029   train_acc= 1.000   test_loss=2.321   test_acc= 0.800
epoch= 24   train_loss= 2.003   train_acc= 1.000   test_loss=2.302   test_acc= 0.800
epoch= 25   train_loss= 1.981   train_acc= 1.000   test_loss=2.25

epoch= 10   train_loss= 2.314   train_acc= 0.940   test_loss=2.432   test_acc= 0.778
epoch= 11   train_loss= 2.304   train_acc= 0.928   test_loss=2.452   test_acc= 0.889
epoch= 12   train_loss= 2.285   train_acc= 0.964   test_loss=2.276   test_acc= 1.000
epoch= 13   train_loss= 2.273   train_acc= 0.940   test_loss=2.422   test_acc= 0.889
epoch= 14   train_loss= 2.232   train_acc= 0.952   test_loss=2.380   test_acc= 0.889
epoch= 15   train_loss= 2.204   train_acc= 0.988   test_loss=2.254   test_acc= 0.889
epoch= 16   train_loss= 2.167   train_acc= 0.976   test_loss=2.231   test_acc= 0.889
epoch= 17   train_loss= 2.131   train_acc= 0.988   test_loss=2.248   test_acc= 0.889
epoch= 18   train_loss= 2.124   train_acc= 0.976   test_loss=2.183   test_acc= 1.000
epoch= 19   train_loss= 2.078   train_acc= 0.988   test_loss=2.280   test_acc= 0.889
epoch= 20   train_loss= 2.076   train_acc= 0.988   test_loss=2.184   test_acc= 0.889
epoch= 21   train_loss= 2.050   train_acc= 0.988   test_loss=2.20

epoch= 6   train_loss= 2.492   train_acc= 0.916   test_loss=2.532   test_acc= 0.778
epoch= 7   train_loss= 2.495   train_acc= 0.855   test_loss=2.493   test_acc= 0.778
epoch= 8   train_loss= 2.400   train_acc= 0.916   test_loss=2.439   test_acc= 1.000
epoch= 9   train_loss= 2.384   train_acc= 0.928   test_loss=2.439   test_acc= 0.889
epoch= 10   train_loss= 2.335   train_acc= 0.928   test_loss=2.394   test_acc= 0.889
epoch= 11   train_loss= 2.305   train_acc= 0.904   test_loss=2.370   test_acc= 0.889
epoch= 12   train_loss= 2.300   train_acc= 0.916   test_loss=2.401   test_acc= 0.889
epoch= 13   train_loss= 2.258   train_acc= 0.916   test_loss=2.325   test_acc= 0.889
epoch= 14   train_loss= 2.214   train_acc= 0.988   test_loss=2.317   test_acc= 0.889
epoch= 15   train_loss= 2.180   train_acc= 0.964   test_loss=2.300   test_acc= 0.889
epoch= 16   train_loss= 2.189   train_acc= 0.928   test_loss=2.345   test_acc= 0.889
epoch= 17   train_loss= 2.158   train_acc= 0.952   test_loss=2.332   

epoch= 2   train_loss= 2.714   train_acc= 0.819   test_loss=2.747   test_acc= 0.778
epoch= 3   train_loss= 2.641   train_acc= 0.831   test_loss=2.730   test_acc= 0.778
epoch= 4   train_loss= 2.620   train_acc= 0.843   test_loss=2.707   test_acc= 0.778
epoch= 5   train_loss= 2.530   train_acc= 0.880   test_loss=2.662   test_acc= 0.778
epoch= 6   train_loss= 2.487   train_acc= 0.892   test_loss=2.677   test_acc= 0.778
epoch= 7   train_loss= 2.434   train_acc= 0.916   test_loss=2.646   test_acc= 0.778
epoch= 8   train_loss= 2.397   train_acc= 0.952   test_loss=2.646   test_acc= 0.778
epoch= 9   train_loss= 2.371   train_acc= 0.940   test_loss=2.606   test_acc= 0.778
epoch= 10   train_loss= 2.354   train_acc= 0.952   test_loss=2.586   test_acc= 0.778
epoch= 11   train_loss= 2.316   train_acc= 0.940   test_loss=2.547   test_acc= 0.778
epoch= 12   train_loss= 2.262   train_acc= 0.952   test_loss=2.547   test_acc= 0.778
epoch= 13   train_loss= 2.246   train_acc= 0.940   test_loss=2.548   test

epoch= 48   train_loss= 1.697   train_acc= 1.000   test_loss=2.287   test_acc= 0.778
epoch= 49   train_loss= 1.684   train_acc= 1.000   test_loss=2.285   test_acc= 0.778
run time: 0.756281320254008 min
test_acc=0.778
run= 2   fold= 8
epoch= 0   train_loss= 2.909   train_acc= 0.687   test_loss=2.735   test_acc= 0.778
epoch= 1   train_loss= 2.803   train_acc= 0.771   test_loss=2.616   test_acc= 0.889
epoch= 2   train_loss= 2.685   train_acc= 0.819   test_loss=2.572   test_acc= 0.778
epoch= 3   train_loss= 2.583   train_acc= 0.867   test_loss=2.544   test_acc= 0.778
epoch= 4   train_loss= 2.558   train_acc= 0.855   test_loss=2.549   test_acc= 0.778
epoch= 5   train_loss= 2.516   train_acc= 0.892   test_loss=2.549   test_acc= 0.778
epoch= 6   train_loss= 2.433   train_acc= 0.904   test_loss=2.527   test_acc= 0.778
epoch= 7   train_loss= 2.389   train_acc= 0.904   test_loss=2.517   test_acc= 0.778
epoch= 8   train_loss= 2.373   train_acc= 0.916   test_loss=2.564   test_acc= 0.778
epoch= 9  

epoch= 44   train_loss= 1.745   train_acc= 1.000   test_loss=3.650   test_acc= 0.556
epoch= 45   train_loss= 1.737   train_acc= 1.000   test_loss=3.719   test_acc= 0.556
epoch= 46   train_loss= 1.724   train_acc= 1.000   test_loss=3.693   test_acc= 0.556
epoch= 47   train_loss= 1.716   train_acc= 1.000   test_loss=3.753   test_acc= 0.556
epoch= 48   train_loss= 1.709   train_acc= 1.000   test_loss=3.771   test_acc= 0.556
epoch= 49   train_loss= 1.694   train_acc= 1.000   test_loss=3.698   test_acc= 0.556
run time: 0.6951592485109965 min
test_acc=0.556
MUSK-188     
MUSK-188     
MUSK-188     
MUSK-188     
MUSK-190     
MUSK-190     
MUSK-190     
MUSK-190     
MUSK-211     
MUSK-211     
MUSK-212     
MUSK-212     
MUSK-212     
MUSK-213     
MUSK-213     
MUSK-213     
MUSK-213     
MUSK-219     
MUSK-219     
MUSK-224     
MUSK-224     
MUSK-227     
MUSK-227     
MUSK-228     
MUSK-228     
MUSK-228     
MUSK-228     
MUSK-228     
MUSK-236     
MUSK-236     
MUSK-236     
MUSK-236

epoch= 12   train_loss= 2.257   train_acc= 0.963   test_loss=2.399   test_acc= 0.900
epoch= 13   train_loss= 2.257   train_acc= 0.939   test_loss=2.362   test_acc= 0.900
epoch= 14   train_loss= 2.197   train_acc= 0.951   test_loss=2.368   test_acc= 0.800
epoch= 15   train_loss= 2.177   train_acc= 0.976   test_loss=2.264   test_acc= 0.900
epoch= 16   train_loss= 2.154   train_acc= 0.976   test_loss=2.277   test_acc= 0.900
epoch= 17   train_loss= 2.143   train_acc= 0.988   test_loss=2.308   test_acc= 0.900
epoch= 18   train_loss= 2.115   train_acc= 0.988   test_loss=2.269   test_acc= 0.900
epoch= 19   train_loss= 2.106   train_acc= 0.963   test_loss=2.256   test_acc= 0.900
epoch= 20   train_loss= 2.072   train_acc= 0.988   test_loss=2.196   test_acc= 0.900
epoch= 21   train_loss= 2.054   train_acc= 0.988   test_loss=2.151   test_acc= 0.900
epoch= 22   train_loss= 2.023   train_acc= 1.000   test_loss=2.194   test_acc= 0.900
epoch= 23   train_loss= 2.014   train_acc= 0.988   test_loss=2.18

epoch= 8   train_loss= 2.404   train_acc= 0.928   test_loss=2.718   test_acc= 0.667
epoch= 9   train_loss= 2.348   train_acc= 0.892   test_loss=2.681   test_acc= 0.667
epoch= 10   train_loss= 2.303   train_acc= 0.976   test_loss=2.643   test_acc= 0.667
epoch= 11   train_loss= 2.317   train_acc= 0.904   test_loss=2.613   test_acc= 0.667
epoch= 12   train_loss= 2.255   train_acc= 0.964   test_loss=2.570   test_acc= 0.667
epoch= 13   train_loss= 2.262   train_acc= 0.952   test_loss=2.571   test_acc= 0.667
epoch= 14   train_loss= 2.194   train_acc= 0.964   test_loss=2.502   test_acc= 0.889
epoch= 15   train_loss= 2.161   train_acc= 0.964   test_loss=2.505   test_acc= 0.667
epoch= 16   train_loss= 2.152   train_acc= 0.988   test_loss=2.468   test_acc= 0.778
epoch= 17   train_loss= 2.117   train_acc= 1.000   test_loss=2.440   test_acc= 0.667
epoch= 18   train_loss= 2.106   train_acc= 0.976   test_loss=2.419   test_acc= 0.889
epoch= 19   train_loss= 2.079   train_acc= 1.000   test_loss=2.373 

epoch= 4   train_loss= 2.532   train_acc= 0.831   test_loss=2.841   test_acc= 0.444
epoch= 5   train_loss= 2.494   train_acc= 0.880   test_loss=2.851   test_acc= 0.444
epoch= 6   train_loss= 2.440   train_acc= 0.916   test_loss=2.886   test_acc= 0.667
epoch= 7   train_loss= 2.364   train_acc= 0.952   test_loss=2.902   test_acc= 0.667
epoch= 8   train_loss= 2.319   train_acc= 0.952   test_loss=2.915   test_acc= 0.444
epoch= 9   train_loss= 2.318   train_acc= 0.940   test_loss=2.897   test_acc= 0.556
epoch= 10   train_loss= 2.280   train_acc= 0.976   test_loss=2.946   test_acc= 0.667
epoch= 11   train_loss= 2.242   train_acc= 0.952   test_loss=2.987   test_acc= 0.667
epoch= 12   train_loss= 2.209   train_acc= 0.976   test_loss=2.935   test_acc= 0.556
epoch= 13   train_loss= 2.192   train_acc= 0.964   test_loss=2.996   test_acc= 0.667
epoch= 14   train_loss= 2.164   train_acc= 0.976   test_loss=3.012   test_acc= 0.556
epoch= 15   train_loss= 2.123   train_acc= 1.000   test_loss=3.049   te

epoch= 0   train_loss= 3.049   train_acc= 0.482   test_loss=3.040   test_acc= 0.444
epoch= 1   train_loss= 2.833   train_acc= 0.687   test_loss=3.046   test_acc= 0.444
epoch= 2   train_loss= 2.743   train_acc= 0.771   test_loss=2.874   test_acc= 0.667
epoch= 3   train_loss= 2.654   train_acc= 0.855   test_loss=2.879   test_acc= 0.333
epoch= 4   train_loss= 2.599   train_acc= 0.831   test_loss=2.803   test_acc= 0.556
epoch= 5   train_loss= 2.522   train_acc= 0.880   test_loss=2.843   test_acc= 0.444
epoch= 6   train_loss= 2.479   train_acc= 0.880   test_loss=2.738   test_acc= 0.778
epoch= 7   train_loss= 2.457   train_acc= 0.892   test_loss=2.904   test_acc= 0.444
epoch= 8   train_loss= 2.395   train_acc= 0.940   test_loss=2.790   test_acc= 0.556
epoch= 9   train_loss= 2.357   train_acc= 0.928   test_loss=2.880   test_acc= 0.444
epoch= 10   train_loss= 2.322   train_acc= 0.964   test_loss=2.827   test_acc= 0.667
epoch= 11   train_loss= 2.259   train_acc= 0.976   test_loss=2.780   test_a

epoch= 46   train_loss= 1.728   train_acc= 1.000   test_loss=2.918   test_acc= 0.778
epoch= 47   train_loss= 1.719   train_acc= 1.000   test_loss=2.979   test_acc= 0.667
epoch= 48   train_loss= 1.709   train_acc= 1.000   test_loss=2.961   test_acc= 0.667
epoch= 49   train_loss= 1.705   train_acc= 1.000   test_loss=2.930   test_acc= 0.667
run time: 0.7117309292157491 min
test_acc=0.667
run= 3   fold= 8
epoch= 0   train_loss= 2.948   train_acc= 0.602   test_loss=2.559   test_acc= 1.000
epoch= 1   train_loss= 2.809   train_acc= 0.747   test_loss=2.499   test_acc= 1.000
epoch= 2   train_loss= 2.718   train_acc= 0.819   test_loss=2.432   test_acc= 1.000
epoch= 3   train_loss= 2.653   train_acc= 0.807   test_loss=2.355   test_acc= 1.000
epoch= 4   train_loss= 2.622   train_acc= 0.831   test_loss=2.349   test_acc= 1.000
epoch= 5   train_loss= 2.551   train_acc= 0.795   test_loss=2.275   test_acc= 1.000
epoch= 6   train_loss= 2.495   train_acc= 0.892   test_loss=2.251   test_acc= 1.000
epoch= 

epoch= 42   train_loss= 1.779   train_acc= 1.000   test_loss=1.757   test_acc= 1.000
epoch= 43   train_loss= 1.757   train_acc= 1.000   test_loss=1.745   test_acc= 1.000
epoch= 44   train_loss= 1.752   train_acc= 1.000   test_loss=1.737   test_acc= 1.000
epoch= 45   train_loss= 1.736   train_acc= 1.000   test_loss=1.727   test_acc= 1.000
epoch= 46   train_loss= 1.730   train_acc= 1.000   test_loss=1.719   test_acc= 1.000
epoch= 47   train_loss= 1.717   train_acc= 1.000   test_loss=1.710   test_acc= 1.000
epoch= 48   train_loss= 1.705   train_acc= 1.000   test_loss=1.698   test_acc= 1.000
epoch= 49   train_loss= 1.709   train_acc= 1.000   test_loss=1.691   test_acc= 1.000
run time: 0.7496605674425761 min
test_acc=1.000
MUSK-188     
MUSK-188     
MUSK-188     
MUSK-188     
MUSK-190     
MUSK-190     
MUSK-190     
MUSK-190     
MUSK-211     
MUSK-211     
MUSK-212     
MUSK-212     
MUSK-212     
MUSK-213     
MUSK-213     
MUSK-213     
MUSK-213     
MUSK-219     
MUSK-219     
MUSK-2

epoch= 10   train_loss= 2.334   train_acc= 0.939   test_loss=2.647   test_acc= 0.700
epoch= 11   train_loss= 2.286   train_acc= 0.976   test_loss=2.456   test_acc= 0.900
epoch= 12   train_loss= 2.259   train_acc= 0.963   test_loss=2.568   test_acc= 0.900
epoch= 13   train_loss= 2.196   train_acc= 1.000   test_loss=2.513   test_acc= 0.900
epoch= 14   train_loss= 2.179   train_acc= 0.976   test_loss=2.522   test_acc= 0.800
epoch= 15   train_loss= 2.154   train_acc= 1.000   test_loss=2.554   test_acc= 0.800
epoch= 16   train_loss= 2.150   train_acc= 0.963   test_loss=2.380   test_acc= 0.900
epoch= 17   train_loss= 2.145   train_acc= 0.963   test_loss=2.469   test_acc= 0.900
epoch= 18   train_loss= 2.115   train_acc= 0.988   test_loss=2.392   test_acc= 0.900
epoch= 19   train_loss= 2.089   train_acc= 0.988   test_loss=2.497   test_acc= 0.800
epoch= 20   train_loss= 2.081   train_acc= 0.976   test_loss=2.403   test_acc= 0.900
epoch= 21   train_loss= 2.052   train_acc= 1.000   test_loss=2.42

epoch= 6   train_loss= 2.474   train_acc= 0.928   test_loss=2.486   test_acc= 0.778
epoch= 7   train_loss= 2.408   train_acc= 0.916   test_loss=2.483   test_acc= 0.778
epoch= 8   train_loss= 2.386   train_acc= 0.916   test_loss=2.438   test_acc= 0.778
epoch= 9   train_loss= 2.366   train_acc= 0.916   test_loss=2.483   test_acc= 0.778
epoch= 10   train_loss= 2.306   train_acc= 0.952   test_loss=2.484   test_acc= 0.778
epoch= 11   train_loss= 2.306   train_acc= 0.940   test_loss=2.488   test_acc= 0.778
epoch= 12   train_loss= 2.229   train_acc= 0.976   test_loss=2.457   test_acc= 0.778
epoch= 13   train_loss= 2.238   train_acc= 0.940   test_loss=2.411   test_acc= 0.778
epoch= 14   train_loss= 2.187   train_acc= 0.988   test_loss=2.457   test_acc= 0.778
epoch= 15   train_loss= 2.175   train_acc= 0.964   test_loss=2.387   test_acc= 0.778
epoch= 16   train_loss= 2.158   train_acc= 0.988   test_loss=2.475   test_acc= 0.778
epoch= 17   train_loss= 2.160   train_acc= 0.952   test_loss=2.373   

epoch= 2   train_loss= 2.699   train_acc= 0.855   test_loss=2.813   test_acc= 0.667
epoch= 3   train_loss= 2.655   train_acc= 0.807   test_loss=2.719   test_acc= 0.778
epoch= 4   train_loss= 2.545   train_acc= 0.880   test_loss=2.697   test_acc= 0.889
epoch= 5   train_loss= 2.500   train_acc= 0.904   test_loss=2.745   test_acc= 0.778
epoch= 6   train_loss= 2.469   train_acc= 0.892   test_loss=2.643   test_acc= 0.778
epoch= 7   train_loss= 2.425   train_acc= 0.904   test_loss=2.623   test_acc= 0.778
epoch= 8   train_loss= 2.401   train_acc= 0.904   test_loss=2.608   test_acc= 0.778
epoch= 9   train_loss= 2.341   train_acc= 0.952   test_loss=2.649   test_acc= 0.778
epoch= 10   train_loss= 2.309   train_acc= 0.940   test_loss=2.651   test_acc= 0.667
epoch= 11   train_loss= 2.284   train_acc= 0.940   test_loss=2.650   test_acc= 0.778
epoch= 12   train_loss= 2.252   train_acc= 0.952   test_loss=2.635   test_acc= 0.778
epoch= 13   train_loss= 2.202   train_acc= 0.976   test_loss=2.656   test

epoch= 48   train_loss= 1.698   train_acc= 1.000   test_loss=1.844   test_acc= 0.889
epoch= 49   train_loss= 1.693   train_acc= 1.000   test_loss=1.871   test_acc= 0.889
run time: 0.7094895680745442 min
test_acc=0.889
run= 4   fold= 6
epoch= 0   train_loss= 2.918   train_acc= 0.711   test_loss=2.843   test_acc= 0.778
epoch= 1   train_loss= 2.835   train_acc= 0.711   test_loss=2.783   test_acc= 0.667
epoch= 2   train_loss= 2.721   train_acc= 0.819   test_loss=2.692   test_acc= 0.889
epoch= 3   train_loss= 2.669   train_acc= 0.807   test_loss=2.647   test_acc= 0.778
epoch= 4   train_loss= 2.582   train_acc= 0.867   test_loss=2.607   test_acc= 0.889
epoch= 5   train_loss= 2.572   train_acc= 0.855   test_loss=2.546   test_acc= 0.889
epoch= 6   train_loss= 2.479   train_acc= 0.892   test_loss=2.504   test_acc= 0.889
epoch= 7   train_loss= 2.506   train_acc= 0.855   test_loss=2.493   test_acc= 0.889
epoch= 8   train_loss= 2.419   train_acc= 0.940   test_loss=2.471   test_acc= 0.889
epoch= 9 

epoch= 44   train_loss= 1.732   train_acc= 1.000   test_loss=3.326   test_acc= 0.778
epoch= 45   train_loss= 1.726   train_acc= 1.000   test_loss=3.319   test_acc= 0.778
epoch= 46   train_loss= 1.710   train_acc= 1.000   test_loss=3.318   test_acc= 0.778
epoch= 47   train_loss= 1.705   train_acc= 1.000   test_loss=3.291   test_acc= 0.778
epoch= 48   train_loss= 1.696   train_acc= 1.000   test_loss=3.289   test_acc= 0.778
epoch= 49   train_loss= 1.684   train_acc= 1.000   test_loss=3.270   test_acc= 0.778
run time: 0.7162322839101155 min
test_acc=0.778
run= 4   fold= 8
epoch= 0   train_loss= 3.014   train_acc= 0.627   test_loss=2.805   test_acc= 0.889
epoch= 1   train_loss= 2.784   train_acc= 0.735   test_loss=2.753   test_acc= 0.778
epoch= 2   train_loss= 2.655   train_acc= 0.855   test_loss=2.701   test_acc= 0.778
epoch= 3   train_loss= 2.597   train_acc= 0.867   test_loss=2.692   test_acc= 0.778
epoch= 4   train_loss= 2.503   train_acc= 0.880   test_loss=2.694   test_acc= 0.778
epoch

epoch= 40   train_loss= 1.807   train_acc= 1.000   test_loss=1.882   test_acc= 1.000
epoch= 41   train_loss= 1.796   train_acc= 1.000   test_loss=1.873   test_acc= 1.000
epoch= 42   train_loss= 1.787   train_acc= 1.000   test_loss=1.834   test_acc= 1.000
epoch= 43   train_loss= 1.777   train_acc= 1.000   test_loss=1.833   test_acc= 1.000
epoch= 44   train_loss= 1.771   train_acc= 1.000   test_loss=1.803   test_acc= 1.000
epoch= 45   train_loss= 1.756   train_acc= 1.000   test_loss=1.808   test_acc= 1.000
epoch= 46   train_loss= 1.751   train_acc= 1.000   test_loss=1.796   test_acc= 1.000
epoch= 47   train_loss= 1.734   train_acc= 1.000   test_loss=1.791   test_acc= 1.000
epoch= 48   train_loss= 1.733   train_acc= 1.000   test_loss=1.797   test_acc= 1.000
epoch= 49   train_loss= 1.719   train_acc= 1.000   test_loss=1.780   test_acc= 1.000
run time: 0.7182445844014486 min
test_acc=1.000
mi-net mean accuracy =  0.8300000059604645
std =  0.10402396848155698
