In [1]:
import os

import lightning.pytorch as pl
import numpy as np
import pandas as pd
import s3fs
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from diffusers import DDPMScheduler, UNet2DConditionModel
from lightning.pytorch.utilities.types import OptimizerLRScheduler
from transformers import get_linear_schedule_with_warmup

import ray.train
from ray.train.lightning import (
    RayDDPStrategy,
    RayLightningEnvironment,
    RayTrainReportCallback,
)
from ray.train.torch import TorchTrainer, get_device

In [2]:
class ParquetDataset(Dataset):
    """Minimal PyTorch Dataset for data stored in parquet."""

    def __init__(
        self,
        s3_path: str,
        dtype: np.dtype = np.float16,
    ):
        self.columns = ["image_latents_256", "caption_latents"]
        loaded_df = pd.read_parquet(
            s3_path,
            columns=["image_latents_256", "caption_latents"],
            filesystem=s3fs.S3FileSystem(anon=False),
        )
        loaded_df["image_latents_256"] = loaded_df["image_latents_256"].apply(
            lambda x: x.reshape(4, 32, 32).astype(dtype)
        )
        loaded_df["caption_latents"] = loaded_df["caption_latents"].apply(
            lambda x: x.reshape(77, 1024).astype(dtype)
        )
        self.df = loaded_df

    def __len__(self) -> int:
        return len(self.df)

    def __getitem__(self, idx: int):
        row = self.df.iloc[idx]
        return {col: torch.as_tensor(row[col]) for col in self.columns}

In [None]:
artifact_storage_path = os.environ["ANYSCALE_ARTIFACT_STORAGE"]
s3_path = f"{artifact_storage_path}/stable-diffusion/256/6_000108_000000.parquet"
dataset = ParquetDataset(s3_path)

In [None]:
data_loader = DataLoader(
    dataset,
    batch_size=8,  # Adjust based on memory constraints
    shuffle=True,
    num_workers=2,  # Adjust based on system's CPU cores
)

In [None]:
for batch in data_loader:
    print(batch["image_latents_256"].shape, batch["caption_latents"].shape)
    print(batch["caption_latents"].dtype, batch["image_latents_256"].dtype)
    break

### Stable difussion model

In [None]:
small_unet_model_config = {
    "_class_name": "UNet2DConditionModel",
    "_diffusers_version": "0.2.2",
    "act_fn": "silu",
    "attention_head_dim": 8,
    "block_out_channels": [160, 320, 640, 640],
    "center_input_sample": False,
    "cross_attention_dim": 1024,
    "down_block_types": [
        "CrossAttnDownBlock2D",
        "CrossAttnDownBlock2D",
        "CrossAttnDownBlock2D",
        "DownBlock2D",
    ],
    "downsample_padding": 1,
    "flip_sin_to_cos": True,
    "freq_shift": 0,
    "in_channels": 4,
    "layers_per_block": 2,
    "mid_block_scale_factor": 1,
    "norm_eps": 1e-05,
    "norm_num_groups": 32,
    "out_channels": 4,
    "sample_size": 64,
    "up_block_types": [
        "UpBlock2D",
        "CrossAttnUpBlock2D",
        "CrossAttnUpBlock2D",
        "CrossAttnUpBlock2D",
    ],
}

class StableDiffusion(pl.LightningModule):
    def __init__(
        self,
        lr: float,
        resolution: int,
        weight_decay: float,
        num_warmup_steps: int,
        model_name: str,
    ) -> None:
        self.lr = lr
        self.resolution = resolution
        self.weight_decay = weight_decay
        self.num_warmup_steps = num_warmup_steps
        super().__init__()
        self.save_hyperparameters()
        # Initialize U-Net.
        # model_config = PretrainedConfig.get_config_dict(model_name, subfolder="unet")[0]
        model_config = small_unet_model_config
        self.unet = UNet2DConditionModel(**model_config)
        # Define the training noise scheduler.
        self.noise_scheduler = DDPMScheduler.from_pretrained(
            model_name, subfolder="scheduler"
        )
        # Setup loss function.
        self.loss_fn = F.mse_loss
        self.current_training_steps = 0

    def on_fit_start(self) -> None:
        """Move cumprod tensor to GPU in advance to avoid data movement on each step."""
        self.noise_scheduler.alphas_cumprod = self.noise_scheduler.alphas_cumprod.to(
            get_device()
        )

    def forward(
        self, batch: dict[str, torch.Tensor]
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """Forward pass of the model."""
        # Extract inputs.
        latents = batch["image_latents_256"]
        conditioning = batch["caption_latents"]
        # Sample the diffusion timesteps.
        timesteps = self._sample_timesteps(latents)
        # Add noise to the inputs (forward diffusion).
        noise = torch.randn_like(latents)
        noised_latents = self.noise_scheduler.add_noise(latents, noise, timesteps)
        # Forward through the model.
        outputs = self.unet(noised_latents, timesteps, conditioning)["sample"]
        return outputs, noise

    def training_step(
        self, batch: dict[str, torch.Tensor], batch_idx: int
    ) -> torch.Tensor:
        """Training step of the model."""
        outputs, targets = self.forward(batch)
        loss = self.loss_fn(outputs, targets)
        self.log(
            "train/loss_mse", loss.item(), prog_bar=False, on_step=True, sync_dist=False
        )
        self.current_training_steps += 1
        return loss

    def configure_optimizers(self) -> OptimizerLRScheduler:
        """Configure the optimizer and learning rate scheduler."""
        optimizer = torch.optim.AdamW(
            self.trainer.model.parameters(),
            lr=self.lr,
            weight_decay=self.weight_decay,
        )
        # Set a large training step here to keep lr constant after warm-up.
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=self.num_warmup_steps,
            num_training_steps=100000000000,
        )
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": "step",
                "frequency": 1,
            },
        }

    def _sample_timesteps(self, latents: torch.Tensor) -> torch.Tensor:
        return torch.randint(
            0, len(self.noise_scheduler), (latents.shape[0],), device=latents.device
        )

### Train loop

In [None]:
def lightning_training_loop(
    train_loader: torch.utils.data.DataLoader,
    storage_path: str,
    model_name: str = "stabilityai/stable-diffusion-2-base",
    resolution: int = 256,
    lr: float = 1e-4,
    max_epochs: int = 1,
    num_warmup_steps: int = 10_000,
    weight_decay: float = 1e-2,
) -> None:
    # 1. Initialize the model
    model = StableDiffusion(
        model_name=model_name,
        resolution=resolution,
        lr=lr,
        num_warmup_steps=num_warmup_steps,
        weight_decay=weight_decay,
    )

    # 2. Initialize the Lightning Trainer
    trainer = pl.Trainer(
        accelerator="gpu",
        devices="auto",
        precision="bf16-mixed",
        max_epochs=max_epochs,
        default_root_dir=storage_path,
        log_every_n_steps=8,
    )

    # 3. Run the trainer
    trainer.fit(model=model, train_dataloaders=train_loader)

storage_path = "/mnt/local_storage/lightning/stable-diffusion-pretraining/"
lightning_training_loop(train_loader=data_loader, storage_path=storage_path)

### DDP

In [None]:
def train_loop_per_worker(
    config: dict,  # Update the function signature to comply with Ray Train
):
    # Note lightning prepares the dataloader (adding a distributed sampler) for us
    train_dataloader = DataLoader(
        dataset,
        batch_size=config["batch_size_per_worker"],
        shuffle=True,
        num_workers=2,
    )

    # Same model initialization as vanilla lightning
    model = StableDiffusion(
        lr=config["lr"],
        resolution=config["resolution"],
        weight_decay=config["weight_decay"],
        num_warmup_steps=config["num_warmup_steps"],
        model_name=config["model_name"],
    )

    # Same trainer setup as vanilla lightning except we add Ray Train specific arguments
    trainer = pl.Trainer(
        max_steps=config["max_steps"],
        max_epochs=config["max_epochs"],
        accelerator="gpu",
        precision="bf16-mixed",
        devices="auto",  # Set devices to "auto" to use all available GPUs
        strategy=RayDDPStrategy(),  # Use RayDDPStrategy for distributed data parallel training
        plugins=[
            RayLightningEnvironment()
        ],  # Use RayLightningEnvironment to run the Lightning Trainer
        callbacks=[
            RayTrainReportCallback()
        ],  # Use RayTrainReportCallback to report metrics and checkpoints
        enable_checkpointing=False,  # Disable lightning checkpointing
    )

    # 4. Same as vanilla lightning
    trainer.fit(model, train_dataloaders=train_dataloader)

In [None]:
scaling_config = ray.train.ScalingConfig(num_workers=2, use_gpu=True)


In [None]:
storage_path = "/mnt/cluster_storage/"
experiment_name = "stable-diffusion-pretraining"

run_config = ray.train.RunConfig(name=experiment_name, storage_path=storage_path)

train_loop_config = {
    "batch_size_per_worker": 8,
    "prefetch_batches": 1,
    "lr": 0.0001,
    "num_warmup_steps": 10_000,
    "weight_decay": 0.01,
    "max_steps": 550_000,
    "max_epochs": 1,
    "resolution": 256,
    "model_name": "stabilityai/stable-diffusion-2-base",
}

trainer = TorchTrainer(
    train_loop_per_worker,
    train_loop_config=train_loop_config,
    scaling_config=scaling_config,
    run_config=run_config,
)


result = trainer.fit()
print(f"Training completed with result: {result}")
result.metrics_dataframe


In [None]:
ckpt = result.checkpoint
with ckpt.as_directory() as ckpt_dir:
    ckpt_path = os.path.join(ckpt_dir, "checkpoint.ckpt")
    loaded_model_ray_train = StableDiffusion.load_from_checkpoint(
        checkpoint_path=ckpt_path,
        map_location=torch.device("cpu"),
        weights_only=True,
    )
    loaded_model_ray_train.eval()

loaded_model_ray_train