# Modeling with GNN-ePC-SAFT

Model combining graph neural network with ePC-SAFT

In [1]:
from torch_geometric.nn.models import PNA
import torch
import torch.nn.functional as F

from pcsaft import pcsaft_fugcoef, pcsaft_den, dielc_water


In [2]:
from pcsaft_electrolyte import pcsaft_fugcoef as pcfugcoef, pcsaft_den as pcden
import numpy as np
import jax.numpy as jnp
from jax import grad, jit, random
import jax
from jax import dlpack as jdlpack
from torch.utils import dlpack as tdlpack


In [3]:
torch.cuda.is_available()

True

In [4]:
jax.devices()[0].platform

'gpu'

In [34]:
paratorch = torch.tensor([[3.8176, 3.8373, 242.78, 0],
                          [1., 3.7039, 150.03, 0]], dtype=float, device='cuda:0', requires_grad=True)


In [35]:
statetorch = torch.asarray([0, 1., 320, 101325], dtype=float, device='cuda:0')


In [9]:
def ActivityCoefficient(parameters, state):
    x = jnp.asarray(
        (state[0],
          state[1])
    )
    t = state[2]
    P = state[3]
    m = parameters[:, 0]
    s = parameters[:, 1]
    e = parameters[:, 2]
    k_ij = jnp.asarray(
        ([1, parameters[0, 3]],
         [parameters[1, 3], 1])
    )
    e_assoc = None
    vol_a = None
    dipm = None
    dip_num = None
    z = None
    dielc = None

    jAC = jit(pcfugcoef)

    rho = jnp.asarray([state[4],])
    fungcoef = jAC(x, m, s, e, t, rho, k_ij, e_assoc, vol_a, dipm,
                   dip_num, z, dielc)
    
    rho01 = jnp.asarray([state[5],])
    fungcoef01 = jAC(jnp.asarray((.99, .01)), m, s, e, t, rho01, k_ij, e_assoc, vol_a, dipm,
                     dip_num, z, dielc)

    gamma1 = (fungcoef[0]/fungcoef01[0])
    return gamma1


In [21]:
def rho(parameters, state):
    x = jnp.asarray(
        (state[0],
          state[1])
    )
    t = state[2]
    P = state[3]
    m = parameters[:, 0]
    s = parameters[:, 1]
    e = parameters[:, 2]
    k_ij = jnp.asarray(
        ([1, parameters[0, 3]],
         [parameters[1, 3], 1])
    )
    e_assoc = None
    vol_a = None
    dipm = None
    dip_num = None
    z = None
    dielc = None
    phase = 'liq'

    jrho = jit(pcden,static_argnums=6)
    rho = jrho(x, m, s, e, t, P, phase, k_ij = k_ij, e_assoc= e_assoc, vol_a = vol_a, dipm = dipm,
               dip_num = dip_num, z = z, dielc = dielc)

    rho01 = jrho(jnp.asarray((0.99, .01)), m, s, e, t, P, phase, k_ij = k_ij, e_assoc= e_assoc, vol_a = vol_a, dipm = dipm,
               dip_num = dip_num, z = z, dielc = dielc)  # mol / mÂ³

    return rho, rho01


In [22]:
parameters = tdlpack.to_dlpack(paratorch)
parameters = jdlpack.from_dlpack(parameters)

state = tdlpack.to_dlpack(statetorch)
state = jdlpack.from_dlpack(state)
rhon = rho(parameters, state)
state = np.append(state, rhon)
state


array([0.000000e+00, 1.000000e+00, 3.200000e+02, 1.013250e+05,
       5.978534e+01, 3.784439e+01], dtype=float32)

In [23]:
ActivityCoefficient(parameters,state)

Array(0.938268, dtype=float32)

In [13]:
grad_fn = jit(grad(ActivityCoefficient))

In [25]:
grad_fn(parameters,state)

Array([[-0.01900522, -0.02862282, -0.00014551,  0.06925365],
       [-0.03995297, -0.01894344, -0.00022705,  0.00022159]],      dtype=float32)

In [26]:
class PCSAFT(torch.autograd.Function):

    @staticmethod
    def forward(ctx, input, state):
        parameters = tdlpack.to_dlpack(input)
        parameters = jdlpack.from_dlpack(parameters)

        state = tdlpack.to_dlpack(state)
        state = jdlpack.from_dlpack(state)
        rhon = rho(parameters, state)
        state = np.append(state, rhon)
        ctx.parameters = parameters
        ctx.state = state
        gamma1 = ActivityCoefficient(parameters, state)
        gamma1 = jdlpack.to_dlpack(gamma1)

        gamma1 = tdlpack.from_dlpack(gamma1)
        return gamma1

    @staticmethod
    def backward(ctx, dg1):

        grad_gamma1 = grad_fn(ctx.parameters, ctx.state)
        grad_gamma1 = jdlpack.to_dlpack(grad_gamma1)
        grad_gamma1 = dg1 * tdlpack.from_dlpack(grad_gamma1)
        return grad_gamma1, None


In [27]:
pcsaft = PCSAFT.apply

In [43]:
gamma = pcsaft(paratorch,statetorch)

In [44]:
gamma

tensor(0.9383, device='cuda:0', grad_fn=<PCSAFTBackward>)

In [53]:
gamma.backward()

In [58]:
paratorch.grad

tensor([[-0.0760, -0.1145, -0.0006,  0.2770],
        [-0.1598, -0.0758, -0.0009,  0.0009]], device='cuda:0',
       dtype=torch.float64)