# Deep Learning HPC Demo

This notebook demonstrates the usage of the HPC-ready deep learning framework.

In [None]:
import sys
sys.path.append('..')

import jax
import jax.numpy as jnp
import numpy as np
from src.models.flax_cnn import create_model
from src.training.train_hpc import create_train_state
import ml_collections
import yaml

In [None]:
with open('../config/train_config.yaml', 'r') as f:
    config_dict = yaml.safe_load(f)

config = ml_collections.ConfigDict(config_dict)

In [None]:
rng = jax.random.PRNGKey(0)
model = create_model(num_classes=config.model.num_classes)
variables = model.init(rng, jnp.ones([1, *config.model.input_shape]))
params = variables['params']

print(f"Model created with input shape: {config.model.input_shape}")
print(f"Number of classes: {config.model.num_classes}")

In [None]:
state = create_train_state(rng, config)
print(f"Training state initialized with optimizer: {config.training.optimizer}")
print(f"Learning rate: {config.training.learning_rate}")

In [None]:
sample_input = jnp.ones([1, *config.model.input_shape])
output = model.apply({'params': params}, sample_input, train=False)
print(f"Model output shape: {output.shape}")
print(f"Output sample: {output[0][:5]}")

## Distributed Training Setup

The following code demonstrates how to set up distributed training across multiple nodes.

In [None]:
import jax.tools.multihost_utils as multihost_utils

print(f"Number of devices: {jax.device_count()}")
print(f"Number of local devices: {jax.local_device_count()}")
print(f"Process count: {jax.process_count()}")
print(f"Process index: {jax.process_index()}")

## Model Serving with Ray

The following code demonstrates how to deploy the model using Ray Serve.

In [None]:
import ray
from ray import serve

ray.init(ignore_reinit_error=True)
serve.start(detached=True)

print("Ray cluster initialized for model serving")