In [1]:
import torch
from cora_loader import CitationNetwork,CocitationNetwork,ConfigurationModelCitationNetwork
from multi_layered_model import MonoModel,BiModel,TriModel
from torch_geometric.nn import GATConv
import time
import torch.nn.functional as F
import copy

In [2]:
cora = CitationNetwork('/tmp/cora','cora',directed=False)
citeseer = CitationNetwork('/tmp/citeseer','citeseer',directed=False)
PubMed = CitationNetwork('/tmp/PubMed','PubMed',directed=False)

In [4]:
class MonoGAT(torch.nn.Module):
    def __init__(self,dataset,channels,heads=1):
        super(MonoGAT,self).__init__()
        channels = [dataset.num_node_features] + channels + [dataset.num_classes]
        self.conv = []
        for i in range(1,len(channels)):
            if i == 1:
                conv = GATConv(channels[i-1],channels[i],heads=heads)
            elif i == len(channels)-1:
                if dataset.name=='PubMed':
                    conv = GATConv(channels[i-1]*heads,channels[i],heads=heads,concat=False)
                else:
                    conv = GATConv(channels[i-1]*heads,channels[i])
            else:
                conv = GATConv(channels[i-1]*heads,channels[i],heads=heads)
            self.add_module(str(i),conv)
            self.conv.append(conv)
        
        
    def forward(self, data): 
        x, edge_index = data.x, data.edge_index 
        
        for conv in self.conv[:-1]:
            x = conv(x,edge_index)
            x = F.elu(x)
            x = F.dropout(x,p=0.6,training=self.training) # YOU MUST UNDERSTAND DROPOUT
        
        # Last layer
        x = self.conv[-1](x,edge_index)
        x = F.softmax(x,dim=1)
        
        return x
    

class BiGAT(torch.nn.Module):
    def __init__(self,dataset,channels,heads=1):
        super(BiGAT,self).__init__()
        self.conv_st = []
        self.conv_ts = []
        channels_output = [dataset.num_node_features] + [c*2*heads for c in channels]
        channels = [dataset.num_node_features] + channels
        for i in range(len(channels)-1):
            conv_st = GATConv(channels_output[i], channels[i+1],heads=heads)
            self.add_module('conv_st'+str(i),conv_st)
            self.conv_st.append(conv_st)
            
            conv_ts = GATConv(channels_output[i], channels[i+1],heads=heads)
            self.add_module('conv_ts'+str(i),conv_ts)
            self.conv_ts.append(conv_ts)
        
        if dataset.name=='PubMed':
            self.last = GATConv(channels_output[-1], dataset.num_classes,heads=heads,concat=False)
        else:
            self.last = GATConv(channels_output[-1], dataset.num_classes)
        
    def forward(self, data): 
        x, edge_index = data.x, data.edge_index
        st_edges = data.edge_index.t()[1-data.is_reversed].t()
        ts_edges = data.edge_index.t()[data.is_reversed].t()
#         print(ts_edges.shape)
        for i in range(len(self.conv_st)):
            x1 = F.elu(self.conv_st[i](x,st_edges))
            x2 = F.elu(self.conv_ts[i](x,ts_edges))
            x = torch.cat((x1,x2),dim=1)
            x = F.dropout(x,p=0.6,training=self.training)
        
        # last layer
        x = self.last(x,edge_index)
        x = F.softmax(x,dim=1) 
        
        return x
    
class TriGAT(torch.nn.Module):
    def __init__(self,dataset,channels,heads=1):
        super(TriGAT,self).__init__()
        self.conv_st = []
        self.conv_ts = []
        self.conv = []
        channels_output = [dataset.num_node_features] + [c*3*heads for c in channels]
        channels = [dataset.num_node_features] + channels
        for i in range(len(channels)-1):
            conv_st = GATConv(channels_output[i], channels[i+1],heads=heads)
            self.add_module('conv_st'+str(i),conv_st)
            self.conv_st.append(conv_st)
            
            conv_ts = GATConv(channels_output[i], channels[i+1],heads=heads)
            self.add_module('conv_ts'+str(i),conv_ts)
            self.conv_ts.append(conv_ts)
            
            conv = GATConv(channels_output[i],channels[i+1],heads=heads)
            self.add_module('conv'+str(i),conv)
            self.conv.append(conv)
        
        if dataset.name=='PubMed':
            self.last = GATConv(channels_output[-1], dataset.num_classes,heads=heads,concat=False)
        else:
            self.last = GATConv(channels_output[-1], dataset.num_classes)
        
    def forward(self, data): 
        x, edge_index = data.x, data.edge_index
        st_edges = data.edge_index.t()[1-data.is_reversed].t()
        ts_edges = data.edge_index.t()[data.is_reversed].t()
#         print(ts_edges.shape)
        for i in range(len(self.conv_st)):
            x1 = F.elu(self.conv_st[i](x,st_edges))
            x2 = F.elu(self.conv_ts[i](x,ts_edges))
            x3 = F.elu(self.conv[i](x,edge_index))
            x = torch.cat((x1,x2,x3),dim=1)
            x = F.dropout(x,training=self.training)
        
        # last layer
        x = self.last(x,edge_index)
        x = F.softmax(x,dim=1) 
        
        return x

In [19]:
def run_and_eval_model(dataset,channels,architecture,lr,wd,heads=1,epochs=200):
    # training process (without batches/transforms)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = architecture(dataset,channels,heads).to(device)
#     print(model)
    data = dataset[0].to(device)
    optimizer = torch.optim.Adam(model.parameters(),lr=lr,weight_decay=wd)
    model.train() # to enter training phase
    maxacc = 0
    chosen = None
    for epoch in range(epochs):
        optimizer.zero_grad() # saw this a lot in the beginning, maybe resetting gradients (not to accumulate)
        out = model(data) # this calls the forward method apparently
        loss = F.nll_loss(out[data.train_mask],data.y[data.train_mask]) # nice indexing, easy and short
        loss.backward() # magic: real back propagation step, takes care of the gradients and stuff
        optimizer.step() # maybe updates the params to be optimized
        
        model.eval()
        _,pred = model(data).max(dim=1) # take prediction out of softmax
        correct = float(pred[data.val_mask].eq(data.y[data.val_mask]).sum().item())
        acc = correct / data.val_mask.sum().item()
        if acc > maxacc:
            maxacc = acc
            chosen=copy.deepcopy(model)
        model.train()
    chosen.eval() # enter eval phase
    _,pred = chosen(data).max(dim=1) # take prediction out of softmax
    correct = float(pred[data.test_mask].eq(data.y[data.test_mask]).sum().item())
    acc = correct / data.test_mask.sum().item()
    return acc

def eval_multiple(dataset,channels,architecture=MonoGAT,heads=1,runs=100,epochs=200):
    best_acc = 0
    start = time.time()
    for lr in [1,1e-1,1e-2]:
        for wd in [1e-3,1e-4,1e-5]:
            accs = []
            for i in range(runs):
                accs.append(run_and_eval_model(dataset,channels,architecture,lr,wd,heads=heads,epochs=epochs))
            acc = sum(accs)/len(accs)
            print("lr={:.4f}\twd={:.5f}\t{:.2f}".format(lr,wd,acc*100))
            if acc > best_acc:
                best_lr = lr
                best_wd = wd
                best_acc =acc
            elapsed_time = time.time() - start
    print('Elaplsed {}'.format(time.strftime("%H:%M:%S", time.gmtime(elapsed_time))))
    return best_lr,best_wd

In [16]:
def param_search(data,channels1,channels2,channels3,outFile,heads=1,runs = 10,epochs=200):
    with open(outFile,'w') as out:
        lr,wd = eval_multiple(data,channels1,architecture=MonoGAT,heads=heads,runs=runs,epochs=epochs)
        out.write('Mono lr={:.4f}\twd={:.5f}\n'.format(lr,wd))
        lr,wd = eval_multiple(data,channels2,architecture=BiGAT,heads=heads,runs=runs,epochs=epochs)
        out.write('Bi   lr={:.4f}\twd={:.5f}\n'.format(lr,wd))
        lr,wd = eval_multiple(data,channels3,architecture=TriGAT,heads=heads,runs=runs,epochs=epochs)
        out.write('Tri  lr={:.4f}\twd={:.5f}\n'.format(lr,wd))

In [30]:
param_search(cora,[16],[8],[5],'GAT_param_search_8.out',heads=8,runs=100,epochs=200)

lr=0.0100	wd=0.01000	81.41
Elaplsed 00:08:22
lr=0.0100	wd=0.01000	80.92
Elaplsed 00:06:34
lr=0.0100	wd=0.01000	81.41
Elaplsed 00:06:28


In [6]:
param_search(cora,[32],[16],[10],'GAT_param_search_4.out',heads=4,runs=100,epochs=200)

lr=0.0100	wd=0.01000	81.29
Elaplsed 00:07:42
lr=0.0100	wd=0.01000	81.16
Elaplsed 00:18:43
lr=0.0100	wd=0.01000	81.41
Elaplsed 00:15:00


In [7]:
param_search(cora,[64],[32],[20],'GAT_param_search_2.out',heads=2,runs=100,epochs=200)

lr=0.0100	wd=0.01000	81.39
Elaplsed 00:04:50
lr=0.0100	wd=0.01000	81.13
Elaplsed 00:05:59
lr=0.0100	wd=0.01000	81.52
Elaplsed 00:05:39


In [8]:
param_search(cora,[128],[64],[42],'GAT_param_search_1.out',heads=1,runs=100,epochs=200)

lr=0.0100	wd=0.01000	80.54
Elaplsed 00:02:57
lr=0.0100	wd=0.01000	80.72
Elaplsed 00:12:50
lr=0.0100	wd=0.01000	81.52
Elaplsed 00:15:19


In [26]:
param_search(cora,[16],[8],[5],'GAT_param_search_1.out',heads=1,runs=100,epochs=200)

lr=0.0100	wd=0.01000	77.61
Elaplsed 00:03:06
lr=0.0100	wd=0.01000	79.49
Elaplsed 00:04:42
lr=0.0100	wd=0.01000	79.83
Elaplsed 00:08:49


In [13]:
param_search(PubMed,[64],[32],[10],'GAT_param_search_2_pubmed.out',heads=2,runs=1,epochs=200)

lr=0.1000	wd=0.10000	47.70
lr=0.1000	wd=0.01000	55.90
lr=0.1000	wd=0.00100	75.00
lr=0.1000	wd=0.00010	75.40
lr=0.1000	wd=0.00001	74.70
lr=0.0100	wd=0.10000	51.40
lr=0.0100	wd=0.01000	59.30
lr=0.0100	wd=0.00100	76.10
lr=0.0100	wd=0.00010	76.10
lr=0.0100	wd=0.00001	76.00
lr=0.0010	wd=0.10000	71.90
lr=0.0010	wd=0.01000	51.30
lr=0.0010	wd=0.00100	74.70
lr=0.0010	wd=0.00010	75.00
lr=0.0010	wd=0.00001	73.70
lr=0.0001	wd=0.10000	38.10
lr=0.0001	wd=0.01000	48.90
lr=0.0001	wd=0.00100	54.30
lr=0.0001	wd=0.00010	67.30
lr=0.0001	wd=0.00001	67.40
lr=0.0000	wd=0.10000	25.80
lr=0.0000	wd=0.01000	33.10
lr=0.0000	wd=0.00100	37.00
lr=0.0000	wd=0.00010	44.50
lr=0.0000	wd=0.00001	49.30
Elaplsed 00:01:59
lr=0.1000	wd=0.10000	43.20
lr=0.1000	wd=0.01000	52.20
lr=0.1000	wd=0.00100	72.20
lr=0.1000	wd=0.00010	74.40
lr=0.1000	wd=0.00001	73.40
lr=0.0100	wd=0.10000	48.20
lr=0.0100	wd=0.01000	57.30
lr=0.0100	wd=0.00100	74.40
lr=0.0100	wd=0.00010	75.10
lr=0.0100	wd=0.00001	74.20
lr=0.0010	wd=0.10000	70.30
lr=0.0010	

KeyboardInterrupt: 

In [14]:
param_search(PubMed,[64],[32],[10],'GAT_param_search_2_pubmed.out',heads=2,runs=10,epochs=200)

lr=0.1000	wd=0.10000	46.31
lr=0.1000	wd=0.01000	55.26
lr=0.1000	wd=0.00100	75.81
lr=0.1000	wd=0.00010	75.66
lr=0.1000	wd=0.00001	74.56
lr=0.0100	wd=0.10000	49.75
lr=0.0100	wd=0.01000	59.01
lr=0.0100	wd=0.00100	76.84
lr=0.0100	wd=0.00010	76.74
lr=0.0100	wd=0.00001	75.82
lr=0.0010	wd=0.10000	61.71
lr=0.0010	wd=0.01000	45.97
lr=0.0010	wd=0.00100	74.03
lr=0.0010	wd=0.00010	74.45
lr=0.0010	wd=0.00001	74.60
lr=0.0001	wd=0.10000	36.63
lr=0.0001	wd=0.01000	43.49
lr=0.0001	wd=0.00100	55.27
lr=0.0001	wd=0.00010	62.88
lr=0.0001	wd=0.00001	60.94
lr=0.0000	wd=0.10000	31.92
lr=0.0000	wd=0.01000	36.63
lr=0.0000	wd=0.00100	40.97
lr=0.0000	wd=0.00010	44.81
lr=0.0000	wd=0.00001	49.92
Elaplsed 00:21:13
lr=0.1000	wd=0.10000	46.43
lr=0.1000	wd=0.01000	51.83
lr=0.1000	wd=0.00100	73.98
lr=0.1000	wd=0.00010	74.99
lr=0.1000	wd=0.00001	75.02
lr=0.0100	wd=0.10000	50.13
lr=0.0100	wd=0.01000	58.83
lr=0.0100	wd=0.00100	74.90
lr=0.0100	wd=0.00010	74.79
lr=0.0100	wd=0.00001	74.58
lr=0.0010	wd=0.10000	63.66
lr=0.0010	

In [15]:
param_search(PubMed,[8],[4],[2],'GAT_param_search_8_pubmed.out',heads=8,runs=1,epochs=200)

lr=0.1000	wd=0.10000	44.10
lr=0.1000	wd=0.01000	47.60
lr=0.1000	wd=0.00100	77.60
lr=0.1000	wd=0.00010	78.40
lr=0.1000	wd=0.00001	79.30
lr=0.0100	wd=0.10000	50.00
lr=0.0100	wd=0.01000	59.10
lr=0.0100	wd=0.00100	77.00
lr=0.0100	wd=0.00010	76.60
lr=0.0100	wd=0.00001	76.30
lr=0.0010	wd=0.10000	67.20
lr=0.0010	wd=0.01000	46.20
lr=0.0010	wd=0.00100	51.90
lr=0.0010	wd=0.00010	72.40
lr=0.0010	wd=0.00001	72.10
lr=0.0001	wd=0.10000	35.90
lr=0.0001	wd=0.01000	48.00
lr=0.0001	wd=0.00100	47.50
lr=0.0001	wd=0.00010	48.20
lr=0.0001	wd=0.00001	51.50
lr=0.0000	wd=0.10000	29.50
lr=0.0000	wd=0.01000	44.20
lr=0.0000	wd=0.00100	39.00
lr=0.0000	wd=0.00010	43.30
lr=0.0000	wd=0.00001	42.90
Elaplsed 00:01:58
lr=0.1000	wd=0.10000	41.90
lr=0.1000	wd=0.01000	51.00
lr=0.1000	wd=0.00100	75.50
lr=0.1000	wd=0.00010	76.70
lr=0.1000	wd=0.00001	75.70
lr=0.0100	wd=0.10000	47.00
lr=0.0100	wd=0.01000	51.60
lr=0.0100	wd=0.00100	75.20
lr=0.0100	wd=0.00010	74.20
lr=0.0100	wd=0.00001	74.90
lr=0.0010	wd=0.10000	51.50
lr=0.0010	

In [18]:
param_search(PubMed,[8],[4],[2],'GAT_param_search_8_pubmed.out',heads=8,runs=10,epochs=200)

lr=0.1000	wd=0.10000	43.00
lr=0.1000	wd=0.01000	53.96
lr=0.1000	wd=0.00100	77.00
lr=0.1000	wd=0.00010	78.15
lr=0.1000	wd=0.00001	76.59
lr=0.0100	wd=0.10000	48.02
lr=0.0100	wd=0.01000	52.84
lr=0.0100	wd=0.00100	77.11
lr=0.0100	wd=0.00010	77.04
lr=0.0100	wd=0.00001	75.97
lr=0.0010	wd=0.10000	61.59


KeyboardInterrupt: 

In [20]:
param_search(PubMed,[8],[4],[2],'GAT_param_search_8_pubmed.out',heads=8,runs=10,epochs=200)

lr=1.0000	wd=0.00100	54.95
lr=1.0000	wd=0.00010	47.85
lr=1.0000	wd=0.00001	54.04
lr=0.1000	wd=0.00100	77.30
lr=0.1000	wd=0.00010	77.25
lr=0.1000	wd=0.00001	76.83
lr=0.0100	wd=0.00100	77.12
lr=0.0100	wd=0.00010	77.20
lr=0.0100	wd=0.00001	76.49
Elaplsed 00:08:30
lr=1.0000	wd=0.00100	52.21
lr=1.0000	wd=0.00010	50.15
lr=1.0000	wd=0.00001	54.00
lr=0.1000	wd=0.00100	74.82
lr=0.1000	wd=0.00010	76.10
lr=0.1000	wd=0.00001	75.83
lr=0.0100	wd=0.00100	74.51
lr=0.0100	wd=0.00010	74.35
lr=0.0100	wd=0.00001	75.05
Elaplsed 00:06:53
lr=1.0000	wd=0.00100	56.03
lr=1.0000	wd=0.00010	53.81
lr=1.0000	wd=0.00001	52.92
lr=0.1000	wd=0.00100	75.81
lr=0.1000	wd=0.00010	76.33
lr=0.1000	wd=0.00001	76.19
lr=0.0100	wd=0.00100	75.56
lr=0.0100	wd=0.00010	75.66
lr=0.0100	wd=0.00001	75.72
Elaplsed 00:08:22
