In [4]:
import os
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
import torch
import torch.nn.functional as F
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

class LSUNBedroomDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = []

        # Traverse the directory to get image paths
        for subdir, _, files in os.walk(self.root_dir):
            for file in files:
                if file.endswith(('png', 'jpg', 'jpeg')):
                    self.image_paths.append(os.path.join(subdir, file))

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image

# Define transforms
transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor()
])

path = '../Dataset/bedroom/'
# Create dataset
dataset = LSUNBedroomDataset(root_dir='../Dataset/bedroom/', transform=transform)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)


In [6]:
import torch
import torch.nn.functional as F
from torch import nn
from torch.optim import Adam
from tqdm import tqdm


# Custom Noise Scheduler for SMLD
class SMLDScheduler:
    def __init__(self, num_train_timesteps):
        self.num_train_timesteps = num_train_timesteps

    def add_noise(self, images, noise, timesteps):
        alphas = self.get_alpha(timesteps).unsqueeze(1).unsqueeze(1).unsqueeze(1)
        return images + alphas * noise

    def get_alpha(self, t):
        # Linear schedule for simplicity
        return torch.linspace(1e-4, 0.02, self.num_train_timesteps).to(t.device)[t]

# Instantiate the model
model = UNet2DModel(
    sample_size=64,
    in_channels=3,
    out_channels=3,
    layers_per_block=1,
    block_out_channels=(32, 64, 128),
    down_block_types=("DownBlock2D", "DownBlock2D", "DownBlock2D"),
    up_block_types=("UpBlock2D", "UpBlock2D", "UpBlock2D")
)

# Set up the noise scheduler
noise_scheduler = SMLDScheduler(num_train_timesteps=1000)

# Training parameters
optimizer = Adam(model.parameters(), lr=1e-4)
num_epochs = 20
device = torch.device("mps")
model.to(device)

UNet2DModel(
  (conv_in): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (time_proj): Timesteps()
  (time_embedding): TimestepEmbedding(
    (linear_1): Linear(in_features=32, out_features=128, bias=True)
    (act): SiLU()
    (linear_2): Linear(in_features=128, out_features=128, bias=True)
  )
  (down_blocks): ModuleList(
    (0): DownBlock2D(
      (resnets): ModuleList(
        (0): ResnetBlock2D(
          (norm1): GroupNorm(32, 32, eps=1e-05, affine=True)
          (conv1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (time_emb_proj): Linear(in_features=128, out_features=32, bias=True)
          (norm2): GroupNorm(32, 32, eps=1e-05, affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
          (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (nonlinearity): SiLU()
        )
      )
      (downsamplers): ModuleList(
        (0): Downsample2D(
          (conv): Conv2d(32, 32, kernel_si

In [9]:
for epoch in range(num_epochs):
    model.train()
    for batch in tqdm(dataloader, desc=f"Training Epoch {epoch+1}/{num_epochs}"):
        optimizer.zero_grad()

        batch = batch.to(device)
        timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (batch.size(0),), device=device).long()
        noise = torch.randn_like(batch)
        noisy_images = noise_scheduler.add_noise(batch, noise, timesteps)
        
        noise_pred = model(noisy_images, timesteps).sample
        
        loss = F.mse_loss(noise_pred, noise)
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch+1} completed. Loss: {loss.item()}")

Training Epoch 1/20:   0%|          | 11/4737 [00:05<41:06,  1.92it/s] 


KeyboardInterrupt: 

In [None]:
# Save the model
torch.save(model.state_dict(), "smld_bedroom_64x64_20_epoch.pth")

In [None]:
import os
import torch
from tqdm import tqdm
from torchvision.utils import save_image

model.eval()
num_images = 50
generated_images = []

with torch.no_grad():
    for _ in tqdm(range(num_images), desc="Generating Images"):
        noisy_image = torch.randn(1, 3, 64, 64, device=device)
        for t in reversed(range(noise_scheduler.num_train_timesteps)):
            timesteps = torch.full((1,), t, device=device, dtype=torch.long)
            model_output = model(noisy_image, timesteps)
            
            step_result = noise_scheduler.step(model_output.sample, t, noisy_image)
            noisy_image = step_result.prev_sample
        
        generated_image = noisy_image.squeeze(0).cpu()
        generated_images.append(generated_image)

# Saving generated images
output_dir = f"generated_images_{num_epochs}_epochs"
os.makedirs(output_dir, exist_ok=True)

for idx, image in enumerate(generated_images):
    save_image(image, f"{output_dir}/generated_image_{idx+1}.png")

print(f"{num_images} images generated and saved in {output_dir}")

In [None]:
# plot 50 images from the generated images
plt.figure(figsize=(10, 10))
for i in range(50):
    plt.subplot(5, 10, i+1)
    plt.imshow(generated_images[i].permute(1, 2, 0))
    plt.axis('off')
plt.show()