In [2]:
from lbfgsb_scipy import LBFGSBScipy
import torch
import torch.nn as nn
import numpy as np
import math

In [31]:
class NotearsRKHS(nn.Module):
   """n: number of samples, d: num variables"""
   def __init__(self, n, d, x):
      super(NotearsRKHS, self).__init__()
      self.d = d
      self.n = n
      self.x = x
      # initialize alpha
      self.fc1_pos = nn.Linear(n, d, bias=False)  # fc1_pos.weight = [d ,n], fc1_pos(x) = x @ fc1_pos.weight^T
      self.fc1_neg = nn.Linear(n, d, bias=False)
      #self.fc1_pos.weight.bounds = self._bounds()
      #self.fc1_neg.weight.bounds = self._bounds()
      nn.init.zeros_(self.fc1_pos.weight)
      nn.init.zeros_(self.fc1_neg.weight)
      self.I = torch.eye(self.d)
   def gaussian_kernel(self, x, y, gamma=1): # [d, 1] * [d, 1] -> [1, 1]
    #distance_squared = torch.norm(x-y, dim=1, keepdim=True)**2
      distance_squared = torch.norm(x-y, dim=-1)**2
      return torch.exp(-distance_squared / (gamma**2))

   def forward(self,x): #[n,d] -> [n, d], forward(x)_{l,j} = estimation of x_j at lth observation 
      #K = torch.zeros((self.n, self.n))
      x1 = x.unsqueeze(-1)
      x1 = x1.repeat(1, 1, 2).transpose(1, 2)
      x2 = x.unsqueeze(0)
      x2 = x.repeat(2, 1, 1)
      K = self.gaussian_kernel(x1, x2, gamma = 1)
      #for i in range(self.n):
      #   for l in range(self.n):
      #      K[i,l] = self.kernel(x[i,:], x[l,:])
      #print(K.shape)
      output = self.fc1_pos(K) - self.fc1_neg(K)
      #output = output.t()
      return output
   
   def L_risk(self, output, x, penalty): # [1, 1]
      squared_loss = 0.5 / self.n * torch.sum((output - x) ** 2)
      x1 = x.unsqueeze(-1)
      x1 = x1.repeat(1, 1, 2).transpose(1, 2)
      x2 = x.unsqueeze(0)
      x2 = x.repeat(2, 1, 1)
      K = self.gaussian_kernel(x1, x2, gamma = 1) #[n,n]
      fc1_weight = self.fc1_pos.weight - self.fc1_neg.weight
      A = torch.matmul(torch.matmul(fc1_weight, K), fc1_weight.t())
      diagonal = torch.sum(torch.diag(A))
      regularized = penalty*diagonal 
      loss = squared_loss + regularized
      return loss
   
   def fc1_to_adj(self) -> np.ndarray: # [d, d]
      """Get W from fc1 weights, take L2 norm of the gradient"""
      # get [n, n, d] gradient matrix of the kernel
      # Initialize a tensor to store gradients
      grad_storage = torch.zeros(self.n, self.n, self.d)
      # Compute gradients
      for i in range(self.n):
         for l in range(self.n):
            x_i = self.x[i,:]
            x_l = self.x[l,:]
            x_l.retain_grad()
            K_il = self.gaussian_kernel(x_i, x_l, gamma=1)
            # Zero out gradients in x, otherwise it accumulate the gradients
            if x_l.grad is not None:
                  x_l.grad.zero_()
            # Select individual scalar element from y
            # Perform backward pass on the scalar element
            K_il.backward()
            # Store the computed gradients
            grad_storage[i, l] = x_l.grad.clone() #[n, n, d]
      fc1_weight = self.fc1_pos.weight - self.fc1_neg.weight # [d, n]
      grad_storage = grad_storage.transpose(0, 2)  # Shape [d, n, n]
      grad_storage = grad_storage.unsqueeze(-1) #[d ,n, 1]
      weight = torch.matmul(fc1_weight, grad_storage).squeeze(-1) # [d, n, d]
      weight = torch.sum(weight ** 2, dim = 1)/self.n # [d, d]
      return weight
    
   def h_func(self, t: float = 1.0) -> torch.Tensor: #[1, 1]
        """
        Parameters
        ----------
        t : float, optional
            Controls the domain of M-matrices, by default 1.0

        Returns
        -------
        torch.Tensor
            A scalar value of the log-det acyclicity function :math:`h(\Theta)`.
        """
        weight = self.fc1_to_adj()
        sign, logabsdet = torch.linalg.slogdet(t*self.I - weight)
        h = -sign * logabsdet + self.d * np.log(t)
        return h
   
   def dual_ascent_step(model, x, lambda1, mu, rho, h, rho_max):
       """Perform one step of dual ascent in augmented Lagrangian."""
       h_new = None
       optimizer = LBFGSBScipy(model.parameters())
       while rho < rho_max:
        def closure():
            optimizer.zero_grad()
            x_hat = model(x)
            loss = L_risk(x_hat, x, penalty = lambda1)
            h_val = model.h_func()
            penalty = 0.5 * rho * h_val * h_val + mu * h_val
            #l2_reg = 0.5 * lambda2 * model.l2_reg()
            #l1_reg = lambda1 * model.fc1_l1_reg()
            primal_obj = loss + penalty 
            primal_obj.backward()
            return primal_obj
        optimizer.step(closure)  # NOTE: updates model in-place
        with torch.no_grad():
            h_new = model.h_func().item()
        if h_new > 0.25 * h:
            rho *= 10
        else:
            break
       mu += rho * h_new
       return rho, mu, h_new

   def RKHS_nonlinear(model: nn.Module,
                        x: torch.Tensor,
                        lambda1: float = 0.,
                        mu: float = 0.,
                        max_iter: int = 100,
                        h_tol: float = 1e-8,
                        rho_max: float = 1e+16,
                        w_threshold: float = 0.3):
        rho, mu, h = 1.0, 0.0, np.inf
        for _ in range(max_iter):
            rho, mu, h = dual_ascent_step(model, x, lambda1, mu,
                                            rho, h, rho_max)
            if h <= h_tol or rho >= rho_max:
                break
        W_est = model.fc1_to_adj()
        W_est[np.abs(W_est) < w_threshold] = 0
        return W_est
# simulation
"""
   def main():
        torch.set_default_dtype(torch.double)
        np.set_printoptions(precision=3)

        import utils as ut
        ut.set_random_seed(123)

        n, d, s0, graph_type, sem_type = 200, 5, 9, 'ER', 'mim'
        B_true = ut.simulate_dag(d, s0, graph_type)
        np.savetxt('W_true.csv', B_true, delimiter=',')

        X = ut.simulate_nonlinear_sem(B_true, n, sem_type)
        np.savetxt('X.csv', X, delimiter=',')

        model = NotearsRKHS(2, 3, x)
        W_est = RKHS_nonlinear(model, x, lambda1=0.01, mu=0.01)
        assert ut.is_dag(W_est)
        np.savetxt('W_est.csv', W_est, delimiter=',')
        acc = ut.count_accuracy(B_true, W_est != 0)
        print(acc)


   if __name__ == '__main__':
        main()
"""


"\n   def main():\n        torch.set_default_dtype(torch.double)\n        np.set_printoptions(precision=3)\n\n        import utils as ut\n        ut.set_random_seed(123)\n\n        n, d, s0, graph_type, sem_type = 200, 5, 9, 'ER', 'mim'\n        B_true = ut.simulate_dag(d, s0, graph_type)\n        np.savetxt('W_true.csv', B_true, delimiter=',')\n\n        X = ut.simulate_nonlinear_sem(B_true, n, sem_type)\n        np.savetxt('X.csv', X, delimiter=',')\n\n        model = NotearsRKHS(2, 3, x)\n        W_est = RKHS_nonlinear(model, x, lambda1=0.01, mu=0.01)\n        assert ut.is_dag(W_est)\n        np.savetxt('W_est.csv', W_est, delimiter=',')\n        acc = ut.count_accuracy(B_true, W_est != 0)\n        print(acc)\n\n\n   if __name__ == '__main__':\n        main()\n"

### Test

In [11]:
def gaussian_kernel(x, y, gamma): # [d, 1] * [d, 1] -> [1, 1]
    #distance_squared = torch.norm(x-y, dim=1, keepdim=True)**2
    distance_squared = torch.norm(x-y, dim=-1)**2
    return torch.exp(-distance_squared / (gamma**2))

In [19]:
x = torch.tensor([[1, 2, 3], [4, 5, 7]], dtype=torch.float32, requires_grad=True)
RKHS = NotearsRKHS(2, 3, x)
#output = RKHS.forward(x)
#RKHS.L_risk(output, x, 1)
#RKHS.fc1_to_adj()
#RKHS.h_func()
RKHS.dual_ascent_step()

TypeError: NotearsRKHS.dual_ascent_step() missing 6 required positional arguments: 'lambda1', 'lambda2', 'rho', 'mu', 'h', and 'rho_max'

In [89]:
x1 = torch.tensor([1.0, 2.0, 3.0], requires_grad=False)
x2 = torch.tensor([1.5, 2.5, 3.5], requires_grad=True)


# Apply the gaussian_kernel function
output = gaussian_kernel(x1, x2, gamma=1)

# Perform backward pass
output.backward()

# Access the gradient with respect to x2
print(x2.grad)

tensor([-0.4724, -0.4724, -0.4724])


In [112]:
weight = torch.zeros(3,3)
I = torch.eye(3)
t = 1
sign, logabsdet = torch.linalg.slogdet(t*I - weight)
h = -sign * logabsdet + 3 * np.log(1)

In [113]:
h

tensor(0.)