In [7]:
import torch

from HNN import HNN

device = "cuda" if torch.cuda.is_available() else "cpu"

mean = torch.tensor([0.0, 2.0])
std = 0.1
init_states = (
    torch.normal(mean=mean.expand(3, 2), std=std).to(device).requires_grad_()
)

hnn_model = HNN().to(device)


def get_time_derivatives(model, x):
    grads_ls = []
    for qp_pair in x:
        h_hat = model(qp_pair)
        grads = torch.autograd.grad(h_hat, qp_pair, create_graph=True)
        grads_ls.append(grads)
    return grads_ls


h_hats = get_time_derivatives(hnn_model, init_states)
print(h_hats)


def get_deriv_pairs(h_hats):
    """
    Gets the pairs of [dq_dt, dp_dt]
    """
    deriv_pairs = []
    for h_hat in h_hats:
        dpdq = h_hat[0] * torch.tensor([-1.0, 1.0]).to("cuda")
        dqdp = torch.flip(dpdq, dims=[0])
        deriv_pairs.append(dqdp)
    return deriv_pairs


deriv_pairs = get_deriv_pairs(h_hats)
print(deriv_pairs)


[(tensor([-0.0101, -0.0109], device='cuda:0', grad_fn=<ViewBackward0>),), (tensor([-0.0052, -0.0035], device='cuda:0', grad_fn=<ViewBackward0>),), (tensor([ 0.0028, -0.0085], device='cuda:0', grad_fn=<ViewBackward0>),)]
[tensor([-0.0109,  0.0101], device='cuda:0', grad_fn=<FlipBackward0>), tensor([-0.0035,  0.0052], device='cuda:0', grad_fn=<FlipBackward0>), tensor([-0.0085, -0.0028], device='cuda:0', grad_fn=<FlipBackward0>)]


In [28]:
def get_model_time_derivatives(model, x):
    """
    Computes time derivatives [dq/dt, dp/dt] for a batch of inputs.
    x: Tensor of shape (Batch_Size, 2)
    """
    # 1. Forward Pass to get Energy
    H_hat = model(x)
    print('\n\n==============H_hat==================')
    print(H_hat)
    
    # 2. Vectorized Gradient Calculation
    # We sum() the energy to get a scalar, but because samples are independent,
    # the gradients separate out perfectly per row.
    grads = torch.autograd.grad(H_hat.sum(), x, create_graph=True)[0]
    print('\n\n==============grads==================')
    print(grads)
    
    # grads is now shape (Batch, 2) -> [dH/dq, dH/dp]
    
    # 3. The Symplectic Swap (Hamilton's Eqs)
    # dq/dt =  dH/dp
    # dp/dt = -dH/dq
    
    dH_dq = grads[:, 0].unsqueeze(1)
    dH_dp = grads[:, 1].unsqueeze(1)
    
    return torch.cat([dH_dp, -dH_dq], dim=1)

In [29]:
get_model_time_derivatives(hnn_model, init_states)



tensor([[0.0821],
        [0.0837],
        [0.0848]], device='cuda:0', grad_fn=<AddmmBackward0>)


tensor([[-0.0101, -0.0109],
        [-0.0052, -0.0035],
        [ 0.0028, -0.0085]], device='cuda:0', grad_fn=<MmBackward0>)


tensor([[-0.0109,  0.0101],
        [-0.0035,  0.0052],
        [-0.0085, -0.0028]], device='cuda:0', grad_fn=<CatBackward0>)

In [30]:
a = torch.tensor([[-0.0101, -0.0109],
        [-0.0052, -0.0035],
        [ 0.0028, -0.0085]])

In [31]:
a

tensor([[-0.0101, -0.0109],
        [-0.0052, -0.0035],
        [ 0.0028, -0.0085]])

In [35]:
a[:, 1].unsqueeze(1)

tensor([[-0.0109],
        [-0.0035],
        [-0.0085]])