In [None]:
import torch
import numpy as np

## Initialization with tensor

In [None]:
nsamples = 100
nnodes = 1000
edge_freq = 0.004
cliq_edge_freq = 0.114
hi_mut_freq = 0.5
hi_node = nnodes*3/4
group_labels = ['Subtype 1']*(nsamples//2) + ['Subtype 2']*(nsamples//2)
feature_names = ['Subnetwork 1', 'Subnetwork 2', 'High mut source', 'High mut target', 
                 'Random 1', 'Random 2', 
                 'Self loop', 'Intercept']
node_names = ['{}'.format(i) for i in range(1,nnodes+1)]
sample_names = ['{}'.format(i) for i in range(1,nsamples+1)]

rand_mut_freq = 0.015

In [None]:
degrees = [0]*nnodes
edges = []
features = [] #(11) cliq1, cliq2, hi_mut_source, hi_mut_target, rand1, rand2, rand3, rand4, rand5, self_loop, intercept
for i in range(nnodes-1):
    for j in range(i+1,nnodes):
        if ((i<100 and j<100) and np.random.random()<cliq_edge_freq) or np.random.random()<edge_freq:
            edges.append([i,j])
            edges.append([j,i])
            features.append([0,0,0,0,np.random.random(),np.random.random(),0,1])
            features.append([0,0,0,0,np.random.random(),np.random.random(),0,1])
            if (i<50 and j<50):
                features[-2][0] = 1
                features[-1][0] = 1
            if (i>=50 and i<100 and j>=50 and j<100):
                features[-2][1] = 1
                features[-1][1] = 1
            if i == nnodes-1:
                features[-2][2] = 1
                features[-1][3] = 1
            if j == nnodes-1:
                features[-2][3] = 1
                features[-1][2] = 1
            degrees[i] += 1
            degrees[j] += 1
            

for i in range(nnodes):
    edges.append([i,i])
    features.append([0,0,0,0,np.random.random(),np.random.random(),1,1])

P_init = []
for p in range(nsamples):
    p_init = []
    for i in range(nnodes):
        freq=0
        if p == i:
            freq = 1
        elif i == hi_node:
            freq = hi_mut_freq
        elif i<100:
            if (max(p,i)<50 or min(p,i)>=50):
                freq = 0.015
            else:
                freq = 0.000
        else:
            freq = rand_mut_freq

        if np.random.random() < freq:
            p_init.append(1.0)
        else:
            p_init.append(0.0)

    P_init.append(p_init)

In [None]:
P_init = torch.tensor(P_init, requires_grad = True)
features = torch.tensor(features, requires_grad = True, dtype=torch.float64)

## loss function

In [None]:
def extract_distinct_labels(group_labels):
    '''
    compute labels dictionary
    {
        'type 1': [sample_start_position_1, sample_end_position_1],
        'type 2': [sample_start_position_2, sample_end_position_2],
        ...
    }
    '''
    all_labels = {}
    for i, label in enumerate(group_labels):
        if label not in all_labels:
            all_labels[label] = [i]
        if i == len(group_labels)-1 or group_labels[i+1] not in all_labels:
            all_labels[label].append(i+1)
    return all_labels

In [None]:
def calculate_centers_SRW(P, all_labels):
    '''
    compute group centers, adapted from SRW centroid() function
    
    return a tensor with group centroids
    if we have k groups and n nodes, then C.shape = (k, n)
    '''
    C = torch.zeros(size=(len(all_labels), P.shape[1]))
    count = 0
    for label in all_labels:
        start, end = all_labels[label]
        C[count,:] = torch.sum(P[start: end, :], axis = 0) / (end - start)
        count += 1
    return C

In [None]:
def label_to_id(all_labels):
    '''
    mapping label string to id
     {
        'type 1': 0,
        'type 2': 1,
        ...
    }
    
    '''
    label_id = {}
    count = 0
    for label in all_labels:
        label_id[label] = count
        count += 1
    return label_id

In [None]:
def loss(lambda_value, params, beta, P, group_labels, all_labels, is_train):
    '''
    This function computes loss function adapted from SRW cost_func_WMW()
    '''
    
    '''l1 norm'''
    loss_value = None
    for param in params:
        if loss_value is None:
            loss_value = lambda_value * torch.norm(param, p=1)
        else:
            loss_value += lambda_value * torch.norm(param, p=1)
    accuracy = 0.0
    
    '''retrieve centers'''
    C = calculate_centers_SRW(P, all_labels)
    
    '''retrieve ids for group labels'''
    label_id = label_to_id(all_labels)
    
    '''necessary intermediate value for computing loss'''
    P_dot_CT = torch.matmul(P, C.T)
    C_dot_CT = torch.matmul(C, C.T)
    P_dot_PT = torch.matmul(P, P.T)
    
    '''simply copy from SRW cost_func_WMW()'''
    for u in range(P.shape[0]):
        x_u = torch.tensor(-2.0)
        i = label_id[group_labels[u]]
        start, end = all_labels[group_labels[u]]
        group_sample = end - start
        if is_train == False:
            coeff = (group_sample / (group_sample - 1)) ** 2
        else:
            coeff = 1.0
        dist_ui = coeff *( P_dot_PT[u,u] -2 * P_dot_CT[u, i] + C_dot_CT[i,i])
        for label in label_id:
            if label != group_labels[u]:
                j = label_id[label]
                x_u_tmp = dist_ui -(P_dot_PT[u,u] - 2*P_dot_CT[u,j] + C_dot_CT[j,j])
                if x_u_tmp > x_u:
                    x_u = x_u_tmp
        '''if correctly classified, increase accuracy'''
        if x_u < 0.0:
            accuracy += 1.0
        loss_value += 1. / (1+torch.exp(-x_u / beta))
    return loss_value, accuracy / P.shape[0]

## Model class implement with pytorch
* train with simple gradient descent
* add validation module 
* add Adam update
* support MLP in activation 

In [None]:
class SRW_pytorch:
    def __init__(self, n_iter, lambda_value, beta, group_labels, features, edges, P, nnodes, layers, params, rst_prob, lr, update = 'GD'):
        self.n_iter = n_iter #number of iterations for training'''
        self.lambda_value = lambda_value #coefficient for l1 norm regularization'''
        self.beta = beta #parameter for loss function'''
        self.group_labels = group_labels #labels of training dataset'''
        self.features = features #features of training dataset'''
        self.edges = edges #edges of training dataset'''
        self.P = P #initialized P matrix'''
        self.nnodes = nnodes #number of nodes'''
        self.rst_prob = rst_prob #random walk reset probabilities'''
        self.lr = lr #learning rate'''
        self.is_train = False #has the model been trained'''
        self.params = params #params for MLP activation
        self.layers = layers
        self.state = [{'m_t': torch.zeros(p.shape), 'v_t': torch.zeros(p.shape)} for p in self.params] #used for adam update
        self.betas = (0.9, 0.999) # hard code for now
        self.eps = 1e-8 # hard code for now
        self.update = update
        
    def train(self):
        for t in range(n_iter):
            '''compute edge strength, sigmoid activation'''
            strength = self.activation()
            #strength = 1.0 / (1.0 + torch.exp(-torch.matmul(self.features, self.w)))

            '''create transition matrix Q'''
            Q = torch.zeros(size=(self.nnodes, self.nnodes))
            for j in range(strength.shape[0]):
                Q[self.edges[j][0], self.edges[j][1]] = strength[j, 0]

            '''normalize Q'''
            Q = Q / (torch.sum(Q, axis = 1) + 1e-8).reshape(-1,1)
            
            '''noramlize P'''
            P_init = self.P / (torch.sum(self.P, axis = 1) + 1e-8).reshape(-1,1)

            '''create P matrix for random walk'''
            P = torch.detach(P_init)
            P.requires_grad = True

            '''for test, only peform random walk 30 times (should be enough to converge)''' 
            for j in range(30):
                P = (1-self.rst_prob) * (torch.matmul(P,Q)) + self.rst_prob * P_init

            '''compute loss and backward()'''
            all_labels = extract_distinct_labels(self.group_labels)
            loss_value, accuracy = loss(self.lambda_value, self.params, self.beta, P, self.group_labels, all_labels, self.is_train)
            loss_value.backward(retain_graph = True)
            
            '''update parameters'''
            self.step(t+1)
            for i in range(len(self.params)):
                self.params[i].grad.zero_()
            
            print("[%d/%d] training loss: %.4f\t training accuracy: %.4f" %(t+1, n_iter, loss_value.data, accuracy))
        self.is_train = True
    
    def activation(self):
        strength = self.features
        for i in range(len(self.params)):
            if self.layers[i] == 'sigmoid':
                strength = 1.0 / (1.0 + torch.exp(-torch.matmul(strength, self.params[i])))
            elif self.layers[i] == 'ReLu':
                strength = torch.matmul(strength, self.params[i])
                strength[strength < 0] = 0
            else:
                print("Layer not implemented yet")
                return
                
        return strength
    
    def step(self, t):
        if self.update == 'Adam':
            b1, b2 = self.betas

            B1 = 1.0 - np.power(b1, t)
            B2 = 1.0 - np.power(b2, t)
            
            for i in range(len(self.params)):
                self.state[i]['m_t'] = b1 * self.state[i]['m_t'] + (1 - b1) * self.params[i].grad.data
                self.state[i]['v_t'] = b2 * self.state[i]['v_t'] + (1 - b2) * (self.params[i].grad.data ** 2)
                st = self.lr / B1
                D = np.sqrt(self.state[i]['v_t'] / B2) + self.eps
                self.params[i].data = self.params[i].data - self.state[i]['m_t'] / D * st
        
    def validation(self, features_validation, edges_validation, P_validation, group_labels_validation):
        if self.is_train == False:
            print("You should train the model first")
            return
        strength = self.activation()
        
        Q = torch.zeros(size=(self.nnodes, self.nnodes))
        for j in range(strength.shape[0]):
            Q[edges_validation[j][0], edges_validation[j][1]] = strength[j, 0]
        Q = Q / (torch.sum(Q, axis = 1) + 1e-8).reshape(-1,1)
        
        P = torch.detach(P_validation)
        for j in range(30):
                P = (1-self.rst_prob) * (torch.matmul(P,Q)) + self.rst_prob * P_validation
        
        all_labels = extract_distinct_labels(group_labels_validation)
        loss_value, accuracy = loss(self.lambda_value, self.params, self.beta, P, group_labels_validation, all_labels, self.is_train)
        
        return loss_value, accuracy

In [None]:
n_iter = 15
lambda_value = 0.1
beta = 2e-4
rst_prob = 0.3
lr = 1.0
w1 = torch.normal(mean = 0, std = 1, size = (8,4), requires_grad = True, dtype = torch.float64)
w2 = torch.normal(mean = 0, std = 1, size = (4,1), requires_grad = True, dtype = torch.float64)
params = [w1, w2]
layers = ['sigmoid', 'ReLu']
solver = SRW_pytorch(n_iter, lambda_value, beta, group_labels, features, edges, P_init, nnodes, layers, params, rst_prob, lr, "Adam")

In [None]:
solver.train()

In [None]:
loss_value, accuracy = solver.validation(features, edges, P_init, group_labels)
print("validation loss: %.4f\t validation_accuracy: %.4f" %(loss_value.data, accuracy))