In [1]:
import torch 
import torch.nn as nn
from lib.solvers import anderson, broyden
from lib.jacobian import jac_loss_estimate

In [2]:
class ConcatConditioning(nn.Module):
    def __init__(self, xdim, ydim, zdim):
        super(ConcatConditioning, self).__init__()

        self.xdim = xdim 
        self.ydim = ydim 
        self.zdim = zdim 

        self.linear = nn.Linear(xdim + ydim, zdim)
        self.act = torch.nn.functional.relu
    
    def forward(self, x, y):
        assert x.size(1) == self.xdim 
        assert y.size(1) == self.ydim 
        
        b, c, l = x.shape
        x = x.permute(0,2,1)
        y = y.permute(0,2,1)

        m = torch.cat([x,y], axis=-1)
        z = self.act(self.linear(m))
        z = z.permute(0,2,1)
        return z 

In [7]:
class DEQConditioning(nn.Module):
    def __init__(self, dim, solver='anderson', f_thres=40, b_thres=40, jac_loss=False):
        super(DEQConditioning, self).__init__()
        self.xdim = dim 
        self.zdim = dim
        if solver == 'anderson':
            self.solver = anderson 
        elif solver == 'broyden':
            self.solver = broyden
        self.f_thres = f_thres
        self.b_thres = b_thres 
        self.jac_loss = jac_loss 
        self.f = ConcatConditioning(dim, dim, dim)
    
    def forward(self, x, z):

        assert x.size(1) == self.xdim 
        assert z.size(1) == self.zdim 
        x = x.unsqueeze(-1)
        z = z.unsqueeze(-1)
        
        with torch.no_grad():
            z_star = self.solver(lambda z: self.f(z, x), z, threshold=self.f_thres)['result']   # See step 2 above
            new_z_star = z_star


        if self.training:
            new_z_star = self.f(z_star.requires_grad_(), x)
            
            # Jacobian-related computations, see additional step above. For instance:
            jac_loss = jac_loss_estimate(new_z_star, z_star, vecs=1)

            def backward_hook(grad):
                if self.hook is not None:
                    self.hook.remove()
                    torch.cuda.synchronize()   # To avoid infinite recursion
                # Compute the fixed point of yJ + grad, where J=J_f is the Jacobian of f at z_star
                new_grad = self.solver(lambda y: autograd.grad(new_z_star, z_star, y, retain_graph=True)[0] + grad, \
                                       torch.zeros_like(grad), threshold=self.b_thres)['result']
                return new_grad

            self.hook = new_z_star.register_hook(backward_hook)

        new_z_star = new_z_star[...,0]

        if self.jac_loss:
            return new_z_star, jac_loss 
        else:
            return new_z_star

In [8]:
bs, xdim, ydim = 16, 32, 32
x = torch.rand(bs, xdim)
y = torch.rand(bs, ydim)

In [9]:
cond = DEQConditioning(32)

In [10]:
cond(x, y).shape

torch.Size([16, 32])