# NAACL'21 DLG4NLP Tutorial Demo: Text Classification

In this tutorial demo, we will use the Graph4NLP library to build a GNN-based text classification model. The model consists of 
- graph construction module (e.g., dependency based static graph)
- graph embedding module (e.g., Bi-Fuse GraphSAGE)
- predictoin module (e.g., graph pooling + MLP classifier)

We will use the built-in module APIs to build the model, and evaluate it on the TREC dataset.

### Environment setup

1. Create virtual environment
```
conda create --name graph4nlp python=3.7
conda activate graph4nlp
```

2. Install [graph4nlp](https://github.com/graph4ai/graph4nlp) library
- Clone the github repo
```
git clone -b stable https://github.com/graph4ai/graph4nlp.git
cd graph4nlp
```
- Then run `./configure` (or `./configure.bat` if you are using Windows 10) to config your installation. The configuration program will ask you to specify your CUDA version. If you do not have a GPU, please choose 'cpu'.
```
./configure
```
- Finally, install the package
```
python setup.py install
```

3. Set up StanfordCoreNLP (for static graph construction only, unnecessary for this demo because preprocessed data is provided)
- Download [StanfordCoreNLP](https://stanfordnlp.github.io/CoreNLP/)
- Go to the root folder and start the server
```
java -mx4g -cp "*" edu.stanford.nlp.pipeline.StanfordCoreNLPServer -port 9000 -timeout 15000
```

In [1]:
import os
import time
import datetime
import yaml
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torch.backends.cudnn as cudnn

from graph4nlp.pytorch.datasets.trec import TrecDataset
from graph4nlp.pytorch.modules.graph_construction import *
from graph4nlp.pytorch.modules.graph_construction.embedding_construction import WordEmbedding
from graph4nlp.pytorch.modules.graph_embedding import *
from graph4nlp.pytorch.modules.prediction.classification.graph_classification import FeedForwardNN
from graph4nlp.pytorch.modules.evaluation.base import EvaluationMetricBase
from graph4nlp.pytorch.modules.evaluation.accuracy import Accuracy
from graph4nlp.pytorch.modules.utils.generic_utils import EarlyStopping
from graph4nlp.pytorch.modules.loss.general_loss import GeneralLoss
from graph4nlp.pytorch.modules.utils.logger import Logger
from graph4nlp.pytorch.modules.utils import constants as Constants

Using backend: pytorch


In [2]:
class TextClassifier(nn.Module):
    def __init__(self, vocab, config):
        super(TextClassifier, self).__init__()
        self.config = config
        self.vocab = vocab
        embedding_style = {'single_token_item': True if config['graph_type'] != 'ie' else False,
                            'emb_strategy': config.get('emb_strategy', 'w2v_bilstm'),
                            'num_rnn_layers': 1,
                            'bert_model_name': config.get('bert_model_name', 'bert-base-uncased'),
                            'bert_lower_case': True
                           }

        assert not (config['graph_type'] in ('node_emb', 'node_emb_refined') and config['gnn'] == 'gat'), \
                                'dynamic graph construction does not support GAT'

        use_edge_weight = False
        if config['graph_type'] == 'dependency':
            self.graph_topology = DependencyBasedGraphConstruction(
                                   embedding_style=embedding_style,
                                   vocab=vocab.in_word_vocab,
                                   hidden_size=config['num_hidden'],
                                   word_dropout=config['word_dropout'],
                                   rnn_dropout=config['rnn_dropout'],
                                   fix_word_emb=not config['no_fix_word_emb'],
                                   fix_bert_emb=not config.get('no_fix_bert_emb', False))
        elif config['graph_type'] == 'constituency':
            self.graph_topology = ConstituencyBasedGraphConstruction(
                                   embedding_style=embedding_style,
                                   vocab=vocab.in_word_vocab,
                                   hidden_size=config['num_hidden'],
                                   word_dropout=config['word_dropout'],
                                   rnn_dropout=config['rnn_dropout'],
                                   fix_word_emb=not config['no_fix_word_emb'],
                                   fix_bert_emb=not config.get('no_fix_bert_emb', False))
        elif config['graph_type'] == 'ie':
            self.graph_topology = IEBasedGraphConstruction(
                                   embedding_style=embedding_style,
                                   vocab=vocab.in_word_vocab,
                                   hidden_size=config['num_hidden'],
                                   word_dropout=config['word_dropout'],
                                   rnn_dropout=config['rnn_dropout'],
                                   fix_word_emb=not config['no_fix_word_emb'],
                                   fix_bert_emb=not config.get('no_fix_bert_emb', False))
        elif config['graph_type'] == 'node_emb':
            self.graph_topology = NodeEmbeddingBasedGraphConstruction(
                                   vocab.in_word_vocab,
                                   embedding_style,
                                   sim_metric_type=config['gl_metric_type'],
                                   num_heads=config['gl_num_heads'],
                                   top_k_neigh=config['gl_top_k'],
                                   epsilon_neigh=config['gl_epsilon'],
                                   smoothness_ratio=config['gl_smoothness_ratio'],
                                   connectivity_ratio=config['gl_connectivity_ratio'],
                                   sparsity_ratio=config['gl_sparsity_ratio'],
                                   input_size=config['num_hidden'],
                                   hidden_size=config['gl_num_hidden'],
                                   fix_word_emb=not config['no_fix_word_emb'],
                                   fix_bert_emb=not config.get('no_fix_bert_emb', False),
                                   word_dropout=config['word_dropout'],
                                   rnn_dropout=config['rnn_dropout'])
            use_edge_weight = True
        elif config['graph_type'] == 'node_emb_refined':
            self.graph_topology = NodeEmbeddingBasedRefinedGraphConstruction(
                                    vocab.in_word_vocab,
                                    embedding_style,
                                    config['init_adj_alpha'],
                                    sim_metric_type=config['gl_metric_type'],
                                    num_heads=config['gl_num_heads'],
                                    top_k_neigh=config['gl_top_k'],
                                    epsilon_neigh=config['gl_epsilon'],
                                    smoothness_ratio=config['gl_smoothness_ratio'],
                                    connectivity_ratio=config['gl_connectivity_ratio'],
                                    sparsity_ratio=config['gl_sparsity_ratio'],
                                    input_size=config['num_hidden'],
                                    hidden_size=config['gl_num_hidden'],
                                    fix_word_emb=not config['no_fix_word_emb'],
                                    fix_bert_emb=not config.get('no_fix_bert_emb', False),
                                    word_dropout=config['word_dropout'],
                                    rnn_dropout=config['rnn_dropout'])
            use_edge_weight = True
        else:
            raise RuntimeError('Unknown graph_type: {}'.format(config['graph_type']))

        if 'w2v' in self.graph_topology.embedding_layer.word_emb_layers:
            self.word_emb = self.graph_topology.embedding_layer.word_emb_layers['w2v'].word_emb_layer
        else:
            self.word_emb = WordEmbedding(
                            self.vocab.in_word_vocab.embeddings.shape[0],
                            self.vocab.in_word_vocab.embeddings.shape[1],
                            pretrained_word_emb=self.vocab.in_word_vocab.embeddings,
                            fix_emb=not config['no_fix_word_emb'],
                            device=config['device']).word_emb_layer

        if config['gnn'] == 'gat':
            heads = [config['gat_num_heads']] * (config['gnn_num_layers'] - 1) + [config['gat_num_out_heads']]
            self.gnn = GAT(config['gnn_num_layers'],
                        config['num_hidden'],
                        config['num_hidden'],
                        config['num_hidden'],
                        heads,
                        direction_option=config['gnn_direction_option'],
                        feat_drop=config['gnn_dropout'],
                        attn_drop=config['gat_attn_dropout'],
                        negative_slope=config['gat_negative_slope'],
                        residual=config['gat_residual'],
                        activation=F.elu)
        elif config['gnn'] == 'graphsage':
            self.gnn = GraphSAGE(config['gnn_num_layers'],
                        config['num_hidden'],
                        config['num_hidden'],
                        config['num_hidden'],
                        config['graphsage_aggreagte_type'],
                        direction_option=config['gnn_direction_option'],
                        feat_drop=config['gnn_dropout'],
                        bias=True,
                        norm=None,
                        activation=F.relu,
                        use_edge_weight=use_edge_weight)
        elif config['gnn'] == 'ggnn':
            self.gnn = GGNN(config['gnn_num_layers'],
                        config['num_hidden'],
                        config['num_hidden'],
                        config['num_hidden'],
                        feat_drop=config['gnn_dropout'],
                        direction_option=config['gnn_direction_option'],
                        bias=True,
                        use_edge_weight=use_edge_weight)
        else:
            raise RuntimeError('Unknown gnn type: {}'.format(config['gnn']))

        self.clf = FeedForwardNN(2 * config['num_hidden'] \
                        if config['gnn_direction_option'] == 'bi_sep' \
                        else config['num_hidden'],
                        config['num_classes'],
                        [config['num_hidden']],
                        graph_pool_type=config['graph_pooling'],
                        dim=config['num_hidden'],
                        use_linear_proj=config['max_pool_linear_proj'])

        self.loss = GeneralLoss('CrossEntropy')


    def forward(self, graph_list, tgt=None, require_loss=True):
        # build graph topology
        batch_gd = self.graph_topology(graph_list)

        # run GNN encoder
        self.gnn(batch_gd)

        # run graph classifier
        self.clf(batch_gd)
        logits = batch_gd.graph_attributes['logits']

        if require_loss:
            loss = self.loss(logits, tgt)
            return logits, loss
        else:
            return logits

In [3]:
class ModelHandler:
    def __init__(self, config):
        super(ModelHandler, self).__init__()
        self.config = config
        self.logger = Logger(self.config['out_dir'], config={k:v for k, v in self.config.items() if k != 'device'}, overwrite=True)
        self.logger.write(self.config['out_dir'])
        self._build_device()
        self._build_dataloader()
        self._build_model()
        self._build_optimizer()
        self._build_evaluation()

    def _build_device(self):
        if not self.config['no_cuda'] and torch.cuda.is_available():
            print('[ Using CUDA ]')
            self.config['device'] = torch.device('cuda' if self.config['gpu'] < 0 else 'cuda:%d' % self.config['gpu'])
            torch.cuda.manual_seed(self.config['seed'])
            torch.cuda.manual_seed_all(self.config['seed'])
            torch.backends.cudnn.deterministic = True
            cudnn.benchmark = False
        else:
            self.config['device'] = torch.device('cpu')
        
    def _build_dataloader(self):
        dynamic_init_topology_builder = None
        if self.config['graph_type'] == 'dependency':
            topology_builder = DependencyBasedGraphConstruction
            graph_type = 'static'
            merge_strategy = 'tailhead'
        elif self.config['graph_type'] == 'constituency':
            topology_builder = ConstituencyBasedGraphConstruction
            graph_type = 'static'
            merge_strategy = 'tailhead'
        elif self.config['graph_type'] == 'ie':
            topology_builder = IEBasedGraphConstruction
            graph_type = 'static'
            merge_strategy = 'global'
        elif self.config['graph_type'] == 'node_emb':
            topology_builder = NodeEmbeddingBasedGraphConstruction
            graph_type = 'dynamic'
            merge_strategy = None
        elif self.config['graph_type'] == 'node_emb_refined':
            topology_builder = NodeEmbeddingBasedRefinedGraphConstruction
            graph_type = 'dynamic'
            merge_strategy = 'tailhead'

            if self.config['init_graph_type'] == 'line':
                dynamic_init_topology_builder = None
            elif self.config['init_graph_type'] == 'dependency':
                dynamic_init_topology_builder = DependencyBasedGraphConstruction
            elif self.config['init_graph_type'] == 'constituency':
                dynamic_init_topology_builder = ConstituencyBasedGraphConstruction
            elif self.config['init_graph_type'] == 'ie':
                merge_strategy = 'global'
                dynamic_init_topology_builder = IEBasedGraphConstruction
            else:
                raise RuntimeError('Define your own dynamic_init_topology_builder')
        else:
            raise RuntimeError('Unknown graph_type: {}'.format(self.config['graph_type']))

        topology_subdir = '{}_graph'.format(self.config['graph_type'])
        if self.config['graph_type'] == 'node_emb_refined':
            topology_subdir += '_{}'.format(self.config['init_graph_type'])

        dataset = TrecDataset(root_dir=self.config.get('root_dir', self.config['root_data_dir']),
                              pretrained_word_emb_name=self.config.get('pretrained_word_emb_name', "840B"),
                              merge_strategy=merge_strategy,
                              seed=self.config['seed'],
                              thread_number=4,
                              port=9000,
                              timeout=15000,
                              word_emb_size=300,
                              graph_type=graph_type,
                              topology_builder=topology_builder,
                              topology_subdir=topology_subdir,
                              dynamic_graph_type=self.config['graph_type'] if \
                                  self.config['graph_type'] in ('node_emb', 'node_emb_refined') else None,
                              dynamic_init_topology_builder=dynamic_init_topology_builder,
                              dynamic_init_topology_aux_args={'dummy_param': 0})

        self.train_dataloader = DataLoader(dataset.train, batch_size=self.config['batch_size'], shuffle=True,
                                           num_workers=self.config['num_workers'],
                                           collate_fn=dataset.collate_fn)
        if hasattr(dataset, 'val')==False:
            dataset.val = dataset.test
        self.val_dataloader = DataLoader(dataset.val, batch_size=self.config['batch_size'], shuffle=False,
                                          num_workers=self.config['num_workers'],
                                          collate_fn=dataset.collate_fn)
        self.test_dataloader = DataLoader(dataset.test, batch_size=self.config['batch_size'], shuffle=False,
                                          num_workers=self.config['num_workers'],
                                          collate_fn=dataset.collate_fn)
        self.vocab = dataset.vocab_model
        self.config['num_classes'] = dataset.num_classes
        self.num_train = len(dataset.train)
        self.num_val = len(dataset.val)
        self.num_test = len(dataset.test)
        print('Train size: {}, Val size: {}, Test size: {}'
            .format(self.num_train, self.num_val, self.num_test))
        self.logger.write('Train size: {}, Val size: {}, Test size: {}'
            .format(self.num_train, self.num_val, self.num_test))

    def _build_model(self):
        self.model = TextClassifier(self.vocab, self.config).to(self.config['device'])

    def _build_optimizer(self):
        parameters = [p for p in self.model.parameters() if p.requires_grad]
        self.optimizer = optim.Adam(parameters, lr=self.config['lr'])
        self.stopper = EarlyStopping(os.path.join(self.config['out_dir'], Constants._SAVED_WEIGHTS_FILE), patience=self.config['patience'])
        self.scheduler = ReduceLROnPlateau(self.optimizer, mode='max', factor=self.config['lr_reduce_factor'], \
            patience=self.config['lr_patience'], verbose=True)

    def _build_evaluation(self):
        self.metric = Accuracy(['accuracy'])

    def train(self):
        dur = []
        for epoch in range(self.config['epochs']):
            self.model.train()
            train_loss = []
            train_acc = []
            t0 = time.time()
            for i, data in enumerate(self.train_dataloader):
                tgt = data['tgt_tensor'].to(self.config['device'])
                data['graph_data'] = data['graph_data'].to(self.config['device'])
                logits, loss = self.model(data['graph_data'], tgt, require_loss=True)

                # add graph regularization loss if available
                if data['graph_data'].graph_attributes.get('graph_reg', None) is not None:
                    loss = loss + data['graph_data'].graph_attributes['graph_reg']

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                train_loss.append(loss.item())

                pred = torch.max(logits, dim=-1)[1].cpu()
                train_acc.append(self.metric.calculate_scores(ground_truth=tgt.cpu(), predict=pred.cpu(), zero_division=0)[0])
                dur.append(time.time() - t0)

            val_acc = self.evaluate(self.val_dataloader)
            self.scheduler.step(val_acc)
            print('Epoch: [{} / {}] | Time: {:.2f}s | Loss: {:.4f} | Train Acc: {:.4f} | Val Acc: {:.4f}'.
              format(epoch + 1, self.config['epochs'], np.mean(dur), np.mean(train_loss), np.mean(train_acc), val_acc))
            self.logger.write('Epoch: [{} / {}] | Time: {:.2f}s | Loss: {:.4f} | Train Acc: {:.4f} | Val Acc: {:.4f}'.
                        format(epoch + 1, self.config['epochs'], np.mean(dur), np.mean(train_loss), np.mean(train_acc), val_acc))

            if self.stopper.step(val_acc, self.model):
                break

        return self.stopper.best_score

    def evaluate(self, dataloader):
        self.model.eval()
        with torch.no_grad():
            pred_collect = []
            gt_collect = []
            for i, data in enumerate(dataloader):
                tgt = data['tgt_tensor'].to(self.config['device'])
                data['graph_data'] = data['graph_data'].to(self.config["device"])
                logits = self.model(data['graph_data'], require_loss=False)
                pred_collect.append(logits)
                gt_collect.append(tgt)

            pred_collect = torch.max(torch.cat(pred_collect, 0), dim=-1)[1].cpu()
            gt_collect = torch.cat(gt_collect, 0).cpu()
            score = self.metric.calculate_scores(ground_truth=gt_collect, predict=pred_collect, zero_division=0)[0]

            return score

    def test(self):
        # restored best saved model
        self.stopper.load_checkpoint(self.model)

        t0 = time.time()
        acc = self.evaluate(self.test_dataloader)
        dur = time.time() - t0
        print('Test examples: {} | Time: {:.2f}s |  Test Acc: {:.4f}'.
          format(self.num_test, dur, acc))
        self.logger.write('Test examples: {} | Time: {:.2f}s |  Test Acc: {:.4f}'.
          format(self.num_test, dur, acc))

        return acc

In [4]:
def print_config(config):
    print('**************** MODEL CONFIGURATION ****************')
    for key in sorted(config.keys()):
        val = config[key]
        keystr = '{}'.format(key) + (' ' * (24 - len(key)))
        print('{} -->   {}'.format(keystr, val))
    print('**************** MODEL CONFIGURATION ****************')

In [5]:
# config setup
config_file = '../config/trec/graphsage_bi_fuse_static_dependency.yaml'
config = yaml.load(open(config_file, 'r'), Loader=yaml.FullLoader)
print_config(config)

**************** MODEL CONFIGURATION ****************
batch_size               -->   50
dataset                  -->   trec
epochs                   -->   500
gat_attn_dropout         -->   None
gat_negative_slope       -->   None
gat_num_heads            -->   None
gat_num_out_heads        -->   None
gat_residual             -->   False
gl_connectivity_ratio    -->   None
gl_epsilon               -->   None
gl_metric_type           -->   None
gl_num_heads             -->   1
gl_num_hidden            -->   300
gl_smoothness_ratio      -->   None
gl_sparsity_ratio        -->   None
gl_top_k                 -->   None
gnn                      -->   graphsage
gnn_direction_option     -->   bi_fuse
gnn_dropout              -->   0.3
gnn_num_layers           -->   1
gpu                      -->   0
graph_pooling            -->   avg_pool
graph_type               -->   dependency
graphsage_aggreagte_type -->   lstm
init_adj_alpha           -->   None
init_graph_type          -->   None
lr   

In [None]:
# run model
np.random.seed(config['seed'])
torch.manual_seed(config['seed'])

ts = datetime.datetime.now().timestamp()
config['out_dir'] += '_{}'.format(ts)
print('\n' + config['out_dir'])

runner = ModelHandler(config)
t0 = time.time()

val_acc = runner.train()
test_acc = runner.test()

runtime = time.time() - t0
print('Total runtime: {:.2f}s'.format(runtime))
runner.logger.write('Total runtime: {:.2f}s\n'.format(runtime))
runner.logger.close()

print('val acc: {}, test acc: {}'.format(val_acc, test_acc))


out/trec/graphsage_bi_fuse_dependency_ckpt_1623813162.034149
Loading pre-built vocab model stored in ../data/trec/processed/dependency_graph/vocab.pt
Train size: 5452, Val size: 500, Test size: 500
[ Fix word embeddings ]
Epoch: [1 / 500] | Time: 15.75s | Loss: 1.1737 | Train Acc: 0.5293 | Val Acc: 0.7460
Saved model to out/trec/graphsage_bi_fuse_dependency_ckpt_1623813162.034149/params.saved
Epoch: [2 / 500] | Time: 16.13s | Loss: 0.6645 | Train Acc: 0.7607 | Val Acc: 0.8320
Saved model to out/trec/graphsage_bi_fuse_dependency_ckpt_1623813162.034149/params.saved
Epoch: [3 / 500] | Time: 17.04s | Loss: 0.5631 | Train Acc: 0.7938 | Val Acc: 0.8260
EarlyStopping counter: 1 out of 10
Epoch: [4 / 500] | Time: 16.94s | Loss: 0.4838 | Train Acc: 0.8278 | Val Acc: 0.8580
Saved model to out/trec/graphsage_bi_fuse_dependency_ckpt_1623813162.034149/params.saved
Epoch: [5 / 500] | Time: 17.27s | Loss: 0.3975 | Train Acc: 0.8638 | Val Acc: 0.8960
Saved model to out/trec/graphsage_bi_fuse_dependen