In [1]:
import torch
import numpy as np

## loss function

In [2]:
def extract_distinct_labels(group_labels):
    '''
    compute labels dictionary
    {
        'type 1': [sample_start_position_1, sample_end_position_1],
        'type 2': [sample_start_position_2, sample_end_position_2],
        ...
    }
    '''
    all_labels = {}
    for i, label in enumerate(group_labels):
        if label not in all_labels:
            all_labels[label] = [i]
        if i == len(group_labels)-1 or group_labels[i+1] not in all_labels:
            all_labels[label].append(i+1)
    return all_labels

In [3]:
def calculate_centers_SRW(P, all_labels):
    '''
    compute group centers, adapted from SRW centroid() function
    
    return a tensor with group centroids
    if we have k groups and n nodes, then C.shape = (k, n)
    '''
    C = torch.zeros(size=(len(all_labels), P.shape[1]), dtype=torch.float64)
    count = 0
    for label in all_labels:
        start, end = all_labels[label]
        C[count,:] = torch.sum(P[start: end, :], axis = 0) / (end - start)
        count += 1
    return C

In [4]:
def label_to_id(all_labels):
    '''
    mapping label string to id
     {
        'type 1': 0,
        'type 2': 1,
        ...
    }
    
    '''
    label_id = {}
    count = 0
    for label in all_labels:
        label_id[label] = count
        count += 1
    return label_id

In [5]:
def loss(lambda_value, params, beta, P, group_labels, all_labels, is_trained):
    '''
    This function computes loss function adapted from SRW cost_func_WMW()
    '''
    
    '''l1 norm'''
    loss_value = torch.tensor(0, dtype =torch.float64)
    for param in params:
        loss_value += lambda_value * torch.norm(param, p=2)
    accuracy = 0.0

    '''retrieve centers'''
    C = calculate_centers_SRW(P, all_labels)
    '''retrieve ids for group labels'''
    label_id = label_to_id(all_labels)
    
    '''necessary intermediate value for computing loss'''
    P_dot_CT = torch.matmul(P, C.T)
    C_dot_CT = torch.matmul(C, C.T)
    P_dot_PT = torch.matmul(P, P.T)
    
    '''simply copy from SRW cost_func_WMW()'''
    for u in range(P.shape[0]):
        x_u = torch.tensor(-2.0, dtype =torch.float64)
        i = label_id[group_labels[u]]
        start, end = all_labels[group_labels[u]]
        group_sample = end - start
        if is_trained == False:
            coeff = max((group_sample / (group_sample - 1)) ** 2, 1.0)
        else:
            coeff = 1.0
        dist_ui = coeff *(P_dot_PT[u,u] - 2 * P_dot_CT[u, i] + C_dot_CT[i,i])
        for label in label_id:
            if label != group_labels[u]:
                j = label_id[label]
                x_u_tmp = dist_ui -(P_dot_PT[u,u] - 2 * P_dot_CT[u,j] + C_dot_CT[j,j])
                if x_u_tmp > x_u:
                    x_u = x_u_tmp
        '''if correctly classified, increase accuracy'''
        if x_u < 0.0:
            accuracy += 1.0
        loss_value += 1. / (1+torch.exp(-x_u / beta))
    return loss_value, accuracy / P.shape[0]

## Model class implement with pytorch
* train with simple gradient descent
* add validation module 
* support MLP in activation module
* support Sigmoid and ReLu activation in MLP
* support Softplus and Gaussian activation in MLP
* use torch.optim to optimize

In [6]:
class SRW_pytorch:
    def __init__(self, n_iter, lambda_value, beta, features, edges, group_labels_train, P_train, group_labels_val, P_val, betas, layers, params, rst_prob, lr):
        self.n_iter = n_iter #number of iterations for training'''
        self.lambda_value = lambda_value #coefficient for regularization'''
        self.beta = beta #parameter for loss function'''
        self.features = features #features of training dataset'''
        self.edges = edges #edges of training dataset'''
        self.group_labels_train = group_labels_train #labels of training dataset'''
        self.P_train = P_train #initialized P matrix'''
        self.group_labels_val = group_labels_val #labels of training dataset'''
        self.P_val = P_val #initialized P matrix'''
        self.rst_prob = rst_prob #random walk reset probabilities'''
        self.lr = lr #learning rate'''
        self.params = params #params for MLP activation
        self.layers = layers #layers for MLP 
        self.betas = betas # adam beta
        self.eps = 1e-8
        
    def train(self):
        optimizer = torch.optim.Adam(self.params, lr=self.lr, betas = self.betas)
        for t in range(n_iter):
            optimizer.zero_grad()
            
            '''compute edge strength, through MLP activation'''
            strength = self.activation()

            '''create transition matrix Q'''
            Q = torch.zeros(size=(self.P_train.shape[1], self.P_train.shape[1]), dtype = torch.float64)
            for j in range(strength.shape[0]):
                Q[self.edges[j][0], self.edges[j][1]] = strength[j, 0]

            '''normalize Q'''
            Q = Q / (torch.sum(Q, axis = 1) + self.eps).reshape(-1,1)
            
            '''noramlize P'''
            P_init = self.P_train / (torch.sum(self.P_train, axis = 1) + self.eps).reshape(-1,1)

            '''create P matrix for random walk'''
            P = P_init.detach().clone()
            
            '''for test, only peform random walk 20 times''' 
            for j in range(20):
                P = (1.0-self.rst_prob) * (torch.matmul(P,Q)) + self.rst_prob * P_init

            '''compute loss and backward()'''
            all_labels = extract_distinct_labels(self.group_labels_train)
            loss_value, accuracy = loss(self.lambda_value, self.params, self.beta, P, self.group_labels_train, all_labels, False)
            loss_value.backward()
            
            '''update parameters'''
            optimizer.step()

            print("[%d/%d] training loss: %.4f\t training accuracy: %.4f" %(t+1, n_iter, loss_value, accuracy))

            loss_value_val, accuracy_val = self.validation()
            print("[%d/%d] validation loss: %.4f\t validation accuracy: %.4f" %(t+1, n_iter, loss_value_val, accuracy_val))
            
            
    
    def activation(self):
        strength = self.features
        for i in range(len(self.layers)):
            if self.layers[i] == 'sigmoid':
                strength = 1.0 / (1.0 + torch.exp(-torch.matmul(strength, self.params[i])))
            elif self.layers[i] == 'ReLu':
                strength = torch.nn.functional.relu(torch.matmul(strength, self.params[i]))
            elif self.layers[i] == 'softplus':
                strength = torch.log(1.0 + torch.exp(torch.matmul(strength, self.params[i])))
            elif self.layers[i] == 'gaussian':
                strength = torch.exp(-torch.matmul(strength, self.params[i]) ** 2)
            else:
                raise NotImplementedError("%s layer has not implemented yet" %(self.layer[i]))
                
        return strength
        
    def validation(self):
        strength = self.activation()
        
        Q = torch.zeros(size=(self.P_val.shape[1], self.P_val.shape[1]), dtype =torch.float64)
        for j in range(strength.shape[0]):
            Q[self.edges[j][0], self.edges[j][1]] = strength[j, 0]
        Q = Q / (torch.sum(Q, axis = 1) + self.eps).reshape(-1,1)
        P_init = self.P_val / (torch.sum(self.P_val, axis = 1) + self.eps).reshape(-1,1)
        P = P_init.detach().clone()
        for j in range(20):
            P = (1-self.rst_prob) * (torch.matmul(P,Q)) + self.rst_prob * P_init
        
        all_labels = extract_distinct_labels(self.group_labels_val)
        loss_value, accuracy = loss(self.lambda_value, self.params, self.beta, P, self.group_labels_val, all_labels, False)
        
        return loss_value, accuracy
    
    def test(self, P_test, group_labels_test):
        strength = self.activation()
        
        Q = torch.zeros(size=(P_test.shape[1], P_test.shape[1]), dtype =torch.float64)
        for j in range(strength.shape[0]):
            Q[self.edges[j][0], self.edges[j][1]] = strength[j, 0]
        Q = Q / (torch.sum(Q, axis = 1) + self.eps).reshape(-1,1)
        P_init = P_test / (torch.sum(P_test, axis = 1) + self.eps).reshape(-1,1)
        P = P_init.detach().clone()
        for j in range(20):
            P = (1-self.rst_prob) * (torch.matmul(P,Q)) + self.rst_prob * P_init
        
        all_labels = extract_distinct_labels(group_labels_test)
        loss_value, accuracy = loss(self.lambda_value, self.params, self.beta, P, group_labels_test, all_labels, False)
        return loss_value, accuracy

## Real data

In [7]:
%load_ext autoreload
%autoreload 2

In [8]:
import pandas as pd
import SRW_v044 as SRW
import pickle

#### loading data

In [9]:
edges, features, node_names = SRW.load_network('data/BRCA_edge2features_2.txt')

P_init_train, sample_names_train = SRW.load_samples('data/BRCA_training_data_2.txt', node_names)

P_init_val, sample_names_val = SRW.load_samples('data/BRCA_validation_data_2.txt', node_names)

group_labels_train = SRW.load_grouplabels('data/BRCA_training_lables_2.txt')

group_labels_val = SRW.load_grouplabels('data/BRCA_validation_lables_2.txt')

* Loading network...
	- Nodes in adjacency matrix: 557
	- Nodes in adjacency matrix: 557


#### preprocessing

In [10]:
def sort_argsort(seq):
    argsort_seq = [i for (v, i) in sorted((v, i) for (i, v) in enumerate(seq))]
    seq.sort()
    return seq, argsort_seq

In [11]:
group_labels_train, group_labels_train_argsort = sort_argsort(group_labels_train)
group_labels_val, group_labels_val_argsort = sort_argsort(group_labels_val)

In [12]:
P_init_train = P_init_train.toarray()[group_labels_train_argsort,:]
P_init_val = P_init_val.toarray()[group_labels_val_argsort,:]

#### initialize tensor

In [13]:
P_init_train = torch.tensor(P_init_train, requires_grad = True, dtype = torch.float64)
features = torch.tensor(features.toarray(), requires_grad = True, dtype = torch.float64)
P_init_val = torch.tensor(P_init_val, requires_grad = False, dtype = torch.float64)

#### run framework

In [14]:
n_iter = 10
lambda_value = 0.1
beta = 2e-4
adam_betas = (0.9, 0.999)
rst_prob = 0.3
lr = 1.0
w1 = torch.normal(mean = 0, std = 1, size = (features.shape[1], features.shape[1] // 2), requires_grad = True, dtype = torch.float64)
w2 = torch.normal(mean = 0, std = 1, size = (features.shape[1] // 2, 1), requires_grad = True, dtype = torch.float64)
params = [w1, w2]
layers = ['softplus', 'softplus']
solver = SRW_pytorch(n_iter, lambda_value, beta, features, edges, group_labels_train, P_init_train, group_labels_val, P_init_val, adam_betas, layers, params, rst_prob, lr)

In [15]:
solver.train()

[1/10] training loss: 265.2046	 training accuracy: 0.5546
[1/10] validation loss: 135.9545	 validation accuracy: 0.5524
[2/10] training loss: 264.3640	 training accuracy: 0.5615
[2/10] validation loss: 138.6296	 validation accuracy: 0.5524
[3/10] training loss: 260.5508	 training accuracy: 0.5754
[3/10] validation loss: 135.9154	 validation accuracy: 0.5769
[4/10] training loss: 258.2404	 training accuracy: 0.5823
[4/10] validation loss: 148.8951	 validation accuracy: 0.5000
[5/10] training loss: 280.5579	 training accuracy: 0.5390
[5/10] validation loss: 137.7253	 validation accuracy: 0.5490
[6/10] training loss: 265.5510	 training accuracy: 0.5650
[6/10] validation loss: 13.1728	 validation accuracy: 1.0000
[7/10] training loss: 13.1728	 training accuracy: 1.0000
[7/10] validation loss: 13.8063	 validation accuracy: 1.0000
[8/10] training loss: 13.8063	 training accuracy: 1.0000
[8/10] validation loss: 14.3281	 validation accuracy: 1.0000
[9/10] training loss: 14.3281	 training accur

## Lung data 

In [24]:
edges, features, node_names = SRW.load_network('lung_data/lung_edge2features_2.txt')
P_init_train, sample_names_train = SRW.load_samples('lung_data/lung_training_data_2.txt', node_names)
P_init_val, sample_names_val = SRW.load_samples('lung_data/lung_validation_data_2.txt', node_names)
P_init_test, sample_names_test = SRW.load_samples('lung_data/lung_testing_data_2.txt', node_names)
group_labels_train = SRW.load_grouplabels('lung_data/lung_training_lables_2.txt')
group_labels_val = SRW.load_grouplabels('lung_data/lung_validation_lables_2.txt')
group_labels_test = SRW.load_grouplabels('lung_data/lung_testing_lables_2.txt')

* Loading network...
	- Nodes in adjacency matrix: 441
	- Nodes in adjacency matrix: 441
	- Nodes in adjacency matrix: 441


In [25]:
group_labels_train, group_labels_train_argsort = sort_argsort(group_labels_train)
group_labels_val, group_labels_val_argsort = sort_argsort(group_labels_val)
group_labels_test, group_labels_test_argsort = sort_argsort(group_labels_test)

In [26]:
P_init_train = P_init_train.toarray()[group_labels_train_argsort,:]
P_init_val = P_init_val.toarray()[group_labels_val_argsort,:]
P_init_test = P_init_test.toarray()[group_labels_test_argsort,:]

In [27]:
features = torch.tensor(features.toarray(), requires_grad = True, dtype = torch.float64)
P_init_train = torch.tensor(P_init_train, requires_grad = True, dtype = torch.float64)
P_init_val = torch.tensor(P_init_val, requires_grad = False, dtype = torch.float64)
P_init_test = torch.tensor(P_init_test, requires_grad = False, dtype = torch.float64)

In [28]:
n_iter = 10
lambda_value = 0.1
beta = 2e-4
adam_betas = (0.9, 0.999)
rst_prob = 0.3
lr = 1.0
w1 = torch.normal(mean = 0, std = 1, size = (features.shape[1], features.shape[1] // 2), requires_grad = True, dtype = torch.float64)
w2 = torch.normal(mean = 0, std = 1, size = (features.shape[1] // 2, 1), requires_grad = True, dtype = torch.float64)
params = [w1, w2]
layers = ['softplus', 'softplus']
solver = SRW_pytorch(n_iter, lambda_value, beta, features, edges, group_labels_train, P_init_train, group_labels_val, P_init_val, adam_betas, layers, params, rst_prob, lr)

In [29]:
solver.train()

[1/10] training loss: 77.9222	 training accuracy: 0.5172
[1/10] validation loss: 29.5850	 validation accuracy: 0.3429
[2/10] training loss: 77.7418	 training accuracy: 0.5172
[2/10] validation loss: 27.7609	 validation accuracy: 0.4571
[3/10] training loss: 76.4790	 training accuracy: 0.5241
[3/10] validation loss: 27.5109	 validation accuracy: 0.5143
[4/10] training loss: 80.6495	 training accuracy: 0.5103
[4/10] validation loss: 31.0402	 validation accuracy: 0.4286
[5/10] training loss: 86.8053	 training accuracy: 0.5241
[5/10] validation loss: 31.9661	 validation accuracy: 0.5429
[6/10] training loss: 85.7731	 training accuracy: 0.5172
[6/10] validation loss: 33.3624	 validation accuracy: 0.4857
[7/10] training loss: 87.1587	 training accuracy: 0.4966
[7/10] validation loss: 35.6264	 validation accuracy: 0.4286
[8/10] training loss: 89.9011	 training accuracy: 0.4897
[8/10] validation loss: 36.2861	 validation accuracy: 0.4000
[9/10] training loss: 88.6274	 training accuracy: 0.5103

In [30]:
loss_test, accuracy_test = solver.test(P_init_test, group_labels_test)

In [31]:
print("test loss: %.4f\t test accuracy: %.4f" %(loss_test, accuracy_test))

test loss: 16.8955	 test accuracy: 1.0000
