# PyTorch DDP Speech Recognition Training Example

This example demonstrates how to train a transformer network to classify audio words with Google's [Speech Command](https://huggingface.co/datasets/google/speech_commands). It's a very small dataset that contains words for classification. The dataset is small(2.3G) and it's quite fast to train a small model.

This notebook walks you through running that example locally, and how to easily scale PyTorch DDP across multiple nodes with Kubeflow TrainJob.


## Prepare the Kubernetes environment using Kind

If you already have your own Kubernetes cluster, you can skip this step.

For demo purpose, we will create a k8s cluster with [Kind](https://kind.sigs.k8s.io/). In the same folder of this example Jupyter notebook file, there is a Kind file called `kind-config.yaml`. It will create a k8s cluster with 3 workers and /data from host server is mounted to kind k8s cluster server. Therefore you can download data to /data in your local machine and can be accessed from kind cluster as well.

To create the kind cluster, run the following command:
**Notice** This will create a Kind cluster named 'ml', you only need to run this command once. 


In [7]:
!kind create cluster --name ml --config kind-config.yaml

Creating cluster "ml" ...
 [32m✓[0m Ensuring node image (kindest/node:v1.34.0) 🖼
 [32m✓[0m Preparing nodes 📦 📦 📦 📦 7l
 [32m✓[0m Writing configuration 📜7l
 [32m✓[0m Starting control-plane 🕹️7l
 [32m✓[0m Installing CNI 🔌7l
 [32m✓[0m Installing StorageClass 💾7l
 [32m✓[0m Joining worker nodes 🚜7l
Set kubectl context to "kind-ml"
You can now use your cluster with:

kubectl cluster-info --context kind-ml

Have a question, bug, or feature request? Let us know! https://kind.sigs.k8s.io/#community 🙂


## Add CRD and Kubeflow Trainer operator to Kubernetes cluster

The full instruction is at [here](https://www.kubeflow.org/docs/components/trainer/operator-guides/installation/). In short, run this command:

In [8]:
!export VERSION=v2.0.0
!kubectl apply --server-side -k "https://github.com/kubeflow/trainer.git/manifests/overlays/manager?ref=${VERSION}"

namespace/kubeflow-system serverside-applied
customresourcedefinition.apiextensions.k8s.io/clustertrainingruntimes.trainer.kubeflow.org serverside-applied
customresourcedefinition.apiextensions.k8s.io/jobsets.jobset.x-k8s.io serverside-applied
customresourcedefinition.apiextensions.k8s.io/trainingruntimes.trainer.kubeflow.org serverside-applied
customresourcedefinition.apiextensions.k8s.io/trainjobs.trainer.kubeflow.org serverside-applied
serviceaccount/jobset-controller-manager serverside-applied
serviceaccount/kubeflow-trainer-controller-manager serverside-applied
role.rbac.authorization.k8s.io/jobset-leader-election-role serverside-applied
clusterrole.rbac.authorization.k8s.io/jobset-manager-role serverside-applied
clusterrole.rbac.authorization.k8s.io/jobset-metrics-reader serverside-applied
clusterrole.rbac.authorization.k8s.io/jobset-proxy-role serverside-applied
clusterrole.rbac.authorization.k8s.io/kubeflow-trainer-controller-manager serverside-applied
rolebinding.rbac.authoriz

## Prepare Docker Image

We need to create a Docker image with requirements.txt, the `Dockerfile` and `requirements.txt` can be found at the same folder of this Jupyter Notebook.

To build:

In [9]:
!docker build -t speech-recognition-image:0.1 -f Dockerfile .

[1A[1B[0G[?25l[+] Building 0.0s (0/1)                                          docker:default
[?25h[1A[0G[?25l[+] Building 0.2s (1/2)                                          docker:default
[34m => [internal] load build definition from Dockerfile                       0.0s
[0m[34m => => transferring dockerfile: 197B                                       0.0s
[0m => [internal] load metadata for docker.io/pytorch/pytorch:2.8.0-cuda12.8  0.2s
[?25h[1A[1A[1A[1A[0G[?25l[+] Building 0.3s (1/2)                                          docker:default
[34m => [internal] load build definition from Dockerfile                       0.0s
[0m[34m => => transferring dockerfile: 197B                                       0.0s
[0m => [internal] load metadata for docker.io/pytorch/pytorch:2.8.0-cuda12.8  0.3s
[?25h[1A[1A[1A[1A[0G[?25l[+] Building 0.5s (1/2)                                          docker:default
[34m => [internal] load build definition from Dockerfile     

### Load image to Kind cluster

#### Kind cluster

If you are using a local Kind cluster, run the following command to load docker image to your local cluster

In [25]:
!kind load docker-image speech-recognition-image:0.1 --name ml

Image: "speech-recognition-image:0.1" with ID "sha256:f98d06d275aa85a352ca3b4ee886fd7c10a052fef62e430c01853e6a1cffc689" not yet present on node "ml-worker2", loading...
Image: "speech-recognition-image:0.1" with ID "sha256:f98d06d275aa85a352ca3b4ee886fd7c10a052fef62e430c01853e6a1cffc689" not yet present on node "ml-worker", loading...
Image: "speech-recognition-image:0.1" with ID "sha256:f98d06d275aa85a352ca3b4ee886fd7c10a052fef62e430c01853e6a1cffc689" not yet present on node "ml-worker3", loading...
Image: "speech-recognition-image:0.1" with ID "sha256:f98d06d275aa85a352ca3b4ee886fd7c10a052fef62e430c01853e6a1cffc689" not yet present on node "ml-control-plane", loading...


#### Kubernetes cluster

If you are not using the local Kind cluster for testing. Please upload the code to your own Docker registry.

```bash
docker image push <your docker image name with registry info>
```

## Add Runtime to K8s cluster

In the same folder of this Jypyter notebook file, there is a `kubeflow-runtime-example.yaml`. 

**Please modify the image to the one you just uploaded.**


In [19]:
!kubectl apply -f kubeflow-runtime-example.yaml

clustertrainingruntime.trainer.kubeflow.org/torch-distributed-speech-recognition created


## Install the Kubeflow SDK

You need to install the Kubeflow SDK to interact with Kubeflow Trainer APIs:

In [24]:
!pip install git+https://github.com/kubeflow/sdk.git@main

Collecting git+https://github.com/kubeflow/sdk.git@main
  Cloning https://github.com/kubeflow/sdk.git (to revision main) to /tmp/pip-req-build-dhm1y012
  Running command git clone --filter=blob:none --quiet https://github.com/kubeflow/sdk.git /tmp/pip-req-build-dhm1y012
  Resolved https://github.com/kubeflow/sdk.git to commit 6709dcff0f3e68d44b37531d3154829e626f4b62
  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone


## Prepare Speech Command Dataset

For demo purpose and to simply the process, we are downloading data to /data in the host server. And in the Kind cluster, it's mounting the host's /data folder to cluster's server's /data folder. And in the Kubeflow Runtime, it's mounting the data with hostpath on /data. Therefore everyone is accessing data in /data folder. **Please make sure there is /data folder in the host server.**

For other clusters, please create a volume to make sure data can be accessed via /data.

The exact path for Speech Command dataset is `/data/SpeechCommands/speech_commands_v0.02`.

To download data, run the code below.


In [15]:
import torchaudio

print("Downloading SpeechCommands dataset...")

# This command will download the data to a folder named "SpeechCommands"
# in your current directory if it's not already there.
train_dataset = torchaudio.datasets.SPEECHCOMMANDS(root="/data", download=True)

print("Download complete!")
print(f"Number of training samples: {len(train_dataset)}")



Downloading SpeechCommands dataset...
Download complete!
Number of training samples: 105829


In [6]:
!ls /data/SpeechCommands/speech_commands_v0.02

_background_noise_  five     left     README.md		tree
backward	    follow   LICENSE  right		two
bed		    forward  marvin   seven		up
bird		    four     nine     sheila		validation_list.txt
cat		    go	     no       six		visual
dog		    happy    off      stop		wow
down		    house    on       testing_list.txt	yes
eight		    learn    one      three		zero


## Create TrainerClient with Kubeflow Trainer SDK


In [1]:
from kubeflow.trainer import CustomTrainer, TrainerClient

client = TrainerClient()

## Get runtime from K8s cluster

After running the below cell, you should see something like the below. If the following cell shows nothing, it mostly because the Custom Kubeflow Runtime is not created. Please go back to previous step to create Kubeflow Runtime with `kubectl`.
```
Runtime(name='torch-distributed-speech-recognition', trainer=RuntimeTrainer(trainer_type=<TrainerType.CUSTOM_TRAINER: 'CustomTrainer'>, framework='torch', num_nodes=1, device='Unknown', device_count='Unknown'), pretrained_model=None)
```



In [2]:
for runtime in client.list_runtimes():
    print(runtime)
    if runtime.name == "torch-distributed-speech-recognition":
        torch_runtime = runtime


Runtime(name='torch-distributed-speech-recognition', trainer=RuntimeTrainer(trainer_type=<TrainerType.CUSTOM_TRAINER: 'CustomTrainer'>, framework='torch', num_nodes=1, device='Unknown', device_count='Unknown'), pretrained_model=None)


## Start training

The training code is in the `train_with_kubeflow_trainer.py`, which is in the same folder of current Jupyter Notebook.

In [34]:
def train_model():
    # 1. IMPORTS
    # ---
    import torch
    import torch.nn as nn
    from torch.utils.data import Dataset, DataLoader
    import torchaudio
    import os  # To navigate file paths
    import json
    from torch.utils.tensorboard import SummaryWriter
    import random
    from datetime import datetime
    import torch.distributed as dist
    from torch.utils.data.distributed import DistributedSampler
    from torch.nn.parallel import DistributedDataParallel as DDP

    debug = False

    # Its job: Load an audio file, convert it to a spectrogram, and return it with its numerical label.
    def load_data(data_path):
        audio_info = []
        label_map = {}
        label_map_reverse = {}
        # Walk through the data directory to find all audio files
        # full_data_path = os.path.join(os.path.dirname(__file__), data_path)
        full_data_path = data_path

        # Get all subdirectories (word labels), excluding special directories
        labels = []
        for item in os.listdir(full_data_path):
            item_path = os.path.join(full_data_path, item)
            if (
                os.path.isdir(item_path)
                and not item.startswith("_")
                and item != "LICENSE"
            ):
                labels.append(item)

        # Sort labels for consistent mapping
        labels.sort()

        # Create label to integer mapping
        label_map = {label: idx for idx, label in enumerate(labels)}
        label_map_reverse = {idx: label for label, idx in enumerate(labels)}

        # Collect all audio files
        for label in labels:
            label_dir = os.path.join(full_data_path, label)
            for filename in os.listdir(label_dir):
                if filename.endswith(".wav"):
                    # Store relative path from voice-recognition folder
                    relative_path = os.path.join(data_path, label, filename)
                    audio_info.append({"filename": relative_path, "label": label})
        return audio_info, label_map, label_map_reverse

    def split_data(audio_info):
        # Don't shuffle here - let DistributedSampler handle shuffling
        # This ensures proper distributed sampling
        random.seed(41)
        random.shuffle(audio_info)
        print(f"audio info length: {len(audio_info)}")
        train_size = int(len(audio_info) * 0.95)
        val_size = int(len(audio_info) * 0.03)
        test_size = len(audio_info) - train_size - val_size
        print(f"train size: {train_size}, val size: {val_size}, test size: {test_size}")

        audio_info_training = audio_info[:train_size]
        audio_info_validation = audio_info[train_size : train_size + val_size]
        audio_info_test = audio_info[train_size + val_size :]
        return audio_info_training, audio_info_validation, audio_info_test

    class SpeechCommandsDataset(Dataset):
        def __init__(
            self, audio_info, label_map, label_map_reverse, data_path_prefix=None
        ):
            self.data_path_prefix = data_path_prefix
            self.audio_info = audio_info
            self.label_map = label_map
            self.label_map_reverse = label_map_reverse
            self.transform = torchaudio.transforms.MelSpectrogram(n_mels=128)

            print(json.dumps(self.audio_info[0:3], indent=4))
            print(self.label_map)

        def __len__(self):
            return len(self.audio_info)

        def __getitem__(self, index):
            audio_info = self.audio_info[index]

            # Build full path to audio file
            if self.data_path_prefix is None:
                audio_path = os.path.join(
                    os.path.dirname(__file__), audio_info["filename"]
                )
            else:
                audio_path = os.path.join(self.data_path_prefix, audio_info["filename"])

            # Load the audio file
            waveform, sample_rate = torchaudio.load(audio_path)

            # Transform to spectrogram
            spectrogram = self.transform(waveform)
            # print(f"Spectrogram shape: {spectrogram.shape}")

            # Get the numerical label from the word label
            label = self.label_map[audio_info["label"]]

            return spectrogram, label

    def collate_fn_spectrogram(batch):
        # batch is a list of tuples (spectrogram, label)

        # Let's set our target length
        target_length = 81

        spectrograms = []
        labels = []

        # Loop through each item in the batch
        for spec, label in batch:
            # spec shape is (1, num_features, time)
            current_length = spec.shape[2]

            # --- Padding or Truncating ---
            if current_length < target_length:
                # Pad with zeros if it's too short
                padding_needed = target_length - current_length
                # torch.pad takes (data, (pad_left, pad_right, pad_top, pad_bottom, ...))
                spec = torch.nn.functional.pad(spec, (0, padding_needed))
            elif current_length > target_length:
                # Truncate if it's too long
                spec = spec[:, :, :target_length]

            spectrograms.append(spec)
            labels.append(label)

        # Stack them into a single batch tensor
        spectrograms_batch = torch.cat(spectrograms, dim=0)
        labels_batch = torch.tensor(labels)

        return spectrograms_batch, labels_batch

    # Its job: Define the Transformer architecture.
    class AudioTransformer(nn.Module):
        def __init__(self, num_input_features=128, num_classes=35, dropout=0.1):
            super().__init__()
            # Using PyTorch's pre-built Transformer components
            encoder_layer = nn.TransformerEncoderLayer(
                d_model=num_input_features, nhead=4, batch_first=True, dropout=dropout
            )
            self.transformer_encoder = nn.TransformerEncoder(
                encoder_layer, num_layers=4
            )
            self.output_layer = nn.Linear(num_input_features, num_classes)

        def forward(self, spectrogram_batch):
            # Input shape needs to be (batch, time, features) for batch_first=True
            # Spectrograms are often (batch, features, time), so we might need to permute
            x = spectrogram_batch.permute(0, 2, 1)

            x = self.transformer_encoder(x)
            x = x.mean(dim=1)  # Average over the time dimension
            predictions = self.output_layer(x)
            return predictions

    class Trainer:
        def __init__(
            self,
            model,
            train_loader,
            val_loader,
            test_loader,
            optimizer,
            loss_fn,
            scheduler,
            device,
            total_epochs,
            exp_path=None,
            rank=None,
        ):
            self.model = model
            self.train_loader = train_loader
            self.val_loader = val_loader
            self.test_loader = test_loader
            self.optimizer = optimizer
            self.loss_fn = loss_fn
            self.scheduler = scheduler
            self.device = device
            self.total_epochs = total_epochs
            self.best_val_accuracy = 0.0
            self.step = 0
            self.timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
            self.total_steps = len(self.train_loader) * self.total_epochs
            if exp_path is None:
                self.exp_path = f"/data/speech-recognition/runs/exp-{self.timestamp}"
            else:
                self.exp_path = exp_path
            print(f"Experiment path: {self.exp_path}")

            self.rank = rank
            self.is_main_process = self.rank == 0

            if self.is_main_process:
                self.writer = SummaryWriter(f"{self.exp_path}/logs")
            else:
                self.writer = None

        def _train_one_epoch(self, epoch):
            self.train_loader.sampler.set_epoch(epoch)
            self.model.train()  # Ensure model is in training mode
            epoch_loss = 0.0
            num_batches = 0

            for batch_idx, (spectrograms, labels) in enumerate(self.train_loader):
                spectrograms = spectrograms.to(self.device)
                labels = labels.to(self.device)

                # 1. PREDICT: Pass data through the model
                predictions = self.model(spectrograms)

                # 2. COMPARE: Calculate the error
                loss = self.loss_fn(predictions, labels)

                # 3. ADJUST: Update the model's weights
                self.optimizer.zero_grad()
                loss.backward()
                # Add gradient clipping to prevent explosion
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
                self.optimizer.step()

                # Update tracking variables
                epoch_loss += loss.item()
                num_batches += 1
                self.step += 1

                # Log to TensorBoard every step
                if self.is_main_process:
                    self.writer.add_scalar("Loss/Train", loss.item(), self.step)
                    # Print progress every 10 steps or at the end of each epoch
                    if (batch_idx + 1) % 10 == 0 or (batch_idx + 1) == len(
                        self.train_loader
                    ):
                        avg_loss = epoch_loss / num_batches
                        print(
                            f"Epoch {epoch+1:2d}/{self.total_epochs} | Step {self.step:4d}/{self.total_steps} | "
                            f"Batch {batch_idx+1:3d}/{len(self.train_loader)} | "
                            f"Loss: {loss.item():.6f} | Avg Loss: {avg_loss:.6f}"
                        )

            # Log epoch average loss to TensorBoard
            avg_epoch_loss = epoch_loss / num_batches
            if self.is_main_process:
                self.writer.add_scalar("Loss/Epoch_Avg", avg_epoch_loss, epoch + 1)
                # Print epoch summary
                print(
                    f"Epoch {epoch+1:2d}/{self.total_epochs} completed | Avg Loss: {avg_epoch_loss:.6f}"
                )
                print("-" * 80)

        def _validate_one_epoch(self, epoch, loader=None):
            # Single-machine validation (only runs on rank 0)
            self.model.eval()  # 1. Switch to evaluation mode
            val_loss = 0
            correct = 0
            total = 0

            print("  Starting validation...")

            with torch.no_grad():  # 2. Do not compute gradients within this code block
                for batch_idx, (spectrograms, labels) in enumerate(loader):
                    spectrograms = spectrograms.to(self.device)
                    labels = labels.to(self.device)

                    # Only perform prediction and calculate loss
                    predictions = self.model(spectrograms)
                    loss = self.loss_fn(predictions, labels)

                    val_loss += loss.item()

                    # Calculate accuracy
                    _, predicted_labels = torch.max(predictions.data, 1)
                    total += labels.size(0)
                    correct += (predicted_labels == labels).sum().item()

                    # Print validation progress every 10 batches
                    if (batch_idx + 1) % 10 == 0 or (batch_idx + 1) == len(loader):
                        current_accuracy = 100 * correct / total
                        print(
                            f"    Validation Batch {batch_idx+1:3d}/{len(loader)} | "
                            f"Current Accuracy: {current_accuracy:.2f}%"
                        )

            # Simple single-machine calculation
            avg_loss = val_loss / len(loader)
            val_accuracy = 100 * correct / total

            print(
                f"  Validation completed | Loss: {avg_loss:.6f} | Accuracy: {val_accuracy:.2f}%"
            )

            if val_accuracy > self.best_val_accuracy:
                self.best_val_accuracy = val_accuracy
                print(
                    f"New best validation accuracy: {self.best_val_accuracy:.2f}%. Saving model..."
                )
                model_folder = f"{self.exp_path}/models"
                os.makedirs(model_folder, exist_ok=True)
                # Handle both DDP and non-DDP models
                if hasattr(self.model, "module"):
                    model_state = self.model.module.state_dict()
                else:
                    model_state = self.model.state_dict()
                torch.save(model_state, f"{model_folder}/best-epoch{epoch}.pth")

                if self.writer:
                    self.writer.add_scalar("Loss/Val", val_loss, epoch + 1)
                    self.writer.add_scalar("Accuracy/Val", val_accuracy, epoch + 1)

        def train(self):
            print("Starting training...")
            for epoch in range(self.total_epochs):
                self._train_one_epoch(epoch)
                # Only validate if val_loader is available (single-machine validation)
                if self.val_loader is not None:
                    self._validate_one_epoch(epoch, self.val_loader)

                # Synchronize all processes after validation
                if dist.is_initialized():
                    dist.barrier()  # Wait for rank 0 to finish validation

                if self.scheduler:
                    self.scheduler.step()
                if self.is_main_process:
                    print("-" * 80)

            if self.is_main_process:
                self.writer.close()
                print("Training complete!")

        def test(self):
            # Only run test on rank 0 (single-machine testing)
            if self.test_loader is None:
                return

            print("Starting test...")
            self.model.eval()  # 1. Switch to evaluation mode
            val_loss = 0
            correct = 0
            total = 0

            with torch.no_grad():  # 2. Do not compute gradients within this code block
                for batch_idx, (spectrograms, labels) in enumerate(self.test_loader):
                    spectrograms = spectrograms.to(self.device)
                    labels = labels.to(self.device)

                    # Only perform prediction and calculate loss
                    predictions = self.model(spectrograms)
                    loss = self.loss_fn(predictions, labels)

                    val_loss += loss.item()

                    # Calculate accuracy
                    _, predicted_labels = torch.max(predictions.data, 1)
                    total += labels.size(0)
                    correct += (predicted_labels == labels).sum().item()

                    # Print test progress every 10 batches
                    if (batch_idx + 1) % 10 == 0 or (batch_idx + 1) == len(self.test_loader):
                        current_accuracy = 100 * correct / total
                        print(
                            f"    Test Batch {batch_idx+1:3d}/{len(self.test_loader)} | "
                            f"Current Accuracy: {current_accuracy:.2f}%"
                        )

            # Simple single-machine calculation
            avg_loss = val_loss / len(self.test_loader)
            val_accuracy = 100 * correct / total

            print(
                f"  Test completed | Loss: {avg_loss:.6f} | Accuracy: {val_accuracy:.2f}%"
            )
            print("Test complete!")

    def setup_ddp():
        """Initialize DDP process group"""
        # This line is key: bind unique GPU for current process
        if (
            "LOCAL_RANK" not in os.environ
            or "RANK" not in os.environ
            or "WORLD_SIZE" not in os.environ
        ):
            print("LOCAL_RANK, RANK, and WORLD_SIZE is not set, will skip using DDP")
            return torch.device("cuda") if torch.cuda.is_available() else "cpu", 0, 0
        print(
            f"LOCAL_RANK: {os.environ['LOCAL_RANK']}, RANK: {os.environ['RANK']}, WORLD_SIZE: {os.environ['WORLD_SIZE']}"
        )

        local_rank = int(os.environ["LOCAL_RANK"])

        if torch.cuda.is_available():
            dist.init_process_group(backend="nccl")
            device = torch.device("cuda", local_rank)
            torch.cuda.set_device(device)
            print(f"Using device: {device}")

        else:
            device = torch.device("cpu")
            dist.init_process_group(backend="gloo")

        rank = torch.distributed.get_rank()
        print(
            f"Rank(dist.get_rank()): {rank}, Rank(os.environ['RANK']): {os.environ['RANK']}, Local Rank(os.environ['LOCAL_RANK']): {local_rank}"
        )

        return device, local_rank, rank

    def cleanup_ddp():
        """Destroy process group"""
        dist.destroy_process_group()

    def train():
        # This line will automatically select GPU (if available), otherwise fall back to CPU
        device, local_rank, rank = setup_ddp()

        print(f"Using device: {device}")

        # Instantiate the Dataset and DataLoader
        print("start loading dataset")
        data_path_prefix = "/data/SpeechCommands/speech_commands_v0.02"
        audio_info, label_map, label_map_reverse = load_data(data_path_prefix)

        audio_info_training, audio_info_validation, audio_info_test = split_data(
            audio_info
        )

        if debug:
            # Use more data for debug: 4000 train, 500 val, 500 test
            audio_info_training = audio_info_training[:4000]
            audio_info_validation = audio_info_validation[:500]
            audio_info_test = audio_info_test[:500]
            print(
                f"Debug mode: using train={len(audio_info_training)}, val={len(audio_info_validation)}, test={len(audio_info_test)}"
            )

        train_dataset = SpeechCommandsDataset(
            audio_info_training, label_map, label_map_reverse, data_path_prefix
        )

        val_dataset = SpeechCommandsDataset(
            audio_info_validation, label_map, label_map_reverse, data_path_prefix
        )
        test_dataset = SpeechCommandsDataset(
            audio_info_test, label_map, label_map_reverse, data_path_prefix
        )

        print("dataset loaded")

        print("start init data loader")

        train_sampler = DistributedSampler(train_dataset, shuffle=True)
        # Adjust batch size for debug mode
        batch_size = 64 if debug else 256
        train_loader = DataLoader(
            train_dataset,
            batch_size=batch_size,
            shuffle=False,
            sampler=train_sampler,
            collate_fn=collate_fn_spectrogram,
        )
        # Use single-machine validation (only rank 0)
        if rank == 0:
            val_loader = DataLoader(
                val_dataset,
                batch_size=batch_size,
                shuffle=False,
                collate_fn=collate_fn_spectrogram,
            )
            test_loader = DataLoader(
                test_dataset,
                batch_size=batch_size,
                shuffle=False,
                collate_fn=collate_fn_spectrogram,
            )
        else:
            val_loader = None
            test_loader = None

        print("data loader initialized")

        # Instantiate the Model, Loss Function, and Optimizer
        model = AudioTransformer().to(device)

        # Create DDP model - different parameters for CPU vs GPU
        if torch.cuda.is_available() and device.type == "cuda":
            ddp_model = DDP(model, device_ids=[local_rank])
        else:
            # For CPU training, don't specify device_ids
            ddp_model = DDP(model)

        loss_fn = nn.CrossEntropyLoss()
        world_size = dist.get_world_size() if dist.is_initialized() else 1
        # Use linear scaling with a more conservative approach
        base_lr = 0.001
        lr = (
            base_lr * min(world_size, 2) if not debug else base_lr
        )  # No scaling in debug
        print(f"Using learning rate: {lr} (world_size: {world_size}, debug: {debug})")
        optimizer = torch.optim.Adam(ddp_model.parameters(), lr=lr)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.9)

        # Generate timestamp for this experiment
        timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
        print(f"Experiment timestamp: {timestamp}")

        # --- 2. Create and start Trainer ---
        total_epochs = 10 if debug else 30
        trainer = Trainer(
            ddp_model,
            train_loader,
            val_loader,
            test_loader,
            optimizer,
            loss_fn,
            scheduler,
            device,
            total_epochs,
            rank=rank,
        )
        trainer.train()

        # --- 3. (Optional) Final test ---
        trainer.test()  # You can add a .test() method to Trainer

        # Synchronize all processes after test
        if dist.is_initialized():
            dist.barrier()  # Wait for rank 0 to finish test

        cleanup_ddp()

        if rank == 0:
            print("Training complete!")
            print("To view TensorBoard, run: tensorboard --logdir=runs")

    train()




## Train in local Jupyter Notebook



In [None]:
import os

# Set the Torch Distributed env variables so the training function can be run locally in the Notebook.
# See https://pytorch.org/docs/stable/elastic/run.html#environment-variables
os.environ["RANK"] = "0"
os.environ["LOCAL_RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12345"

train_model()

## Train with Kubeflow Trainer

In [35]:
job_name = client.train(
    trainer=CustomTrainer(
        func=train_model,
        # Set how many PyTorch nodes you want to use for distributed training.
        num_nodes=2,
        # Set the resources for each PyTorch node.
        resources_per_node={
            "cpu": 5,
            "memory": "50Gi",
            # Uncomment this to distribute the TrainJob using GPU nodes.
            # "nvidia.com/gpu": 1,
        },
    ),
    runtime=torch_runtime,
)

In [36]:
client.wait_for_job_status(name=job_name, status={"Running"})

TrainJob(name='r646465920c0', creation_timestamp=datetime.datetime(2025, 9, 14, 3, 44, 31, tzinfo=TzInfo(UTC)), runtime=Runtime(name='torch-distributed-speech-recognition', trainer=RuntimeTrainer(trainer_type=<TrainerType.CUSTOM_TRAINER: 'CustomTrainer'>, framework='torch', num_nodes=1, device='Unknown', device_count='Unknown'), pretrained_model=None), steps=[Step(name='node-0', status='Running', pod_name='r646465920c0-node-0-0-qplsf', device='cpu', device_count='5'), Step(name='node-1', status='Running', pod_name='r646465920c0-node-0-1-cbwr5', device='cpu', device_count='5')], num_nodes=2, status='Running')

In [37]:
# or use kubectl to get pod and see logs
! kubectl get pod