This notebook is used for visualizing the weights of the linear regression model. The weights is from 557 (features) to 4 (output classes).

In [1]:
import numpy as np

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [3]:
import sys
sys.path.insert(0, '../')

from models.linear_regression import LinearRegression

In [4]:
PATH_pretrained = "./models/model_6.pth"

In [5]:
### define a function to load in configuration  ###
import yaml

def load_config(filename):
    params = {}

    with open(filename) as f:
        params = yaml.load(f)

    return params

In [6]:
# define the model
params = load_config("./models/config.yaml")
model = LinearRegression(params["feature_dim"], params["output_dim"])
model.load_state_dict(torch.load(PATH_pretrained, map_location='cpu'))

IncompatibleKeys(missing_keys=[], unexpected_keys=[])

In [7]:
model.linear.weight

Parameter containing:
tensor([[-3.0771e-02,  1.9845e-01,  3.7109e-04,  ...,  1.4883e-03,
          1.8570e-04,  1.5613e-03],
        [ 1.2075e-01, -4.8275e-01,  8.0865e-05,  ..., -1.4362e-03,
         -3.0203e-04, -9.3305e-04],
        [ 6.5207e-03, -3.7063e-02,  1.0472e-04,  ..., -3.4571e-03,
          1.7285e-04,  5.2106e-04],
        [ 3.3806e-03, -3.2478e-03,  2.4866e-04,  ...,  5.8791e-05,
         -2.8904e-05,  9.4735e-07]], requires_grad=True)

In [8]:
# transder it into the numpy array and do without any gradient
with torch.no_grad():
    lr_weights = model.linear.weight.detach().numpy()

In [9]:
lr_weights

array([[-3.07707749e-02,  1.98448315e-01,  3.71087634e-04, ...,
         1.48830039e-03,  1.85702011e-04,  1.56133319e-03],
       [ 1.20745525e-01, -4.82749373e-01,  8.08647237e-05, ...,
        -1.43617764e-03, -3.02034605e-04, -9.33045812e-04],
       [ 6.52073650e-03, -3.70634049e-02,  1.04715480e-04, ...,
        -3.45706777e-03,  1.72854459e-04,  5.21059206e-04],
       [ 3.38060292e-03, -3.24778212e-03,  2.48657423e-04, ...,
         5.87912800e-05, -2.89038508e-05,  9.47351964e-07]], dtype=float32)

I don't know how to make it clear, so let's get the weights based on each label and the order of the features we concatenated to construct the feature vector.

In [10]:
lr_weights[0, :].shape

(557,)

In [11]:
!ls column_names/

all_column_names.pkl  categorical_values.pkl  continuous_column_names.pkl


In [12]:
# read in the entire column names and categorical column names
import pickle

folder_name = "./column_names/"
entire_column_filename = folder_name + "all_column_names.pkl"
#categorical_column_filename = folder_name + "categorical_column_names.pkl"
categorical_values_filename = folder_name + "categorical_values.pkl"

with open(entire_column_filename, 'rb') as fin:
    entire_column_names = pickle.load(fin)
#with open(categorical_column_filename, 'rb') as fin:
#    categorical_column_names = pickle.load(fin)
with open(categorical_values_filename, 'rb') as fin:
    categorical_values = pickle.load(fin)

In [13]:
categorical_values

{'PRIMLANG': array([-4,  1,  2,  3,  4,  5,  6]),
 'NACCAHTN': array([-4,  0,  1]),
 'DEPD': array([-4,  0,  1]),
 'AGIT': array([-4,  0,  1]),
 'ANX': array([-4,  0,  1]),
 'ALCOCCAS': array([-4,  0,  1]),
 'NACCANGI': array([-4,  0,  1]),
 'OCD': array([-4,  0,  1,  2]),
 'DISNSEV': array([-4,  1,  2,  3]),
 'CRAFTCUE': array([-4,  0,  1]),
 'EVENTS': array([-4,  0,  1,  2,  3]),
 'ELATSEV': array([-4,  1,  2,  3]),
 'APP': array([-4,  0,  1]),
 'BILLS': array([-4,  0,  1,  2,  3]),
 'NACCVASD': array([-4,  0,  1]),
 'RBD': array([-4,  0,  1,  2]),
 'NACCDBMD': array([-4,  0,  1]),
 'DISN': array([-4,  0,  1]),
 'HYPERTEN': array([-4,  0,  1,  2]),
 'NACCAPOE': array([-4,  1,  2,  3,  4,  5,  6]),
 'IRR': array([-4,  0,  1]),
 'PAYATTN': array([-4,  0,  1,  2,  3]),
 'NACCHTNC': array([-4,  0,  1]),
 'PTSD': array([-4,  0,  1,  2]),
 'NACCDIUR': array([-4,  0,  1]),
 'DEPDSEV': array([-4,  1,  2,  3]),
 'SHOPPING': array([-4,  0,  1,  2,  3]),
 'DEPOTHR': array([-4,  0,  1]),
 'NACCB

In [14]:
entire_column_names

['NACCID',
 'DATE',
 'SEX',
 'HISPANIC',
 'HISPOR',
 'PRIMLANG',
 'EDUC',
 'INSEX',
 'NACCNINR',
 'INEDUC',
 'INRELTO',
 'NACCFAM',
 'ANYMEDS',
 'TOBAC30',
 'TOBAC100',
 'SMOKYRS',
 'PACKSPER',
 'QUITSMOK',
 'ALCOCCAS',
 'ALCFREQ',
 'CVHATT',
 'HATTMULT',
 'CVAFIB',
 'CVANGIO',
 'CVBYPASS',
 'CVPACDEF',
 'CVPACE',
 'CVCHF',
 'CVANGINA',
 'CVHVALVE',
 'CVOTHR',
 'CBSTROKE',
 'STROKMUL',
 'CBTIA',
 'TIAMULT',
 'SEIZURES',
 'NACCTBI',
 'TBI',
 'TBIBRIEF',
 'TRAUMBRF',
 'TBIEXTEN',
 'TRAUMEXT',
 'TBIWOLOS',
 'TRAUMCHR',
 'NCOTHR',
 'DIABETES',
 'DIABTYPE',
 'HYPERTEN',
 'HYPERCHO',
 'B12DEF',
 'THYROID',
 'ARTHRIT',
 'ARTHTYPE',
 'ARTHUPEX',
 'ARTHLOEX',
 'ARTHSPIN',
 'ARTHUNK',
 'INCONTU',
 'INCONTF',
 'APNEA',
 'RBD',
 'INSOMN',
 'OTHSLEEP',
 'ALCOHOL',
 'ABUSOTHR',
 'PTSD',
 'BIPOLAR',
 'SCHIZ',
 'DEP2YRS',
 'DEPOTHR',
 'ANXIETY',
 'OCD',
 'NPSYDEV',
 'PSYCDIS',
 'BPSYS',
 'BPDIAS',
 'CDRSUM',
 'NPIQINF',
 'DEL',
 'DELSEV',
 'HALL',
 'HALLSEV',
 'AGIT',
 'AGITSEV',
 'DEPD',
 'DEPDSEV',


In [15]:
# since there are "NACCID" and "DATE" column, we remove them first
entire_column_names.remove("NACCID")
entire_column_names.remove("DATE")

In [16]:
# get the whole feature names for each entry
whole_feature_names_list= []  # store names for each entry
for col in entire_column_names:
    if col in list(categorical_values.keys()):
        # if this is a categorical name
        for value in list(categorical_values[col]):
            whole_feature_names_list.append(col+str(value))
    else:
        whole_feature_names_list.append(col)

In [17]:
len(whole_feature_names_list)

557

In [18]:
whole_feature_names_list

['SEX1',
 'SEX2',
 'HISPANIC-4',
 'HISPANIC0',
 'HISPANIC1',
 'HISPOR-4',
 'HISPOR1',
 'HISPOR2',
 'HISPOR3',
 'HISPOR4',
 'HISPOR5',
 'HISPOR6',
 'HISPOR50',
 'PRIMLANG-4',
 'PRIMLANG1',
 'PRIMLANG2',
 'PRIMLANG3',
 'PRIMLANG4',
 'PRIMLANG5',
 'PRIMLANG6',
 'EDUC',
 'INSEX-4',
 'INSEX1',
 'INSEX2',
 'NACCNINR-4',
 'NACCNINR1',
 'NACCNINR2',
 'NACCNINR3',
 'NACCNINR4',
 'NACCNINR5',
 'NACCNINR6',
 'INEDUC',
 'INRELTO-4',
 'INRELTO1',
 'INRELTO2',
 'INRELTO3',
 'INRELTO4',
 'INRELTO5',
 'INRELTO6',
 'INRELTO7',
 'NACCFAM-4',
 'NACCFAM0',
 'NACCFAM1',
 'ANYMEDS-4',
 'ANYMEDS0',
 'ANYMEDS1',
 'TOBAC30-4',
 'TOBAC300',
 'TOBAC301',
 'TOBAC100-4',
 'TOBAC1000',
 'TOBAC1001',
 'SMOKYRS',
 'PACKSPER-4',
 'PACKSPER0',
 'PACKSPER1',
 'PACKSPER2',
 'PACKSPER3',
 'PACKSPER4',
 'PACKSPER5',
 'QUITSMOK',
 'ALCOCCAS-4',
 'ALCOCCAS0',
 'ALCOCCAS1',
 'ALCFREQ-4',
 'ALCFREQ0',
 'ALCFREQ1',
 'ALCFREQ2',
 'ALCFREQ3',
 'ALCFREQ4',
 'CVHATT-4',
 'CVHATT0',
 'CVHATT1',
 'CVHATT2',
 'HATTMULT-4',
 'HATTMULT0

In [19]:
lr_weights.shape

(4, 557)

In [20]:
wll_arr = np.array([[4,5,8],[7,2,0]])
np.argsort(wll_arr, axis=1)

array([[0, 1, 2],
       [2, 1, 0]])

In [21]:
# sort the linear regression weights for each label to see which feature is more important for each label
# NOTE that the order is ascending, so if we want to get the most important ones, we need to reverse it
sort_perm_index = np.argsort(lr_weights, axis=1)

In [22]:
sort_perm_index

array([[420, 367, 423, ..., 439,   1, 477],
       [  1, 293, 356, ..., 440, 483, 292],
       [420, 367, 533, ..., 477,  20, 278],
       [278, 333, 496, ..., 420, 535, 439]])

In [23]:
# sort weights for each disease and use the descending order
pure_alz_index_descending = sort_perm_index[0][::-1]
pure_lbd_index_descending = sort_perm_index[1][::-1]
mix_index_descending = sort_perm_index[2][::-1]
others_index_descending = sort_perm_index[3][::-1]

In [24]:
# transfer the name list into a numpy array and resort it according to the sort_perm_index
whole_feature_names_array = np.array(whole_feature_names_list)
# for label 0 (PURE AD)
pure_alz_feature_names_importance = whole_feature_names_array[pure_alz_index_descending]
# for label 1 (PURE LBD)
pure_lbd_feature_names_importance = whole_feature_names_array[pure_lbd_index_descending]
# for label 3 (MIX)
mix_feature_names_importance = whole_feature_names_array[mix_index_descending]
# for label 4 (OTHERS)
others_feature_names_importance = whole_feature_names_array[others_index_descending]

In [25]:
pure_alz_feature_names_importance

array(['NACCAGE', 'SEX2', 'WAIS', 'UDSVERLC', 'CSFPTAU', 'NITE0', 'HALL0',
       'NACCADEP0', 'NACCAPOE2', 'NACCNE4S1', 'UDSVERTI', 'NACCFAM1',
       'UDSVERLN', 'NITESEV-4', 'UDSVERFC', 'UDSBENTC', 'HALLSEV-4',
       'TRAVEL0', 'GAMES0', 'NACCAGEB', 'UDSVERTN', 'BILLS0', 'CSFTTAU',
       'DIGIB', 'TAXES0', 'MEALPREP0', 'TRAUMBRF-4', 'UDSVERNF',
       'APASEV-4', 'TRAUMEXT-4', 'HYPERTEN2', 'IRR1', 'ANX0', 'SHOPPING0',
       'APPSEV-4', 'TRAUMCHR-4', 'APNEA0', 'NACCCCBS1', 'APA0',
       'MOCATOTS', 'ANXSEV-4', 'NACCAANX0', 'DIGBACCT', 'UDSBENRS1',
       'MINTPCNG', 'TRAILBRR', 'IRRSEV1', 'TRAILARR', 'NPIQINF2', 'RBD0',
       'UDSVERLR', 'ANIMALS', 'CVPACE-4', 'NACCAHTN1', 'CRAFTDVR',
       'DIGIBLEN', 'GAMES-4', 'HISPANIC1', 'NACCAPSY0', 'MINTSCNG',
       'NACCBMI', 'INCONTU0', 'NACCNINR2', 'NACCLIPL1', 'NACCNIHR2',
       'APP0', 'NACCACEI1', 'PACKSPER1', 'CVAFIB0', 'NPIQINF3',
       'TBIBRIEF0', 'NCOTHR-4', 'OTHSLEEP0', 'DEP2YRS0', 'DISN1',
       'EVENTS0', 'BILLS1', 'STO

In [26]:
pure_lbd_feature_names_importance

array(['HALL1', 'NACCAANX1', 'BOSTON', 'NACCMMSE', 'TRAILA', 'NITE1',
       'SEX1', 'INRELTO1', 'TRAILB', 'MEMUNITS', 'NACCNE4S-4',
       'NACCAPOE-4', 'NACCGDS', 'NPIQINF1', 'NACCAMD', 'MINTTOTW',
       'MEMTIME', 'MOCATOTS', 'TRAILALI', 'ANIMALS', 'EDUC', 'NACCLIPL0',
       'UDSVERTN', 'DISN0', 'NACCNINR1', 'MINTTOTS', 'UDSVERFC', 'GAMES1',
       'UDSVERFN', 'NACCACEI0', 'INEDUC', 'DISNSEV-4', 'CRAFTDRE',
       'UDSVERLC', 'NACCANGI1', 'UDSBENTD', 'MEALPREP-4', 'DIGFORCT',
       'TRAVEL1', 'LOGIMEM', 'HALLSEV2', 'INCONTU1', 'CVPACE1',
       'DIABETES0', 'NACCDIUR0', 'CRAFTURS', 'HALLSEV1', 'TRAVEL3',
       'CBSTROKE0', 'NACCNIHR5', 'BILLS3', 'CVCHF1', 'TAXES3', 'MINTPCNC',
       'NACCDBMD0', 'CVBYPASS2', 'TRAILBLI', 'VEG', 'UDSVERTE',
       'PAYATTN1', 'CRAFTDTI', 'CRAFTDVR', 'INSEX2', 'TRAILARR',
       'INRELTO4', 'DELSEV2', 'PSYCDIS2', 'PACKSPER5', 'NACCNINR5',
       'NACCNIHR1', 'CVANGIO1', 'RBD1', 'ALCOCCAS0', 'CVOTHR2',
       'NACCBETA0', 'APP0', 'APASEV1', 'PAYATT

In [27]:
mix_feature_names_importance

array(['CDRSUM', 'EDUC', 'NACCAGE', 'BOSTON', 'NACCAGEB', 'DIGIF',
       'HYPERCHO1', 'TRAILA', 'TRAILB', 'HALL1', 'HALLSEV1', 'HYPERTEN0',
       'REMDATES3', 'DELSEV1', 'CVPACE0', 'NCOTHR0', 'DEPOTHR0',
       'INCONTU0', 'ANXSEV3', 'DEL1', 'NITESEV1', 'NACCAHTN0', 'INEDUC',
       'SEX1', 'INRELTO5', 'ANXSEV1', 'INSEX1', 'TRAUMBRF0', 'TAXES-4',
       'TRAUMEXT0', 'NACCAPOE1', 'CVAFIB-4', 'NACCAANX0', 'NACCANGI0',
       'BPSYS', 'NACCNE4S2', 'NACCAPOE4', 'NITE1', 'PACKSPER1', 'BILLS-4',
       'TRAUMCHR0', 'CRAFTCUE1', 'UDSBENRS1', 'NACCFAM-4', 'MEALPREP-4',
       'PACKSPER4', 'NACCTBI0', 'MINTTOTW', 'ALCOHOL0', 'CVCHF0',
       'NACCCCBS0', 'MEALPREP1', 'MINTPCNG', 'STOVE0', 'HISPOR5',
       'TRAUMCHR1', 'TRAVEL1', 'THYROID2', 'PRIMLANG4', 'PAYATTN2',
       'CVANGIO0', 'CBSTROKE0', 'IRRSEV-4', 'NPIQINF1', 'BILLS3',
       'NACCFAM0', 'TOBAC100-4', 'IRR0', 'ABUSOTHR-4', 'GAMES1',
       'HYPERCHO2', 'DIGIFLEN', 'CVAFIB0', 'CVBYPASS0', 'NITESEV2',
       'NPSYDEV-4', 'BIPOLAR-4'

In [28]:
others_feature_names_importance

array(['WAIS', 'NACCNE4S0', 'MEMUNITS', 'NACCBMI', 'UDSBENTD',
       'REMDATES0', 'NACCMMSE', 'UDSBENTC', 'QUITSMOK', 'DISN1',
       'NACCGDS', 'LOGIMEM', 'BILLS0', 'SMOKYRS', 'NACCDIUR1', 'TRAILALI',
       'ANXSEV2', 'HYPERTEN1', 'NACCAHTN1', 'VEG', 'TAXES0', 'CBSTROKE2',
       'NACCBETA1', 'TRAVEL0', 'NACCDBMD1', 'DEPDSEV2', 'MEMTIME',
       'DIABETES1', 'BPDIAS', 'NACCAPSY1', 'NACCADEP1', 'GAMES2', 'MOT1',
       'APASEV2', 'CSFABETA', 'MINTTOTW', 'APPSEV1', 'BILLS2', 'DISNSEV2',
       'NACCAPOE3', 'APA1', 'INCONTF1', 'MOCATOTS', 'MEALPREP3',
       'MEALPREP1', 'SHOPPING3', 'MINTPCNC', 'PACKSPER-4', 'HALLSEV-4',
       'CRAFTVRS', 'NPIQINF3', 'NACCCCBS1', 'AGIT1', 'AGITSEV1',
       'NACCAAAS1', 'CVAFIB2', 'THYROID1', 'HALL0', 'PACKSPER3', 'ELAT1',
       'CBTIA2', 'SEX1', 'NACCAANX1', 'DIGIFLEN', 'TAXES2', 'NACCFAM0',
       'PAYATTN2', 'NACCNSD0', 'MINTTOTS', 'UDSVERLC', 'NACCACEI1',
       'CRAFTDRE', 'CRAFTDVR', 'NACCLIPL1', 'PACKSPER2', 'TRAVEL-4',
       'STOVE1', 'GAME

In [29]:
# to see the top 10 factors which is most important for each 
print("for PURE AD:", pure_alz_feature_names_importance[:10], "\n")
print("for PURE LBD:", pure_lbd_feature_names_importance[:10], "\n")
print("for MIX AD+LBD:", mix_feature_names_importance[:10], "\n")
print("for OTHERS:", others_feature_names_importance[:10])

for PURE AD: ['NACCAGE' 'SEX2' 'WAIS' 'UDSVERLC' 'CSFPTAU' 'NITE0' 'HALL0' 'NACCADEP0'
 'NACCAPOE2' 'NACCNE4S1'] 

for PURE LBD: ['HALL1' 'NACCAANX1' 'BOSTON' 'NACCMMSE' 'TRAILA' 'NITE1' 'SEX1'
 'INRELTO1' 'TRAILB' 'MEMUNITS'] 

for MIX AD+LBD: ['CDRSUM' 'EDUC' 'NACCAGE' 'BOSTON' 'NACCAGEB' 'DIGIF' 'HYPERCHO1'
 'TRAILA' 'TRAILB' 'HALL1'] 

for OTHERS: ['WAIS' 'NACCNE4S0' 'MEMUNITS' 'NACCBMI' 'UDSBENTD' 'REMDATES0' 'NACCMMSE'
 'UDSBENTC' 'QUITSMOK' 'DISN1']


In [30]:
categorical_values["TRAUMBRF"]

array([-4,  0,  1,  2])

In [31]:
# to get the top ten weight which is most important for each category 
lr_weights[0][pure_alz_index_descending]

array([ 2.05386251e-01,  1.98448315e-01,  1.22205973e-01,  6.38090819e-02,
        4.58135121e-02,  4.46736105e-02,  4.40288447e-02,  4.22878005e-02,
        2.86072604e-02,  2.45713070e-02,  2.44010631e-02,  2.34525092e-02,
        2.05977000e-02,  1.67055167e-02,  1.62547082e-02,  1.42134083e-02,
        1.40959704e-02,  1.35516189e-02,  1.20872688e-02,  1.18439877e-02,
        9.84445401e-03,  8.97072256e-03,  8.78062937e-03,  8.77899304e-03,
        8.55623279e-03,  7.77562102e-03,  7.53624598e-03,  7.45913014e-03,
        6.99444488e-03,  5.78579260e-03,  5.58532123e-03,  5.44494810e-03,
        5.35966735e-03,  5.30220894e-03,  5.12987701e-03,  5.10410778e-03,
        5.09709772e-03,  5.07226959e-03,  5.05612744e-03,  5.04508521e-03,
        4.94312495e-03,  4.85405605e-03,  4.82379599e-03,  4.80205659e-03,
        4.66859201e-03,  4.55451943e-03,  4.38414188e-03,  4.35648533e-03,
        4.33043344e-03,  4.19089710e-03,  4.11882298e-03,  4.08621971e-03,
        4.07770090e-03,  

In [32]:
lr_weights[1][pure_lbd_index_descending]

array([ 3.34778041e-01,  2.12556884e-01,  1.84686035e-01,  1.74420848e-01,
        1.74330309e-01,  1.24800034e-01,  1.20745525e-01,  1.20740347e-01,
        9.50447544e-02,  8.35523382e-02,  7.44196177e-02,  7.44021758e-02,
        6.18036278e-02,  5.52720539e-02,  4.59516980e-02,  3.89954634e-02,
        3.33058052e-02,  3.20543572e-02,  2.67276727e-02,  2.67107878e-02,
        2.40064282e-02,  1.89858805e-02,  1.31071797e-02,  1.28102787e-02,
        1.25746625e-02,  1.22643495e-02,  1.15473978e-02,  1.11883171e-02,
        1.08814705e-02,  1.08524151e-02,  9.85697284e-03,  9.50344093e-03,
        9.39287338e-03,  9.30874981e-03,  8.92224908e-03,  8.56145937e-03,
        8.53830017e-03,  8.45854916e-03,  8.39263760e-03,  8.24426301e-03,
        8.24330468e-03,  7.58118508e-03,  7.55265029e-03,  7.29069533e-03,
        6.95275329e-03,  6.33081514e-03,  5.91190904e-03,  5.78865921e-03,
        5.71964309e-03,  5.13193803e-03,  4.99221869e-03,  4.62352810e-03,
        3.82231595e-03,  

In [33]:
lr_weights[2][mix_index_descending]

array([ 2.37321272e-01,  1.14826769e-01,  1.00569695e-01,  4.14263718e-02,
        2.66687293e-02,  1.78125147e-02,  1.65677331e-02,  1.45078888e-02,
        1.43322730e-02,  1.15420604e-02,  1.04637286e-02,  9.62031167e-03,
        9.51093715e-03,  8.80895462e-03,  8.65726452e-03,  8.08901154e-03,
        8.08671769e-03,  7.78029766e-03,  7.65510742e-03,  7.42154103e-03,
        7.22655375e-03,  7.11715734e-03,  6.66459976e-03,  6.52073650e-03,
        6.32718019e-03,  6.19228138e-03,  5.83439553e-03,  5.82424179e-03,
        5.50394133e-03,  5.40528120e-03,  5.03768260e-03,  4.95704636e-03,
        4.78239590e-03,  4.72449046e-03,  4.55093384e-03,  4.50976472e-03,
        4.49405191e-03,  4.43141721e-03,  4.37868387e-03,  4.37490316e-03,
        4.36180318e-03,  4.33922233e-03,  4.33690473e-03,  4.23851050e-03,
        4.05412726e-03,  4.00923844e-03,  3.97416390e-03,  3.92906554e-03,
        3.84798809e-03,  3.70661449e-03,  3.68314562e-03,  3.64410994e-03,
        3.57097038e-03,  

In [34]:
lr_weights[3][others_index_descending]

array([ 1.48545355e-01,  1.15240470e-01,  7.54888505e-02,  7.17599317e-02,
        5.38035259e-02,  4.89344895e-02,  4.69528139e-02,  3.84853370e-02,
        2.74173822e-02,  2.22986434e-02,  2.08720472e-02,  2.07169745e-02,
        1.94499400e-02,  1.41121866e-02,  1.33889504e-02,  1.33196376e-02,
        1.04781548e-02,  1.04209324e-02,  1.03292717e-02,  1.02117276e-02,
        1.00726141e-02,  9.97036975e-03,  9.77711193e-03,  9.47320182e-03,
        9.42646340e-03,  9.34896525e-03,  9.21852794e-03,  9.02977679e-03,
        8.60551745e-03,  7.64703332e-03,  7.61695299e-03,  7.51414942e-03,
        6.98743388e-03,  6.96724700e-03,  6.62362808e-03,  6.42114924e-03,
        6.25921879e-03,  6.24636933e-03,  6.17664028e-03,  5.99400466e-03,
        5.65992342e-03,  5.29326731e-03,  5.21059521e-03,  4.97725699e-03,
        4.92161606e-03,  4.85311775e-03,  4.82202973e-03,  4.76525072e-03,
        4.72687464e-03,  4.44382895e-03,  4.31526825e-03,  4.18942655e-03,
        4.17250535e-03,  

In [35]:
# get the most important factors whose weights are at the magnitude of e-1
print("for PURE AD:", pure_alz_feature_names_importance[:3], "\n")
print("for PURE LBD:", pure_lbd_feature_names_importance[:8], "\n")
print("for MIX AD+LBD:", mix_feature_names_importance[:3], "\n")
print("for OTHERS:", others_feature_names_importance[:2])

for PURE AD: ['NACCAGE' 'SEX2' 'WAIS'] 

for PURE LBD: ['HALL1' 'NACCAANX1' 'BOSTON' 'NACCMMSE' 'TRAILA' 'NITE1' 'SEX1'
 'INRELTO1'] 

for MIX AD+LBD: ['CDRSUM' 'EDUC' 'NACCAGE'] 

for OTHERS: ['WAIS' 'NACCNE4S0']


In [36]:
mix_feature_names_importance[:11]

array(['CDRSUM', 'EDUC', 'NACCAGE', 'BOSTON', 'NACCAGEB', 'DIGIF',
       'HYPERCHO1', 'TRAILA', 'TRAILB', 'HALL1', 'HALLSEV1'], dtype='<U11')