# 1. Library

In [1]:
import pandas as pd
import numpy as np
import random
import os
from anytree import Node

from Algorithms import TreeInitialization, TreeEmbedding
from Preprocessing import DataPreprocessing
from NN
import Evaluation

In [2]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)

seed_everything(123)

# 2. Tree

In [3]:
reuters = Node("Reuters")

ccat = Node("CCAT", parent = reuters)
c11 = Node("C11", parent = ccat)
c21 = Node("C21", parent = ccat)
c24 = Node("C24", parent = ccat)

gcat = Node("GCAT", parent = reuters)
gcrim = Node("GCRIM", parent = gcat)
gdip = Node("GDIP", parent = gcat)
gpol = Node("GPOL", parent = gcat)
gvio = Node("GVIO", parent = gcat)

mcat = Node("MCAT", parent = reuters)
m12 = Node("M12", parent = mcat)
m14 = Node("M14", parent = mcat)
m141 = Node("M141", parent = m14)
m142 = Node("M142", parent = m14)

In [4]:
tree = reuters

In [5]:
trees = TreeInitialization(tree)
trees.df

Unnamed: 0,node,id,level,isLeaf,parent
0,Reuters,0,0,False,-1
1,CCAT,1,1,False,0
2,GCAT,2,1,False,0
3,MCAT,3,1,False,0
4,C11,4,2,True,1
5,C21,5,2,True,1
6,C24,6,2,True,1
7,GCRIM,7,2,True,2
8,GDIP,8,2,True,2
9,GPOL,9,2,True,2


In [6]:
trees.nodes

[Node('/Reuters', id=0),
 Node('/Reuters/CCAT', id=1),
 Node('/Reuters/GCAT', id=2),
 Node('/Reuters/MCAT', id=3),
 Node('/Reuters/CCAT/C11', id=4),
 Node('/Reuters/CCAT/C21', id=5),
 Node('/Reuters/CCAT/C24', id=6),
 Node('/Reuters/GCAT/GCRIM', id=7),
 Node('/Reuters/GCAT/GDIP', id=8),
 Node('/Reuters/GCAT/GPOL', id=9),
 Node('/Reuters/GCAT/GVIO', id=10),
 Node('/Reuters/MCAT/M12', id=11),
 Node('/Reuters/MCAT/M14', id=12),
 Node('/Reuters/MCAT/M14/M141', id=13),
 Node('/Reuters/MCAT/M14/M142', id=14)]

In [7]:
Xi = TreeEmbedding(tree).xi
print(Xi)

[[-0.8660254   0.8660254   0.         -0.8660254  -0.8660254  -0.8660254
   0.8660254   0.8660254   0.8660254   0.8660254   0.          0.
   0.          0.        ]
 [-0.5        -0.5         1.         -0.5        -0.5        -0.5
  -0.5        -0.5        -0.5        -0.5         1.          1.
   1.          1.        ]
 [ 0.          0.          0.         -0.38729833  0.38729833  0.
   0.          0.          0.          0.          0.          0.
   0.          0.        ]
 [ 0.          0.          0.         -0.2236068  -0.2236068   0.4472136
   0.          0.          0.          0.          0.          0.
   0.          0.        ]
 [ 0.          0.          0.          0.          0.          0.
  -0.36514837  0.36514837  0.          0.          0.          0.
   0.          0.        ]
 [ 0.          0.          0.          0.          0.          0.
  -0.21081851 -0.21081851  0.42163702  0.          0.          0.
   0.          0.        ]
 [ 0.          0.          0.  

# 3. Data

In [8]:
prep = DataPreprocessing(tree)

X = pd.read_csv('./data/Reuters_X.txt', sep=' ', header=None)
y = pd.read_csv('./data/Reuters_Y.txt', header=None)

In [9]:
X = np.c_[np.ones(X.shape[0]), np.array(X)]
y = y[0]

In [10]:
X.shape, len(y)

((455, 7206), 455)

# 4. Preprocessing

### 4.1 Split Train / Test Dataset

In [11]:
X_train, X_test, y_train, y_test = prep.split_data(X, y, train_ratio=0.5)

In [None]:
# select_f = prep.feature_screening(X_train, y_train, num_features=110)
# X_train = X_train[:, select_f]
# X_test = X_test[:, select_f]

In [13]:
X_train.shape, X_test.shape, y_train.shape, y_test.shape

((227, 7206), (228, 7206), (227,), (228,))

# 5. Modeling

In [None]:
NUM_ITER = 10000
HIDDEN_NEURONS = 100
LEARNING_RATE = 0.1
INIT_METHOD = 'Xavier'
INIT_DIST = 'Normal'

In [None]:
model = train_model(X_train, y_train, tree,\
                NUM_ITER, HIDDEN_NEURONS, LEARNING_RATE,\
                INIT_METHOD, INIT_DIST)

# 6. Prediction

In [14]:
tree.leaves

(Node('/Reuters/CCAT/C11', id=4),
 Node('/Reuters/CCAT/C21', id=5),
 Node('/Reuters/CCAT/C24', id=6),
 Node('/Reuters/GCAT/GCRIM', id=7),
 Node('/Reuters/GCAT/GDIP', id=8),
 Node('/Reuters/GCAT/GPOL', id=9),
 Node('/Reuters/GCAT/GVIO', id=10),
 Node('/Reuters/MCAT/M12', id=11),
 Node('/Reuters/MCAT/M14/M141', id=13),
 Node('/Reuters/MCAT/M14/M142', id=14))