In [1]:
import torch
from torch import nn
from torch.nn import Parameter
from torch_scatter import scatter_add
from torch.nn.init import xavier_normal_
import torch.nn.functional as F
from collections import defaultdict as ddict
from ordered_set import OrderedSet
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import numpy as np, sys, os, json,argparse
from pprint import pprint
import logging
import inspect
from torch_scatter import scatter
import networkx as nx 
from torch_geometric.nn import GCNConv
import csv

In [2]:
def get_param(shape):
	param = Parameter(torch.Tensor(*shape)); 	
	xavier_normal_(param.data)
	return param

def set_gpu(gpus):

	os.environ["CUDA_DEVICE_ORDER"]    = "PCI_BUS_ID"
	os.environ["CUDA_VISIBLE_DEVICES"] = gpus
      
def get_logger(name, log_dir, config_dir):

	config_dict = json.load(open( config_dir + 'log_config.json'))
	config_dict['handlers']['file_handler']['filename'] = log_dir + name.replace('/', '-')
	logging.config.dictConfig(config_dict)
	logger = logging.getLogger(name)

	std_out_format = '%(asctime)s - [%(levelname)s] - %(message)s'
	consoleHandler = logging.StreamHandler(sys.stdout)
	consoleHandler.setFormatter(logging.Formatter(std_out_format))
	logger.addHandler(consoleHandler)

	return logger

def conj(a):
	a[..., 1] = -a[..., 1]
	return a

def com_mult(a, b):
	r1, i1 = a[..., 0], a[..., 1]
	r2, i2 = b[..., 0], b[..., 1]
	return torch.stack([r1 * r2 - i1 * i2, r1 * i2 + i1 * r2], dim = -1)

def ccorr(a, b):
      A = torch.fft.rfft(a, dim=-1)
      B = torch.fft.rfft(b, dim=-1)
      out = torch.fft.irfft(A.conj() * B, n=a.shape[-1], dim=-1)
      return out

def get_combined_results(left_results, right_results):


	results = {}
	count   = float(left_results['count'])

	results['left_mr']	= round(left_results ['mr'] /count, 5)
	results['left_mrr']	= round(left_results ['mrr']/count, 5)
	results['right_mr']	= round(right_results['mr'] /count, 5)
	results['right_mrr']	= round(right_results['mrr']/count, 5)
	results['mr']		= round((left_results['mr']  + right_results['mr']) /(2*count), 5)
	results['mrr']		= round((left_results['mrr'] + right_results['mrr'])/(2*count), 5)

	for k in (0,2,9):
		results['left_hits@{}'.format(k+1)]	= round(left_results ['hits@{}'.format(k+1)]/count, 5)
		results['right_hits@{}'.format(k+1)]	= round(right_results['hits@{}'.format(k+1)]/count, 5)
		results['hits@{}'.format(k+1)]		= round((left_results['hits@{}'.format(k+1)] + right_results['hits@{}'.format(k+1)])/(2*count), 5)
	return results


In [3]:
class TrainDataset(Dataset):

	def __init__(self, triples, params):
		self.triples	= triples
		self.p 		    = params
		self.entities	= np.arange(self.p.num_ent, dtype=np.int32)

	def __len__(self):
		return len(self.triples)

	def __getitem__(self, idx):
		ele			= self.triples[idx]
		triple, label, sub_samp	= torch.LongTensor(ele['triple']), np.int32(ele['label']), np.float32(ele['sub_samp'])
		trp_label		= self.get_label(label)

		if self.p.lbl_smooth != 0.0:
			trp_label = (1.0 - self.p.lbl_smooth)*trp_label + (1.0/self.p.num_ent)

		return triple, trp_label, None, None

	@staticmethod
	def collate_fn(data):
		triple		= torch.stack([_[0] 	for _ in data], dim=0)
		trp_label	= torch.stack([_[1] 	for _ in data], dim=0)
		return triple, trp_label
	
	def get_neg_ent(self, triple, label):
		def get(triple, label):
			pos_obj		= label
			mask		= np.ones([self.p.num_ent], dtype=np.bool)
			mask[label]	= 0
			neg_ent		= np.int32(np.random.choice(self.entities[mask], self.p.neg_num - len(label), replace=False)).reshape([-1])
			neg_ent		= np.concatenate((pos_obj.reshape([-1]), neg_ent))

			return neg_ent

		neg_ent = get(triple, label)
		return neg_ent

	def get_label(self, label):
		y = np.zeros([self.p.num_ent], dtype=np.float32)
		for e2 in label: y[e2] = 1.0
		return torch.FloatTensor(y)
	
class TestDataset(Dataset):

	def __init__(self, triples, params):
		self.triples	= triples
		self.p 		= params

	def __len__(self):
		return len(self.triples)

	def __getitem__(self, idx):
		ele		= self.triples[idx]
		triple, label	= torch.LongTensor(ele['triple']), np.int32(ele['label'])
		label		= self.get_label(label)

		return triple, label

	@staticmethod
	def collate_fn(data):
		triple		= torch.stack([_[0] 	for _ in data], dim=0)
		label		= torch.stack([_[1] 	for _ in data], dim=0)
		return triple, label
	
	def get_label(self, label):
		y = np.zeros([self.p.num_ent], dtype=np.float32)
		for e2 in label: y[e2] = 1.0
		return torch.FloatTensor(y)

In [4]:
def scatter_(name, src, index, dim_size=None):

	if name == 'add': name = 'sum'
	assert name in ['sum', 'mean', 'max']
	out = scatter(src, index, dim=0, out=None, dim_size=dim_size, reduce=name)
	return out[0] if isinstance(out, tuple) else out

class MessagePassing(torch.nn.Module):


	def __init__(self, aggr='add'):
		super().__init__()

		self.message_args = inspect.getfullargspec(self.message)[0][1:]	# In the defined message function: get the list of arguments as list of string| For eg. in rgcn this will be ['x_j', 'edge_type', 'edge_norm'] (arguments of message function)
		self.update_args  = inspect.getfullargspec(self.update)[0][2:]	# Same for update function starting from 3rd argument | first=self, second=out

	def propagate(self, aggr, edge_index, **kwargs):

		assert aggr in ['add', 'mean', 'max']
		kwargs['edge_index'] = edge_index


		size = None
		message_args = []
		for arg in self.message_args:
			if arg[-2:] == '_i':					# If arguments ends with _i then include indic
				tmp  = kwargs[arg[:-2]]				# Take the front part of the variable | Mostly it will be 'x', 
				size = tmp.size(0)
				message_args.append(tmp[edge_index[0]])		# Lookup for head entities in edges
			elif arg[-2:] == '_j':
				tmp  = kwargs[arg[:-2]]				# tmp = kwargs['x']
				size = tmp.size(0)
				message_args.append(tmp[edge_index[1]])		# Lookup for tail entities in edges
			else:
				message_args.append(kwargs[arg])		# Take things from kwargs

		update_args = [kwargs[arg] for arg in self.update_args]		# Take update args from kwargs

		out = self.message(*message_args)
		out = scatter_(aggr, out, edge_index[0], dim_size=size)		# Aggregated neighbors for each vertex
		out = self.update(out, *update_args)

		return out

	def message(self, x_j):  # pragma: no cover

		return x_j

	def update(self, aggr_out):  # pragma: no cover

		return aggr_out

In [5]:
class CompGCNConv(MessagePassing):
	def __init__(self, in_channels, out_channels, act=lambda x:x, cache=True, params=None):
		super(self.__class__, self).__init__()

		self.p 				= params
		self.in_channels	= in_channels
		self.out_channels	= out_channels
		self.act 			= act
		self.device			= None
		self.cache 			= cache			# Should be False for graph classification tasks
		self.b_norm 		= False

		self.w_loop			= get_param((in_channels, out_channels))
		self.w_in			= get_param((in_channels, out_channels))
		self.w_out			= get_param((in_channels, out_channels))

		self.w_rel 			= get_param((in_channels, out_channels))
		self.loop_rel 		= get_param((1, in_channels))

		self.drop			= torch.nn.Dropout(self.p.dropout)
		self.bn				= torch.nn.BatchNorm1d(out_channels)
		# Check if b_norm is present in params and set it
		if hasattr(self.p, 'b_norm'):
			self.b_norm = self.p.b_norm
		else:
			self.b_norm = False  # Set a default value

		if self.p.bias: self.register_parameter('bias', Parameter(torch.zeros(out_channels)))


	def forward(self, x, edge_index, edge_type, edge_norm=None, rel_embed=None):
	
		if self.device is None:
			self.device = edge_index.device

		rel_embed = torch.cat([rel_embed, self.loop_rel], dim=0)

		num_edges = edge_index.size(1) // 2
		num_ent   = x.size(0)

		if not self.cache == None:
			self.in_index, self.out_index = edge_index[:, :num_edges], edge_index[:, num_edges:]
			self.in_type,  self.out_type  = edge_type[:num_edges], 	 edge_type [num_edges:]

			self.loop_index  = torch.stack([torch.arange(num_ent), torch.arange(num_ent)]).to(self.device)
			self.loop_type   = torch.full((num_ent,), rel_embed.size(0)-1, dtype=torch.long).to(self.device)

			self.in_norm     = self.compute_norm(self.in_index,  num_ent)
			self.out_norm    = self.compute_norm(self.out_index, num_ent)
		
		in_res		= self.propagate('add', self.in_index,   x=x, edge_type=self.in_type,   rel_embed=rel_embed, edge_norm=self.in_norm, 	mode='in')
		loop_res	= self.propagate('add', self.loop_index, x=x, edge_type=self.loop_type, rel_embed=rel_embed, edge_norm=None, 			mode='loop')
		out_res		= self.propagate('add', self.out_index,  x=x, edge_type=self.out_type,  rel_embed=rel_embed, edge_norm=self.out_norm,	mode='out')
		out			= self.drop(in_res)*(1/3) + self.drop(out_res)*(1/3) + loop_res*(1/3)

		if self.p.bias: out = out + self.bias
		if self.b_norm: out = self.bn(out)

		return self.act(out), torch.matmul(rel_embed, self.w_rel)[:-1]

	def rel_transform(self, ent_embed, rel_embed):
		trans_embed  = ccorr(ent_embed, rel_embed)
		return trans_embed

	def message(self, x_j, edge_type, rel_embed, edge_norm, mode):
		weight 	= getattr(self, 'w_{}'.format(mode))
		rel_emb = torch.index_select(rel_embed, 0, edge_type)
		xj_rel  = self.rel_transform(x_j, rel_emb)
		out	= torch.mm(xj_rel, weight)

		return out if edge_norm is None else out * edge_norm.view(-1, 1)

	def update(self, aggr_out):
		return aggr_out

	def compute_norm(self, edge_index, num_ent):
		row, col	= edge_index
		edge_weight 	= torch.ones_like(row).float()
		deg		= scatter_add( edge_weight, row, dim=0, dim_size=num_ent)	# Summing number of weights of the edges [Computing out-degree] [Should be equal to in-degree (undireted graph)]
		deg_inv		= deg.pow(-0.5)							# D^{-0.5}
		deg_inv[deg_inv	== float('inf')] = 0
		norm		= deg_inv[row] * edge_weight * deg_inv[col]			# D^{-0.5}

		return norm

In [6]:
class BaseModel(torch.nn.Module):
	def __init__(self, params):
		super(BaseModel, self).__init__()

		self.p		= params
		self.act	= torch.tanh
		self.bceloss	= torch.nn.BCELoss()

	def loss(self, pred, true_label):
		loss_bce = self.bceloss(pred, true_label)
		return loss_bce

In [7]:
class CompGCNBase(BaseModel):
	def __init__(self, edge_index, edge_type, rrg_edge_index, rrg_edge_type, features,chequer_perm, params=None):
		super(CompGCNBase, self).__init__(params)

		self.edge_index		 = edge_index
		self.edge_type		 = edge_type
		self.rrg_edge_index  = rrg_edge_index
		self.rrg_edge_type   = rrg_edge_type

		self.features        = features
		self.device			= self.edge_index.device
		
		self.init_embed		= get_param((self.p.num_ent,   self.p.init_dim-20))
		self.init_rel  		= get_param((self.p.num_rel*2,   self.p.init_dim))
		self.init_rel_1		= Parameter(self.init_rel)

		self.conv1 = CompGCNConv(self.p.init_dim, self.p.embed_dim, act=self.act, params=self.p)
		self.conv2 = GCNConv(self.p.init_dim, self.p.embed_dim, add_self_loops=True)

		self.register_parameter('bias', Parameter(torch.zeros(self.p.num_ent)))


		self.bn		= torch.nn.BatchNorm1d(self.p.embed_dim)



	def forward_base(self, sub, rel, drop1, drop2):

		r  = self.init_rel
		r1 = self.init_rel_1
		x    	= torch.cat([self.init_embed, self.features], dim=1)
		x       = self.bn(x)
		r_rrg 		= self.conv2(r1, self.rrg_edge_index, self.rrg_edge_type)
		x_eeg, r_eeg 	= self.conv1(x, self.edge_index, self.edge_type, rel_embed=r)
		x_eeg	 	= drop1(x_eeg)

		x = x_eeg

		r = 0.1*r_rrg + 0.9*r_eeg
		
		
		sub_emb	= torch.index_select(x, 0, sub)
		rel_emb	= torch.index_select(r, 0, rel)

		return sub_emb, rel_emb, x

In [8]:
class CompGCN_ConvE(CompGCNBase):
	def __init__(self, edge_index, edge_type, rrg_edge_index, rrg_edge_type, features, chequer_perm, params=None):
		super(self.__class__, self).__init__(edge_index, edge_type, rrg_edge_index, rrg_edge_type, features, chequer_perm,params)
		
		self.bn0		= torch.nn.BatchNorm2d(1)
		self.bn1		= torch.nn.BatchNorm2d(self.p.num_filt)
		self.bn2		= torch.nn.BatchNorm1d(self.p.embed_dim)
		
		self.hidden_drop	= torch.nn.Dropout(self.p.hid_drop)
		self.hidden_drop2	= torch.nn.Dropout(self.p.hid_drop2)
		self.feature_drop	= torch.nn.Dropout(self.p.feat_drop)
		self.m_conv1		= torch.nn.Conv2d(1, out_channels=self.p.num_filt, kernel_size=(self.p.ker_sz, self.p.ker_sz), stride=1, padding=0, bias=self.p.bias)

		flat_sz_h		= int(2*self.p.k_w) - self.p.ker_sz + 1
		flat_sz_w		= self.p.k_h 	    - self.p.ker_sz + 1
		self.flat_sz	= flat_sz_h*flat_sz_w*self.p.num_filt
		self.fc			= torch.nn.Linear(self.flat_sz, self.p.embed_dim)

	def concat(self, e1_embed, rel_embed):
		e1_embed	= e1_embed. view(-1, 1, self.p.embed_dim)
		rel_embed	= rel_embed.view(-1, 1, self.p.embed_dim)
		stack_inp	= torch.cat([e1_embed, rel_embed], 1)
		stack_inp	= torch.transpose(stack_inp, 2, 1).reshape((-1, 1, 2*self.p.k_w, self.p.k_h))
		return stack_inp

	def forward(self, sub, rel):

		sub_emb, rel_emb, all_ent	= self.forward_base(sub, rel, self.hidden_drop, self.feature_drop)
		stk_inp				= self.concat(sub_emb, rel_emb)
		x				= self.bn0(stk_inp)
		x				= self.m_conv1(x)
		x				= self.bn1(x)
		x				= F.relu(x)
		x				= self.feature_drop(x)
		x				= x.view(-1, self.flat_sz)
		x				= self.fc(x)
		x				= self.hidden_drop2(x)
		x				= self.bn2(x)
		x				= F.relu(x)

		x = torch.mm(x, all_ent.transpose(1,0))
		x += self.bias.expand_as(x)

		score = torch.sigmoid(x)
		return score

In [9]:
class CompGCN_DistMult(CompGCNBase):
	def __init__(self, edge_index, edge_type, rrg_edge_index, rrg_edge_type, features,  chequer_perm, params=None):
		super(self.__class__, self).__init__(edge_index, edge_type, rrg_edge_index, rrg_edge_type, features,  chequer_perm,params)
		self.drop = torch.nn.Dropout(self.p.hid_drop)

	def forward(self, sub, rel):

		sub_emb, rel_emb, all_ent	= self.forward_base(sub, rel, self.drop, self.drop)
		obj_emb				= sub_emb * rel_emb

		x = torch.mm(obj_emb, all_ent.transpose(1, 0))
		x += self.bias.expand_as(x)

		score = torch.sigmoid(x)
		return score

In [10]:
class CompGCN_InteractE(CompGCNBase):
    def __init__(self, edge_index, edge_type, rrg_edge_index, rrg_edge_type, features,chequer_perm, params=None):
        super(self.__class__, self).__init__(edge_index, edge_type, rrg_edge_index, rrg_edge_type, features, chequer_perm, params)

        self.dilation_rate2 = 2  
        self.dilation_rate3 = 3 
        
        self.register_parameter('conv_filt_orig', Parameter(
            torch.zeros(self.p.num_filt, 1, self.p.ker_sz, self.p.ker_sz)))
        self.register_parameter('conv_filt_dilated2', Parameter(
            torch.zeros(self.p.num_filt, 1, self.p.ker_sz, self.p.ker_sz)))
        self.register_parameter('conv_filt_dilated3', Parameter(
            torch.zeros(self.p.num_filt, 1, self.p.ker_sz, self.p.ker_sz)))
        
        self.alpha = 1/3  
        self.beta = 1/3   
        xavier_normal_(self.conv_filt_orig)
        xavier_normal_(self.conv_filt_dilated2)
        xavier_normal_(self.conv_filt_dilated3)
        
        self.inp_drop = torch.nn.Dropout(self.p.hid_drop)
        self.hidden_drop = torch.nn.Dropout(self.p.hid_drop2)
        self.feature_map_drop = torch.nn.Dropout2d(self.p.feat_drop)
        self.bn0 = torch.nn.BatchNorm2d(self.p.perm)
        
        flat_sz_h = self.p.k_h
        flat_sz_w = 2*self.p.k_w
        self.padding = 0

        self.bn1 = torch.nn.BatchNorm2d(self.p.num_filt*self.p.perm)
        self.flat_sz = flat_sz_h * flat_sz_w * self.p.num_filt*self.p.perm

        self.bn2 = torch.nn.BatchNorm1d(self.p.embed_dim)
        self.fc = torch.nn.Linear(self.flat_sz, self.p.embed_dim)
        self.chequer_perm = chequer_perm

        self.register_parameter('bias', Parameter(torch.zeros(self.p.num_ent)))
        self.register_parameter('conv_filt', Parameter(torch.zeros(self.p.num_filt, 1, self.p.ker_sz,  self.p.ker_sz)))
        xavier_normal_(self.conv_filt)

    def circular_padding_chw(self, batch, padding):
        upper_pad = batch[..., -padding:, :]
        lower_pad = batch[..., :padding, :]
        temp = torch.cat([upper_pad, batch, lower_pad], dim=2)

        left_pad = temp[..., -padding:]
        right_pad = temp[..., :padding]
        padded = torch.cat([left_pad, temp, right_pad], dim=3)
        return padded

    def forward(self, sub, rel):
        sub_emb, rel_emb, all_ent = self.forward_base(sub, rel, self.hidden_drop, self.feature_map_drop)
        comb_emb = torch.cat([sub_emb, rel_emb], dim=1)
        chequer_perm = comb_emb[:, self.chequer_perm]
        stack_inp = chequer_perm.reshape((-1, self.p.perm, 2*self.p.k_w, self.p.k_h))
        stack_inp = self.bn0(stack_inp)
        x = self.inp_drop(stack_inp)

        pad_orig = self.p.ker_sz // 2
        x_orig = self.circular_padding_chw(x, pad_orig)
        x_orig = F.conv2d(x_orig, 
                         self.conv_filt_orig.repeat(self.p.perm, 1, 1, 1),
                         padding=0,
                         groups=self.p.perm,
                         dilation=1)
        
        pad_dilated2 = self.dilation_rate2 * (self.p.ker_sz // 2)
        x_dilated2 = self.circular_padding_chw(x, pad_dilated2)
        x_dilated2 = F.conv2d(x_dilated2, 
                            self.conv_filt_dilated2.repeat(self.p.perm, 1, 1, 1),
                            padding=0,
                            groups=self.p.perm,
                            dilation=self.dilation_rate2)
        
        pad_dilated3 = self.dilation_rate3 * (self.p.ker_sz // 2)
        x_dilated3 = self.circular_padding_chw(x, pad_dilated3)
        x_dilated3 = F.conv2d(x_dilated3, 
                            self.conv_filt_dilated3.repeat(self.p.perm, 1, 1, 1),
                            padding=0,
                            groups=self.p.perm,
                            dilation=self.dilation_rate3)
        
        gamma = 1.0 - self.alpha - self.beta
        x = (self.alpha * x_orig + self.beta * x_dilated2 + gamma * x_dilated3)

        x = self.bn1(x)
        x = F.relu(x)
        x = self.feature_map_drop(x)
        x = x.view(-1, self.flat_sz)
        x = self.fc(x)

        x = self.hidden_drop(x)
        x = self.bn2(x)
        x = F.relu(x)

        x = torch.mm(x, all_ent.transpose(1,0))
        x += self.bias.expand_as(x)
        score = torch.sigmoid(x)
        return score

In [11]:
class Runner(object):
	
	def __init__(self, params):

		self.p			= params
		self.logger		= get_logger(self.p.name, self.p.log_dir, self.p.config_dir)

		self.logger.info(vars(self.p))
		pprint(vars(self.p))

		if self.p.gpu != '-1' and torch.cuda.is_available():
			self.device = torch.device('cuda')
			torch.cuda.set_rng_state(torch.cuda.get_rng_state())
			torch.backends.cudnn.deterministic = True
		else:
			self.device = torch.device('cpu')

		self.load_data()
		self.model        = self.add_model(self.p.model, self.p.score_func)
		self.optimizer    = self.add_optimizer(self.model.parameters())


	def load_data(self):
		"""
		Reading in raw triples and converts it into a standard format. 

		Parameters
		----------
		self.p.dataset:         Takes in the name of the dataset (FB15k-237)
		
		Returns
		-------
		self.ent2id:            Entity to unique identifier mapping
		self.id2rel:            Inverse mapping of self.ent2id
		self.rel2id:            Relation to unique identifier mapping
		self.num_ent:           Number of entities in the Knowledge graph
		self.num_rel:           Number of relations in the Knowledge graph
		self.embed_dim:         Embedding dimension used
		self.data['train']:     Stores the triples corresponding to training dataset
		self.data['valid']:     Stores the triples corresponding to validation dataset
		self.data['test']:      Stores the triples corresponding to test dataset
		self.data_iter:		    The dataloader for different data splits

		"""

		ent_set, rel_set = OrderedSet(), OrderedSet()
		for split in ['train', 'test', 'valid']:
			for line in open('./data/{}/{}.txt'.format(self.p.dataset, split)):
				sub, rel, obj = map(str.lower, line.strip().split('\t'))
				ent_set.add(sub)
				rel_set.add(rel)
				ent_set.add(obj)

		self.ent2id = {ent: idx for idx, ent in enumerate(ent_set)}
		self.rel2id = {rel: idx for idx, rel in enumerate(rel_set)}
		self.rel2id.update({rel+'_reverse': idx+len(self.rel2id) for idx, rel in enumerate(rel_set)})

		self.id2ent = {idx: ent for ent, idx in self.ent2id.items()}
		self.id2rel = {idx: rel for rel, idx in self.rel2id.items()}

		self.p.num_ent		= len(self.ent2id)
		self.p.num_rel		= len(self.rel2id) // 2
		self.p.embed_dim	= self.p.k_w * self.p.k_h if self.p.embed_dim is None else self.p.embed_dim
		print('num entities: ', self.p.num_ent)
		print('num relations: ', self.p.num_rel)
		self.data = ddict(list)
		sr2o = ddict(set)

		for split in ['train', 'test', 'valid']:
			for line in open('./data/{}/{}.txt'.format(self.p.dataset, split)):
				sub, rel, obj = map(str.lower, line.strip().split('\t'))
				sub, rel, obj = self.ent2id[sub], self.rel2id[rel], self.ent2id[obj]
				self.data[split].append((sub, rel, obj))

				if split == 'train': 
					sr2o[(sub, rel)].add(obj)
					sr2o[(obj, rel+self.p.num_rel)].add(sub)

		self.data = dict(self.data)

		self.sr2o = {k: list(v) for k, v in sr2o.items()}
		for split in ['test', 'valid']:
			for sub, rel, obj in self.data[split]:
				sr2o[(sub, rel)].add(obj)
				sr2o[(obj, rel+self.p.num_rel)].add(sub)

		self.sr2o_all = {k: list(v) for k, v in sr2o.items()}
		self.triples  = ddict(list)

		for (sub, rel), obj in self.sr2o.items():
			self.triples['train'].append({'triple':(sub, rel, -1), 'label': self.sr2o[(sub, rel)], 'sub_samp': 1})

		for split in ['test', 'valid']:
			for sub, rel, obj in self.data[split]:
				rel_inv = rel + self.p.num_rel
				self.triples['{}_{}'.format(split, 'tail')].append({'triple': (sub, rel, obj), 	   'label': self.sr2o_all[(sub, rel)]})
				self.triples['{}_{}'.format(split, 'head')].append({'triple': (obj, rel_inv, sub), 'label': self.sr2o_all[(obj, rel_inv)]})

		self.triples = dict(self.triples)
		
		def get_data_loader(dataset_class, split, batch_size, shuffle=True):
			return  DataLoader(
					dataset_class(self.triples[split], self.p),
					batch_size      = batch_size,
					shuffle         = shuffle,
					num_workers     = max(0, self.p.num_workers),
					collate_fn      = dataset_class.collate_fn
				)

		self.data_iter = {
			'train':    	get_data_loader(TrainDataset, 'train', 	    self.p.batch_size),
			'valid_head':   get_data_loader(TestDataset,  'valid_head', self.p.batch_size),
			'valid_tail':   get_data_loader(TestDataset,  'valid_tail', self.p.batch_size),
			'test_head':   	get_data_loader(TestDataset,  'test_head',  self.p.batch_size),
			'test_tail':   	get_data_loader(TestDataset,  'test_tail',  self.p.batch_size),
		}

		self.edge_index, self.edge_type = self.construct_adj(self.data['train'])
		self.rrg_edge_index, self.rrg_edge_type = self.construct_rrg(self.data['train'])
		self.features = self.concat_features(self.edge_index, self.edge_type)
		self.chequer_perm	= self.get_chequer_perm()

	def get_chequer_perm(self):
		"""
		Function to generate the chequer permutation required for InteractE model

		Parameters
		----------
		
		Returns
		-------
		
		"""
		ent_perm  = np.int32([np.random.permutation(self.p.embed_dim) for _ in range(self.p.perm)])
		rel_perm  = np.int32([np.random.permutation(self.p.embed_dim) for _ in range(self.p.perm)])

		comb_idx = []
		for k in range(self.p.perm):
			temp = []
			ent_idx, rel_idx = 0, 0

			for i in range(self.p.k_h):
				for j in range(self.p.k_w):
					if k % 2 == 0:
						if i % 2 == 0:
							temp.append(ent_perm[k, ent_idx]); ent_idx += 1
							temp.append(rel_perm[k, rel_idx]+self.p.embed_dim); rel_idx += 1
						else:
							temp.append(rel_perm[k, rel_idx]+self.p.embed_dim); rel_idx += 1
							temp.append(ent_perm[k, ent_idx]); ent_idx += 1
					else:
						if i % 2 == 0:
							temp.append(rel_perm[k, rel_idx]+self.p.embed_dim); rel_idx += 1
							temp.append(ent_perm[k, ent_idx]); ent_idx += 1
						else:
							temp.append(ent_perm[k, ent_idx]); ent_idx += 1
							temp.append(rel_perm[k, rel_idx]+self.p.embed_dim); rel_idx += 1

			comb_idx.append(temp)

		chequer_perm = torch.LongTensor(np.int32(comb_idx)).to(self.device)
		return chequer_perm
	
	def concat_features(self, edge_index, edge_type):
			features = self.compute_node_features(edge_index, edge_type, self.p.num_ent).to(self.device)
			print("features:", features)
			print("features:", features.size())
			return features

	def construct_adj(self, data):
			edge_index, edge_type = [], []
			edge_type_count = {} 

			sub_unique = {}
			obj_unique = {}

			for sub, rel, obj in data:
				edge_index.append((sub, obj))
				edge_type.append(rel)
				edge_type_count[rel] = edge_type_count.get(rel, 0) + 1

				if rel not in sub_unique:
					sub_unique[rel] = set()
				if rel not in obj_unique:
					obj_unique[rel] = set()
				sub_unique[rel].add(sub)
				obj_unique[rel].add(obj)


			for sub, rel, obj in data:
				inverse_rel = rel + self.p.num_rel
				edge_index.append((obj, sub))
				edge_type.append(inverse_rel)
				edge_type_count[inverse_rel] = edge_type_count.get(inverse_rel, 0) + 1

				if inverse_rel not in sub_unique:
					sub_unique[inverse_rel] = set()
				if inverse_rel not in obj_unique:
					obj_unique[inverse_rel] = set()
				sub_unique[inverse_rel].add(obj)
				obj_unique[inverse_rel].add(sub)

			edge_index = torch.LongTensor(edge_index).to(self.device).t()
			edge_type = torch.LongTensor(edge_type).to(self.device)
			return edge_index, edge_type
	
	def construct_rrg(self, data):
		num_relations = self.p.num_rel

		head_rel_dict = ddict(set)  # {ent: {rel of triples having ent as head}}
		tail_rel_dict = ddict(set)  # {ent: {rel of triples having ent as tail}}
		ent_set = set()

		new_triple = []
		new_triple_text = set()

		for triple in data:
			head, rel, tail = triple
			ent_set.update([head, tail])
			head_rel_dict[head].add(rel)
			tail_rel_dict[tail].add(rel)
			
		for ent in ent_set:
			as_tail_set = head_rel_dict[ent] 
			as_head_set = tail_rel_dict[ent] 

			for ele in as_tail_set:
				for ele2 in as_head_set:
					str_key = "{}_{}_{}".format(ele, ent, ele2)
					if (ele % num_relations) == (ele2 % num_relations):
						continue
					if str_key not in new_triple_text:
						new_triple_text.add(str_key)
						new_triple.append([ele, ent, ele2])
		edge_count = {}
		
		for sub, rel, obj in new_triple:
			edge = (sub, obj)
			edge_count[edge] = edge_count.get(edge, 0) + 1
		for sub, rel, obj in new_triple:
			edge = (sub + num_relations, obj + num_relations)
			edge_count[edge] = edge_count.get(edge, 0) + 1

		edge_index = []
		edge_weight = []
		for edge, count in edge_count.items():
			edge_index.append(edge)
			edge_weight.append(count)  

		edge_index = torch.LongTensor(edge_index).to(self.device).t()
		edge_weight = torch.LongTensor(edge_weight).to(self.device)

		min_weight = torch.min(edge_weight[edge_weight > 0])
		log_weights = torch.log(edge_weight + min_weight/10)
		edge_weight = (log_weights - log_weights.min()) / (log_weights.max() - log_weights.min())

		return edge_index, edge_weight


	def compute_node_features(self, edge_index, edge_type, num_nodes):
		G = nx.Graph()
		G.add_nodes_from(range(num_nodes))
		edges = edge_index.t().tolist()
		for i, (u, v) in enumerate(edges):
			rel = edge_type[i].item() 
			G.add_edge(u, v, relation=rel)
		G.remove_edges_from(nx.selfloop_edges(G))

		second_neighbors_dict = {}
		path_relations = ddict(lambda: ddict(list))

		for n in G.nodes():
			neighbors = set(G.neighbors(n))
			second_neighbors = set()
			for neighbor in neighbors:
				for second_neighbor in G.neighbors(neighbor):
					if second_neighbor != n and second_neighbor not in neighbors:
						second_neighbors.add(second_neighbor)
						rel1 = G.edges[n, neighbor]['relation']
						rel2 = G.edges[neighbor, second_neighbor]['relation']
						path_relations[n][second_neighbor].append((rel1, rel2))
			second_neighbors_dict[n] = second_neighbors
		nbr_deg_stats = {}
		for n in G.nodes():
			neighbors = list(G.neighbors(n))
			if not neighbors:
				min_deg = 0.0
				avg_deg = 0.0
				max_deg = 0.0
			else:
				degrees = [G.degree[neighbor] for neighbor in neighbors]
				min_deg = min(degrees)
				avg_deg = sum(degrees) / len(degrees)
				max_deg = max(degrees)
			nbr_deg_stats[n] = (min_deg, avg_deg, max_deg)

		second_hop_deg_stats = {}
		for n in G.nodes():
			second_neighbors = second_neighbors_dict.get(n, set())
			if not second_neighbors:
				min_deg = 0.0
				avg_deg = 0.0
				max_deg = 0.0
			else:
				degrees = [G.degree[sn] for sn in second_neighbors]
				min_deg = min(degrees)
				avg_deg = sum(degrees) / len(degrees)
				max_deg = max(degrees)
			second_hop_deg_stats[n] = (min_deg, avg_deg, max_deg)

		features = [
			[G.degree(n),
			nbr_deg_stats[n][0], nbr_deg_stats[n][1], nbr_deg_stats[n][2],
			second_hop_deg_stats[n][0], second_hop_deg_stats[n][1], second_hop_deg_stats[n][2]]
			for n in range(num_nodes)
		]


		degrees = dict(G.degree())
		gravity_scores = []
		for u in G.nodes():
			score = 0.0
			paths = nx.single_source_shortest_path_length(G, u, cutoff=5)
			for v, d in paths.items():
				if u == v or d == 0: continue
				score += (degrees[u] * degrees[v]) / (d ** 2)
			gravity_scores.append(score)
		gravity_scores = torch.tensor(gravity_scores, dtype=torch.float)


		def array_norm(array):
			max_val = torch.max(array).item()
			min_val = torch.min(array).item()
			if min_val < 0: min_val = 0
			return (array - min_val)/(max_val - min_val) if max_val != min_val else torch.zeros_like(array)

		features = torch.tensor(features, dtype=torch.float)
		processed = [array_norm(features[:, col]) for col in range(features.size(1))]
		processed.append(array_norm(gravity_scores))


		rel_counts = ddict(int)
		rel_degree_sums = ddict(float)
		for u, v, data in G.edges(data=True):
			rel = data['relation']
			rel_counts[rel] += 1
			rel_degree_sums[rel] += degrees[u] + degrees[v]

		avg_degree_per_rel = {}
		for rel in rel_counts:
			avg_degree_per_rel[rel] = rel_degree_sums[rel] / (2 * rel_counts[rel])


		max_count = max(rel_counts.values()) if rel_counts else 1
		min_count = min(rel_counts.values()) if rel_counts else 0
		max_avg_deg = max(avg_degree_per_rel.values()) if avg_degree_per_rel else 1
		min_avg_deg = min(avg_degree_per_rel.values()) if avg_degree_per_rel else 0

		rel_embeddings = {}
		for rel in rel_counts:
			norm_count = (rel_counts[rel] - min_count) / (max_count - min_count) if max_count != min_count else 0.0
			norm_avg_deg = (avg_degree_per_rel[rel] - min_avg_deg) / (max_avg_deg - min_avg_deg) if max_avg_deg != min_avg_deg else 0.0
			rel_embeddings[rel] = (norm_count, norm_avg_deg)


		relation_weighted_features = []
		for n in G.nodes():
			sum_embed = [0.0, 0.0]
			count = 0
			for neighbor in G.neighbors(n):
				edge_data = G.get_edge_data(n, neighbor)
				rel = edge_data['relation']
				embed = rel_embeddings.get(rel, (0.0, 0.0))
				deg = degrees[neighbor]
				sum_embed[0] += deg * embed[0]
				sum_embed[1] += deg * embed[1]
				count += 1
			avg_embed = [s/count if count else 0.0 for s in sum_embed]
			relation_weighted_features.append(avg_embed)

		relation_weighted_features = torch.tensor(relation_weighted_features, dtype=torch.float)
		processed_rel_feat1 = array_norm(relation_weighted_features[:, 0])
		processed_rel_feat2 = array_norm(relation_weighted_features[:, 1])
		processed.extend([processed_rel_feat1, processed_rel_feat2])


		neighbor_deg_features = []
		for u in G.nodes():
			direct_neighbors = set(G.neighbors(u))
			second_neighbors = second_neighbors_dict.get(u, set())
			all_neighbors = direct_neighbors.union(second_neighbors)
			neighbors = list(all_neighbors)
			
			neighbor_scores = []
			for v in neighbors:
				if v in direct_neighbors:
					edge_data = G.get_edge_data(u, v)
					rel = edge_data['relation']
					embed = rel_embeddings.get(rel, (0.0, 0.0))
					avg_weight = (embed[0] + embed[1]) / 2
				else:  
					relations_list = path_relations[u].get(v, [])
					if not relations_list:
						avg_weight = 0.0
					else:
						total_avg = 0.0
						for rel1, rel2 in relations_list:
							embed1 = rel_embeddings.get(rel1, (0.0, 0.0))
							embed2 = rel_embeddings.get(rel2, (0.0, 0.0))
							avg_embed0 = (embed1[0] + embed2[0]) / 2
							avg_embed1 = (embed1[1] + embed2[1]) / 2
							combined_avg = (avg_embed0 + avg_embed1)/2 
							total_avg += combined_avg
						avg_weight = total_avg / len(relations_list)
				score = gravity_scores[v].item() * avg_weight
				neighbor_scores.append((v, score))
			
			sorted_neighbors = sorted(neighbor_scores, key=lambda x: -x[1])
			top_10 = [v for v, _ in sorted_neighbors[:10]]
			
			degs = [degrees[v] for v in top_10]
			degs += [0.0] * (10 - len(degs))
			neighbor_deg_features.append(degs)

		neighbor_deg_features = torch.tensor(neighbor_deg_features, dtype=torch.float)
		processed_neighbor_degs = torch.stack([array_norm(neighbor_deg_features[:, i]) for i in range(10)], dim=1)
		
		final_features = torch.cat([torch.stack(processed, dim=1), processed_neighbor_degs], dim=1)
		return final_features
	

	def add_model(self, model, score_func):

		model_name = '{}_{}'.format(model, score_func)

		if model_name.lower()	== 'compgcn_conve': 			model = CompGCN_ConvE(self.edge_index, self.edge_type, self.rrg_edge_index, self.rrg_edge_type, self.features,self.chequer_perm, params=self.p)
		elif model_name.lower()	== 'compgcn_interacte': 		model = CompGCN_InteractE(self.edge_index, self.edge_type, self.rrg_edge_index, self.rrg_edge_type, self.features,self.chequer_perm, params=self.p)
		else: raise NotImplementedError

		model.to(self.device)
		return model

	def add_optimizer(self, parameters):


		return torch.optim.Adam(parameters, lr=self.p.lr, weight_decay=self.p.l2)
	
	def read_batch(self, batch, split):

		if split == 'train':
			triple, label = [ _.to(self.device) for _ in batch]
			return triple[:, 0], triple[:, 1], triple[:, 2], label
		else:
			triple, label = [ _.to(self.device) for _ in batch]
			return triple[:, 0], triple[:, 1], triple[:, 2], label
		
	def save_model(self, save_path):

		dir_path = os.path.dirname(save_path)
		if not os.path.exists(dir_path):
			os.makedirs(dir_path)

		state = {
			'state_dict'	: self.model.state_dict(),
			'best_val'	: self.best_val,
			'best_epoch'	: self.best_epoch,
			'optimizer'	: self.optimizer.state_dict(),
			'args'		: vars(self.p)
		}
		torch.save(state, save_path)

	def load_model(self, load_path):

		state			= torch.load(load_path)
		state_dict		= state['state_dict']
		self.best_val		= state['best_val']
		self.best_val_mrr	= self.best_val['mrr'] 

		self.model.load_state_dict(state_dict)
		self.optimizer.load_state_dict(state['optimizer'])

	def evaluate(self, split, epoch):

		left_results  = self.predict(split=split, mode='tail_batch')
		right_results = self.predict(split=split, mode='head_batch')
		results       = get_combined_results(left_results, right_results)
		self.logger.info('[Epoch {} {}]: MRR: Tail : {:.5}, Head : {:.5}, Avg : {:.5}'.format(epoch, split, results['left_mrr'], results['right_mrr'], results['mrr']))
		return results

	def predict(self, split='valid', mode='tail_batch'):

		self.model.eval()

		with torch.no_grad():
			results = {}
			train_iter = iter(self.data_iter['{}_{}'.format(split, mode.split('_')[0])])

			for step, batch in enumerate(train_iter):
				sub, rel, obj, label	= self.read_batch(batch, split)
				pred			= self.model.forward(sub, rel)
				b_range			= torch.arange(pred.size()[0], device=self.device)
				target_pred		= pred[b_range, obj]
				pred 			= torch.where(label.bool(), -torch.ones_like(pred) * 10000000, pred)
				pred[b_range, obj] 	= target_pred
				ranks			= 1 + torch.argsort(torch.argsort(pred, dim=1, descending=True), dim=1, descending=False)[b_range, obj]

				ranks 			= ranks.float()
				results['count']	= torch.numel(ranks) 		+ results.get('count', 0.0)
				results['mr']		= torch.sum(ranks).item() 	+ results.get('mr',    0.0)
				results['mrr']		= torch.sum(1.0/ranks).item()   + results.get('mrr',   0.0)
				for k in range(10):
					results['hits@{}'.format(k+1)] = torch.numel(ranks[ranks <= (k+1)]) + results.get('hits@{}'.format(k+1), 0.0)

				if step % 100 == 0:
					self.logger.info('[{}, {} Step {}]\t{}'.format(split.title(), mode.title(), step, self.p.name))

		return results

	def run_epoch(self, epoch, val_mrr = 0):

		self.model.train()
		losses = []
		train_iter = iter(self.data_iter['train'])

		for step, batch in enumerate(train_iter):
			self.optimizer.zero_grad()
			sub, rel, obj, label = self.read_batch(batch, 'train')

			pred	= self.model.forward(sub, rel)
			loss	= self.model.loss(pred, label)

			loss.backward()
			torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)  # 裁剪梯度
			self.optimizer.step()
			losses.append(loss.item())

			if step % 100 == 0:
				self.logger.info('[E:{}| {}]: Train Loss:{:.5},  Val MRR:{:.5}\t{}'.format(epoch, step, np.mean(losses), self.best_val_mrr, self.p.name))

		loss = np.mean(losses)
		self.logger.info('[Epoch:{}]:  Training Loss:{:.4}\n'.format(epoch, loss))
		return loss
	
	def fit(self):
		
		self.best_val_mrr, self.best_val, self.best_epoch, val_mrr = 0., {}, 0, 0.
		save_path = os.path.join('./checkpoints', self.p.name)

		os.makedirs(self.p.log_dir, exist_ok=True) 
		metrics_file = os.path.join(self.p.log_dir, '8_2_training_metrics.csv')
		
		if not os.path.exists(metrics_file):
			with open(metrics_file, 'a') as f:
				f.write('epoch,train_loss,val_mr,val_mrr,val_hits@1,val_hits@3,val_hits@10\n')

		if self.p.restore:
			self.load_model(save_path)
			self.logger.info('Successfully Loaded previous model')

		kill_cnt = 0
		for epoch in range(self.p.max_epochs):
			train_loss  = self.run_epoch(epoch, val_mrr)
			val_results = self.evaluate('valid', epoch)

			with open(metrics_file, 'a') as f:
				line = f"{epoch},{train_loss:.5f},{val_results['mr']:.5f}," \
					f"{val_results['mrr']:.5f},{val_results['hits@1']:.5f}," \
					f"{val_results['hits@3']:.5f},{val_results['hits@10']:.5f}\n"
				f.write(line)

			if val_results['mrr'] > self.best_val_mrr:
				self.best_val     = val_results
				self.best_val_mrr = val_results['mrr']
				self.best_epoch   = epoch
				self.save_model(save_path)
				kill_cnt = 0
			else:
				kill_cnt += 1
				if kill_cnt > 40: 
					self.logger.info("Early Stopping!!")
					break

			self.logger.info('[Epoch {}]: Training Loss: {:.5}, Valid MRR: {:.5}\n\n'.format(epoch, train_loss, self.best_val_mrr))

		self.logger.info('Loading best model, Evaluating on Test data')
		self.load_model(save_path)
		test_results = self.evaluate('test', epoch)
		return test_results

In [12]:

parser = argparse.ArgumentParser()
parser.add_argument('-data',		dest='dataset',         default='UMLS',                                                         help='Dataset to use, default: FB15k-237')
parser.add_argument('-k_w',	  	    dest='k_w', 		    default=10,   	            type=int, 	                                help='Decoder: k_w')
parser.add_argument('-k_h',	  	    dest='k_h', 		    default=20,   	            type=int, 	                                help='Decoder: k_h')
parser.add_argument('-embed_dim',	dest='embed_dim', 	    default=200,                type=int, 	                                help='Embedding dimension to give as input to score function')
parser.add_argument('-batch',       dest='batch_size',      default=256,                type=int,                                   help='Batch size')
parser.add_argument('-num_workers',	type=int,               default=0,                                                              help='Number of processes to construct batches')
parser.add_argument('-name',        default='testrun',                                                                              help='Set run name for saving/restoring models')
parser.add_argument('-logdir',      dest='log_dir',         default='./',                                                           help='Log directory')
parser.add_argument('-config',      dest='config_dir',      default='./config/',                                                    help='Config directory')
parser.add_argument('-gpu',         type=str,               default='0',                                                            help='Set GPU Ids : Eg: For CPU = -1, For Single GPU = 0')
parser.add_argument('-model',       dest='model',           default='compgcn',                                                      help='Model Name')
parser.add_argument('-score_func',  dest='score_func',      default='interacte',                                                    help='Score Function for Link prediction')
parser.add_argument('-gcn_layer',   dest='gcn_layer',       default=1,                  type=int,                                   help='Number of GCN Layers to use')
parser.add_argument('-init_dim',    dest='init_dim',        default=200,                type=int,                                   help='Initial dimension size for entities and relations')
parser.add_argument('-gcn_drop',    dest='dropout',         default=0.1,                type=float,                                 help='Dropout to use in GCN Layer')
parser.add_argument('-bias',        dest='bias',            action='store_true',                                                    help='Whether to use bias in the model')
parser.add_argument('-num_filt',    dest='num_filt',        default=128,                type=int,                                   help='Decoder: Number of filters in convolution')
parser.add_argument('-hid_drop',    dest='hid_drop',        default=0.1,                type=float,                                 help='Dropout after GCN')
parser.add_argument('-hid_drop2',   dest='hid_drop2',       default=0.3,                type=float,                                 help='Decoder: Hidden dropout')
parser.add_argument('-feat_drop',   dest='feat_drop',       default=0.3,                type=float,                                 help='Decoder: Feature Dropout')
parser.add_argument('-ker_sz',      dest='ker_sz',          default=7,                  type=int,                                   help='Decoder: Kernel size to use')
parser.add_argument('-lr',          type=float,             default=0.001,                                                          help='Starting Learning Rate')
parser.add_argument('-l2',          type=float,             default=0.0,                                                            help='L2 Regularization for Optimizer')
parser.add_argument('-restore',     dest='restore',         action='store_true',                                                    help='Restore from the previously saved model')
parser.add_argument('-epoch',		dest='max_epochs', 	    type=int,                   default=500,  	                            help='Number of epochs')
parser.add_argument('-lbl_smooth',  dest='lbl_smooth',	    type=float,                 default=0.1,	                            help='Label Smoothing')
parser.add_argument('-gcn_dim',	  	dest='gcn_dim', 	    default=200,   	            type=int, 	                                help='Number of hidden units in GCN')

parser.add_argument('-gamma',		type=float,             default=9.0,			help='Margin')
parser.add_argument('-perm',	  	    dest='perm', 		    default=2,   	            type=int, 	                                help='num_perm')

args, unknown = parser.parse_known_args()

np.random.seed(42)
torch.manual_seed(42)
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
model = Runner(args)

model.fit()

2025-06-16 00:46:25,867 - [INFO] - {'dataset': 'UMLS', 'k_w': 10, 'k_h': 20, 'embed_dim': 200, 'batch_size': 256, 'num_workers': 0, 'name': 'testrun', 'log_dir': './', 'config_dir': './config/', 'gpu': '0', 'model': 'compgcn', 'score_func': 'interacte', 'gcn_layer': 1, 'init_dim': 200, 'dropout': 0.1, 'bias': False, 'num_filt': 128, 'hid_drop': 0.1, 'hid_drop2': 0.3, 'feat_drop': 0.3, 'ker_sz': 7, 'lr': 0.001, 'l2': 0.0, 'restore': False, 'max_epochs': 500, 'lbl_smooth': 0.1, 'gcn_dim': 200, 'gamma': 9.0, 'perm': 2}
{'batch_size': 256,
 'bias': False,
 'config_dir': './config/',
 'dataset': 'UMLS',
 'dropout': 0.1,
 'embed_dim': 200,
 'feat_drop': 0.3,
 'gamma': 9.0,
 'gcn_dim': 200,
 'gcn_layer': 1,
 'gpu': '0',
 'hid_drop': 0.1,
 'hid_drop2': 0.3,
 'init_dim': 200,
 'k_h': 20,
 'k_w': 10,
 'ker_sz': 7,
 'l2': 0.0,
 'lbl_smooth': 0.1,
 'log_dir': './',
 'lr': 0.001,
 'max_epochs': 500,
 'model': 'compgcn',
 'name': 'testrun',
 'num_filt': 128,
 'num_workers': 0,
 'perm': 2,
 'restore'

KeyboardInterrupt: 