In [None]:
!pip install -q diffusers transformers accelerate

import os
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.utils import save_image
from PIL import Image
from tqdm import tqdm
from diffusers import AutoencoderKL, UNet2DModel, DDIMScheduler
from accelerate import Accelerator

# ============== 1. Dataset ==============
class MangoDataset(Dataset):
    def __init__(self, root_dir):
        self.paths = [os.path.join(root_dir, f) for f in os.listdir(root_dir)
                      if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
        self.transform = transforms.Compose([
            transforms.Resize((512, 512)),
            transforms.ToTensor(),
            transforms.Normalize([0.5]*3, [0.5]*3)
        ])

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        img = Image.open(self.paths[idx]).convert("RGB")
        return self.transform(img)

data_path = "/kaggle/input/rganhealthy512x512/HEALTHY"
dataset = MangoDataset(data_path)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

# ============== 2. Load Models ==============
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to("cuda").eval()
unet = UNet2DModel(
    sample_size=64,
    in_channels=4,
    out_channels=4,
    layers_per_block=2,
    block_out_channels=(128, 256, 512, 512),
    down_block_types=("DownBlock2D", "DownBlock2D", "DownBlock2D", "AttnDownBlock2D"),
    up_block_types=("AttnUpBlock2D", "UpBlock2D", "UpBlock2D", "UpBlock2D")
)
scheduler = DDIMScheduler(num_train_timesteps=1000)

# ============== 3. Optimizer & Accelerator ==============
optimizer = torch.optim.AdamW(unet.parameters(), lr=1e-4)
accelerator = Accelerator()
unet, optimizer, dataloader = accelerator.prepare(unet, optimizer, dataloader)

# ============== 4. Training ==============
unet.train()
num_epochs = 100

for epoch in range(1, num_epochs + 1):
    print(f"🔥 Epoch {epoch}/{num_epochs}")
    for batch in tqdm(dataloader, desc=f"Epoch {epoch}"):
        with torch.no_grad():
            latents = vae.encode(batch.to(accelerator.device)).latent_dist.sample() * 0.18215
        noise = torch.randn_like(latents)
        timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (latents.size(0),), device=latents.device).long()
        noisy_latents = scheduler.add_noise(latents, noise, timesteps)

        noise_pred = unet(noisy_latents, timesteps).sample
        loss = nn.functional.mse_loss(noise_pred, noise)

        optimizer.zero_grad()
        accelerator.backward(loss)
        optimizer.step()

    # ========== Save Sample ==========
    if epoch % 10 == 0:
        unet.eval()
        with torch.no_grad():
            sample = torch.randn(1, 4, 64, 64).to(accelerator.device)
            for t in scheduler.timesteps:
                noise_pred = unet(sample, t).sample
                sample = scheduler.step(noise_pred, t, sample).prev_sample
            decoded = vae.decode(sample / 0.18215).sample
            save_image(decoded, f"sample_epoch_{epoch}.png", normalize=True)
        unet.train()

    # ========== Save Checkpoint ==========
    if epoch % 10 == 0:
        unet.save_pretrained(f"ldm_unet_epoch_{epoch}")


[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m4.7 MB/s[0m eta [36m0:00:00[0m0:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0m0:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m6.5 MB/s[0m eta [36m0:00:00[0m0:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m31.4 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m127.9/127.9 MB[0m [31m13.5 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m207.5/207.5 MB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0m0:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m21.1/21.1 MB[0m [31m91.1 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[?25h

2025-06-19 05:31:20.954447: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1750311081.197606      35 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1750311081.314131      35 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


config.json:   0%|          | 0.00/547 [00:00<?, ?B/s]

diffusion_pytorch_model.safetensors:   0%|          | 0.00/335M [00:00<?, ?B/s]

🔥 Epoch 1/100


Epoch 1: 100%|██████████| 2988/2988 [30:06<00:00,  1.65it/s]


🔥 Epoch 2/100


Epoch 2: 100%|██████████| 2988/2988 [28:23<00:00,  1.75it/s]


🔥 Epoch 3/100


Epoch 3: 100%|██████████| 2988/2988 [28:27<00:00,  1.75it/s]


🔥 Epoch 4/100


Epoch 4: 100%|██████████| 2988/2988 [28:26<00:00,  1.75it/s]


🔥 Epoch 5/100


Epoch 5: 100%|██████████| 2988/2988 [28:29<00:00,  1.75it/s]


🔥 Epoch 6/100


Epoch 6:  11%|█         | 317/2988 [03:01<25:24,  1.75it/s]