In [1]:
import time
import os

os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"  

import jax
import jax.numpy as jnp
from jax import pmap
from functools import partial
import numpy as np
import optax
from flax import nnx

from modules.params_utils import save_params
from modules.training_utils import data_loader, print_generated, update_and_check_grads, clip_gradients, plot_learning_curves, choose_schedule

import jax
import jax.numpy as jnp
from flax import nnx
from jax.sharding import Mesh, PartitionSpec, NamedSharding
from jax import device_put
import numpy as np

from models.peds import PEDS
from modules.params_utils import initialize_or_restore_params

from modules.params_utils import initialize_or_restore_params
from modules.training import train_model

from models.mlp import mlp
from models.cnn import cnn
from solvers.low_fidelity_solvers.lowfidsolver_class import lowfid


In [None]:
# Define Parallelization
n_devices = len(jax.devices())

print(n_devices)

mesh = Mesh(devices=np.array(jax.devices()), axis_names=('devices',))
data_sharding = NamedSharding(mesh, PartitionSpec('devices',))

# Ingest data <- Here we will do active learning
full_data = jnp.load("data/highfidelity/high_fidelity_10012_20steps.npz", allow_pickle=True)

pores = jnp.asarray(full_data['pores'], dtype=jnp.float32)
kappas = jnp.asarray(full_data['kappas'], dtype=jnp.float32)
base_conductivities = jnp.asarray(full_data['conductivity'], dtype=jnp.float32)

# Create dataset
dataset_train = [pores[:8000], base_conductivities[:8000], kappas[:8000]]
dataset_valid = [pores[8000:], base_conductivities[8000:], kappas[8000:]]

# Shard the dataset
def shard_dataset(dataset, n_device, sharding):
    # Extract components of the dataset
    pores, conductivities, kappas = dataset
    
    # Determine shard sizes
    shard_size = pores.shape[0] // n_devices
    
    # Ensure each shard is of equal size
    assert pores.shape[0] % n_devices == 0, "Dataset size must be divisible by the number of devices"
    
    # Reshape data to distribute across devices
    pores_sharded = pores.reshape(n_devices, shard_size, *pores.shape[1:])
    conductivities_sharded = conductivities.reshape(n_devices, shard_size, *conductivities.shape[1:])
    kappas_sharded = kappas.reshape(n_devices, shard_size, *kappas.shape[1:])

    # Apply NamedSharding to the reshaped data
    pores_sharded = device_put(pores_sharded, sharding)
    conductivities_sharded = device_put(conductivities_sharded, sharding)
    kappas_sharded = device_put(kappas_sharded, sharding)

    return (pores_sharded, conductivities_sharded, kappas_sharded)


# Shard training and validation datasets
dataset_train = shard_dataset(dataset_train, n_devices, data_sharding)

8


In [3]:
def data_loader(*arrays, batch_size):
    
    # Ensure all arrays have the same number of samples
    n_samples = arrays[1].shape[1]
    for array in arrays:
        assert array.shape[1] == n_samples, "All input arrays must have the same first dimension or second in the case of parallelized."
    
    indices = jnp.arange(n_samples)  # Use jnp.arange for JAX arrays
    
    # Split into batches and yield
    for start_idx in range(0, n_samples, batch_size):
        batch_indices = indices[start_idx:start_idx + batch_size]
        yield tuple(array[:, batch_indices] for array in arrays)

In [4]:
# Create model
key = nnx.Rngs(42)
generator = mlp(input_size= 25, hidden_sizes=[32, 64, 128], step_size=5, rngs=key)
#generator = cnn(rngs=key)

# Params initializing or restoring
generator, checkpointer, ckpt_dir = initialize_or_restore_params(generator, model_name='peds_PI')

# Low Fidelity Solver
lowfidsolver = lowfid(solver='gauss', iterations=1000)

epochs = 10

schedule = "constant"
learn_rate_min = 5e-5
learn_rate_max = 5e-5
batch_size = 200



No checkpoints found. Initializing new parameters.


In [5]:
def predict(generator, lowfidsolver, pores, conductivities):

    conductivity_res = nnx.jit(generator)(pores)
        
    new_conductivity = conductivity_res+conductivities 

    new_conductivity = jnp.maximum(new_conductivity, 1e-5) # here we 
    
    kappa = lowfidsolver(new_conductivity) 
    
    return kappa, conductivity_res


In [6]:
from jax.tree_util import tree_map

def simplify_grad_structure(grads):
    def simplify_state(var):
        var = jnp.squeeze(var[:1], 0)  
        return var
    # Apply the simplification to the entire gradient tree
    return tree_map(simplify_state, grads)

In [None]:
lr_schedule = choose_schedule(schedule, learn_rate_min, learn_rate_max, epochs)
optimizer = nnx.Optimizer(generator, optax.adam(lr_schedule))

def train_step(pores, conductivities, kappas, batch_n, epoch): # sharded pores and kappas
    
    def loss_fn(generator):
        
        kappa_pred, conductivity_res = predict(generator, lowfidsolver, pores, conductivities)
        residuals = (kappa_pred - kappas)

        return jnp.sum(residuals**2)

    loss, grads = nnx.value_and_grad(loss_fn)(generator)
    
    return loss, grads


@partial(
pmap,
axis_name='devices',
static_broadcasted_argnums=(3, 4)  # Indices of `batch_n` and `epoch`
)
def parallel_train_step(pores, conductivities, kappas, batch_n, epoch):
    # `train_step` must return loss and gradients
    loss, grads = train_step(pores, conductivities, kappas, batch_n, epoch)

    #return loss, grads
    grads_tot =  jax.lax.psum(grads, axis_name='devices')
    loss_tot =  jax.lax.psum(loss, axis_name='devices')

    return loss_tot, grads_tot

# Function to accumulate gradients
def accumulate_gradients(total_grads, new_grads):
    if total_grads is None:
        return new_grads
    return jax.tree_util.tree_map(lambda x, y: x + y, total_grads, new_grads)

print("Training...")

epoch_losses = np.zeros(epochs) # 
valid_losses = np.zeros(epochs)
valid_perc_losses = np.zeros(epochs)

for epoch in range(epochs):

    epoch_time = time.time()

    grads = None
    total_loss = 0.0  # Initialize total loss for the epoch
    

    #batch_size = batch_size // n_devices

    for en, batch in enumerate(data_loader(*dataset_train, batch_size=batch_size)):
        
        pores_sharded, conductivities_sharded, kappas_sharded = batch
        
        # Perform parallel computation of loss and gradients
        losses, new_grads = parallel_train_step(pores_sharded, conductivities_sharded, kappas_sharded, en, epoch)

        # Print the first element of the first bias gradient (before reduction)

        # Accumulate gradients across batches
        grads = accumulate_gradients(grads, new_grads)

        # Accumulate loss
        total_loss += jnp.sum(losses)  # Sum losses across devices
    
    avg_loss = total_loss / dataset_train[0].shape[0]

    #avg_val_loss, total_loss_perc = valid(dataset_valid, batch_size, generator, lowfidsolver)
    avg_val_loss, total_loss_perc = 0.0, 0.0

    # Print the average loss at the end of each epoch
    print(f"Epoch {epoch+1}/{epochs}, Training Loss: {avg_loss:.2f}, Validation Losses: [{avg_val_loss:.2f}, {total_loss_perc:.2f}%], Epoch time: {time.time() - epoch_time:.2f}s")

    # per ognuna delle variabili (in ['layers'] --> ['kernel'] e ['bias'], prendi solamente la prima e scarta le altre 8)
    #grads = extract_first_elements(grads)
    
    grads = simplify_grad_structure(grads)

    optimizer.update(grads)
    


Training...
