In [1]:
import hgfp
import torch
import dgl
import numpy as np
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'
from matplotlib import pyplot as plt




In [2]:
ds = list(hgfp.data.parm_at_Frosst.df.batched(num=1000, batch_size=10, use_fp=False))
ds_tr, ds_vl, ds_te = hgfp.data.utils.split(ds, 10, 10)



In [3]:
from dgl.nn import pytorch as dgl_pytorch

In [4]:
class GN(torch.nn.Module):
    def __init__(self, model, kwargs):
        super(GN, self).__init__()
        self.gn = model(64, 64, **kwargs)
        
    def forward(self, g, x):
        g_sub = dgl.to_homo(
            g.edge_type_subgraph(['atom_neighbors_atom']))
        x = self.gn(g_sub, x)
        return x        

In [5]:
class Classifier(torch.nn.Module):
    def __init__(self, in_dim=128, out_dim=256, n_classes=45):
        super(Classifier, self).__init__()
        self.d = torch.nn.Linear(in_dim, out_dim)
        self.c = torch.nn.Linear(out_dim, n_classes)
        
    def forward(self, x):
        y_hat = self.c(
                torch.nn.functional.tanh(
                    self.d(
                        x)))
        
        return y_hat
        

In [8]:
class Net(torch.nn.Module):
    def __init__(self, model, kwargs):
        super(Net, self).__init__()
        self.f_in = torch.nn.Sequential(
            torch.nn.Linear(117, 64),
            torch.nn.Tanh())
        
        self.gn0 = GN(model, kwargs)
        self.gn1 = GN(model, kwargs)
        self.gn2 = GN(model, kwargs)
        
        self.c = Classifier(64, 64, 45)
        
    def forward(self, g):
        x = g.nodes['atom'].data['h0']
        x = self.f_in(x)
        x = self.gn0(g, x)
        x = torch.nn.functional.tanh(x)
        x = self.gn1(g, x)
        x = torch.nn.functional.tanh(x)
        x = self.gn2(g, x)
        x = self.c(x)
        return x
        

In [9]:
for model_name, model in {
        # 'GraphConv': [dgl_pytorch.conv.GraphConv, {}],
        # 'TAGConv': [dgl_pytorch.conv.TAGConv, {}],
        # 'EdgeConv': [dgl_pytorch.conv.EdgeConv, {}],
        'SAGEConvMean': [dgl_pytorch.conv.SAGEConv, {'aggregator_type': 'mean'}],
        # 'SAGEConvGCN': [dgl_pytorch.conv.SAGEConv, {'aggregator_type': 'gcn'}],
        # 'SAGEConvPool': [dgl_pytorch.conv.SAGEConv, {'aggregator_type': 'pool'}],
        # 'SAGEConvLSTM': [dgl_pytorch.conv.SAGEConv, {'aggregator_type': 'lstm'}],
        # 'SGConv': [dgl_pytorch.conv.SGConv, {}]
}.items():

    print(model_name)
    net=Net(model[0], model[1])
    opt = torch.optim.Adam(list(net.parameters()), 1e-3)
    loss_fn = torch.nn.CrossEntropyLoss()
    namespace = __import__(__name__)

    for part in ['tr', 'vl', 'te']:
        exec('accuracy_' + part + '= []')

    for _ in range(500):
        for g, y in ds_tr:
            opt.zero_grad()
            y_hat = net(g)
            loss = loss_fn(y_hat, torch.where(torch.gt(y, 0))[1])
            loss.backward()
            opt.step()

        net.eval()
        for part in ['tr', 'vl', 'te']:
            y_hat = torch.cat([torch.argmax(net(g), dim=1) for g, y in getattr(
                namespace, 'ds_' + part)], dim=0).detach().numpy()

            y = torch.cat([torch.argmax(y, dim=1) for g, y in getattr(
                namespace, 'ds_' + part)], dim=0).detach().numpy()

            getattr(namespace, 'accuracy_' + part).append(
                1 - np.divide(
                    np.count_nonzero(y_hat - y),
                y_hat.shape[0]))

        net.train()

    plt.style.use('fivethirtyeight')
    plt.figure()
    plt.plot(accuracy_tr, label='training')
    plt.plot(accuracy_vl, label='validation')
    plt.ylim(0, 1)
    plt.legend()
    plt.ylabel('accuracy')
    plt.xlabel('n_epochs')
    plt.title(model_name)
    plt.tight_layout()
    plt.savefig(model_name + '.png', dpi=500)
    plt.close()

SAGEConvMean




In [10]:
net.eval()
from sklearn.metrics import confusion_matrix

y_hat = torch.cat([torch.argmax(net(g), dim=1) for g, y in ds_vl+ds_te], axis=0)
y = torch.cat([torch.argmax(y, dim=1) for g, y in ds_vl+ds_te], axis=0)

print(np.count_nonzero(y - y_hat) / y_hat.shape[0])

0.015371621621621622




In [12]:
element_to_idx = {b'BR': 0, b'C': 1, b'C2': 2, b'CA': 3, b'CB': 4, b'CC': 5, b'CJ': 6, b'CL': 7, b'CM': 8, b'CP': 9, b'CR': 10, b'CT': 11, b'CW': 12, b'Cstar': 13, b'F': 14, b'H': 15, b'H1': 16, b'H2': 17, b'H3': 18, b'H4': 19, b'H5': 20, b'HA': 21, b'HC': 22, b'HO': 23, b'HP': 24, b'HX': 25, b'I': 26, b'N': 27, b'N2': 28, b'N3': 29, b'NA': 30, b'NB': 31, b'NC': 32, b'NL': 33, b'Nstar': 34, b'Nu': 35, b'O': 36, b'O2': 37, b'OH': 38, b'OS': 39, b'Ou': 40, b'P': 41, b'S': 42, b'SO': 43, b'Su': 44}

In [13]:
idx_to_element = {v: k.decode("utf-8") for k, v in element_to_idx.items()}

In [14]:
import pandas as pd
df_cm = pd.DataFrame(confusion_matrix(y, y_hat, labels=list(range(1, 46))),
    [v for k, v in idx_to_element.items()],
    [v for k, v in idx_to_element.items()])

In [15]:
count_matrix = df_cm.values
wrong_idxs = np.stack(np.where(np.greater(count_matrix, 0)), axis=1)
wrong_idxs = wrong_idxs[wrong_idxs[:, 0] != wrong_idxs[:, 1]]
wrong_count = np.array([count_matrix[idxs[0]][idxs[1]] for idxs in wrong_idxs])
wrong_count_argsort = np.flip(np.argsort(wrong_count))
for idx in wrong_count_argsort:
    print('%s -> %s : %s'%(
        idx_to_element[wrong_idxs[idx][0]],
        idx_to_element[wrong_idxs[idx][1]],
        wrong_count[idx]))

CB -> CA : 7
CM -> CA : 7
CT -> CB : 6
CA -> CB : 5
CA -> CT : 5
CB -> CT : 5
C2 -> CL : 4
CA -> C2 : 4
H3 -> H5 : 4
CA -> CM : 3
CT -> CA : 3
CP -> CB : 2
C2 -> CA : 2
CL -> C2 : 2
CP -> BR : 2
NL -> N : 2
BR -> CP : 2
CT -> CL : 2
HO -> H : 2
N2 -> NL : 2
N -> I : 1
C2 -> BR : 1
NB -> NA : 1
NA -> N : 1
C2 -> CP : 1
N2 -> N : 1
N2 -> I : 1
N -> N2 : 1
CA -> CW : 1
CB -> CL : 1
CW -> CP : 1
N -> Cstar : 1
CJ -> P : 1
I -> N3 : 1
CL -> CB : 1
I -> N : 1
CM -> CB : 1
H2 -> H1 : 1
CP -> CW : 1
NL -> NA : 1
