In [2]:
import numpy as np
from random_forest import RandomForest
import logistic_regression
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score

In [3]:
## load the dataset
filename = 'SPECTF.dat'
data = np.loadtxt(filename, delimiter=',')

In [4]:
### a sample from the dataset
print(data[[6, 9, 69]])

[[  1.  69.  66.  62.  75.  67.  71.  72.  76.  69.  70.  66.  69.  71.
   80.  66.  64.  71.  77.  65.  61.  72.  67.  71.  69.  65.  57.  69.
   65.  68.  65.  76.  73.  63.  64.  69.  70.  72.  72.  69.  68.  70.
   73.  63.  59.]
 [  1.  74.  73.  72.  79.  66.  61.  76.  66.  65.  64.  78.  74.  62.
   57.  48.  36.  62.  50.  67.  63.  79.  70.  61.  57.  52.  36.  69.
   49.  55.  65.  74.  73.  58.  60.  64.  62.  73.  69.  62.  67.  60.
   56.  53.  46.]
 [  0.  69.  64.  73.  72.  49.  70.  66.  71.  57.  56.  64.  62.  76.
   74.  65.  62.  63.  58.  63.  63.  75.  76.  78.  80.  75.  77.  51.
   62.  74.  68.  77.  77.  70.  68.  68.  64.  59.  58.  69.  66.  74.
   75.  62.  59.]]


In [4]:
## divide into features matrix and response variables vector
X = data[:, 1:]
Y = np.array(data[:, 0])
n = X.shape[0]

In [5]:
## divide data into training set and 
trainX, testX, trainY, testY = train_test_split(X, Y, test_size=0.2, random_state=0)

In [6]:
%%time
## train random forest
forest = RandomForest(50, 100)
forest.fit(trainX, trainY)
(forest_labels, forest_conf) = forest.predict(testX)
forest_probs = [forest_conf[i] if forest_labels[i] == 1 else 1 - forest_conf[i] for i in range(len(forest_labels))]

CPU times: user 8.84 s, sys: 0 ns, total: 8.84 s
Wall time: 8.84 s


In [7]:
%%time
## train logistic regression
log_beta = logistic_regression(trainX, trainY, step_size=1e-1, max_steps=300)
log_probs = logistic_probs(testX, log_beta)

CPU times: user 26 s, sys: 1.15 s, total: 27.1 s
Wall time: 25.8 s


In [8]:
### area under ROC curve
print(roc_auc_score(testY, forest_probs))
print(roc_auc_score(testY, log_probs))

### random forest looks better
### and they are both quite slow

0.714814814815
0.61975308642
