In [1]:
import torch
# from model import RouteNet
from CustomLoss import CustomLoss
from SimpleNetworkDataset import NetworkDataset,get_dataloader

In [2]:
fpath = "/home/steve/Documents/GitHub/CSE222a/CSE222RouteNet/data/pt_dir"

In [3]:
data_set = NetworkDataset(fpath)

In [4]:
dloader = get_dataloader(data_set,1)

In [5]:
def train(model,x,label,n_packets):
    optimizer = torch.optim.Adam(model.parameters(), .001) # optimizer method for gradient descent
#     criterion = torch.nn.MSELoss() 
    criterion = CustomLoss()
    model.train() #put model in training mode
    for epoch in range(100):
        tr_loss = []
        optimizer.zero_grad()  
        outputs = model(x)
        #the input of the parameters should be re-defined here
        loss = criterion(outputs, label,n_packets)
        loss.backward()                        
        optimizer.step()                  
        tr_loss.append(loss.item())
        print(tr_loss)

In [6]:
for _,batch in enumerate(dloader):
    x,y = batch
    break
    

In [14]:
import torch
import torch.nn as nn
import torch.nn.functional as F

'''using the implementation from RouteNet'''
class RouteNet(nn.Module):

    def __init__(self):
        super(RouteNet,self).__init__()

        ### Architecture ###
        # for gru need to pay attention to if input is of size:
        # (batch, seq_len, feature size) or seq_len, batch, feature size
        # if sequence length is variable
        # may need to pad the sequence
        
        self.link_state_dim = 32
        self.path_state_dim = 32
        self.readout_dim = 8
        self.output_units = 2
        self.T = 8

        inSize = 10 # place holder
        hSize  = 32
        readSize = 10
        nLayer = 1

        self.l_U = nn.GRU(input_size = self.link_state_dim,
                          hidden_size = hSize,
                          num_layers = nLayer,
                          batch_first=True)
        
        self.p_U = nn.GRU(input_size = self.path_state_dim,
                          hidden_size = hSize,
                          num_layers = nLayer,
                          batch_first=True)
        
        self.readOut = nn.ModuleDict({  'r1': nn.Linear(hSize,self.readout_dim),
                                        'r2': nn.Linear(self.readout_dim,self.readout_dim),
                                        'r3': nn.Linear(self.readout_dim,self.output_units)
                                        })
        

    def forward(self,x):
        
        links = torch.unsqueeze(torch.tensor(x['links']),1)                         
        paths = torch.unsqueeze(torch.tensor(x['paths']),1)                 
        seqs = torch.unsqueeze(torch.tensor(x['sequences']),1)
        link_cap = torch.unsqueeze(torch.tensor(x['link_capacity']).float(),axis=1)
        link_cap = link_cap/torch.max(link_cap)
        bandwidth = torch.unsqueeze(torch.tensor(x['bandwith']).float(), axis=1)
        bandwidth = bandwidth/torch.max(bandwidth)

        # state matrix shape for the link
        link_h_state_shape = (x['n_links'][0], self.link_state_dim-1)

        # create hidden state matrix shape for the path  
        path_h_state_shape = (x['n_paths'],self.path_state_dim-1)
        path_h_state = torch.cat((bandwidth,torch.zeros(path_h_state_shape)), axis=1)
        
        # prepare input for path update RNN
        max_seq_len = torch.max(seqs)
        path_rnn_input_shape = (x['n_paths'],max_seq_len+1,self.link_state_dim)
        
        #stack the paths and sequences
        ids = torch.stack((paths,seqs),axis=1)
        ids = torch.squeeze(ids,2)           
        p_ind = ids[:,0]
        s_ind = ids[:,1]

        # flatten the double loop into a bulk matrix using the gather functionality
        # this is an aggregation of the state vector of each link on each path flattened into a (sum_paths(sum_links_on_paths) x h_state) size matrix 
        # using torch.gather
        indices = torch.zeros(len(links),32)
        for i in range(len(links)):
            link_id = links[i]
            indices[i,:] = link_id
            
        # variable dictionary for forward pass
        vd = {}
        
        for t in range(self.T):
            
            
            ############# set up the matrices and variables for each pass through #################
            
            ########## PATH VARIABLES ###########
            
            # input to the path rnn layer P_u
            path_rnn_input_key = 'path_rnn_input_' + str(t)
            vd[path_rnn_input_key] = torch.zeros(path_rnn_input_shape)
            
            if (t > 0):  # for non leaf variables, we need to propagate the gradient back
                vd[path_rnn_input_key].requires_grad = True
            
            
            # path hidden state output from P_U, initialized with just bandwidth at T_0, else copy
            path_h_state_key = 'path_h_state_' + str(t)
            if (t==0):
                vd[path_h_state_key] = torch.cat((bandwidth,torch.zeros(path_h_state_shape)), axis=1)
            else:
                path_h_state_key = 'path_h_state_' + str(t)
                path_h_prev = 'path_h_state_' + str(t-1)
                vd[path_h_state_key] = vd[path_h_prev]
            
            
            # path_hidden state sequence from P_U, used to update links
            path_h_state_seq_key = 'path_h_states_seq_' + str(t)
            
            
            ########## LINK VARIABLES ###########
            
            # vector to store the link_hidden states
            if (t == 0):
                # create hidden state matrix for links and initialize with first column as link capacity
                link_h_state_key = 'link_h_state_' + str(t)
                vd[link_h_state_key] = torch.cat((link_cap,torch.zeros(link_h_state_shape)),1)
            else:
                # copy hidden state value for next pass through
                link_h_state_key = 'link_h_state_' + str(t)
                linK_h_prev = 'link_h_state_' + str(t-1)
                vd[link_h_state_key] = vd[linK_h_prev]
                
                
            # matrix storing the hidden states of links on paths
            # i.e. the hidden state of all links in the x['links'] list
            h_link_path_key = 'h_links_on_paths_' + str(t)
            vd[h_link_path_key] = torch.gather(vd[link_h_state_key],0,indices.long())
            
            #link messages extracted from the path hidden state sequence output from P_U
            link_message_key = 'link_messages_' + str(t)

            
             # container for the link messages that are extracted from path rnn hidden states
            agg_link_message_key = 'aggregated_link_message_' + str(t)
            vd[agg_link_message_key] = torch.zeros((x['n_links'],self.link_state_dim),requires_grad=True)
            
            ########################################################################################
            
            
            ################################## DO THE MESSAGE PASSING ##############################
            
            # prepare input for path RNN
            vd[path_rnn_input_key] = vd[path_rnn_input_key].index_put(indices = [p_ind,s_ind],
                                                                      values = vd[h_link_path_key])
            # pass through the path RNN
            vd[path_h_state_seq_key], vd[path_h_state_key] = self.p_U(vd[path_rnn_input_key],
                                                                      torch.unsqueeze(vd[path_h_state_key],
                                                                      0))
            # reformat
            vd[path_h_state_key] = vd[path_h_state_key].squeeze(0)
            
            # extract link messages from the path RNN sequence output
            # equivalent to tf.gather_nd
            vd[link_message_key] = vd[path_h_state_seq_key][p_ind,s_ind,:]
           
            # aggregate the link messages
            vd[agg_link_message_key] = vd[agg_link_message_key].index_put([links.squeeze(1)],
                                                                           vd[link_message_key],
                                                                           accumulate=True)
            # update the state of the links by passing through link 
            _, vd[link_h_state_key] = self.l_U(torch.unsqueeze(vd[agg_link_message_key],1),
                                               torch.unsqueeze(vd[link_h_state_key].squeeze(0),0))
            # reformat
            vd[link_h_state_key] = vd[link_h_state_key].squeeze(0)
            
            ##########################################################################################

        # readout from the paths
        y = self.readout(vd[path_h_state_key])
        return y


    def readout(self,path_state):
        x = F.relu(self.readOut['r1'](path_state))
        x = F.relu(self.readOut['r2'](x))
        x = F.relu(self.readOut['r3'](x))
        return x

In [15]:
model = RouteNet()

In [16]:
yy = model(x)

In [17]:
train(model,x,y,x['packets'])

[0.33971108523744964]
[0.3312999169139103]
[0.3265874610469636]
[0.31761056215542144]
[0.30963469058262727]
[0.30233142516007255]
[0.29482077064354867]
[0.2873400968529528]
[0.27969718235955165]
[0.27179114023138723]
[0.26354703910852173]
[0.25518969327941815]
[0.24649557113453788]
[0.23749583504457294]
[0.22837938069375313]
[0.21884608829800203]
[0.20916900597721508]
[0.19906237337866403]
[0.1887313453651304]
[0.17804305476020957]
[0.16671456485902153]
[0.15485425415776888]
[0.14255537575886493]
[0.12978897022437746]
[0.11658737021963297]
[0.10300481436402038]
[0.08880021982741604]
[0.07399925971542795]
[0.05860083535740999]
[0.04257275083563194]
[0.02590381067742429]
[0.008570224912309106]
[-0.00938918415099805]
[-0.027893069635941948]
[-0.04693575890974589]
[-0.06655175796383157]
[-0.08696493332142279]
[-0.10791151442374303]
[-0.1297254784671124]
[-0.15200527218092313]
[-0.17483728683407637]
[-0.1977246665152961]
[-0.22074024664931796]
[-0.2436755476391545]
[-0.2661677564562294]
[-0

[tensor([[0.0005]], dtype=torch.float64),
 tensor([[0.0326]], dtype=torch.float64),
 tensor([[0.0118]], dtype=torch.float64),
 tensor([[0.0373]], dtype=torch.float64),
 tensor([[0.0552]], dtype=torch.float64),
 tensor([[0.0030]], dtype=torch.float64),
 tensor([[0.0416]], dtype=torch.float64),
 tensor([[0.0061]], dtype=torch.float64),
 tensor([[0.0016]], dtype=torch.float64),
 tensor([[0.0064]], dtype=torch.float64),
 tensor([[0.1350]], dtype=torch.float64),
 tensor([[0.0035]], dtype=torch.float64),
 tensor([[0.0162]], dtype=torch.float64),
 tensor([[0.1405]], dtype=torch.float64),
 tensor([[0.0714]], dtype=torch.float64),
 tensor([[0.0126]], dtype=torch.float64),
 tensor([[0.0081]], dtype=torch.float64),
 tensor([[0.0071]], dtype=torch.float64),
 tensor([[0.0062]], dtype=torch.float64),
 tensor([[0.0110]], dtype=torch.float64),
 tensor([[0.0063]], dtype=torch.float64),
 tensor([[0.0087]], dtype=torch.float64),
 tensor([[0.0093]], dtype=torch.float64),
 tensor([[0.0004]], dtype=torch.fl