In [1]:
from tqdm import tqdm
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import MultiStepLR
from data import *
from model.gat import *
from util.misc import CSVLogger

In [2]:
class Args:
    pass
args = Args()
args.__dict__ = {
    "batch_size":32,
    "dataset":'USPTO50K',
    "epochs":80,
    "exp_name":'USPTO50K_typed',
    "gat_layers":3, 
    "heads":4, 
    "hidden_dim":128, 
    "in_dim":714, 
    "load":False, 
    "logdir":'logs', 
    "lr":0.0005, 
    "seed":123, 
    "test_on_train":False, 
    "test_only":False, 
    "typed":True, 
    "use_cpu":True, 
    "valid_only":False
}

In [3]:
def collate(data):
    return map(list, zip(*data))
    
batch_size = args.batch_size
epochs = args.epochs
data_root = os.path.join('data', args.dataset)
args.exp_name = args.dataset
if args.typed:
    args.in_dim += 10
    args.exp_name += '_typed'
else:
    args.exp_name += '_untyped'
print(args)

test_id = '{}'.format(args.logdir)
filename = 'logs/' + test_id + '.csv'
csv_logger = CSVLogger(
    args=args,
    fieldnames=['epoch', 'train_acc', 'valid_acc', 'train_loss'],
    filename=filename,
)

GAT_model = GATNet(
    in_dim=args.in_dim,
    num_layers=args.gat_layers,
    hidden_dim=args.hidden_dim,
    heads=args.heads,
    use_gpu=(args.use_cpu == False),
)

if args.use_cpu:
    device = 'cpu'
else:
    GAT_model = GAT_model.cuda()
    device = 'cuda:0'

if args.load:
    GAT_model.load_state_dict(
        torch.load('checkpoints/{}_checkpoint.pt'.format(args.exp_name),
                    map_location=torch.device(device)), )
    args.lr *= 0.2
    milestones = []
else:
    milestones = [20, 40, 60, 80]

optimizer = torch.optim.Adam([{
    'params': GAT_model.parameters()
}],
                                lr=args.lr)
scheduler = MultiStepLR(optimizer, milestones=milestones, gamma=0.2)

<__main__.Args object at 0x000000C77E657CC8>


In [4]:
valid_data = RetroCenterDatasets(root=data_root, data_split='valid')
valid_dataloader = DataLoader(valid_data,
                                batch_size=4 * batch_size,
                                shuffle=False,
                                num_workers=0,
                                collate_fn=collate)



Counter({1: 3482, 0: 1415, 2: 102, 9: 1, 17: 1})


In [5]:
train_data = RetroCenterDatasets(root=data_root, data_split='train')
train_dataloader = DataLoader(train_data,
                            batch_size=batch_size,
                            shuffle=True,
                            num_workers=0,
                            collate_fn=collate)

Counter({1: 27851, 0: 11296, 2: 849, 3: 4, 4: 4, 10: 2, 7: 1, 13: 1})


In [71]:
progress_bar = tqdm(train_dataloader)


  0%|          | 0/1251 [00:00<?, ?it/s]

In [7]:
for i, data in enumerate(progress_bar):
    rxn_class, x_pattern_feat, x_atom, x_adj, x_graph, y_adj, disconnection_num = data
    #print(i,rxn_class,x_pattern_feat,x_atom,x_adj,x_graph,y_adj,disconnection_num)
    x_atom = list(map(lambda x: torch.from_numpy(x).float(), x_atom))
    x_pattern_feat = list(
        map(lambda x: torch.from_numpy(x).float(), x_pattern_feat))
    x_atom = list(
        map(lambda x, y: torch.cat([x, y], dim=1), x_atom,
            x_pattern_feat))

    if args.typed:
        rxn_class = list(
            map(lambda x: torch.from_numpy(x).float(), rxn_class))
        x_atom = list(
            map(lambda x, y: torch.cat([x, y], dim=1), x_atom,
                rxn_class))

    x_atom = torch.cat(x_atom, dim=0)
    disconnection_num = torch.LongTensor(disconnection_num)
    if not args.use_cpu:
        x_atom = x_atom.cuda()
        disconnection_num = disconnection_num.cuda()

    x_adj = list(map(lambda x: torch.from_numpy(np.array(x)), x_adj))
    y_adj = list(map(lambda x: torch.from_numpy(np.array(x)), y_adj))
    if not args.use_cpu:
        x_adj = [xa.cuda() for xa in x_adj]
        y_adj = [ye.cuda() for ye in y_adj]

    mask = list(map(lambda x: x.view(-1, 1).bool(), x_adj))    
    print(len(mask),len(x_adj),len(y_adj))
    print(len(mask[0]),len(x_adj[0]),len(y_adj[0]))
    print(mask[0].size(),x_adj[0].size(),y_adj[0].size())
    print(y_adj[0])
    #print(mask[1].size(),x_adj[1].size(),y_adj[1].view(-1, 1).size())
    # bond_connections = list(
    #     map(lambda x, y: torch.masked_select(x.view(-1, 1), y), y_adj,mask)
    # )
    break

32 32 32
729 27 27
torch.Size([729, 1]) torch.Size([27, 27]) torch.Size([27, 27])
tensor([[ True,  True, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False],
        [ True,  True,  True, False, False, False, False, False,  True, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False],
        [False,  True,  True,  True, False, False, False,  True, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False],
        [False, False,  True,  True,  True, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False],
        [False, False, False,  True,  True,  True,

In [8]:
GAT_model.zero_grad()
# batch graph
g_dgl = dgl.batch(x_graph)

#h_pred, e_pred = GAT_model(g_dgl, x_atom)

In [30]:
GAT_model = GATNet(
    in_dim=args.in_dim,
    num_layers=args.gat_layers,
    hidden_dim=args.hidden_dim,
    heads=args.heads,
    use_gpu=(args.use_cpu == False),
)
GAT_model(g_dgl, x_atom)

forward


RuntimeError: size mismatch, m1: [800 x 587], m2: [724 x 128] at C:\w\1\s\tmp_conda_3.7_183424\conda\conda-bld\pytorch_1570818936694\work\aten\src\TH/generic/THTensorMath.cpp:197

In [None]:
#https://xinhaoli74.github.io/posts/2019/12/DGL-Basic01-Data/

#https://dsgiitr.com/work/graph_nets/
import matplotlib.pyplot as plt
import networkx as nx

print("g_dgl = ",g_dgl)
print("x_graph len=",len(x_graph))
print("sum nodes=",sum(c.number_of_nodes() for c in x_graph),"sum edges=",sum(c.number_of_edges() for c in x_graph))
print(x_graph[0],g_dgl.number_of_nodes())
g = x_graph[0]
# print(g)
# print(g.nodes())
# print(g.edges())
nx.draw(g.to_networkx(), with_labels=True)
plt.show()


In [98]:
x_atom.size()

torch.Size([800, 587])

In [39]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl


class GATLayer(nn.Module):
    def __init__(self, in_dim, out_dim, e_in_dim, e_out_dim):
        super(GATLayer, self).__init__()
        self.embed_node = nn.Linear(in_dim, out_dim, bias=False)
        self.attn_fc = nn.Linear(2 * out_dim + e_in_dim, 1, bias=False)
        self.to_node_fc = nn.Linear(out_dim + e_in_dim, out_dim, bias=False)
        self.edge_linear = nn.Linear(2 * out_dim + e_in_dim,
                                     e_out_dim,
                                     bias=False)

    def edge_attention(self, edges):
        z2 = torch.cat([edges.src['h'], edges.dst['h'], edges.data['w']],
                       dim=1)
        a = self.attn_fc(z2)
        return {'e': F.leaky_relu(a, negative_slope=0.1)}

    def message_func(self, edges):
        return {
            'h': edges.src['h'],
            'e': edges.data['e'],
            'w': edges.data['w']
        }

    def reduce_func(self, nodes):
        alpha = F.softmax(nodes.mailbox['e'], dim=1)
        t = torch.cat([nodes.mailbox['h'], nodes.mailbox['w']], dim=-1)
        t = self.to_node_fc(t)
        h = torch.sum(alpha * t, dim=1)
        return {'h': h}

    def edge_calc(self, edges):
        z2 = torch.cat([edges.src['h'], edges.dst['h'], edges.data['w']],
                       dim=1)
        w = self.edge_linear(z2)
        return {'w': w}

    def forward(self, g, h):
        with g.local_scope():
            g.ndata['h'] = self.embed_node(h)
            g.apply_edges(self.edge_attention)
            g.update_all(self.message_func, self.reduce_func)
            g.apply_edges(self.edge_calc)
            # h_readout = dgl.mean_nodes(g, 'h')
            # gh = dgl.broadcast_nodes(g, h_readout)
            # return torch.cat((g.ndata['h'], gh), dim=1), g.edata['w']
            return g.ndata['h'], g.edata['w']


class MultiHeadGATLayer(nn.Module):
    def __init__(self,
                 in_dim,
                 out_dim,
                 e_in_dim,
                 e_out_dim,
                 num_heads,
                 use_gpu=True):
        super(MultiHeadGATLayer, self).__init__()
        self.heads = nn.ModuleList()
        self.use_gpu = use_gpu
        for i in range(num_heads):
            self.heads.append(GATLayer(in_dim, out_dim, e_in_dim, e_out_dim))

    def forward(self, g, h, merge):
        if self.use_gpu:
            g.edata['w'] = g.edata['w'].cuda()
        outs = list(map(lambda x: x(g, h), self.heads))
        outs = list(map(list, zip(*outs)))
        head_outs = outs[0]
        edge_outs = outs[1]
        if merge == 'flatten':
            head_outs = torch.cat(head_outs, dim=1)
            edge_outs = torch.cat(edge_outs, dim=1)
        elif merge == 'mean':
            head_outs = torch.mean(torch.stack(head_outs), dim=0)
            edge_outs = torch.mean(torch.stack(edge_outs), dim=0)
        g.edata['w'] = edge_outs
        return head_outs, edge_outs


class GATNet(nn.Module):
    def __init__(self, in_dim, hidden_dim, num_layers, heads, use_gpu=True):
        super(GATNet, self).__init__()
        self.num_layers = num_layers
        self.gat = nn.ModuleList()

        self.gat.append(
            MultiHeadGATLayer(in_dim, hidden_dim, 12, 128, heads, use_gpu))
        for l in range(1, num_layers):
            self.gat.append(
                MultiHeadGATLayer(
                    hidden_dim * heads,
                    hidden_dim,
                    128 * heads,
                    128,
                    heads,
                    use_gpu,
                ))

        self.linear_e = nn.Sequential(
            nn.Linear(128 * 2, 32),
            nn.ReLU(inplace=True),
            nn.Dropout(0.2),
            nn.Linear(32, 1),
        )

        self.linear_h = nn.Sequential(
            nn.Linear(128, 32),
            nn.ReLU(inplace=True),
            nn.Dropout(0.2),
            nn.Linear(32, 3),
        )

    def forward(self, g, h):
        print("forward",g.number_of_nodes(),g.number_of_edges(),h)
        print("gat",self.gat)
        
        for l in range(self.num_layers - 1):
            #print("forward","l=",l,self.gat)
            h, _ = self.gat[l](g, h, merge='flatten')
            h = F.elu(h)
        h, e = self.gat[-1](g, h, merge='mean')

        # Graph level prediction
        g.ndata['h'] = h
        h_readout = dgl.mean_nodes(g, 'h')
        h_pred = self.linear_h(h_readout)

        # Edge prediction
        eh = dgl.broadcast_edges(g, h_readout)
        e_fused = torch.cat((eh, e), dim=1)
        e_pred = self.linear_e(e_fused)

        return h_pred, e_pred

GAT_model = GATNet(
    in_dim=args.in_dim,
    num_layers=args.gat_layers,
    hidden_dim=args.hidden_dim,
    heads=args.heads,
    use_gpu=(args.use_cpu == False),
)
#print(args.in_dim)
x_atom.size()
#GAT_model(g_dgl, x_atom)

#m1: [800 x 587], m2: [724 x 128]
# [batchsize x feature]  [input x output]
# make sure feature == input
#800x587 came from   x_atom

torch.Size([800, 587])

In [74]:
for i, data in enumerate(progress_bar):
    rxn_class, x_pattern_feat, x_atom, x_adj, x_graph, y_adj, disconnection_num = data
    print(y_adj)
    break
#next(iter(progress_bar))
#print(list(item))
#item = next(iter(train_dataloader))

[array([[ True,  True, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False],
       [ True,  True,  True,  True, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False],
       [False,  True,  True, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False],
       [False,  True, False,  True,  True, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False],
       [False, False, False,  True,  True,  True, False, False, False,
        False, False, False, False, False, False, False, False, False,
         True, False, False],
       [False, False, False, False,  True,  True, False, False, False,
        False, False, False, False, False, False, False, False, Fals

In [75]:
len(x_atom)

32

16

32 32 32 32
