# Performing Inference

In [None]:
import jax
import jax.numpy as jnp
from jax.sharding import Mesh
from jax.sharding import NamedSharding

import flax

from rich import print
import numpy as np
import os

from torch_to_flax import torch_to_flax
from model_flax_sharded import get_partition_rules, Qwen2Config, Qwen2ForCausalLM

# Set this environment variable before importing JAX.
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.9375"


def init_sharded_model():
    """Initialize a model with parameters on CPU then shard them to GPU after checkpoint loading."""
    # Create config and model
    config = Qwen2Config()
    model = Qwen2ForCausalLM(config=config)

    rng = jax.random.PRNGKey(0)
    input_shape = (1, 32)

    # Force initialization on CPU to avoid duplicate GPU allocations
    with jax.default_device(jax.devices("cpu")[0]):
        try:
            params = model.init(rng, jnp.ones(input_shape, dtype=jnp.int4))
        except Exception as e:
            params = model.init(rng, jnp.ones(input_shape, dtype=jnp.int32))


    # Get available JAX devices and create mesh
    devices = jax.devices()
    device_mesh = np.array(devices).reshape(-1)  # 1D mesh
    mesh = Mesh(device_mesh, ("mp",))

    partition_rules = get_partition_rules()

    # Helper to match a parameter path to a partition spec
    def get_spec(path: str):
        for rule_path, spec in partition_rules:
            if rule_path in path:
                return spec
        return None

    # Build sharding specs for each param in the tree
    def create_sharding_specs(param_tree):
        def assign_spec(path, value):
            # Convert tuple path into a slash-joined string
            path_str = "/".join(str(p) for p in path)
            matched_spec = get_spec(path_str)
            if matched_spec is None:
                return NamedSharding(mesh, None)
            return NamedSharding(mesh, matched_spec)

        return jax.tree_util.tree_map_with_path(assign_spec, param_tree)

    sharding_specs = create_sharding_specs(params)

    # Load the parameters from the file
    try:
        with open("flax_params.msgpack", "rb") as f:
            loaded_bytes = f.read()
            loaded_params = {
                "params": flax.serialization.from_bytes(params["params"], loaded_bytes)
            }
    except FileNotFoundError:
        print("File not found. Running conversion...")
        torch_to_flax()
        with open("flax_params.msgpack", "rb") as f:
            loaded_bytes = f.read()
            loaded_params = {
                "params": flax.serialization.from_bytes(params["params"], loaded_bytes)
            }


    del params
    # After loading 'loaded_params' from disk (which might be plain NumPy arrays)
    sharded_params = jax.tree_util.tree_map(
        lambda x, spec: jax.device_put(x, spec), loaded_params, sharding_specs
    )

    return model, sharded_params

model, sharded_params = init_sharded_model()
print("Sharded model initialized.")


In [2]:
from transformers import AutoTokenizer
model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
tokenizer = AutoTokenizer.from_pretrained(model_name)


prompt = "What is 3 + 4? <think>\n"
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = jnp.array(inputs["input_ids"].numpy())
# Generate
print("Generating tokens...")
output = model.generate(
    sharded_params,
    input_ids,
    max_new_tokens=100,
    temperature=0.7,
    do_sample=True,
    prng_key=jax.random.PRNGKey(0),
)

# Decode using your tokenizer
decoded = tokenizer.decode(np.array(output[0]))
print("Decoded text:", decoded)

100%|██████████| 100/100 [00:13<00:00,  7.25it/s]
