### For the following several blocks, you just need to run them, no need to change anything.

In [None]:
!git clone https://github.com/huggingface/diffusers
%cd diffusers
!pip install .

In [None]:
%cd examples/text_to_image

In [None]:
!pip install -r requirements.txt
! pip install wandb
!pip install torchmetrics[image]
!pip install torch-fidelity

### You need to login the wandb account here.

In [6]:
import wandb
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33msalzfischyi[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

### For the following several blocks, you just need to run them, no need to change anything.

In [7]:
import argparse
import logging
import math
import os
import random
import shutil
from pathlib import Path

import datasets
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
import transformers
from accelerate import Accelerator
# 如何使用get_logger获得log信息？
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from datasets import load_dataset
from huggingface_hub import create_repo, upload_folder
from packaging import version
from torchvision import transforms
# tqdm 是一个用于在 Python 中显示进度条的库，它通常在循环中使用，以便用户可以实时看到代码的执行进度。
from tqdm.auto import tqdm
# CLIPTextModel： 它是一个用于处理文本的预训练模型，可以接收输入文本并生成对应的文本表示。你可以使用这个模型来提取文本特征或进行文本相关的任务
# CLIPTokenizer： 用于对输入文本进行标记化（tokenization）的 CLIP 模型的标记器
from transformers import CLIPTextModel, CLIPTokenizer

import diffusers
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
from diffusers.models.lora import LoRALinearLayer
from diffusers.optimization import get_scheduler
from diffusers.training_utils import compute_snr
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available


# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.24.0.dev0")

logger = get_logger(__name__, log_level="INFO")

In [None]:
from datasets import load_dataset
from diffusers import StableDiffusionPipeline
import torch
from torchmetrics.functional.multimodal import clip_score
from functools import partial
from PIL import Image
import os
import numpy as np
import torch.nn as nn
from torchvision.transforms import functional as F
from torchmetrics.image.fid import FrechetInceptionDistance

device = 'cuda'
weight_dtype = torch.float16
# load the FID score function
fid = FrechetInceptionDistance(normalize=True)
# load the clip score function
clip_score_fn = partial(clip_score, model_name_or_path="openai/clip-vit-base-patch16")
# first load our learnt unet to the stable diffusion pipline 
# images = pipline(prompts,....) then pass images into the func
# prompts = ["...","...",...]
def calculate_clip_score(images, prompts):
    images_int = (images * 255).astype("uint8")
    clip_score = clip_score_fn(torch.from_numpy(images_int).permute(0, 3, 1, 2), prompts).detach()
    return round(float(clip_score), 4)

# image process func before calculating FID score
def preprocess_image(image):
    image = torch.tensor(image).unsqueeze(0)
    image = image.permute(0, 3, 1, 2) / 255.0
    return F.center_crop(image, (256, 256))

In [None]:
# create prompts with ["","",...] from our test dataset
import csv
# TODO: test_csv_path: the path to test_emoji.csv
test_csv_path = "/content/drive/MyDrive/huggingface_diffuser/data/validation_emoji.csv"
test_file = open(test_csv_path)
csvreader = csv.reader(test_file)
test_prompts = []
png_path = []
header = next(csvreader)
for row in csvreader:
  test_prompts.append(row[0]) # maybe need to change according to whether the 1st or 2nd column is the prompt
  png_path.append(row[1])
test_file.close()
print(test_prompts)
print(png_path)

In [8]:
#This function should be removed once training scripts are rewritten in PEFT
def text_encoder_lora_state_dict(text_encoder):
    state_dict = {}

    def text_encoder_attn_modules(text_encoder):
        from transformers import CLIPTextModel, CLIPTextModelWithProjection

        attn_modules = []

        if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
            for i, layer in enumerate(text_encoder.text_model.encoder.layers):
                name = f"text_model.encoder.layers.{i}.self_attn"
                mod = layer.self_attn
                attn_modules.append((name, mod))

        return attn_modules

    for name, module in text_encoder_attn_modules(text_encoder):
        for k, v in module.q_proj.lora_linear_layer.state_dict().items():
            state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v

        for k, v in module.k_proj.lora_linear_layer.state_dict().items():
            state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v

        for k, v in module.v_proj.lora_linear_layer.state_dict().items():
            state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v

        for k, v in module.out_proj.lora_linear_layer.state_dict().items():
            state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v

    return state_dict


def save_model_card(repo_id: str, images=None, base_model=str, dataset_name=str, repo_folder=None):
    img_str = ""
    for i, image in enumerate(images):
        image.save(os.path.join(repo_folder, f"image_{i}.png"))
        img_str += f"![img_{i}](./image_{i}.png)\n"

    yaml = f"""
---
license: creativeml-openrail-m
base_model: {base_model}
tags:
- stable-diffusion
- stable-diffusion-diffusers
- text-to-image
- diffusers
- lora
inference: true
---
    """
    model_card = f"""
# LoRA text2image fine-tuning - {repo_id}
These are LoRA adaption weights for {base_model}. The weights were fine-tuned on the {dataset_name} dataset. You can find some example images in the following. \n
{img_str}
"""
    with open(os.path.join(repo_folder, "README.md"), "w") as f:
        f.write(yaml + model_card)

### For the following block, you will connect with your google drive.

In [9]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


### For TODO here, You need to change train_data_dir to the corresponding path in your own google drive.
### Also, you can change the parameter and validation_prompt here.

In [10]:
pretrained_model_name_or_path = "runwayml/stable-diffusion-v1-5"
#dataset_name = "lambdalabs/pokemon-blip-captions"
dataset_name = None #our own dataset
revision = None
#train_data_dir = None
### TODO: change it to the directory where you save your images and csv ###
train_data_dir = "/content/drive/MyDrive/project182/Apple/train_set"
variant = None
dataset_config_name = None
image_column = "image"
caption_column = "text"
#validation_prompt = "Totoro"  # 这个可以多加几个吗，可以用list保存
# validation_prompt = ["grinning face with sweat in Apple style",
#           "grinning face with sweat in Facebook style",
#           "grinning face with sweat in Google style",
#           "grinning face with sweat in JoyPixels style",
#          "grinning face with sweat in Samsung style",
#         "grinning face with sweat in Twitter style",
#         "grinning face with sweat in Windows style"]
### TODO: change it to whatever you want
validation_prompt = ["grinning face with sweat emoji"]
num_validation_images = 1
validation_epochs = 1
max_train_samples = None
### TODO: change it to the directory where you save your output model ###
output_dir = "/content/drive/MyDrive/project182/Apple/model/checkpoints/"
cache_dir = None
resolution = 256
center_crop = False
random_flip = False
train_batch_size = 1
gradient_accumulation_steps = 1
### TODO: can change those two ###
max_train_steps = 15000
num_train_epochs = 100
gradient_checkpointing = False
learning_rate = 1e-4
scale_lr = False
lr_scheduler = "cosine"
lr_warmup_steps = 0
snr_gamma = None
use_8bit_adam = False
allow_tf32 = False
dataloader_num_workers = 2
adam_beta1 = 0.9
adam_beta2 = 0.999
adam_weight_decay = 1e-2
adam_epsilon = 1e-08
max_grad_norm = 1
push_to_hub = False
hub_model_id = None
prediction_type = None
mixed_precision = "fp16"
report_to = "wandb" # tensorborad
local_rank = -1
### TODO: you can change it ###
checkpointing_steps = 500
checkpoints_total_limit = None
# "Whether training should be resumed from a previous checkpoint. Use a path saved by"
           # ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
resume_from_checkpoint = None
enable_xformers_memory_efficient_attention = False
noise_offset = 0
logging_dir = "logs"
# TODO: lora rank ; we need to run different experiments on that ###
rank = 4
seed = 1337
# DATASET_NAME_MAPPING = {
#     "lambdalabs/pokemon-blip-captions": ("image", "text"),
# }
DATASET_NAME_MAPPING = {
    train_data_dir: ("image", "text"),
}

### For the following several blocks, you just need to run them without any modification.

In [None]:
set_seed(seed)
# Load scheduler, tokenizer and models.
# 将pretrained diffusion model 的各个部分都load进来
noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
tokenizer = CLIPTokenizer.from_pretrained(
    pretrained_model_name_or_path, subfolder="tokenizer", revision=revision
)
text_encoder = CLIPTextModel.from_pretrained(
    pretrained_model_name_or_path, subfolder="text_encoder", revision=revision
)
vae = AutoencoderKL.from_pretrained(
    pretrained_model_name_or_path, subfolder="vae", revision=revision, variant=variant
)
unet = UNet2DConditionModel.from_pretrained(
    pretrained_model_name_or_path, subfolder="unet", revision=revision, variant=variant
)
# freeze parameters of models to save more memory
# 即我们不需要改Model原先的parameters,我们只需要改lora层的parameters
unet.requires_grad_(False)
vae.requires_grad_(False)
text_encoder.requires_grad_(False)


### Here is the Lora layer setting.

In [12]:
# 开始finetune!!! 调整的是unet中的weights!
# now we will add new LoRA weights to the attention layers
# It's important to realize here how many attention weights will be added and of which sizes
# The sizes of the attention layers consist only of two different variables:

# 1) - the "hidden_size", which is increased according to `unet.config.block_out_channels`.
# 2) - the "cross attention size", which is set to `unet.config.cross_attention_dim`.

# Let's first see how many attention processors we will have to set.
# For Stable Diffusion, it should be equal to:
# - down blocks (2x attention layers) * (2x transformer layers) * (3x down blocks) = 12
# - mid blocks (2x attention layers) * (1x transformer layers) * (1x mid blocks) = 2
# - up blocks (2x attention layers) * (3x transformer layers) * (3x down blocks) = 18
# => 32 layers

# Set correct lora layers
unet_lora_parameters = []
# attn_processor_name is the key
# attn_processor is the value
for attn_processor_name, attn_processor in unet.attn_processors.items():
    # Parse the attention module.
    attn_module = unet
    for n in attn_processor_name.split(".")[:-1]:
        attn_module = getattr(attn_module, n)

    # Set the `lora_layer` attribute of the attention-related matrices.
    attn_module.to_q.set_lora_layer(
        LoRALinearLayer(
            in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=rank
        )
    )
    attn_module.to_k.set_lora_layer(
        LoRALinearLayer(
            in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=rank
        )
    )

    attn_module.to_v.set_lora_layer(
        LoRALinearLayer(
            in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=rank
        )
    )
    attn_module.to_out[0].set_lora_layer(
        LoRALinearLayer(
            in_features=attn_module.to_out[0].in_features,
            out_features=attn_module.to_out[0].out_features,
            rank=rank,
        )
    )

    # Accumulate the LoRA params to optimize.
    unet_lora_parameters.extend(attn_module.to_q.lora_layer.parameters())
    unet_lora_parameters.extend(attn_module.to_k.lora_layer.parameters())
    unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters())
    unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters())

optimizer_cls = torch.optim.AdamW
optimizer = optimizer_cls(
    unet_lora_parameters,
    lr=learning_rate,
    betas=(adam_beta1,adam_beta2),
    weight_decay=adam_weight_decay,
    eps=adam_epsilon,
)

In [13]:
from datasets import Dataset
import pandas as pd
from PIL import Image
class ImageCaptionDataset(Dataset):
    def __init__(self, csv_file, image_folder, transform=None):
        self.df = pd.read_csv(csv_file)
        self.image_folder = image_folder
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        img_name = os.path.join(self.image_folder, self.df.iloc[idx, 0])
        image = Image.open(img_name).convert("RGB")

        if self.transform:
            image = self.transform(image)

        text = self.df.iloc[idx, 1]
        print(type(image))
        return {"image": image, "text": text}

In [14]:
def convert_to_hf_dataset(dataset):
    hf_dataset = Dataset.from_pandas(dataset.df)
    return hf_dataset

# Preprocessing the datasets.
# We need to tokenize input captions and transform the images.
def tokenize_captions(examples, is_train=True):
    captions = []
    for caption in examples[caption_column]:
        if isinstance(caption, str):
            captions.append(caption)
        elif isinstance(caption, (list, np.ndarray)):
            # take a random caption if there are multiple
            captions.append(random.choice(caption) if is_train else caption[0])
        else:
            raise ValueError(
                f"Caption column `{caption_column}` should contain either strings or lists of strings."
            )
    inputs = tokenizer(
        captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
    )
    return inputs.input_ids

def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
    input_ids = torch.stack([example["input_ids"] for example in examples])
    return {"pixel_values": pixel_values, "input_ids": input_ids}

### For the TODO here, you will need to change the paths to the corresponding path in your own google drive

In [15]:
# construct train and validation dataset
# TODO: image_folder is the place where images stored; csv should be in the same dir with images
csv_path = "/content/drive/MyDrive/project182/Apple/train_set/train_emoji.csv"
image_folder = "/content/drive/MyDrive/project182/Apple/train_set/"
csv_path_valid = "/content/drive/MyDrive/project182/Apple/validation_set/validation_emoji.csv"
image_folder_valid = "/content/drive/MyDrive/project182/Apple/validation_set/"
csv_path_test = "/content/drive/MyDrive/project182/Apple/test_set/test_emoji.csv"
image_folder_test = "/content/drive/MyDrive/project182/Apple/test_set/"
dataset = ImageCaptionDataset(csv_file=csv_path, image_folder=image_folder, transform=transforms.ToTensor())
valid_dataset = ImageCaptionDataset(csv_file=csv_path_valid, image_folder=image_folder_valid, transform=transforms.ToTensor())
test_dataset = ImageCaptionDataset(csv_file=csv_path_test, image_folder=image_folder_test, transform=transforms.ToTensor())
dataset = convert_to_hf_dataset(dataset)
valid_dataset = convert_to_hf_dataset(valid_dataset)
test_dataset = convert_to_hf_dataset(test_dataset)

# Preprocessing the datasets.
train_transforms = transforms.Compose(
    [
        transforms.Resize(resolution, interpolation=transforms.InterpolationMode.BILINEAR),
        transforms.CenterCrop(resolution) if center_crop else transforms.RandomCrop(resolution),
        transforms.RandomHorizontalFlip() if random_flip else transforms.Lambda(lambda x: x),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5]),
    ]
)


def preprocess_train(examples):
    images = [Image.open(image_folder+image).convert("RGB") for image in examples[image_column]]
    examples["pixel_values"] = [train_transforms(image) for image in images]
    examples["input_ids"] = tokenize_captions(examples)
    return examples


def preprocess_valid(examples):
    images = [Image.open(image_folder_valid+image).convert("RGB") for image in examples[image_column]]
    examples["pixel_values"] = [train_transforms(image) for image in images]
    examples["input_ids"] = tokenize_captions(examples)
    return examples

def preprocess_test(examples):
    images = [Image.open(image_folder_test+image).convert("RGB") for image in examples[image_column]]
    examples["pixel_values"] = [train_transforms(image) for image in images]
    examples["input_ids"] = tokenize_captions(examples)
    return examples

# Set the training transforms
train_dataset = dataset.with_transform(preprocess_train)
valid_dataset = valid_dataset.with_transform(preprocess_valid)
test_dataset = test_dataset.with_transform(preprocess_test)

In [16]:
# DataLoaders creation:
train_dataloader = torch.utils.data.DataLoader(
    train_dataset,
    shuffle=True,
    collate_fn=collate_fn,
    batch_size=train_batch_size,
    num_workers=dataloader_num_workers,
)

valid_dataloader = torch.utils.data.DataLoader(
    valid_dataset,
    shuffle=False,
    collate_fn=collate_fn,
    batch_size=1,
    num_workers=dataloader_num_workers,
)

test_dataloader = torch.utils.data.DataLoader(
    test_dataset,
    shuffle=False,
    collate_fn=collate_fn,
    batch_size=1,
    num_workers=dataloader_num_workers,
)

# set learning rate scheduler
lr_scheduler = get_scheduler(
    # lr_scheduler=lr_scheduler,
    "cosine",
    optimizer=optimizer,
    num_warmup_steps=lr_warmup_steps,
    num_training_steps=max_train_steps,
)


### If you run it for the first time or you want to record the data from the begining, just run the following block: wanda.init() and skip to the training block. Else, skip this block.

In [None]:
wandb.init()

### If your training stop accidently, you can run the blocks here.

In [None]:
#You can know your RunID here.
print(wandb.run.id)

### Connect to the wandb you used before. You can see the final checkpoints you stored here.

In [None]:
NAME = f"rank:{rank}"
PROJECT_NAME = "diffusers-examples_text_to_image"
CHECKPOINT_PATH = "/content/drive/MyDrive/project182/Apple/model/checkpoints/checkpoint.tar"
RunID = "pj6sgt7z"
run = wandb.init(project=PROJECT_NAME, name=NAME, resume=True)
#run = wandb.init(project=PROJECT_NAME, name=NAME, id = RunID, resume=True)
begin_batch = 1
if wandb.run.resumed:
    #checkpoint = torch.load(wandb.restore(CHECKPOINT_PATH))
    with open(CHECKPOINT_PATH, 'rb') as f:
        checkpoint = torch.load(f, encoding='utf-8')
    unet_path = checkpoint["unet"]
    unet.load_attn_procs(unet_path)
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    epoch = checkpoint["epoch"]
    loss = checkpoint["loss"]
    batch = checkpoint["batch"]
    begin_batch = checkpoint["batch"]
    print(epoch, begin_batch)



## Training Block
### You can train the data here. If you need to connect to the part you just run before, you need to use the TODO part of checking to skip checkpoints you already recorded (the epoch and begin_batch is printed in the block before). Else, you need to comment out this part of the if loop.
### You can also change the recording way here. We iniailize it as storing every checkpoints.

In [None]:
# Train!
# total_batch_size = train_batch_size * accelerator.num_processes * gradient_accumulation_steps
total_batch_size = train_batch_size * gradient_accumulation_steps

max_train_steps = 10000

progress_bar = tqdm(
    range(0, max_train_steps),
    initial=begin_batch,
    desc="Steps",
    # Only show the progress bar once on each machine.
    #disable=not accelerator.is_local_main_process,
)

# Load previous pipeline
pipeline = DiffusionPipeline.from_pretrained(
    pretrained_model_name_or_path,
    safety_checker=None
).to("cuda")
max_epoch = 10 #5
for epoch in range(max_epoch):
    unet.train()
    for step, batch in enumerate(train_dataloader):
      #
      #if epoch == 0 and step < 246 :
        #continue
      # Convert images to latent space
      latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
      latents = latents * vae.config.scaling_factor
      # Sample noise that we'll add to the latents
      noise = torch.randn_like(latents)
      bsz = latents.shape[0]
      # Sample a random timestep for each image
      timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
      timesteps = timesteps.long()

      # Add noise to the latents according to the noise magnitude at each timestep
      # (this is the forward diffusion process)
      noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
      # Get the text embedding for conditioning
      encoder_hidden_states = text_encoder(batch["input_ids"])[0]

      # Get the target for loss depending on the prediction type
      if prediction_type is not None:
          # set prediction_type of scheduler if defined
          noise_scheduler.register_to_config(prediction_type=prediction_type)

      if noise_scheduler.config.prediction_type == "epsilon":
          target = noise
      elif noise_scheduler.config.prediction_type == "v_prediction":
          target = noise_scheduler.get_velocity(latents, noise, timesteps)
      else:
          raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")

      # draw validation loss
      valid_loss = 0.0
      valid_size = 7
      if step % 49 == 1:
        unet.eval()
        for _ , valid_batch in enumerate(valid_dataloader):
          # Convert images to latent space
          valid_latents = vae.encode(valid_batch["pixel_values"]).latent_dist.sample()
          valid_latents = valid_latents * vae.config.scaling_factor
          valid_noise = torch.randn_like(valid_latents)
          valid_bsz = valid_latents.shape[0]
          # Sample a random timestep for each image
          valid_timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (valid_bsz,), device=valid_latents.device)
          valid_timesteps = timesteps.long()

          # Add noise to the latents according to the noise magnitude at each timestep
          # (this is the forward diffusion process)
          valid_noisy_latents = noise_scheduler.add_noise(valid_latents, valid_noise, valid_timesteps)
          # Get the text embedding for conditioning
          valid_encoder_hidden_states = text_encoder(valid_batch["input_ids"])[0]

          # Get the target for loss depending on the prediction type
          if prediction_type is not None:
              # set prediction_type of scheduler if defined
              noise_scheduler.register_to_config(prediction_type=prediction_type)

          if noise_scheduler.config.prediction_type == "epsilon":
              valid_target = valid_noise
          elif noise_scheduler.config.prediction_type == "v_prediction":
              valid_target = noise_scheduler.get_velocity(valid_latents, valid_noise, valid_timesteps)
          else:
              raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")

          valid_model_pred = unet(valid_noisy_latents, valid_timesteps, valid_encoder_hidden_states).sample
          valid_loss += F.mse_loss(valid_model_pred.float(), valid_target.float(), reduction="mean")
        print("valid_loss = "+str(valid_loss/valid_size))
        wandb.log({"valid_loss": valid_loss/valid_size})

      #unet.train()
      # Predict the noise residual and compute loss
      model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
      loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
      # Backpropagate
      loss.backward()
      optimizer.step()
      lr_scheduler.step()
      optimizer.zero_grad()
      progress_bar.update(1)
      print("loss = "+str(loss.item()))
      wandb.log({"loss": loss})
      if 1 == 1: #修改！！！
        unet.save_attn_procs(output_dir+f"epoch_{epoch+1}_batch_{step+1}")
        unet.save_attn_procs(output_dir+f"latest")
        # Save our checkpoint loc
        torch.save(
            {
                "epoch": epoch,
                "unet": output_dir+f"latest",
                "optimizer_state_dict": optimizer.state_dict(),
                "loss": loss,
                "batch": step,
            },
            CHECKPOINT_PATH,
        )
        wandb.save(CHECKPOINT_PATH)
        # load attention processors
        pipeline.unet.load_attn_procs(output_dir+f"latest")

        # run inference
        generator = torch.Generator()
        if seed is not None:
            generator = generator.manual_seed(seed)
        images = []
        for i in range(num_validation_images):
            images.append(pipeline(validation_prompt[i], num_inference_steps=30, generator=generator).images[0])
        wandb.log({"validation": [wandb.Image(image, caption=f"{i}: {validation_prompt[i]}")for i, image in enumerate(images)]})


unet = unet.to(torch.float32)
unet.save_attn_procs(output_dir)



In [None]:
# # our based pretrained model is "runwayml/stable-diffusion-v1-5"
# pretrained_model_name_or_path = "runwayml/stable-diffusion-v1-5"
# # load our pretained model
pipeline = DiffusionPipeline.from_pretrained(
    pretrained_model_name_or_path,weight_dtype = torch.float16).to(device)
# TODO output_dir: the directory where we save the finetuned weight of lora 
output_dir = "/content/drive/MyDrive/huggingface_diffuser/model/checkpoints/latest"
seed = 1337
pipeline.unet.load_attn_procs(output_dir)
generator = torch.Generator()
generator = generator.manual_seed(seed)
test_images = []
for i  in range(len(test_prompts)):
    test_images.append(pipeline(test_prompts[i],num_inference_steps=30,generator=generator,output_type="np").images[0])
test_images = np.array(test_images)
#test_images = pipeline(test_prompts,num_inference_steps=30,generator=generator,output_type="np").images
# calcualate CLIP score
CLIP_score = calculate_clip_score(test_images,test_prompts)
print(f"CLIP Score : {CLIP_score}")
# write CLIP_score to wandb
wandb.log({"CLIP Score":CLIP_score})

In [None]:
# calculate FID score
# load the funcitons used to calculate FID score
# TODO dataset_path: the path to the test dataset
dataset_path = "/content/drive/MyDrive/huggingface_diffuser/data/validation_set" 
#image_paths = sorted([os.path.join(dataset_path, x) for x in os.listdir(dataset_path)])
image_paths = sorted([os.path.join(dataset_path, x) for x in png_path])
real_test_images = [np.array(Image.open(path).convert("RGB")) for path in image_paths]
real_test_images = torch.cat([preprocess_image(image) for image in real_test_images])
test_images = torch.tensor(test_images)
test_images = test_images.permute(0,3,1,2)
fid.update(real_test_images, real=True)
fid.update(test_images, real=False)
print(f"FID: {float(fid.compute())}")
# write FID_score to wandb
wandb.log({"FID Score":float(fid.compute())})