In [1]:
import hgfp
import torch
import dgl



In [2]:
element_to_idx = {b'BR': 0, b'C': 1, b'C2': 2, b'CA': 3, b'CB': 4, b'CC': 5, b'CJ': 6, b'CL': 7, b'CM': 8, b'CP': 9, b'CR': 10, b'CT': 11, b'CW': 12, b'Cstar': 13, b'F': 14, b'H': 15, b'H1': 16, b'H2': 17, b'H3': 18, b'H4': 19, b'H5': 20, b'HA': 21, b'HC': 22, b'HO': 23, b'HP': 24, b'HX': 25, b'I': 26, b'N': 27, b'N2': 28, b'N3': 29, b'NA': 30, b'NB': 31, b'NC': 32, b'NL': 33, b'Nstar': 34, b'Nu': 35, b'O': 36, b'O2': 37, b'OH': 38, b'OS': 39, b'Ou': 40, b'P': 41, b'S': 42, b'SO': 43, b'Su': 44}

In [3]:
idx_to_element = {v: k.decode("utf-8") for k, v in element_to_idx.items()}

In [4]:
ds = list(hgfp.data.parm_at_Frosst.df.batched(num=100, batch_size=16))



In [5]:
print(len(ds))

6


In [6]:
ds_tr, ds_vl, ds_te = hgfp.data.utils.split(ds, 1, 1)

In [7]:
class GN(torch.nn.Module):
    def __init__(self, model, kwargs):
        super(GN, self).__init__()
        self.gn = model(64, 64, **kwargs)
        
    def forward(self, g, x):
        g_sub = dgl.to_homo(
            g.edge_type_subgraph(['atom_neighbors_atom']))
        x = self.gn(g_sub, x)
        return x        

class Classifier(torch.nn.Module):
    def __init__(self, in_dim=128, out_dim=256, n_classes=45):
        super(Classifier, self).__init__()
        self.d = torch.nn.Linear(in_dim, out_dim)
        self.c = torch.nn.Linear(out_dim, n_classes)
        
    def forward(self, x):
        y_hat = self.c(
                torch.nn.functional.tanh(
                    self.d(
                        x)))
        
        return y_hat
        
class Net(torch.nn.Module):
    def __init__(self, model, kwargs):
        super(Net, self).__init__()
        self.f_in = torch.nn.Sequential(
            torch.nn.Linear(117, 64),
            torch.nn.Tanh())
        
        self.gn0 = GN(model, kwargs)
        self.gn1 = GN(model, kwargs)
        self.gn2 = GN(model, kwargs)
        
        self.c = Classifier(64, 64, 45)
        
    def forward(self, g):
        x = g.nodes['atom'].data['h0']
        x = self.f_in(x)
        x = self.gn0(g, x)
        x = torch.nn.functional.tanh(x)
        x = self.gn1(g, x)
        x = torch.nn.functional.tanh(x)
        x = self.gn2(g, x)
        x = self.c(x)
        return x
        

In [11]:
from dgl.nn import pytorch as dgl_pytorch
net = Net(dgl_pytorch.conv.SAGEConv,  {'aggregator_type': 'mean'})

In [12]:
# opt = torch.optim.Adam(list(net.parameters()) + list(classifier.parameters()), 1e-3)
opt = torch.optim.Adam(list(net.parameters()), 1e-3)

In [13]:
loss_fn = torch.nn.CrossEntropyLoss()

In [14]:
for _ in range(100):
    for g, y in ds_tr:
        opt.zero_grad()
        # y_hat = classifier(net(g, return_graph=True))
        y_hat = net(g)
        loss = loss_fn(y_hat, torch.where(torch.gt(y, 0))[1])
        print(loss)
        loss.backward()
        opt.step()



tensor(3.8970, grad_fn=<NllLossBackward>)
tensor(3.6739, grad_fn=<NllLossBackward>)
tensor(3.5260, grad_fn=<NllLossBackward>)
tensor(3.3626, grad_fn=<NllLossBackward>)
tensor(3.2505, grad_fn=<NllLossBackward>)
tensor(3.1976, grad_fn=<NllLossBackward>)
tensor(3.1707, grad_fn=<NllLossBackward>)
tensor(3.0109, grad_fn=<NllLossBackward>)
tensor(2.9323, grad_fn=<NllLossBackward>)
tensor(2.9530, grad_fn=<NllLossBackward>)
tensor(2.9488, grad_fn=<NllLossBackward>)
tensor(2.7790, grad_fn=<NllLossBackward>)
tensor(2.7172, grad_fn=<NllLossBackward>)
tensor(2.7632, grad_fn=<NllLossBackward>)
tensor(2.7719, grad_fn=<NllLossBackward>)
tensor(2.5985, grad_fn=<NllLossBackward>)
tensor(2.5506, grad_fn=<NllLossBackward>)
tensor(2.6156, grad_fn=<NllLossBackward>)
tensor(2.6389, grad_fn=<NllLossBackward>)
tensor(2.4621, grad_fn=<NllLossBackward>)
tensor(2.4269, grad_fn=<NllLossBackward>)
tensor(2.4934, grad_fn=<NllLossBackward>)
tensor(2.5321, grad_fn=<NllLossBackward>)
tensor(2.3539, grad_fn=<NllLossBac

tensor(0.4582, grad_fn=<NllLossBackward>)
tensor(0.5604, grad_fn=<NllLossBackward>)
tensor(0.4150, grad_fn=<NllLossBackward>)
tensor(0.4339, grad_fn=<NllLossBackward>)
tensor(0.4457, grad_fn=<NllLossBackward>)
tensor(0.5460, grad_fn=<NllLossBackward>)
tensor(0.4032, grad_fn=<NllLossBackward>)
tensor(0.4226, grad_fn=<NllLossBackward>)
tensor(0.4337, grad_fn=<NllLossBackward>)
tensor(0.5322, grad_fn=<NllLossBackward>)
tensor(0.3919, grad_fn=<NllLossBackward>)
tensor(0.4118, grad_fn=<NllLossBackward>)
tensor(0.4222, grad_fn=<NllLossBackward>)
tensor(0.5189, grad_fn=<NllLossBackward>)
tensor(0.3810, grad_fn=<NllLossBackward>)
tensor(0.4014, grad_fn=<NllLossBackward>)
tensor(0.4111, grad_fn=<NllLossBackward>)
tensor(0.5061, grad_fn=<NllLossBackward>)
tensor(0.3705, grad_fn=<NllLossBackward>)
tensor(0.3913, grad_fn=<NllLossBackward>)
tensor(0.4005, grad_fn=<NllLossBackward>)
tensor(0.4939, grad_fn=<NllLossBackward>)
tensor(0.3604, grad_fn=<NllLossBackward>)
tensor(0.3816, grad_fn=<NllLossBac

In [16]:
net.eval()
from sklearn.metrics import confusion_matrix
for g, y in ds_tr:
    y_hat = torch.argmax(net(g), dim=1)
    y = torch.argmax(y, dim=1)
    



In [17]:
import pandas as pd
df_cm = pd.DataFrame(confusion_matrix(y, y_hat, labels=list(range(1, 46))),
    [v for k, v in idx_to_element.items()],
    [v for k, v in idx_to_element.items()])

In [18]:
import seaborn as sn
from matplotlib import pyplot as plt
plt.figure(figsize=(15, 15))
sn.heatmap(df_cm, annot=True, annot_kws={"size": 16})
plt.show()

ModuleNotFoundError: No module named 'seaborn'