In [2]:
from __future__ import print_function
from model import RouteNet
import tensorflow as tf
import numpy as np
from read_dataset import generator
import os
from torchviz import make_dot, make_dot_from_trace
tr_path = r'/Users/yanchuanqi/Github/CSE222RouteNet/sample_data'

ModuleNotFoundError: No module named 'torch_geometric'

In [6]:
from read_dataset_tf import generator
tr_path = r'/Users/yanchuanqi/Github/CSE222RouteNet/sample_data'

In [7]:
for batch, delay in generator(tr_path):
    x = batch
    y = delay
    break

[[None 10000 None None None 25000 None None None None None None None None
  None None None None None]
 [10000 None 10000 None 10000 40000 None None None None None None None
  None None None None 10000 10000]
 [None 10000 None 25000 None None None None None None None None None None
  None None None None None]
 [None None 25000 None None 40000 None None None None None None None None
  None None None None None]
 [None 10000 None None None 25000 None None None None None None None None
  None None None None None]
 [25000 40000 None 40000 25000 None 25000 25000 40000 None None 40000
  None None 40000 40000 None 25000 None]
 [None None None None None 25000 None None None None None None None None
  None None None None None]
 [None None None None None 25000 None None 10000 None None None None
  10000 None None None None None]
 [None None None None None 40000 None 10000 None 25000 None 10000 None
  None None None None None None]
 [None None None None None None None None 25000 None 10000 None Non

In [8]:
import torch
import torch.nn as nn
x['links'] = torch.unsqueeze(torch.tensor(x['links']),1)                         
x['paths'] = torch.unsqueeze(torch.tensor(x['paths']),1)                 
x['sequences'] = torch.unsqueeze(torch.tensor(x['sequences']),1)
x['link_capacity'] = torch.unsqueeze(torch.tensor(x['link_capacity']).float(),axis=1)
x['bandwith'] = torch.unsqueeze(torch.tensor(x['bandwith']).float(), axis=1)

In [9]:
import torch.nn.functional as F

In [10]:

class RouteNet(nn.Module):

    def __init__(self):
        super(RouteNet,self).__init__()

        ### Architecture ###
        # for gru need to pay attention to if input is of size:
        # (batch, seq_len, feature size) or seq_len, batch, feature size
        # if sequence length is variable
        # may need to pad the sequence
        
        self.link_state_dim = 32
        self.path_state_dim = 32
        self.readout_dim = 8
        self.output_units = 1
        self.T = 8

        inSize = 10 # place holder
        hSize  = 32
        readSize = 10
        nLayer = 1

        self.l_U = nn.GRU(input_size = self.link_state_dim,
                          hidden_size = hSize,
                          num_layers = nLayer,
                          batch_first=True)
        
        self.p_U = nn.GRU(input_size = self.path_state_dim,
                          hidden_size = hSize,
                          num_layers = nLayer,
                          batch_first=True)
        
        self.readOut = nn.ModuleDict({  'r1': nn.Linear(hSize,self.readout_dim),
                                        'r2': nn.Linear(self.readout_dim,self.readout_dim),
                                        'r3': nn.Linear(self.readout_dim,self.output_units)
                                        })
        

    def forward(self,x):
        
        links = x['links']
        paths = x['paths']
        seqs  = x['sequences']
        bandwidth = x['bandwith']
        link_cap  = x['link_capacity']

        # state matrix shape for the link
        link_h_state_shape = (x['n_links'], self.link_state_dim-1)

        # create hidden state matrix shape for the path  
        path_h_state_shape = (x['n_paths'],self.path_state_dim-1)
        path_h_state = torch.cat((bandwidth,torch.zeros(path_h_state_shape)), axis=1)
        
        # prepare input for path update RNN
        max_seq_len = torch.max(seqs)
        path_rnn_input_shape = (x['n_paths'],max_seq_len+1,self.link_state_dim)
        
        #stack the paths and sequences
        ids = torch.stack((paths,seqs),axis=1)
        ids = torch.squeeze(ids,2)           
        p_ind = ids[:,0]
        s_ind = ids[:,1]

        # flatten the double loop into a bulk matrix using the gather functionality
        # this is an aggregation of the state vector of each link on each path flattened into a (sum_paths(sum_links_on_paths) x h_state) size matrix 
        # using torch.gather
        indices = torch.zeros(len(links),32)
        for i in range(len(links)):
            link_id = links[i]
            indices[i,:] = link_id
            
        # variable dictionary for forward pass
        vd = {}
        
        for t in range(self.T):
            
            
            ############# set up the matrices and variables for each pass through #################
            
            ########## PATH VARIABLES ###########
            
            # input to the path rnn layer P_u
            path_rnn_input_key = 'path_rnn_input_' + str(t)
            vd[path_rnn_input_key] = torch.zeros(path_rnn_input_shape)
            
            if (t > 0):  # for non leaf variables, we need to propagate the gradient back
                vd[path_rnn_input_key].requires_grad = True
            
            
            # path hidden state output from P_U, initialized with just bandwidth at T_0, else copy
            path_h_state_key = 'path_h_state_' + str(t)
            if (t==0):
                vd[path_h_state_key] = torch.cat((bandwidth,torch.zeros(path_h_state_shape)), axis=1)
            else:
                path_h_state_key = 'path_h_state_' + str(t)
                path_h_prev = 'path_h_state_' + str(t-1)
                vd[path_h_state_key] = vd[path_h_prev]
            
            
            # path_hidden state sequence from P_U, used to update links
            path_h_state_seq_key = 'path_h_states_seq_' + str(t)
            
            
            ########## LINK VARIABLES ###########
            
            # vector to store the link_hidden states
            if (t == 0):
                # create hidden state matrix for links and initialize with first column as link capacity
                link_h_state_key = 'link_h_state_' + str(t)
                vd[link_h_state_key] = torch.cat((link_cap,torch.zeros(link_h_state_shape)),1)
            else:
                # copy hidden state value for next pass through
                link_h_state_key = 'link_h_state_' + str(t)
                linK_h_prev = 'link_h_state_' + str(t-1)
                vd[link_h_state_key] = vd[linK_h_prev]
                
                
            # matrix storing the hidden states of links on paths
            # i.e. the hidden state of all links in the x['links'] list
            h_link_path_key = 'h_links_on_paths_' + str(t)
            vd[h_link_path_key] = torch.gather(vd[link_h_state_key],0,indices.long())
            
            #link messages extracted from the path hidden state sequence output from P_U
            link_message_key = 'link_messages_' + str(t)

            
             # container for the link messages that are extracted from path rnn hidden states
            agg_link_message_key = 'aggregated_link_message_' + str(t)
            vd[agg_link_message_key] = torch.zeros((x['n_links'],self.link_state_dim),requires_grad=True)
            
            ########################################################################################
            
            
            ################################## DO THE MESSAGE PASSING ##############################
            
            # prepare input for path RNN
            vd[path_rnn_input_key] = vd[path_rnn_input_key].index_put(indices = [p_ind,s_ind],
                                                                      values = vd[h_link_path_key])
            # pass through the path RNN
            vd[path_h_state_seq_key], vd[path_h_state_key] = self.p_U(vd[path_rnn_input_key],
                                                                      torch.unsqueeze(vd[path_h_state_key],
                                                                      0))
            # reformat
            vd[path_h_state_key] = vd[path_h_state_key].squeeze(0)
            
            # extract link messages from the path RNN sequence output
            # equivalent to tf.gather_nd
            vd[link_message_key] = vd[path_h_state_seq_key][p_ind,s_ind,:]
           
            # aggregate the link messages
            vd[agg_link_message_key] = vd[agg_link_message_key].index_put([links.squeeze(1)],
                                                                           vd[link_message_key],
                                                                           accumulate=True)
            # update the state of the links by passing through link 
            _, vd[link_h_state_key] = self.l_U(torch.unsqueeze(vd[agg_link_message_key],1),
                                               torch.unsqueeze(vd[link_h_state_key].squeeze(0),0))
            # reformat
            vd[link_h_state_key] = vd[link_h_state_key].squeeze(0)
            
            ##########################################################################################

        # readout from the paths
        y = self.readout(vd[path_h_state_key])
        return y


    def readout(self,path_state):
        x = F.relu(self.readOut['r1'](path_state))
        x = F.relu(self.readOut['r2'](x))
        x = self.readOut['r3'](x)
        return x

In [39]:
# Code snippet from runner

def train(model,label):
    optimizer = torch.optim.Adam(model.parameters(), .001) # optimizer method for gradient descent
#     criterion = torch.nn.MSELoss() 
    criterion = CustomLoss(100)
    model.train() #put model in training mode
    for epoch in range(10):
        tr_loss = []
        optimizer.zero_grad()  
        outputs = model(x)
        print(outputs[0].dtype,torch.tensor(label)[0].dtype)
        #the input of the parameters should be re-defined here
        loss = criterion(outputs, torch.unsqueeze(torch.tensor(label),1)) 
        loss.backward()                        
        optimizer.step()                  
        tr_loss.append(loss.item())
        print(torch.mean(torch.tensor(tr_loss)))

class CustomLoss(nn.Module):
    def __init__(self, num_packet, , , , ): 
        # maybe it's not a great idea to throw all the output into the init function
        # it should be okay to put them as the input parameter of the forward method, but I haven't been able to test it
        super(CustomLoss, self).__init__()
        self.n = num_packet
        self.d = 
        self.d_t = 
        self.j =
        self.j_t =
        
    def forward(self):
        # calculate the negative log-likelihood and return their average
        nll = -self.n * ((self.j_t + (self.d_t - self.d)**2)/(2*self.j**2) + torch.log(self.j))
        return torch.mean(nll)
    

In [35]:
model = RouteNet()
model = model.float()

In [36]:
yy = model(x)

In [37]:
# make_dot(yy)

In [40]:
train(model,y)

torch.float32 torch.float32
tensor(-1.0167)
torch.float32 torch.float32
tensor(-1.0205)
torch.float32 torch.float32
tensor(-1.0279)
torch.float32 torch.float32
tensor(-1.0353)
torch.float32 torch.float32
tensor(-1.0440)
torch.float32 torch.float32
tensor(-1.0510)
torch.float32 torch.float32
tensor(-1.0578)
torch.float32 torch.float32
tensor(-1.0649)
torch.float32 torch.float32
tensor(-1.0723)
torch.float32 torch.float32
tensor(-1.0804)


In [13]:
tr_loss

NameError: name 'tr_loss' is not defined

In [62]:
#test
a = torch.tensor([1.,2.,3.])
b = torch.tensor([6.,5.,4.])
c = torch.tensor([7.,8.,9.])
d = torch.tensor([10.,11.,12.])
e = torch.tensor([1,2,3])

In [63]:
-e * (d/(2 * c**2) + (b-a)**2/(2 * c**2) + torch.log(c))

tensor([-2.3031, -4.4714, -6.8324])