# Training CNN on Custom Dataset

Oumi is not limited to LLMs. This example shows how to train a simple ConvNet classifier on a custom dataset cotaining binaray data in Numpy `.npz` file. The dataset is created from the classic MNIST dataset (hand-written digits classification).

# Prerequisites
## Oumi Installation
First, let's install Oumi. You can find detailed instructions [here](https://github.com/oumi-ai/oumi/blob/main/README.md), but it should be as simple as:

```bash
pip install -e ".[gpu]"  # if you have an nvidia or AMD GPU
# OR
pip install -e "."  # if you don't have a GPU
```

## Environment Setup: Common Imports and Variables

In [1]:
import os
from pathlib import Path

import numpy as np
import torchvision

tutorial_dir = "cnn_mnist_example"

Path(tutorial_dir).mkdir(parents=True, exist_ok=True)
os.environ["TOKENIZERS_PARALLELISM"] = "false"  # Disable warnings from HF

# Data
## Data Preparation
First, let's convert MNIST dataset to `.npz` archive.

In [2]:
images = []
labels = []
splits = []
for train_split in (False, True):
    mnist_dataset = torchvision.datasets.MNIST(
        root=Path("/tmp/mnist_data"),
        train=train_split,
        download=True,
    )
    num_examples = len(mnist_dataset)
    images.extend(
        [np.asarray(mnist_dataset.data[i], dtype=np.uint8) for i in range(num_examples)]
    )
    labels.extend([int(mnist_dataset.targets[i]) for i in range(num_examples)])
    splits.extend([("train" if train_split else "test")] * num_examples)

npz_filename = (Path(tutorial_dir) / "mnist.npz").absolute()

# Normalize and convert [N,W,H] to [N,C,W,H] by adding dummy C=1 (PyTorch convention).
images = np.expand_dims((np.stack(images).astype(dtype=np.float32) / 255.0), axis=1)
np.savez_compressed(
    npz_filename, images=images, labels=np.stack(labels), split=np.stack(splits)
)
print(f"Saved {len(labels)} examples to '{npz_filename}'!")

Saved 70000 examples to '/home/user/oumi/notebooks/cnn_mnist_example/mnist.npz'!


Let's define Oumi custom dataset that can load MNIST data from `.npz` archive. For more details, refer to: https://oumi.ai/docs/latest/resources/datasets/datasets.html

In [3]:
from typing import Optional, Union

import numpy as np
import pandas as pd
from typing_extensions import override

from oumi.core.datasets import BaseMapDataset
from oumi.core.registry import register_dataset


@register_dataset("npz_file")
class NpzDataset(BaseMapDataset):
    """Loads dataset from Numpy .npz archive."""

    default_dataset = "custom"

    def __init__(
        self,
        *,
        dataset_name: Optional[str] = None,
        dataset_path: Optional[Union[str, Path]] = None,
        split: Optional[str] = None,
        npz_split_col: Optional[str] = None,
        npz_allow_pickle: bool = False,
        **kwargs,
    ) -> None:
        """Initializes a new instance of the NpzDataset class.

        Args:
            dataset_name: Dataset name.
            dataset_path: Path to .npz file.
            split: Dataset split.
            npz_split_col: Name of '.npz' array containing dataset split info.
                If unspecified, then the name "split" is assumed by default.
            npz_allow_pickle: Whether pickle is allowed when loading data
                from the npz archive.
            **kwargs: Additional arguments to pass to the parent class.

        Raises:
            ValueError: If dataset_path is not provided, or
                if .npz file contains data in unexpected format.
        """
        if not dataset_path:
            raise ValueError("`dataset_path` must be provided")
        super().__init__(
            dataset_name=dataset_name,
            dataset_path=(str(dataset_path) if dataset_path is not None else None),
            split=split,
            **kwargs,
        )
        self._npz_allow_pickle = npz_allow_pickle
        self._npz_split_col = npz_split_col

        dataset_path = Path(dataset_path)
        if not dataset_path.is_file():
            raise ValueError(f"Path is not a file! '{dataset_path}'")
        elif dataset_path.suffix.lower() != ".npz":
            raise ValueError(f"File extension is not '.npz'! '{dataset_path}'")

        self._data = self._load_data()

    @staticmethod
    def _to_list(x: np.ndarray) -> list:
        # `pd.DataFrame` expects Python lists for columns
        # (elements can still be `ndarray`)
        if len(x.shape) > 1:
            return [x[i, ...] for i in range(x.shape[0])]
        return x.tolist()

    @override
    def _load_data(self) -> pd.DataFrame:
        data_dict = {}
        if not self.dataset_path:
            raise ValueError("dataset_path is empty!")
        with np.load(self.dataset_path, allow_pickle=self._npz_allow_pickle) as npzfile:
            feature_names = list(sorted(npzfile.files))
            if len(feature_names) == 0:
                raise ValueError(
                    f"'.npz' archive contains no data! '{self.dataset_path}'"
                )
            num_examples = None
            for feature_name in feature_names:
                col_data = npzfile[feature_name]
                assert isinstance(col_data, np.ndarray)
                if num_examples is None:
                    num_examples = col_data.shape[0]
                elif num_examples != col_data.shape[0]:
                    raise ValueError(
                        "Inconsistent number of examples for features "
                        f"'{feature_name}' and '{feature_names[0]}': "
                        f"{col_data.shape[0]} vs {num_examples}!"
                    )
                data_dict[feature_name] = self._to_list(col_data)

        dataframe: pd.DataFrame = pd.DataFrame(data_dict)

        split_feature_name = (self._npz_split_col or "split") if self.split else None
        if split_feature_name:
            if split_feature_name not in dataframe:
                raise ValueError(
                    f"'.npz' doesn't contain data split info: '{split_feature_name}'!"
                )
            dataframe = pd.DataFrame(
                dataframe[dataframe[split_feature_name] == self.split].drop(
                    split_feature_name, axis=1
                ),
                copy=True,
            )
        return dataframe

    @override
    def transform(self, sample: pd.Series) -> dict:
        """Preprocesses the inputs in the given sample."""
        return sample.to_dict()

# Training a Model

Oumi provides the sample `CnnClassfier` model [[source](https://github.com/oumi-ai/oumi/blob/main/src/oumi/models/cnn_classifier.py)]. Let's use it to train a classifier for MNIST hand-written digits.

Oumi uses [training configuration files](https://oumi.ai/docs/latest/api/oumi.core.configs.html#oumi.core.configs.TrainingConfig) to specify training parameters. We've already created a training config for `CnnClassfier`--let's give it a try!

In [4]:
yaml_content = f"""
model:
  model_name: "CnnClassifier"
  torch_dtype_str: "float32"
  load_pretrained_weights: False
  model_kwargs:
      image_width: 28   # MNIST images are 28x28 single channel
      image_height: 28
      in_channels: 1
      output_dim: 10    # Number of output classes: 10 digits

data:
  train:
    experimental_use_torch_datapipes: True
    datasets:
      - dataset_name: "npz_file" # Custom dataset defined above for .npz archives
        dataset_path: "{npz_filename}"
        split: "train"

training:
  trainer_type: "OUMI"  # For non-transformers, use "OUMI" trainer
  per_device_train_batch_size: 64
  num_train_epochs: 2  # Quick "mini" training, for demo purposes only
  logging_steps: 500
  run_name: "mnist_cnn_classifier"
  output_dir: "{tutorial_dir}/output"
"""

with open(f"{tutorial_dir}/train.yaml", "w") as f:
    f.write(yaml_content)

In [5]:
from oumi.core.configs import TrainingConfig
from oumi.train import train

config = TrainingConfig.from_yaml(str(Path(tutorial_dir) / "train.yaml"))

train(config)

[2025-01-20 18:45:15,743][oumi][rank0][pid:2307328][MainThread][INFO]][torch_utils.py:66] Torch version: 2.4.0+cu121. NumPy version: 1.26.4
[2025-01-20 18:45:15,745][oumi][rank0][pid:2307328][MainThread][INFO]][torch_utils.py:72] CUDA version: 12.1 CuDNN version: 90.1.0
[2025-01-20 18:45:15,936][oumi][rank0][pid:2307328][MainThread][INFO]][torch_utils.py:106] CPU cores: 24 CUDA devices: 1
device(0)='NVIDIA GeForce RTX 3090' Capability: (8, 6) Memory: [Total: 24.0GiB Free: 22.76GiB Allocated: 0.0GiB Cached: 0.0GiB]
[2025-01-20 18:45:15,938][oumi][rank0][pid:2307328][MainThread][INFO]][train.py:133] Oumi version: 0.1.3.dev2+g6b9b1a1f.d20250121
[2025-01-20 18:45:15,940][oumi][rank0][pid:2307328][MainThread][INFO]][train.py:135] Git revision hash: 6b9b1a1f6bbe787515c6bc6b067ec0d479db5f96
[2025-01-20 18:45:15,943][oumi][rank0][pid:2307328][MainThread][INFO]][train.py:136] Git tag: None
[2025-01-20 18:45:15,944][oumi][rank0][pid:2307328][MainThread][INFO]][train.py:174] TrainingConfig: Train

Training:   0%|          | 0/1876 [00:00<?, ?it/s]

[2025-01-20 18:45:19,228][oumi][rank0][pid:2307328][MainThread][INFO]][oumi_trainer.py:604] {'epoch': 0,
 'global_step': 500,
 'global_steps_per_second': 183.7019607898465,
 'learning_rate': 3.670042643923241e-05,
 'tokens_per_second': 0.0,
 'tokens_per_step_per_gpu': 0.0,
 'total_tokens_seen': 0,
 'train/loss': 0.9217740297317505}
[2025-01-20 18:45:19,231][oumi.telemetry][rank0][pid:2307328][MainThread][INFO]][telemetry.py:332] Telemetry Summary (PUGET-SYSTEMS):
Total time: 2.75 seconds

CPU Timers:
	fetching batch:
		Total: 1.1436s Mean: 0.0023s Median: 0.0022s
		Min: 0.0021s Max: 0.0046s StdDev: 0.0003s
		Count: 500.0 Percentage of total time: 41.60%
	computing tokens:
		Total: 0.0003s Mean: 0.0000s Median: 0.0000s
		Min: 0.0000s Max: 0.0000s StdDev: 0.0000s
		Count: 500.0 Percentage of total time: 0.01%
	moving batch to device:
		Total: 0.0182s Mean: 0.0000s Median: 0.0000s
		Min: 0.0000s Max: 0.0001s StdDev: 0.0000s
		Count: 500.0 Percentage of total time: 0.66%
	model forward:
		

Congratulations, you've trained your first CNN using custom dataset (`numpy` arrays) using Oumi!