In [1]:
import sys

import torch
from torchvision import transforms
from transformers import AutoTokenizer
from torch.utils.data import DataLoader
from peft import LoraConfig

sys.path.append("../")
from src import create_hf_dataset, ImageTextSDTensorDataset, ImageTextSDXLTensorDataset, RandomCropWithCoords, ComposeWithCropCoords, AutoStableDiffusionModel, LoraWrapper, Params, convert_to_lora_target_names

  warn(
  from .autonotebook import tqdm as notebook_tqdm
  torch.utils._pytree._register_pytree_node(


In [2]:
data_dir = "../data/deep_fashion"

In [3]:
config_path = "../configs/train_config_sdxl.json"
config = Params(config_path)

In [4]:
hf_dataset = create_hf_dataset(data_dir, config.TRAINING.TEST_SIZE)

# Look at Model Outputs for Stable Diffusion

In [5]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"

In [6]:
sd_train_config_path = "./configs_for_exploration/train_config_sd.json"
sd_train_config = Params(sd_train_config_path)

In [7]:
sd_tokenizer = AutoTokenizer.from_pretrained(config.MODEL.BASE_MODEL_NAME, subfolder="tokenizer")

In [8]:
sd_train_transforms = transforms.Compose(
    [
        transforms.Resize(sd_train_config.DATA_AUGMENTATION.RESIZE_RESOLUTION),
        transforms.RandomCrop(sd_train_config.DATA_AUGMENTATION.TARGET_RESOLUTION),
        transforms.RandomHorizontalFlip() if sd_train_config.DATA_AUGMENTATION.RANDOM_HORIZONTAL_FLIP else transforms.Lambda(lambda x: x),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5]),
    ]
)
sd_train_dataset = ImageTextSDTensorDataset(hf_dataset["train"], sd_tokenizer, sd_train_transforms)
sd_train_dataloader = DataLoader(sd_train_dataset, batch_size=sd_train_config.TRAINING.BATCH_SIZE.TRAIN, shuffle=True)

Map: 100%|██████████| 40416/40416 [00:02<00:00, 19161.33 examples/s]


In [9]:
sd_train_batch = next(iter(sd_train_dataloader))

In [10]:
sd_train_batch = [item.to(device) for item in sd_train_batch]

**Base Model Output**

In [11]:
sd_base_model = AutoStableDiffusionModel.from_pretrained(sd_train_config.MODEL.BASE_MODEL_NAME)
sd_base_model = sd_base_model.to(device)

In [12]:
sd_base_loss = sd_base_model(*sd_train_batch)

In [13]:
sd_base_loss

tensor(0.0998, device='cuda:0', grad_fn=<MseLossBackward0>)

In [14]:
sd_base_loss.backward()

In [15]:
sd_base_model.to("cpu")
del sd_base_model
torch.cuda.empty_cache()

**Lora Model with Trainable Text Encoder Output**

In [16]:
sd_base_model = AutoStableDiffusionModel.from_pretrained(sd_train_config.MODEL.BASE_MODEL_NAME)
sd_base_model = sd_base_model.to(device)

In [17]:
sd_unet_lora_config = LoraConfig(
    r=sd_train_config.LORA.RANK,
    lora_alpha=sd_train_config.LORA.ALPHA,
    target_modules=convert_to_lora_target_names(sd_train_config.LORA.TARGET, "unet"),
    init_lora_weights="gaussian",
)

In [18]:
text_encoder_lora_config = (
    LoraConfig(
        r=sd_train_config.LORA.RANK,
        lora_alpha=sd_train_config.LORA.ALPHA,
        target_modules=convert_to_lora_target_names(
            sd_train_config.LORA.TARGET, "text_encoder"
        ),
        init_lora_weights="gaussian",
    )
)

In [19]:
sd_lora_model = LoraWrapper.from_config(sd_base_model, sd_unet_lora_config, text_encoder_lora_config)
sd_lora_model = sd_lora_model.to(device)

In [20]:
sd_lora_loss = sd_lora_model(*sd_train_batch)

In [21]:
sd_lora_loss

tensor(0.1961, device='cuda:0', grad_fn=<MseLossBackward0>)

In [22]:
sd_lora_loss.backward()

In [23]:
sd_base_model.to("cpu")
sd_lora_model.to("cpu")
del sd_base_model
del sd_lora_model
torch.cuda.empty_cache()

**Lora Model without Trainable Text Encoder Output**

In [24]:
sd_base_model = AutoStableDiffusionModel.from_pretrained(sd_train_config.MODEL.BASE_MODEL_NAME)
sd_base_model = sd_base_model.to(device)

In [25]:
sd_unet_lora_config = LoraConfig(
    r=sd_train_config.LORA.RANK,
    lora_alpha=sd_train_config.LORA.ALPHA,
    target_modules=convert_to_lora_target_names(sd_train_config.LORA.TARGET, "unet"),
    init_lora_weights="gaussian",
)

In [26]:
sd_lora_model = LoraWrapper.from_config(sd_base_model, sd_unet_lora_config)
sd_lora_model = sd_lora_model.to(device)

In [27]:
sd_lora_loss = sd_lora_model(*sd_train_batch)

In [28]:
sd_lora_loss

tensor(0.0473, device='cuda:0', grad_fn=<MseLossBackward0>)

In [29]:
sd_lora_loss.backward()

In [30]:
sd_base_model.to("cpu")
sd_lora_model.to("cpu")
del sd_base_model
del sd_lora_model
torch.cuda.empty_cache()

# Look at Model Outputs for Stable Diffusion XL

In [31]:
device = "cuda:1" if torch.cuda.is_available() else "cpu"

In [32]:
sdxl_train_config_path = "./configs_for_exploration/train_config_sdxl.json"
sdxl_train_config = Params(sdxl_train_config_path)

In [33]:
sdxl_tokenizer1 = AutoTokenizer.from_pretrained(sdxl_train_config.MODEL.BASE_MODEL_NAME, subfolder="tokenizer")
sdxl_tokenizer2 = AutoTokenizer.from_pretrained(sdxl_train_config.MODEL.BASE_MODEL_NAME, subfolder="tokenizer_2")

In [34]:
sdxl_train_transforms = ComposeWithCropCoords(
    [
        transforms.Resize(sdxl_train_config.DATA_AUGMENTATION.RESIZE_RESOLUTION),
        RandomCropWithCoords(sdxl_train_config.DATA_AUGMENTATION.TARGET_RESOLUTION),
        transforms.RandomHorizontalFlip() if sdxl_train_config.DATA_AUGMENTATION.RANDOM_HORIZONTAL_FLIP else transforms.Lambda(lambda x: x),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5]),
    ]
)
sdxl_train_dataset = ImageTextSDXLTensorDataset(hf_dataset["train"], sdxl_tokenizer1, sdxl_tokenizer2, sdxl_train_transforms, sdxl_train_config.DATA_AUGMENTATION.TARGET_RESOLUTION)
sdxl_train_dataloader = DataLoader(sdxl_train_dataset, batch_size=sdxl_train_config.TRAINING.BATCH_SIZE.TRAIN, shuffle=True)

Map: 100%|██████████| 40416/40416 [00:03<00:00, 10224.53 examples/s]


In [35]:
sdxl_train_batch = next(iter(sdxl_train_dataloader))

In [36]:
sdxl_train_batch = [item.to(device) for item in sdxl_train_batch]

**Base Model Output**

In [37]:
sdxl_base_model = AutoStableDiffusionModel.from_pretrained(sdxl_train_config.MODEL.BASE_MODEL_NAME)
sdxl_base_model = sdxl_base_model.to(device)

In [38]:
sdxl_base_loss = sdxl_base_model(*sdxl_train_batch)

In [39]:
sdxl_base_loss

tensor(0.0995, device='cuda:1', grad_fn=<MseLossBackward0>)

In [40]:
sdxl_base_loss.backward()

In [41]:
sdxl_base_model.to("cpu")
del sdxl_base_model
torch.cuda.empty_cache()

**Lora Model with Trainable Text Encoder Output**

In [42]:
sdxl_base_model = AutoStableDiffusionModel.from_pretrained(sdxl_train_config.MODEL.BASE_MODEL_NAME)
sdxl_base_model = sdxl_base_model.to(device)

In [43]:
sdxl_unet_lora_config = LoraConfig(
    r=sdxl_train_config.LORA.RANK,
    lora_alpha=sdxl_train_config.LORA.ALPHA,
    target_modules=convert_to_lora_target_names(sdxl_train_config.LORA.TARGET, "unet"),
    init_lora_weights="gaussian",
)

In [44]:
sdxl_text_encoder_lora_config = LoraConfig(
    r=sdxl_train_config.LORA.RANK,
    lora_alpha=sdxl_train_config.LORA.ALPHA,
    target_modules=convert_to_lora_target_names(sdxl_train_config.LORA.TARGET, "text_encoder"),
    init_lora_weights="gaussian",
)

In [45]:
sdxl_lora_model = LoraWrapper.from_config(sdxl_base_model, sdxl_unet_lora_config, sdxl_text_encoder_lora_config)
sdxl_lora_model = sdxl_lora_model.to(device)

In [46]:
sdxl_lora_loss = sdxl_lora_model(*sdxl_train_batch)

In [47]:
sdxl_lora_loss

tensor(0.0698, device='cuda:1', grad_fn=<MseLossBackward0>)

In [48]:
sdxl_lora_loss.backward()

In [49]:
sdxl_base_model.to("cpu")
sdxl_lora_model.to("cpu")
del sdxl_base_model
del sdxl_lora_model
torch.cuda.empty_cache()

**Lora Model without Trainable Text Encoder Output**

In [50]:
sdxl_base_model = AutoStableDiffusionModel.from_pretrained(sdxl_train_config.MODEL.BASE_MODEL_NAME)
sdxl_base_model = sdxl_base_model.to(device)

In [51]:
sdxl_unet_lora_config = LoraConfig(
    r=sdxl_train_config.LORA.RANK,
    lora_alpha=sdxl_train_config.LORA.ALPHA,
    target_modules=convert_to_lora_target_names(sdxl_train_config.LORA.TARGET, "unet"),
    init_lora_weights="gaussian",
)

In [52]:
sdxl_lora_model = LoraWrapper.from_config(sdxl_base_model, sdxl_unet_lora_config)
sdxl_lora_model = sdxl_lora_model.to(device)

In [53]:
sdxl_lora_model.enable_gradient_checkpointing()
sdxl_lora_model.enable_xformers()

In [54]:
sdxl_lora_loss = sdxl_lora_model(*sdxl_train_batch)

In [55]:
sdxl_lora_loss

tensor(0.1640, device='cuda:1', grad_fn=<MseLossBackward0>)

In [56]:
sdxl_lora_loss.backward()

In [57]:
sdxl_base_model.to("cpu")
sdxl_lora_model.to("cpu")
del sdxl_base_model
del sdxl_lora_model
torch.cuda.empty_cache()