In [460]:
import tensorflow as tf
import numpy as np
from julia import Main, LightPropagation

"""
has properties to access absorption and scattering

inherits from Layer for easy access to trainable_variables with further composition
"""
class Layer(tf.keras.layers.Layer):
    def __init__(self, infinite=False, **kwargs):
        super().__init__(**kwargs)

        self.n = tf.constant(1.4, dtype=tf.float32)

        if infinite:
            self.infinite = True
        else:
            self.h = tf.Variable(10.0)
    
    def mu_a(self):
        pass
    
    def mu_s(self):
        pass

class ConstLayer(Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

        self._mu_s = tf.Variable(0.1)
        self._mu_a = tf.Variable(0.1)

    def mu_s(self):
        return self._mu_s

    def mu_a(self):
        return self._mu_a

class VariableLayer(Layer):
    def __init__(self, wavelengths, coeffs, **kwargs):
        super().__init__(**kwargs)

        self.wavelengths = tf.constant(wavelengths, dtype=tf.float32)

        # Scattering
        self.a = tf.Variable(1.0)
        self.b = tf.Variable(1.0)

        # Absorption
        self.coeffs = tf.constant(coeffs, dtype=tf.float32)
        self.n_concs = self.coeffs.shape[1] # maybe the wrong one
        self.concs = tf.Variable(tf.ones(self.n_concs), dtype=tf.float32)

    def mu_s(self):
        return self.a*(self.wavelengths / 500)**(-self.b)

    def mu_a(self):
        return tf.transpose(self.coeffs @ self.concs[:, None])[0, ...]

In [461]:
one = VariableLayer(np.array([1, 2]), np.array([[1, 2, 3], [1,2,3]])).mu_s()
two = ConstLayer().mu_s()

one, two

(<tf.Tensor: shape=(2,), dtype=float32, numpy=array([499.99997, 249.99998], dtype=float32)>,
 <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=0.1>)

In [462]:
"""
Material has a forward model which spits out the reflectance
"""
class Material(tf.keras.layers.Layer):
    def __init__(self, layers, n_ext=1.4):
        super().__init__()

        self.layers = layers
        self.n_ext = tf.constant(n_ext)
    def reflectance(self, distance):
        mu_as, mu_ss = self.build_coeffs()

        size = mu_as.shape[0]

        hs = [l.h for l in self.layers]
        hs = self._build_constant_scalar_part(size, hs)

        ns = [l.n for l in self.layers]
        ns = self._build_constant_scalar_part(size, ns)

        

        print(hs, ns)

    def build_coeffs(self):
        mu_as = [l.mu_a() for l in self.layers]
        mu_ss = [l.mu_s() for l in self.layers]

        max_size = tf.reduce_max([tf.size(e) for e in mu_as])

        mu_as = [tf.broadcast_to(e[..., None], (max_size, 1)) for e in mu_as]
        mu_ss = [tf.broadcast_to(e[..., None], (max_size, 1)) for e in mu_ss]

        mu_as = tf.concat(mu_as, -1)
        mu_ss = tf.concat(mu_ss, -1)

        return mu_as, mu_ss
        
    @staticmethod
    def _build_constant_scalar_part(size, vals):
        result = [tf.broadcast_to(e[..., None, None], (size, 1)) for e in vals]
        result = tf.concat(result, -1)

        return result

    def grad_reflectance():
        pass

    def call(self, i):
        return self.reflectance(1)




mat = Material([ConstLayer(), ConstLayer(), VariableLayer(np.array([1, 2]), np.array([[1, 2, 3], [1,2,3]]))])

mat(10)

tf.Tensor(
[[10. 10. 10.]
 [10. 10. 10.]], shape=(2, 3), dtype=float32) tf.Tensor(
[[1.4 1.4 1.4]
 [1.4 1.4 1.4]], shape=(2, 3), dtype=float32)


In [463]:
"""
accepts numpy arguments

n layers
mu_a    [bulk, n]
mu_s    [bulk, n]
n       [bulk, n]
h       [bulk, n]
z       [bulk, 1]
rho     [bulk, 1]
"""

pack_call = lambda mu_a, mu_s, n, h, z, rho: mu_a + mu_s + n + h + [z] + [rho]
unpack_result = lambda e, n: (e[..., :n], e[..., n:2*n], e[..., 2*n:3*n], e[..., 3*n:4*n], e[..., -2, None], e[..., -1, None])

# julia indexing starts at 1
julia_unpack_call = """
function unpack_call(x)
    len = length(x)
    n_layers = (len - 2) ÷ 4

    mu_a = x[1:n_layers]
    mu_s = x[n_layers+1:2*n_layers]
    n = x[2*n_layers+1:3*n_layers]
    h = x[3*n_layers+1:4*n_layers]
    z = x[end-1]
    rho = x[end]

    LightPropagation.flux_DA_Nlay_cylinder_CW(
        rho, mu_a, mu_s; n_ext=n_ext, n_med=n, l=h, a=a, z=z, MaxIter=MaxIter
    )
end
"""

Main.eval(julia_unpack_call)
Main.eval("using LightPropagation; using ForwardDiff")

def reflectance(mu_as, mu_ss, ns, hs, zs, rhos, n_ext=1.4, a=100.0, MaxIter=10000):
    result = np.zeros(rhos.shape)
    for i in np.ndindex(mu_as.shape[:-1]):
        z, rho = [e[i][0].item() for e in [zs, rhos]]
        mu_a, mu_s, n, h = [e[i].tolist() for e in [mu_as, mu_ss, ns, hs]]
        result[i][0] = LightPropagation.flux_DA_Nlay_cylinder_CW(
            rho, mu_a, mu_s, n_ext=n_ext, n_med=n, l=h, a=a, z=z, MaxIter=MaxIter
        )

    return result

def reflectance_gradient(mu_as, mu_ss, ns, hs, zs, rhos, n_ext=1.4, a=100.0, MaxIter=10000):
    Main.n_ext = n_ext
    Main.a = a
    Main.MaxIter = MaxIter

    n_layers = mu_as.shape[-1]
    grad_dim = n_layers*4 + 2
    grad_shape = np.array(rhos.shape)
    grad_shape[-1] = grad_dim

    gradient_result = np.zeros(grad_shape)
    for i in np.ndindex(mu_as.shape[:-1]):
        z, rho = [e[i][0].item() for e in [zs, rhos]]
        mu_a, mu_s, n, h = [e[i].tolist() for e in [mu_as, mu_ss, ns, hs]]

        combined = pack_call(mu_a, mu_s, n, h, z, rho)
        print(combined)
        Main.combined = combined

        Main.eval("result = ForwardDiff.gradient(unpack_call, combined)")
        gradient_result[i] = Main.result
    
    return unpack_result(gradient_result, n_layers)

mu_a = np.array([[0.1, 0.2], [0.2, 0.3]])
mu_s = np.array([[10, 20], [30, 40]])
h = np.array([[10, 20], [30, 40]])
n = np.array([[1.4, 1.4], [1.4, 1.4]])
z = np.array([[0.0], [0.0]])
rho = np.array([[1.0], [2.0]])

reflectance(mu_a, mu_s, n, h, z, rho)
reflectance_gradient(mu_a, mu_s, n, h, z, rho)

[0.1, 0.2, 10, 20, 1.4, 1.4, 10, 20, 0.0, 1.0]
[0.2, 0.3, 30, 40, 1.4, 1.4, 30, 40, 0.0, 2.0]


(array([[-6.66386605e-002, -7.20620591e-018],
        [-4.08142192e-005, -9.02531059e-114]]),
 array([[-1.66068124e-003,  7.41718173e-020],
        [-3.45197390e-006,  6.81486915e-116]]),
 array([[ 4.17811075e-018, -4.17811075e-018],
        [ 7.76220120e-114, -7.76220120e-114]]),
 array([[-2.22501153e-019,  1.06435096e-077],
        [ 2.64771874e-114,  2.03060980e-321]]),
 array([[-0.00903799],
        [-0.0028275 ]]),
 array([[-0.04690193],
        [ 0.00212668]]))

In [464]:
@tf.custom_gradient
def tf_reflectance(mu_as, mu_ss, ns, hs, zs, rhos, n_ext=1.4, a=100.0, MaxIter=10000):
    kwargs = dict(n_ext=1.4, a=100.0, MaxIter=10000)
    wrap_ref = lambda *args: reflectance(*args, **kwargs)
    wrap_grad_ref = lambda *args: reflectance_gradient(*args, **kwargs)

    args = [mu_as, mu_ss, ns, hs, zs, rhos]
    result = tf.numpy_function(wrap_ref, args, Tout=tf.float32)

    def gradient(dy):
        result = tf.numpy_function(wrap_grad_ref, args, Tout=tf.float32)
        print(len(result))
        print(result)
        return [e*dy for e in result]

    return result, gradient

In [465]:
mu_a = tf.Variable([[0.1, 0.2], [0.2, 0.3]])
mu_s = tf.Variable([[10, 20], [30, 40]])
h = tf.Variable([[10, 20], [30, 40]])
n = tf.Variable([[1.4, 1.4], [1.4, 1.4]])
z = tf.Variable([[0.0], [0.0]])
rho = tf.Variable([[1.0], [2.0]])

tf_reflectance(mu_a, mu_s, n, h, z, rho)

with tf.GradientTape() as t:
    reflectance = tf_reflectance(mu_a, mu_s, n, h, z, rho)

t.gradient(reflectance, [mu_a, mu_s, n, h, z, rho])

[0.10000000149011612, 0.20000000298023224, 10, 20, 1.399999976158142, 1.399999976158142, 10, 20, 0.0, 1.0]
[0.20000000298023224, 0.30000001192092896, 30, 40, 1.399999976158142, 1.399999976158142, 30, 40, 0.0, 2.0]
6
[<tf.Tensor: shape=(2, 2), dtype=float64, numpy=
array([[-3.69918203e-002, -4.97115660e-018],
       [-2.14565423e-005, -5.92068604e-114]])>, <tf.Tensor: shape=(2, 2), dtype=float64, numpy=
array([[-1.01748331e-003,  5.11769996e-020],
       [-2.57609833e-005,  4.47063619e-116]])>, <tf.Tensor: shape=(2, 2), dtype=float64, numpy=
array([[-1.78808937e-002, -2.88253052e-018],
       [-4.22836912e-003, -5.09208632e-114]])>, <tf.Tensor: shape=(2, 2), dtype=float64, numpy=
array([[-1.54564335e-019,  9.60060962e-078],
       [ 1.73690227e-114,  1.68476385e-321]])>, <tf.Tensor: shape=(2, 1), dtype=float64, numpy=
array([[0.00092179],
       [0.02415864]])>, <tf.Tensor: shape=(2, 1), dtype=float64, numpy=
array([[-0.02727397],
       [ 0.03075881]])>]


[<tf.Tensor: shape=(2, 2), dtype=float64, numpy=
 array([[-3.69918203e-002, -4.97115660e-018],
        [-2.14565423e-005, -5.92068604e-114]])>,
 <tf.Tensor: shape=(2, 2), dtype=float64, numpy=
 array([[-1.01748331e-003,  5.11769996e-020],
        [-2.57609833e-005,  4.47063619e-116]])>,
 <tf.Tensor: shape=(2, 2), dtype=float64, numpy=
 array([[-1.78808937e-002, -2.88253052e-018],
        [-4.22836912e-003, -5.09208632e-114]])>,
 <tf.Tensor: shape=(2, 2), dtype=float64, numpy=
 array([[-1.54564335e-019,  9.60060962e-078],
        [ 1.73690227e-114,  0.00000000e+000]])>,
 <tf.Tensor: shape=(2, 1), dtype=float64, numpy=
 array([[0.00092179],
        [0.02415864]])>,
 <tf.Tensor: shape=(2, 1), dtype=float64, numpy=
 array([[-0.02727397],
        [ 0.03075881]])>]