In [2]:
import jax
import jax.numpy as jnp
import numpy as np
from flax import linen as nn
import optax
import matplotlib.pyplot as plt
from pathlib import Path
import common_jax_utils as cju
from typing import Union
from tqdm import tqdm
import wandb
import ml_collections
key = jax.random.PRNGKey(12398)
key_gen = cju.key_generator(key)

# print(Path.cwd())

2024-11-21 16:11:14.775761: W external/xla/xla/service/gpu/nvptx_compiler.cc:930] The NVIDIA driver's CUDA version is 12.3 which is older than the PTX compiler version 12.6.77. Because the driver is older than the PTX compiler version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [3]:
def set_config():
    config = ml_collections.ConfigDict()
    config.lr = 0.001
    config.num_epochs = 3
    config.batch_size = 1024


    config.in_features = 2
    config.out_features = 3
    
    config.num_layers = 4
    config.hidden_features = 256
    
    
    config.s0 = 12
    config.w0 = 10
    return config


config = set_config()

In [4]:
def parrot():
    parent = Path.cwd().parent
    parrot = parent / 'example_data' / 'parrot.png'
    parrot = plt.imread(parrot)
    parrot = jnp.array(parrot)
    return parrot


def kirby():
    parent = Path.cwd().parent
    kirby = parent / 'example_data' / 'kirby.png'
    kirby = plt.imread(kirby)
    kirby = jnp.array(kirby)
    return kirby

def data_format():
    y = parrot()
    # y = kirby()
    dims = y.shape
    x = jnp.array(np.meshgrid(np.arange(y.shape[0]), np.arange(y.shape[1]), indexing='ij')).T.reshape(-1, 2)
    y = y.reshape(-1, 3)
    return x, y, dims



    
def data_loader(batch_size, key):
    x, y, _= data_format()
    key, subkey = jax.random.split(key)
    shuffled = jax.random.permutation(subkey, jnp.arange(x.shape[0]))
    x = x[shuffled]
    y = y[shuffled]
    for i in range(0, len(x), batch_size):
        yield x[i:i+batch_size].astype(jnp.complex64), y[i:i+batch_size].astype(jnp.complex64)
        
        
    



In [5]:
def complex_kernel_initialization(rng, shape, dtype):
    key, subkey = jax.random.split(rng)
    real = jax.random.normal(key, shape, dtype=dtype)
    key, subkey = jax.random.split(key)
    imag = jax.random.normal(key, shape, dtype=dtype)
    return (real + imag.astype(jnp.complex64))/jnp.sqrt(shape[0])

def complex_input_kernel_initialization(rng, shape, dtype):
    key, w_key = jax.random.split(rng)
    in_size = shape[0]
    lim = 1./in_size# from https://github.com/vsitzmann/siren/blob/4df34baee3f0f9c8f351630992c1fe1f69114b5f/modules.py#L630

    
    real = jax.random.uniform(
        key=w_key,
        shape=shape,
        minval=-lim, 
        maxval=lim,
        dtype=dtype
        )
    key, w_key = jax.random.split(w_key)
    imag = jax.random.uniform(
        key=w_key,
        shape=shape,
        minval=-lim, 
        maxval=lim,
        dtype=dtype
        )
    
    return real + 1j*imag
    

In [6]:

def complex_kernel_initialization2(rng, shape, dtype):
    key, w_key = jax.random.split(rng)
    in_size = shape[0]
    lim = jnp.sqrt(6./in_size)/config.w0  # from https://arxiv.org/pdf/2006.09661.pdf subsection.3.2 and appendix 1.5 and https://github.com/vsitzmann/siren/blob/4df34baee3f0f9c8f351630992c1fe1f69114b5f/modules.py#L627
    
    real = jax.random.uniform(
        key=w_key,
        shape=shape,
        minval=-lim, 
        maxval=lim,
        dtype=dtype
        )
    key, w_key = jax.random.split(w_key)
    imag = jax.random.uniform(
        key=w_key,
        shape=shape,
        minval=-lim, 
        maxval=lim,
        dtype=dtype
        )
    
    return real +  1j*imag
    
            

def complex_wire(x: jax.Array, s0:Union[float, jax.Array], w0:Union[float, jax.Array]):
    """
    Implements a complex version of WIRE
    that is exp(j*w0*x)*exp(-|s0*x'|^2)
    from https://arxiv.org/pdf/2301.05187

    :parameter x: a bunch of `jax.Array`s to be fed to this activation function
        var positional
    :parameter s0: inverse scale used in the radial part of the wavelet (s_0 in the paper)
        keyword only
    :parameter w0: w0 parameter used in the rotational art of the wavelet (\omega_0 in the paper)
        keyword only
    :return: a `jax.Array` with a shape determined by broadcasting all elements of x to tha same shape
    """
    radial_part = jnp.exp(
        -jnp.square(
            jnp.abs(
                s0 * x
            )
        )
    )
    rotational_part = 1j*jnp.exp(w0 * x)

    return rotational_part*radial_part



In [7]:

class ComplexDense(nn.Module):
    in_features: int = 2
    hidden_features: int = 256
    out_features: int = 3
    num_layers: int = 5
    s0: float = 12
    w0: float = 10
    
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(
            features=self.hidden_features,
            # kernel_init=complex_input_kernel_initialization,
            # bias_init=complex_input_kernel_initialization,
            kernel_init=complex_kernel_initialization,
            bias_init=complex_kernel_initialization,
            name='input_layer'
        )(x)
        x = complex_wire(x, self.s0, self.w0)
        for i in range(self.num_layers):
            x = nn.Dense(
                features=self.hidden_features,
                # kernel_init=complex_kernel_initialization2,
                # bias_init=complex_kernel_initialization2,
                kernel_init=complex_kernel_initialization,
                bias_init=complex_kernel_initialization,
                name=f'fc{i+1}'
            )(x)
            x = complex_wire(x, self.s0, self.w0)
        
        x = nn.Dense(
            features=self.out_features,
            # kernel_init=complex_kernel_initialization2,
            # bias_init=complex_kernel_initialization2,
            kernel_init=complex_kernel_initialization,
            bias_init=complex_kernel_initialization,
            name='output_layer'
        )(x)
        # x = real_wire(x, s0=self.s0, w0=self.w0)
        # x = complex_wire(x, s0=self.s0, w0=self.w0)
        x = nn.relu(x)
        
        return x
    
    





In [8]:

def mse_loss(pred, true):
    return jnp.mean(jnp.square(pred - true))

    





def update(model, params, x, y, opt, opt_state):
    """
    Compute the gradient for a batch and update the parameters.
    """
    
    def loss_fn(params, x, true_val, eps=1e-6):
        """Mean squared error loss."""
        pred_val = model.apply(params, x)
        mse = mse_loss(pred_val, true_val)
        mean_true_val = true_val.mean(axis=0, keepdims=True)
        scaling = mse_loss(mean_true_val, true_val)
        return mse/(scaling + eps)
        # return mse

    loss, grads = jax.value_and_grad(loss_fn, holomorphic=True)(params, x, y)   
    updates, opt_state = opt.update(grads, opt_state)
    new_params = optax.apply_updates(params, updates)
    return new_params, opt_state, loss

def update_complex(model, params, x, y, opt, opt_state):
    
    def loss_fn(params, x, true_val, eps=1e-6):
        """Mean squared error loss."""
        pred_val = model.apply(params, x)
        mse = mse_loss(pred_val, true_val)
        mean_true_val = true_val.mean(axis=0, keepdims=True)
        scaling = mse_loss(mean_true_val, true_val)
        return mse/(scaling + eps)
        # return mse

    loss, vjp_fn = jax.vjp(loss_fn, params, x, y)
    grads = vjp_fn(jnp.ones_like(loss))[0]
    print(grads["params"])
    updates, opt_state = opt.update(grads, opt_state)
    new_params = optax.apply_updates(params, updates)
    return new_params, opt_state, loss
    
    

def render_and_save_image(model, params, epoch, batch_size=1024):
    if not Path.exists(Path.cwd() / 'results'):
        Path.mkdir(Path.cwd() / 'results')
    path = Path.cwd() / 'results'/ f'epoch_{epoch}.png'
    pred_ys, _ = test(
        model,
        params,
        batch_size
    )
    sh = data_format()[2]
    pred_ys = pred_ys.reshape(sh)
    pred_ys = jnp.real(pred_ys)
    plt.imsave(path, pred_ys)
    

def train(model, params, opt, opt_state, num_epochs, batch_size, key):
    """
    Train the model for a number of epochs.
    """
    x, _, _ = data_format()
    size = x.shape[0]//batch_size
    
    print(f'{size=} steps per epoch')
    for epoch in (range(num_epochs)):
        
        key, subkey = jax.random.split(key)
        dl = data_loader(batch_size, key)
        
        for i, pair in enumerate(dl):
            x, y = pair
            params, opt_state, loss = update(model, params, x, y, opt, opt_state)
            
            if jnp.isnan(loss):
                raise ValueError('Loss is NaN')
            
            if i % 100 == 0:
                print(f'Epoch: {epoch}, Step: {epoch*size +i }, Loss: {jnp.abs(loss)}')
            wandb.log({'loss': jnp.abs(loss), 'step': epoch*size + i})

        
        render_and_save_image(model, params, epoch, batch_size)
        # if epoch % 10 == 0:
        #     print(f'Epoch: {epoch}, Loss: {jnp.abs(loss)}')
        # wandb.log({'loss': jnp.abs(loss), 'epoch': epoch})
        
    return params




def test(model, params, batch_size):
    
    def loss_fn(params, x, true_val, eps=1e-6):
        """Mean squared error loss."""
        pred_val = model.apply(params, x)
        mse = mse_loss(pred_val, true_val)
        mean_true_val = true_val.mean(axis=0, keepdims=True)
        scaling = mse_loss(mean_true_val, true_val)
        return mse/(scaling + eps)
        # return mse

    
    x, y, _ = data_format()
    ys = jnp.array([])
    for i in range(0, len(x), batch_size):
        pred_ys = model.apply(params, x[i:i+batch_size])
        if ys.shape[0] == 0:
            ys = pred_ys
        else:
            ys = jnp.vstack((ys, pred_ys))    
    # pred_ys = model.apply(params, x)
    loss = loss_fn( params, x, y)
    return pred_ys, loss
    
        

def main():
    wandb.init(project='flax_wire')
    
    config = set_config()
    key = jax.random.PRNGKey(0)
    lr = config.lr
    num_epochs = config.num_epochs
    batch_size = config.batch_size
    
    
    wandb.config.update(config)
    
    model = ComplexDense(
        in_features=config.in_features,
        hidden_features=config.hidden_features,
        out_features=config.out_features,
        num_layers=config.num_layers,
        s0=config.s0,
        w0=config.w0
    )

    
    params = model.init(next(key_gen), jnp.ones((1, 2)).astype(jnp.complex64))
    
    print(model)
    print(nn.tabulate(model, next(key_gen))(jnp.ones((1, 2)).astype(jnp.complex64)))

    
    # print the model summary
    
    
    
    opt = optax.adam(lr)
    
    opt_state = opt.init(params)
    
    trained_params = train(
        model=model,
        params=params,
        opt=opt,
        opt_state=opt_state,
        num_epochs=num_epochs,
        batch_size=batch_size,
        key=key
    )

        
    pred_ys, test_loss = test(
        model=model,
        params=trained_params,
        batch_size=batch_size
    )
    print(f'Training Done, Test Loss: {jnp.abs(test_loss)}')
    wandb.log({'test_loss': jnp.abs(test_loss)})
    pred_ys = pred_ys.reshape(data_format()[2])
    pred_ys = jnp.real(pred_ys)
    
    fig, ax = plt.subplots(1, 2)
    _, y, dims = data_format()
    valid_y = y.reshape(dims)
    ax[0].imshow(pred_ys)
    ax[1].imshow(valid_y)
    plt.show()
    
    
    
    

    return pred_ys

In [None]:
pys = main()

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mmaxwell_litsios[0m ([33mbep-circle[0m). Use [1m`wandb login --relogin`[0m to force relogin


ComplexDense(
    # attributes
    in_features = 2
    hidden_features = 256
    out_features = 3
    num_layers = 4
    s0 = 12
    w0 = 10
)

[3m                              ComplexDense Summary                              [0m
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃[1m [0m[1mpath        [0m[1m [0m┃[1m [0m[1mmodule      [0m[1m [0m┃[1m [0m[1minputs       [0m[1m [0m┃[1m [0m[1moutputs       [0m[1m [0m┃[1m [0m[1mparams       [0m[1m [0m┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│              │ ComplexDense │ [2mcomplex64[0m[1,… │ [2mcomplex64[0m[1,3] │               │
├──────────────┼──────────────┼───────────────┼────────────────┼───────────────┤
│ input_layer  │ Dense        │ [2mcomplex64[0m[1,… │ [2mcomplex64[0m[1,2… │ bias:         │
│              │              │               │                │ [2mcomplex64[0m[25… │
│              │              │         