## 2. Single GPU Training with PyTorch Lightning

### 2.1 Overview

We will start by fitting a Stable Diffusion `Unet` model to a preprocessed image and text dataset.

This diagram shows the full-scale training architecture.

<img src="https://anyscale-materials.s3.us-west-2.amazonaws.com/stable-diffusion/training_architecture_v3.jpeg" width="700px">

Regardless of scale, the process is primarily composed of three main stages:
1. **Loading the preprocessed data**
2. **Training the model**
3. **Storing the model checkpoints**


### 2.2. Create a torch dataloader 

Let's start by defining a dataset we want to use. 

We'll use `parquet` data that was generated using the same preprocessing pipeline.

In [None]:
class ParquetDataset(Dataset):
    """Minimal PyTorch Dataset for data stored in parquet."""

    def __init__(
        self,
        s3_path: str,
        dtype: np.dtype = np.float16,
    ):
        self.columns = ["image_latents_256", "caption_latents"]
        loaded_df = pd.read_parquet(
            s3_path,
            columns=["image_latents_256", "caption_latents"],
            filesystem=s3fs.S3FileSystem(anon=False),
        )
        loaded_df["image_latents_256"] = loaded_df["image_latents_256"].apply(
            lambda x: x.reshape(4, 32, 32).astype(dtype)
        )
        loaded_df["caption_latents"] = loaded_df["caption_latents"].apply(
            lambda x: x.reshape(77, 1024).astype(dtype)
        )
        self.df = loaded_df

    def __len__(self) -> int:
        return len(self.df)

    def __getitem__(self, idx: int):
        row = self.df.iloc[idx]
        # Convert to tensors so default collate stacks them automatically.
        return {col: torch.as_tensor(row[col]) for col in self.columns}

<div class="alert alert-block alert-warning">

**Note** this Dataset implementation is very simple and loads the entire dataset into memory. Not recommended for large datasets.

</div>

Let's proceed to build our custom map-style dataset.

In [None]:
artifact_storage_path = os.environ["ANYSCALE_ARTIFACT_STORAGE"]
s3_path = f"{artifact_storage_path}/stable-diffusion/256/6_000108_000000.parquet"
dataset = ParquetDataset(s3_path)

We construct a torch dataloader that will be used to train the model.

In [None]:
data_loader = DataLoader(
    dataset,
    batch_size=8,  # Adjust based on memory constraints
    shuffle=True,
    num_workers=2,  # Adjust based on system's CPU cores
)

We can inspect the batches to verify their shape and type.

In [None]:
for batch in data_loader:
    print(batch["image_latents_256"].shape, batch["caption_latents"].shape)
    print(batch["caption_latents"].dtype, batch["image_latents_256"].dtype)
    break  # Remove this to process the entire dataset

### 2.3 Define a stable diffusion model

This "standard" LightningModule does not explicitly refer to Ray or Ray Train, which makes migrating workloads easier.

In [None]:
small_unet_model_config = {
    "_class_name": "UNet2DConditionModel",
    "_diffusers_version": "0.2.2",
    "act_fn": "silu",
    "attention_head_dim": 8,
    "block_out_channels": [160, 320, 640, 640],
    "center_input_sample": False,
    "cross_attention_dim": 1024,
    "down_block_types": [
        "CrossAttnDownBlock2D",
        "CrossAttnDownBlock2D",
        "CrossAttnDownBlock2D",
        "DownBlock2D",
    ],
    "downsample_padding": 1,
    "flip_sin_to_cos": True,
    "freq_shift": 0,
    "in_channels": 4,
    "layers_per_block": 2,
    "mid_block_scale_factor": 1,
    "norm_eps": 1e-05,
    "norm_num_groups": 32,
    "out_channels": 4,
    "sample_size": 64,
    "up_block_types": [
        "UpBlock2D",
        "CrossAttnUpBlock2D",
        "CrossAttnUpBlock2D",
        "CrossAttnUpBlock2D",
    ],
}

class StableDiffusion(pl.LightningModule):
    def __init__(
        self,
        lr: float,
        resolution: int,
        weight_decay: float,
        num_warmup_steps: int,
        model_name: str,
    ) -> None:
        self.lr = lr
        self.resolution = resolution
        self.weight_decay = weight_decay
        self.num_warmup_steps = num_warmup_steps
        super().__init__()
        self.save_hyperparameters()
        # Initialize U-Net.
        # model_config = PretrainedConfig.get_config_dict(model_name, subfolder="unet")[0]
        model_config = small_unet_model_config
        self.unet = UNet2DConditionModel(**model_config)
        # Define the training noise scheduler.
        self.noise_scheduler = DDPMScheduler.from_pretrained(
            model_name, subfolder="scheduler"
        )
        # Setup loss function.
        self.loss_fn = F.mse_loss
        self.current_training_steps = 0

    def on_fit_start(self) -> None:
        """Move cumprod tensor to GPU in advance to avoid data movement on each step."""
        self.noise_scheduler.alphas_cumprod = self.noise_scheduler.alphas_cumprod.to(
            get_device()
        )

    def forward(
        self, batch: dict[str, torch.Tensor]
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """Forward pass of the model."""
        # Extract inputs.
        latents = batch["image_latents_256"]
        conditioning = batch["caption_latents"]
        # Sample the diffusion timesteps.
        timesteps = self._sample_timesteps(latents)
        # Add noise to the inputs (forward diffusion).
        noise = torch.randn_like(latents)
        noised_latents = self.noise_scheduler.add_noise(latents, noise, timesteps)
        # Forward through the model.
        outputs = self.unet(noised_latents, timesteps, conditioning)["sample"]
        return outputs, noise

    def training_step(
        self, batch: dict[str, torch.Tensor], batch_idx: int
    ) -> torch.Tensor:
        """Training step of the model."""
        outputs, targets = self.forward(batch)
        loss = self.loss_fn(outputs, targets)
        self.log(
            "train/loss_mse", loss.item(), prog_bar=False, on_step=True, sync_dist=False
        )
        self.current_training_steps += 1
        return loss

    def configure_optimizers(self) -> OptimizerLRScheduler:
        """Configure the optimizer and learning rate scheduler."""
        optimizer = torch.optim.AdamW(
            self.trainer.model.parameters(),
            lr=self.lr,
            weight_decay=self.weight_decay,
        )
        # Set a large training step here to keep lr constant after warm-up.
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=self.num_warmup_steps,
            num_training_steps=100000000000,
        )
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": "step",
                "frequency": 1,
            },
        }

    def _sample_timesteps(self, latents: torch.Tensor) -> torch.Tensor:
        return torch.randint(
            0, len(self.noise_scheduler), (latents.shape[0],), device=latents.device
        )

### 2.4. Define a PyTorch Lightning training loop

Here is a training loop that is specific to PyTorch Lightning.

It performs the following steps:
1. **Model Initialization:**
   - Instantiate the diffusion model.
2. **Trainer Setup:**
   - Instantiate the Lightning Trainer with a `DDPStrategy` to perform data parallel training.
3. **Training Execution:**
   - Run the trainer using the `fit` method.

In [None]:
def lightning_training_loop(
    train_loader: torch.utils.data.DataLoader,
    storage_path: str,
    model_name: str = "stabilityai/stable-diffusion-2-base",
    resolution: int = 256,
    lr: float = 1e-4,
    max_epochs: int = 1,
    num_warmup_steps: int = 10_000,
    weight_decay: float = 1e-2,
) -> None:
    # 1. Initialize the model
    model = StableDiffusion(
        model_name=model_name,
        resolution=resolution,
        lr=lr,
        num_warmup_steps=num_warmup_steps,
        weight_decay=weight_decay,
    )

    # 2. Initialize the Lightning Trainer
    trainer = pl.Trainer(
        accelerator="gpu",
        devices="auto",
        precision="bf16-mixed",
        max_epochs=max_epochs,
        default_root_dir=storage_path,
        log_every_n_steps=8,
    )

    # 3. Run the trainer
    trainer.fit(model=model, train_dataloaders=train_loader)


Here is how we would run the lightning training loop on a single GPU.

In [None]:
storage_path = "/mnt/local_storage/lightning/stable-diffusion-pretraining/"
lightning_training_loop(train_loader=data_loader, storage_path=storage_path)

Let's inspect the storage path to see what files were created.

In [None]:
!tree {storage_path}