In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import pandas as pd
import SRW_v044 as SRW
import pickle

In [None]:
edges, features, node_names = SRW.load_network('data/BRCA_edge2features_2.txt')

In [None]:
P_init_train, sample_names_train = SRW.load_samples('data/BRCA_training_data_2.txt', node_names)

In [None]:
P_init_val, sample_names_val = SRW.load_samples('data/BRCA_validation_data_2.txt', node_names)

In [None]:
group_labels_train = SRW.load_grouplabels('data/BRCA_training_lables_2.txt')

In [None]:
group_labels_val = SRW.load_grouplabels('data/BRCA_validation_lables_2.txt')

In [None]:
nnodes = len(node_names)
rst_prob = 0.3
lam = 1e-1

In [None]:
feature_names = []
with open('data/BRCA_feature_names_2.txt') as f:
    for line in f.read().rstrip().splitlines():
        feature_names.append(line)
feature_names.append('selfloop')
feature_names.append('intercept')

In [None]:
import SRW_v044 as SRW

SRW_obj = SRW.SRW_solver(edges, features, nnodes, P_init_train, rst_prob, group_labels_train, lam, 
                         w_init_sd=0.01, w=None, feature_names=feature_names, 
                         sample_names=sample_names_train, node_names=node_names, loss='WMW', 
                         norm_type='L1', learning_rate=0.2, update_w_func='Adam', 
                         P_init_val=P_init_val, group_labels_val=group_labels_val, ncpus=16, 
                         maxit=1000, early_stop=500, WMW_b=2e-4)

### Arguments of SRW_solver objects  
**edges** (e by 2, int, ndarray): Edges in the network  
**features** (e by w, float, csc_matrix): Edge features  
**nnodes** (int): Number of nodes in the network  
**P_init** (m by n, float, csr_matrix): Initial state of samples (training set)  
**rst_prob** (float): Reset probability of Random Walk  
**group_labels** (m by 1, str/int, list): Group labels of samples (training set)   
**lam** (float): Regularization parameter, controling the amount of L1/L2 norm  
**w_init_sd** (float): Standard deviation for weight initialization (default 0.01)  
**w** (w by 1, float, list): Initial weights (default None)  
**feature_names** (w by 1, str, list): Feature names (default [])  
**sample_names** (m by 1, str, list): Feature names (default [])  
**node_names** (n by 1, str, list): Feature names (default [])  
**loss** {'WMW'}: Type of the loss funtion (default 'WMW')  
**norm_type** {'L1', 'L2'}: Type of the norm (derault 'L1')  
**learning_rate** (float): Learning rate (default 0.1)  
**update_w_func** {'momentum', 'Nesterov', 'Adam', 'Nadam'}: Function for updating parameters (default 'Adam')  
**P_init_val** (m by n, float, csr_matrix): Initial state of samples (validation set) (default None)  
**group_labels_val** (m by 1, str/int, list): Group labels of samples (validation set) (default None)  
**ncpus** (int): Number of CPUs to use for multiprocess.Pool (default: -1, use all cpus)  
**maxit** (int): Max number of iterations for training the model (default: 1000)  
**early_stop** (int): Stop the learning if the performance doesn't improve in x iterations (default: None)  
**WMW_b** (float): Parameter b in the WMW loss function (default: 2e-4)

In [None]:
SRW_obj.train_SRW_GD()

# After convergence

In [None]:
# Output learned feature weights to a file
SRW_obj.w_map.to_csv('data/BRCA_edge_feature_weights_2.txt', sep='\t')

In [None]:
# Generate SRW_obj.Q_fin_df, the final transition matrix, 
# and SRW_obj.P_fin_df, the final propagated mutation profiles
SRW_obj.generate_Q_and_P_fin()

In [None]:
# Save the object
with open('data/SRW_obj_2', 'wb') as output:
    pickle.dump(SRW_obj, output, pickle.HIGHEST_PROTOCOL)