In [1]:
from tqdm import tqdm
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import MultiStepLR
from data import *
from model.gat import *
from util.misc import CSVLogger

In [21]:
class Args:
    pass
args = Args()
args.__dict__ = {
    "batch_size":32,
    "dataset":'USPTO50K',
    "epochs":80,
    "exp_name":'USPTO50K_typed',
    "gat_layers":3, 
    "heads":4, 
    "hidden_dim":128, 
    "in_dim":714, 
    "load":False, 
    "logdir":'logs', 
    "lr":0.0005, 
    "seed":123, 
    "test_on_train":False, 
    "test_only":False, 
    "typed":True, 
    "use_cpu":True, 
    "valid_only":False
}

In [23]:
def collate(data):
    return map(list, zip(*data))
    
batch_size = args.batch_size
epochs = args.epochs
data_root = os.path.join('data', args.dataset)
args.exp_name = args.dataset
if args.typed:
    args.in_dim += 10
    args.exp_name += '_typed'
else:
    args.exp_name += '_untyped'
print(args)

test_id = '{}'.format(args.logdir)
filename = 'logs/' + test_id + '.csv'
csv_logger = CSVLogger(
    args=args,
    fieldnames=['epoch', 'train_acc', 'valid_acc', 'train_loss'],
    filename=filename,
)

GAT_model = GATNet(
    in_dim=args.in_dim,
    num_layers=args.gat_layers,
    hidden_dim=args.hidden_dim,
    heads=args.heads,
    use_gpu=(args.use_cpu == False),
)

if args.use_cpu:
    device = 'cpu'
else:
    GAT_model = GAT_model.cuda()
    device = 'cuda:0'

if args.load:
    GAT_model.load_state_dict(
        torch.load('checkpoints/{}_checkpoint.pt'.format(args.exp_name),
                    map_location=torch.device(device)), )
    args.lr *= 0.2
    milestones = []
else:
    milestones = [20, 40, 60, 80]

optimizer = torch.optim.Adam([{
    'params': GAT_model.parameters()
}],
                                lr=args.lr)
scheduler = MultiStepLR(optimizer, milestones=milestones, gamma=0.2)

<__main__.Args object at 0x00000076F3040608>


In [24]:
valid_data = RetroCenterDatasets(root=data_root, data_split='valid')
valid_dataloader = DataLoader(valid_data,
                                batch_size=4 * batch_size,
                                shuffle=False,
                                num_workers=0,
                                collate_fn=collate)



Counter({1: 3482, 0: 1415, 2: 102, 9: 1, 17: 1})


In [27]:
train_data = RetroCenterDatasets(root=data_root, data_split='train')
train_dataloader = DataLoader(train_data,
                            batch_size=batch_size,
                            shuffle=True,
                            num_workers=0,
                            collate_fn=collate)

Counter({1: 27851, 0: 11296, 2: 849, 3: 4, 4: 4, 10: 2, 7: 1, 13: 1})


In [32]:
progress_bar = tqdm(train_dataloader)


  0%|          | 0/1251 [00:00<?, ?it/s]

In [59]:
for i, data in enumerate(progress_bar):
    rxn_class, x_pattern_feat, x_atom, x_adj, x_graph, y_adj, disconnection_num = data
    #print(i,rxn_class,x_pattern_feat,x_atom,x_adj,x_graph,y_adj,disconnection_num)
    x_atom = list(map(lambda x: torch.from_numpy(x).float(), x_atom))
    x_pattern_feat = list(
        map(lambda x: torch.from_numpy(x).float(), x_pattern_feat))
    x_atom = list(
        map(lambda x, y: torch.cat([x, y], dim=1), x_atom,
            x_pattern_feat))

    if args.typed:
        rxn_class = list(
            map(lambda x: torch.from_numpy(x).float(), rxn_class))
        x_atom = list(
            map(lambda x, y: torch.cat([x, y], dim=1), x_atom,
                rxn_class))

    x_atom = torch.cat(x_atom, dim=0)
    disconnection_num = torch.LongTensor(disconnection_num)
    if not args.use_cpu:
        x_atom = x_atom.cuda()
        disconnection_num = disconnection_num.cuda()

    x_adj = list(map(lambda x: torch.from_numpy(np.array(x)), x_adj))
    y_adj = list(map(lambda x: torch.from_numpy(np.array(x)), y_adj))
    if not args.use_cpu:
        x_adj = [xa.cuda() for xa in x_adj]
        y_adj = [ye.cuda() for ye in y_adj]

    mask = list(map(lambda x: x.view(-1, 1).bool(), x_adj))    
    print(len(mask),len(x_adj),len(y_adj))
    print(len(mask[0]),len(x_adj[0]),len(y_adj[0]))
    print(mask[0].size(),x_adj[0].size(),y_adj[0].size())
    print(y_adj[0])
    #print(mask[1].size(),x_adj[1].size(),y_adj[1].view(-1, 1).size())
    # bond_connections = list(
    #     map(lambda x, y: torch.masked_select(x.view(-1, 1), y), y_adj,mask)
    # )
    break

32 32 32
361 19 19
torch.Size([361, 1]) torch.Size([19, 19]) torch.Size([19, 19])
tensor([[ True,  True, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False],
        [ True,  True,  True, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False],
        [False,  True,  True,  True, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False],
        [False, False,  True,  True,  True,  True, False, False, False, False,
         False, False, False, False, False, False, False, False, False],
        [False, False, False,  True,  True, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False],
        [False, False, False,  True, False,  True,  True, False, False, False,
         False, False, False,  True, False, False, False, False, False],
      

In [64]:
y = torch.empty(19,19)
print(y.size())
y.view(-1,1)

torch.Size([19, 19])


tensor([[-7.2330e+32],
        [ 1.6535e-43],
        [-7.2331e+32],
        [ 1.6535e-43],
        [-7.2332e+32],
        [ 1.6535e-43],
        [-7.2333e+32],
        [ 1.6535e-43],
        [-7.2333e+32],
        [ 1.6535e-43],
        [-7.2334e+32],
        [ 1.6535e-43],
        [-7.2335e+32],
        [ 1.6535e-43],
        [-7.2336e+32],
        [ 1.6535e-43],
        [-7.2336e+32],
        [ 1.6535e-43],
        [-7.2337e+32],
        [ 1.6535e-43],
        [-7.2338e+32],
        [ 1.6535e-43],
        [-7.4399e+32],
        [ 1.6535e-43],
        [-7.2339e+32],
        [ 1.6535e-43],
        [-7.2339e+32],
        [ 1.6535e-43],
        [-7.2340e+32],
        [ 1.6535e-43],
        [-7.2341e+32],
        [ 1.6535e-43],
        [-7.2342e+32],
        [ 1.6535e-43],
        [-7.2342e+32],
        [ 1.6535e-43],
        [-7.2343e+32],
        [ 1.6535e-43],
        [-7.2344e+32],
        [ 1.6535e-43],
        [-7.2345e+32],
        [ 1.6535e-43],
        [-7.2345e+32],
        [ 1