In [None]:
import torch
import torch.nn as nn
import dgl
import dgl.function as fn
from dgl.nn.pytorch import edge_softmax
import numpy as np
import torch.nn.functional as F
from torch.nn import Linear
from torch.nn.parameter import Parameter
from tqdm import tqdm
import time
import warnings

warnings.filterwarnings("ignore")

sig = nn.Sigmoid()
hardtanh = nn.Hardtanh(0,1)
gamma = -0.1
zeta = 1.1
beta = 0.66
eps = 1e-20
const1 = beta*np.log(-gamma/zeta + eps)


In [None]:
def l0_train(logAlpha, min, max):
	# draws reparameterizes Z for gradient estimation 
	# according to the description of function f in Eq 12:
	U = torch.rand(logAlpha.size()).type_as(logAlpha) + eps
	s = sig((torch.log(U / (1 - U)) + logAlpha) / beta)
	s_bar = s * (zeta - gamma) + gamma
	mask = F.hardtanh(s_bar, min, max)
	return mask

def l0_test(logAlpha, min, max):
	# draws samples for Z (according to Eq 13) to be used in forward pass
	s = sig(logAlpha/beta)
	s_bar = s * (zeta - gamma) + gamma
	mask = F.hardtanh(s_bar, min, max)
	return mask

def get_loss2(logAlpha):
	"""
	Second term in the right hand side of the loss fxn in Eq 12
	(except not yet scaled by lambda)
	"""
	return sig(logAlpha - const1)

In [None]:
class SparseGraphAttentionLayer(nn.Module):
	def __init__(self,
				 graphs,
				 in_dim,
				 out_dim,
				 num_heads,
				 feat_drop,
				 attn_drop,
				 alpha,
				 bias_l0,
				 residual=False,l0=0, min=0):
		super(SparseGraphAttentionLayer, self).__init__()
		self.graphs = graphs
		self.num_graphs=len(self.graphs)
		self.num_heads = num_heads
		self.fc = nn.Linear(in_dim, num_heads * out_dim, bias=False)
		if feat_drop:
			self.feat_drop = nn.Dropout(feat_drop) #random feature dropout
		else:
			self.feat_drop = lambda x : x
		if attn_drop:
			self.attn_drop = nn.Dropout(attn_drop) #random attention (i.e. edge) dropout (in addition to the sparsity learned)
		else:
			self.attn_drop = lambda x : x
		self.attn_l = nn.Parameter(torch.Tensor(size=(1, 1, out_dim)))
		self.attn_r = nn.Parameter(torch.Tensor(size=(1, 1, out_dim)))
		self.bias_l0 = nn.Parameter(torch.FloatTensor([bias_l0]))

		nn.init.xavier_normal_(self.fc.weight.data, gain=1.414)
		nn.init.xavier_normal_(self.attn_l.data, gain=1.414)
		nn.init.xavier_normal_(self.attn_r.data, gain=1.414)
		self.leaky_relu = nn.LeakyReLU(alpha)
		self.softmax = edge_softmax
		self.residual = residual
		self.num = 0
		self.l0 = l0
		self.loss = 0
		self.dis = []
		self.min=min
		if residual:
			if in_dim != out_dim:
				self.res_fc = nn.Linear(in_dim, num_heads * out_dim, bias=False)
				nn.init.xavier_normal_(self.res_fc.weight.data, gain=1.414)
			else:
				self.res_fc = None

	def forward(self, inputs, edges="__ALL__", skip=0):
		self.loss = 0
		# prepare
		rets=[]
		for g_i in tqdm(range(len(inputs))):
			time.sleep(0.1) #for tqdm updates
			h = self.feat_drop(inputs[g_i])  # NxD
			h=inputs[g_i]
			ft=self.fc(h)
			ft = self.fc(h).reshape((h.shape[0], self.num_heads, -1))  # NxHxD'
			a1 = (ft * self.attn_l).sum(dim=-1).unsqueeze(-1) # N x H x 1
			a2 = (ft * self.attn_r).sum(dim=-1).unsqueeze(-1) # N x H x 1
			self.graphs[g_i].ndata.update({'ft' : ft, 'a1' : a1, 'a2' : a2})

			if skip == 0:
				# 1. compute edge attention
				self.graphs[g_i].apply_edges(self.edge_attention, edges)

				# 2. compute softmax
				if self.l0 == 1:
					ind = self.graphs[g_i].nodes()
					self.graphs[g_i].apply_edges(self.loop, edges=(ind, ind))

				self.edge_softmax(g_i)

				if self.l0 == 1:
					self.graphs[g_i].apply_edges(self.norm)

			# 2. compute the aggregated node features scaled by the dropped,
				edges = self.graphs[g_i].edata['a'].squeeze().nonzero().squeeze()


			self.graphs[g_i].edata['a_drop'] = self.attn_drop(self.graphs[g_i].edata['a'])
			self.num = (self.graphs[g_i].edata['a'] > 0).sum()
			self.graphs[g_i].update_all(fn.src_mul_edge('ft', 'a_drop', 'ft'), fn.sum('ft', 'ft'))
			ret = self.graphs[g_i].ndata['ft']

			# 4. residual
			if self.residual:
				if self.res_fc is not None:
					resval = self.res_fc(h).reshape((h.shape[0], self.num_heads, -1))  # NxHxD'
				else:
					resval = torch.unsqueeze(h, 1)  # Nx1xD'
				ret = resval + ret
			ret = F.elu(ret.flatten(1))
			rets.append(ret)
		return rets, edges
			
	def edge_attention(self, edges):
		# an edge UDF to compute unnormalized attention values from src and dst
		if self.l0 == 0:
			m = self.leaky_relu(edges.src['a1'] + edges.dst['a2'])
		else:
			tmp = edges.src['a1'] + edges.dst['a2']
			logits = tmp + self.bias_l0

			if self.training:
				m = l0_train(logits, 0, 1)
			else:
				m = l0_test(logits, 0, 1)
			self.loss = get_loss2(logits[:,0,:]).sum()
		return {'a': m}

	def norm(self, edges):
		# normalize attention
		a = edges.data['a'] / edges.dst['z']
		return {'a' : a}

	def loop(self, edges):
		# set attention to itself as 1
		return {'a': torch.pow(edges.data['a'], 0)}
		
	def normalize(self, logits, g_i):
		self._logits_name = "_logits"
		self._normalizer_name = "_norm"

		self.graphs[g_i].edata[self._logits_name] = logits
		self.graphs[g_i].update_all(fn.copy_edge(self._logits_name, self._logits_name),
						 fn.sum(self._logits_name, self._normalizer_name))
		return self.graphs[g_i].edata.pop(self._logits_name), self.graphs[g_i].ndata.pop(self._normalizer_name)

	def edge_softmax(self, g_i):

		if self.l0 == 0:
			scores = self.softmax(self.graphs[g_i], self.graphs[g_i].edata.pop('a'))
		else:
			scores, normalizer = self.normalize(self.graphs[g_i].edata.pop('a'), g_i)
			self.graphs[g_i].ndata['z'] = normalizer[:,0,:].unsqueeze(1)

		self.graphs[g_i].edata['a'] = scores[:,0,:].unsqueeze(1)


In [None]:
class SGATDense(Linear):

	def __init__(self, in_features, out_features, bias= True, device=None, dtype=None):
		factory_kwargs = {'device': device, 'dtype': dtype}
		super(SGATDense, self).__init__(in_features, out_features)

		self.in_features = in_features
		self.out_features = out_features
		self.weight = Parameter(torch.empty((out_features, in_features), **factory_kwargs))
		if bias:
			self.bias = Parameter(torch.empty(out_features, **factory_kwargs))
		else:
			self.register_parameter('bias', None)
		self.reset_parameters()

	def forward(self, inputs):
		logits=torch.empty((len(inputs),4), dtype=torch.float32)
		for i in tqdm(range(len(inputs))):
			time.sleep(0.01) #for tqdm updates
			inp=inputs[i]
			logit=F.linear(torch.transpose(inp,0,1), self.weight, self.bias)
			# prob=F.softmax(logit)
			logits[i,:]=logit
			# logits.append(logit)
			# print(logit.dtype)
		return logits

In [None]:
def process_dataset(dataset):
	num_samples = len(dataset)
	graphs=[]
	labels=[]
	inputs=[]
	for i in range(num_samples):
		graph,label=dataset[i]
		graphs.append(graph),
		labels.append(int(label['label']))
		inputs.append(graph.ndata["feat"])
	return graphs, labels, inputs

def get_batch(dataset, batchSz):
	num_samples = len(dataset)
	graphs=[]
	labels=[]
	inputs=[]
	idx=np.random.choice(num_samples, batchSz, replace=False)
	for i in idx:
		graph,label=dataset[i]
		graphs.append(graph),
		labels.append(int(label['label']))
		inputs.append(graph.ndata["feat"])
	return graphs, labels, inputs

In [None]:
class SGAT(nn.Module):
	def __init__(self,
				 graphs,
				 num_layers,
				 in_dim,
				 num_hidden,
				 num_classes,
				 heads,
				 alpha,
				 feat_drop,
				 attn_drop,
				 bias_l0,
				 residual, l0=0):
		super(SGAT, self).__init__()
		self.graphs = graphs
		self.num_layers = num_layers
		self.sgat_layers = nn.ModuleList()
		# input projection (no residual)
		self.sgat_layers.append(SparseGraphAttentionLayer(
			graphs, in_dim, num_hidden, heads[0],feat_drop, attn_drop, alpha,bias_l0, False, l0=l0, min=0)) # 
		# hidden layers
		for l in range(1, num_layers):
			# due to multi-head, the in_dim = num_hidden * num_heads
			self.sgat_layers.append(SparseGraphAttentionLayer(
				graphs, num_hidden * heads[l-1], num_hidden, heads[l], feat_drop, attn_drop,
				 alpha,bias_l0, residual, l0=l0, min=0)) #feat_drop, attn_drop,

		self.denseFinal=SGATDense(num_hidden, num_classes)

		print("Checking added layers", self.sgat_layers)


	def forward(self, inputs):

		hs = inputs
		edges = "__ALL__"
		hs, edges = self.sgat_layers[0](hs, edges)
		# hs = self.activation(h.flatten(1))
		for l in range(1, self.num_layers):
			hs, _= self.sgat_layers[l](hs, edges, skip=1)
		logits =self.denseFinal(hs)
		return logits

In [None]:
device = torch.device("cpu")
print("Loading in custom dataset")
dataset=dgl.data.CSVDataset('SNAREseq') #Once this is run, next time you can just import SNAREseq and do dataset=SNAREseq()
graphs, labels, inputs=process_dataset(dataset)

# num_train=int(len(dataset)*80/100)	
# num_val=int(len(dataset)*10/100)
# num_test=int(len(dataset)*10/100)

# train_idx=np.random.choice(len(dataset), num_train, replace=False)
# val_idx=np.random.choice([i for i in range(len(dataset)) if i not in train_idx], num_val, replace=False)
# test_idx=[i for i in range(len(dataset)) if i not in train_idx and i not in val_idx]
# train_dataset=[dataset[i] for i in train_idx]


#Set hyperparameters (need to calibrate these)
num_heads=1
num_layers=1
num_out_heads=1
heads = ([num_heads] * num_layers) + [num_out_heads]
num_hidden=2000
alpha=0.2
bias_l0=0
l0=0
residual=False
feat_drop=None
attn_drop=None 
#Fixed parameters:
num_classes=max(labels)+1
in_dim=len(inputs[0])

model=SGAT(graphs,num_layers=num_layers,
				 in_dim=in_dim,
				 num_hidden=num_hidden,
				 num_classes=num_classes,
				 heads=heads,
				 alpha=alpha,
				 feat_drop=feat_drop,
				 attn_drop=attn_drop,
				 bias_l0=bias_l0,
				 residual=residual, l0=l0)

optimizer = torch.optim.Adam(model.parameters(), lr=0.0005)
criterion = torch.nn.CrossEntropyLoss()
labels=torch.LongTensor(labels)

def train():
	model.train()
	out = model(inputs)  # Perform a single forward pass.
	loss = criterion(out, labels)  # Compute the loss.
	try:
		loss.backward()  
		optimizer.step() 
		optimizer.zero_grad() 
	except:
		pass

def test():
	model.eval()

	correct = 0
	test_graphs,test_labels,test_inputs=get_batch(dataset, 100)
	test_labels=torch.LongTensor(test_labels)
	out = model(test_inputs) 
	pred = out.argmax(dim=1)  # Use the class with highest probability.
	correct += int((pred == test_labels).sum())  # Check against ground-truth labels.
	return correct / len(test_labels)  # Derive ratio of correct predictions.

for i in range(100):
	print("EPOCH", i)
	train()
	print(test())
