In [12]:
import os
import dgl
import torch
import pickle
import datetime

import numpy as np
import pandas as pd

from itertools import chain
from dgl.data import DGLDataset

# Step 1: Create dataset

In [2]:
class CB12Dataset(DGLDataset):
    """
    CareerBuilder12 dataset for node classification
    
    
    Dataset statistics:
    
    - Nodes: 
    - Node features: 
    - Edges: 
    - Edge Weights:
    - Number of Classes: 
    
    Attributes
    ----------
    num_classes : int
        Number of node classes
    data : list
        A list of :class:`dgl.DGLGraph` objects
    
    """
    
    def __init__(self):
        super(CB12Dataset, self).__init__(name='cb12')
    
    def process(self):
        nodes_data = pd.read_csv("../data/cb12/graph/titles.csv", "\t")
        edges_data = pd.read_csv("../data/cb12/graph/title_title_transition_MinorGroup200.csv", "\t")
        edges_src = torch.from_numpy(edges_data['Src'].to_numpy())
        edges_dst = torch.from_numpy(edges_data['Dst'].to_numpy())
        
        # Node feature
        all_tokens = chain.from_iterable([eval(item) for item in nodes_data['JobTitle_tokens_idx']])
        vocab_size = len(set(all_tokens))
        node_features = []
        for node in nodes_data['JobTitle_tokens_idx'].tolist():
            feature = [0 for _ in range(vocab_size)]
            for i in eval(node):
                feature[i] = 1
            node_features.append(feature)
        
        node_features = torch.from_numpy(np.array(node_features)).float()
        edge_features = torch.from_numpy(edges_data['Weight'].to_numpy())
        
        
        self.all_labels = nodes_data["MajorGroup"].tolist()
        label_to_id = {label: idx for idx, label in enumerate(set(self.all_labels))}
        print(dict(enumerate(nodes_data['MajorGroup'].astype('category').cat.categories)))
        node_labels = torch.from_numpy(nodes_data['MajorGroup'].astype('category').cat.codes.to_numpy()).int()
        
        
        g = dgl.graph((edges_src, edges_dst), num_nodes=nodes_data.shape[0])
        #self.graph = dgl.add_self_loop(g)
        self.graph = g
        
        
        self.graph.ndata['feature'] = node_features
        self.graph.ndata['label'] = node_labels
        self.graph.edata['weight'] = edge_features 
        
        
        n_nodes = nodes_data.shape[0]
        n_train = int(n_nodes * 0.6)
        n_val = int(n_nodes * 0.2)
        train_mask = torch.zeros(n_nodes, dtype=torch.bool)
        val_mask = torch.zeros(n_nodes, dtype=torch.bool)
        test_mask = torch.zeros(n_nodes, dtype=torch.bool)
        train_mask[:n_train] = True
        val_mask[n_train:n_train + n_val] = True
        test_mask[n_train + n_val:] = True
        self.graph.ndata['train_mask'] = train_mask
        self.graph.ndata['val_mask'] = val_mask
        self.graph.ndata['test_mask'] = test_mask
    
    
    def __getitem__(self, idx):
        """
        Get graph object
        
        Parameters
        ----------
        idx : int
            Item index
        Returns
        -------
        :class:`dgl.DGLGraph`
            graph structure and labels.
            - ``ndata['label']``: ground truth labelsv
        """
        assert idx == 0, "This dataset has only one graph"
        return self.graph
    
    def __len__(self):
        r"""The number of graphs in the dataset."""
        return 1   
    
    

    @property
    def data(self):
        return self._data
   

    @property
    def num_classes(self):
        """Number of classes."""
        print("Number of classes: {}".format(len(set(self.all_labels))))
        return len(set(self.all_labels))

In [3]:
dataset_cb12 = CB12Dataset()
graph_cb12 = dataset_cb12[0]

  self.process()


{0: 11, 1: 13, 2: 15, 3: 17, 4: 21, 5: 25, 6: 29, 7: 31, 8: 33, 9: 35, 10: 41, 11: 43, 12: 47, 13: 49, 14: 51, 15: 53}




In [4]:
graph_cb12

Graph(num_nodes=9216, num_edges=20640,
      ndata_schemes={'feature': Scheme(shape=(1682,), dtype=torch.float32), 'label': Scheme(shape=(), dtype=torch.int32), 'train_mask': Scheme(shape=(), dtype=torch.bool), 'val_mask': Scheme(shape=(), dtype=torch.bool), 'test_mask': Scheme(shape=(), dtype=torch.bool)}
      edata_schemes={'weight': Scheme(shape=(), dtype=torch.int64)})

In [5]:
labels = graph_cb12.ndata['label']

train_mask = graph_cb12.ndata['train_mask']
val_mask = graph_cb12.ndata['val_mask']
test_mask = graph_cb12.ndata['test_mask']
    
    
train_labels = labels[train_mask]
val_labels = labels[val_mask]
test_labels = labels[test_mask]

# Step 2: Training

In [6]:
import torch.nn as nn
import dgl.function as fn
import torch.nn.functional as F
from sklearn.metrics import accuracy_score,precision_score,recall_score,f1_score

In [47]:
def compute_metrics(logits, labels):
    _, indices = torch.max(logits, dim=1)
    preds = indices.long().cpu().numpy()
    labels = labels.cpu().numpy()
    
    acc = accuracy_score(labels, preds)
    macro_f1 = f1_score(labels, preds, average='macro')
    micro_f1 = f1_score(labels, preds, average='micro')
    weighted_f1 = f1_score(labels, preds, average='weighted')
    return acc, macro_f1, micro_f1, weighted_f1


def evaluate(model, g, features, labels, mask, loss_fn):
    model.eval()
    with torch.no_grad():
        logits = model(g, features)
        logits = logits[mask]
        labels = labels[mask]
    loss = loss_fn(logits, labels.long())
    acc, macro_f1, micro_f1, weighted_f1 = compute_metrics(logits, labels)
    return loss, acc, macro_f1, micro_f1, weighted_f1

In [52]:
class EarlyStopping:
    def __init__(self, patience=10):
        dt = datetime.datetime.now()
        self.filename = 'early_stop_{}_{:02d}-{:02d}-{:02d}.pth'.format(dt.date(), dt.hour, dt.minute, dt.second)
        
        self.patience = patience
        self.counter = 0
        self.best_loss = None
        self.best_score = None
        self.early_stop = False
        
    def save_checkpoint(self, model, model_name):
        """
        Save model when validation loss decrease
        """
        dirs = os.path.join('../checkpoints/' + model_name)
        if not os.path.exists(dirs):
            os.makedirs(dirs)
        torch.save(model.state_dict(), os.path.join(dirs + '/' +  self.filename))
    
    def load_checkpoint(self, model, model_name):
        """
        Load the latest checkpoint
        """
        model.load_state_dict(torch.load(os.path.join('../checkpoints/' + model_name + '/' + self.filename)))   
    
    def step(self, model, model_name, loss, acc):
        score = acc
        if self.best_score is None:
            self.best_score = score
            self.best_loss = loss
            self.save_checkpoint(model, model_name)
            
        elif (loss > self.best_loss) and (acc < self.best_score):
            self.counter += 1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        
        else:
            if (loss <= self.best_loss) and (acc >= self.best_score):
                self.save_checkpoint(model, model_name)
                
            self.best_score = np.max((acc, self.best_score))
            self.best_loss = np.min((loss, self.best_loss))
            self.counter = 0
        
        return self.early_stop

    

    
def train(g, model, model_name, lr, weight_decay, epoch):
    stopper = EarlyStopping(patience=100)
    loss_fcn = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    best_val_acc = 0
    best_test_acc = 0
    
    best_val_macro_f1 = 0
    best_test_macro_f1 = 0

    features = g.ndata['feature']
    labels = g.ndata['label']
    
    train_mask = g.ndata['train_mask']
    val_mask = g.ndata['val_mask']
    test_mask = g.ndata['test_mask']
    
    
    list_train_loss = []
    
    list_train_acc = []
    list_val_acc = []
    list_test_acc = []
    
    list_train_macro_f1 = []
    list_val_macro_f1  = []
    list_test_macro_f1  = []
    
    list_train_micro_f1 = []
    list_val_micro_f1  = []
    list_test_micro_f1  = []
    
    list_train_weighted_f1 = []
    list_val_weighted_f1  = []
    list_test_weighted_f1  = []
    
    
    for e in range(epoch):
        # Forward
        logits = model(g, features)

        # Compute prediction
        pred = logits.argmax(1)

        # Compute loss
        # Note that you should only compute the losses of the nodes in the training set.
        loss = loss_fcn(logits[train_mask], labels[train_mask].long())
        list_train_loss.append(loss.detach().numpy())
        
        
        # Backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        
        # Compute accuracy on training/validation/test
        train_acc, train_macro_f1, train_micro_f1, train_weighted_f1 = compute_metrics(logits[train_mask], labels[train_mask])
        list_train_acc.append(train_acc)
        list_train_macro_f1.append(train_macro_f1)
        list_train_micro_f1.append(train_micro_f1)
        list_train_weighted_f1.append(train_weighted_f1)
    
        
        val_loss, val_acc, val_macro_f1, val_micro_f1, val_weighted_f1 = evaluate(model, g, features, labels, val_mask, loss_fcn)
        list_val_acc.append(val_acc)
        list_val_macro_f1.append(val_macro_f1)
        list_val_micro_f1.append(val_micro_f1)
        list_val_weighted_f1.append(val_weighted_f1)
        
        if stopper.step(model, model_name, val_loss, val_acc):
            break
        

        # Save the best validation accuracy and the corresponding test accuracy.
        if best_val_acc < val_acc:
            best_val_acc = val_acc


        if e % 100 == 0:
            print('In epoch {}, loss: {:.4f}'.format(e, loss))
            print('train acc: {:.4f}, val acc: {:.4f} (best {:.4f})'.format(train_acc, val_acc, best_val_acc))
            print('train macro_f1: {:.4f}, val macro_f1: {:.4f}'.format(train_macro_f1, val_macro_f1))
            print('train micro_f1: {:.4f}, val micro_f1: {:.4f}'.format(train_micro_f1, val_micro_f1))
            print('train weighted_f1: {:.4f}, val weighted_f1: {:.4f}'.format(train_weighted_f1, val_weighted_f1))
            print("-----------------------------")
    
    stopper.load_checkpoint(model, model_name)
    test_loss, test_acc, test_macro_f1, test_micro_f1, test_weighted_f1 = evaluate(model, g, features, labels, test_mask, loss_fcn)
    print('test acc: {:.4f}, test macro_f1: {:.4f}, test micro_f1: {:.4f}, test weighted_f1: {:.4f}'.format(test_acc, test_macro_f1, test_micro_f1, test_weighted_f1))
            
    
    
    results =  pd.DataFrame({'loss': list_train_loss, 
                'train_acc': list_train_acc, 
                'val_acc':list_val_acc,
                'train_macro_f1': list_train_macro_f1,
                'val_macro_f1': list_val_macro_f1,
                'train_micro_f1': list_train_micro_f1,
                'val_micro_f1': list_val_micro_f1,
                'train_weighted_f1': list_train_weighted_f1,
                'val_weighted_f1': list_val_weighted_f1,
               })
    
    dirs = os.path.join('results/' + model_name)
    if not os.path.exists(dirs):
        os.makedirs(dirs)
            
    f_out = open(os.path.join(dirs + '/' + 'lr' + str(lr) +'.pkl'), 'wb')
    pickle.dump(results, f_out)

### GCN 

In [53]:
from dgl.nn import GraphConv 

In [54]:
class GCN(torch.nn.Module):
    def __init__(self, in_feats, n_hidden, n_classes, n_layers, activation, dropout):
        """
        :param in_feats[int]: dimension of input features
        :param n_hidden[int]: number of hidden units
        :param n_classes[int]: number of classes
        :param n_layers[int]: number of gcn layers
        :param activation[str]: 
        :param dropout[float]: 
        """
        super(GCN, self).__init__()
        self.activation = activation
        self.layers = nn.ModuleList()
        # Input layer
        self.layers.append(GraphConv(in_feats, n_hidden, activation=self.activation))
        # Hidden layer
        for i in range(n_layers-1):
            self.layers.append(GraphConv(n_hidden, n_hidden, activation=self.activation))
        # Output layer
        self.layers.append(GraphConv(n_hidden, n_classes))
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, g, features):
        h = features
        for i, layer in enumerate(self.layers):
            if i !=0:
                h = self.dropout(h)
            h = layer(g, h)
        return h
    
    def embedding(self, g, x, nodes=None):
        """
        Returns the embeddings of the input nodes
        Parameters
        ----------
        nodes: Tensor, optional
            Input nodes, if set `None`, will return all the node embedding.
        Returns
        -------
        Tensor
            Node embedding.
        """
        h = x
        for l, layer in enumerate(self.layers):
            if l != len(self.layers) - 1:
                h = layer(g, h)
                h = self.activation(h)
        
        dirs = os.path.join('embs/' + model_name)
        if not os.path.exists(dirs):
            os.makedirs(dirs)
        
        f_out = open(dirs + '/' + 'lr' + str(lr) +'.pkl', 'wb')
        pickle.dump(h, f_out)
        return h

In [55]:
model_GCN = GCN(
            in_feats=graph_cb12.ndata['feature'].shape[1],
            n_hidden=128,
            n_classes=dataset_cb12.num_classes,
            n_layers=1,
            activation=F.relu,
            dropout=0.0
           )
train(graph_cb12, model_GCN, 'GCN', lr=0.001, weight_decay=0.0005, epoch=1000)

Number of classes: 16




In epoch 0, loss: 2.7702
train acc: 0.0696, val acc: 0.1584 (best 0.1584)
train macro_f1: 0.0451, val macro_f1: 0.0639
train micro_f1: 0.0696, val micro_f1: 0.1584
train weighted_f1: 0.0644, val weighted_f1: 0.1649
-----------------------------




In epoch 100, loss: 1.3064
train acc: 0.6784, val acc: 0.6755 (best 0.6761)
train macro_f1: 0.5085, val macro_f1: 0.3941
train micro_f1: 0.6784, val micro_f1: 0.6755
train weighted_f1: 0.6531, val weighted_f1: 0.6497
-----------------------------




In epoch 200, loss: 0.8996
train acc: 0.7479, val acc: 0.6994 (best 0.7016)
train macro_f1: 0.6959, val macro_f1: 0.5772
train micro_f1: 0.7479, val micro_f1: 0.6994
train weighted_f1: 0.7433, val weighted_f1: 0.6952
-----------------------------




EarlyStopping counter: 1 out of 100




EarlyStopping counter: 2 out of 100




EarlyStopping counter: 1 out of 100




EarlyStopping counter: 2 out of 100




EarlyStopping counter: 1 out of 100




EarlyStopping counter: 2 out of 100




EarlyStopping counter: 3 out of 100




EarlyStopping counter: 4 out of 100




EarlyStopping counter: 5 out of 100




EarlyStopping counter: 6 out of 100




EarlyStopping counter: 7 out of 100




EarlyStopping counter: 8 out of 100




EarlyStopping counter: 9 out of 100




EarlyStopping counter: 10 out of 100




EarlyStopping counter: 11 out of 100




EarlyStopping counter: 12 out of 100




EarlyStopping counter: 13 out of 100




EarlyStopping counter: 14 out of 100




EarlyStopping counter: 15 out of 100




EarlyStopping counter: 16 out of 100




EarlyStopping counter: 17 out of 100




EarlyStopping counter: 18 out of 100




EarlyStopping counter: 19 out of 100




EarlyStopping counter: 20 out of 100




EarlyStopping counter: 21 out of 100
In epoch 300, loss: 0.7699
train acc: 0.7710, val acc: 0.7032 (best 0.7059)
train macro_f1: 0.7407, val macro_f1: 0.6144
train micro_f1: 0.7710, val micro_f1: 0.7032
train weighted_f1: 0.7688, val weighted_f1: 0.7018
-----------------------------




EarlyStopping counter: 22 out of 100




EarlyStopping counter: 23 out of 100




EarlyStopping counter: 24 out of 100




EarlyStopping counter: 25 out of 100




EarlyStopping counter: 26 out of 100




EarlyStopping counter: 27 out of 100




EarlyStopping counter: 28 out of 100




EarlyStopping counter: 29 out of 100




EarlyStopping counter: 30 out of 100




EarlyStopping counter: 31 out of 100




EarlyStopping counter: 32 out of 100




EarlyStopping counter: 33 out of 100




EarlyStopping counter: 34 out of 100




EarlyStopping counter: 35 out of 100




EarlyStopping counter: 36 out of 100




EarlyStopping counter: 37 out of 100




EarlyStopping counter: 38 out of 100




EarlyStopping counter: 39 out of 100




EarlyStopping counter: 40 out of 100




EarlyStopping counter: 41 out of 100




EarlyStopping counter: 42 out of 100




EarlyStopping counter: 43 out of 100




EarlyStopping counter: 44 out of 100




EarlyStopping counter: 45 out of 100




EarlyStopping counter: 46 out of 100




EarlyStopping counter: 47 out of 100




EarlyStopping counter: 48 out of 100




EarlyStopping counter: 49 out of 100




EarlyStopping counter: 50 out of 100




EarlyStopping counter: 51 out of 100




EarlyStopping counter: 52 out of 100




EarlyStopping counter: 53 out of 100




EarlyStopping counter: 54 out of 100




EarlyStopping counter: 55 out of 100




EarlyStopping counter: 56 out of 100




EarlyStopping counter: 57 out of 100




EarlyStopping counter: 58 out of 100




EarlyStopping counter: 59 out of 100




EarlyStopping counter: 60 out of 100




EarlyStopping counter: 61 out of 100




EarlyStopping counter: 62 out of 100




EarlyStopping counter: 63 out of 100




EarlyStopping counter: 64 out of 100




EarlyStopping counter: 65 out of 100




EarlyStopping counter: 66 out of 100




EarlyStopping counter: 67 out of 100




EarlyStopping counter: 68 out of 100




EarlyStopping counter: 69 out of 100




EarlyStopping counter: 70 out of 100




EarlyStopping counter: 71 out of 100




EarlyStopping counter: 72 out of 100




EarlyStopping counter: 73 out of 100




EarlyStopping counter: 74 out of 100




EarlyStopping counter: 75 out of 100




EarlyStopping counter: 76 out of 100




EarlyStopping counter: 77 out of 100




EarlyStopping counter: 78 out of 100




EarlyStopping counter: 79 out of 100




EarlyStopping counter: 80 out of 100




EarlyStopping counter: 81 out of 100




EarlyStopping counter: 82 out of 100




EarlyStopping counter: 83 out of 100




EarlyStopping counter: 84 out of 100




EarlyStopping counter: 85 out of 100




EarlyStopping counter: 86 out of 100




EarlyStopping counter: 87 out of 100




EarlyStopping counter: 88 out of 100




EarlyStopping counter: 89 out of 100




EarlyStopping counter: 90 out of 100




EarlyStopping counter: 91 out of 100




EarlyStopping counter: 92 out of 100




EarlyStopping counter: 93 out of 100




EarlyStopping counter: 94 out of 100




EarlyStopping counter: 95 out of 100




EarlyStopping counter: 96 out of 100




EarlyStopping counter: 97 out of 100




EarlyStopping counter: 98 out of 100




EarlyStopping counter: 99 out of 100




EarlyStopping counter: 100 out of 100
test acc: 0.6833, test macro_f1: 0.5756, test micro_f1: 0.6833, test weighted_f1: 0.6799
