* This notebook is for experimenting with the open graph benchmarks(ogb) node prediction tasks.
* It contains the following modules
    * analysis-does simple eda
    * create_dataloader- creates a train valid and test dataloader
    * Gat- GatConv layers.
    * one_epoch- trains model for one epoch.

In [168]:
import torch
from ogb.nodeproppred import PygNodePropPredDataset, Evaluator
from torch_geometric.data import DataLoader
from torch_geometric.data.sampler import NeighborSampler
import seaborn as sns
import matplotlib.pyplot as plt

# Dataloader/analysis

In [48]:
dataset=PygNodePropPredDataset(name='ogbn-arxiv',root='../node_dataset/')

In [49]:
split_idx = dataset.get_idx_split()
data = dataset[0]

In [171]:
sns.__version__

'0.10.1'

In [90]:
def analysis(data,split_idx:dict,name:str):
    if name=='ogbn-arxiv' or name=='ogbn-products':
        print(f'output analysis {name}')
        for i in ['train','valid','test']:
            figure=plt.figure(figsize=(10,3))
            sns.countplot(x=data.y[split_idx[i]].squeeze(1).numpy())
            plt.title(f'{i} y')
            plt.show()
    return data

def preprocess(data,split_idx:dict,name:str):
    '''
    Normalize node features.
    preprocess dataset based on name.
    --args--
    data=PygNodePropPredDataset(name)[0]
    name: name of ogbn dataset.
    '''
    print(f'preprocessing {name}')
    if name=='ogbn-arxiv':
        #add directed edge in other direction
        #add edge attribute
        print('adding edges in other direction')
        print('normalizing')
    elif name=='ogbn-products':
        print('normalizing')
    return data

def create_dataloder(data,split_idx:dict,sizes=[-1,-1,-1],batch_size=2048)->dict:
    '''
    return train,test and valid dataloaders in a dict.
    --args--
    data=PygNodePropPredDataset(name)[0]
    split_idx: is a dictonary with train,valid and test ids. output of get_idx_split 
    '''
    loader_dict={}
    train_batchsize=batch_size
    for i in ['train','valid','test']:
        idx=split_idx[i]
        batch_size=train_batchsize if i=='train' else 2*train_batchsize
        shuffle=True if i=='train' else False
        loader=NeighborSampler(data.edge_index,node_idx=idx,sizes=sizes, batch_size=batch_size,
                               shuffle=shuffle)
        loader_dict[i]=loader
    return loader_dict

In [91]:
batch_size=2048
loader=create_dataloder(data,split_idx,batch_size=batch_size)

# GCN model

In [121]:
import torch_geometric
from torch_geometric.nn import GATConv,BatchNorm
from torch import nn,optim
class Gat(nn.Module):
    def __init__(self,inp_dim=3,filters=[16,16,16],drop=0.1,edge_drop=0.1,bn=True):
        super().__init__()#all params are added to _modules internally, this makes sure its initialized.
        self.gat_modules=nn.ModuleList()
        self.bn=bn
        self.bn_modules=nn.ModuleList()
        for i,j in enumerate(filters):
            if i==0:
                self.gat_modules.append(GATConv(in_channels=inp_dim,out_channels=filters[i],dropout=edge_drop))
            else:
                self.gat_modules.append(GATConv(in_channels=filters[i-1],out_channels=filters[i],dropout=edge_drop))
            
            if bn:
                self.bn_modules.append(BatchNorm(in_channels=filters[i]))
        self.leaky=nn.LeakyReLU()
        self.drop=nn.Dropout(p=edge_drop)
    
    def forward(self,x,adjs):
        for i,adj in enumerate(adjs):
            x=self.gat_modules[i](x,adj.edge_index.to(device=x.device))
            x=x[:adj.size[1]]
            x=self.leaky(x)
            if self.bn:
                x=self.bn_modules[i](x)
            x=self.drop(x)
        return x
class Fcn(nn.Module):
    '''
    last layer will not have activation.
    '''
    def __init__(self,inp_dim=16,layers=[8,40],bn=True,drop=0.1):
        super().__init__()
        self.lyrs=[]
        for i,j in enumerate(layers):
            if i==0:
                self.lyrs.append(nn.Linear(inp_dim,layers[i]))
            else:
                self.lyrs.append(nn.Linear(layers[i-1],layers[i]))
            if i!=len(layers)-1:
                self.lyrs.append(nn.LeakyReLU())
                self.lyrs.append(nn.BatchNorm1d(layers[i]))
                self.lyrs.append(nn.Dropout(p=drop))
        self.lyrs=nn.Sequential(*self.lyrs)#pass the list to Sequential.   
    def forward(self,x):
        return self.lyrs(x)

class arxiv_classifier(nn.Module):
    def __init__(self,gnn_inp_dim=3,gnn_filters=[16,16,16],gnn_drop=0.1,gnn_edge_drop=0.1,gnn_bn=True
                ,fc_inp_dim=16,fc_layers=[8,40],fc_bn=True,fc_drop=0.1):
        super().__init__()
        self.gnn=Gat(inp_dim=gnn_inp_dim,filters=gnn_filters,drop=gnn_drop,
                     edge_drop=gnn_edge_drop,bn=gnn_bn)
        self.fcn=Fcn(inp_dim=fc_inp_dim,layers=fc_layers,bn=fc_bn,drop=fc_drop)
    def forward(self,x,adjs):
        gnn_out=self.gnn(x,adjs)
        return self.fcn(gnn_out)

# Training Loop

In [166]:
device=torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model=arxiv_classifier(gnn_inp_dim=128)
model=model.to(device=device)
opt=optim.Adam(model.parameters(),lr=0.05)
scheduler=optim.lr_scheduler.CyclicLR(optimizer=opt,base_lr=0.01,max_lr=0.1,step_size_up=100,cycle_momentum=False)
closs=nn.CrossEntropyLoss()

In [167]:
def accuracy(pred:torch.tensor,truth:torch.tensor):
    return sum(pred==truth).item()/len(truth)
def one_epoch(data,loader,model,loss_func,opt,eval_func=accuracy,train=True):
    sum_eval=0
    datapoints=0
    for size,n_id,adjs in loader:
        inp=data.x[n_id].to(device=device)
        pred=model(inp,adjs)
        truth=data.y[n_id[:size]].to(device=device).squeeze(dim=1)
        loss=loss_func(pred,truth)
        if train:
            loss.backward()
            opt.step()
            opt.zero_grad()
        sum_eval+=accuracy(torch.argmax(pred,dim=1),truth)*size#.item converts a 0d tensor to a python number 
        datapoints+=size
    return sum_eval/datapoints
        
epochs=10
for i in range(epochs):
    model.train()
    train_loss=one_epoch(data,loader['train'],model,closs,opt)
    model.eval()
    with torch.no_grad():
        val_loss=one_epoch(data,loader['valid'],model,closs,opt,train=False)
    print(val_loss)

0.4594113896439478
0.5081714151481593
0.5431725896842176
0.5412262156448203
0.5379039565086077
0.5406557267022384
0.5525353199771804
0.5384744454511896
0.5348166045840465
0.5599181180576529
