In [1]:
import numpy as np
import torch
import os

In [2]:
from models.modules_causal_vel import *
from data.AL_sampler import RandomPytorchSampler
from data.datasets import *
from data.dataset_utils import *
import argparse
from torch.utils.data import DataLoader
from utils.functions import *

parser = argparse.ArgumentParser()
parser.add_argument("-f",help="Input image, directory, or npy.")
args = parser.parse_args()
args.dims = 9
args.edge_types = 2
args.decoder_hidden = 512
args.skip_first = True
args.decoder_dropout = 0.0
args.self_loop = True
args.suffix = 'valid_causal_vel_interpolation'
args.input_atoms = 6
args.variations = 4
args.train_size = None
args.temp = 0.01
args.hard = True
args.prediction_steps = 19
args.control_constraint = 1.0
args.grouped=True
args.train_bs=4
args.target_atoms =2
args.gt_A=True

args.num_atoms = args.input_atoms+args.target_atoms        
decoder = MLPDecoder_Causal(args).cuda()

if args.grouped:
    assert args.train_bs % args.variations == 0, "Grouping training set requires args.traing-bs integer times of args.variations"

    train_data = load_one_graph_data(
        args.suffix, size=args.train_size, self_loop=args.self_loop, control=True, control_nodes=args.input_atoms, variations=args.variations)
    train_sampler = RandomPytorchSampler(train_data)
    train_data_loader = DataLoader(
        train_data, batch_size=args.train_bs, shuffle=False, sampler=train_sampler)

else:
    train_data = load_one_graph_data(
        args.suffix, size=args.train_size, self_loop=args.self_loop, control=False)
    train_data_loader = DataLoader(
        train_data, batch_size=args.train_bs, shuffle=True)

off_diag = np.ones([args.num_atoms, args.num_atoms])
rel_rec = np.array(encode_onehot(np.where(off_diag)[0]), dtype=np.float32)
rel_send = np.array(encode_onehot(np.where(off_diag)[1]), dtype=np.float32)
rel_rec = torch.FloatTensor(rel_rec).cuda()
rel_send = torch.FloatTensor(rel_send).cuda()


Using ground truth A and the softmax result is tensor([[[[10.,  0.],
          [10.,  0.],
          [10.,  0.],
          [10.,  0.],
          [10.,  0.],
          [10.,  0.],
          [10.,  0.],
          [10.,  0.],
          [10.,  0.],
          [10.,  0.],
          [10.,  0.],
          [10.,  0.],
          [10.,  0.],
          [10.,  0.],
          [10.,  0.],
          [10.,  0.],
          [10.,  0.],
          [10.,  0.],
          [10.,  0.],
          [10.,  0.],
          [10.,  0.],
          [10.,  0.],
          [10.,  0.],
          [10.,  0.],
          [10.,  0.],
          [10.,  0.],
          [10.,  0.],
          [10.,  0.],
          [10.,  0.],
          [10.,  0.],
          [10.,  0.],
          [10.,  0.],
          [10.,  0.],
          [10.,  0.],
          [10.,  0.],
          [10.,  0.],
          [10.,  0.],
          [10.,  0.],
          [10.,  0.],
          [10.,  0.],
          [10.,  0.],
          [10.,  0.],
          [10.,  0.],
       

In [6]:
def denormalize(target, output, train_data):
    for i in range(target.size(1)):
        output[:,i,:,0] = ((output[:,i,:,0]+1)*(train_data.maxs[i]-train_data.mins[i]))/2+train_data.mins[i]
        target[:,i,:,0] = ((target[:,i,:,0]+1)*(train_data.maxs[i]-train_data.mins[i]))/2+train_data.mins[i]
    return target, output

In [9]:
def load_predict(weight_path, stop_ind=10, gt_A=False):
    decoder.load_state_dict(torch.load(weight_path)[0])
    decoder.eval()  
    decoder.rel_graph = torch.load(weight_path)[1].cuda()
    loss=[]
    truth = []
    pred = []
    condition =[]
    msgs=[]
    for batch_idx, all_data in enumerate(train_data_loader):
        if batch_idx<stop_ind:
            if args.grouped:
                # edge is only for calculating edge accuracy. Since we have not included that, edge is not used.
                data, which_node, edge = all_data[0].cuda(
                ), all_data[1].cuda(), all_data[2].cuda()
                output, logits, msg_hook = decoder(data, rel_rec, rel_send,
                                                   args.temp, args.hard, args.prediction_steps, [])
                control_constraint_loss = control_loss(
                    msg_hook, which_node, args.input_atoms, args.variations)*args.control_constraint
                # if batch_idx == 20:
                #     print('start3', time.time()-start3)

            else:
                data, edge = all_data[0].cuda(), all_data[1].cuda()
                output, logits, msg_hook = decoder(data, rel_rec, rel_send,
                                                   args.temp, args.hard, args.prediction_steps, [])
                control_constraint_loss = torch.zeros(1).cuda()
            msgs.append(msg_hook)
            print('batch_size', data.size(0))
            target = data[:, :, 1:, :] 
            loss_nll = nll_gaussian(output[:, -2:, :, :], target[:, -2:, :, :], 5e-5)
            print('Nll', loss_nll)
            loss.append(loss_nll.item())
            
            target, output = denormalize(target, output, train_data)
            print('Setup [shapes,colors,mus,thetas,masses,x0s]', batch_idx, target[:,:-2,0,0])
            print('Velocity',batch_idx, target[0,-2,:,0], '\n',output[0,-2,:,0])
            print('Position',batch_idx, target[0,-1,:,0], '\n',output[0,-1,:,0])
            condition.append(target[:,:-2,0,0])
            truth.append(target)
            pred.append(output)
        else:
            print('Avg nll loss', np.mean(loss))
            return loss, truth, pred, condition, msgs

In [10]:
loss, truth, pred, condition,msgs = load_predict('logs/exp2020-08-17T00:36:32.387949_train-bs_128_suffix_causal_vel_x0s_val-suffix_causal_vel_interpolation_input-atoms_6_dims_9_decoder-hidden_512_gt-A/best_decoder.pt')

batch_size 4
Nll tensor(169.1943, device='cuda:0', grad_fn=<DivBackward0>)
Setup [shapes,colors,mus,thetas,masses,x0s] 0 tensor([[ 7.0000,  8.0000,  0.1800,  0.8491,  7.0000,  8.0000],
        [ 7.0000,  8.0000,  0.1800,  0.8491,  7.0000, 12.0000],
        [ 7.0000,  8.0000,  0.1800,  0.8491,  7.0000,  5.9000],
        [ 7.0000,  8.0000,  0.1800,  0.8491,  7.0000, 14.0000]],
       device='cuda:0')
Velocity 0 tensor([  6.1912,  12.3823,  18.5735,  24.7646,  30.9558,  37.1470,  43.3381,
         49.5293,  55.7204,  61.9116,  68.1027,  74.2939,  80.4851,  86.6762,
         92.8674,  99.0585, 105.2497, 111.4409], device='cuda:0') 
 tensor([  5.9530,  11.9150,  17.8773,  23.8314,  29.7680,  35.6987,  41.6248,
         47.5510,  53.4932,  59.4491,  65.4016,  71.3592,  77.3248,  83.2847,
         89.2358,  95.1681, 101.0694, 106.9403], device='cuda:0',
       grad_fn=<SelectBackward>)
Position 0 tensor([  11.0956,   20.3823,   35.8602,   57.5293,   85.3895,  119.4408,
         159.6834,  206

batch_size 4
Nll tensor(126.5953, device='cuda:0', grad_fn=<DivBackward0>)
Setup [shapes,colors,mus,thetas,masses,x0s] 6 tensor([[8.0000, 8.0000, 0.4800, 0.5712, 7.0000, 5.9000],
        [4.0000, 8.0000, 0.4800, 0.5712, 7.0000, 5.9000],
        [5.0000, 8.0000, 0.4800, 0.5712, 7.0000, 5.9000],
        [7.0000, 8.0000, 0.4800, 0.5712, 7.0000, 5.9000]], device='cuda:0')
Velocity 6 tensor([ 1.3410,  2.6820,  4.0231,  5.3641,  6.7051,  8.0461,  9.3872, 10.7282,
        12.0692, 13.4102, 14.7513, 16.0923, 17.4333, 18.7743, 20.1154, 21.4564,
        22.7974, 24.1384], device='cuda:0') 
 tensor([ 1.5461,  3.0927,  4.6376,  6.1794,  7.7170,  9.2504, 10.7788, 12.3027,
        13.8232, 15.3406, 16.8559, 18.3693, 19.8793, 21.3842, 22.8828, 24.3741,
        25.8620, 27.3488], device='cuda:0', grad_fn=<SelectBackward>)
Position 6 tensor([  6.5705,   8.5821,  11.9346,  16.6282,  22.6628,  30.0384,  38.7551,
         48.8128,  60.2115,  72.9512,  87.0319, 102.4537, 119.2165, 137.3203,
        156.765

In [12]:
msgs[0].size()

torch.Size([19, 4, 64, 512])

In [38]:
rows=msgs[0][0,0,:,:].nonzero()[:,0].unique()
rows

tensor([50, 51, 58, 59, 62], device='cuda:0')

In [43]:
for i in rows:
    print(msgs[0][0,0,i,:].nonzero().size(0))


60
67
61
67
58


In [18]:
for j in range(4):
    print(msgs[1][0,j,51,:])

tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.1817, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.1071, 0.0395, 0.0897, 0.0000, 0.0000,
        0.3180, 0.1110, 0.0000, 0.0000, 0.0000, 0.1820, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0483, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0809, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1457, 0.1454,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0197, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.1918, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0282, 0.1662, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 

In [11]:
loss, truth, pred, condition,msgs_1 = load_predict('logs/exp2020-08-18T02:44:01.911828_train-bs_144_suffix_causal_vel_grouped_46656_x0s_manual_val-suffix_causal_vel_interpolation_gt-A_decoder-hidden_512_control-constraint_1.0_grouped/best_decoder.pt')

batch_size 4
Nll tensor(185.6675, device='cuda:0', grad_fn=<DivBackward0>)
Setup [shapes,colors,mus,thetas,masses,x0s] 0 tensor([[ 7.0000,  8.0000,  0.1800,  0.8491,  7.0000,  8.0000],
        [ 7.0000,  8.0000,  0.1800,  0.8491,  7.0000, 12.0000],
        [ 7.0000,  8.0000,  0.1800,  0.8491,  7.0000,  5.9000],
        [ 7.0000,  8.0000,  0.1800,  0.8491,  7.0000, 14.0000]],
       device='cuda:0')
Velocity 0 tensor([  6.1912,  12.3823,  18.5735,  24.7646,  30.9558,  37.1470,  43.3381,
         49.5293,  55.7204,  61.9116,  68.1027,  74.2939,  80.4851,  86.6762,
         92.8674,  99.0585, 105.2497, 111.4409], device='cuda:0') 
 tensor([  6.1945,  12.3634,  18.4846,  24.5624,  30.5935,  36.5756,  42.5305,
         48.4565,  54.3446,  60.1880,  65.9961,  71.7787,  77.5257,  83.2437,
         88.9442,  94.6556, 100.3763, 106.0948], device='cuda:0',
       grad_fn=<SelectBackward>)
Position 0 tensor([  11.0956,   20.3823,   35.8602,   57.5293,   85.3895,  119.4408,
         159.6834,  206

batch_size 4
Nll tensor(77.6774, device='cuda:0', grad_fn=<DivBackward0>)
Setup [shapes,colors,mus,thetas,masses,x0s] 7 tensor([[5.0000, 6.0000, 0.5000, 0.6981, 2.7000, 8.0000],
        [5.0000, 6.0000, 0.5000, 0.6981, 3.6000, 8.0000],
        [5.0000, 6.0000, 0.5000, 0.6981, 2.3000, 8.0000],
        [5.0000, 6.0000, 0.5000, 0.6981, 7.0000, 8.0000]], device='cuda:0')
Velocity 7 tensor([ 2.5457,  5.0914,  7.6371, 10.1828, 12.7285, 15.2742, 17.8199, 20.3656,
        22.9113, 25.4570, 28.0027, 30.5484, 33.0941, 35.6398, 38.1855, 40.7312,
        43.2769, 45.8226], device='cuda:0') 
 tensor([ 2.8095,  5.6069,  8.3909, 11.1605, 13.9179, 16.6643, 19.3989, 22.1190,
        24.8204, 27.5089, 30.1964, 32.8868, 35.5738, 38.2537, 40.9256, 43.5919,
        46.2627, 48.9379], device='cuda:0', grad_fn=<SelectBackward>)
Position 7 tensor([  9.2729,  13.0914,  19.4556,  28.3656,  39.8213,  53.8226,  70.3697,
         89.4624, 111.1009, 135.2850, 162.0149, 191.2905, 223.1117, 257.4787,
        294.3913

In [42]:
for i in rows:
    print(msgs_1[0][0,0,i,:].nonzero().size(0))

14
17
14
17
20


In [19]:
for j in range(4):
    print(msgs_1[1][0,j,58,:])

tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.3690, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0722, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0914, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.3627,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 