From a4c76aacd79f886ba0c8bef6790ceebd328cb0c0 Mon Sep 17 00:00:00 2001 From: Antoine Wehenkel Date: Tue, 21 Feb 2023 12:00:14 +0100 Subject: [PATCH] Address #252 - Make soft ranking notebook self contained (#301) * Test commit on personal fork. * Test commit on personal fork. * Update the soft sort notebook to make it self-contained. Remove the `examples` folder. * pre-commit the soft sort notebook * runnable on GPUs + Michaal comments on PR * Clean the NB. * Clean the NB. --------- Co-authored-by: antoinewehenkel --- docs/tutorials/index.rst | 1 - examples/fairness/config.py | 38 ------- examples/fairness/data.py | 168 ----------------------------- examples/fairness/losses.py | 80 -------------- examples/fairness/main.py | 59 ---------- examples/fairness/models.py | 53 --------- examples/fairness/train.py | 193 --------------------------------- examples/soft_error/config.py | 27 ----- examples/soft_error/data.py | 92 ---------------- examples/soft_error/losses.py | 49 --------- examples/soft_error/main.py | 59 ---------- examples/soft_error/model.py | 52 --------- examples/soft_error/train.py | 198 ---------------------------------- 13 files changed, 1069 deletions(-) delete mode 100644 examples/fairness/config.py delete mode 100644 examples/fairness/data.py delete mode 100644 examples/fairness/losses.py delete mode 100644 examples/fairness/main.py delete mode 100644 examples/fairness/models.py delete mode 100644 examples/fairness/train.py delete mode 100644 examples/soft_error/config.py delete mode 100644 examples/soft_error/data.py delete mode 100644 examples/soft_error/losses.py delete mode 100644 examples/soft_error/main.py delete mode 100644 examples/soft_error/model.py delete mode 100644 examples/soft_error/train.py diff --git a/docs/tutorials/index.rst b/docs/tutorials/index.rst index b0acdcef8..ef62e807c 100644 --- a/docs/tutorials/index.rst +++ b/docs/tutorials/index.rst @@ -33,7 +33,6 @@ Miscellaneous :maxdepth: 1 notebooks/soft_sort - notebooks/fairness notebooks/application_biology Quadratic Optimal Transport diff --git a/examples/fairness/config.py b/examples/fairness/config.py deleted file mode 100644 index 3618eeb9b..000000000 --- a/examples/fairness/config.py +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright OTT-JAX -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Configuration to train a fairness aware classifier on the adult dataset.""" - -import ml_collections - - -def get_config(): - """Return a ConfigDict.""" - config = ml_collections.ConfigDict() - config.folder = '/tmp/adult_dataset/' - config.training_filename = 'adult.data' - config.test_filename = 'adult.test' - config.info_filename = 'adult.names' - config.protected = 'sex' - - config.batch_size = 256 - config.num_epochs = 20 - config.embed_dim = 16 - config.hidden_layers = (64, 64) - config.learning_rate = 1e-4 - - config.epsilon = 1e-3 - config.quantization = 16 - config.num_groups = 2 - config.fair_weight = 1.0 - return config diff --git a/examples/fairness/data.py b/examples/fairness/data.py deleted file mode 100644 index 72ca0e969..000000000 --- a/examples/fairness/data.py +++ /dev/null @@ -1,168 +0,0 @@ -# Copyright OTT-JAX -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Loads the adult dataset data.""" - -import os - -import jax -import numpy as np -import pandas as pd - - -def load_df( - data_path: str, - info_path: str, - protected: str, - strip_target: bool = True, - **kwargs -): - """Load a pandas dataframe from two filenames.""" - with open(data_path) as fp: - df = pd.read_csv(fp, skipinitialspace=True, header=None, **kwargs) - - headers = [] - targets = [] - with open(info_path) as fp: - for line in fp: - if line.startswith('|') or not line.strip(): - continue - - parts = line.split(':') - if len(parts) > 1: - headers.append(parts[0]) - else: - pattern = '\n\t.' if strip_target else '\n\t' - targets = [x.strip() for x in line.strip(pattern).split(',')] - - # Finds the index of the column target. - df2 = (df == targets[1]).any(axis=0) - target_idx = df2.index[df2][0] - if len(headers) < len(df.columns): - headers.insert(target_idx, 'target') - df.columns = headers - target = df.columns[target_idx] - - # Change targets and protected columns to integers - for col in [protected, target]: - vs = sorted(df[col].unique()) - df[col] = df[col].map(vs.index) - - return df - - -def categoricals_to_onehots(df): - """Turn string features into onehot vectors.""" - categoricals = { - k: df[k].unique().tolist() - for k in df.columns - if not pd.api.types.is_numeric_dtype(df[k]) - } - - def onehots(row): - result = {} - for col in row.keys(): - category = categoricals.get(col, None) - if category is not None: - result[col] = np.zeros(len(category)) - result[col][category.index(row[col])] = 1.0 - return pd.Series(result) - - return df.apply(onehots, axis=1) - - -def whiten(df, reference_df=None, target='target'): - """Make the numerical data have zero means and unit variance.""" - df_ref = df if reference_df is None else reference_df - cols = [ - k for k in df.columns - if pd.api.types.is_numeric_dtype(df[k]) and k != target - ] - df_num = df[cols].astype(np.float32) - df_ref = df_ref[cols].astype(np.float32) - return (df_num - df_ref.mean()) / df_ref.std() - - -def get_dims(data): - """Given a record array, extract the dimensions of each column.""" - x, _ = data - dims = [x[0][name].shape for name in x.dtype.names] - return [1 if not d else d[0] for d in dims] - - -def load_train_test(config): - """Load the training data, the test data and the dimensions of the input.""" - train_path = os.path.join(config.folder, config.training_filename) - test_path = os.path.join(config.folder, config.test_filename) - info_path = os.path.join(config.folder, config.info_filename) - - train_df = load_df(train_path, info_path, config.protected, strip_target=True) - test_df = load_df( - test_path, info_path, config.protected, strip_target=False, skiprows=1 - ) - - result = [] - for df, ref_df in zip((train_df, test_df), (None, train_df)): - target_df = df['target'] - protected_df = df[config.protected] - protected_df.name = 'protected' - num_df = whiten(df, reference_df=ref_df, target='target') - cat_df = categoricals_to_onehots(df) - x = pd.concat([num_df, cat_df], axis=1).to_records(index=False) - y_true = pd.concat([protected_df, target_df], - axis=1).to_records(index=False) - result.append((x, y_true)) - - dims = [x[0][name].shape for name in result[0][0].dtype.names] - dims = [1 if not d else d[0] for d in dims] - return tuple(result) + (dims,) - - -def flatten(record): - """Turn the record array into a flat numpy array.""" - result = [np.stack(record[name]) for name in record.dtype.names] - result = [e[:, np.newaxis] if len(e.shape) == 1 else e for e in result] - return np.concatenate(result, axis=1) - - -def prepare_batch_for_pmap(batch): - """Prepare the batch with the proper shapes for multi-devices setups.""" - local_device_count = jax.local_device_count() - - def _prepare(x): - return x.reshape((local_device_count, -1) + x.shape[1:]) - - return jax.tree_map(_prepare, batch) - - -def generate(data, batch_size: int = 256, num_epochs: int = 1): - """Generate batches of examples, shuffling after each 'epoch'.""" - x, y_true = data - size = x.shape[0] - round_num_examples = (size // batch_size) * batch_size - num_epochs = round_num_examples if num_epochs is None else num_epochs - - count = 0 - while count < num_epochs: - order = np.arange(size) - np.random.shuffle(order) - x = x[order] - y_true = y_true[order] - for i in range(0, round_num_examples, batch_size): - end = i + batch_size - yield prepare_batch_for_pmap({ - 'features': flatten(x[i:end]), - 'label': y_true[i:end].target, - 'protected': y_true[i:end].protected, - }) - count += 1 diff --git a/examples/fairness/losses.py b/examples/fairness/losses.py deleted file mode 100644 index 63487a9ec..000000000 --- a/examples/fairness/losses.py +++ /dev/null @@ -1,80 +0,0 @@ -# Copyright OTT-JAX -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Losses for the fairness example.""" - -import functools - -import jax -import jax.numpy as jnp -import ott - - -def binary_cross_entropy(logits, labels): - return jnp.sum(-labels * jnp.log(logits) - (1 - labels) * jnp.log(1 - logits)) - - -def compute_metrics(logits, labels): - loss = binary_cross_entropy(logits, labels) - accuracy = jnp.mean((logits > 0.5) == labels) - metrics = { - 'loss': loss, - 'accuracy': accuracy, - } - metrics = jax.lax.pmean(metrics, axis_name='batch') - return metrics - - -@functools.partial(jax.jit, static_argnums=(2, 3)) -def sort_group( - inputs: jnp.ndarray, in_group: jnp.ndarray, quantization: int, - epsilon: float -) -> jnp.ndarray: - """Sorts and quantizes only the member of the given group. - - Args: - inputs: 1D array to be sorted. - in_group: a 1D array of 0s and 1s indicating if the element is part of the - group or not. - quantization: the number of values the sorted values output should be mapped - onto. - epsilon: sinkhorn entropic regularization. - - Returns: - A sorted array of size `quantization`. - """ - a = in_group / jnp.sum(in_group) - b = jnp.ones(quantization) / quantization - ot = ott.tools.soft_sort.transport_for_sort(inputs, a, b, epsilon=epsilon) - return 1.0 / b * ot.apply(inputs, axis=0) - - -def fairness_regularizer( - inputs: jnp.ndarray, - groups: jnp.ndarray, - quantization: int = 16, - epsilon: float = 1e-2, - num_groups: int = 2 -): - """Approximation of the wasserstein between the per-group distributions.""" - quantiles = jnp.stack([ - sort_group(inputs, groups == g, quantization, epsilon) - for g in range(num_groups) - ]) - weights = jnp.stack([jnp.sum(groups == g) for g in range(num_groups)] - ) / groups.shape[0] # noqa: E124 - mean_quantile = jnp.sum(weights[:, None] * quantiles, axis=0) - delta = jnp.where( - quantiles, quantiles - mean_quantile, jnp.zeros_like(mean_quantile) - ) - return jnp.mean(delta ** 2) diff --git a/examples/fairness/main.py b/examples/fairness/main.py deleted file mode 100644 index c4eafb1e4..000000000 --- a/examples/fairness/main.py +++ /dev/null @@ -1,59 +0,0 @@ -# Copyright OTT-JAX -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Runs the training of the network on CIFAR10.""" - -from typing import Sequence - -import jax -from absl import app, flags, logging -from clu import platform -from ml_collections import config_flags -from ott.examples.fairness import train - -FLAGS = flags.FLAGS - -flags.DEFINE_string( - 'workdir', '/tmp/soft_error/', 'Directory to store model data.' -) -config_flags.DEFINE_config_file( - 'config', - None, - 'File path to the training hyperparameter configuration.', - lock_config=True -) -flags.DEFINE_integer('seed', 0, 'Random seed') -FLAGS = flags.FLAGS - - -def main(argv: Sequence[str]) -> None: - if len(argv) > 1: - raise app.UsageError('Too many command-line arguments.') - - logging.info('JAX host: %d / %d', jax.host_id(), jax.host_count()) - logging.info('JAX local devices: %r', jax.local_devices()) - - # Add a note so that we can tell which task is which JAX host. - # (Depending on the platform task 0 is not guaranteed to be host 0) - platform.work_unit().set_task_status( - f'host_id: {jax.host_id()}, host_count: {jax.host_count()}' - ) - platform.work_unit().create_artifact( - platform.ArtifactType.DIRECTORY, FLAGS.workdir, 'workdir' - ) - - train.train_and_evaluate(FLAGS.workdir, FLAGS.config, FLAGS.seed) - - -if __name__ == '__main__': - app.run(main) diff --git a/examples/fairness/models.py b/examples/fairness/models.py deleted file mode 100644 index 3a62098da..000000000 --- a/examples/fairness/models.py +++ /dev/null @@ -1,53 +0,0 @@ -# Copyright OTT-JAX -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""A model for to embed structured features.""" - -from typing import Any, Tuple - -import flax.linen as nn -import jax.numpy as jnp - - -class FeaturesEncoder(nn.Module): - """Encodes structured features.""" - - input_dims: Tuple[int] - embed_dim: int = 32 - - @nn.compact - def __call__(self, x): - result = [] - index = 0 - for d in self.input_dims: - arr = x[..., index:index + d] - result.append(arr if d == 1 else nn.Dense(self.embed_dim)(arr)) - index += d - return jnp.concatenate(result, axis=-1) - - -class AdultModel(nn.Module): - """A model to predict if the income is above 50k (adult dataset).""" - - encoder_cls: Any - hidden: Tuple[int] = (64, 64) - - @nn.compact - def __call__(self, x, train: bool = True): - x = self.encoder_cls()(x) - for h in self.hidden: - x = nn.Dense(h)(x) - x = nn.relu(x) - x = nn.Dense(1)(x) - x = nn.sigmoid(x) - return x[..., 0] diff --git a/examples/fairness/train.py b/examples/fairness/train.py deleted file mode 100644 index 50146ad96..000000000 --- a/examples/fairness/train.py +++ /dev/null @@ -1,193 +0,0 @@ -# Copyright OTT-JAX -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Training a network on the adult dataset with fairnes constraints.""" - -import collections -import functools -from typing import Any - -import flax -import jax -import jax.numpy as jnp -import ml_collections -from flax import jax_utils -from flax.metrics import tensorboard -from flax.training import checkpoints, common_utils -from ott.examples.fairness import data, losses, models - - -@flax.struct.dataclass -class TrainState: - step: int - optimizer: flax.optim.Optimizer - model_state: Any - - -def initialized(key, model, size): - """Initialize the model.""" - - @jax.jit - def init(*args): - return model.init(*args) - - variables = init({'params': key}, jnp.ones((1, size))) - model_state, params = variables.pop('params') - return params, model_state - - -def create_train_state(rng, config, model, size): - """Create initial training state.""" - params, model_state = initialized(rng, model, size) - optimizer = flax.optim.Adam(learning_rate=config.learning_rate).create(params) - state = TrainState(step=0, optimizer=optimizer, model_state=model_state) - return state - - -def train_step(apply_fn, config, state, batch): - """Perform a single training step.""" - regularizer = functools.partial( - losses.fairness_regularizer, - quantization=config.quantization, - num_groups=config.num_groups, - epsilon=config.epsilon - ) - - def compute_loss(params): - variables = {'params': params, **state.model_state} - logits = apply_fn(variables, batch['features'], train=True) - loss = losses.binary_cross_entropy(logits, batch['label']) - reg = ( - regularizer(logits, batch['protected']) if config.fair_weight > 0 else 0 - ) - return loss + config.fair_weight * reg, logits - - grad_fn = jax.value_and_grad(compute_loss, has_aux=True) - aux, grad = grad_fn(state.optimizer.target) - # Re-use same axis_name as in the call to `pmap(...train_step...)` below. - grad = jax.lax.pmean(grad, axis_name='batch') - logits = aux[1] - new_optimizer = state.optimizer.apply_gradient(grad) - metrics = losses.compute_metrics(logits, batch['label']) - new_state = state.replace(step=state.step + 1, optimizer=new_optimizer) - return new_state, metrics - - -def eval_step(apply_fn, state, batch): - params = state.optimizer.target - variables = {'params': params, **state.model_state} - logits = apply_fn(variables, batch['features'], train=False, mutable=False) - return losses.compute_metrics(logits, batch['label']) - - -def log(results, epoch, summary, train=True, summary_writer=None): - """Log the metrics to stderr and tensorboard.""" - if jax.host_id() != 0: - return - - phase = 'train' if train else 'eval' - for key in ('loss', 'accuracy'): - results[f'{phase}_{key}'].append((epoch + 1, summary[key])) - print( - '{} epoch: {}, loss: {:.3f}, accuracy: {:.2%}'.format( - phase, epoch + 1, summary['loss'], summary['accuracy'] - ) - ) - - if summary_writer is None: - return - - for key, val in summary.items(): - summary_writer.scalar(f'{phase}_{key}', val, epoch) - - -def restore_checkpoint(state, workdir): - return checkpoints.restore_checkpoint(workdir, state) - - -def save_checkpoint(state, workdir): - if jax.host_id() != 0: - return - # Gets train state from the first replica. - state = jax.device_get(jax.tree_map(lambda x: x[0], state)) - step = int(state.step) - checkpoints.save_checkpoint(workdir, state, step, keep=3) - - -def train_and_evaluate( - workdir: str, config: ml_collections.ConfigDict, seed: int = 0 -): - """Execute model training and evaluation loop.""" - rng = jax.random.PRNGKey(seed) - - if config.batch_size % jax.device_count() > 0: - raise ValueError('Batch size must be divisible by the number of devices') - - if jax.host_id() == 0: - summary_writer = tensorboard.SummaryWriter(workdir) - summary_writer.hparams(dict(config)) - - local_batch_size = config.batch_size // jax.host_count() - train_ds, test_ds, dims = data.load_train_test(config) - train_iter = data.generate( - train_ds, batch_size=local_batch_size, num_epochs=config.num_epochs - ) - train_iter = jax_utils.prefetch_to_device(train_iter, 8) - - model = models.AdultModel( - encoder_cls=functools.partial( - models.FeaturesEncoder, input_dims=dims, embed_dim=config.embed_dim - ), - hidden=config.hidden_layers - ) - - state = create_train_state(rng, config, model, sum(dims)) - state = restore_checkpoint(state, workdir) - step_offset = int(state.step) - state = jax_utils.replicate(state) - - p_train_step = jax.pmap( - functools.partial(train_step, model.apply, config), axis_name='batch' - ) - p_eval_step = jax.pmap( - functools.partial(eval_step, model.apply), axis_name='batch' - ) - - steps_per_epoch = train_ds[0].shape[0] // config.batch_size - num_steps = steps_per_epoch * config.num_epochs - - results = collections.defaultdict(list) - epoch_metrics = [] - for step, batch in zip(range(step_offset, num_steps), train_iter): - state, metrics = p_train_step(state=state, batch=batch) - epoch_metrics.append(metrics) - - if (step + 1) % steps_per_epoch == 0: - epoch = step // steps_per_epoch - epoch_metrics = common_utils.get_metrics(epoch_metrics) - summary = jax.tree_map(lambda x: x.mean(), epoch_metrics) - log(results, epoch, summary, train=True, summary_writer=summary_writer) - - epoch_metrics = [] - eval_metrics = [] - for eval_batch in data.generate(test_ds, batch_size=local_batch_size): - metrics = p_eval_step(state, eval_batch) - eval_metrics.append(metrics) - eval_metrics = common_utils.get_metrics(eval_metrics) - summary = jax.tree_map(lambda x: x.mean(), eval_metrics) - log(results, epoch, summary, train=False, summary_writer=summary_writer) - save_checkpoint(state, workdir) - - # Wait until computations are done before exiting - jax.random.normal(jax.random.PRNGKey(0), ()).block_until_ready() - return results, state diff --git a/examples/soft_error/config.py b/examples/soft_error/config.py deleted file mode 100644 index 25700a8b8..000000000 --- a/examples/soft_error/config.py +++ /dev/null @@ -1,27 +0,0 @@ -# Copyright OTT-JAX -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Default Hyperparameter configuration.""" - -import ml_collections - - -def get_config(): - """Get the default hyperparameter configuration.""" - config = ml_collections.ConfigDict() - config.dataset = 'cifar10' - config.learning_rate = 1e-4 - config.batch_size = 64 - config.num_epochs = 100 - config.loss = 'soft_error' - return config diff --git a/examples/soft_error/data.py b/examples/soft_error/data.py deleted file mode 100644 index f45f12431..000000000 --- a/examples/soft_error/data.py +++ /dev/null @@ -1,92 +0,0 @@ -# Copyright OTT-JAX -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Data loading and data augmentation.""" - -import jax -import tensorflow as tf -import tensorflow_datasets as tfds -from flax import jax_utils - - -def random_shift(img, ratio=0.1): - """Apply a random shift on an input image.""" - height, _ = img.shape[:2] # Assumes squared images. - size = tf.random.uniform( - shape=(2,), - minval=int((1 - ratio) * height), - maxval=height, - dtype=tf.int32 - ) - size = tf.concat((size, [3]), axis=0) - img = tf.image.random_crop(img, size) - - deltas = tf.constant([32, 32, 3]) - size - for _ in tf.range(deltas[0]): - img = tf.pad( - img, [tf.random.shuffle([1, 0]), [0, 0], [0, 0]], mode='SYMMETRIC' - ) - for _ in tf.range(deltas[1]): - img = tf.pad( - img, [[0, 0], tf.random.shuffle([1, 0]), [0, 0]], mode='SYMMETRIC' - ) - return img - - -def prepare_tf_data(xs): - """Convert a input batch from tf Tensors to numpy arrays.""" - local_device_count = jax.local_device_count() - - def _prepare(x): - # Use _numpy() for zero-copy conversion between TF and NumPy. - x = x._numpy() # pylint: disable=protected-access - return x.reshape((local_device_count, -1) + x.shape[1:]) - - return jax.tree_map(_prepare, xs) - - -def create_input_iter(dataset_builder, batch_size: int, train: bool): - """Create an iterator over the training / test set.""" - split = tfds.Split.TRAIN if train else tfds.Split.TEST - ds = dataset_builder.as_dataset(split=split) - if train: - ds = ds.repeat() - ds = ds.shuffle(16 * batch_size, seed=0) - - def augment(inputs): - im = inputs['image'] - im = tf.image.random_flip_left_right(im) - im = random_shift(im, ratio=0.1) - inputs['image'] = im - return inputs - - def prepare(inputs): - im = inputs['image'] - inputs['image'] = tf.cast(im, tf.float32) / 255.0 - inputs['label'] = tf.one_hot(inputs['label'], 10) - inputs.pop('id') - return inputs - - ds = ds.map(prepare, num_parallel_calls=tf.data.AUTOTUNE) - ds = ds.cache() - if train: - ds = ds.map(augment, num_parallel_calls=tf.data.AUTOTUNE) - ds = ds.batch(batch_size, drop_remainder=True) - - if not train: - ds = ds.repeat() - - ds = ds.prefetch(tf.data.AUTOTUNE) - it = map(prepare_tf_data, ds) - it = jax_utils.prefetch_to_device(it, 8) - return it diff --git a/examples/soft_error/losses.py b/examples/soft_error/losses.py deleted file mode 100644 index 2f1052be4..000000000 --- a/examples/soft_error/losses.py +++ /dev/null @@ -1,49 +0,0 @@ -# Copyright OTT-JAX -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Defines classification losses.""" - -import functools - -import flax.linen as nn -import jax -import jax.numpy as jnp -from ott.tools import soft_sort - - -def cross_entropy_loss(logits: jnp.ndarray, labels: jnp.ndarray): - logits = nn.log_softmax(logits) - return -jnp.sum(labels * logits) / labels.shape[0] - - -def soft_error_loss( - logits: jnp.ndarray, labels: jnp.ndarray, epsilon: float = 1e-2 -): - """Average distance between the top rank and the rank of the true class.""" - ranks_fn = functools.partial(soft_sort.ranks, axis=-1, epsilon=epsilon) - ranks_fn = jax.jit(ranks_fn) - soft_ranks = ranks_fn(logits) - return jnp.mean( - nn.relu(labels.shape[-1] - 1 - jnp.sum(labels * soft_ranks, axis=1)) - ) - - -def get(name: str = 'cross_entropy'): - """Return the loss function corresponding to the input name.""" - losses = {'soft_error': soft_error_loss, 'cross_entropy': cross_entropy_loss} - result = losses.get(name, None) - if result is None: - raise ValueError( - f'Unknown loss {name}. Possible values: {",".join(losses)}' - ) - return result diff --git a/examples/soft_error/main.py b/examples/soft_error/main.py deleted file mode 100644 index 596fe18ee..000000000 --- a/examples/soft_error/main.py +++ /dev/null @@ -1,59 +0,0 @@ -# Copyright OTT-JAX -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Runs the training of the network on CIFAR10.""" - -from typing import Sequence - -import jax -from absl import app, flags, logging -from clu import platform -from ml_collections import config_flags -from ott.examples.soft_error import train - -FLAGS = flags.FLAGS - -flags.DEFINE_string( - 'workdir', '/tmp/soft_error/', 'Directory to store model data.' -) -config_flags.DEFINE_config_file( - 'config', - None, - 'File path to the training hyperparameter configuration.', - lock_config=True -) -flags.DEFINE_integer('seed', 0, 'Random seed') -FLAGS = flags.FLAGS - - -def main(argv: Sequence[str]) -> None: - if len(argv) > 1: - raise app.UsageError('Too many command-line arguments.') - - logging.info('JAX host: %d / %d', jax.host_id(), jax.host_count()) - logging.info('JAX local devices: %r', jax.local_devices()) - - # Add a note so that we can tell which task is which JAX host. - # (Depending on the platform task 0 is not guaranteed to be host 0) - platform.work_unit().set_task_status( - f'host_id: {jax.host_id()}, host_count: {jax.host_count()}' - ) - platform.work_unit().create_artifact( - platform.ArtifactType.DIRECTORY, FLAGS.workdir, 'workdir' - ) - - train.train_and_evaluate(FLAGS.workdir, FLAGS.config, FLAGS.seed) - - -if __name__ == '__main__': - app.run(main) diff --git a/examples/soft_error/model.py b/examples/soft_error/model.py deleted file mode 100644 index 97c685e72..000000000 --- a/examples/soft_error/model.py +++ /dev/null @@ -1,52 +0,0 @@ -# Copyright OTT-JAX -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Flax CNN model.""" - -from typing import Any - -import flax.linen as nn -import jax.numpy as jnp - - -class ConvBlock(nn.Module): - """A simple CNN blockl.""" - - features: int = 32 - dtype: Any = jnp.float32 - - @nn.compact - def __call__(self, x, train: bool = True): - x = nn.Conv(features=self.features, kernel_size=(3, 3))(x) - x = nn.relu(x) - x = nn.Conv(features=self.features, kernel_size=(3, 3))(x) - x = nn.relu(x) - x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2)) - return x - - -class CNN(nn.Module): - """A simple CNN model.""" - - num_classes: int = 10 - dtype: Any = jnp.float32 - - @nn.compact - def __call__(self, x, train: bool = True): - x = ConvBlock(features=32)(x) - x = ConvBlock(features=64)(x) - x = x.reshape((x.shape[0], -1)) # flatten - x = nn.Dense(features=512)(x) - x = nn.relu(x) - x = nn.Dense(features=self.num_classes)(x) - return x diff --git a/examples/soft_error/train.py b/examples/soft_error/train.py deleted file mode 100644 index 1a2668af5..000000000 --- a/examples/soft_error/train.py +++ /dev/null @@ -1,198 +0,0 @@ -# Copyright OTT-JAX -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Train for the soft-error loss.""" - -import collections -import functools -from typing import Any - -import flax -import jax -import jax.numpy as jnp -import ml_collections -import tensorflow_datasets as tfds -from flax import jax_utils -from flax.metrics import tensorboard -from flax.training import checkpoints, common_utils -from ott.examples.soft_error import data, losses -from ott.examples.soft_error import model as model_lib - - -def initialized(key, height, width, model): - """Initialize the model parameters.""" - input_shape = (1, height, width, 3) - - @jax.jit - def init(*args): - return model.init(*args) - - variables = init({'params': key}, jnp.ones(input_shape, jnp.float32)) - model_state, params = variables.pop('params') - return params, model_state - - -def compute_metrics(logits, labels, loss_fn): - loss = loss_fn(logits, labels) - accuracy = jnp.mean(jnp.argmax(logits, -1) == jnp.argmax(labels, -1)) - metrics = { - 'loss': loss, - 'accuracy': accuracy, - } - metrics = jax.lax.pmean(metrics, axis_name='batch') - return metrics - - -def train_step(apply_fn, loss_fn, state, batch): - """Perform a single training step.""" - - def compute_loss(params): - variables = {'params': params, **state.model_state} - logits = apply_fn(variables, batch['image']) - loss = loss_fn(logits, batch['label']) - return loss, logits - - grad_fn = jax.value_and_grad(compute_loss, has_aux=True) - aux, grad = grad_fn(state.optimizer.target) - # Re-use same axis_name as in the call to `pmap(...train_step...)` below. - grad = jax.lax.pmean(grad, axis_name='batch') - logits = aux[1] - new_optimizer = state.optimizer.apply_gradient(grad) - metrics = compute_metrics(logits, batch['label'], loss_fn=loss_fn) - new_state = state.replace(step=state.step + 1, optimizer=new_optimizer) - return new_state, metrics - - -def eval_step(apply_fn, loss_fn, state, batch): - params = state.optimizer.target - variables = {'params': params, **state.model_state} - logits = apply_fn(variables, batch['image'], train=False, mutable=False) - return compute_metrics(logits, batch['label'], loss_fn=loss_fn) - - -@flax.struct.dataclass -class TrainState: - step: int - optimizer: flax.optim.Optimizer - model_state: Any - - -def create_train_state(rng, config, model, height, width): - """Create initial training state.""" - params, model_state = initialized(rng, height, width, model) - optimizer = flax.optim.Adam(learning_rate=config.learning_rate).create(params) - state = TrainState(step=0, optimizer=optimizer, model_state=model_state) - return state - - -def log(results, epoch, summary, train=True, summary_writer=None): - """Log the metrics to stderr and tensorboard.""" - if jax.host_id() != 0: - return - - phase = 'train' if train else 'eval' - for key in ('loss', 'accuracy'): - results[f'{phase}_{key}'].append((epoch + 1, summary[key])) - print( - '{} epoch: {}, loss: {:.3f}, accuracy: {:.2%}'.format( - phase, epoch + 1, summary['loss'], summary['accuracy'] - ) - ) - - for key, val in summary.items(): - summary_writer.scalar(f'{phase}_{key}', val, epoch) - - -def restore_checkpoint(state, workdir): - return checkpoints.restore_checkpoint(workdir, state) - - -def save_checkpoint(state, workdir): - if jax.host_id() == 0: - # get train state from the first replica - state = jax.device_get(jax.tree_map(lambda x: x[0], state)) - step = int(state.step) - checkpoints.save_checkpoint(workdir, state, step, keep=3) - - -def train_and_evaluate( - workdir: str, config: ml_collections.ConfigDict, seed: int = 0 -): - """Execute model training and evaluation loop.""" - rng = jax.random.PRNGKey(seed) - - if config.batch_size % jax.device_count() > 0: - raise ValueError('Batch size must be divisible by the number of devices') - - if jax.host_id() == 0: - summary_writer = tensorboard.SummaryWriter(workdir) - summary_writer.hparams(dict(config)) - - loss_fn = losses.get(config.loss) - local_batch_size = config.batch_size // jax.host_count() - dataset_builder = tfds.builder(config.dataset) - info = dataset_builder.info - height, width = info.features['image'].shape[:2] - train_iter = data.create_input_iter( - dataset_builder, local_batch_size, train=True - ) - eval_iter = data.create_input_iter( - dataset_builder, local_batch_size, train=False - ) - steps_per_epoch = info.splits['train'].num_examples // config.batch_size - num_steps = int(steps_per_epoch * config.num_epochs) - num_validation_examples = info.splits['test'].num_examples - steps_per_eval = num_validation_examples // config.batch_size - - num_classes = info.features['label'].num_classes - model = model_lib.CNN(num_classes=num_classes, dtype=jnp.float32) - state = create_train_state(rng, config, model, height, width) - state = restore_checkpoint(state, workdir) - # step_offset > 0 if restarting from checkpoint - step_offset = int(state.step) - state = jax_utils.replicate(state) - - p_train_step = jax.pmap( - functools.partial(train_step, model.apply, loss_fn), axis_name='batch' - ) - p_eval_step = jax.pmap( - functools.partial(eval_step, model.apply, loss_fn), axis_name='batch' - ) - - results = collections.defaultdict(list) - epoch_metrics = [] - for step, batch in zip(range(step_offset, num_steps), train_iter): - state, metrics = p_train_step(state=state, batch=batch) - epoch_metrics.append(metrics) - - if (step + 1) % steps_per_epoch == 0: - epoch = step // steps_per_epoch - epoch_metrics = common_utils.get_metrics(epoch_metrics) - summary = jax.tree_map(lambda x: x.mean(), epoch_metrics) - log(results, epoch, summary, train=True, summary_writer=summary_writer) - - epoch_metrics = [] - eval_metrics = [] - for _ in range(steps_per_eval): - eval_batch = next(eval_iter) - metrics = p_eval_step(state, eval_batch) - eval_metrics.append(metrics) - eval_metrics = common_utils.get_metrics(eval_metrics) - summary = jax.tree_map(lambda x: x.mean(), eval_metrics) - log(results, epoch, summary, train=False, summary_writer=summary_writer) - - save_checkpoint(state, workdir) - - # Wait until computations are done before exiting - jax.random.normal(jax.random.PRNGKey(0), ()).block_until_ready() - return results, state