# Training GCN

In [290]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch import autograd

import dgl
from dgl import DGLGraph
from dgl.data import MiniGCDataset
import dgl.function as fn
from dgl.data.utils import load_graphs

import numpy as np
import pandas as pd

import spacy
import collections

import os

## Load Data

In [291]:
size = 'xs'

In [292]:
cls_tokens = torch.load('data/X_train_cls_tokens_{}.bin'.format(size))
gcn_offsets = torch.load("data/X_train_gcn_offsets_{}.bin".format(size))
all_graphs, _ = load_graphs("data/X_train_graphs_{}.bin".format(size))

y_data = torch.load('data/y_{}.pt'.format(size))

# Model Building

## RGCN Layer

In [293]:
class RGCNLayer(nn.Module):
    def __init__(self, feat_size, num_rels, activation=None, gated = True):
        
        super(RGCNLayer, self).__init__()
        self.feat_size = feat_size
        self.num_rels = num_rels
        self.activation = activation
        self.gated = gated

        self.weight = nn.Parameter(torch.Tensor(self.num_rels, self.feat_size, 256))
        # init trainable parameters
        nn.init.xavier_uniform_(self.weight,gain=nn.init.calculate_gain('relu'))
        
        if self.gated:
            self.gate_weight = nn.Parameter(torch.Tensor(self.num_rels, self.feat_size, 1))
            nn.init.xavier_uniform_(self.gate_weight,gain=nn.init.calculate_gain('sigmoid'))
        
    def forward(self, g):
        
        weight = self.weight
        gate_weight = self.gate_weight
        
        def message_func(edges):
            w = weight[edges.data['rel_type']]
            msg = torch.bmm(edges.src['h'].unsqueeze(1), w).squeeze()
            msg = msg * edges.data['norm']
            
            if self.gated:
                gate_w = gate_weight[edges.data['rel_type']]
                gate = torch.bmm(edges.src['h'].unsqueeze(1), gate_w).squeeze().reshape(-1,1)
                gate = torch.sigmoid(gate)
                msg = msg * gate
                
            return {'msg': msg}
    
        def apply_func(nodes):
            h = nodes.data['h']
            h = self.activation(h)
            return {'h': h}
        g.update_all(message_func, fn.sum(msg='msg', out='h'), apply_func)

## Define Full RGCN Model

In [294]:
class RGCNModel(nn.Module):
    def __init__(self, h_dim, num_rels, num_hidden_layers=1, gated = True):
        super(RGCNModel, self).__init__()

        self.h_dim = h_dim
        self.num_rels = num_rels
        self.num_hidden_layers = num_hidden_layers
        self.gated = gated
        
        # create rgcn layers
        self.build_model()
       
    def build_model(self):        
        self.layers = nn.ModuleList() 
        for _ in range(self.num_hidden_layers):
            rgcn_layer = RGCNLayer(self.h_dim, self.num_rels, activation=F.relu, gated = self.gated)
            self.layers.append(rgcn_layer)
    
    def forward(self, g):
        for layer in self.layers:
            layer(g) # todo: maybe g = layer(g)??
        
        rst_hidden = []
        for sub_g in dgl.unbatch(g):
            rst_hidden.append(  sub_g.ndata['h']   )
        return rst_hidden

## Design the Main Model (R-GCN + FFNN)

In [295]:
class Head(nn.Module):
    """The MLP submodule"""
    def __init__(self, gcn_out_size: int, bert_out_size: int):
        super().__init__()
        self.bert_out_size = bert_out_size
        self.gcn_out_size = gcn_out_size
        
        self.fc = nn.Sequential(
            nn.BatchNorm1d(bert_out_size + gcn_out_size * 3),
            nn.Dropout(0.5),
            nn.Linear(bert_out_size + gcn_out_size * 3, 256),    
            nn.ReLU(),
            nn.BatchNorm1d(256),
            nn.Dropout(0.5),
            nn.Linear(256, 2), # todo: make sure 2 is fine.
        )
        for i, module in enumerate(self.fc):
            if isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d)):
                nn.init.constant_(module.weight, 1)
                nn.init.constant_(module.bias, 0)
            elif isinstance(module, nn.Linear):
                if getattr(module, "weight_v", None) is not None:
                    nn.init.uniform_(module.weight_g, 0, 1)
                    nn.init.kaiming_normal_(module.weight_v)
                    assert model[i].weight_g is not None
                else:
                    nn.init.kaiming_normal_(module.weight)
                nn.init.constant_(module.bias, 0)
                
    def forward(self, gcn_outputs, offsets_gcn, bert_embeddings):
        
        gcn_extracted_outputs = [gcn_outputs[i].unsqueeze(0).gather(1, offsets_gcn[i].unsqueeze(0).unsqueeze(2)
                                       .expand(-1, -1, gcn_outputs[i].unsqueeze(0).size(2))).view(gcn_outputs[i].unsqueeze(0).size(0), -1) for i in range(len(gcn_outputs))]
        
        gcn_extracted_outputs = torch.stack(gcn_extracted_outputs, dim=0).squeeze()
        
        embeddings = torch.cat((gcn_extracted_outputs, bert_embeddings), 1) 
        
        return self.fc(embeddings)    
    
    
class GPRModel(nn.Module):
    """The main model."""
    def __init__(self):
        super().__init__()
        self.RGCN =  RGCNModel(h_dim = 1024, num_rels = 3, gated = True)
#         self.BERThead = BERT_Head(1024) # bert output size
        self.head = Head(256, 1024)  # gcn output   berthead output
    
    
    def forward(self, g, offsets_gcn, cls_token):
        gcn_outputs = self.RGCN(g)
#         print(gcn_outputs.shape)
        bert_head_outputs = cls_token
        head_outputs = self.head(gcn_outputs, offsets_gcn, bert_head_outputs)
        return head_outputs

In [296]:
class GPRDataset(Dataset):
    def __init__(self, graphs, gcn_offsets, cls_tokens, labels):

        self.graphs = graphs
        self.cls_tokens = cls_tokens
        self.gcn_offsets = gcn_offsets
        self.y = labels
        
    def __len__(self):
        return len(self.graphs)

    def __getitem__(self, idx):
        return self.graphs[idx], self.gcn_offsets[idx], self.cls_tokens[idx], self.y[idx]

In [297]:
train_dataset = GPRDataset(all_graphs, gcn_offsets, cls_tokens, y_data)

In [298]:
def collate(samples):
    
    graphs, gcn_offsets, cls_tokens, labels = map(list, zip(*samples))
    
    batched_graph = dgl.batch(graphs)
    offsets_gcn = torch.stack([torch.LongTensor(x) for x in gcn_offsets], dim=0)
    
    cls_tokens = torch.stack(cls_tokens, dim=0).squeeze()
    
    labels = torch.stack(labels)
    
    
    return batched_graph, offsets_gcn, cls_tokens, labels

In [311]:
train_dataloader = DataLoader(
   train_dataset,
   collate_fn = collate,
   batch_size = 8,
   shuffle=False,
)

In [314]:
model = GPRModel()
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5, weight_decay=1e-5)
# optimizer = optim.SGD(model.parameters(), lr=1e-5, momentum=0.9)

reg_lambda = 0.035

save_model_name = "exp_model.pt"

for epoch in range(2):  # loop over the dataset multiple times

        running_loss = 0.0
        with autograd.detect_anomaly():
            for i, data in enumerate(train_dataloader, 0):
            # get the inputs; data is a list of [inputs, labels]
                graphs, gcn_offsets, cls_tokens, labels = data
    #             inputs, labels = inputs.to(device), labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()
                # forward + backward + optimize
                outputs = model(graphs, gcn_offsets, cls_tokens)
                
#                 print(graphs, gcn_offsets, cls_tokens, labels)
#                 print(outputs)

    #             print(outputs, labels)

                l2_reg = None
                for w in model.RGCN.parameters():
                    if not l2_reg:
                        l2_reg = w.norm(2)
                    else:
                        l2_reg = l2_reg + w.norm(2)  
                for w in model.head.parameters():
                    if not l2_reg:
                        l2_reg = w.norm(2)
                    else:
                        l2_reg = l2_reg + w.norm(2)

                loss = criterion(outputs, labels) + l2_reg * reg_lambda
                loss.backward()
                
#                 nn.utils.clip_grad_norm_(model.RGCN.parameters(), 1.0)
#                 nn.utils.clip_grad_norm_(model.head.parameters(), 0.5)
                
                optimizer.step()
                # print statistics

                running_loss += loss.item()

                if i % 1 == 0:    # print every 20 mini-batches
                    print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss))
                    running_loss = 0.0

        torch.save(model, save_model_name)

[1,     1] loss: 3.756
[1,     2] loss: 3.821
[1,     3] loss: 4.131
[1,     4] loss: 4.159
[1,     5] loss: 4.181
[1,     6] loss: 4.826
[1,     7] loss: 4.087
[1,     8] loss: 3.873
[1,     9] loss: 4.689
[1,    10] loss: 4.476
[1,    11] loss: 3.771
[1,    12] loss: 3.963
[1,    13] loss: 3.904
[1,    14] loss: 4.542


RuntimeError: Function 'BinaryCrossEntropyWithLogitsBackward' returned nan values in its 0th output.

## Log of all things tried:
1. Add autograd.detect_anomaly - saw that first outputs turn NaN then loss
2. Comment out batch norm - _something_ changes - only first two rows of output are nan now...
3. Increase batch size to 8 - no improvemnt
4. tried setting gated = False for GPR model - no luck - in fact threw some error
5.  Increasing dropout to 0.7 - un-helpful.
6. clip_grad_norm was useless :(
8. batch size 2 -> 56 iterations to failure


    8.1 4 -> 28
    8.2 8 -> 14
    8.3 128 -> 0
    
9. Tried replacing Adam with SGD - no luck..