In [1]:
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import torch
import torch_geometric

from src.graphs import Graph
from src.utils import batcher

# GATNET

In [2]:
import torch
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T
from torch_geometric.nn import GATConv

class GATNet(torch.nn.Module):
    def __init__(self, data, heads_layer1, 
               heads_layer2, dropout, dropout_alphas):
        super().__init__()

        self.dropout = dropout
        num_features = data.num_features
        num_classes = 2  # hardcoded for now

        self.conv1 = GATConv(in_channels=num_features, out_channels=8,
                             heads=heads_layer1, concat=True, negative_slope=0.2, 
                             dropout=dropout_alphas)

        self.conv2 = GATConv(in_channels=8*heads_layer1, out_channels=num_classes, 
                             heads=heads_layer2, concat=False, negative_slope=0.2,
                             dropout=dropout_alphas)

    def forward(self, data):
        x=data.x
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.conv1(x, data.edge_index)
        x = F.elu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.conv2(x, data.edge_index)
      
        return F.log_softmax(x, dim=1)

In [3]:
def train(model, data, optimizer, mask):
    """
    Single iteration of training
    """
    # set training mode to True (enabling dropout, etc)
    model.train()
    
    # make sure format of weights is correct
    model.double()
    
    # reset gradients
    optimizer.zero_grad()
    
    # get output of model, which is log-probability (log of softmax)
    # note mask is not applied because message passing needs all nodes
    log_softmax = model(data)
    
    labels = data.y # labels of each node
    
    # apply training mask
    nll_loss = F.nll_loss(log_softmax[mask], labels[mask])
    
    # backprop- compute gradients
    nll_loss.backward()
    
    # backprop- update parameters
    optimizer.step()
    

def compute_accuracy(model, data, mask):
    # set eval mode to True (disable dropout, etc)
    model.eval()
    
    model.double()
    
    # get output of model
    log_softmax = model(data)
    
    # get index of max value from softmax, equivalent to y pred
    yp = log_softmax[mask].argmax(dim=1) 
    
    
    
    return yp == data.y[mask]

# run without gradient (faster)
@torch.no_grad() 
def test(model, data):
    return compute_accuracy(model, data, data.mask)

In [4]:
def detect_agg(g):
    if g.graph_attr['candidate_growth_ratio'] > 10 and g.graph_attr['candidate_rgr'] > 2.5:
        return True
    else:
        return False

In [5]:
data = Path('..','data','candidate-grains-processed')
data.exists()

True

In [6]:
json_paths = list(sorted(data.glob('*'))[-1].glob('*.json'))

In [7]:
def normalize_features(d):
    d.x = (d.x - d.x.mean(dim=0))/d.x.std(dim=0)
    d.edge_attr = (d.edge_attr - d.edge_attr.mean(dim=0)/d.edge_attr.std(dim=0))

In [8]:
graphs = [Graph.from_json(x) for x in json_paths[:200]]
datasets = [g.to_pyg_dataset() for g in graphs]
for g, d in zip(graphs, datasets):
    y = np.zeros(len(g.nodes), np.int)
    y[d.mask] = int(detect_agg(g))
    d.y = torch.tensor(y, dtype=torch.long)
    normalize_features(d)

In [9]:
batch = torch_geometric.data.Batch().from_data_list(datasets)

In [10]:
gat = GATNet(datasets[0], 4, 4, 0.5, 0.5)
gat.double()


#optimizer = torch.optim.Adam(gat.parameters(), lr=0.005, weight_decay=1e-3)
optimizer = torch.optim.Adam(gat.parameters(), lr=0.005, weight_decay=5e-4)

log = 'Epoch: {:03d}, Train: {:.4f}, Loss: {:.4f}'#', Val: {:.4f}'
for epoch in range(1, 51):
    train(gat, batch, optimizer, batch.mask)
    #for d in datasets:
    #    train(gat, d, optimizer, d.mask)
    if epoch % 5 == 0:
        tests = [test(gat, d) for d in datasets]
        losses = [F.nll_loss(gat(d)[d.mask], d.y[d.mask]).detach().numpy() for d in datasets]
        
        print(log.format(epoch, np.mean(tests), np.mean(losses)), )
        

Epoch: 005, Train: 0.6300, Loss: 0.7205
Epoch: 010, Train: 0.6250, Loss: 0.6924
Epoch: 015, Train: 0.6250, Loss: 0.6706
Epoch: 020, Train: 0.6150, Loss: 0.6666
Epoch: 025, Train: 0.6150, Loss: 0.6719
Epoch: 030, Train: 0.6150, Loss: 0.6760
Epoch: 035, Train: 0.6150, Loss: 0.6763
Epoch: 040, Train: 0.6150, Loss: 0.6759
Epoch: 045, Train: 0.6150, Loss: 0.6755
Epoch: 050, Train: 0.6150, Loss: 0.6750


In [14]:
runs_all = [list(x.glob('*.json')) for x in data.glob('*') if x.is_dir() and len(list(x.glob('*.json'))) > 500] 
temp = []
[temp.extend(r) for r in runs_all]
runs_all = sorted(temp)
rs = np.random.RandomState(seed=3346665170)
rs.shuffle(runs_all)
from multiprocessing import get_context, Pool
def load_wrapper(x):
    from src.graphs import Graph
    g = Graph.from_json(x)
    d = g.to_pyg_dataset()
    y = np.zeros(len(g.nodes), np.int)
    y[d.mask] = int(detect_agg(g))
    d.y = torch.tensor(y, dtype=torch.long)
    return d

#with Pool(processes=8) as p:
#    datasets_large = p.map(load_wrapper, runs_all[:1000])

datasets_large = list(map(load_wrapper, runs_all[:1000]))

In [18]:
batches = batcher(datasets_large, batch_size=100, min_size=30)
batches = [torch_geometric.data.Batch().from_data_list(b) for b in batches]

In [28]:
gat = GATNet(datasets[0], 4, 4, 0.5, 0.5)
gat.double()


#optimizer = torch.optim.Adam(gat.parameters(), lr=0.005, weight_decay=1e-3)
optimizer = torch.optim.Adam(gat.parameters(), lr=0.005, weight_decay=5e-4)

log = 'Epoch: {:03d}, Train: {:.4f}, Loss: {:.4f}'#', Val: {:.4f}'
for epoch in range(1, 51):
    for batch in batches:
        train(gat, batch, optimizer, batch.mask)
    if epoch % 5 == 0:
        # TODO fix these functions to work with multiple batches
        tests = [test(gat, d) for d in batches[0]]
        losses = [F.nll_loss(gat(d)[d.mask], d.y[d.mask]).detach().numpy() for d in batches]
        
        print(log.format(epoch, np.mean(tests), np.mean(losses)), )
        

AttributeError: 'tuple' object has no attribute 'mask'

In [27]:
batches[0]

Batch(batch=[44182], edge_attr=[265142, 4], edge_index=[2, 265142], mask=[44182], x=[44182, 9], y=[44182])