<a href="https://colab.research.google.com/github/priba/gmprdia2019/blob/master/gmprdia2019.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# ICDAR2019 - Graph-based Methods in Pattern Recognition & Document Image Analysis



## Change runtime of notebook to GPU


```
  Select Runtime -> Change Runtime type -> select runtime python 3 and hardward accelerator GPU
```

## Install requirements

The basic libraries that will be used are:

*   [Network](https://networkx.github.io/)
*   [Pytorch](https://pytorch.org/)
*   [DGL](https://www.dgl.ai/)



In [0]:
!pip3 install networkx
!pip3 install torch
!pip3 install dgl-cu100

##Download Data

###Letter Database

Graphs that represent distorted letter drawings. They consider the 15 capital letters of the Roman alphabet that consist of straight lines only (A, E, F, H, I, K, L, M, N, T, V, W, X, Y, Z). Each node is labeled with a two-dimensional attribute giving its position relative to a reference coordinate system. Edges are unlabeled. The graph database consists of a training set, a validation set, and a test set of size 750 each. Also, three levels of distortions are provided.

This dataset is part of [IAM Graph Database Repository](http://www.fki.inf.unibe.ch/databases/iam-graph-database) and it is also linked in the [IAPR TC15 resources](https://iapr-tc15.greyc.fr/links.html).

It can be considered as a **TOY EXAMPLE** for graph classification.

> Riesen, K. and Bunke, H.: [IAM Graph Database Repository for Graph Based Pattern Recognition and Machine Learning.](https://link.springer.com/chapter/10.1007/978-3-540-89689-0_33) In: da Vitora Lobo, N. et al. (Eds.), SSPR&SPR 2008, LNCS, vol. 5342, pp. 287-297, 2008.


In [0]:
!wget https://iapr-tc15.greyc.fr/IAM/Letter.zip
!unzip Letter.zip

## Prepare data reader

IAM graphs are provided as a GXL file:


```
<gxl>
  <graph id="GRAPH_ID" edgeids="false" edgemode="undirected">
    <node id="_0">
      <attr name="x">
        <float>0.812867</float>
      </attr>
      <attr name="y">
        <float>0.630453</float>
      </attr>
    </node>
    ...
    <node id="_N">
      ...
    </node>
    <edge from="_0" to="_1"/>
    ...
    <edge from="_M" to="_N"/>
  </graph>
</gxl>
```



In [0]:
import numpy as np
import xml.etree.ElementTree as ET
import networkx as nx


def read_letters(file):
  """Parse GXL file and returns a networkx graph
  """
  
  tree_gxl = ET.parse(file)
  root_gxl = tree_gxl.getroot()
  node_label = {}
  node_id = []
  
  # Parse nodes
  for i, node in enumerate(root_gxl.iter('node')):
    node_id += [node.get('id')]
    for attr in node.iter('attr'):
      if (attr.get('name') == 'x'):
        x = float(attr.find('float').text)
      elif (attr.get('name') == 'y'):
        y = float(attr.find('float').text)
    node_label[i] = [x, y]

  node_id = np.array(node_id)

  # Create adjacency matrix
  am = np.zeros((len(node_id), len(node_id)))
  for edge in root_gxl.iter('edge'):
    s = np.where(node_id==edge.get('from'))[0][0]
    t = np.where(node_id==edge.get('to'))[0][0]

    # Undirected Graph
    am[s,t] = 1
    am[t,s] = 1

  # Create the networkx graph
  G = nx.from_numpy_matrix(am)
  nx.set_node_attributes(G, node_label, 'position')
  
  return G

## Load Data with NetworkX

In [0]:
import os

# Select distortion [LOW, MED, HIGH]
distortion = 'LOW'

# Select letter [A, E, F, H, I, K, L, M, N, T, V, W, X, Y, Z]
letter = 'K'

# Select id [0-149]
id=100

# Read the graph and draw it using networkx tools
G = read_letters(os.path.join('Letter', distortion, letter+'P1_'+ str(id).zfill(4) +'.gxl'))
nx.draw(G, pos=dict(G.nodes(data='position')))

## Batch processing


### NetworkX to DGL graph

In [0]:
import dgl

# Create graph object
g = dgl.DGLGraph()

# Transfer NetworkX graph with the corresponding attributes
g.from_networkx(G, node_attrs=['position'])

# Data structure
print(g)

### Define dataset


#### Dataset Division

The dataset is divided by means of CXL files in *train*, *validation* and *test* with the correspondance filename and class:


```
<GraphCollection>
  <fingerprints base="/scratch/mneuhaus/progs/letter-database/automatic/0.1" classmodel="henry5" count="750">
    <print file="AP1_0100.gxl" class="A"/>
    ...
    <print file="ZP1_0149.gxl" class="Z"/>
  </fingerprints>
</GraphCollection>
```

In [0]:
def getFileList(file_path):
  """Parse CXL file and returns the corresponding file list and class
  """
  
  elements, classes = [], []
  tree = ET.parse(file_path)
  root = tree.getroot()
  
  for child in root:
    for sec_child in child:
      if sec_child.tag == 'print':
        elements += [sec_child.attrib['file']]
        classes += sec_child.attrib['class']
        
  return elements, classes

#### Define Dataset Class
Pytorch provides an abstract class representig a dataset, ```torch.utils.data.Dataset```. We need to override two methods:

*   ```__len__``` so that ```len(dataset)``` returns the size of the dataset.
*   ```__getitem__``` to support the indexing such that ```dataset[i]``` can be used to get i-th sample




In [0]:
import torch.utils.data as data

class Letters(data.Dataset):
  """Letter Dataset
  """
  
  def __init__(self, root_path, file_list):
    self.root = root_path
    self.file_list = file_list
    
    # List of files and corresponding labels
    self.graphs, self.labels = getFileList(os.path.join(self.root, self.file_list))
    
    # Labels to numeric value
    self.unique_labels = np.unique(self.labels)
    self.num_classes = len(self.unique_labels)
    
    self.labels = [np.where(target == self.unique_labels)[0][0] 
                   for target in self.labels]
    
    
  def __getitem__(self, index):
    # Read the graph and label
    G = read_letters(os.path.join(self.root, self.graphs[index]))
    target = self.labels[index]
    
    # Convert to DGL format
    g = dgl.DGLGraph()
    g.set_n_initializer(dgl.init.zero_initializer)
    g.set_e_initializer(dgl.init.zero_initializer)
    g.from_networkx(G, node_attrs=['position'])
        
    return g, target
  
  def label2class(self, label):
    # Converts the numeric label to the corresponding string
    return self.unique_labels[label]
  
  def __len__(self):
    # Subset length
    return len(self.labels)

# Define the corresponding subsets for train, validation and test.
trainset = Letters(os.path.join('Letter', distortion), 'train.cxl')
validset = Letters(os.path.join('Letter', distortion), 'validation.cxl')
testset = Letters(os.path.join('Letter', distortion), 'test.cxl')

### Prepare DataLoader

```torch.utils.data.DataLoader``` is an iterator which provides:


*   Data batching
*   Shuffling the data
*   Parallel data loading

In our specific case, we need to deal with graphs of many sizes. Hence, we define a new collate function makin guse of the method ```dgl.batch```.

In [0]:
from torch.utils.data import DataLoader

def collate(samples):
    # The input `samples` is a list of pairs
    #  (graph, label).
    graphs, labels = map(list, zip(*samples))
    batched_graph = dgl.batch(graphs)
    return batched_graph, torch.tensor(labels)
  
# Define the three dataloaders. Train data will be shuffled at each epoch
train_loader = DataLoader(trainset, batch_size=32, shuffle=True,
                         collate_fn=collate)
valid_loader = DataLoader(validset, batch_size=32, collate_fn=collate)
test_loader = DataLoader(testset, batch_size=32, collate_fn=collate)

## Define Model

To define a Graph Convolution, three functions have to be defined:




*   Message: Decide which information is sent by a node
*   Reduce: Combine the messages and the current data
*   NodeApply: Update the node features that are recieved from the reduce function

In [0]:
import dgl.function as fn
import torch.nn as nn
import torch.nn.functional as F

# Sends a message of node feature h.
msg = fn.copy_src(src='h', out='m')
def message_func(edges):
    return {'m': edges.src['h']}

def reduce(nodes):
  """Take an average over all neighbor node features hu and use it to
  overwrite the original node feature."""
  accum = torch.sum(nodes.mailbox['m'], 1)
  return {'m': accum}

class NodeApplyModule(nn.Module):
  """Update the node feature hv with ReLU(Whv+b)."""
  def __init__(self, in_feats, out_feats, activation):
    super(NodeApplyModule, self).__init__()
    self.linear = nn.Linear(in_feats, out_feats)
    self.activation = activation

  def forward(self, node):
    h = torch.cat([node.data['h'], node.data['m']],1)
    h = self.linear(h)
    h = self.activation(h)
    return {'h' : h}  
  
class GCN(nn.Module):
  """Define a GCN layer"""
  def __init__(self, in_feats, out_feats, activation):
    super(GCN, self).__init__()
    self.apply_mod = NodeApplyModule(2*in_feats, out_feats, activation)
    
  def forward(self, g, feature):
    # Initialize the node features with h.
    g.ndata['h'] = feature
    g.update_all(msg, reduce)
    g.apply_nodes(func=self.apply_mod)
    return g.ndata.pop('h')
  
class Net(nn.Module):
  def __init__(self, in_dim, hidden_dim, n_classes):
    super(Net, self).__init__()
    self.layers = nn.ModuleList([
        GCN(in_dim, hidden_dim, F.relu),
        GCN(hidden_dim, hidden_dim, F.relu)])
    self.classify = nn.Linear(hidden_dim, n_classes)
    
  def forward(self, g):
    # For undirected graphs, in_degree is the same as
    # out_degree.
    h = g.ndata['position']
    #h = torch.cat((g.ndata['position'], g.in_degrees().view(-1, 1).float()), dim=1)
    
    if torch.cuda.is_available():
      h = h.cuda() 
    for conv in self.layers:
      h = conv(g, h)
    g.ndata['h'] = h
    hg = dgl.mean_nodes(g, 'h')
    return self.classify(hg)

## New NN Modules

In the last release, several Graph Convolutions are provided by default as [NN Modules](https://docs.dgl.ai/api/python/nn.pytorch.html#graphconv).

In [0]:
from dgl.nn.pytorch import GraphConv

class Net(nn.Module):
  def __init__(self, in_dim, hidden_dim, n_classes):
    super(Net, self).__init__()
    self.layers = nn.ModuleList([
        GraphConv(in_dim, hidden_dim, activation=F.relu),
        GraphConv(hidden_dim, hidden_dim, activation=F.relu)])
    self.classify = nn.Linear(hidden_dim, n_classes)
    
  def forward(self, g):
    h = g.ndata['position']
    
    if torch.cuda.is_available():
      h = h.cuda() 
      
    for conv in self.layers:
      h = conv(g, h)
      
    g.ndata['h'] = h
    hg = dgl.mean_nodes(g, 'h')
    return self.classify(hg)

## Training setup

In [0]:
import torch
import torch.optim as optim

model = Net(2, 256, trainset.num_classes)
if torch.cuda.is_available():
  model = model.cuda()
loss_func = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)
model.train()

epoch_losses = []
for epoch in range(80):
  epoch_loss = 0
  for iter, (bg, label) in enumerate(train_loader):
    if torch.cuda.is_available():
      label = label.cuda()
    prediction = model(bg)
    loss = loss_func(prediction, label)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    epoch_loss += loss.detach().item()
  epoch_loss /= (iter + 1)
  print('Epoch {}, loss {:.4f}'.format(epoch, epoch_loss))
  epoch_losses.append(epoch_loss)

## Evaluation

In [0]:
def accuracy(output, target):
  """Accuacy given a logit vector output and a target class
  """
  _, pred = output.topk(1)
  pred = pred.squeeze()
  correct = pred == target
  correct = correct.float()
  return correct.sum() * 100.0 / correct.shape[0]


model.eval()
acc = 0
with torch.no_grad():
  for iter, (bg, label) in enumerate(test_loader):
    if torch.cuda.is_available():
        label = label.cuda()
    prediction = model(bg)
    acc += accuracy(prediction, label) * label.shape[0]
acc = acc/len(testset)

print('Test accuracy {:.4f}'.format(acc))

## Plot results

In [0]:
from random import randrange
import matplotlib.pyplot as plt
for i in range(10):
  index = randrange(len(testset))
  g, label = testset[index]
  pred = model(g)
  _, pred = pred.topk(1)
  G = g.to_networkx(node_attrs=['position'])
  plt.figure(i)
  position = {k: v.numpy() for k, v in dict(G.nodes(data='position')).items()}
  nx.draw(G, pos=position, arrows=False)
  plt.show()
  print('Label {} {}; Prediction {} {}'.format(label, testset.label2class(label), pred.item(), testset.label2class(pred.item())))