In [1]:
import jax
import jax.numpy as jnp
import numpy as np
from flax import linen as nn
import optax
import matplotlib.pyplot as plt
from pathlib import Path
import common_jax_utils as cju
from typing import Union
from tqdm import tqdm
import wandb
import ml_collections
key = jax.random.PRNGKey(12398)
key_gen = cju.key_generator(key)

print(Path.cwd())

/home/ovindar/PycharmProjects/INR_BEP/flax_exp


2024-11-20 14:37:20.536771: W external/xla/xla/service/gpu/nvptx_compiler.cc:930] The NVIDIA driver's CUDA version is 12.3 which is older than the PTX compiler version 12.6.77. Because the driver is older than the PTX compiler version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [2]:
def set_config():
    config = ml_collections.ConfigDict()
    config.lr = 1e-3
    config.num_epochs = 10
    config.batch_size = 256


    config.in_features = 2
    config.out_features = 3
    
    config.num_layers = 4
    config.hidden_features = 256
    
    
    config.s0 = 12
    config.w0 = 1
    return config

In [3]:
def image_loader():
    parent = Path.cwd().parent
    parrot = parent / 'example_data' / 'parrot.png'
    parrot = plt.imread(parrot)
    parrot = jnp.array(parrot)
    return parrot

def data_format():
    y = image_loader()
    x = jnp.array(np.meshgrid(np.arange(y.shape[0]), np.arange(y.shape[1]), indexing='ij')).T.reshape(-1, 2)
    y = y.reshape(-1, 3)
    return x, y


def data_loader(batch_size, key):
    x, y = data_format()
    together = jnp.concatenate([x, y], axis=1)
    key, subkey = jax.random.split(key)
    shuffled = jax.random.permutation(subkey, together)
    x, y = jnp.split(shuffled, [2], axis=1)
    x = jnp.array(x, dtype=jnp.int16)
    for i in range(0, len(x), batch_size):
        yield x[i:i+batch_size], y[i:i+batch_size]
    


sh = data_format()[1].shape
length = sh[0]*sh[1]


In [4]:
def complex_kernel_initialization(rng, shape, dtype):
    key, subkey = jax.random.split(rng)
    real = jax.random.normal(key, shape, dtype=dtype)
    key, subkey = jax.random.split(key)
    imag = jax.random.normal(key, shape, dtype=dtype)
    return real + jnp.imag(imag/np.sqrt(shape[0]))

def complex_wire(x: jax.Array, s0:Union[float, jax.Array], w0:Union[float, jax.Array]):
    """
    Implements a complex version of WIRE
    that is exp(j*w0*x)*exp(-|s0*x'|^2)
    from https://arxiv.org/pdf/2301.05187

    :parameter x: a bunch of `jax.Array`s to be fed to this activation function
        var positional
    :parameter s0: inverse scale used in the radial part of the wavelet (s_0 in the paper)
        keyword only
    :parameter w0: w0 parameter used in the rotational art of the wavelet (\omega_0 in the paper)
        keyword only
    :return: a `jax.Array` with a shape determined by broadcasting all elements of x to tha same shape
    """
    radial_part = jnp.exp(
        -jnp.square(
            jnp.abs(
                s0 * x
            )
        )
    )
    rotational_part = jnp.exp(jnp.imag(w0 * x))


    return rotational_part*radial_part

def unscaled_gaussian_bump(*x:jax.Array, inverse_scale:Union[float, jax.Array]):
    """ 
    e^(sum_{x' in x}-|inverse_scale*x'|^2)

    :param x: sequence of arrays for which to calculate the gaussian bump
    :returns: the product of the gaussian bumps (computed as a sum in log-space)
    """
    x = jnp.stack(x, axis=0)
    if jnp.isrealobj(x):
        scaled_x = inverse_scale*x
    else:
        scaled_x = jnp.abs(inverse_scale*x)
    return jnp.exp(-jnp.sum(jnp.square(scaled_x), axis=0))


def real_wire(*x: jax.Array, s0:Union[float, jax.Array], w0:Union[float, jax.Array]):
    """ 
    Implements a real version of WIRE-nD
    that is sin(w0*x[0])*exp(-\sum_{x' in x}|inverse_scale*x'|^2)
    from https://arxiv.org/pdf/2301.05187

    :parameter x: a bunch of `jax.Array`s to be fed to this activation function
        var positional
    :parameter s0: inverse scale used in the radial part of the wavelet (s_0 in the paper)
        keyword only
    :parameter w0: w0 parameter used in the rotational art of the wavelet (\omega_0 in the paper)
        keyword only
    :return: a `jax.Array` with a shape determined by broadcasting all elements of x to tha same shape
    """
    radial_part = unscaled_gaussian_bump(*x, inverse_scale=s0)
    rotational_part = jnp.sin(w0*x[0])
    return rotational_part*radial_part

In [5]:

class ComplexDense(nn.Module):
    
    in_features: int = 2
    hidden_features: int =32
    out_features: int = 3
    num_layers: int = 1
    s0: float = 1
    w0: float = 1
    
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(
            features=self.hidden_features,
            kernel_init=complex_kernel_initialization,
            bias_init=complex_kernel_initialization,
            name='input_layer'
        )(x)
        x = complex_wire(x, self.s0, self.w0)
        for i in range(self.num_layers-1):
            x = nn.Dense(
                features=self.hidden_features,
                kernel_init=complex_kernel_initialization,
                bias_init=complex_kernel_initialization,
                name=f'fc{i}'
            )(x)
            x = complex_wire(x, self.s0, self.w0)
        
        x = nn.Dense(
            features=self.out_features,
            kernel_init=complex_kernel_initialization,
            bias_init=complex_kernel_initialization,
            name='output_layer'
        )(x)
        # x = real_wire(x, s0=self.s0, w0=self.w0)
        x = complex_wire(x, s0=self.s0, w0=self.w0)
        
        
        return x
    
    
config = set_config()
model = ComplexDense(
    in_features=config.in_features,
    hidden_features=config.hidden_features,
    out_features=config.out_features,
    num_layers=config.num_layers,
    s0=config.s0,
    w0=config.w0
)



In [6]:
def loss_fn(params, x, y):
    """Mean squared error loss."""
    y_pred = model.apply(params, x)
    return jnp.mean(jnp.square(y - y_pred)) + 1e-8


def update(params, x, y, opt, opt_state):
    """
    Compute the gradient for a batch and update the parameters.
    """
    loss, grads = jax.value_and_grad(loss_fn)(params, x, y)


    
    # jax.tree.map(lambda g: print(g), grads)
    # loss, vjp_fn = jax.vjp(loss_fn, params, x, y)
    # grads = vjp_fn(jnp.ones_like(loss))[0]
    
    # grads = jax.tree.map(lambda g: jnp.clip(g, -1.0, 1.0), grads)

    
    updates, opt_state = opt.update(grads, opt_state)
    new_params = optax.apply_updates(params, updates)
    return new_params, opt_state, loss

def render_and_save_image(params, epoch):
    path = Path.cwd() / f'epoch_{epoch}.png'
    pred_ys, _ = test(params, 512)
    pred_ys = pred_ys.reshape(sh)
    plt.imsave(path, pred_ys)
    return pred_ys
    

def train(params, opt, opt_state, num_epochs, batch_size, key):
    """
    Train the model for a number of epochs.
    """
    for epoch in tqdm(range(num_epochs)):
        key, subkey = jax.random.split(key)
        dl = data_loader(batch_size, key)
        for x, y in dl:
            params, opt_state, loss = update(params, x, y, opt, opt_state)
            
            if jnp.isnan(loss):
                raise ValueError('Loss is NaN')
        
        render_and_save_image(params, epoch)
        
        print(f'Epoch: {epoch}, Loss: {loss}')
        wandb.log({'loss': loss, 'epoch': epoch})
    return params


def test(params, batch_size):
    x, y = data_format()
    pred_ys = jnp.array([])
    for i in range(0, len(x), batch_size):
        x_batch, y_batch = x[i:i+batch_size], y[i:i+batch_size]
        y_pred = model.apply(params, x_batch)
        if pred_ys.shape[0] == 0:
            pred_ys = y_pred
        else:
            pred_ys = jnp.concatenate([pred_ys, y_pred], axis=0)
    loss = jnp.mean(jnp.square(y - pred_ys))
    return pred_ys, loss
    
        

def main():
    wandb.init(project='flax_wire')
    
    
    config = set_config()
    key = jax.random.PRNGKey(0)
    lr = config.lr
    num_epochs = config.num_epochs
    batch_size = config.batch_size
    
    
    wandb.config.update(config)

    
    params = model.init(next(key_gen), jnp.ones((1, 2)))
    
    
    opt = optax.adam(lr)
    
    opt_state = opt.init(params)
    data = data_loader(batch_size, key)
    
    trained_params = train(
        params=params,
        opt=opt,
        opt_state=opt_state,
        num_epochs=num_epochs,
        batch_size=batch_size,
        key=key
    )
        
    pred_ys, test_loss = test(
        params=trained_params,
        batch_size=batch_size
    )
    print(f'Training Done, Test Loss: {test_loss}')
    wandb.log({'test_loss': test_loss})
    
    
    
    fig, ax = plt.subplots(1, 2, figsize=(10, 5))
    img = image_loader()[1]
    ax[0].imshow(img)
    ax[0].set_title('Original Image')
    ax[1].imshow(pred_ys.reshape(img.shape))
    ax[1].set_title('Reconstructed Image')
    return pred_ys

In [None]:
main()

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mmaxwell_litsios[0m ([33mbep-circle[0m). Use [1m`wandb login --relogin`[0m to force relogin


  0%|          | 0/10 [00:00<?, ?it/s]