<a href="https://colab.research.google.com/github/rizveeredwan/Annoying-Mute-Line/blob/master/GAT_Hamiltonian_thresholding.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# https://nn.labml.ai/graphs/gat/index.html
# https://theaisummer.com/gnn-architectures/#:~:text=The%20main%20idea%20behind%20GAT,determine%20each%20node's%20%E2%80%9Cimportance%E2%80%9D.

In [2]:
!pip install labml_helpers
!pip install pytorchtools

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [3]:
import torch
from torch import nn
from labml_helpers.module import Module
from math import sqrt
from matplotlib import pyplot as plt
import os
from math import ceil, sqrt
import csv

In [4]:
from google.colab import drive
drive.mount('/content/gdrive')

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [5]:
class EarlyStopping:
  def __init__(self, patience):
    self.loss = None
    self.cnt = 0 
    self.patience = patience
  def check(self, curr_loss):
    if self.loss is None:
      self.loss = curr_loss
      self.cnt = 0 
    elif self.loss > curr_loss:
      self.loss = curr_loss 
      self.cnt = 0 
    elif self.loss <= curr_loss:
      self.cnt += 1 
    if self.cnt > self.patience:
      return True 
    return False 

In [6]:
class GraphAttentionLayer(Module):
  def __init__(self, in_features: int, out_features: int, n_heads: int, is_concat: bool = True, dropout: float = 0.6, leaky_relu_negative_slope: float = 0.2):
    super().__init__()
    self.is_concat = is_concat
    self.n_heads = n_heads
    if is_concat:
      assert out_features % n_heads == 0
      self.n_hidden = out_features // n_heads
    else:
      self.n_hidden = out_features
    self.linear = nn.Linear(in_features, self.n_hidden * n_heads, bias=False)
    self.attn = nn.Linear(self.n_hidden * 2, 1, bias=False)
    self.activation = nn.LeakyReLU(negative_slope=leaky_relu_negative_slope)
    self.softmax = nn.Softmax(dim=1)
    self.dropout = nn.Dropout(dropout)    
  def forward(self, h: torch.Tensor, adj_mat: torch.Tensor):
    n_nodes = h.shape[0]
    g = self.linear(h).view(n_nodes, self.n_heads, self.n_hidden)
    g_repeat = g.repeat(n_nodes, 1, 1)
    g_repeat_interleave = g.repeat_interleave(n_nodes, dim=0)
    g_concat = torch.cat([g_repeat_interleave, g_repeat], dim=-1)
    g_concat = g_concat.view(n_nodes, n_nodes, self.n_heads, 2 * self.n_hidden)
    e = self.activation(self.attn(g_concat))
    e = e.squeeze(-1)
    assert adj_mat.shape[0] == 1 or adj_mat.shape[0] == n_nodes
    assert adj_mat.shape[1] == 1 or adj_mat.shape[1] == n_nodes
    assert adj_mat.shape[2] == 1 or adj_mat.shape[2] == self.n_heads
    e = e.masked_fill(adj_mat == 0, float('-inf')) 
    a = self.softmax(e)
    a = self.dropout(a)
    attn_res = torch.einsum('ijh,jhf->ihf', a, g)
    if self.is_concat:
      return attn_res.reshape(n_nodes, self.n_heads * self.n_hidden)
    else:
       return attn_res.mean(dim=1)

class Decoder(nn.Module):
  def __init__(self, C, d, dh):
    super().__init__()
    self.C = C 
    self.dh = dh
    self.layer1 = nn.Linear(d, dh, bias=False)
    self.layer2 = nn.Linear(d, dh, bias=False)
    self.alpha = nn.Tanh()
    self.softmax = nn.Softmax(dim=0)

  def only_neighbour_consideration(self, V, chosen, adj_list, label_map, epoch):
    for node in range(len(chosen)-1, -1, -1):
      o1 = self.layer1(V[chosen[node]])
      _list = []
      _idx = []
      for j in range(len(adj_list[chosen[node]])):
        v = adj_list[chosen[node]][j]
        if label_map[v] is None: # calculation with each neighbour 
          o2 = self.layer2(V[v])
          res = torch.dot(o1, o2) / sqrt(self.dh)
          val = self.C * self.alpha(res)
          _list.append(val)
          non_zero = True 
          _idx.append(v)
      if len(_list) > 0:
        t = torch.tensor(_list)
        res = self.softmax(t)
        # print("node ", node, epoch, _list)
        selected_node = torch.argmax(res)
        return res[selected_node], _idx[selected_node] # corresponding_probability , selected_node 
    return None, None 

        

  def forward(self, V, adj_mat, adj_list, chosen, label_map, current_epoch):
    # print(o1.shape)
    # print(o1)
    temp = torch.squeeze(adj_mat)
    decoder_output, best_node = self.only_neighbour_consideration(V=V, chosen=chosen, adj_list=adj_list, label_map=label_map, epoch=current_epoch)
    return decoder_output, best_node 
    """
    _list = []
    non_zero = None 
    for node in range(len(chosen)-1, -1, -1):
      o1 = self.layer1(V[node])
      _list = []
      non_zero = None 
      for j in range(temp.shape[1]):
        if temp[node][j] == 1: # calculation with each neighbour 
          o2 = self.layer2(V[j])
          res = torch.dot(o1, o2) / sqrt(self.dh)
          val = self.C * self.alpha(res)
          _list.append(val)
          non_zero = True 
        else:
          _list.append(0) # not neighbour 
      if non_zero is not None:
        t = torch.tensor(_list)
        res = self.softmax(t)
        print("node ", node, len(res), current_epoch, _list)
        decoder_output, best_node = None, None 
        for j in range(0, res.shape[0]):
          if label_map[j] is None: # from that which is not chosen 
            if decoder_output is None:
              decoder_output = res[j] 
              best_node = j 
            if decoder_output < res[j]:
              decoder_output = res[j] 
              best_node = j 
        if best_node is not None: # solution found 
          return decoder_output, best_node   
    # chosen did not work: random distance neighbour pick up  
    _list = []
    o1 = self.layer1(V[len(chosen)-1])
    for i in range(0, temp.shape[1]):
      o2 = self.layer2(V[i])
      res = torch.dot(o1, o2) / sqrt(self.dh)
      val = self.C * self.alpha(res)
      _list.append(val)
    t = torch.tensor(_list)
    res = self.softmax(t)
    decoder_output, best_node = None, None 
    for j in range(0, res.shape[0]):
      if label_map[j] is None:
        if decoder_output is None:
          decoder_output = res[j] 
          best_node = j 
        if decoder_output < res[j]:
          decoder_output = res[j] 
          best_node = j 
    return decoder_output, best_node     
  """
         
   

class Encoder(nn.Module):
  def __init__(self, n_nodes: int, in_features: int, out_features: int):
    super().__init__()
    self.n_nodes = n_nodes 
    self.embedding = nn.Embedding(n_nodes, in_features)
    self.encoder1 = GraphAttentionLayer(in_features=in_features, out_features=out_features, n_heads=1, is_concat= True)
    self.encoder2 = GraphAttentionLayer(in_features=out_features, out_features=out_features, n_heads=1, is_concat= True)
    self.encoder3 = GraphAttentionLayer(in_features=out_features, out_features=out_features, n_heads=1, is_concat= True)

    self.last_layer = nn.Linear(out_features, 1, bias=False)
    self.softmax = nn.Softmax(dim=0)

  def forward(self, adj_mat):
    t = [i for i in range(self.n_nodes)]
    t = torch.tensor(t)
    em = self.embedding(t)
    #print(em)
    en1 = self.encoder1(em, adj_mat) 
    #print("en1 ", en1)
    en2 = self.encoder2(en1, adj_mat) 
    #print("en2 ", en2)
    en3 = self.encoder3(en2, adj_mat) 
    #print("en3 ", en3.shape, en3)
    ll = self.last_layer(en3)
    softmax = self.softmax(ll)
    #print(softmax, softmax.shape)
    idx = torch.argmax(softmax, dim=0).squeeze()
    #print(idx)
    return int(idx), en3, softmax[idx]
  
 

m = nn.Softmax(dim=2)
input = torch.randn(2, 3, 4)
#print(input)
output = m(input)
#print(output)

In [7]:
def constraint_violation(label_maps, adj_mat):
  # function to see how many violations have been observed 
  violation = 0
  #print(label_maps)
  # print(adj_mat.shape)
  temp = torch.squeeze(adj_mat)
  #print(temp.shape)
  # print(temp)
  edge_count = 0 
  for i in range(0, temp.shape[0]):
    for j in range(i+1, temp.shape[0]):
      if temp[i][j] == 1 and label_maps[i] == label_maps[j]:
        violation += 1 
      if temp[i][j] == 1:
        #print("edge ", i,j )
        edge_count += 1
    # print(f"{i} = {edge_count}")
  # print("edge count ",edge_count)
  return violation 

In [8]:
def calculate_reward(adj_mat, node_idx, prob, threshold, label_maps, initiate=False):
  assert(adj_mat.shape[2] == 1) # n_heads == 1 
  val = 0
  cnt = 0
  for i in range(0, adj_mat.shape[0]):
    if adj_mat[node_idx][i][0] == 1: # adjacency 
      cnt += 1 
  if prob>= threshold: # labeled as 1 
    label_maps[node_idx] = 1
  else:
    if initiate==True:
      label_maps[node_idx] = 1
    else:
      label_maps[node_idx] = 0
  val = label_maps[node_idx] * (-cnt) * label_maps[node_idx] 
  # who are bigger
  for j in range(node_idx+1, adj_mat.shape[0]): # i < j 
    #print("one ", label_maps)
    if adj_mat[node_idx][j][0] == 1 and adj_mat[j][node_idx][0] == 1 and label_maps[j] is not None: # 
      val = val + label_maps[node_idx] * 2 * label_maps[j]
  # who are smaller 
  for j in range(0, node_idx):
    #print("two ", label_maps)
    if adj_mat[node_idx][j][0] == 1 and adj_mat[j][node_idx][0] == 1 and label_maps[j] is not None:
      val = val + label_maps[j] * 2 * label_maps[node_idx]
  return val, label_maps 

def check_grad(model):
  params = list(model.parameters())
  print("length ", len(params))
  for i in range(0, len(params)):
    print(params[i].shape, params[i].requires_grad)
    print(params[i].grad)

def run(adj_mat, adj_list, epoch, threshold, d0=2, d1=5, d2=7, fig_name="50_499.jpg", data_file_name="data.csv", patience=20):
  n = adj_mat.shape[0]
  if n>=100000:
    d0 = max(ceil(sqrt(n*1.0)), 2)
  else:
    # d0 = max(ceil(n**(1.0/3.0)), 2)
    d0 = max(ceil(sqrt(n*1.0)), 2)
  d1 = max(ceil(d0/2.0), 2)
  d2 = max(2, ceil(d1/2.0))
  print(f"do = {d0} d1={d1} d2={d2}")
  encoder = Encoder(n_nodes = adj_mat.shape[0], in_features= d0, out_features=d1) 
  t = list(encoder.parameters())
  # print(list(encoder.parameters()))
  decoder = Decoder(C=10, d=d1, dh=ceil(d1/2.0))
  # print(len(list(decoder.parameters())))

  encoder_optimizer = torch.optim.Adam(encoder.parameters(), lr=0.001)
  decoder_optimizer = torch.optim.Adam(decoder.parameters(), lr=0.0001)
  # initialize the early_stopping object
  es = EarlyStopping(patience=patience)

  save_loss = []
  save_reward = []
  save_result = []

  for i in range(0, epoch):
    print("epoch ", i)
    first_node, encoded, prob = encoder(adj_mat)
    # print("enoded ", encoded)
    chosen =  []
    label_maps = []
    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()
    for j in range(0, adj_mat.shape[0]):
      label_maps.append(None) 
    chosen.append(first_node)
    #print(first_node, encoded, prob)
    #print("prob ", prob)
    reward, label_maps = calculate_reward(adj_mat, chosen[0], prob, threshold, label_maps, True)
    loss = reward * prob 
    # print("reward ", reward)
    for j in range(1, adj_mat.shape[0]):
      decoder_output, best_node = decoder(encoded, adj_mat, adj_list, chosen, label_maps, i)
      # print("j = ", j, best_node, decoder_output)
      # print("came ", decoder_output, best_node, decoder_output.requires_grad)
      chosen.append(best_node)
      temp_reward, label_maps = calculate_reward(adj_mat, chosen[-1], decoder_output, threshold, label_maps, False)
      reward = reward + temp_reward
      loss = loss + reward * decoder_output
      # print("loss ",loss, reward, best_node, label_maps, decoder_output)
    loss.backward()
    save_loss.append(loss.detach().numpy())
    save_reward.append(reward)
    encoder_optimizer.step()
    decoder_optimizer.step()
    #check_grad(encoder)
    #check_grad(decoder)
    # print(f"redward = {reward} loss = {loss}")
    #print(chosen) 
    violation = constraint_violation(label_maps, adj_mat)
    print(f"redward = {reward} loss = {loss} violation = {violation}")
    print(label_maps)
    save_result.append((loss.detach().numpy(), violation))
    # early_stopping needs the validation loss to check if it has decresed, 
    if es.check(curr_loss=loss) is True:
        print("Early stopping")
        break
  # save_loss.sort(key=lambda x:x[0], reverse=True)
  save_result.sort(key=lambda x:x[1])
  global MIN_VALIDATION_VIOLATION_OBSERVED
  print(f"MIN_VALIDATION_VIOLATION_OBSERVED = {MIN_VALIDATION_VIOLATION_OBSERVED}")
  if MIN_VALIDATION_VIOLATION_OBSERVED is None:
    MIN_VALIDATION_VIOLATION_OBSERVED = save_result[0][1]
    function_plot(data=save_loss, title='loss variation', ylabel='loss', xlabel='epoch', loc="upper left", fig_name=fig_name)
    function_plot(data=save_reward, title='reward variation', ylabel='reward', xlabel='epoch', loc="upper left", fig_name=None)
    write_into_csv(file_name=data_file_name, save_result=save_result, trim=30)  
  elif MIN_VALIDATION_VIOLATION_OBSERVED > save_result[0][1]:
    function_plot(data=save_loss, title='loss variation', ylabel='loss', xlabel='epoch', loc="upper left", fig_name=fig_name)
    function_plot(data=save_reward, title='reward variation', ylabel='reward', xlabel='epoch', loc="upper left", fig_name=None)
    write_into_csv(file_name=data_file_name, save_result=save_result, trim=30)  
    MIN_VALIDATION_VIOLATION_OBSERVED = save_result[0][1]
  else:
    function_plot(data=save_loss, title='loss variation', ylabel='loss', xlabel='epoch', loc="upper left", fig_name=None)
    function_plot(data=save_reward, title='reward variation', ylabel='reward', xlabel='epoch', loc="upper left", fig_name=None)
  print(save_result[0:min(30, len(save_result))])

In [9]:
def graph_read(file_name=os.path.join('/content', "gdrive", 'MyDrive', 'Research', 'MADRL', 'Graph Dataset', "50_499.txt")):
  # reading the text file 
  if os.path.exists(file_name):
    with open(file_name, 'r', encoding='utf-8') as f:
      lines = f.readline().strip().split(' ')
      n, e = int(lines[0]), int(lines[1])
      adj_mat = [] 
      adj_list = []
      for i in range(0, n):
        adj_mat.append([])
        adj_list.append([])
        for j in range(0, n):
          adj_mat[i].append(0)
      for i in range(0, e):
        lines = f.readline().strip().split(' ')
        a, b = int(lines[0]), int(lines[1]) 
        adj_mat[a-1][b-1]=1
        adj_mat[b-1][a-1]=1
        adj_list[a-1].append(b-1)
        adj_list[b-1].append(a-1)
      return adj_mat, adj_list  
# graph_read(file_name=os.path.join('/content', "gdrive", 'MyDrive', 'Research', 'MADRL', 'Graph Dataset', '50_499.txt'))

In [10]:
def function_plot(data, title='loss variation', ylabel='loss', xlabel='epoch', loc="upper left", fig_name=None):
  plt.plot(data)
  plt.title(title)
  plt.ylabel(ylabel)
  plt.xlabel(xlabel)
  plt.show()
  if fig_name is not None:
    plt.savefig(fig_name)

In [11]:
def write_into_csv(file_name='data.csv', save_result=[], trim=30):
  with open(file_name, 'w', encoding='utf-8') as f:
    csv_writer = csv.writer(f)
    for i in range(0, min(trim, len(save_result))):
      csv_writer.writerow([save_result[i][0], save_result[i][1]])
  return 

In [12]:
MIN_VALIDATION_VIOLATION_OBSERVED = None 
data_file = '500_1499'

In [None]:
adj_matrix, adj_list = graph_read(file_name=os.path.join('/content', "gdrive", 'MyDrive', 'Research', 'MADRL', 'Graph Dataset', data_file+'.txt'))
adj_matrix = torch.tensor(adj_matrix)
#controller = Controller(n_nodes = 5, in_features= 2, out_features=5, decoder_out_features=10)
#adj_matrix = torch.unsqueeze(adj_matrix, 2)
adj_matrix = adj_matrix.unsqueeze(-1)
#print(adj_matrix.shape)
#controller(adj_matrix)
temp = adj_matrix.squeeze()
print(temp)
print(temp[0], temp[0].shape)

run(adj_mat=adj_matrix, adj_list=adj_list,  epoch=300, threshold=0.5, d0=4, d1=2, d2=2, fig_name=os.path.join('/content', "gdrive", 'MyDrive', 'Research', 'MADRL', 'Graph Dataset', 'GAT_HAM_'+data_file+'.jpg'),
    data_file_name = os.path.join('/content', "gdrive", 'MyDrive', 'Research', 'MADRL', 'Graph Dataset', 'GAT_HAM_'+data_file+'.csv'), patience=30)


tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]])
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
        0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 

In [None]:

m = nn.Tanh()
input = torch.randn(2)
print(input)
output = m(input)
print(output)

In [None]:
A = torch.rand(1)
print(A)
A.requires_grad = True
print(A.grad)

B = torch.rand(1)
print(B)
B.requires_grad = True
print(B.grad)

D = torch.rand(1)
print(D)
D.requires_grad = True
print(D.grad)

x = max(A*D, B*D)
print(x)
x.required_grad = True 
#print(x.grad)
x.backward()

print("A ", A.grad)
print("B ", B.grad)
print("D ",D.grad)
