In [None]:
# CMT-PINN baseline: 1-compartment IV bolus simulation
# Author: Engin Yapici (open-pk-pinn project)
# Run in Google Colab

!pip install torchdiffeq --quiet

import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from torchdiffeq import odeint
import numpy as np

# Simulate 1-compartment IV bolus PK data
Vd = 1.0  # L/kg
CL = 0.1  # L/h/kg
Dose = 1.0  # mg/kg
k = CL / Vd

# Time points (in hours)
t = torch.linspace(0, 24, steps=50)
C_true = (Dose / Vd) * torch.exp(-k * t)

# Add small noise to simulate measurement
noise = 0.02 * torch.randn_like(C_true)
C_obs = C_true + noise

# Plot ground truth
plt.plot(t, C_obs.numpy(), 'o', label='Observed')
plt.plot(t, C_true.numpy(), '-', label='True')
plt.xlabel('Time (h)')
plt.ylabel('Concentration (mg/L)')
plt.legend()
plt.title('Simulated 1-Compartment IV Bolus PK Profile')
plt.show()

# PINN model: input t -> predicts C(t)
class PINN(nn.Module):
    def __init__(self):
        super(PINN, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(1, 64),
            nn.Tanh(),
            nn.Linear(64, 64),
            nn.Tanh(),
            nn.Linear(64, 1)
        )
    
    def forward(self, t):
        return self.net(t)

# Derivative: dC/dt = -k*C, enforce this during training
def pinn_loss(model, t, C_obs, k):
    t = t.view(-1, 1).requires_grad_()
    C_pred = model(t)
    
    dCdt = torch.autograd.grad(
        C_pred, t,
        grad_outputs=torch.ones_like(C_pred),
        create_graph=True,
        retain_graph=True
    )[0]
    
    ode_residual = dCdt + k * C_pred
    data_loss = torch.mean((C_pred - C_obs.view(-1,1))**2)
    physics_loss = torch.mean(ode_residual**2)
    
    return data_loss + physics_loss, data_loss.item(), physics_loss.item()

# Train
model = PINN()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

t_train = t
C_train = C_obs

for epoch in range(2000):
    optimizer.zero_grad()
    loss, data_l, phys_l = pinn_loss(model, t_train, C_train, k)
    loss.backward()
    optimizer.step()
    
    if epoch % 200 == 0:
        print(f"Epoch {epoch}, Total Loss: {loss.item():.5f}, Data: {data_l:.5f}, Physics: {phys_l:.5f}")

# Predict
with torch.no_grad():
    t_test = torch.linspace(0, 24, 100).view(-1, 1)
    C_pred = model(t_test).squeeze()

# Plot prediction
plt.plot(t.numpy(), C_obs.numpy(), 'o', label='Observed')
plt.plot(t_test.numpy(), C_pred.numpy(), label='PINN Prediction')
plt.plot(t.numpy(), C_true.numpy(), '--', label='True')
plt.xlabel('Time (h)')
plt.ylabel('Concentration (mg/L)')
plt.legend()
plt.title('PINN vs Observed vs True')
plt.show()
