In [49]:
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
from transformers import CLIPTextModel, CLIPTokenizer
import torch
from diffusers.models.attention_processor import LoRAAttnProcessor
import xformers
from diffusers.loaders import AttnProcsLayers
from datasets import load_dataset
from IPython.display import display
from torchvision import transforms
from diffusers.optimization import get_scheduler

In [56]:
model_name='runwayml/stable-diffusion-v1-5'
train_data_dir='~/Pictures/lora_datasets/potions'
output_dir=''
rank=4
optimizer_cls = torch.optim.AdamW
# Initial learning rate (after the potential warmup period) to use
learning_rate=1e-4
adam_beta1=0.9
adam_beta2=0.999
adam_weight_decay=1e-2
adam_epsilon=1e-08
image_column_name='image'
text_column_name='text'
resolution=512
center_crop=True
random_flip=False
prediction_type=None
max_grad_norm=1.0
lr_scheduler='constant' #["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"]
lr_warmup_steps=500
# TODO scaling LR

# TODO
num_train_epochs=100
# Batch size (per device) for the training dataloader
train_batch_size=16
# Number of updates steps to accumulate before performing a backward/update pass
gradient_accumulation_steps=1
# Total number of training steps to perform.  If provided, overrides num_train_epochs.
max_train_steps=num_train_epochs * 1 # TODO 

In [13]:
noise_scheduler = DDPMScheduler.from_pretrained(model_name, subfolder="scheduler")
tokenizer = CLIPTokenizer.from_pretrained(model_name, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(model_name, subfolder="text_encoder")
vae = AutoencoderKL.from_pretrained(model_name, subfolder="vae")
unet = UNet2DConditionModel.from_pretrained(model_name, subfolder="unet")

In [14]:
# freeze parameters of models to save more memory
unet.requires_grad_(False)
vae.requires_grad_(False)
text_encoder.requires_grad_(False)

if torch.cuda.is_available():
    unet.enable_xformers_memory_efficient_attention()

In [16]:
%%capture
weight_dtype = torch.float16
device='cuda'
# Move unet, vae and text_encoder to device and cast to weight_dtype
unet.to(device, dtype=weight_dtype)
vae.to(device, dtype=weight_dtype)
text_encoder.to(device, dtype=weight_dtype)

In [17]:
lora_attn_procs = {}
for name in unet.attn_processors.keys():
    cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
    if name.startswith("mid_block"):
        hidden_size = unet.config.block_out_channels[-1]
    elif name.startswith("up_blocks"):
        block_id = int(name[len("up_blocks.")])
        hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
    elif name.startswith("down_blocks"):
        block_id = int(name[len("down_blocks.")])
        hidden_size = unet.config.block_out_channels[block_id]

    lora_attn_procs[name] = LoRAAttnProcessor(
        hidden_size=hidden_size,
        cross_attention_dim=cross_attention_dim,
        rank=rank,
    )

unet.set_attn_processor(lora_attn_procs)
lora_layers = AttnProcsLayers(unet.attn_processors)

In [18]:
optimizer = optimizer_cls(
    lora_layers.parameters(),
    lr=learning_rate,
    betas=(adam_beta1, adam_beta2),
    weight_decay=adam_weight_decay,
    eps=adam_epsilon,
)

In [42]:
dataset = load_dataset(
    "imagefolder",
    data_dir=train_data_dir,
    #cache_dir=args.cache_dir,
)
#display(dataset)
column_names = dataset["train"].column_names
#print(column_names)
if image_column_name not in column_names:
    raise f"Image column {image_column_name} not in {','.join(column_names)}"
if text_column_name not in column_names:
    raise f"Text column {text_column_name} not in {','.join(column_names)}"

Resolving data files:   0%|          | 0/38 [00:00<?, ?it/s]

In [43]:
train_transforms = transforms.Compose([
    transforms.Resize(resolution, interpolation=transforms.InterpolationMode.BILINEAR),
    transforms.CenterCrop(resolution) if center_crop else transforms.RandomCrop(args.resolution),
    transforms.RandomHorizontalFlip() if random_flip else transforms.Lambda(lambda x: x),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5]),
])

def preprocess_train(items):
    display(items)
    print('here')
    images = [image.convert("RGB") for image in items[image_column]]
    items["pixel_values"] = [train_transforms(image) for image in images]
    items["input_ids"] = tokenize_captions(examples)
    return items

train_dataset = dataset["train"].with_transform(preprocess_train)

def collate_fn(items):
    pixel_values = torch.stack([item["pixel_values"] for item in items])
    pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
    input_ids = torch.stack([item["input_ids"] for item in items])
    return {"pixel_values": pixel_values, "input_ids": input_ids}

train_dataloader = torch.utils.data.DataLoader(
    train_dataset,
    shuffle=True,
    collate_fn=collate_fn,
    batch_size=train_batch_size,
    # num_workers=args.dataloader_num_workers,
)
len(train_dataloader)

3

In [50]:
lr_scheduler = get_scheduler(
    lr_scheduler,
    optimizer=optimizer,
    num_warmup_steps=lr_warmup_steps,
    num_training_steps=max_train_steps,
)

In [52]:
first_epoch=0
global_step=0

In [57]:
for epoch in range(first_epoch, num_train_epochs):
    unet.train()
    train_loss=0.0
    for step, batch in enumerate(train_dataloader):
        # TODO skip steps until we reach the resumed step

        # TODO accelerate.accumulate

        # Convert images to latent space
        latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
        latents = latents * vae.config.scaling_factor

        # Sample noise that we'll add to the latents
        noise = torch.randn_like(latents)

        bsz = latents.shape[0]
        # Sample a random timestep for each image
        timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
        timesteps = timesteps.long()

        # Add noise to the latents according to the noise magnitude at each timestep
        # (this is the forward diffusion process)
        noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
        
        # Get the text embedding for conditioning
        encoder_hidden_states = text_encoder(batch["input_ids"])[0]

        
        # Get the target for loss depending on the prediction type
        if prediction_type is not None:
            # set prediction_type of scheduler if defined
            noise_scheduler.register_to_config(prediction_type=prediction_type)

        if noise_scheduler.config.prediction_type == "epsilon":
            target = noise
        elif noise_scheduler.config.prediction_type == "v_prediction":
            target = noise_scheduler.get_velocity(latents, noise, timesteps)
        else:
            raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")

        # Predict the noise residual and compute loss
        model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
        loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")

        loss.backward()
        torch.nn.utils.clip_grad_norm_(lora_layers.parameters, max_norm=max_grad_norm)
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        
        # TODO

{'image': [<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=512x512>,
  <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=512x512>,
  <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=512x512>,
  <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=512x512>,
  <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=512x512>,
  <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=512x512>,
  <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=512x512>,
  <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=512x512>,
  <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=512x512>,
  <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=512x512>,
  <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=512x512>,
  <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=512x512>,
  <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=512x512>,
  <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=512x512>,
  <PIL.JpegImagePlugin.JpegImageFile image mode=RGB s

here


NameError: name 'image_column' is not defined

In [None]:
unet = unet.to(torch.float32)
unet.save_attn_procs(args.output_dir)