# üìò Notebook 2c: Distributed Data Parallel (DDP) Training with Snowflake

## Native PyTorch DDP using `PyTorchDistributor`

Snowflake **natively supports PyTorch Distributed Data Parallel (DDP)** training through the `PyTorchDistributor` API. This allows you to run distributed training at scale without managing your own cluster or orchestration layer.

### When to Consider DDP

DDP is beneficial when:

- **Training time is a bottleneck** ‚Äî Epochs take too long on a single GPU
- **Multiple GPUs are available** ‚Äî You have 2+ GPUs you want to utilize
- **GPU utilization is already high** ‚Äî Single GPU is near 100% but still slow
- **Data loading isn't the bottleneck** ‚Äî If data loading is slow, DDP won't help

DDP may be overkill when:

- **Training completes quickly** ‚Äî A few minutes per epoch on single GPU
- **GPU utilization is low** ‚Äî Indicates data loading or CPU bottleneck
- **Only 1 GPU available** ‚Äî DDP requires multiple GPUs

> ‚ö†Ô∏è **Note:** There are no universal dataset size thresholds. Whether DDP helps depends on your specific model, batch size, hardware, and data pipeline. Profile first, then decide.

### How Snowflake DDP Works

```
‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê
‚îÇ                   PyTorchDistributor                         ‚îÇ
‚îÇ                                                              ‚îÇ
‚îÇ   ‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê   ‚îÇ
‚îÇ   ‚îÇ              Your Training Function                  ‚îÇ   ‚îÇ
‚îÇ   ‚îÇ  ‚Ä¢ Uses standard PyTorch DDP APIs                   ‚îÇ   ‚îÇ
‚îÇ   ‚îÇ  ‚Ä¢ Gets rank/world_size from get_context()          ‚îÇ   ‚îÇ
‚îÇ   ‚îÇ  ‚Ä¢ Wraps model with DistributedDataParallel         ‚îÇ   ‚îÇ
‚îÇ   ‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò   ‚îÇ
‚îÇ                           ‚Üì                                  ‚îÇ
‚îÇ   ‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê  ‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê  ‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê  ‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê   ‚îÇ
‚îÇ   ‚îÇ  GPU 0   ‚îÇ  ‚îÇ  GPU 1   ‚îÇ  ‚îÇ  GPU 2   ‚îÇ  ‚îÇ  GPU 3   ‚îÇ   ‚îÇ
‚îÇ   ‚îÇ  Rank 0  ‚îÇ  ‚îÇ  Rank 1  ‚îÇ  ‚îÇ  Rank 2  ‚îÇ  ‚îÇ  Rank 3  ‚îÇ   ‚îÇ
‚îÇ   ‚îÇ  Model   ‚îÇ  ‚îÇ  Model   ‚îÇ  ‚îÇ  Model   ‚îÇ  ‚îÇ  Model   ‚îÇ   ‚îÇ
‚îÇ   ‚îÇ  Copy    ‚îÇ  ‚îÇ  Copy    ‚îÇ  ‚îÇ  Copy    ‚îÇ  ‚îÇ  Copy    ‚îÇ   ‚îÇ
‚îÇ   ‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚î¨‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò  ‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚î¨‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò  ‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚î¨‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò  ‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚î¨‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò   ‚îÇ
‚îÇ        ‚îÇ             ‚îÇ             ‚îÇ             ‚îÇ          ‚îÇ
‚îÇ        ‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î¥‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î¨‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î¥‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò          ‚îÇ
‚îÇ                             ‚Üì                               ‚îÇ
‚îÇ                   Gradient AllReduce                        ‚îÇ
‚îÇ                   (Handled automatically)                   ‚îÇ
‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò
```

### Key Components

| Component | Purpose |
|-----------|---------|
| `ShardedDataConnector` | Automatically partitions data across workers |
| `PyTorchDistributor` | Manages distributed training orchestration |
| `PyTorchScalingConfig` | Configures nodes, workers, and resources |
| `get_context()` | Provides rank, local_rank, dataset_map inside training function |

---

**References:**
- [Snowflake PyTorchDistributor Documentation](https://docs.snowflake.com/en/developer-guide/snowpark-ml/reference/latest/container-runtime/distributors.pytorch_distributor)


In [None]:
 -- Run this SQL to increase your pool capacity to 3 nodes                      
 ALTER COMPUTE POOL WAFER_TRAINING_POOL SET MAX_NODES = 3;   

In [None]:
session.sql("DESCRIBE COMPUTE POOL WAFER_TRAINING_POOL").show() 

In [None]:
# ============================================================================
# IMPORTS
# ============================================================================

import os
import numpy as np
import pandas as pd
from datetime import datetime

# PyTorch imports
import torch
import torch.nn as nn
import torch.optim as optim

# Snowpark imports
from snowflake.snowpark.context import get_active_session

# Snowflake ML Dataset imports
from snowflake.ml import dataset

# Snowflake ML Data imports
from snowflake.ml.data import DataConnector
from snowflake.ml.data.sharded_data_connector import ShardedDataConnector

# Snowflake Distributed Training imports
from snowflake.ml.modeling.distributors.pytorch import (
    PyTorchDistributor,
    PyTorchScalingConfig,
    WorkerResourceConfig
)

print("‚úÖ Imports complete")
print(f"   PyTorch: {torch.__version__}")
print(f"   CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"   GPU count: {torch.cuda.device_count()}")


In [None]:
# ============================================================================
# SNOWFLAKE SESSION
# ============================================================================

session = get_active_session()

session.sql("USE DATABASE WAFER_YIELD_DEMO").collect()
session.sql("USE SCHEMA RAW_DATA").collect()

print(f"‚úÖ Session active")
print(f"   Database: {session.get_current_database()}")
print(f"   Schema: {session.get_current_schema()}")


---

## üìò Section 1 ‚Äî Load Data with ShardedDataConnector

The `ShardedDataConnector` automatically partitions data across distributed workers. Each worker receives a unique shard of the data.

Why use ShardedDataConnector:                                                   

 ‚Ä¢ Each GPU worker gets a unique subset of your XX rows                       
 ‚Ä¢ Prevents memory overflow on individual GPUs                                  
 ‚Ä¢ Required for PyTorch DDP to work correctly (each worker needs different data)
 ‚Ä¢ Automatically partitions data across your workers 


In [None]:
# ============================================================================
# LOAD ML DATASET
# ============================================================================

# Fully qualified dataset name (created in Notebook 01)
DATASET_NAME = "WAFER_YIELD_DEMO.RAW_DATA.WAFER_YIELD_TRAINING_DATASET"
DATASET_VERSION = "v1"

# Load ML Dataset
print(f"üì¶ Loading ML Dataset: {DATASET_NAME}")
wafer_dataset = dataset.load_dataset(session, DATASET_NAME, DATASET_VERSION)

# Create DataConnector from the dataset
sharded_data_connector = ShardedDataConnector.from_dataset(wafer_dataset)  

# For DDP, we'll convert to ShardedDataConnector later in the training function
# For now, get the Snowpark DataFrame to inspect columns
training_df = wafer_dataset.read.to_snowpark_dataframe()

# Define feature and label columns
EXCLUDE_COLS = ['WAFER_ID', 'YIELD_GOOD', 'YIELD_SCORE', 'DOMINANT_DEFECT_TYPE']

# Get column names from the dataframe
all_cols = [f.name for f in training_df.schema.fields]                         
input_cols = [c for c in all_cols if c.upper() not in [x.upper() for x in      
 EXCLUDE_COLS]]                                                                 
label_col = 'YIELD_GOOD'   

from snowflake.ml.runtime_cluster import scale_cluster, get_nodes               

print("Scaling cluster to 2 nodes for 2-GPU training...") 
scale_cluster(2)  #   This may take a few minutes   

nodes = get_nodes() 
print(f"Active nodes: {len(nodes)}") 
print(f"Node details:  {nodes}")     

In [None]:
print("ML Dataset loaded:", DATASET_NAME) 
print("Total columns:", len(all_cols))
print("Feature columns:", len(input_cols)) 
print("Label column:", label_col)    
print("Features:", input_cols[:5]) 

---

## üìò Section 2 ‚Äî Define Model Architecture

Define the DNN model that will be trained with DDP. The model itself is standard PyTorch ‚Äî DDP wrapping happens inside the training function.


In [None]:
# ============================================================================
# DEFINE MODEL ARCHITECTURE
# ============================================================================

class WaferYieldDNN(nn.Module):                                                
     """                                                                        
     Deep Neural Network for wafer yield classification.                        
                                                                                
     Architecture:                                                              
         Input ‚Üí 128 ‚Üí ReLU ‚Üí BN ‚Üí Dropout ‚Üí 64 ‚Üí ReLU ‚Üí BN ‚Üí Dropout ‚Üí 1 ‚Üí     
 Sigmoid                                                                        
     """                                                                        
                                                                                
     def __init__(self, input_size, hidden_size=128, output_size=1,             
 dropout_p=0.3):                                                                
         super(WaferYieldDNN, self).__init__()                                  
                                                                                
         self.network = nn.Sequential(                                          
             nn.Linear(input_size, hidden_size),                                
             nn.ReLU(),                                                         
             nn.BatchNorm1d(hidden_size),                                       
             nn.Dropout(dropout_p),                                             
             nn.Linear(hidden_size, hidden_size // 2),                          
             nn.ReLU(),                                                         
             nn.BatchNorm1d(hidden_size // 2),                                  
             nn.Dropout(dropout_p * 0.67),  # Slightly less dropout in second layer                                                                          
             nn.Linear(hidden_size // 2, output_size),                          
             nn.Sigmoid()                                                       
         )                                                                      
                                                                                
     def forward(self, x):                                                      
         return self.network(x)                                                 
                                                                                

print("Model architecture defined") 
print("Input size:", len(input_cols))       
print("Hidden layers: 128 -> 64") 
print("Output: Binary classification (Sigmoid)")      


---

## üìò Section 3 ‚Äî Define DDP Training Function

The training function runs on each distributed worker. Key patterns:

1. **Import inside the function** ‚Äî Ensures workers have access to modules
2. **`get_context()`** ‚Äî Provides rank, local_rank, dataset_map, model_dir
3. **`init_process_group`** ‚Äî Initializes DDP communication
4. **`DDP(model)`** ‚Äî Wraps model for gradient synchronization
5. **`get_shard()`** ‚Äî Each worker gets its unique data partition
6. **Save on rank 0 only** ‚Äî Prevents duplicate saves

 ‚Ä¢ PyTorchTrainer setup                                                         
 ‚Ä¢ ScalingConfig with num_nodes and num_workers_per_node                        
 ‚Ä¢ Training function with dist.init_process_group() and DDP wrapper             
 ‚Ä¢ DataLoader creation from sharded data    


In [None]:
# ============================================================================
# DEFINE DDP TRAINING FUNCTION
# ============================================================================

# Store column info for access inside training function
INPUT_COLS = input_cols
LABEL_COL = label_col
INPUT_SIZE = len(input_cols)

def train_ddp_func():
    """DDP training function that runs on each worker."""
    import os
    import torch
    import torch.nn as nn
    import torch.optim as optim
    import torch.distributed as dist
    from torch.nn.parallel import DistributedDataParallel as DDP
    from torch.utils.data import DataLoader
    from snowflake.ml.modeling.distributors.pytorch import get_context  
    
    # Get distributed context
    context = get_context()
    rank = context.get_rank()
    local_rank = context.get_local_rank()
    print(f"[Rank {rank}] Starting training...")
    
    # Initialize process group
    backend = 'nccl' if torch.cuda.is_available() else 'gloo'
    dist.init_process_group(backend=backend)
    device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu")
    print(f"[Rank {rank}] Using device: {device}")
    
    # Define model
    class WaferYieldDNN(nn.Module):
        def __init__(self, input_size, hidden_size=128):
            super().__init__()
            self.network = nn.Sequential(
                nn.Linear(input_size, hidden_size),
                nn.ReLU(),
                nn.BatchNorm1d(hidden_size),
                nn.Dropout(0.3),
                nn.Linear(hidden_size, hidden_size // 2),
                nn.ReLU(),
                nn.BatchNorm1d(hidden_size // 2),
                nn.Dropout(0.2),
                nn.Linear(hidden_size // 2, 1),
                nn.Sigmoid()
            )
        def forward(self, x):
            return self.network(x)
    
    # Create and wrap model with DDP
    model = WaferYieldDNN(input_size=INPUT_SIZE)                                   
    model = nn.SyncBatchNorm.convert_sync_batchnorm(model)                         
    model = model.to(device)                                                       
    model = DDP(model, device_ids=[local_rank] if torch.cuda.is_available() else None)  
    
    # Setup optimizer and loss
    criterion = nn.BCELoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    
    # Get data shard for this worker
    dataset_map = context.get_dataset_map()
    torch_dataset = dataset_map['train'].get_shard().to_torch_dataset(batch_size=1024)
    dataloader = DataLoader(torch_dataset, batch_size=None)
    
    # Training loop
    EPOCHS = 25
    model.train()
    
    for epoch in range(EPOCHS):
        epoch_loss = 0.0
        num_batches = 0
        
        for batch_dict in dataloader:

            if len(batch_dict) == 0:                                                   
                continue                                                               
                                                                                
            features = torch.cat([batch_dict[col].T for col in INPUT_COLS], dim=1).float().to(device)
            labels = batch_dict[LABEL_COL].T.squeeze(0).float().to(device)
            
            optimizer.zero_grad()
            outputs = model(features).squeeze()
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
            num_batches += 1
        
        if rank == 0 and (epoch + 1) % 5 == 0:
            avg_loss = epoch_loss / max(num_batches, 1)
            print(f"   Epoch [{epoch+1}/{EPOCHS}], Loss: {avg_loss:.4f}")
    
    # Save model (only rank 0)
    if rank == 0:
        model_dir = context.get_model_dir()
        model_path = os.path.join(model_dir, "wafer_yield_ddp_model.pt")
        torch.save(model.module.state_dict(), model_path)
        print(f"\n‚úÖ Model saved to: {model_path}")
    
    dist.destroy_process_group()
    print(f"[Rank {rank}] Training complete!")

print("‚úÖ Training function defined")
 

In [None]:
# ============================================================================
# DEFINE DDP TRAINING FUNCTION
# ============================================================================

# Store column info for access inside training function
INPUT_COLS = input_cols
LABEL_COL = label_col
INPUT_SIZE = len(input_cols)

def train_ddp_func():
    """DDP training function that runs on each worker."""
    import os
    import torch
    import torch.nn as nn
    import torch.optim as optim
    import torch.distributed as dist
    from torch.nn.parallel import DistributedDataParallel as DDP
    from torch.utils.data import DataLoader
    from snowflake.ml.modeling.distributors.pytorch import get_context  
    
    # Get distributed context
    context = get_context()
    rank = context.get_rank()
    local_rank = context.get_local_rank()
    print(f"[Rank {rank}] Starting training...")
    
    # Initialize process group
    backend = 'nccl' if torch.cuda.is_available() else 'gloo'
    dist.init_process_group(backend=backend)
    device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu")
    print(f"[Rank {rank}] Using device: {device}")
    
    # Define model
    class WaferYieldDNN(nn.Module):
        def __init__(self, input_size, hidden_size=128):
            super().__init__()
            self.network = nn.Sequential(
                nn.Linear(input_size, hidden_size),
                nn.ReLU(),
                nn.BatchNorm1d(hidden_size),
                nn.Dropout(0.3),
                nn.Linear(hidden_size, hidden_size // 2),
                nn.ReLU(),
                nn.BatchNorm1d(hidden_size // 2),
                nn.Dropout(0.2),
                nn.Linear(hidden_size // 2, 1),
                nn.Sigmoid()
            )
        def forward(self, x):
            return self.network(x)
    
    # Create and wrap model with DDP
    model = WaferYieldDNN(input_size=INPUT_SIZE)                                   
    model = nn.SyncBatchNorm.convert_sync_batchnorm(model)                         
    model = model.to(device)                                                       
    model = DDP(model, device_ids=[local_rank] if torch.cuda.is_available() else None)  
    
    # Setup optimizer and loss
    criterion = nn.BCELoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    
    # Get data shard for this worker
    dataset_map = context.get_dataset_map()
    torch_dataset = dataset_map['train'].get_shard().to_torch_dataset(batch_size=1024)
    dataloader = DataLoader(torch_dataset, batch_size=None)
    
    # Training loop
    EPOCHS = 25
    model.train()
    
    for epoch in range(EPOCHS):
        epoch_loss = 0.0
        num_batches = 0
        
        for batch_dict in dataloader:

            if len(batch_dict) == 0:                                                   
                continue                                                               
                                                                                
            features = torch.stack([batch_dict[col].squeeze() for col in INPUT_COLS], dim=1).float().to(device)
            labels = batch_dict[LABEL_COL].squeeze().float().to(device)
            
            optimizer.zero_grad()
            outputs = model(features).squeeze()
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
            num_batches += 1
        
        if rank == 0 and (epoch + 1) % 5 == 0:
            avg_loss = epoch_loss / max(num_batches, 1)
            print(f"   Epoch [{epoch+1}/{EPOCHS}], Loss: {avg_loss:.4f}")
    
    # Save model (only rank 0)
    if rank == 0:
        model_dir = context.get_model_dir()
        model_path = os.path.join(model_dir, "wafer_yield_ddp_model.pt")
        torch.save(model.module.state_dict(), model_path)
        print(f"\n‚úÖ Model saved to: {model_path}")
    
    dist.destroy_process_group()
    print(f"[Rank {rank}] Training complete!")

print("‚úÖ Training function defined")

---

## üìò Section 4 ‚Äî Configure and Launch Distributed Training

Use `PyTorchDistributor` with `PyTorchScalingConfig` to configure the distributed training job.


In [None]:
# ============================================================================
# CONFIGURE PYTORCHDISTRIBUTOR
# ============================================================================

NUM_NODES = 2                    # Use both nodes in your pool                  
NUM_WORKERS_PER_NODE = 1         # 1 worker per node (1 GPU per node)           
NUM_CPUS_PER_WORKER = 4          # CPUs per worker 
NUM_GPUS_PER_WORKER = 1        # 1 GPU per worker                                                              

scaling_config = PyTorchScalingConfig( num_nodes=NUM_NODES,                     
num_workers_per_node=NUM_WORKERS_PER_NODE,                                      
resource_requirements_per_worker=WorkerResourceConfig( num_cpus=NUM_CPUS_PER_WORKER, num_gpus=NUM_GPUS_PER_WORKER ) )                  

pytorch_trainer = PyTorchDistributor( train_func=train_ddp_func,                
scaling_config=scaling_config )                                                 

print("PyTorchDistributor configured") 
print(f"   Nodes: {NUM_NODES}") 
print(f" Workers per node: {NUM_WORKERS_PER_NODE}") 
print(f"   GPUs per worker: {NUM_GPUS_PER_WORKER}") 
print(f"   Total GPUs: {NUM_NODES * NUM_WORKERS_PER_NODE * NUM_GPUS_PER_WORKER}")   

In [None]:
# ============================================================================
# RUN DISTRIBUTED TRAINING
# ============================================================================

print("üöÄ Starting distributed DDP training...")
print("=" * 60)

# Run the distributed training job
# Pass the data_connector via dataset_map
response = pytorch_trainer.run(
    dataset_map={'train': sharded_data_connector}
)

print("=" * 60)
print("‚úÖ Distributed training complete!")


---

## üìò Section 5 ‚Äî Retrieve Trained Model

For multi-node DDP, the model is automatically synchronized to a Snowflake stage. Use `get_model_dir()` from the response to locate it.


In [None]:
# ============================================================================
# RETRIEVE MODEL FROM RESPONSE
# ============================================================================

# Get the model directory from the training response (this is a stage path)
model_dir = response.get_model_dir()
print(f"üìÅ Model stage location: {model_dir}")
print(f"‚úÖ Model saved to Snowflake stage and ready for registry")


---

## üìò Section 6 ‚Äî Multi-Node Model Persistence with Stages

For multi-node training, specify an `artifact_stage_location` to persist the model to a Snowflake stage:

```python
response = pytorch_trainer.run(
    dataset_map={'train': data_connector},
    artifact_stage_location="DB_NAME.SCHEMA_NAME.STAGE_NAME"
)

# Model saved at: DB_NAME.SCHEMA_NAME.STAGE_NAME/model/{request_id}/
stage_location = response.get_model_dir()
```

This ensures the model is accessible across nodes and persisted beyond the training session.


In [None]:
# ============================================================================
# REGISTER MODEL TO SNOWFLAKE MODEL REGISTRY
# ============================================================================

from snowflake.ml.registry import Registry
import pandas as pd

# Get the model stage path from training response
model_dir = response.get_model_dir()
stage_model_path = f"{model_dir}/wafer_yield_ddp_model.pt"

# Recreate model architecture and load from stage
trained_model = WaferYieldDNN(input_size=len(input_cols))

# Download temporarily just to load state dict
import tempfile
local_temp_dir = tempfile.mkdtemp()
session.file.get(stage_model_path, local_temp_dir)

import glob
downloaded_files = glob.glob(os.path.join(local_temp_dir, "*wafer_yield_ddp_model.pt*"))
if not downloaded_files:
    downloaded_files = glob.glob(os.path.join(local_temp_dir, "*.pt"))
    
trained_model.load_state_dict(torch.load(downloaded_files[0]))
trained_model.eval()

print(f"‚úÖ Model loaded from stage: {stage_model_path}")
print(f"   Parameters: {sum(p.numel() for p in trained_model.parameters()):,}")

# Create registry
registry = Registry(session=session)

# Create sample input for signature inference
sample_input = pd.DataFrame({col: [0.0] for col in input_cols}).astype('float32')

# Register the model
mv = registry.log_model(
    model=trained_model,
    model_name="WAFER_YIELD_DDP_MODEL",
    version_name=f"v_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
    sample_input_data=sample_input,
    options={
        "embed_local_ml_library": True
    }
)

print(f"‚úÖ Model registered to Snowflake Model Registry")
print(f"   Name: {mv.model_name}")
print(f"   Version: {mv.version_name}")

# Clean up temp directory
import shutil
shutil.rmtree(local_temp_dir)


---

## üìò Summary

### What We Covered

| Topic | Snowflake API |
|-------|---------------|
| **Data Loading** | `ShardedDataConnector` ‚Äî auto-partitions data across workers |
| **Training Function** | Standard PyTorch DDP with `get_context()` for rank/device info |
| **Orchestration** | `PyTorchDistributor` ‚Äî manages distributed job execution |
| **Scaling** | `PyTorchScalingConfig` ‚Äî configure nodes, workers, GPUs |
| **Model Persistence** | `artifact_stage_location` ‚Äî sync models to Snowflake stage |
| **Model Registry** | `registry.log_model()` ‚Äî register for deployment |

### Key Takeaways

1. **No cluster management** ‚Äî Snowflake handles all orchestration
2. **Standard PyTorch code** ‚Äî Your DDP logic is portable
3. **Automatic data sharding** ‚Äî Each worker gets unique data partition
4. **Gradient sync handled** ‚Äî DDP wrapper synchronizes automatically
5. **Integrated persistence** ‚Äî Models saved to Snowflake stages

### When to Use DDP

| ‚úÖ Consider DDP | ‚ùå DDP May Not Help |
|-----------------|---------------------|
| Training epochs take too long | Training already fast |
| Multiple GPUs available | Only 1 GPU available |
| GPU utilization is high | GPU utilization is low (data bottleneck) |
| Model fits on single GPU | Model too large for single GPU (use FSDP) |

> **Tip:** Profile your workload first. Use `nvidia-smi` or Snowflake's resource monitoring to identify bottlenecks before adding distributed training complexity.

---

**Next Steps:**
- Register model to Model Registry (Notebook 02)
- Deploy model via SPCS (Notebook 03)
- Set up ML Jobs and CI/CD (Notebook 04)


In [None]:
# ============================================================================
# END OF NOTEBOOK 2c
# ============================================================================

print("=" * 60)
print("‚úÖ Notebook 2c Complete: Distributed Data Parallel (DDP)")
print("=" * 60)
print()
print("üìä Key APIs Used:")
print("   ‚Ä¢ ShardedDataConnector.from_dataframe() ‚Äî Data loading")
print("   ‚Ä¢ PyTorchDistributor ‚Äî Distributed training orchestration")
print("   ‚Ä¢ PyTorchScalingConfig ‚Äî Resource configuration")
print("   ‚Ä¢ get_context() ‚Äî Rank, device, dataset access in workers")
print()
print("üöÄ DDP enables linear scaling across GPUs for large workloads")
print()
print("‚û°Ô∏è  For even larger models, explore FSDP (Fully Sharded Data Parallel)")
