In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import matplotlib.pyplot as plt
import copy

import torch
import torch.nn as nn
import logging


from diffusers import SD3Transformer2DModel, AutoencoderKL, FlowMatchEulerDiscreteScheduler
from transformers import CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel, T5TokenizerFast

from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3, free_memory


from text_embed import encode_prompt, get_precomputed_tensors
from datasets import FillDataset
from sd3 import SD3CNModel

2024-10-31 17:08:43.685652: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-10-31 17:08:43.685698: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-10-31 17:08:43.686846: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-10-31 17:08:43.692672: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [3]:
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
from torch import nn

from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.resnet import Downsample2D, ResnetBlock2D
    

class ControlNeXtModel(ModelMixin, ConfigMixin):
    _supports_gradient_checkpointing = True

    @register_to_config
    def __init__(
        self,
        time_embed_dim = 256,
        in_channels = [128, 128],
        out_channels = [128, 256],
        groups = [4, 8],
        controlnext_scale=1.
    ):
        super().__init__()

        self.time_proj = Timesteps(128, True, downscale_freq_shift=0)
        self.time_embedding = TimestepEmbedding(128, time_embed_dim)
        self.embedding = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1),
            nn.GroupNorm(2, 64),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
            nn.GroupNorm(2, 64),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.GroupNorm(2, 128),
            nn.ReLU(),
        )

        self.down_res = nn.ModuleList()
        self.down_sample = nn.ModuleList()
        for i in range(len(in_channels)):
            self.down_res.append(
                ResnetBlock2D(
                    in_channels=in_channels[i],
                    out_channels=out_channels[i],
                    temb_channels=time_embed_dim,
                    groups=groups[i]
                ),
            )
            self.down_sample.append(
                Downsample2D(
                    out_channels[i],
                    use_conv=True,
                    out_channels=out_channels[i],
                    padding=1,
                    name="op",
                )
            )
        
        self.mid_convs = nn.ModuleList()
        self.mid_convs.append(nn.Sequential(
            nn.Conv2d(
                in_channels=out_channels[-1],
                out_channels=out_channels[-1],
                kernel_size=3,
                stride=1,
                padding=1
            ),
            nn.ReLU(),
            nn.GroupNorm(8, out_channels[-1]),
            nn.Conv2d(
                in_channels=out_channels[-1],
                out_channels=out_channels[-1],
                kernel_size=3,
                stride=1,
                padding=1
            ),
            nn.GroupNorm(8, out_channels[-1]),
        ))
        self.mid_convs.append(
            nn.Conv2d(
            in_channels=out_channels[-1],
            out_channels=320,
            kernel_size=1,
            stride=1,
        ))

        self.scale = controlnext_scale

    def forward(
        self,
        sample: torch.FloatTensor,
        timestep: Union[torch.Tensor, float, int],
    ):
        
        timesteps = timestep
        if not torch.is_tensor(timesteps):
            # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
            # This would be a good case for the `match` statement (Python 3.10+)
            is_mps = sample.device.type == "mps"
            if isinstance(timestep, float):
                dtype = torch.float32 if is_mps else torch.float64
            else:
                dtype = torch.int32 if is_mps else torch.int64
            timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
        elif len(timesteps.shape) == 0:
            timesteps = timesteps[None].to(sample.device)

        # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
        batch_size = sample.shape[0]
        timesteps = timesteps.expand(batch_size)

        t_emb = self.time_proj(timesteps)

        # `Timesteps` does not contain any weights and will always return f32 tensors
        # but time_embedding might actually be running in fp16. so we need to cast here.
        # there might be better ways to encapsulate this.
        t_emb = t_emb.to(dtype=sample.dtype)

        emb = self.time_embedding(t_emb)

        sample = self.embedding(sample)

        for res, downsample in zip(self.down_res, self.down_sample):
            sample = res(sample, emb)
            sample = downsample(sample, emb)
        
        sample = self.mid_convs[0](sample) + sample
        sample = self.mid_convs[1](sample)
        
        return {
            'output': sample,
            'scale': self.scale,
        }


def get_sigmas(timesteps, n_dim=4, dtype=torch.float32, device="cuda"):
    sigmas = noise_scheduler_copy.sigmas.to(device=device, dtype=dtype)
    schedule_timesteps = noise_scheduler_copy.timesteps.to(device)
    # timesteps = timesteps.to(accelerator.device)
    step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]

    sigma = sigmas[step_indices].flatten()
    while len(sigma.shape) < n_dim:
        sigma = sigma.unsqueeze(-1)
    return sigma

In [4]:
device = 'cuda:1'

In [43]:
# transformer = SD3CNModel.from_pretrained(
#     "stabilityai/stable-diffusion-3-medium-diffusers",
#     subfolder="transformer",
#     torch_dtype=torch.float16).to(device)
transformer = SD3CNModel.from_pretrained(
    "stabilityai/stable-diffusion-3-medium-diffusers",
    subfolder="transformer",).to(device)

In [6]:
control_next_model = ControlNeXtModel().to(device)

vae = AutoencoderKL.from_pretrained(
            "stabilityai/stable-diffusion-3-medium",
            subfolder="vae",
            revision="refs/pr/26").to(device)

In [7]:
noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
    "stabilityai/stable-diffusion-3-medium-diffusers", subfolder="scheduler"
)
noise_scheduler_copy = copy.deepcopy(noise_scheduler)

In [8]:
tensor_list = get_precomputed_tensors()
for data in tensor_list:
    for key, value in data.items():
        if key != 'prompt' and not isinstance(value, torch.Tensor):
            data[key] = torch.tensor(data[key]).to(device)
        if key in ['img', 'hint']:
            data[key] = data[key].permute(2, 0, 1).unsqueeze(dim=0)
        
        
pixel_list = [x['img'] for x in tensor_list]
hint_list = [x['hint'] for x in tensor_list]
prompt_embed_list = [x['prompt_embeds'] for x in tensor_list]
pooled_prompt_embed_list = [x['pooled_prompt_embeds'] for x in tensor_list]

In [9]:
model_input = vae.encode(pixel_list[0]).latent_dist.sample()
model_input = (model_input - vae.config.shift_factor) * vae.config.scaling_factor

In [15]:
control_input = vae.encode(hint_list[0]).latent_dist.sample()
control_input = (control_input - vae.config.shift_factor) * vae.config.scaling_factor

In [10]:
# Sample noise that we'll add to the latents
noise = torch.randn_like(model_input)
bsz = model_input.shape[0]
# Sample a random timestep for each image
# for weighting schemes where we sample timesteps non-uniformly
u = compute_density_for_timestep_sampling(
    weighting_scheme="logit_normal",
    batch_size=1,
    logit_mean=0,
    logit_std=1,
    mode_scale=1.29,
)

In [11]:
indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()
timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)

# Add noise according to flow matching.
# zt = (1 - texp) * x + texp * z1
sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype, device=device)
noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise

In [16]:
control_input.shape

torch.Size([1, 16, 64, 64])

In [29]:
# res = control_next_model(control_input, timesteps)
res = control_next_model(hint_list[0], timesteps)
control_out = res['output']

In [31]:
control_out.shape

torch.Size([1, 320, 64, 64])

In [38]:
timesteps.half()

tensor([851.], device='cuda:1', dtype=torch.float16)

In [42]:
print(noisy_model_input.dtype)
print(timesteps.dtype)
print(prompt_embeds.dtype)
print(pooled_prompt_embeds.dtype)


torch.float32
torch.float32
torch.float32
torch.float32


In [46]:
transformer.device

device(type='cuda', index=1)

In [45]:
# Get the text embedding for conditioning
prompt_embeds = prompt_embed_list[0]
pooled_prompt_embeds = pooled_prompt_embed_list[0]

# Predict the noise residual
model_pred = transformer(
    hidden_states=noisy_model_input,
    timestep=timesteps,
    encoder_hidden_states=prompt_embeds,
    pooled_projections=pooled_prompt_embeds,
    block_controlnet_hidden_states=pooled_prompt_embeds,
    return_dict=False,
)[0]

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cpu! (when checking argument for argument mat1 in method wrapper_CUDA_addmm)

In [18]:
res['output'].shape

torch.Size([1, 320, 64, 64])

In [14]:
res['scale']

1.0