# 졸업 논문

## DSAN: Denoising Self-Attention Network

- 추가 실험 jupyter 코드

### colab setting

In [None]:
# import os
# from google.colab import drive

# drive.mount('/drive/')

# experiment_dir = '/drive/MyDrive/experiments'
# os.chdir(experiment_dir)

In [1]:
!nvidia-smi

zsh:1: command not found: nvidia-smi


In [5]:
!pip install missingpy

Collecting missingpy
  Using cached missingpy-0.2.0-py3-none-any.whl (49 kB)
Installing collected packages: missingpy
Successfully installed missingpy-0.2.0


In [6]:
import sys
import numpy as np
import pandas as pd

from missingpy import MissForest
from dsan import Imputer as DSANImputer
from utils import *

import warnings
warnings.filterwarnings('ignore')

In [3]:
seed = 128
set_seed(seed)

In [17]:
def get_imputer(config):
    
    model = config['model']
    
    if model == 'DSAN':
        imputer = DSANImputer(rep_dim=config['rep_dim'],
                num_heads=config['num_heads'],
                n_hidden=config['n_hidden'],
                lr=config['lr'],
                weight_decay=config['weight_decay'],
                batch_size=config['batch_size'],
                epochs=config['epochs'],
                noise_percent=config['noise_percent'],
                stopped_epoch=config['stopped_epoch'])
    
    elif model == 'MissForest':
        imputer = MissForest(random_state=0)
        
    elif model == 'Statistics':
        imputer = None
        
    return imputer

In [19]:
def experiment_run(data_name, config, test=False):
    
    dataset = get_data(data_name)
    if test:
        dataset = (dataset[0][:1000], dataset[1], dataset[2], dataset[3])

    k = 5
    percent_missing_lst = [ p for p in range(5, 25, 5)]    
    X_origin, n_col, num_vars, cat_vars = dataset
    X_origin = category_mapping(X_origin, cat_vars)
    
    print("Data Size: {}".format(X_origin.shape))

    n_sample = X_origin.shape[0]
    idx_lst = list(range(n_sample))
    np.random.shuffle(idx_lst)

    # k-fold validation
    n_valid = len(idx_lst) // k
   
    results = {}
    for percent_missing in percent_missing_lst:
        result = []
        for it, n in enumerate(range(0, n_sample, n_valid)):
            valid_idx = idx_lst[n: n+n_valid]

            if len(valid_idx) < n_valid:
                break

            train_idx = list(set(idx_lst) - set(valid_idx))
            
            X = X_origin[train_idx, :]
            X_test = X_origin[valid_idx, :] # for prediction aucroc metric
            
            # make incomplete data
            missing_mask = generate_missing_mask(X, percent_missing=percent_missing)
            X_incomplete = X.copy()
            X_incomplete[missing_mask] = np.nan

            imputer = get_imputer(config)
            X_imputed = imputer.fit_transform(X_incomplete, cat_vars=cat_vars)
            
            X_res = X.copy()
            X_res[missing_mask] = X_imputed[missing_mask]
            
            metric = cal_metric_numpy(X_res, X, missing_mask, num_vars, cat_vars)
            score = ex_classify(data_name=data_name,
                    train_array=X_res,
                    test_array=X_test,
                    num_vars=num_vars,
                    cat_vars=cat_vars)
            
            metric['clf_aucroc'] = round(score, 4)
            result.append(metric)
            print(metric)
    
        avg_result = dict()
        for key in result[0].keys():
            value = np.mean([m[key] for m in result])
            avg_result[key] = value 
            print(key, value, percent_missing)

        results[percent_missing] = avg_result

    return results

In [26]:
config = {
    'model': 'DSAN',
    'rep_dim': 32,
    'num_heads': 8,
    'n_hidden': 2,
    'lr': 3e-3,
    'weight_decay': 1e-5,
    'batch_size': 128,
    'epochs': 10,
    'noise_percent': 10,
    'stopped_epoch': 10
}

In [40]:
import missingpy

In [43]:
from missingpy import MissForest

In [38]:
data_name_lst = ['adult', 'bank', 'crime']

result_dir = './results'
os.makedirs(result_dir)

for data_name in data_name_lst:
    results = experiment_run(data_name, config, test=True)
    results['config'] = config
    results['dataset'] = data_name
    result_path = os.path.join(result_dir, '{}.json'.format(data_name))
    write_json(results, result_path)

  0%|          | 0/10 [00:00<?, ?it/s]

Data Size: (1000, 15)


100%|██████████| 10/10 [00:05<00:00,  1.75it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

{'nrmse': 0.4611, 'col_1_error_rate': 0.3659, 'col_3_error_rate': 0.2778, 'col_5_error_rate': 0.16, 'col_6_error_rate': 0.6857, 'col_7_error_rate': 0.1633, 'col_8_error_rate': 0.1622, 'col_9_error_rate': 0.1556, 'col_13_error_rate': 0.1304, 'col_14_error_rate': 0.15, 'total_error_rate': 0.2375, 'clf_aucroc': 0.8678}


100%|██████████| 10/10 [00:05<00:00,  1.76it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

{'nrmse': 0.3537, 'col_1_error_rate': 0.2895, 'col_3_error_rate': 0.225, 'col_5_error_rate': 0.1591, 'col_6_error_rate': 0.6154, 'col_7_error_rate': 0.122, 'col_8_error_rate': 0.1111, 'col_9_error_rate': 0.119, 'col_13_error_rate': 0.0889, 'col_14_error_rate': 0.2381, 'total_error_rate': 0.2104, 'clf_aucroc': 0.8471}


100%|██████████| 10/10 [00:05<00:00,  1.76it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

{'nrmse': 0.446, 'col_1_error_rate': 0.3889, 'col_3_error_rate': 0.2955, 'col_5_error_rate': 0.186, 'col_6_error_rate': 0.6818, 'col_7_error_rate': 0.2857, 'col_8_error_rate': 0.0769, 'col_9_error_rate': 0.0789, 'col_13_error_rate': 0.0238, 'col_14_error_rate': 0.0556, 'total_error_rate': 0.2277, 'clf_aucroc': 0.8923}


100%|██████████| 10/10 [00:06<00:00,  1.53it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

{'nrmse': 0.6653, 'col_1_error_rate': 0.3409, 'col_3_error_rate': 0.1489, 'col_5_error_rate': 0.102, 'col_6_error_rate': 0.7143, 'col_7_error_rate': 0.1316, 'col_8_error_rate': 0.075, 'col_9_error_rate': 0.1053, 'col_13_error_rate': 0.0227, 'col_14_error_rate': 0.1351, 'total_error_rate': 0.1979, 'clf_aucroc': 0.8761}


100%|██████████| 10/10 [00:06<00:00,  1.61it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

{'nrmse': 0.5914, 'col_1_error_rate': 0.2903, 'col_3_error_rate': 0.25, 'col_5_error_rate': 0.1579, 'col_6_error_rate': 0.6944, 'col_7_error_rate': 0.1765, 'col_8_error_rate': 0.0606, 'col_9_error_rate': 0.0789, 'col_13_error_rate': 0.0698, 'col_14_error_rate': 0.1892, 'total_error_rate': 0.2164, 'clf_aucroc': 0.9158}
nrmse 0.5035000000000001 5
col_1_error_rate 0.3351 5
col_3_error_rate 0.23944000000000001 5
col_5_error_rate 0.153 5
col_6_error_rate 0.67832 5
col_7_error_rate 0.17581999999999998 5
col_8_error_rate 0.09716 5
col_9_error_rate 0.10754 5
col_13_error_rate 0.06712 5
col_14_error_rate 0.15360000000000001 5
total_error_rate 0.21797999999999998 5
clf_aucroc 0.8798200000000002 5


100%|██████████| 10/10 [00:05<00:00,  1.70it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

{'nrmse': 0.4791, 'col_1_error_rate': 0.3134, 'col_3_error_rate': 0.2615, 'col_5_error_rate': 0.236, 'col_6_error_rate': 0.6711, 'col_7_error_rate': 0.2419, 'col_8_error_rate': 0.1, 'col_9_error_rate': 0.1467, 'col_13_error_rate': 0.0563, 'col_14_error_rate': 0.119, 'total_error_rate': 0.2342, 'clf_aucroc': 0.8657}


100%|██████████| 10/10 [00:05<00:00,  1.67it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

{'nrmse': 0.5488, 'col_1_error_rate': 0.3947, 'col_3_error_rate': 0.2987, 'col_5_error_rate': 0.2073, 'col_6_error_rate': 0.6173, 'col_7_error_rate': 0.2105, 'col_8_error_rate': 0.1014, 'col_9_error_rate': 0.0759, 'col_13_error_rate': 0.05, 'col_14_error_rate': 0.1354, 'total_error_rate': 0.2318, 'clf_aucroc': 0.8454}


100%|██████████| 10/10 [00:06<00:00,  1.66it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

{'nrmse': 0.5612, 'col_1_error_rate': 0.4198, 'col_3_error_rate': 0.284, 'col_5_error_rate': 0.25, 'col_6_error_rate': 0.6026, 'col_7_error_rate': 0.1096, 'col_8_error_rate': 0.0864, 'col_9_error_rate': 0.1169, 'col_13_error_rate': 0.0909, 'col_14_error_rate': 0.1728, 'total_error_rate': 0.236, 'clf_aucroc': 0.8913}


100%|██████████| 10/10 [00:05<00:00,  1.69it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

{'nrmse': 0.5473, 'col_1_error_rate': 0.4366, 'col_3_error_rate': 0.1746, 'col_5_error_rate': 0.0897, 'col_6_error_rate': 0.7647, 'col_7_error_rate': 0.1772, 'col_8_error_rate': 0.0694, 'col_9_error_rate': 0.0972, 'col_13_error_rate': 0.0429, 'col_14_error_rate': 0.1566, 'total_error_rate': 0.2318, 'clf_aucroc': 0.8725}


100%|██████████| 10/10 [00:06<00:00,  1.63it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

{'nrmse': 0.5414, 'col_1_error_rate': 0.3721, 'col_3_error_rate': 0.3462, 'col_5_error_rate': 0.2, 'col_6_error_rate': 0.7317, 'col_7_error_rate': 0.2812, 'col_8_error_rate': 0.1067, 'col_9_error_rate': 0.125, 'col_13_error_rate': 0.044, 'col_14_error_rate': 0.1031, 'total_error_rate': 0.2524, 'clf_aucroc': 0.9198}
nrmse 0.53556 10
col_1_error_rate 0.38732 10
col_3_error_rate 0.273 10
col_5_error_rate 0.19660000000000002 10
col_6_error_rate 0.67748 10
col_7_error_rate 0.20407999999999998 10
col_8_error_rate 0.09278000000000002 10
col_9_error_rate 0.11234000000000002 10
col_13_error_rate 0.056819999999999996 10
col_14_error_rate 0.13738 10
total_error_rate 0.23723999999999998 10
clf_aucroc 0.87894 10


100%|██████████| 10/10 [00:06<00:00,  1.60it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

{'nrmse': 0.4804, 'col_1_error_rate': 0.3871, 'col_3_error_rate': 0.2542, 'col_5_error_rate': 0.1368, 'col_6_error_rate': 0.7222, 'col_7_error_rate': 0.2302, 'col_8_error_rate': 0.0635, 'col_9_error_rate': 0.1333, 'col_13_error_rate': 0.0833, 'col_14_error_rate': 0.1527, 'total_error_rate': 0.2434, 'clf_aucroc': 0.8802}


100%|██████████| 10/10 [00:06<00:00,  1.65it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

{'nrmse': 0.5409, 'col_1_error_rate': 0.3967, 'col_3_error_rate': 0.2381, 'col_5_error_rate': 0.223, 'col_6_error_rate': 0.696, 'col_7_error_rate': 0.2381, 'col_8_error_rate': 0.1136, 'col_9_error_rate': 0.1026, 'col_13_error_rate': 0.0902, 'col_14_error_rate': 0.1496, 'total_error_rate': 0.2482, 'clf_aucroc': 0.8483}


100%|██████████| 10/10 [00:05<00:00,  1.67it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

{'nrmse': 0.5567, 'col_1_error_rate': 0.3243, 'col_3_error_rate': 0.3051, 'col_5_error_rate': 0.2645, 'col_6_error_rate': 0.7803, 'col_7_error_rate': 0.1261, 'col_8_error_rate': 0.073, 'col_9_error_rate': 0.1417, 'col_13_error_rate': 0.0918, 'col_14_error_rate': 0.2, 'total_error_rate': 0.2627, 'clf_aucroc': 0.8936}


100%|██████████| 10/10 [00:06<00:00,  1.66it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

{'nrmse': 0.559, 'col_1_error_rate': 0.3675, 'col_3_error_rate': 0.232, 'col_5_error_rate': 0.1949, 'col_6_error_rate': 0.7059, 'col_7_error_rate': 0.1709, 'col_8_error_rate': 0.0815, 'col_9_error_rate': 0.0806, 'col_13_error_rate': 0.0916, 'col_14_error_rate': 0.1909, 'total_error_rate': 0.2308, 'clf_aucroc': 0.8764}


100%|██████████| 10/10 [00:05<00:00,  1.69it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

{'nrmse': 0.5282, 'col_1_error_rate': 0.3307, 'col_3_error_rate': 0.2788, 'col_5_error_rate': 0.2808, 'col_6_error_rate': 0.7133, 'col_7_error_rate': 0.1681, 'col_8_error_rate': 0.0926, 'col_9_error_rate': 0.1575, 'col_13_error_rate': 0.1008, 'col_14_error_rate': 0.1789, 'total_error_rate': 0.267, 'clf_aucroc': 0.9159}
nrmse 0.53304 15
col_1_error_rate 0.36126 15
col_3_error_rate 0.26164 15
col_5_error_rate 0.22000000000000003 15
col_6_error_rate 0.7235400000000001 15
col_7_error_rate 0.18668 15
col_8_error_rate 0.08484 15
col_9_error_rate 0.12314 15
col_13_error_rate 0.09154 15
col_14_error_rate 0.17441999999999996 15
total_error_rate 0.25042 15
clf_aucroc 0.8828799999999999 15


100%|██████████| 10/10 [00:06<00:00,  1.65it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

{'nrmse': 0.4441, 'col_1_error_rate': 0.3274, 'col_3_error_rate': 0.3371, 'col_5_error_rate': 0.2391, 'col_6_error_rate': 0.6839, 'col_7_error_rate': 0.3646, 'col_8_error_rate': 0.0899, 'col_9_error_rate': 0.125, 'col_13_error_rate': 0.044, 'col_14_error_rate': 0.1786, 'total_error_rate': 0.2623, 'clf_aucroc': 0.8677}


100%|██████████| 10/10 [00:06<00:00,  1.66it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

{'nrmse': 0.51, 'col_1_error_rate': 0.3976, 'col_3_error_rate': 0.2878, 'col_5_error_rate': 0.1677, 'col_6_error_rate': 0.7037, 'col_7_error_rate': 0.1582, 'col_8_error_rate': 0.0867, 'col_9_error_rate': 0.1161, 'col_13_error_rate': 0.0649, 'col_14_error_rate': 0.1635, 'total_error_rate': 0.2379, 'clf_aucroc': 0.8389}


100%|██████████| 10/10 [00:06<00:00,  1.52it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

{'nrmse': 0.4786, 'col_1_error_rate': 0.4593, 'col_3_error_rate': 0.284, 'col_5_error_rate': 0.2515, 'col_6_error_rate': 0.7285, 'col_7_error_rate': 0.195, 'col_8_error_rate': 0.0828, 'col_9_error_rate': 0.1623, 'col_13_error_rate': 0.0629, 'col_14_error_rate': 0.1701, 'total_error_rate': 0.2668, 'clf_aucroc': 0.8799}


100%|██████████| 10/10 [00:06<00:00,  1.64it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

{'nrmse': 0.6205, 'col_1_error_rate': 0.3924, 'col_3_error_rate': 0.303, 'col_5_error_rate': 0.2516, 'col_6_error_rate': 0.8194, 'col_7_error_rate': 0.3758, 'col_8_error_rate': 0.0838, 'col_9_error_rate': 0.1366, 'col_13_error_rate': 0.0839, 'col_14_error_rate': 0.1908, 'total_error_rate': 0.2918, 'clf_aucroc': 0.8684}


100%|██████████| 10/10 [00:06<00:00,  1.56it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

{'nrmse': 0.5677, 'col_1_error_rate': 0.3913, 'col_3_error_rate': 0.3476, 'col_5_error_rate': 0.2603, 'col_6_error_rate': 0.7748, 'col_7_error_rate': 0.2857, 'col_8_error_rate': 0.0889, 'col_9_error_rate': 0.1776, 'col_13_error_rate': 0.0651, 'col_14_error_rate': 0.1325, 'total_error_rate': 0.2736, 'clf_aucroc': 0.9103}
nrmse 0.5241800000000001 20
col_1_error_rate 0.39360000000000006 20
col_3_error_rate 0.31189999999999996 20
col_5_error_rate 0.23403999999999997 20
col_6_error_rate 0.7420599999999999 20
col_7_error_rate 0.27586 20
col_8_error_rate 0.08641999999999998 20
col_9_error_rate 0.14352 20
col_13_error_rate 0.06416000000000001 20
col_14_error_rate 0.16709999999999997 20
total_error_rate 0.26648 20
clf_aucroc 0.8730399999999999 20
Data Size: (1000, 17)


100%|██████████| 10/10 [00:06<00:00,  1.44it/s]


ValueError: Only one class present in y_true. ROC AUC score is not defined in that case.

In [23]:
config = {
    'model': 'MissForest'
    }

In [24]:
experiment_run('crime', config, test=True)

b'Skipping line 1513591: expected 23 fields, saw 24\n'


Data Size: (1000, 12)
Iteration: 0
Iteration: 1
Iteration: 2
Iteration: 3
Iteration: 4
Iteration: 5
Iteration: 6
Iteration: 7
Iteration: 8
Iteration: 9
{'nrmse': 0.0058, 'col_0_error_rate': 0.175, 'col_1_error_rate': 0.1053, 'col_2_error_rate': 0.3455, 'col_3_error_rate': 0.0, 'col_4_error_rate': 0.0, 'col_5_error_rate': 0.0952, 'col_6_error_rate': 0.0, 'col_7_error_rate': 0.069, 'total_error_rate': 0.1169, 'clf_aucroc': 0.8674}
Iteration: 0
Iteration: 1
Iteration: 2
Iteration: 3
Iteration: 4
Iteration: 5
Iteration: 6
Iteration: 7
{'nrmse': 0.0002, 'col_0_error_rate': 0.2424, 'col_1_error_rate': 0.075, 'col_2_error_rate': 0.1667, 'col_3_error_rate': 0.1471, 'col_4_error_rate': 0.0541, 'col_5_error_rate': 0.175, 'col_6_error_rate': 0.15, 'col_7_error_rate': 0.1212, 'total_error_rate': 0.1399, 'clf_aucroc': 0.8879}
Iteration: 0
Iteration: 1
Iteration: 2
Iteration: 3
Iteration: 4
Iteration: 5
{'nrmse': 0.0032, 'col_0_error_rate': 0.1957, 'col_1_error_rate': 0.0811, 'col_2_error_rate': 0.2

Iteration: 0
Iteration: 1
Iteration: 2
Iteration: 3
Iteration: 4
Iteration: 5
Iteration: 6
Iteration: 7
{'nrmse': 0.0044, 'col_0_error_rate': 0.3409, 'col_1_error_rate': 0.0741, 'col_2_error_rate': 0.2338, 'col_3_error_rate': 0.1317, 'col_4_error_rate': 0.0823, 'col_5_error_rate': 0.1813, 'col_6_error_rate': 0.1824, 'col_7_error_rate': 0.0955, 'total_error_rate': 0.1613, 'clf_aucroc': 0.8172}
nrmse 0.00478 20
col_0_error_rate 0.25558000000000003 20
col_1_error_rate 0.08854 20
col_2_error_rate 0.22468 20
col_3_error_rate 0.11988000000000001 20
col_4_error_rate 0.07731999999999999 20
col_5_error_rate 0.21338 20
col_6_error_rate 0.16416 20
col_7_error_rate 0.09744000000000001 20
total_error_rate 0.15496 20
clf_aucroc 0.8669 20


{5: {'nrmse': 0.0034200000000000003,
  'col_0_error_rate': 0.18082,
  'col_1_error_rate': 0.07186000000000001,
  'col_2_error_rate': 0.21538,
  'col_3_error_rate': 0.09068000000000001,
  'col_4_error_rate': 0.02232,
  'col_5_error_rate': 0.17712,
  'col_6_error_rate': 0.08834,
  'col_7_error_rate': 0.0721,
  'total_error_rate': 0.11785999999999999,
  'clf_aucroc': 0.86954},
 10: {'nrmse': 0.00346,
  'col_0_error_rate': 0.23264,
  'col_1_error_rate': 0.06536,
  'col_2_error_rate': 0.23500000000000001,
  'col_3_error_rate': 0.11954000000000001,
  'col_4_error_rate': 0.04464,
  'col_5_error_rate': 0.16388,
  'col_6_error_rate': 0.11488000000000001,
  'col_7_error_rate': 0.0871,
  'total_error_rate': 0.13208,
  'clf_aucroc': 0.86922},
 15: {'nrmse': 0.00438,
  'col_0_error_rate': 0.23842,
  'col_1_error_rate': 0.09246,
  'col_2_error_rate': 0.21888000000000002,
  'col_3_error_rate': 0.11667999999999998,
  'col_4_error_rate': 0.08203999999999999,
  'col_5_error_rate': 0.22566000000000003,
 