In [8]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [9]:
import torch as th
from torch import nn
from torch.nn import functional as F
import lightning as L
from ddpm import DDPMSampler
from clip import CLIP
from encoder import VAE_Encoder
from decoder import VAE_Decoder
from diffusion import Diffusion
from utils import rescale
from datasets import load_dataset
from transformers import CLIPTokenizer
from torch.utils.data import Dataset, DataLoader

In [10]:
DEVICE = "cpu"
ALLOW_CUDA = True
ALLOW_MPS = False

if th.cuda.is_available() and ALLOW_CUDA:
    DEVICE = "cuda"
elif th.backends.mps.is_available() and ALLOW_MPS:
    DEVICE = "mps"
print(f"Using Device {DEVICE}")

Using Device cuda


In [11]:
WIDTH = 512
HEIGHT = 512
LATENT_WIDTH = WIDTH // 8
LATENT_HEIGHT = HEIGHT // 8

In [12]:
class StableDiffusionModel(nn.Module):

    def __init__(self, latent_height, latent_width, device="cpu", seed=42):
        super().__init__()
        self.latent_height = latent_height
        self.latent_width = latent_width
        self.device = th.device(device)
        th.generator.manual_seed(seed)
        self.clip = CLIP().to_device(device)
        self.encoder = VAE_Encoder()
        self.decoder = VAE_Decoder()
        self.diffusion = Diffusion()

    def forward(self, image=None, captions=None, tokenizer=None, strength=0.8):
        batch_size = 1
        latents_shape = (batch_size, 4, self.latent_height, self.latent_width)

        # Tokenize the sentences
        tokens = tokenizer(
            captions,
            padding="max_length",
            max_length=50,
            return_tensors="pt",
            truncation=True,
        ).input_ids.to_device(self.device)

        # Pass the tokens through clip
        # for adding more context to the tokens
        context = self.clip(tokens)

        # Setting up the DDPM Sampler
        generator = th.Generator(device=self.device)
        sampler = DDPMSampler(generator=generator)
        sampler.set_inference_timesteps(self.n_inference_steps)

        # If input image is present
        if image:
            encoder_noise = th.randn(
                layout=latents_shape, generator=generator, device=self.device
            )

            latents = self.encoder(images, encoder_noise)
            sampler.set_strength(strength)
            latents = sampler.add_noise(latents, sampler.timesteps[0])

        # If input image in not present
        else:
            latents = th.randn(
                layout=latents_shape, generator=generator, device=self.device
            )

        for timestep in sampler.timesteps:
            time_embedding = self.get_time_embedding(timestep).to(self.device)
            model_input = latents
            model_output = self.diffusion(model_input, context, time_embedding)
            latents = sampler.step(timestep, latents, model_output)

        decoder_output = self.decoder(latents)

        return rescale(decoder_output, (-1, 1), (0, 1), clamp=True), context

In [13]:
class StableDiffusionTraining(L.LightningModule):

    def __init__(self, lr=1e-3):
        super().__init__()
        WIDTH = 512
        HEIGHT = 512
        LATENT_WIDTH = WIDTH // 8
        LATENT_HEIGHT = HEIGHT // 8
        self.lr = lr
        self.model = StableDiffusionModel(
            latent_height=LATENT_HEIGHT, latent_width=LATENT_WIDTH
        )

    def training_step(self, batch, batch_idx):
        x, y = batch
        x_hat, context = self.model(x)
        x_hat = x_hat.flatten()
        y = y.flatten()
        loss = F.mse_loss(x_hat, y)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        x_hat, context = self.model(x)
        x_hat = x_hat.flatten()
        y = y.flatten()
        loss = F.mse_loss(x_hat, y)
        self.log("val_loss", loss)
        return loss

    def test_step(self, batch, batch_idx):
        x, y = batch
        x_hat, context = self.model(x)
        x_hat = x_hat.flatten()
        y = y.flatten()
        loss = F.mse_loss(x_hat, y)
        self.log("test_loss", loss)
        return loss

    def predict_step(self, batch, batch_idx):
        x, y = batch
        return self.model(x)[0]

    def configure_optimizers(self):
        optimizer = th.optim.Adam(self.model.paramters(), lr=self.lr)

In [20]:
class ImagePromptDataloader(L.LightningDataModule):
    def __init__(self, hugging_face_db_name="poloclub/diffusiondb", sample_size="2m_random_1k", train=True, test=True, validation=True, batch_size=4):
        super().__init__()
        self.db_name = hugging_face_db_name
        self.sample_size = sample_size
        self.train = train
        self.test  = test
        self.validation   = validation
        self.batch_size = batch_size

    def setup(self, stage=None):
        if self.train:
            self.train_dataset = load_dataset(self.db_name, trust_remote_code=True, split="train")
        if self.test:
            self.test_dataset = load_dataset(self.db_name, trust_remote_code=True, split="test")
        if self.validation:
            self.validation_dataset = load_dataset(self.db_name, trust_remote_code=True, split="validation")
    
    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_worker=2
            )
    
    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_worker=2
            )

    def val_dataloader(self):
        return DataLoader(
            self.validation_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_worker=2
            )

    def predict_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_worker=2
            )

In [21]:
data_module = ImagePromptDataloader()
data_module.setup()

ValueError: Unknown split "test". Should be one of ['train'].

In [None]:
model = 

In [7]:
# dataset = load_dataset("poloclub/diffusiondb", "2m_random_1k", trust_remote_code=True)

Downloading data:   0%|          | 0.00/662M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/195M [00:00<?, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]