In [None]:
import torch
import numpy as np

## loss function

In [None]:
def extract_distinct_labels(group_labels):
    '''
    compute labels dictionary
    {
        'type 1': [sample_start_position_1, sample_end_position_1],
        'type 2': [sample_start_position_2, sample_end_position_2],
        ...
    }
    '''
    all_labels = {}
    for i, label in enumerate(group_labels):
        if label not in all_labels:
            all_labels[label] = [i]
        if i == len(group_labels)-1 or group_labels[i+1] not in all_labels:
            all_labels[label].append(i+1)
    return all_labels

In [None]:
def calculate_centers_SRW(P, all_labels):
    '''
    compute group centers, adapted from SRW centroid() function
    
    return a tensor with group centroids
    if we have k groups and n nodes, then C.shape = (k, n)
    '''
    C = torch.zeros(size=(len(all_labels), P.shape[1]))
    count = 0
    for label in all_labels:
        start, end = all_labels[label]
        C[count,:] = torch.sum(P[start: end, :], axis = 0) / (end - start)
        count += 1
    return C

In [None]:
def label_to_id(all_labels):
    '''
    mapping label string to id
     {
        'type 1': 0,
        'type 2': 1,
        ...
    }
    
    '''
    label_id = {}
    count = 0
    for label in all_labels:
        label_id[label] = count
        count += 1
    return label_id

In [None]:
def loss(lambda_value, params, beta, P, group_labels, all_labels, is_train):
    '''
    This function computes loss function adapted from SRW cost_func_WMW()
    '''
    
    '''l2 norm'''
    loss_value = None
    for param in params:
        if loss_value is None:
            loss_value = lambda_value * torch.norm(param, p=2)
        else:
            loss_value += lambda_value * torch.norm(param, p=2)
    accuracy = 0.0
    
    '''retrieve centers'''
    C = calculate_centers_SRW(P, all_labels)
    
    '''retrieve ids for group labels'''
    label_id = label_to_id(all_labels)
    
    '''necessary intermediate value for computing loss'''
    P_dot_CT = torch.matmul(P, C.T)
    C_dot_CT = torch.matmul(C, C.T)
    P_dot_PT = torch.matmul(P, P.T)
    
    '''simply copy from SRW cost_func_WMW()'''
    for u in range(P.shape[0]):
        x_u = torch.tensor(-2.0)
        i = label_id[group_labels[u]]
        start, end = all_labels[group_labels[u]]
        group_sample = end - start
        if is_train == False:
            coeff = max((group_sample / (group_sample - 1)) ** 2, 1.0)
        else:
            coeff = 1.0
        dist_ui = coeff *(P_dot_PT[u,u] - 2 * P_dot_CT[u, i] + C_dot_CT[i,i])
        for label in label_id:
            if label != group_labels[u]:
                j = label_id[label]
                x_u_tmp = dist_ui -(P_dot_PT[u,u] - 2 * P_dot_CT[u,j] + C_dot_CT[j,j])
                if x_u_tmp > x_u:
                    x_u = x_u_tmp
        '''if correctly classified, increase accuracy'''
        if x_u < 0.0:
            accuracy += 1.0
        loss_value += 1. / (1+torch.exp(-x_u / beta))
    return loss_value, accuracy / P.shape[0]

## Model class implement with pytorch
* train with simple gradient descent
* add validation module 
* add Adam update
* support MLP in activation module
* support Sigmoid and ReLu activation in MLP
* support Softplus and Gaussian activation in MLP
* support Nesterov update method

In [None]:
class SRW_pytorch:
    def __init__(self, n_iter, lambda_value, beta, features, edges, group_labels_train, P_train, group_labels_val, P_val, betas, layers, params, rst_prob, lr, update = 'Adam'):
        self.n_iter = n_iter #number of iterations for training'''
        self.lambda_value = lambda_value #coefficient for l1 norm regularization'''
        self.beta = beta #parameter for loss function'''
        self.features = features #features of training dataset'''
        self.edges = edges #edges of training dataset'''
        self.group_labels_train = group_labels_train #labels of training dataset'''
        self.P_train = P_train #initialized P matrix'''
        self.group_labels_val = group_labels_val #labels of training dataset'''
        self.P_val = P_val #initialized P matrix'''
        self.rst_prob = rst_prob #random walk reset probabilities'''
        self.lr = lr #learning rate'''
        self.params = params #params for MLP activation
        self.layers = layers
        self.state = [{'m_t': torch.zeros(p.shape), 'v_t': torch.zeros(p.shape)} for p in self.params] #used for adam update
        self.betas = betas # hard code for now
        self.eps = 1e-8 # hard code for now
        self.update = update
        
    def train(self):
        for t in range(n_iter):
            '''compute edge strength, sigmoid activation'''
            
            strength = self.activation()

            '''create transition matrix Q'''
            Q = torch.zeros(size=(self.P_train.shape[1], self.P_train.shape[1]))
            for j in range(strength.shape[0]):
                Q[self.edges[j][0], self.edges[j][1]] = strength[j, 0]

            '''normalize Q'''
            Q = Q / (torch.sum(Q, axis = 1) + self.eps).reshape(-1,1)
            
            '''noramlize P'''
            P_init = self.P_train / (torch.sum(self.P_train, axis = 1) + self.eps).reshape(-1,1)

            '''create P matrix for random walk'''
            P = P_init.detach().clone()
            
            '''for test, only peform random walk 30 times (should be enough to converge)''' 
            
            for j in range(40):
                P = (1.0-self.rst_prob) * (torch.matmul(P,Q)) + self.rst_prob * P_init
            '''compute loss and backward()'''
            all_labels = extract_distinct_labels(self.group_labels_train)
            loss_value, accuracy = loss(self.lambda_value, self.params, self.beta, P, self.group_labels_train, all_labels, True)
            loss_value.backward(retain_graph = True)
            
            '''update parameters'''
            self.step(t+1)
            for i in range(len(self.params)):
                self.params[i].grad.zero_()

            print("[%d/%d] training loss: %.4f\t training accuracy: %.4f" %(t+1, n_iter, loss_value.data, accuracy))

            loss_value_val, accuracy_val = self.validation()
            print("[%d/%d] validation loss: %.4f\t validation accuracy: %.4f" %(t+1, n_iter, loss_value_val.data, accuracy_val))
            
            
    
    def activation(self):
        strength = self.features
        for i in range(len(self.params)):
            if self.layers[i] == 'sigmoid':
                strength = 1.0 / (1.0 + torch.exp(-torch.matmul(strength, self.params[i])))
            elif self.layers[i] == 'ReLu':
                strength = torch.nn.functional.relu(torch.matmul(strength, self.params[i]))
            elif self.layers[i] == 'softplus':
                strength = torch.log(1.0 + torch.exp(torch.matmul(strength, self.params[i])))
            elif self.layers[i] == 'gaussian':
                strength = torch.exp(-torch.matmul(strength, self.params[i]) ** 2)
            else:
                raise NotImplementedError("%s layer has not implemented yet" %(self.layer[i]))
                
        return strength
    
    def step(self, t):
        if self.update == 'Adam':
            b1, b2 = self.betas

            B1 = 1.0 - np.power(b1, t)
            B2 = 1.0 - np.power(b2, t)
            
            for i in range(len(self.params)):
                self.state[i]['m_t'] = b1 * self.state[i]['m_t'] + (1 - b1) * self.params[i].grad.data
                self.state[i]['v_t'] = b2 * self.state[i]['v_t'] + (1 - b2) * (self.params[i].grad.data ** 2)
                st = self.lr / B1
                D = torch.sqrt(self.state[i]['v_t'] / B2) + self.eps
                self.params[i].data = self.params[i].data - self.state[i]['m_t'] / D * st
        elif self.update == 'GD':
            for i in range(len(self.params)):
                self.params[i].data -= self.lr * self.params[i].grad.data
        elif self.update == 'Nesterov':
            for i in range(len(self.params)):
                prev = self.state[i]['m_t'].detach().clone()
                self.state[i]['m_t'] = 0.9 * self.state[i]['m_t'] - self.lr * self.params[i].grad.data
                self.params[i].data = -0.9 * prev + self.state[i]['m_t'] * 1.9
        else:
            raise NotImplementedError("%s optimization method has not implemented yet" %(self.update))
        
    def validation(self):
        strength = self.activation()
        
        Q = torch.zeros(size=(self.P_val.shape[1], self.P_val.shape[1]))
        for j in range(strength.shape[0]):
            Q[self.edges[j][0], self.edges[j][1]] = strength[j, 0]
        Q = Q / (torch.sum(Q, axis = 1) + self.eps).reshape(-1,1)
        P_init = self.P_val / (torch.sum(self.P_val, axis = 1) + self.eps).reshape(-1,1)
        P = P_init.detach().clone()
        for j in range(40):
            P = (1-self.rst_prob) * (torch.matmul(P,Q)) + self.rst_prob * P_init
        
        all_labels = extract_distinct_labels(self.group_labels_val)
        loss_value, accuracy = loss(self.lambda_value, self.params, self.beta, P, self.group_labels_val, all_labels, False)
        
        return loss_value, accuracy

## Real data

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import pandas as pd
import SRW_v044 as SRW
import pickle

#### loading data

In [None]:
edges, features, node_names = SRW.load_network('data/BRCA_edge2features_2.txt')

In [None]:
P_init_train, sample_names_train = SRW.load_samples('data/BRCA_training_data_2.txt', node_names)

In [None]:
P_init_val, sample_names_val = SRW.load_samples('data/BRCA_validation_data_2.txt', node_names)

In [None]:
group_labels_train = SRW.load_grouplabels('data/BRCA_training_lables_2.txt')

In [None]:
group_labels_val = SRW.load_grouplabels('data/BRCA_validation_lables_2.txt')

#### preprocessing

In [None]:
def sort_argsort(seq):
    argsort_seq = [i for (v, i) in sorted((v, i) for (i, v) in enumerate(seq))]
    seq.sort()
    return seq, argsort_seq

In [None]:
group_labels_train, group_labels_train_argsort = sort_argsort(group_labels_train)
group_labels_val, group_labels_val_argsort = sort_argsort(group_labels_val)

In [None]:
P_init_train = P_init_train.toarray()[group_labels_train_argsort,:]
P_init_val = P_init_val.toarray()[group_labels_val_argsort,:]

#### initialize tensor

In [None]:
P_init_train = torch.tensor(P_init_train, requires_grad = True, dtype = torch.float32)
features = torch.tensor(features.toarray(), requires_grad = True, dtype = torch.float32)
P_init_val = torch.tensor(P_init_val, requires_grad = False, dtype = torch.float32)

#### run framework

In [None]:
n_iter = 200
lambda_value = 0.1
beta = 2e-4
adam_betas = (0.9, 0.999)
rst_prob = 0.3
lr = 4e-3
w1 = torch.normal(mean = 0, std = 1, size = (features.shape[1], features.shape[1] // 2), requires_grad = True, dtype = torch.float32)
w2 = torch.normal(mean = 0, std = 1, size = (features.shape[1] // 2, 1), requires_grad = True, dtype = torch.float32)
params = [w1, w2]
layers = ['softplus', 'ReLu']
solver = SRW_pytorch(n_iter, lambda_value, beta, features, edges, group_labels_train, P_init_train, group_labels_val, P_init_val, adam_betas, layers, params, rst_prob, lr, "Adam")

In [None]:
solver.train()