In [1]:
# reload modules before executing cells
%load_ext autoreload
%autoreload 2

In [2]:

from tqdm.notebook import tqdm
import torch
from diffusers import StableDiffusionPipeline, DDIMScheduler
from diffusers import DiffusionPipeline, UNet2DConditionModel
from transformers import CLIPTextModel

import os
import numpy as np
from PIL import Image
from pathlib import Path

In [3]:

from pathlib import Path
from torchvision import transforms
from torch.utils.data import Dataset

class DreamBoothDataset(Dataset):
    def __init__(
        self,
        instance_data_root,
        instance_prompt,
        tokenizer,
        class_data_root=None,
        class_prompt=None,
        size=512,
        center_crop=False,
    ):
        self.size = size
        self.center_crop = center_crop
        self.tokenizer = tokenizer

        self.instance_data_root = Path(instance_data_root)
        if not self.instance_data_root.exists():
            raise ValueError("Instance images root doesn't exists.")

        self.instance_images_path = list(Path(instance_data_root).iterdir())
        self.num_instance_images = len(self.instance_images_path)
        self.instance_prompt = instance_prompt
        self._length = self.num_instance_images

        if class_data_root is not None:
            self.class_data_root = Path(class_data_root)
            self.class_data_root.mkdir(parents=True, exist_ok=True)
            self.class_images_path = list(Path(class_data_root).iterdir())
            self.num_class_images = len(self.class_images_path)
            self._length = max(self.num_class_images, self.num_instance_images)
            self.class_prompt = class_prompt
        else:
            self.class_data_root = None

        self.image_transforms = transforms.Compose(
            [
                transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
                transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
                transforms.ToTensor(),
            ]
        )

    def __len__(self):
        return self._length

    def __getitem__(self, index):
        example = {}
        instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
        if not instance_image.mode == "RGB":
            instance_image = instance_image.convert("RGB")
        example["instance_images"] = self.image_transforms(instance_image)
        example["instance_prompt_ids"] = self.tokenizer(
            self.instance_prompt,
            padding="do_not_pad",
            truncation=True,
            max_length=self.tokenizer.model_max_length,
        ).input_ids

        if self.class_data_root:
            class_image = Image.open(self.class_images_path[index % self.num_class_images])
            if not class_image.mode == "RGB":
                class_image = class_image.convert("RGB")
            example["class_images"] = self.image_transforms(class_image)
            example["class_prompt_ids"] = self.tokenizer(
                self.class_prompt,
                padding="do_not_pad",
                truncation=True,
                max_length=self.tokenizer.model_max_length,
            ).input_ids
        
        return example["instance_images"]


class PromptDataset(Dataset):
    def __init__(self, prompt, num_samples):
        self.prompt = prompt
        self.num_samples = num_samples

    def __len__(self):
        return self.num_samples

    def __getitem__(self, index):
        example = {}
        example["prompt"] = self.prompt
        example["index"] = index
        return example

In [4]:
LOW_RESOURCE = False 
NUM_DIFFUSION_STEPS = 50
GUIDANCE_SCALE = 7.5
MAX_NUM_WORDS = 77
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
model_id = "./logs/dog/prior_high_lr_again"
ldm_stable = StableDiffusionPipeline.from_pretrained(model_id).to(device)
tokenizer = ldm_stable.tokenizer

###########
instance_data_dir = "./data/dog/original"
instance_prompt = "a photo of <sks-dog>"

train_dataset = DreamBoothDataset(
    instance_data_root=instance_data_dir,
    instance_prompt=instance_prompt,
    class_data_root=None,
    class_prompt=None,
    tokenizer=tokenizer,
    size=512,
    center_crop=True,
)

train_dataloader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=len(train_dataset),
    shuffle=True,
)

#####
import PIL.Image as Image
from ignite.engine import Engine, Events
from ignite.metrics import FID, InceptionScore

fid_metric = FID(device=device)
is_metric = InceptionScore(device=device, output_transform=lambda x: x[0])

def interpolate(batch):
    arr = []
    for img in batch:
        pil_img = transforms.ToPILImage()(img.cpu())
        resized_img = pil_img.resize((299,299), Image.BILINEAR)
        arr.append(transforms.ToTensor()(resized_img))
    return torch.stack(arr)


def evaluation_step(engine, batch):
    with torch.no_grad():
        negative_prompt = "low quality, blurry, unfinished"
        fake_batch = ldm_stable(instance_prompt , num_images_per_prompt=len(train_dataset), output_type="pil", negative_prompt =negative_prompt).images
        fake_batch = torch.stack([transforms.ToTensor()(img) for img in fake_batch])
        print(fake_batch.shape, batch.shape)
        fake = interpolate(fake_batch)
        real = interpolate(batch)
        return fake, real

evaluator = Engine(evaluation_step)
fid_metric.attach(evaluator, "fid")
is_metric.attach(evaluator, "is")

# run the evaluator on the val data loader
evaluator.run(train_dataloader, max_epochs=1) 
metrics = evaluator.state.metrics
fid_score = metrics['fid']
is_score = metrics['is']
print("====> For the train data loader:")
print("FID score: {}".format(fid_score))
print("Inception score: {}".format(is_score))

  0%|          | 0/50 [00:00<?, ?it/s]

Potential NSFW content was detected in one or more images. A black image will be returned instead. Try again with a different prompt and/or seed.


torch.Size([18, 3, 512, 512]) torch.Size([18, 3, 512, 512])
====> For the train data loader:
FID score: 0.030175975406464967
Inception score: 1.7770968250203503


In [4]:
LOW_RESOURCE = False 
NUM_DIFFUSION_STEPS = 50
GUIDANCE_SCALE = 7.5
MAX_NUM_WORDS = 77
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
model_id = "CompVis/stable-diffusion-v1-4"
ldm_stable = StableDiffusionPipeline.from_pretrained(model_id).to(device)
ldm_stable.load_textual_inversion("./logs/dog/text_inv")
tokenizer = ldm_stable.tokenizer

###########
instance_data_dir = "./data/dog/original"
instance_prompt = "a photo of <sks-dog>"

train_dataset = DreamBoothDataset(
    instance_data_root=instance_data_dir,
    instance_prompt=instance_prompt,
    class_data_root=None,
    class_prompt=None,
    tokenizer=tokenizer,
    size=512,
    center_crop=True,
)

train_dataloader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=len(train_dataset),
    shuffle=True,
)

#####
import PIL.Image as Image
from ignite.engine import Engine, Events
from ignite.metrics import FID, InceptionScore

fid_metric = FID(device=device)
is_metric = InceptionScore(device=device, output_transform=lambda x: x[0])

def interpolate(batch):
    arr = []
    for img in batch:
        pil_img = transforms.ToPILImage()(img.cpu())
        resized_img = pil_img.resize((299,299), Image.BILINEAR)
        arr.append(transforms.ToTensor()(resized_img))
    return torch.stack(arr)


def evaluation_step(engine, batch):
    with torch.no_grad():
        negative_prompt = "low quality, blurry, unfinished"
        fake_batch = ldm_stable(instance_prompt , num_images_per_prompt=len(train_dataset), output_type="pil", negative_prompt =negative_prompt).images
        fake_batch = torch.stack([transforms.ToTensor()(img) for img in fake_batch])
        print(fake_batch.shape, batch.shape)
        fake = interpolate(fake_batch)
        real = interpolate(batch)
        return fake, real

evaluator = Engine(evaluation_step)
fid_metric.attach(evaluator, "fid")
is_metric.attach(evaluator, "is")

# run the evaluator on the val data loader
evaluator.run(train_dataloader, max_epochs=1) 
metrics = evaluator.state.metrics
fid_score = metrics['fid']
is_score = metrics['is']
print("====> For the train data loader:")
print("FID score: {}".format(fid_score))
print("Inception score: {}".format(is_score))

`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["id2label"]` will be overriden.


  0%|          | 0/50 [00:00<?, ?it/s]

torch.Size([18, 3, 512, 512]) torch.Size([18, 3, 512, 512])
====> For the train data loader:
FID score: 0.1030309235339662
Inception score: 1.5083089183497504


In [4]:
LOW_RESOURCE = False 
NUM_DIFFUSION_STEPS = 50
GUIDANCE_SCALE = 7.5
MAX_NUM_WORDS = 77
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
model_id = "./logs/dog/lora"
ldm_stable = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4").to(device)
ldm_stable.load_lora_weights(model_id)
tokenizer = ldm_stable.tokenizer

###########
instance_data_dir = "./data/dog/original"
instance_prompt = "a photo of <sks-dog>"

train_dataset = DreamBoothDataset(
    instance_data_root=instance_data_dir,
    instance_prompt=instance_prompt,
    class_data_root=None,
    class_prompt=None,
    tokenizer=tokenizer,
    size=512,
    center_crop=True,
)

train_dataloader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=len(train_dataset),
    shuffle=True,
)

#####
import PIL.Image as Image
from ignite.engine import Engine, Events
from ignite.metrics import FID, InceptionScore

fid_metric = FID(device=device)
is_metric = InceptionScore(device=device, output_transform=lambda x: x[0])

def interpolate(batch):
    arr = []
    for img in batch:
        pil_img = transforms.ToPILImage()(img.cpu())
        resized_img = pil_img.resize((299,299), Image.BILINEAR)
        arr.append(transforms.ToTensor()(resized_img))
    return torch.stack(arr)


def evaluation_step(engine, batch):
    with torch.no_grad():
        negative_prompt = "low quality, blurry, unfinished"
        fake_batch = ldm_stable(instance_prompt , num_images_per_prompt=len(train_dataset), output_type="pil", negative_prompt =negative_prompt).images
        fake_batch = torch.stack([transforms.ToTensor()(img) for img in fake_batch])
        print(fake_batch.shape, batch.shape)
        fake = interpolate(fake_batch)
        real = interpolate(batch)
        return fake, real

evaluator = Engine(evaluation_step)
fid_metric.attach(evaluator, "fid")
is_metric.attach(evaluator, "is")

# run the evaluator on the val data loader
evaluator.run(train_dataloader, max_epochs=1) 
metrics = evaluator.state.metrics
fid_score = metrics['fid']
is_score = metrics['is']
print("====> For the train data loader:")
print("FID score: {}".format(fid_score))
print("Inception score: {}".format(is_score))

`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["id2label"]` will be overriden.


  0%|          | 0/50 [00:00<?, ?it/s]

Potential NSFW content was detected in one or more images. A black image will be returned instead. Try again with a different prompt and/or seed.


torch.Size([18, 3, 512, 512]) torch.Size([18, 3, 512, 512])
====> For the train data loader:
FID score: 0.0519351787889373
Inception score: 1.5925326350414324
