In [2]:
import torch
import torch.nn as nn
from torch.autograd import Variable

# Define the PINN model
class QuantumPINN(nn.Module):
    def __init__(self, dim):
        super(QuantumPINN, self).__init__()
        self.hidden_layer1 = nn.Linear(1, 128)
        self.hidden_layer2 = nn.Linear(128, 128)
        self.output_layer = nn.Linear(128, dim * dim * 2)  # Output is the real and imaginary parts of rho

    def forward(self, t):
        t = t.reshape(-1, 1)  # Input is time
        x = torch.relu(self.hidden_layer1(t))
        x = torch.relu(self.hidden_layer2(x))
        output = self.output_layer(x)
        real_part = output[:, :dim * dim].reshape(-1, dim, dim)
        imag_part = output[:, dim * dim:].reshape(-1, dim, dim)
        rho = torch.complex(real_part, imag_part)
        return rho


In [31]:
def generate_loss(pinn_model, Hamiltonian):
    def loss_fn(t):
        rho_pred = pinn_model(t)  # Predict rho(t)
        
        # Separate real and imaginary parts of the predicted rho
        rho_real = rho_pred.real
        rho_imag = rho_pred.imag

        # Compute time derivatives using autograd (for both real and imaginary parts)
        rho_t_real = torch.autograd.grad(rho_real.sum(), t, create_graph=True)[0]
        rho_t_imag = torch.autograd.grad(rho_imag.sum(), t, create_graph=True)[0]
        
        # Compute the commutator: [H, rho]
        commutator_real = - (Hamiltonian @ rho_real - rho_real @ Hamiltonian) + 1j * (Hamiltonian @ rho_imag - rho_imag @ Hamiltonian)
        commutator_imag = - (Hamiltonian @ rho_imag - rho_imag @ Hamiltonian) - 1j * (Hamiltonian @ rho_real - rho_real @ Hamiltonian)

        # Loss is the MSE between the time derivative and the commutator (real and imaginary parts separately)
        loss_real = torch.mean((rho_t_real - commutator_real.real)**2)
        loss_imag = torch.mean((rho_t_imag - commutator_imag.imag)**2)
        
        return loss_real + loss_imag
    
    return loss_fn

In [9]:
import math
import functools
import numpy as np
import scipy as sci
import matplotlib.pyplot as plt
import qutip as qt
import time

In [10]:
# This dictionary maps string keys ('x', 'y', 'z', 'p', 'm', 'i') to functions that generate spin operators for a given dimension dim.
opstr2fun = {'x': lambda dim: qt.spin_Jx((dim-1)/2),
             'y': lambda dim: qt.spin_Jy((dim-1)/2),
             'z': lambda dim: qt.spin_Jz((dim-1)/2),
             'p': lambda dim: qt.spin_Jp((dim-1)/2),
             'm': lambda dim: qt.spin_Jm((dim-1)/2),
             'i': qt.identity}
# Initializes ops as a list of identity matrices for each dimension in dims. Iterates over specs to replace the identity matrix at the specified index with the corresponding spin operator. Returns the tensor product of the operators in ops using qt.tensor.
def mkSpinOp(dims, specs):
    ops = [qt.identity(d) for d in dims]
    for ind, opstr in specs:
        ops[ind] = ops[ind] * opstr2fun[opstr](dims[ind])
    return qt.tensor(ops)
# Constructs a Hamiltonian for a single spin system with interactions along the x, y, and z axes.
def mkH1(dims, ind, parvec):
    axes = ['x', 'y', 'z']
    # Creates a list of spin operators weighted by the corresponding parameters in parvec (ignores zero parameters). Uses functools.reduce to sum these weighted spin operators.
    return functools.reduce(lambda a, b: a + b, 
               [v * mkSpinOp(dims, [(ind,ax)]) for v, ax in zip(parvec, axes) if v!=0])
# Constructs a Hamiltonian for the interaction between two spin systems with interaction terms along all combinations of x, y, and z axes.
def mkH12(dims, ind1, ind2, parmat):
    axes = ['x', 'y', 'z']
    ops = []
    # Iterates over all combinations of the x, y, and z axes for the two spins. For each non-zero element in parmat, adds the corresponding spin-spin interaction term to the empty list ops.
    for i in range(3):
        for j in range(3):
            if parmat[i,j] != 0:
                ops.append(parmat[i,j] * mkSpinOp(dims, [(ind1,axes[i]), (ind2,axes[j])]))
    return functools.reduce(lambda a, b: a + b, ops) # Uses functools.reduce to sum these interaction terms.

In [37]:
import numpy as np 
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

b0 = 1.4 * 2*math.pi # Zeeman field strength in radians per microsecond
A = np.diag([-2.6, -2.6, 49.2]) * 2*math.pi # Hyperfine coupling matrix (Mrad/s)
kr = 1. # Rate constant 1/us
tmax = 12. / kr # Maximum time us
tlist = np.linspace(0, tmax, math.ceil(1000*tmax)) # Time points for simulation
B0 = b0 * np.array([1,0,0]) # Magnetic field vector along x-axis

dims = [2, 2, 3] # Dimensions of the system components (2 qubits, 1 spin-1 nucleus)
dim = np.prod(dims) # Total dimension of the composite system
Hzee = mkH1(dims, 0, B0) + mkH1(dims, 1, B0) # Zeeman Hamiltonian for two spins
Hhfc = mkH12(dims, 0, 2, A) # Hyperfine coupling Hamiltonian
H0 = Hzee + Hhfc # Total Hamiltonian

Ps = 1/4 * mkSpinOp(dims,[]) - mkH12(dims, 0, 1, np.identity(3)) # Singlet projection operator

rho0 = (Ps / Ps.tr()).full().flatten() # Initial density matrix, normalized projection operator for the singlet state.
# Convert the Hamiltonian H0 to a scipy sparse matrix and then to a PyTorch sparse tensor
H_scipy = H0.data.tocsc()  # Convert to a SciPy sparse CSC matrix

# Extract the indices and values for the sparse tensor
H_indices = torch.LongTensor(np.vstack([H_scipy.nonzero()[0], H_scipy.nonzero()[1]]))  # Non-zero element indices
H_values = torch.FloatTensor(H_scipy.data)  # Non-zero values
H_shape = H_scipy.shape  # Shape of the matrix

# Convert to a PyTorch sparse tensor
H_torch_sparse = torch.sparse_coo_tensor(H_indices, H_values, H_shape)..to(device)
Ps = Ps.data.toarray()
print(H_scipy)


  (0, 0)	(154.56635855661784+0j)
  (3, 0)	(4.39822971502571+0j)
  (6, 0)	(4.39822971502571+0j)
  (4, 1)	(4.39822971502571+0j)
  (6, 1)	(-11.551495639211753+0j)
  (7, 1)	(4.39822971502571+0j)
  (2, 2)	(-154.56635855661784+0j)
  (5, 2)	(4.39822971502571+0j)
  (7, 2)	(-11.551495639211753+0j)
  (8, 2)	(4.39822971502571+0j)
  (0, 3)	(4.39822971502571+0j)
  (3, 3)	(154.56635855661784+0j)
  (9, 3)	(4.39822971502571+0j)
  (1, 4)	(4.39822971502571+0j)
  (9, 4)	(-11.551495639211753+0j)
  (10, 4)	(4.39822971502571+0j)
  (2, 5)	(4.39822971502571+0j)
  (5, 5)	(-154.56635855661784+0j)
  (10, 5)	(-11.551495639211753+0j)
  (11, 5)	(4.39822971502571+0j)
  (0, 6)	(4.39822971502571+0j)
  (1, 6)	(-11.551495639211753+0j)
  (6, 6)	(-154.56635855661784+0j)
  (9, 6)	(4.39822971502571+0j)
  (1, 7)	(4.39822971502571+0j)
  (2, 7)	(-11.551495639211753+0j)
  (10, 7)	(4.39822971502571+0j)
  (2, 8)	(4.39822971502571+0j)
  (8, 8)	(154.56635855661784+0j)
  (11, 8)	(4.39822971502571+0j)
  (3, 9)	(4.39822971502571+0j)
 

In [32]:
# Initialize model and optimizer
pinn_model = QuantumPINN(dim).to(device)
optimizer = torch.optim.Adam(pinn_model.parameters(), lr=0.001)
loss_fn = generate_loss(pinn_model, H)

# Time points for training
t_collocation = np.linspace(0, tmax, 1000).reshape(-1, 1)
pt_t_collocation = Variable(torch.from_numpy(t_collocation).float(), requires_grad=True).to(device)

# Training loop
iterations = 5000
for epoch in range(iterations):
    optimizer.zero_grad()
    loss = loss_fn(pt_t_collocation)
    loss.backward()
    optimizer.step()
    
    if epoch % 500 == 0:
        print(f"Epoch {epoch}, Loss: {loss.item()}")


RuntimeError: Can't call numpy() on Tensor that requires grad. Use tensor.detach().numpy() instead.

In [None]:
import matplotlib.pyplot as plt

# Predict the evolution
with torch.no_grad():
    pt_t_eval = Variable(torch.from_numpy(tlist).float(), requires_grad=True).to(device)
    rho_pred = pinn_model(pt_t_eval)

# Extract the predicted singlet probability
ps_pred = []
for rho_t in rho_pred:
    ps_pred.append(np.real(np.trace(Ps @ rho_t.cpu().numpy())))

# Plot results
plt.plot(tlist, ps_pred, label="PINN Prediction")
plt.plot(tlist[:1000], ps[:1000], label="Original ODE Solution")
plt.xlabel('time')
plt.ylabel('Singlet Probability')
plt.title('Evolution of singlet probability over time for 1 spin-1 nucleus')
plt.legend()
plt.show()
