In [None]:
# import jax
# import jax.numpy as jnp
# import equinox as eqx

# def generate_alpha(x: jax.Array):
#     """
#     Generate alpha used in the FINER activation.
#     FINER generates alpha as |x| + 1.
#     :param x: input array for alpha generation.
#     :return: alpha array.
#     """
#     return jnp.abs(x) + 1

# def finer_activation(x: jax.Array, omega: float):
#     """
#     FINER activation function: sin(omega * alpha * x)
#     :param x: input array for activation.
#     :param omega: frequency scaling factor (omega).
#     :return: output array after applying variable-periodic activation function.
#     """
#     alpha = generate_alpha(x)
#     return jnp.sin(omega * alpha * x)

# def init_weights(key, fan_in: int, omega: float, is_first: bool):
#     """
#     Initializes weights based on the SIREN/FINER initialization scheme.
#     :param key: random key for JAX.
#     :param fan_in: input dimension size.
#     :param omega: frequency scaling factor (omega).
#     :param is_first: boolean indicating if it's the first layer.
#     :return: initialized weight array.
#     """
#     if is_first:
#         bound = 1.0 / fan_in
#     else:
#         bound = jnp.sqrt(6.0 / fan_in) / omega
#     return jax.random.uniform(key, shape=(fan_in,), minval=-bound, maxval=bound)

# def init_bias(key, out_size: int, k: float = 20):
#     """
#     Initializes bias based on a uniform distribution with a larger range.
#     :param key: random key for JAX.
#     :param out_size: output size (number of neurons in the layer).
#     :param k: scaling factor for bias initialization.
#     :return: initialized bias array.
#     """
#     return jax.random.uniform(key, shape=(out_size,), minval=-k, maxval=k)

# class FinerLayer(eqx.Module):
#     """
#     FINER Layer using variable-periodic activation functions.
#     This layer applies a linear transformation followed by a FINER sine activation.
#     """
#     weights: jax.Array
#     biases: jax.Array
#     omega: float

#     def __init__(self, in_size: int, out_size: int, key: jax.Array, omega: float = 30.0, is_first: bool = False):
#         """
#         Initializes the FINER layer with the given parameters.
#         :param in_size: input size (number of input features).
#         :param out_size: output size (number of neurons).
#         :param key: JAX random key for initialization.
#         :param omega: frequency scaling factor (omega).
#         :param is_first: boolean indicating if this is the first layer.
#         """
#         key_w, key_b = jax.random.split(key)
#         self.weights = init_weights(key_w, in_size, omega, is_first)
#         self.biases = init_bias(key_b, out_size)
#         self.omega = omega

#     def __call__(self, x):
#         """
#         Forward pass: applies the linear transformation and the FINER activation.
#         :param x: input array.
#         :return: output after applying the FINER activation.
#         """
#         wx_b = jnp.dot(x, self.weights) + self.biases
#         return finer_activation(wx_b, self.omega)

# class Finer(eqx.Module):
#     """
#     Full FINER model with multiple hidden layers.
#     """
#     layers: list

#     def __init__(self, in_size: int, out_size: int, hidden_layers: int = 3, hidden_size: int = 256, key: jax.Array, omega: float = 30.0):
#         """
#         Initialize the FINER network.
#         :param in_size: input size (number of input features).
#         :param out_size: output size (number of output features).
#         :param hidden_layers: number of hidden layers.
#         :param hidden_size: number of neurons in each hidden layer.
#         :param key: JAX random key for initialization.
#         :param omega: frequency scaling factor for the FINER activation.
#         """
#         layers = []
#         keys = jax.random.split(key, num=hidden_layers + 2)

#         # First layer
#         layers.append(FinerLayer(in_size, hidden_size, key=keys[0], omega=omega, is_first=True))

#         # Hidden layers
#         for i in range(hidden_layers):
#             layers.append(FinerLayer(hidden_size, hidden_size, key=keys[i+1], omega=omega))

#         # Output layer (no activation)
#         layers.append(FinerLayer(hidden_size, out_size, key=keys[-1], omega=omega, is_first=False))

#         self.layers = layers

#     def __call__(self, x):
#         """
#         Forward pass through the entire FINER network.
#         :param x: input array.
#         :return: output of the model.
#         """
#         for layer in self.layers:
#             x = layer(x)
#         return x


In [1]:
import jax
import jax.numpy as jnp

def init_bias(shape, k, key):
    """
    Initializes bias values for a layer uniformly in the range [-k, k].

    :param shape: Shape of the bias vector (e.g., (out_features,)).
    :param k: Bound for uniform initialization, i.e., bias values are drawn from [-k, k].
    :param key: JAX random key for generating random values.
    :return: A JAX array with initialized biases.
    """
    return jax.random.uniform(key, shape, minval=-k, maxval=k)


In [2]:
def init_bias_cond(is_first, shape, fbs=None, key=None):
    """
    Conditionally initializes the bias based on whether it is the first layer and if an fbs value is provided.

    :param is_first: Boolean indicating if this is the first layer in the network.
    :param shape: Shape of the bias vector.
    :param fbs: Bound for uniform initialization, used only if `is_first` is True.
    :param key: JAX random key for generating random values.
    :return: Initialized bias, or None if not the first layer and no fbs is provided.
    """
    if is_first and fbs is not None and key is not None:
        return init_bias(shape, fbs, key)
    return None  # Return None if conditions are not met


In [3]:
def generate_alpha(x):
    """
    Generates a scaling factor (alpha) for input x based on its magnitude.
    
    :param x: Input tensor.
    :return: Scaling factor based on |x| + 1.
    """
    return jnp.abs(x) + 1

def finer_activation(x, omega=1):
    """
    Variable-periodic activation function for FINER.
    
    :param x: Input tensor.
    :param omega: Frequency control parameter.
    :return: Activated tensor.
    """
    alpha = generate_alpha(x)
    return jnp.sin(omega * alpha * x)


In [5]:
import jax
import jax.numpy as jnp
import equinox as eqx
from jax import random

class FINERLayer(eqx.Module):
    """
    A single FINER layer with a configurable sine activation.
    """
    weight: jnp.ndarray
    bias: jnp.ndarray
    omega: float
    is_last: bool = False

    def __init__(self, in_features, out_features, key, omega=30, is_last=False):
        self.omega = omega
        self.is_last = is_last
        
        # Initialize weights and biases
        w_key, b_key = random.split(key)
        fan_in = in_features
        bound = 1.0 / fan_in if is_last else jnp.sqrt(6.0 / fan_in) / omega
        self.weight = random.uniform(w_key, (out_features, in_features), minval=-bound, maxval=bound)
        self.bias = random.uniform(b_key, (out_features,), minval=-1.0, maxval=1.0)

    def __call__(self, x):
        # Linear transformation: xW + b
        wx_b = jnp.dot(x, self.weight.T) + self.bias
        # Apply the finer_activation unless this is the last layer
        return wx_b if self.is_last else finer_activation(wx_b, omega=self.omega)


In [6]:
class FINERModel(eqx.Module):
    """
    FINER model composed of multiple FINER layers.
    """
    layers: list

    def __init__(self, in_features=2, out_features=3, hidden_layers=3, hidden_features=256, 
                 first_omega=30, hidden_omega=30, key=None):
        keys = random.split(key, hidden_layers + 2)
        
        # Define layers
        self.layers = []
        
        # Input layer with first_omega
        self.layers.append(FINERLayer(in_features, hidden_features, keys[0], omega=first_omega, is_last=False))
        
        # Hidden layers with hidden_omega
        for i in range(1, hidden_layers + 1):
            self.layers.append(FINERLayer(hidden_features, hidden_features, keys[i], omega=hidden_omega, is_last=False))
        
        # Output layer (no activation)
        self.layers.append(FINERLayer(hidden_features, out_features, keys[-1], omega=hidden_omega, is_last=True))

    def __call__(self, x):
        # Pass the input through each layer in sequence
        for layer in self.layers:
            x = layer(x)
        return x


In [7]:
# Initialize model
key = random.PRNGKey(0)
finer_model = FINERModel(in_features=2, out_features=3, hidden_layers=3, hidden_features=256,
                         first_omega=30, hidden_omega=30, key=key)

# Example input
x = jnp.array([[0.5, -1.0], [1.0, 0.5]])  # Batch of inputs

# Forward pass through the model
output = finer_model(x)
print(output)


[[0.4858135  0.4693304  0.6433001 ]
 [0.49075848 0.45279476 0.62308896]]


In [8]:
from PIL import Image
import numpy as np

def load_image(image_path):
    # Load image and convert to RGB
    img = Image.open(image_path).convert("RGB")
    img = np.array(img) / 255.0  # Normalize pixel values to [0, 1]
    return img

def prepare_data(img):
    # Get image dimensions
    height, width, _ = img.shape

    # Create normalized coordinates and RGB values
    x_coords = np.linspace(-1, 1, width)
    y_coords = np.linspace(-1, 1, height)
    coords = np.array(np.meshgrid(x_coords, y_coords)).reshape(2, -1).T  # Shape: (width*height, 2)
    rgb_values = img.reshape(-1, 3)  # Flatten RGB values to match coords

    return coords, rgb_values


In [10]:

import optax
import jax
from jax import numpy as jnp

# Load and prepare image
image_path = "example_data/parrot.png"  # Replace with your image path
img = load_image(image_path)
coords, rgb_values = prepare_data(img)

# Convert data to JAX arrays
coords = jnp.array(coords)
rgb_values = jnp.array(rgb_values)

# Initialize model
key = jax.random.PRNGKey(0)
finer_model = FINERModel(in_features=2, out_features=3, hidden_layers=3, hidden_features=256, key=key)

# Define loss function
def loss_fn(params, coords, rgb_values):
    pred_rgb = eqx.filter_jit(finer_model)(coords)  # Predict RGB values
    return jnp.mean((pred_rgb - rgb_values) ** 2)

# Initialize optimizer
optimizer = optax.adam(learning_rate=1e-4)
opt_state = optimizer.init(finer_model)

# Training loop
num_steps = 10000

@jax.jit
def train_step(finer_model, opt_state, coords, rgb_values):
    # Compute loss and gradients
    loss, grads = jax.value_and_grad(loss_fn)(finer_model, coords, rgb_values)
    updates, opt_state = optimizer.update(grads, opt_state)
    finer_model = eqx.apply_updates(finer_model, updates)
    return finer_model, opt_state, loss

for step in range(num_steps):
    finer_model, opt_state, loss = train_step(finer_model, opt_state, coords, rgb_values)
    if step % 1000 == 0:
        print(f"Step {step}, Loss: {loss}")


TypeError: grad requires real- or complex-valued inputs (input dtype that is a sub-dtype of np.inexact), but got int32. If you want to use Boolean- or integer-valued inputs, use vjp or set allow_int to True.

Random chunks of cod ethat i have been messing with to try and make it work 


In [None]:
import jax.numpy as jnp

def generate_alpha(x):
    """
    Generates a scaling factor (alpha) for input x based on its magnitude.
    """
    return jnp.abs(x) + 1

def finer_activation(x, omega=1):
    """
    Variable-periodic activation function for FINER.
    """
    alpha = generate_alpha(x)
    return jnp.sin(omega * alpha * x)


In [None]:
import jax
import jax.numpy as jnp
import activation_functions as act  # Import finer_activation here if needed
from .inr_layers import INRLayer

class FinerLayer(INRLayer):
    """
    A single FINER layer with a configurable sine activation.
    """
    allowed_keys = frozenset({'omega'})  # omega for frequency control
    allows_multiple_weights_and_biases = False

    def __init__(self, in_features, out_features, omega=30, key=None, is_last=False):
        self.omega = omega
        self.is_last = is_last
        
        # Initialize weights and biases with special conditions
        w_key, b_key = jax.random.split(key)
        fan_in = in_features
        bound = 1.0 / fan_in if is_last else jnp.sqrt(6.0 / fan_in) / omega
        self.weight = jax.random.uniform(w_key, (out_features, in_features), minval=-bound, maxval=bound)
        self.bias = jax.random.uniform(b_key, (out_features,), minval=-1.0, maxval=1.0)

    def __call__(self, x):
        # Linear transformation: xW + b
        wx_b = jnp.dot(x, self.weight.T) + self.bias
        # Apply finer_activation unless this is the last layer
        return wx_b if self.is_last else act.finer_activation(wx_b, omega=self.omega)


In [None]:
import jax
import jax.numpy as jnp
from .inr_layers import INRLayer  # Make sure this path aligns with your file structure

class FinerLayer(INRLayer):
    """
    FINER layer with custom initialization and activation function.
    """
    allowed_keys = frozenset({'omega'})  # omega for frequency control
    allows_multiple_weights_and_biases = False

    def __init__(self, weight, bias, omega=30, is_last=False):
        """
        Initialize FinerLayer with specified weight and bias.
        
        :param weight: Weight matrix.
        :param bias: Bias vector.
        :param omega: Frequency control parameter.
        :param is_last: Whether this is the last layer (no activation if True).
        """
        self.weight = weight
        self.bias = bias
        self.omega = omega
        self.is_last = is_last

    @classmethod
    def from_config(cls, in_size, out_size, key, is_first=False, omega=30, fbs=None):
        """
        Custom from_config for FinerLayer to handle unique initialization requirements.

        :param in_size: Input size.
        :param out_size: Output size.
        :param key: Random key for initializing weights and biases.
        :param is_first: Flag to indicate if this is the first layer.
        :param omega: Frequency control parameter.
        :param fbs: Bound for initializing the bias in the first layer.
        :return: Initialized FinerLayer instance with weights and biases.
        """
        w_key, b_key = jax.random.split(key)

        # Custom weight initialization
        fan_in = in_size
        bound = 1.0 / fan_in if is_first else jnp.sqrt(6.0 / fan_in) / omega
        weight = jax.random.uniform(w_key, shape=(out_size, in_size), minval=-bound, maxval=bound)

        # Custom bias initialization for the first layer
        if is_first and fbs is not None:
            bias = jax.random.uniform(b_key, shape=(out_size,), minval=-fbs, maxval=fbs)
        else:
            bias = jax.random.uniform(b_key, shape=(out_size,), minval=-1.0, maxval=1.0)

        # Return an instance of FinerLayer with initialized weights and biases
        return cls(weight=weight, bias=bias, omega=omega, is_last=not is_first)
