# Fine-tunning diffsuion model using ControlLora

### Installing dependencies

In [3]:
""" !pip install accelerate>=0.16.0
!pip install bitsandbytes>=0.37.0
!pip install datasets>=2.9.0
!pip install git+https://github.com/huggingface/diffusers
!pip install gradio>=3.18.0
!pip install opencv-python>=4.7.0
!pip install safetensors>=0.2.8
!pip install torch>=1.12.1
!pip install torchaudio>=0.12.1
!pip install torchvision>=0.13.1
!pip install transformers>=4.26.1
!pip install wandb>=0.13.1
!pip install ipywidgets>=7.7.1 """

' !pip install accelerate>=0.16.0\n!pip install bitsandbytes>=0.37.0\n!pip install datasets>=2.9.0\n!pip install git+https://github.com/huggingface/diffusers\n!pip install gradio>=3.18.0\n!pip install opencv-python>=4.7.0\n!pip install safetensors>=0.2.8\n!pip install torch>=1.12.1\n!pip install torchaudio>=0.12.1\n!pip install torchvision>=0.13.1\n!pip install transformers>=4.26.1\n!pip install wandb>=0.13.1\n!pip install ipywidgets>=7.7.1 '

### Imports

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from typing import List, Tuple, Union
from dataclasses import dataclass
from diffusers.utils.outputs import BaseOutput
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.unet_2d_blocks import get_down_block as get_down_block_default
from diffusers.models.resnet import Mish, Upsample2D, Downsample2D, upsample_2d, downsample_2d, partial
from diffusers.models.cross_attention import CrossAttention # , LoRACrossAttnProcessor


### Dataset class

In [5]:
import torch
import numpy as np

from PIL import Image
from torch.utils import data


class dataset_cls(data.Dataset):
    DATASET_TYPE_DICT = {}

    def __init__(self, tokenizer, resolution=512, use_crop=True, **kwargs):
        pass

    @classmethod
    def register_cls(cls, name: str):
        dataset_cls.DATASET_TYPE_DICT['process/' + name] = cls

    @staticmethod
    def from_name(name: str):
        dataset_cls: dataset_cls = dataset_cls.DATASET_TYPE_DICT[name]
        return dataset_cls

    @staticmethod
    def control_channel():
        return 3

    @staticmethod
    def cat_input(image: Image.Image, target: torch.Tensor, guide: torch.Tensor):
        target = np.uint8(((target + 1) * 127.5)[0].permute(1,2,0).cpu().numpy().clip(0,255))
        guide = np.uint8(((guide + 1) * 127.5)[0].permute(1,2,0).cpu().numpy().clip(0,255))
        target = Image.fromarray(target).convert('RGB').resize(image.size)
        guide = Image.fromarray(guide).convert('RGB').resize(image.size)
        image_cat = Image.new('RGB', (image.size[0]*3,image.size[1]), (0,0,0))
        image_cat.paste(target,(0,0))
        image_cat.paste(guide,(image.size[0], 0))
        image_cat.paste(image,(image.size[0]*2, 0))

        return image_cat

### Define conv blocks

In [6]:
class ConvBlock2D(nn.Module):
    def __init__(
        self,
        *,
        in_channels,
        out_channels=None,
        conv_kernel_size=3,
        dropout=0.0,
        temb_channels=512,
        groups=32,
        groups_out=None,
        pre_norm=True,
        eps=1e-6,
        non_linearity="swish",
        time_embedding_norm="default",
        kernel=None,
        output_scale_factor=1.0,
        up=False,
        down=False,
    ):
        super().__init__()
        self.pre_norm = pre_norm
        self.pre_norm = True
        self.in_channels = in_channels
        out_channels = in_channels if out_channels is None else out_channels
        self.out_channels = out_channels
        self.time_embedding_norm = time_embedding_norm
        self.up = up
        self.down = down
        self.output_scale_factor = output_scale_factor

        if groups_out is None:
            groups_out = groups

        self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)

        self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=conv_kernel_size, stride=1, padding=conv_kernel_size//2)

        if temb_channels is not None:
            if self.time_embedding_norm == "default":
                time_emb_proj_out_channels = out_channels
            elif self.time_embedding_norm == "scale_shift":
                time_emb_proj_out_channels = out_channels * 2
            else:
                raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")

            self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels)
        else:
            self.time_emb_proj = None

        self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
        self.dropout = torch.nn.Dropout(dropout)

        if non_linearity == "swish":
            self.nonlinearity = lambda x: F.silu(x)
        elif non_linearity == "mish":
            self.nonlinearity = Mish()
        elif non_linearity == "silu":
            self.nonlinearity = nn.SiLU()

        self.upsample = self.downsample = None
        if self.up:
            if kernel == "fir":
                fir_kernel = (1, 3, 3, 1)
                self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel)
            elif kernel == "sde_vp":
                self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
            else:
                self.upsample = Upsample2D(in_channels, use_conv=False)
        elif self.down:
            if kernel == "fir":
                fir_kernel = (1, 3, 3, 1)
                self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel)
            elif kernel == "sde_vp":
                self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
            else:
                self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op")

    def forward(self, input_tensor, temb):
        hidden_states = input_tensor

        hidden_states = self.norm1(hidden_states)
        hidden_states = self.nonlinearity(hidden_states)

        if self.upsample is not None:
            # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
            if hidden_states.shape[0] >= 64:
                input_tensor = input_tensor.contiguous()
                hidden_states = hidden_states.contiguous()
            input_tensor = self.upsample(input_tensor)
            hidden_states = self.upsample(hidden_states)
        elif self.downsample is not None:
            input_tensor = self.downsample(input_tensor)
            hidden_states = self.downsample(hidden_states)

        hidden_states = self.conv1(hidden_states)

        if temb is not None:
            temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]

        if temb is not None and self.time_embedding_norm == "default":
            hidden_states = hidden_states + temb

        hidden_states = self.norm2(hidden_states)

        if temb is not None and self.time_embedding_norm == "scale_shift":
            scale, shift = torch.chunk(temb, 2, dim=1)
            hidden_states = hidden_states * (1 + scale) + shift

        hidden_states = self.nonlinearity(hidden_states)

        output_tensor = self.dropout(hidden_states)

        return output_tensor


class SimpleDownEncoderBlock2D(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        dropout: float = 0.0,
        num_layers: int = 1,
        convnet_eps: float = 1e-6,
        convnet_time_scale_shift: str = "default",
        convnet_act_fn: str = "swish",
        convnet_groups: int = 32,
        convnet_pre_norm: bool = True,
        convnet_kernel_size: int = 3,
        output_scale_factor=1.0,
        add_downsample=True,
        downsample_padding=1,
    ):
        super().__init__()
        convnets = []

        for i in range(num_layers):
            in_channels = in_channels if i == 0 else out_channels
            convnets.append(
                ConvBlock2D(
                    in_channels=in_channels,
                    out_channels=out_channels,
                    temb_channels=None,
                    eps=convnet_eps,
                    groups=convnet_groups,
                    dropout=dropout,
                    time_embedding_norm=convnet_time_scale_shift,
                    non_linearity=convnet_act_fn,
                    output_scale_factor=output_scale_factor,
                    pre_norm=convnet_pre_norm,
                    conv_kernel_size=convnet_kernel_size,
                )
            )
        in_channels = in_channels if num_layers == 0 else out_channels

        self.convnets = nn.ModuleList(convnets)

        if add_downsample:
            self.downsamplers = nn.ModuleList(
                [
                    Downsample2D(
                        in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
                    )
                ]
            )
        else:
            self.downsamplers = None

    def forward(self, hidden_states):
        for convnet in self.convnets:
            hidden_states = convnet(hidden_states, temb=None)

        if self.downsamplers is not None:
            for downsampler in self.downsamplers:
                hidden_states = downsampler(hidden_states)

        return hidden_states

In [7]:
def get_down_block(
    down_block_type,
    num_layers,
    in_channels,
    out_channels,
    temb_channels,
    add_downsample,
    resnet_eps,
    resnet_act_fn,
    attn_num_head_channels,
    resnet_groups=None,
    cross_attention_dim=None,
    downsample_padding=None,
    dual_cross_attention=False,
    use_linear_projection=False,
    only_cross_attention=False,
    upcast_attention=False,
    resnet_time_scale_shift="default",
    resnet_kernel_size=3,
):
    down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
    if down_block_type == "SimpleDownEncoderBlock2D":
        return SimpleDownEncoderBlock2D(
            num_layers=num_layers,
            in_channels=in_channels,
            out_channels=out_channels,
            add_downsample=add_downsample,
            convnet_eps=resnet_eps,
            convnet_act_fn=resnet_act_fn,
            convnet_groups=resnet_groups,
            downsample_padding=downsample_padding,
            convnet_time_scale_shift=resnet_time_scale_shift,
            convnet_kernel_size=resnet_kernel_size
        )
    else:
        return get_down_block_default(
            down_block_type,
            num_layers,
            in_channels,
            out_channels,
            temb_channels,
            add_downsample,
            resnet_eps,
            resnet_act_fn,
            attn_num_head_channels,
            resnet_groups=resnet_groups,
            cross_attention_dim=cross_attention_dim,
            downsample_padding=downsample_padding,
            dual_cross_attention=dual_cross_attention,
            use_linear_projection=use_linear_projection,
            only_cross_attention=only_cross_attention,
            upcast_attention=upcast_attention,
            resnet_time_scale_shift=resnet_time_scale_shift,
            # resnet_kernel_size=resnet_kernel_size
        )

### Define LoRA layers

In [8]:
class LoRALinearLayer(nn.Module):
    def __init__(self, in_features, out_features, rank=4):
        super().__init__()

        if rank > min(in_features, out_features):
            raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}")

        self.down = nn.Linear(in_features, rank, bias=False)
        self.up = nn.Linear(rank, out_features, bias=False)

        nn.init.normal_(self.down.weight, std=1 / rank)
        nn.init.zeros_(self.up.weight)

    def forward(self, hidden_states):
        orig_dtype = hidden_states.dtype
        dtype = self.down.weight.dtype

        down_hidden_states = self.down(hidden_states.to(dtype))
        up_hidden_states = self.up(down_hidden_states)

        return up_hidden_states.to(orig_dtype)

In [9]:
class LoRACrossAttnProcessor(nn.Module):
    def __init__(
            self, 
            hidden_size, 
            cross_attention_dim=None, 
            rank=4, 
            post_add=False,
            key_states_skipped=False,
            value_states_skipped=False,
            output_states_skipped=False):
        super().__init__()

        self.hidden_size = hidden_size
        self.cross_attention_dim = cross_attention_dim
        self.rank = rank
        self.post_add = post_add

        self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
        if not key_states_skipped:
            self.to_k_lora = LoRALinearLayer(
                hidden_size if post_add else (cross_attention_dim or hidden_size), hidden_size, rank)
        if not value_states_skipped:
            self.to_v_lora = LoRALinearLayer(
                hidden_size if post_add else (cross_attention_dim or hidden_size), hidden_size, rank)
        if not output_states_skipped:
            self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank)

        self.key_states_skipped: bool = key_states_skipped
        self.value_states_skipped: bool = value_states_skipped
        self.output_states_skipped: bool = output_states_skipped

    def skip_key_states(self, is_skipped: bool = True):
        if is_skipped == False:
            assert hasattr(self, 'to_k_lora')
        self.key_states_skipped = is_skipped

    def skip_value_states(self, is_skipped: bool = True):
        if is_skipped == False:
            assert hasattr(self, 'to_q_lora')
        self.value_states_skipped = is_skipped

    def skip_output_states(self, is_skipped: bool = True):
        if is_skipped == False:
            assert hasattr(self, 'to_out_lora')
        self.output_states_skipped = is_skipped

    def __call__(
        self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0
    ):
        batch_size, sequence_length, _ = hidden_states.shape
        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)

        query = attn.to_q(hidden_states) 
        query = query + scale * self.to_q_lora(query if self.post_add else hidden_states)
        query = attn.head_to_batch_dim(query)

        encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states

        key = attn.to_k(encoder_hidden_states) 
        if not self.key_states_skipped:
            key = key + scale * self.to_k_lora(key if self.post_add else encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)
        if not self.value_states_skipped:
            value = value + scale * self.to_v_lora(value if self.post_add else encoder_hidden_states)

        key = attn.head_to_batch_dim(key)
        value = attn.head_to_batch_dim(value)

        attention_probs = attn.get_attention_scores(query, key, attention_mask)
        hidden_states = torch.bmm(attention_probs, value)
        hidden_states = attn.batch_to_head_dim(hidden_states)

        # linear proj
        out = attn.to_out[0](hidden_states)
        if not self.output_states_skipped:
            out = out + scale * self.to_out_lora(out if self.post_add else hidden_states)
        hidden_states = out
        # dropout
        hidden_states = attn.to_out[1](hidden_states)

        return hidden_states

In [10]:
class ControlLoRACrossAttnProcessor(LoRACrossAttnProcessor):
    def __init__(
            self, 
            hidden_size, 
            cross_attention_dim=None, 
            rank=4, 
            control_rank=None, 
            post_add=False, 
            concat_hidden=False,
            control_channels=None,
            control_self_add=True,
            key_states_skipped=False,
            value_states_skipped=False,
            output_states_skipped=False,
            **kwargs):
        super().__init__(
            hidden_size, 
            cross_attention_dim, 
            rank, 
            post_add=post_add,
            key_states_skipped=key_states_skipped,
            value_states_skipped=value_states_skipped,
            output_states_skipped=output_states_skipped)

        control_rank = rank if control_rank is None else control_rank
        control_channels = hidden_size if control_channels is None else control_channels
        self.concat_hidden = concat_hidden
        self.control_self_add = control_self_add if control_channels is None else False
        self.control_states: torch.Tensor = None

        self.to_control = LoRALinearLayer(
            control_channels + (hidden_size if concat_hidden else 0), 
            hidden_size, 
            control_rank)
        self.pre_loras: List[LoRACrossAttnProcessor] = []
        self.post_loras: List[LoRACrossAttnProcessor] = []

    def inject_pre_lora(self, lora_layer):
        self.pre_loras.append(lora_layer)
    
    def inject_post_lora(self, lora_layer):
        self.post_loras.append(lora_layer)

    def inject_control_states(self, control_states):
        self.control_states = control_states

    def process_control_states(self, hidden_states, scale=1.0):
        control_states = self.control_states.to(hidden_states.dtype)
        if hidden_states.ndim == 3 and control_states.ndim == 4:
            batch, _, height, width = control_states.shape
            control_states = control_states.permute(0, 2, 3, 1).reshape(batch, height * width, -1)
            self.control_states = control_states
        _control_states = control_states
        if self.concat_hidden:
            b1, b2 = control_states.shape[0], hidden_states.shape[0]
            if b1 != b2:
                control_states = control_states[:,None].repeat(1, b2//b1, *([1]*(len(control_states.shape)-1)))
                control_states = control_states.view(-1, *control_states.shape[2:])
            _control_states = torch.cat([hidden_states, control_states], -1)
        _control_states = scale * self.to_control(_control_states)
        if self.control_self_add:
            control_states = control_states + _control_states
        else:
            control_states = _control_states

        return control_states

    def __call__(
        self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0
    ):
        pre_lora: LoRACrossAttnProcessor
        post_lora: LoRACrossAttnProcessor
        assert self.control_states is not None

        batch_size, sequence_length, _ = hidden_states.shape
        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
        query = attn.to_q(hidden_states)
        for pre_lora in self.pre_loras:
            lora_in = query if pre_lora.post_add else hidden_states
            if isinstance(pre_lora, ControlLoRACrossAttnProcessor):
                lora_in = lora_in + pre_lora.process_control_states(hidden_states, scale)
            query = query + scale * pre_lora.to_q_lora(lora_in)
        query = query + scale * self.to_q_lora((
            query if self.post_add else hidden_states) + self.process_control_states(hidden_states, scale))
        for post_lora in self.post_loras:
            lora_in = query if post_lora.post_add else hidden_states
            if isinstance(post_lora, ControlLoRACrossAttnProcessor):
                lora_in = lora_in + post_lora.process_control_states(hidden_states, scale)
            query = query + scale * post_lora.to_q_lora(lora_in)
        query = attn.head_to_batch_dim(query)

        encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states

        key = attn.to_k(encoder_hidden_states)
        for pre_lora in self.pre_loras:
            if not pre_lora.key_states_skipped:
                key = key + scale * pre_lora.to_k_lora(key if pre_lora.post_add else encoder_hidden_states)
        if not self.key_states_skipped:
            key = key + scale * self.to_k_lora(key if self.post_add else encoder_hidden_states)
        for post_lora in self.post_loras:
            if not post_lora.key_states_skipped:
                key = key + scale * post_lora.to_k_lora(key if post_lora.post_add else encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)
        for pre_lora in self.pre_loras:
            if not pre_lora.value_states_skipped:
                value = value + pre_lora.to_v_lora(value if pre_lora.post_add else encoder_hidden_states)
        if not self.value_states_skipped:
            value = value + scale * self.to_v_lora(value if self.post_add else encoder_hidden_states)
        for post_lora in self.post_loras:
            if not post_lora.value_states_skipped:
                value = value + post_lora.to_v_lora(value if post_lora.post_add else encoder_hidden_states)

        key = attn.head_to_batch_dim(key)
        value = attn.head_to_batch_dim(value)

        attention_probs = attn.get_attention_scores(query, key, attention_mask)
        hidden_states = torch.bmm(attention_probs, value)
        hidden_states = attn.batch_to_head_dim(hidden_states)

        # linear proj
        out = attn.to_out[0](hidden_states)
        for pre_lora in self.pre_loras:
            if not pre_lora.output_states_skipped:
                out = out + scale * pre_lora.to_out_lora(out if pre_lora.post_add else hidden_states)
        out = out + scale * self.to_out_lora(out if self.post_add else hidden_states)
        for post_lora in self.post_loras:
            if not post_lora.output_states_skipped:
                out = out + scale * post_lora.to_out_lora(out if post_lora.post_add else hidden_states)
        hidden_states = out
        # dropout
        hidden_states = attn.to_out[1](hidden_states)

        return hidden_states

### ControlLora Class

In [11]:
@dataclass
class ControlLoRAOutput(BaseOutput):
    control_states: Tuple[torch.FloatTensor]


class ControlLoRA(ModelMixin, ConfigMixin):
    @register_to_config
    def __init__(
        self,
        in_channels: int = 3,
        down_block_types: Tuple[str] = (
            "SimpleDownEncoderBlock2D",
            "SimpleDownEncoderBlock2D",
            "SimpleDownEncoderBlock2D",
            "SimpleDownEncoderBlock2D",
        ),
        block_out_channels: Tuple[int] = (32, 64, 128, 256),
        layers_per_block: int = 1,
        act_fn: str = "silu",
        norm_num_groups: int = 32,
        lora_pre_down_block_types: Tuple[str] = (
            None,
            "SimpleDownEncoderBlock2D",
            "SimpleDownEncoderBlock2D",
            "SimpleDownEncoderBlock2D",
        ),
        lora_pre_down_layers_per_block: int = 1,
        lora_pre_conv_skipped: bool = False,
        lora_pre_conv_types: Tuple[str] = (
            "SimpleDownEncoderBlock2D",
            "SimpleDownEncoderBlock2D",
            "SimpleDownEncoderBlock2D",
            "SimpleDownEncoderBlock2D",
        ),
        lora_pre_conv_layers_per_block: int = 1,
        lora_pre_conv_layers_kernel_size: int = 1,
        lora_block_in_channels: Tuple[int] = (256, 256, 256, 256),
        lora_block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
        lora_cross_attention_dims: Tuple[List[int]] = (
            [None, 768, None, 768, None, 768, None, 768, None, 768], 
            [None, 768, None, 768, None, 768, None, 768, None, 768], 
            [None, 768, None, 768, None, 768, None, 768, None, 768], 
            [None, 768]
        ),
        lora_rank: int = 4,
        lora_control_rank: int = None,
        lora_post_add: bool = False,
        lora_concat_hidden: bool = False,
        lora_control_channels: Tuple[int] = (None, None, None, None),
        lora_control_self_add: bool = True,
        lora_key_states_skipped: bool = False,
        lora_value_states_skipped: bool = False,
        lora_output_states_skipped: bool = False,
    ):
        super().__init__()

        lora_control_cls = ControlLoRACrossAttnProcessor

        assert lora_block_in_channels[0] == block_out_channels[-1]
        
        if lora_pre_conv_skipped:
            lora_control_channels = lora_block_in_channels
            lora_control_self_add = False

        self.layers_per_block = layers_per_block
        self.lora_pre_down_layers_per_block = lora_pre_down_layers_per_block
        self.lora_pre_conv_layers_per_block = lora_pre_conv_layers_per_block

        self.conv_in = torch.nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)

        self.down_blocks = nn.ModuleList([])
        self.pre_lora_layers = nn.ModuleList([])
        self.lora_layers = nn.ModuleList([])

        # pre_down
        pre_down_blocks = []
        output_channel = block_out_channels[0]
        for i, down_block_type in enumerate(down_block_types):
            input_channel = output_channel
            output_channel = block_out_channels[i]
            is_final_block = i == len(block_out_channels) - 1

            pre_down_block = get_down_block(
                down_block_type,
                num_layers=self.layers_per_block,
                in_channels=input_channel,
                out_channels=output_channel,
                add_downsample=not is_final_block,
                resnet_eps=1e-6,
                downsample_padding=0,
                resnet_act_fn=act_fn,
                resnet_groups=norm_num_groups,
                attn_num_head_channels=None,
                temb_channels=None,
            )
            pre_down_blocks.append(pre_down_block)
        self.down_blocks.append(nn.Sequential(*pre_down_blocks))
        self.pre_lora_layers.append(
            get_down_block(
                lora_pre_conv_types[0],
                num_layers=self.lora_pre_conv_layers_per_block,
                in_channels=lora_block_in_channels[0],
                out_channels=(
                    lora_block_out_channels[0] 
                    if lora_control_channels[0] is None 
                    else lora_control_channels[0]),
                add_downsample=False,
                resnet_eps=1e-6,
                downsample_padding=0,
                resnet_act_fn=act_fn,
                resnet_groups=norm_num_groups,
                attn_num_head_channels=None,
                temb_channels=None,
                resnet_kernel_size=lora_pre_conv_layers_kernel_size,
            ) if not lora_pre_conv_skipped else nn.Identity()
        )
        self.lora_layers.append(
            nn.ModuleList([
                lora_control_cls(
                    lora_block_out_channels[0], 
                    cross_attention_dim=cross_attention_dim, 
                    rank=lora_rank, 
                    control_rank=lora_control_rank,
                    post_add=lora_post_add,
                    concat_hidden=lora_concat_hidden,
                    control_channels=lora_control_channels[0],
                    control_self_add=lora_control_self_add,
                    key_states_skipped=lora_key_states_skipped,
                    value_states_skipped=lora_value_states_skipped,
                    output_states_skipped=lora_output_states_skipped)
                for cross_attention_dim in lora_cross_attention_dims[0]
            ])
        )
        
        # down
        output_channel = lora_block_in_channels[0]
        for i, down_block_type in enumerate(lora_pre_down_block_types):
            if i == 0:
                continue
            input_channel = output_channel
            output_channel = lora_block_in_channels[i]

            down_block = get_down_block(
                down_block_type,
                num_layers=self.lora_pre_down_layers_per_block,
                in_channels=input_channel,
                out_channels=output_channel,
                add_downsample=True,
                resnet_eps=1e-6,
                downsample_padding=0,
                resnet_act_fn=act_fn,
                resnet_groups=norm_num_groups,
                attn_num_head_channels=None,
                temb_channels=None,
            )
            self.down_blocks.append(down_block)

            self.pre_lora_layers.append(
                get_down_block(
                    lora_pre_conv_types[i],
                    num_layers=self.lora_pre_conv_layers_per_block,
                    in_channels=output_channel,
                    out_channels=(
                        lora_block_out_channels[i] 
                        if lora_control_channels[i] is None 
                        else lora_control_channels[i]),
                    add_downsample=False,
                    resnet_eps=1e-6,
                    downsample_padding=0,
                    resnet_act_fn=act_fn,
                    resnet_groups=norm_num_groups,
                    attn_num_head_channels=None,
                    temb_channels=None,
                    resnet_kernel_size=lora_pre_conv_layers_kernel_size,
                ) if not lora_pre_conv_skipped else nn.Identity()
            )
            self.lora_layers.append(
                nn.ModuleList([
                    lora_control_cls(
                        lora_block_out_channels[i], 
                        cross_attention_dim=cross_attention_dim, 
                        rank=lora_rank, 
                        control_rank=lora_control_rank,
                        post_add=lora_post_add,
                        concat_hidden=lora_concat_hidden,
                        control_channels=lora_control_channels[i],
                        control_self_add=lora_control_self_add,
                        key_states_skipped=lora_key_states_skipped,
                        value_states_skipped=lora_value_states_skipped,
                        output_states_skipped=lora_output_states_skipped)
                    for cross_attention_dim in lora_cross_attention_dims[i]
                ])
            )

    def forward(self, x: torch.FloatTensor, return_dict: bool = True) -> Union[ControlLoRAOutput, Tuple]:
        lora_layer: ControlLoRACrossAttnProcessor
        
        orig_dtype = x.dtype
        dtype = self.conv_in.weight.dtype

        h = x.to(dtype)
        h = self.conv_in(h)
        control_states_list = []

        # down
        for down_block, pre_lora_layer, lora_layer_list in zip(
            self.down_blocks, self.pre_lora_layers, self.lora_layers):
            h = down_block(h)
            control_states = pre_lora_layer(h)
            if isinstance(control_states, tuple):
                control_states = control_states[0]
            control_states = control_states.to(orig_dtype)
            for lora_layer in lora_layer_list:
                lora_layer.inject_control_states(control_states)
            control_states_list.append(control_states)

        if not return_dict:
            return tuple(control_states_list)

        return ControlLoRAOutput(control_states=tuple(control_states_list))

### Specifying Arguments

In [12]:
class Args:
    def __init__(
        self,
        pretrained_model_name_or_path,
        revision=None,
        dataset_name=None,
        dataset_config_name=None,
        train_data_dir=None,
        image_column="image",
        guide_column="guide",
        caption_column="text",
        validation_prompt=None,
        num_validation_images=4,
        validation_epochs=1,
        max_train_samples=None,
        output_dir="sd-fill50k-model-control-lora",
        cache_dir=None,
        seed=None,
        resolution=512,
        train_batch_size=16,
        num_train_epochs=100,
        max_train_steps=None,
        gradient_accumulation_steps=1,
        gradient_checkpointing=False,
        learning_rate=1e-4,
        scale_lr=False,
        lr_scheduler="constant",
        lr_warmup_steps=500,
        use_8bit_adam=False,
        allow_tf32=False,
        dataloader_num_workers=0,
        adam_beta1=0.9,
        adam_beta2=0.999,
        adam_weight_decay=1e-2,
        adam_epsilon=1e-08,
        max_grad_norm=1.0,
        push_to_hub=False,
        hub_token=None,
        hub_model_id=None,
        logging_dir="logs",
        mixed_precision=None,
        report_to="tensorboard",
        local_rank=-1,
        checkpointing_steps=500,
        resume_from_checkpoint=None,
        enable_xformers_memory_efficient_attention=False,
        control_lora_config=None,
        wandb_project_name=None,
    ):
        self.pretrained_model_name_or_path = pretrained_model_name_or_path
        self.revision = revision
        self.dataset_name = dataset_name
        self.dataset_config_name = dataset_config_name
        self.train_data_dir = train_data_dir
        self.image_column = image_column
        self.guide_column = guide_column
        self.caption_column = caption_column
        self.validation_prompt = validation_prompt
        self.num_validation_images = num_validation_images
        self.validation_epochs = validation_epochs
        self.max_train_samples = max_train_samples
        self.output_dir = output_dir
        self.cache_dir = cache_dir
        self.seed = seed
        self.resolution = resolution
        self.train_batch_size = train_batch_size
        self.num_train_epochs = num_train_epochs
        self.max_train_steps = max_train_steps
        self.gradient_accumulation_steps = gradient_accumulation_steps
        self.gradient_checkpointing = gradient_checkpointing
        self.learning_rate = learning_rate
        self.scale_lr = scale_lr
        self.lr_scheduler = lr_scheduler
        self.lr_warmup_steps = lr_warmup_steps
        self.use_8bit_adam = use_8bit_adam
        self.allow_tf32 = allow_tf32
        self.dataloader_num_workers = dataloader_num_workers
        self.adam_beta1 = adam_beta1
        self.adam_beta2 = adam_beta2
        self.adam_weight_decay = adam_weight_decay
        self.adam_epsilon = adam_epsilon
        self.max_grad_norm = max_grad_norm
        self.push_to_hub = push_to_hub
        self.hub_token = hub_token
        self.hub_model_id = hub_model_id
        self.logging_dir = logging_dir
        self.mixed_precision = mixed_precision
        self.report_to = report_to
        self.local_rank = local_rank
        self.checkpointing_steps = checkpointing_steps
        self.resume_from_checkpoint = resume_from_checkpoint
        self.enable_xformers_memory_efficient_attention = enable_xformers_memory_efficient_attention
        self.wandb_project_name = wandb_project_name

In [13]:
model_name = "sd-unsplash_10k_blur_rand_KS-model-control-lora"
dataset_name = "unsplash_10k_blur_rand_KS"

In [14]:
#@title Arguments
pretrained_model_name_or_path = "runwayml/stable-diffusion-v1-5" #@param {type:"string"}
revision = None #@param {type:"string"}
dataset_name = "wtcherr/"+dataset_name #@param {type:"string"}
dataset_config_name = None #@param {type:"string"}
train_data_dir = None #@param {type:"string"}
image_column = "image" #@param {type:"string"}
guide_column = "guide" #@param {type:"string"}
caption_column = "text" #@param {type:"string"}
validation_prompt = "a high-quality, detailed, and professional image" #@param {type:"string"}
num_validation_images = 3 #@param {type:"integer"}
validation_epochs = 1 #@param {type:"integer"}
max_train_samples = -1 #@param {type:"integer"}
output_dir = "ckpts/"+model_name #@param {type:"string"}
cache_dir = None #@param {type:"string"}
seed = 42 #@param {type:"integer"}
resolution = 512 #@param {type:"integer"}
train_batch_size = 1 #@param {type:"integer"}
num_train_epochs = 6 #@param {type:"integer"}
max_train_steps = None #@param {type:"string"}
gradient_accumulation_steps = 1 #@param {type:"integer"}
gradient_checkpointing = False #@param {type:"boolean"}
learning_rate = 1e-4 #@param {type:"number"}
scale_lr = False #@param {type:"boolean"}
lr_scheduler = "constant" #@param {type:"string"}
lr_warmup_steps = 0 #@param {type:"integer"}
use_8bit_adam = False #@param {type:"boolean"}
allow_tf32 = False #@param {type:"boolean"}
dataloader_num_workers = 0 #@param {type:"integer"}
adam_beta1 = 0.9 #@param {type:"number"}
adam_beta2 = 0.999 #@param {type:"number"}
adam_weight_decay = 1e-2 #@param {type:"number"}
adam_epsilon = 1e-08 #@param {type:"number"}
max_grad_norm = 1.0 #@param {type:"number"}
push_to_hub = True #@param {type:"boolean"}
hub_token = "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" #@param {type:"string"}
hub_model_id = None #@param {type:"string"}
logging_dir = "logs" #@param {type:"string"}
mixed_precision = "fp16" #@param {type:"string"}
report_to = "wandb" #@param {type:"string"}
local_rank = -1 #@param {type:"integer"}
checkpointing_steps = 5000 #@param {type:"integer"}
resume_from_checkpoint = "latest" #@param {type:"string"}
enable_xformers_memory_efficient_attention = False #@param {type:"boolean"}
wandb_project_name = "ControlLora_unsplash_10k_blur_rand_KS" #@param {type:"string"}

In [15]:
args = Args(
    pretrained_model_name_or_path=pretrained_model_name_or_path,
    revision=revision,
    dataset_name=dataset_name,
    dataset_config_name=dataset_config_name,
    train_data_dir=train_data_dir,
    image_column=image_column,
    guide_column=guide_column,
    caption_column=caption_column,
    validation_prompt=validation_prompt,
    num_validation_images=num_validation_images,
    validation_epochs=validation_epochs,
    max_train_samples=max_train_samples,
    output_dir=output_dir,
    cache_dir=cache_dir,
    seed=seed,
    resolution=resolution,
    train_batch_size=train_batch_size,
    num_train_epochs=num_train_epochs,
    max_train_steps=max_train_steps,
    gradient_accumulation_steps=gradient_accumulation_steps,
    gradient_checkpointing=gradient_checkpointing,
    learning_rate=learning_rate,
    scale_lr=scale_lr,
    lr_scheduler=lr_scheduler,
    lr_warmup_steps=lr_warmup_steps,
    use_8bit_adam=use_8bit_adam,
    allow_tf32=allow_tf32,
    dataloader_num_workers=dataloader_num_workers,
    adam_beta1=adam_beta1,
    adam_beta2=adam_beta2,
    adam_weight_decay=adam_weight_decay,
    adam_epsilon=adam_epsilon,
    max_grad_norm=max_grad_norm,
    push_to_hub=push_to_hub,
    hub_token=hub_token,
    hub_model_id=hub_model_id,
    logging_dir=logging_dir,
    mixed_precision=mixed_precision,
    report_to=report_to,
    local_rank=local_rank,
    checkpointing_steps=checkpointing_steps,
    resume_from_checkpoint=resume_from_checkpoint,
    enable_xformers_memory_efficient_attention=enable_xformers_memory_efficient_attention,
    wandb_project_name=wandb_project_name
)

### ControlLoRA Configuration

In [16]:
control_lora_config = {
  "_class_name": "ControlLoRA",
  "_diffusers_version": "0.13.0.dev0",
  "act_fn": "silu",
  "block_out_channels": [
    32,
    64,
    128,
    256
  ],
  "down_block_types": [
    "SimpleDownEncoderBlock2D",
    "SimpleDownEncoderBlock2D",
    "SimpleDownEncoderBlock2D",
    "SimpleDownEncoderBlock2D"
  ],
  "in_channels": 3,
  "layers_per_block": 1,
  "lora_block_in_channels": [
    256,
    256,
    256,
    256
  ],
  "lora_block_out_channels": [
    320,
    640,
    1280,
    1280
  ],
  "lora_control_rank": None,
  "lora_cross_attention_dims": [
    [
      None,
      768,
      None,
      768,
      None,
      768,
      None,
      768,
      None,
      768
    ],
    [
      None,
      768,
      None,
      768,
      None,
      768,
      None,
      768,
      None,
      768
    ],
    [
      None,
      768,
      None,
      768,
      None,
      768,
      None,
      768,
      None,
      768
    ],
    [
      None,
      768
    ]
  ],
  "lora_post_add": False,
  "lora_pre_conv_layers_kernel_size": 1,
  "lora_pre_conv_layers_per_block": 1,
  "lora_pre_conv_types": [
    "SimpleDownEncoderBlock2D",
    "SimpleDownEncoderBlock2D",
    "SimpleDownEncoderBlock2D",
    "SimpleDownEncoderBlock2D"
  ],
  "lora_pre_down_block_types": [
    None,
    "SimpleDownEncoderBlock2D",
    "SimpleDownEncoderBlock2D",
    "SimpleDownEncoderBlock2D"
  ],
  "lora_pre_down_layers_per_block": 1,
  "lora_rank": 4,
  "norm_num_groups": 32
}


### Defining helper functions

In [17]:
"""Fine-tuning script for Stable Diffusion for text2image with support for ControlLoRA."""
"""Code refer to https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora.py"""

from diffusers import utils
from diffusers.utils import deprecation_utils
from diffusers.models import cross_attention
utils.deprecate = lambda *arg, **kwargs: None
deprecation_utils.deprecate = lambda *arg, **kwargs: None
cross_attention.deprecate = lambda *arg, **kwargs: None

import argparse
import logging
import math
import os
import random
from pathlib import Path
from typing import Optional

import datasets
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from datasets import load_dataset
from huggingface_hub import HfFolder, Repository, create_repo, whoami
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer

import diffusers
from diffusers import (
    AutoencoderKL, 
    DDPMScheduler, 
    DPMSolverMultistepScheduler, 
    DiffusionPipeline, 
    UNet2DConditionModel)
from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available


# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
#check_min_version("0.13.0.dev0")

logger = get_logger(__name__, log_level="INFO")


def save_model_card(repo_name, images=None, base_model=str, dataset_name=str, repo_folder=None):
    img_str = ""
    for i, image in enumerate(images):
        image.save(os.path.join(repo_folder, f"image_{i}.png"))
        img_str += f"![img_{i}](./image_{i}.png)\n"

    yaml = f"""
---
license: creativeml-openrail-m
base_model: {base_model}
tags:
- stable-diffusion
- stable-diffusion-diffusers
- text-to-image
- diffusers
- lora
- controlnet
- control-lora
inference: true
---
    """
    model_card = f"""
# ControlLoRA text2image fine-tuning - {repo_name}
These are ControlLoRA adaption weights for {base_model}. The weights were fine-tuned on the {dataset_name} dataset. You can find some example images in the following. \n
{img_str}
"""
    with open(os.path.join(repo_folder, "README.md"), "w") as f:
        f.write(yaml + model_card)


def parse_args():
    env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
    if env_local_rank != -1 and env_local_rank != args.local_rank:
        args.local_rank = env_local_rank

    # Sanity checks
    if args.dataset_name is None and args.train_data_dir is None:
        raise ValueError("Need either a dataset name or a training folder.")

    return args


def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
    if token is None:
        token = HfFolder.get_token()
    if organization is None:
        username = whoami(token)["name"]
        return f"{username}/{model_id}"
    else:
        return f"{organization}/{model_id}"


DATASET_NAME_MAPPING = ("image", "guide", "text");

### Initializing the accelerator and logging

In [18]:
args = parse_args()
logging_dir = os.path.join(args.output_dir, args.logging_dir)

accelerator = Accelerator(
    gradient_accumulation_steps=args.gradient_accumulation_steps,
    mixed_precision=args.mixed_precision,
    log_with=args.report_to,
    logging_dir=logging_dir,
)
if args.report_to == "wandb":
    if not is_wandb_available():
        raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
    import wandb

# Make one log on every process with the configuration for debugging.
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO,
)
logger.info(accelerator.state, main_process_only=False)
if accelerator.is_local_main_process:
    datasets.utils.logging.set_verbosity_warning()
    transformers.utils.logging.set_verbosity_warning()
    diffusers.utils.logging.set_verbosity_info()
else:
    datasets.utils.logging.set_verbosity_error()
    transformers.utils.logging.set_verbosity_error()
    diffusers.utils.logging.set_verbosity_error()

# If passed along, set the training seed now.
if args.seed is not None:
    set_seed(args.seed)

# Handle the repository creation
if accelerator.is_main_process:
    if args.push_to_hub:
        if args.hub_model_id is None:
            repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
        else:
            repo_name = args.hub_model_id
        repo_name = create_repo(repo_name, exist_ok=True, token=args.hub_token)
        repo = Repository(args.output_dir, clone_from=repo_name)

        with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
            if "step_*" not in gitignore:
                gitignore.write("step_*\n")
            if "epoch_*" not in gitignore:
                gitignore.write("epoch_*\n")
    elif args.output_dir is not None:
        os.makedirs(args.output_dir, exist_ok=True)


06/09/2023 00:40:00 - INFO - __main__ - Distributed environment: NO
Num processes: 1
Process index: 0
Local process index: 0
Device: cuda

Mixed precision type: fp16

d:\BUE\Graduation Project\MyWork\ckpts/sd-unsplash_10k_blur_rand_KS-model-control-lora is already a clone of https://huggingface.co/wtcherr/sd-unsplash_10k_blur_rand_KS-model-control-lora. Make sure you pull the latest changes with `repo.git_pull()`.


### Loading models

In [19]:
# Load scheduler, tokenizer and models.
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
tokenizer = CLIPTokenizer.from_pretrained(
    args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
)
text_encoder = CLIPTextModel.from_pretrained(
    args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
)
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
unet = UNet2DConditionModel.from_pretrained(
    args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
)

{'prediction_type', 'sample_max_value', 'thresholding', 'clip_sample_range', 'dynamic_thresholding_ratio', 'variance_type'} was not found in config. Values will be initialized to default values.
{'scaling_factor'} was not found in config. Values will be initialized to default values.
{'projection_class_embeddings_input_dim', 'class_embeddings_concat', 'conv_in_kernel', 'cross_attention_norm', 'encoder_hid_dim', 'conv_out_kernel', 'class_embed_type', 'resnet_time_scale_shift', 'time_embedding_dim', 'time_embedding_act_fn', 'dual_cross_attention', 'time_cond_proj_dim', 'mid_block_type', 'addition_embed_type_num_heads', 'resnet_out_scale_factor', 'mid_block_only_cross_attention', 'num_class_embeds', 'time_embedding_type', 'addition_embed_type', 'resnet_skip_time_act', 'use_linear_projection', 'only_cross_attention', 'upcast_attention', 'timestep_post_act'} was not found in config. Values will be initialized to default values.


In [20]:
print(unet)

UNet2DConditionModel(
  (conv_in): Conv2d(4, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (time_proj): Timesteps()
  (time_embedding): TimestepEmbedding(
    (linear_1): Linear(in_features=320, out_features=1280, bias=True)
    (act): SiLU()
    (linear_2): Linear(in_features=1280, out_features=1280, bias=True)
  )
  (down_blocks): ModuleList(
    (0): CrossAttnDownBlock2D(
      (attentions): ModuleList(
        (0-1): 2 x Transformer2DModel(
          (norm): GroupNorm(32, 320, eps=1e-06, affine=True)
          (proj_in): Conv2d(320, 320, kernel_size=(1, 1), stride=(1, 1))
          (transformer_blocks): ModuleList(
            (0): BasicTransformerBlock(
              (norm1): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
              (attn1): Attention(
                (to_q): Linear(in_features=320, out_features=320, bias=False)
                (to_k): Linear(in_features=320, out_features=320, bias=False)
                (to_v): Linear(in_features=320, out_fe

### Injecting ControlLoRA for cross-attention

In [21]:
n_ch = len(unet.config.block_out_channels)
control_ids = [i for i in range(n_ch)]
cross_attention_dims = {i: [] for i in range(n_ch)}
for name in unet.attn_processors.keys():
    cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
    if name.startswith("mid_block"):
        control_id = control_ids[-1]
    elif name.startswith("up_blocks"):
        block_id = int(name[len("up_blocks.")])
        control_id = list(reversed(control_ids))[block_id]
    elif name.startswith("down_blocks"):
        block_id = int(name[len("down_blocks.")])
        control_id = control_ids[block_id]
    cross_attention_dims[control_id].append(cross_attention_dim)
cross_attention_dims = tuple([cross_attention_dims[control_id] for control_id in control_ids])

control_lora = ControlLoRA.from_config(control_lora_config)

# freeze parameters of models to save more memory
unet.requires_grad_(False)
vae.requires_grad_(False)

text_encoder.requires_grad_(False)

# For mixed precision training we cast the text_encoder and vae weights to half-precision
# as these models are only used for inference, keeping weights in full precision is not required.
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
    weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
    weight_dtype = torch.bfloat16

# Move unet, vae and text_encoder to device and cast to weight_dtype
unet.to(accelerator.device, dtype=weight_dtype)
vae.to(accelerator.device, dtype=weight_dtype)
text_encoder.to(accelerator.device, dtype=weight_dtype)
control_lora.to(accelerator.device) # control_lora.to(accelerator.device), dtype=weight_dtype)

if args.enable_xformers_memory_efficient_attention:
    if is_xformers_available():
        unet.enable_xformers_memory_efficient_attention()
    else:
        raise ValueError("xformers is not available. Make sure it is installed correctly")

# now we will add new LoRA weights to the attention layers
# It's important to realize here how many attention weights will be added and of which sizes
# The sizes of the attention layers consist only of two different variables:
# 1) - the "hidden_size", which is increased according to `unet.config.block_out_channels`.
# 2) - the "cross attention size", which is set to `unet.config.cross_attention_dim`.

# Let's first see how many attention processors we will have to set.
# For Stable Diffusion, it should be equal to:
# - down blocks (2x attention layers) * (2x transformer layers) * (3x down blocks) = 12
# - mid blocks (2x attention layers) * (1x transformer layers) * (1x mid blocks) = 2
# - up blocks (2x attention layers) * (3x transformer layers) * (3x up blocks) = 18
# => 32 layers

# Set correct lora layers
lora_attn_procs = {}
lora_layers_list = list([list(layer_list) for layer_list in control_lora.lora_layers])
for name in unet.attn_processors.keys():
    cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
    if name.startswith("mid_block"):
        control_id = control_ids[-1]
    elif name.startswith("up_blocks"):
        block_id = int(name[len("up_blocks.")])
        control_id = list(reversed(control_ids))[block_id]
    elif name.startswith("down_blocks"):
        block_id = int(name[len("down_blocks.")])
        control_id = control_ids[block_id]

    lora_layers = lora_layers_list[control_id]
    if len(lora_layers) != 0:
        lora_layer = lora_layers.pop(0) 
        lora_attn_procs[name] = lora_layer

unet.set_attn_processor(lora_attn_procs)

{'lora_key_states_skipped', 'lora_control_self_add', 'lora_output_states_skipped', 'lora_value_states_skipped', 'lora_pre_conv_skipped', 'lora_control_channels', 'lora_concat_hidden'} was not found in config. Values will be initialized to default values.


In [22]:
print(unet)

UNet2DConditionModel(
  (conv_in): Conv2d(4, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (time_proj): Timesteps()
  (time_embedding): TimestepEmbedding(
    (linear_1): Linear(in_features=320, out_features=1280, bias=True)
    (act): SiLU()
    (linear_2): Linear(in_features=1280, out_features=1280, bias=True)
  )
  (down_blocks): ModuleList(
    (0): CrossAttnDownBlock2D(
      (attentions): ModuleList(
        (0-1): 2 x Transformer2DModel(
          (norm): GroupNorm(32, 320, eps=1e-06, affine=True)
          (proj_in): Conv2d(320, 320, kernel_size=(1, 1), stride=(1, 1))
          (transformer_blocks): ModuleList(
            (0): BasicTransformerBlock(
              (norm1): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
              (attn1): Attention(
                (to_q): Linear(in_features=320, out_features=320, bias=False)
                (to_k): Linear(in_features=320, out_features=320, bias=False)
                (to_v): Linear(in_features=320, out_fe

In [23]:
def print_trainable_parameters(models):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for model in models:
        for _, param in model.named_parameters():
            all_param += param.numel()
            if param.requires_grad:
                trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
    )
  
print_trainable_parameters([text_encoder, vae, unet])

trainable params: 996864 || all params: 1067232171 || trainable%: 0.0934064795915902


In [24]:
# Enable TF32 for faster training on Ampere GPUs,
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
if args.allow_tf32:
    torch.backends.cuda.matmul.allow_tf32 = True

if args.scale_lr:
    args.learning_rate = (
        args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
    )

# Initialize the optimizer
if args.use_8bit_adam:
    try:
        import bitsandbytes as bnb
    except ImportError:
        raise ImportError(
            "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
        )

    optimizer_cls = bnb.optim.AdamW8bit
else:
    optimizer_cls = torch.optim.AdamW

optimizer = optimizer_cls(
    control_lora.parameters(),
    lr=args.learning_rate,
    betas=(args.adam_beta1, args.adam_beta2),
    weight_decay=args.adam_weight_decay,
    eps=args.adam_epsilon,
)

# Preprocessing the datasets.
# We need to tokenize input captions and transform the images.
def tokenize_captions(examples, is_train=True):
    captions = []
    for caption in examples[caption_column]:
        if isinstance(caption, str):
            captions.append(caption)
        elif isinstance(caption, (list, np.ndarray)):
            # take a random caption if there are multiple
            captions.append(random.choice(caption) if is_train else caption[0])
        else:
            raise ValueError(
                f"Caption column `{caption_column}` should contain either strings or lists of strings."
            )
    inputs = tokenizer(
        captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
    )
    return inputs.input_ids

# Get the datasets: you can either provide your own training and evaluation files (see below)
# or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).

# In distributed training, the load_dataset function guarantees that only one local process can concurrently
# download the dataset.
dataset_cls = dataset_cls
use_custom_dataset = False
if args.dataset_name.startswith('process/'):
    # Use custom dataset define in process
    use_custom_dataset = True
    dataset_cls = dataset_cls.from_name(args.dataset_name)
    dataset = dataset_cls(tokenize_captions, resolution=args.resolution, use_crop=True)
elif args.dataset_name is not None:
    # Downloading and loading a dataset from the hub.
    dataset = load_dataset(
        args.dataset_name,
        args.dataset_config_name,
        cache_dir=args.cache_dir,
    )
else:
    data_files = {}
    if args.train_data_dir is not None:
        data_files["train"] = os.path.join(args.train_data_dir, "**")
    dataset = load_dataset(
        "imagefolder",
        data_files=data_files,
        cache_dir=args.cache_dir,
    )
    # See more about loading custom images at
    # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder

if use_custom_dataset:
    collate_fn = None
    train_dataset = dataset
    caption_column = args.caption_column
else:
    # Preprocessing the datasets.
    # We need to tokenize inputs and targets.
    column_names = dataset["train"].column_names

    # 6. Get the column names for input/target.
    dataset_columns = DATASET_NAME_MAPPING
    if args.image_column is None:
        image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
    else:
        image_column = args.image_column
        if image_column not in column_names:
            raise ValueError(
                f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}"
            )
    if args.guide_column is None:
        guide_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
    else:
        guide_column = args.guide_column
        if guide_column not in column_names:
            raise ValueError(
                f"--guide_column' value '{args.guide_column}' needs to be one of: {', '.join(column_names)}"
            )
    if args.caption_column is None:
        caption_column = dataset_columns[2] if dataset_columns is not None else column_names[2]
    else:
        caption_column = args.caption_column
        if caption_column not in column_names:
            raise ValueError(
                f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}"
            )

    # Preprocessing the datasets.
    train_transforms = transforms.Compose(
        [
            transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5]),
        ]
    )

    def preprocess_train(examples):
        images, guides = [], []
        for image, guide in zip(examples[image_column], examples[guide_column]):
            image, guide = image.convert("RGB"), guide.convert("RGB")
            image, guide = train_transforms(image), train_transforms(guide)
            c, h, w = image.shape
            y1, x1 = 0, 0
            if h != args.resolution:
                y1 = torch.randint(0, h - args.resolution, (1, )).item()
            elif w != args.resolution:
                x1 = torch.randint(0, w - args.resolution, (1, )).item()
            y2, x2 = y1 + args.resolution, x1 + args.resolution
            image = image[:,y1:y2,x1:x2]
            guide = guide[:,y1:y2,x1:x2]
            images.append(image)
            guides.append(guide)
            
        examples["pixel_values"] = images
        examples["guide_values"] = guides
        examples["input_ids"] = tokenize_captions(examples)
        return examples

    with accelerator.main_process_first():
        if args.max_train_samples != -1 and args.max_train_samples != None:
            dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
        # Set the training transforms
        train_dataset = dataset["train"].with_transform(preprocess_train)

    def collate_fn(examples):
        pixel_values = torch.stack([example["pixel_values"] for example in examples])
        pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
        guide_values = torch.stack([example["guide_values"] for example in examples])
        guide_values = guide_values.to(memory_format=torch.contiguous_format).float()
        input_ids = torch.stack([example["input_ids"] for example in examples])
        return {"pixel_values": pixel_values, "guide_values": guide_values, "input_ids": input_ids}


Downloading readme:   0%|          | 0.00/447 [00:00<?, ?B/s]



  0%|          | 0/1 [00:00<?, ?it/s]

### Preparing Everything

In [25]:
# DataLoaders creation:
train_dataloader = torch.utils.data.DataLoader(
    train_dataset,
    shuffle=True,
    collate_fn=collate_fn,
    batch_size=args.train_batch_size,
    num_workers=args.dataloader_num_workers,
)
val_dataloader = torch.utils.data.DataLoader(
    train_dataset,
    shuffle=False,
    collate_fn=collate_fn,
    batch_size=1,
    num_workers=0,
)
val_iter = iter(val_dataloader)


In [26]:
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
    args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
    overrode_max_train_steps = True

lr_scheduler = get_scheduler(
    args.lr_scheduler,
    optimizer=optimizer,
    num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
    num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
)


In [27]:
# Prepare everything with our `accelerator`.
control_lora, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
    control_lora, optimizer, train_dataloader, lr_scheduler
)

# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if overrode_max_train_steps:
    args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)

# We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process.
if accelerator.is_main_process:
    accelerator.init_trackers(args.wandb_project_name, config=vars(args))

06/09/2023 00:40:31 - ERROR - wandb.jupyter - Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mwtcherr[0m ([33mwtcherrr[0m). Use [1m`wandb login --relogin`[0m to force relogin


### Training

In [28]:
# Train!
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps

logger.info("***** Running training *****")
logger.info(f"  Num examples = {len(train_dataset)}")
logger.info(f"  Num Epochs = {args.num_train_epochs}")
logger.info(f"  Instantaneous batch size per device = {args.train_batch_size}")
logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f"  Total optimization steps = {args.max_train_steps}")
global_step = 0
first_epoch = 0

# Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint:
    if args.resume_from_checkpoint != "latest":
        path = os.path.basename(args.resume_from_checkpoint)
    else:
        # Get the most recent checkpoint
        dirs = os.listdir(args.output_dir)
        dirs = [d for d in dirs if d.startswith("checkpoint")]
        dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
        path = dirs[-1] if len(dirs) > 0 else None

    if path is None:
        accelerator.print(
            f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
        )
        args.resume_from_checkpoint = None
    else:
        accelerator.print(f"Resuming from checkpoint {path}")
        accelerator.load_state(os.path.join(args.output_dir, path))
        global_step = int(path.split("-")[1])

        resume_global_step = global_step * args.gradient_accumulation_steps
        first_epoch = global_step // num_update_steps_per_epoch
        resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)

# Only show the progress bar once on each machine.
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
progress_bar.set_description("Steps")

for epoch in range(first_epoch, args.num_train_epochs):
    unet.train()
    train_loss = 0.0
    for step, batch in enumerate(train_dataloader):
        # Skip steps until we reach the resumed step
        if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
            if step % args.gradient_accumulation_steps == 0:
                progress_bar.update(1)
            continue

        with accelerator.accumulate(unet):
            # Convert images to latent space
            latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
            latents = latents * vae.config.scaling_factor

            # Sample noise that we'll add to the latents
            noise = torch.randn_like(latents)
            bsz = latents.shape[0]
            # Sample a random timestep for each image
            timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)
            timesteps = timesteps.long()

            # Add noise to the latents according to the noise magnitude at each timestep
            # (this is the forward diffusion process)
            noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

            # Get the text embedding for conditioning
            encoder_hidden_states = text_encoder(batch["input_ids"])[0]

            # Inject control states to unet
            _ = control_lora(batch["guide_values"].to(dtype=weight_dtype)).control_states

            # Get the target for loss depending on the prediction type
            if noise_scheduler.config.prediction_type == "epsilon":
                target = noise
            elif noise_scheduler.config.prediction_type == "v_prediction":
                target = noise_scheduler.get_velocity(latents, noise, timesteps)
            else:
                raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")

            # Predict the noise residual and compute loss
            model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
            loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")

            # Gather the losses across all processes for logging (if we use distributed training).
            avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
            train_loss += avg_loss.item() / args.gradient_accumulation_steps

            # Backpropagate
            accelerator.backward(loss)
            if accelerator.sync_gradients:
                params_to_clip = control_lora.parameters()
                accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()

        # Checks if the accelerator has performed an optimization step behind the scenes
        if accelerator.sync_gradients:
            progress_bar.update(1)
            global_step += 1
            accelerator.log({"train_loss": train_loss}, step=global_step)
            train_loss = 0.0

            if global_step % args.checkpointing_steps == 0:
                if accelerator.is_main_process:
                    save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
                    accelerator.save_state(save_path)
                    logger.info(f"Saved state to {save_path}")

                    if args.validation_prompt is not None:
                        logger.info(
                            f"Running sampling... \n Generating {args.num_validation_images} images with prompt:"
                            f" {args.validation_prompt}."
                        )
                        # create pipeline
                        pipeline = DiffusionPipeline.from_pretrained(
                            args.pretrained_model_name_or_path,
                            unet=accelerator.unwrap_model(unet),
                            revision=args.revision,
                            torch_dtype=weight_dtype,
                            safety_checker=None
                        )
                        pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
                        pipeline = pipeline.to(accelerator.device)
                        pipeline.set_progress_bar_config(disable=True)

                        # run inference
                        generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
                        images = []
                        for _ in range(args.num_validation_images):
                            with torch.no_grad():
                                try:
                                    batch = next(val_iter)
                                except:
                                    val_iter = iter(val_dataloader)
                                    batch = next(val_iter)
                                target = batch["pixel_values"].to(dtype=weight_dtype)
                                guide = batch["guide_values"].to(accelerator.device)
                                _ = control_lora(guide).control_states
                                image = pipeline(
                                    args.validation_prompt, num_inference_steps=30, generator=generator).images[0]
                                image = dataset_cls.cat_input(image, target, guide)
                            images.append(image)

                        for tracker in accelerator.trackers:
                            if tracker.name == "tensorboard":
                                np_images = np.stack([np.asarray(img) for img in images])
                                tracker.writer.add_images("sampling", np_images, epoch, dataformats="NHWC")
                            if tracker.name == "wandb":
                                tracker.log(
                                    {
                                        "sampling": [
                                            wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
                                            for i, image in enumerate(images)
                                        ]
                                    }
                                )

                        del pipeline
                        torch.cuda.empty_cache()

        logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
        progress_bar.set_postfix(**logs)

        if global_step >= args.max_train_steps:
            break
    
    if accelerator.is_main_process:
        if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
            logger.info(
                f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
                f" {args.validation_prompt}."
            )
            # create pipeline
            pipeline = DiffusionPipeline.from_pretrained(
                args.pretrained_model_name_or_path,
                unet=accelerator.unwrap_model(unet),
                revision=args.revision,
                torch_dtype=weight_dtype,
                safety_checker=None
            )
            pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
            pipeline = pipeline.to(accelerator.device)
            pipeline.set_progress_bar_config(disable=True)

            # run inference
            generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
            images = []
            for _ in range(args.num_validation_images):
                with torch.no_grad():
                    try:
                        batch = next(val_iter)
                    except:
                        val_iter = iter(val_dataloader)
                        batch = next(val_iter)
                    target = batch["pixel_values"].to(dtype=weight_dtype)
                    guide = batch["guide_values"].to(accelerator.device)
                    _ = control_lora(guide).control_states
                    image = pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]
                    image = dataset_cls.cat_input(image, target, guide)
                images.append(image)

            if accelerator.is_main_process:
                for tracker in accelerator.trackers:
                    if tracker.name == "tensorboard":
                        np_images = np.stack([np.asarray(img) for img in images])
                        tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
                    if tracker.name == "wandb":
                        tracker.log(
                            {
                                "validation": [
                                    wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
                                    for i, image in enumerate(images)
                                ]
                            }
                        )

            del pipeline
            torch.cuda.empty_cache()

06/09/2023 00:40:44 - INFO - __main__ - ***** Running training *****
06/09/2023 00:40:44 - INFO - __main__ -   Num examples = 10000
06/09/2023 00:40:44 - INFO - __main__ -   Num Epochs = 6
06/09/2023 00:40:44 - INFO - __main__ -   Instantaneous batch size per device = 1
06/09/2023 00:40:44 - INFO - __main__ -   Total train batch size (w. parallel, distributed & accumulation) = 1
06/09/2023 00:40:44 - INFO - __main__ -   Gradient Accumulation steps = 1
06/09/2023 00:40:44 - INFO - __main__ -   Total optimization steps = 60000
06/09/2023 00:40:44 - INFO - accelerate.accelerator - Loading states from ckpts/sd-unsplash_10k_blur_rand_KS-model-control-lora\checkpoint-5000


Resuming from checkpoint checkpoint-5000


06/09/2023 00:40:45 - INFO - accelerate.checkpointing - All model weights loaded successfully
06/09/2023 00:40:45 - INFO - accelerate.checkpointing - All optimizer states loaded successfully
06/09/2023 00:40:45 - INFO - accelerate.checkpointing - All scheduler states loaded successfully
06/09/2023 00:40:45 - INFO - accelerate.checkpointing - GradScaler state loaded successfully
06/09/2023 00:41:11 - INFO - accelerate.checkpointing - All random states loaded successfully
06/09/2023 00:41:11 - INFO - accelerate.accelerator - Loading in 0 custom states


  0%|          | 0/55000 [00:00<?, ?it/s]

  deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False)
06/09/2023 01:10:43 - INFO - accelerate.accelerator - Saving current state to ckpts/sd-unsplash_10k_blur_rand_KS-model-control-lora\checkpoint-10000
06/09/2023 01:10:43 - INFO - accelerate.checkpointing - Model weights saved in ckpts/sd-unsplash_10k_blur_rand_KS-model-control-lora\checkpoint-10000\pytorch_model.bin
06/09/2023 01:10:43 - INFO - accelerate.checkpointing - Optimizer state saved in ckpts/sd-unsplash_10k_blur_rand_KS-model-control-lora\checkpoint-10000\optimizer.bin
06/09/2023 01:10:43 - INFO - accelerate.checkpointing - Scheduler state saved in ckpts/sd-unsplash_10k_blur_rand_KS-model-control-lora\checkpoint-10000\scheduler.bin
06/09/2023 01:10:43 - INFO - accelerate.checkpointing - Gradient scaler state saved in ckpts/sd-unsplash_10k_blur_rand_KS-model-control-lora\checkpoint-10000\scaler.pt
06/09/2023 01:10:43 - INFO - accelerate.checkpointing - Random states saved in ckpts/sd-uns

### Saving the model to huggingface

In [27]:
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [29]:
# Save the lora layers
accelerator.wait_for_everyone()
if accelerator.is_main_process:
    repo = Repository(args.output_dir, clone_from=repo_name)
    unet = unet.to(torch.float32)
    # unet.save_attn_procs(args.output_dir)
    control_lora.save_config(args.output_dir)
    control_lora.save_pretrained(args.output_dir, safe_serialization=False)
    control_lora.save_pretrained(args.output_dir, safe_serialization=True)

    if args.push_to_hub:
        save_model_card(
            repo_name,
            images=images,
            base_model=args.pretrained_model_name_or_path,
            dataset_name=args.dataset_name,
            repo_folder=args.output_dir,
        )
        repo.git_add(auto_lfs_track=True)
        repo.git_commit(commit_message='End of training')
        repo.git_push()

d:\BUE\Graduation Project\MyWork\ckpts/sd-unsplash_10k_blur_rand_KS-model-control-lora is already a clone of https://huggingface.co/wtcherr/sd-unsplash_10k_blur_rand_KS-model-control-lora. Make sure you pull the latest changes with `repo.git_pull()`.
Configuration saved in ckpts/sd-unsplash_10k_blur_rand_KS-model-control-lora\config.json
Configuration saved in ckpts/sd-unsplash_10k_blur_rand_KS-model-control-lora\config.json
Model weights saved in ckpts/sd-unsplash_10k_blur_rand_KS-model-control-lora\diffusion_pytorch_model.bin
Configuration saved in ckpts/sd-unsplash_10k_blur_rand_KS-model-control-lora\config.json
Model weights saved in ckpts/sd-unsplash_10k_blur_rand_KS-model-control-lora\diffusion_pytorch_model.safetensors
Adding files tracked by Git LFS: ['image_0.png', 'image_1.png', 'image_2.png']. This may take a bit of time if the files are large.


Upload file checkpoint-40000/optimizer.bin:   0%|          | 1.00/46.5M [00:00<?, ?B/s]

Upload file checkpoint-30000/optimizer.bin:   0%|          | 1.00/46.5M [00:00<?, ?B/s]

Upload file checkpoint-10000/optimizer.bin:   0%|          | 1.00/46.5M [00:00<?, ?B/s]

Upload file checkpoint-35000/optimizer.bin:   0%|          | 1.00/46.5M [00:00<?, ?B/s]

Upload file checkpoint-5000/optimizer.bin:   0%|          | 1.00/46.5M [00:00<?, ?B/s]

Upload file checkpoint-45000/optimizer.bin:   0%|          | 1.00/46.5M [00:00<?, ?B/s]

Upload file checkpoint-15000/optimizer.bin:   0%|          | 1.00/46.5M [00:00<?, ?B/s]

Upload file checkpoint-55000/optimizer.bin:   0%|          | 1.00/46.5M [00:00<?, ?B/s]

Upload file checkpoint-50000/optimizer.bin:   0%|          | 1.00/46.5M [00:00<?, ?B/s]

Upload file checkpoint-20000/optimizer.bin:   0%|          | 1.00/46.5M [00:00<?, ?B/s]

Upload file checkpoint-25000/optimizer.bin:   0%|          | 1.00/46.5M [00:00<?, ?B/s]

Upload file checkpoint-60000/optimizer.bin:   0%|          | 1.00/46.5M [00:00<?, ?B/s]

Upload file diffusion_pytorch_model.bin:   0%|          | 1.00/23.2M [00:00<?, ?B/s]

Upload file checkpoint-30000/pytorch_model.bin:   0%|          | 1.00/23.2M [00:00<?, ?B/s]

Upload file checkpoint-45000/pytorch_model.bin:   0%|          | 1.00/23.2M [00:00<?, ?B/s]

Upload file checkpoint-20000/pytorch_model.bin:   0%|          | 1.00/23.2M [00:00<?, ?B/s]

Upload file checkpoint-25000/pytorch_model.bin:   0%|          | 1.00/23.2M [00:00<?, ?B/s]

Upload file checkpoint-50000/pytorch_model.bin:   0%|          | 1.00/23.2M [00:00<?, ?B/s]

Upload file checkpoint-60000/pytorch_model.bin:   0%|          | 1.00/23.2M [00:00<?, ?B/s]

Upload file checkpoint-5000/pytorch_model.bin:   0%|          | 1.00/23.2M [00:00<?, ?B/s]

Upload file checkpoint-40000/pytorch_model.bin:   0%|          | 1.00/23.2M [00:00<?, ?B/s]

Upload file checkpoint-15000/pytorch_model.bin:   0%|          | 1.00/23.2M [00:00<?, ?B/s]

Upload file checkpoint-10000/pytorch_model.bin:   0%|          | 1.00/23.2M [00:00<?, ?B/s]

Upload file checkpoint-55000/pytorch_model.bin:   0%|          | 1.00/23.2M [00:00<?, ?B/s]

Upload file checkpoint-35000/pytorch_model.bin:   0%|          | 1.00/23.2M [00:00<?, ?B/s]

Upload file diffusion_pytorch_model.safetensors:   0%|          | 1.00/23.1M [00:00<?, ?B/s]

Upload file image_1.png:   0%|          | 1.00/1.04M [00:00<?, ?B/s]

Upload file image_0.png:   0%|          | 1.00/915k [00:00<?, ?B/s]

Upload file image_2.png:   0%|          | 1.00/822k [00:00<?, ?B/s]

Upload file checkpoint-40000/random_states_0.pkl:   0%|          | 1.00/14.3k [00:00<?, ?B/s]

Upload file checkpoint-5000/random_states_0.pkl:   0%|          | 1.00/14.3k [00:00<?, ?B/s]

Upload file checkpoint-10000/random_states_0.pkl:   0%|          | 1.00/14.3k [00:00<?, ?B/s]

Upload file checkpoint-35000/random_states_0.pkl:   0%|          | 1.00/14.3k [00:00<?, ?B/s]

Upload file checkpoint-15000/random_states_0.pkl:   0%|          | 1.00/14.3k [00:00<?, ?B/s]

Upload file checkpoint-60000/random_states_0.pkl:   0%|          | 1.00/14.3k [00:00<?, ?B/s]

Upload file checkpoint-20000/random_states_0.pkl:   0%|          | 1.00/14.3k [00:00<?, ?B/s]

Upload file checkpoint-55000/random_states_0.pkl:   0%|          | 1.00/14.3k [00:00<?, ?B/s]

Upload file checkpoint-45000/random_states_0.pkl:   0%|          | 1.00/14.3k [00:00<?, ?B/s]

Upload file checkpoint-50000/random_states_0.pkl:   0%|          | 1.00/14.3k [00:00<?, ?B/s]

Upload file checkpoint-25000/random_states_0.pkl:   0%|          | 1.00/14.3k [00:00<?, ?B/s]

Upload file checkpoint-30000/random_states_0.pkl:   0%|          | 1.00/14.3k [00:00<?, ?B/s]

Upload file checkpoint-55000/scheduler.bin:   0%|          | 1.00/563 [00:00<?, ?B/s]

Upload file checkpoint-35000/scheduler.bin:   0%|          | 1.00/563 [00:00<?, ?B/s]

Upload file checkpoint-10000/scheduler.bin:   0%|          | 1.00/563 [00:00<?, ?B/s]

Upload file checkpoint-5000/scheduler.bin:   0%|          | 1.00/563 [00:00<?, ?B/s]

Upload file checkpoint-60000/scheduler.bin:   0%|          | 1.00/563 [00:00<?, ?B/s]

Upload file checkpoint-45000/scheduler.bin:   0%|          | 1.00/563 [00:00<?, ?B/s]

Upload file checkpoint-25000/scheduler.bin:   0%|          | 1.00/563 [00:00<?, ?B/s]

Upload file checkpoint-40000/scheduler.bin:   0%|          | 1.00/563 [00:00<?, ?B/s]

Upload file checkpoint-50000/scheduler.bin:   0%|          | 1.00/563 [00:00<?, ?B/s]

Upload file checkpoint-20000/scheduler.bin:   0%|          | 1.00/563 [00:00<?, ?B/s]

Upload file checkpoint-15000/scheduler.bin:   0%|          | 1.00/563 [00:00<?, ?B/s]

Upload file checkpoint-30000/scheduler.bin:   0%|          | 1.00/563 [00:00<?, ?B/s]

Upload file checkpoint-45000/scaler.pt:   0%|          | 1.00/557 [00:00<?, ?B/s]

Upload file checkpoint-55000/scaler.pt:   0%|          | 1.00/557 [00:00<?, ?B/s]

Upload file checkpoint-30000/scaler.pt:   0%|          | 1.00/557 [00:00<?, ?B/s]

Upload file checkpoint-20000/scaler.pt:   0%|          | 1.00/557 [00:00<?, ?B/s]

Upload file checkpoint-50000/scaler.pt:   0%|          | 1.00/557 [00:00<?, ?B/s]

Upload file checkpoint-60000/scaler.pt:   0%|          | 1.00/557 [00:00<?, ?B/s]

Upload file checkpoint-25000/scaler.pt:   0%|          | 1.00/557 [00:00<?, ?B/s]

Upload file checkpoint-15000/scaler.pt:   0%|          | 1.00/557 [00:00<?, ?B/s]

Upload file checkpoint-40000/scaler.pt:   0%|          | 1.00/557 [00:00<?, ?B/s]

Upload file checkpoint-5000/scaler.pt:   0%|          | 1.00/557 [00:00<?, ?B/s]

Upload file checkpoint-10000/scaler.pt:   0%|          | 1.00/557 [00:00<?, ?B/s]

Upload file checkpoint-35000/scaler.pt:   0%|          | 1.00/557 [00:00<?, ?B/s]

To https://huggingface.co/wtcherr/sd-unsplash_10k_blur_rand_KS-model-control-lora
   3e18b08..23b198d  main -> main

   3e18b08..23b198d  main -> main



### Evaluating the model

In [30]:
# Final inference
# Load previous pipeline
pipeline = DiffusionPipeline.from_pretrained(
    args.pretrained_model_name_or_path, revision=args.revision, torch_dtype=weight_dtype, safety_checker=None
)
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
pipeline = pipeline.to(accelerator.device)

# load attention processors
lora_attn_procs = {}
lora_layers_list = list([list(layer_list) for layer_list in control_lora.lora_layers])
for name in pipeline.unet.attn_processors.keys():
    cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
    if name.startswith("mid_block"):
        control_id = control_ids[-1]
    elif name.startswith("up_blocks"):
        block_id = int(name[len("up_blocks.")])
        control_id = list(reversed(control_ids))[block_id]
    elif name.startswith("down_blocks"):
        block_id = int(name[len("down_blocks.")])
        control_id = control_ids[block_id]

    lora_layers = lora_layers_list[control_id]
    if len(lora_layers) != 0:
        lora_layer = lora_layers.pop(0)
        lora_attn_procs[name] = lora_layer

pipeline.unet.set_attn_processor(lora_attn_procs)

# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
images = []
for _ in range(args.num_validation_images):
    with torch.no_grad():
        try:
            batch = next(val_iter)
        except:
            val_iter = iter(val_dataloader)
            batch = next(val_iter)
        target = batch["pixel_values"].to(dtype=weight_dtype)
        guide = batch["guide_values"].to(accelerator.device)
        _ = control_lora(guide).control_states
        image = pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]
        image = dataset_cls.cat_input(image, target, guide)
    images.append(image)

if accelerator.is_main_process:
    for tracker in accelerator.trackers:
        if tracker.name == "tensorboard":
            np_images = np.stack([np.asarray(img) for img in images])
            tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC")
        if tracker.name == "wandb":
            tracker.log(
                {
                    "test": [
                        wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
                        for i, image in enumerate(images)
                    ]
                }
            )

accelerator.end_training()

vae\diffusion_pytorch_model.safetensors not found
{'requires_safety_checker'} was not found in config. Values will be initialized to default values.
{'prediction_type'} was not found in config. Values will be initialized to default values.
{'scaling_factor'} was not found in config. Values will be initialized to default values.
{'projection_class_embeddings_input_dim', 'class_embeddings_concat', 'conv_in_kernel', 'cross_attention_norm', 'encoder_hid_dim', 'conv_out_kernel', 'class_embed_type', 'resnet_time_scale_shift', 'time_embedding_dim', 'time_embedding_act_fn', 'dual_cross_attention', 'time_cond_proj_dim', 'mid_block_type', 'addition_embed_type_num_heads', 'resnet_out_scale_factor', 'mid_block_only_cross_attention', 'num_class_embeds', 'time_embedding_type', 'addition_embed_type', 'resnet_skip_time_act', 'use_linear_projection', 'only_cross_attention', 'upcast_attention', 'timestep_post_act'} was not found in config. Values will be initialized to default values.
You have disabled 

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

0,1
train_loss,▂▁▁▃█▂█▃▁▁▁▁▁▃▁▁▂▂▂▁▅▁▄▂▄▃▁▃▁▅▁▂▃▁▁▁▄▁▁▁

0,1
train_loss,0.15092


## Testing the models

### Helper functions

In [30]:
import numpy as np
import cv2


def HWC3(x):
    assert x.dtype == np.uint8
    if x.ndim == 2:
        x = x[:, :, None]
    assert x.ndim == 3
    H, W, C = x.shape
    assert C == 1 or C == 3 or C == 4
    if C == 3:
        return x
    if C == 1:
        return np.concatenate([x, x, x], axis=2)
    if C == 4:
        color = x[:, :, 0:3].astype(np.float32)
        alpha = x[:, :, 3:4].astype(np.float32) / 255.0
        y = color * alpha + 255.0 * (1.0 - alpha)
        y = y.clip(0, 255).astype(np.uint8)
        return y


def resize_image(input_image, resolution):
    H, W, C = input_image.shape
    H = float(H)
    W = float(W)
    k = float(resolution) / min(H, W)
    H *= k
    W *= k
    H = int(np.round(H / 64.0)) * 64
    W = int(np.round(W / 64.0)) * 64
    img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
    return img


### Testing Canny Model

In [55]:
class CannyDetector:
    def __call__(self, img, low_threshold, high_threshold):
        return cv2.Canny(img, low_threshold, high_threshold)

In [56]:
import os
import shutil
from datetime import datetime
from PIL import Image

from diffusers import utils
from diffusers.utils import deprecation_utils
from diffusers.models import cross_attention
utils.deprecate = lambda *arg, **kwargs: None
deprecation_utils.deprecate = lambda *arg, **kwargs: None
cross_attention.deprecate = lambda *arg, **kwargs: None



import gradio as gr
import numpy as np
import torch
import random

from diffusers.models.unet_2d_condition import UNet2DConditionModel
from diffusers.pipelines import DiffusionPipeline
from diffusers.schedulers import DPMSolverMultistepScheduler

apply_canny = CannyDetector()

device = 'cuda' if torch.cuda.is_available() else 'cpu'

pipeline = DiffusionPipeline.from_pretrained(
    'runwayml/stable-diffusion-v1-5', safety_checker=None
)
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
pipeline = pipeline.to(device)
unet: UNet2DConditionModel = pipeline.unet

control_lora = ControlLoRA.from_pretrained('ckpts/sd-unsplash_5k_canny-model-control-lora')
control_lora = control_lora.to(device)


# load control lora attention processors
lora_attn_procs = {}
lora_layers_list = list([list(layer_list) for layer_list in control_lora.lora_layers])
n_ch = len(unet.config.block_out_channels)
control_ids = [i for i in range(n_ch)]
for name in pipeline.unet.attn_processors.keys():
    cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
    if name.startswith("mid_block"):
        control_id = control_ids[-1]
    elif name.startswith("up_blocks"):
        block_id = int(name[len("up_blocks.")])
        control_id = list(reversed(control_ids))[block_id]
    elif name.startswith("down_blocks"):
        block_id = int(name[len("down_blocks.")])
        control_id = control_ids[block_id]

    lora_layers = lora_layers_list[control_id]
    if len(lora_layers) != 0:
        lora_layer: ControlLoRACrossAttnProcessor = lora_layers.pop(0)
        lora_attn_procs[name] = lora_layer

unet.set_attn_processor(lora_attn_procs)

def save(*args):
    # unpack arguments
    input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, sample_steps, scale, seed, eta, low_threshold, high_threshold, result_images = args

    # check if result_images is empty
    if not result_images:
        return

    # create save directory
    timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
    save_folder = os.path.join(model_name, "saves", timestamp)
    os.makedirs(save_folder, exist_ok=True)

    # Saving parameters
    with open(os.path.join(save_folder, "parameters.txt"), "w") as f:
        params = [prompt, a_prompt, n_prompt, num_samples, image_resolution, sample_steps, scale, seed, eta, low_threshold, high_threshold]
        param_names = ["prompt", "a_prompt", "n_prompt", "num_samples", "image_resolution", "sample_steps", "scale", "seed", "eta", "low_threshold", "high_threshold"]
        for name, value in zip(param_names, params):
            f.write(f"{name}: {value}\n")

    # save original image
    input_image_pil = Image.fromarray(input_image)
    input_image_pil.save(os.path.join(save_folder, "original_image.png"))

    # Saving images
    for i, result_image_dict in enumerate(result_images):
        # get the filename of the result_image
        result_image_filename = result_image_dict['name']
        # read the image file as a numpy array
        result_image_np = cv2.imread(result_image_filename)
        # convert the numpy array to a PIL image
        result_image_pil = Image.fromarray(cv2.cvtColor(result_image_np, cv2.COLOR_BGR2RGB))
        if i == 0:
            # save the guide image
            result_image_pil.save(os.path.join(save_folder, "guide_image.png"))
        else:
            # save the generated image
            result_image_pil.save(os.path.join(save_folder, f"generated_image_{i}.png"))

def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, sample_steps, scale, seed, eta, low_threshold, high_threshold):
    with torch.no_grad():
        img = resize_image(HWC3(input_image), image_resolution)
        H, W, C = img.shape

        detected_map = apply_canny(img, low_threshold, high_threshold)
        detected_map = HWC3(detected_map)

        control = torch.from_numpy(detected_map[...,::-1].copy().transpose([2,0,1])).float().to(device)[None] / 127.5 - 1
        _ = control_lora(control).control_states

        if seed == -1:
            seed = random.randint(0, 65535)

        # run inference
        generator = torch.Generator(device=device).manual_seed(seed)
        images = []
        for i in range(num_samples):
            _ = control_lora(control).control_states
            image = pipeline(
                prompt + ', ' + a_prompt, negative_prompt=n_prompt, 
                num_inference_steps=sample_steps, guidance_scale=scale, eta=eta,
                generator=generator, height=H, width=W).images[0]
            images.append(np.asarray(image))
        
        results = images
    return [255 - detected_map] + results


block = gr.Blocks().queue()
with block:
    with gr.Row():
        gr.Markdown("## Control Stable Diffusion with Canny Edge Maps")
    with gr.Row():
        with gr.Column():
            input_image = gr.Image(source='upload', type="numpy")
            prompt = gr.Textbox(label="Prompt")
            run_button = gr.Button(label="Run")
            save_button = gr.Button(label="Save")
            with gr.Accordion("Advanced options", open=False):
                num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
                image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=256)
                low_threshold = gr.Slider(label="Canny low threshold", minimum=1, maximum=10, value=5, step=1)
                high_threshold = gr.Slider(label="Canny high threshold", minimum=130, maximum=150, value=140, step=1)
                sample_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
                scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
                seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
                eta = gr.Number(label="eta", value=0.0)
                a_prompt = gr.Textbox(label="Added Prompt", value='a high-quality, detailed, and professional image')
                n_prompt = gr.Textbox(label="Negative Prompt",
                                      value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
        with gr.Column():
            result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
    ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, sample_steps, scale, seed, eta, low_threshold, high_threshold]
    ips_save = ips + [result_gallery]
    run_button.click(fn=process, inputs=ips, outputs=[result_gallery])
    save_button.click(fn=save, inputs=ips_save, outputs=[])


block.launch(server_name='0.0.0.0', server_port=7860)


vae\diffusion_pytorch_model.safetensors not found
{'requires_safety_checker'} was not found in config. Values will be initialized to default values.
{'scaling_factor'} was not found in config. Values will be initialized to default values.
{'timestep_post_act', 'cross_attention_norm', 'mid_block_type', 'num_class_embeds', 'mid_block_only_cross_attention', 'use_linear_projection', 'time_embedding_act_fn', 'addition_embed_type', 'resnet_out_scale_factor', 'addition_embed_type_num_heads', 'time_embedding_dim', 'dual_cross_attention', 'conv_in_kernel', 'time_embedding_type', 'projection_class_embeddings_input_dim', 'class_embeddings_concat', 'resnet_time_scale_shift', 'only_cross_attention', 'time_cond_proj_dim', 'encoder_hid_dim', 'conv_out_kernel', 'class_embed_type', 'upcast_attention', 'resnet_skip_time_act'} was not found in config. Values will be initialized to default values.
{'prediction_type'} was not found in config. Values will be initialized to default values.
You have disabled 

Running on local URL:  http://0.0.0.0:7860

To create a public link, set `share=True` in `launch()`.




  deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False)


  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False)


  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False)


  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False)


  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False)


  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

In [57]:
block.close()

Closing server running on port: 7860


### Test Blur

In [31]:
class GaussianBlur:
    def __call__(self, img, kernel_size, sigmaX):
        return cv2.GaussianBlur(img, (kernel_size, kernel_size), sigmaX)

In [None]:
import os
import shutil
from datetime import datetime
from PIL import Image

from diffusers import utils
from diffusers.utils import deprecation_utils
from diffusers.models import cross_attention
utils.deprecate = lambda *arg, **kwargs: None
deprecation_utils.deprecate = lambda *arg, **kwargs: None
cross_attention.deprecate = lambda *arg, **kwargs: None

from diffusers.models.unet_2d_condition import UNet2DConditionModel
from diffusers.pipelines import DiffusionPipeline
from diffusers.schedulers import DPMSolverMultistepScheduler


import gradio as gr


apply_gaussian_blur = GaussianBlur()

device = 'cuda' if torch.cuda.is_available() else 'cpu'

pipeline = DiffusionPipeline.from_pretrained(
    'runwayml/stable-diffusion-v1-5', safety_checker=None
)
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
pipeline = pipeline.to(device)
unet: UNet2DConditionModel = pipeline.unet

control_lora = ControlLoRA.from_pretrained('ckpts\sd-unsplash_5k_blur_61KS-model-control-lora')
control_lora = control_lora.to(device)


# load control lora attention processors
lora_attn_procs = {}
lora_layers_list = list([list(layer_list) for layer_list in control_lora.lora_layers])
n_ch = len(unet.config.block_out_channels)
control_ids = [i for i in range(n_ch)]
for name in pipeline.unet.attn_processors.keys():
    cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
    if name.startswith("mid_block"):
        control_id = control_ids[-1]
    elif name.startswith("up_blocks"):
        block_id = int(name[len("up_blocks.")])
        control_id = list(reversed(control_ids))[block_id]
    elif name.startswith("down_blocks"):
        block_id = int(name[len("down_blocks.")])
        control_id = control_ids[block_id]

    lora_layers = lora_layers_list[control_id]
    if len(lora_layers) != 0:
        lora_layer: ControlLoRACrossAttnProcessor = lora_layers.pop(0)
        lora_attn_procs[name] = lora_layer

unet.set_attn_processor(lora_attn_procs)


def save(*args):
    # unpack arguments
    input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, sample_steps, scale, seed, eta, kernel_size, sigmaX, result_images = args

    # check if result_images is empty
    if not result_images:
        return

    # create save directory
    timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
    save_folder = os.path.join(model_name, "saves", timestamp)
    os.makedirs(save_folder, exist_ok=True)

    # Saving parameters
    with open(os.path.join(save_folder, "parameters.txt"), "w") as f:
        params = [prompt, a_prompt, n_prompt, num_samples, image_resolution, sample_steps, scale, seed, eta, kernel_size, sigmaX]
        param_names = ["prompt", "a_prompt", "n_prompt", "num_samples", "image_resolution", "sample_steps", "scale", "seed", "eta", "kernel_size", "sigmaX"]
        for name, value in zip(param_names, params):
            f.write(f"{name}: {value}\n")

    # save original image
    input_image_pil = Image.fromarray(input_image)
    input_image_pil.save(os.path.join(save_folder, "original_image.png"))

    # Saving images
    for i, result_image_dict in enumerate(result_images):
        # get the filename of the result_image
        result_image_filename = result_image_dict['name']
        # read the image file as a numpy array
        result_image_np = cv2.imread(result_image_filename)
        # convert the numpy array to a PIL image
        result_image_pil = Image.fromarray(cv2.cvtColor(result_image_np, cv2.COLOR_BGR2RGB))
        if i == 0:
            # save the guide image
            result_image_pil.save(os.path.join(save_folder, "guide_image.png"))
        else:
            # save the generated image
            result_image_pil.save(os.path.join(save_folder, f"generated_image_{i}.png"))

def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, sample_steps, scale, seed, eta, kernel_size, sigmaX):
    with torch.no_grad():
        img = resize_image(HWC3(input_image), image_resolution)
        H, W, C = img.shape

        blur_map = apply_gaussian_blur(img, kernel_size, sigmaX)
        blur_map = HWC3(blur_map)

        control = torch.from_numpy(blur_map[...,::-1].copy().transpose([2,0,1])).float().to(device)[None] / 127.5 - 1
        _ = control_lora(control).control_states

        if seed == -1:
            seed = random.randint(0, 65535)

        # run inference
        generator = torch.Generator(device=device).manual_seed(seed)
        images = []
        for i in range(num_samples):
            _ = control_lora(control).control_states
            image = pipeline(
                prompt + ', ' + a_prompt, negative_prompt=n_prompt, 
                num_inference_steps=sample_steps, guidance_scale=scale, eta=eta,
                generator=generator, height=H, width=W).images[0]
            images.append(np.asarray(image))

        results = images
    return [blur_map] + results



block = gr.Blocks().queue()
with block:
    with gr.Row():
        gr.Markdown("## Control Stable Diffusion with Blur")
    with gr.Row():
        with gr.Column():
            input_image = gr.Image(source='upload', type="numpy")
            prompt = gr.Textbox(label="Prompt")
            run_button = gr.Button("Run")
            save_button = gr.Button("Save")
            with gr.Accordion("Advanced options", open=False):
                num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
                image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=256)
                kernel_size = gr.Slider(label="Kernel size", minimum=1, maximum=101, value=61, step=2)
                sigmaX = gr.Slider(label="SigmaX", minimum=0, maximum=100, value=10, step=0.5)
                sample_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
                scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
                seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
                eta = gr.Number(label="eta", value=0.0)
                a_prompt = gr.Textbox(label="Added Prompt", value='a high-quality, detailed, and professional image')
                n_prompt = gr.Textbox(label="Negative Prompt",
                                      value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
        with gr.Column():
            result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
    ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, sample_steps, scale, seed, eta, kernel_size, sigmaX]
    ips_save = ips + [result_gallery]
    run_button.click(fn=process, inputs=ips, outputs=[result_gallery])
    save_button.click(fn=save, inputs=ips_save, outputs=[])



block.launch(server_name='0.0.0.0', server_port= 7861)


unet\diffusion_pytorch_model.safetensors not found
{'requires_safety_checker'} was not found in config. Values will be initialized to default values.
{'scaling_factor'} was not found in config. Values will be initialized to default values.
{'addition_embed_type', 'num_class_embeds', 'cross_attention_norm', 'use_linear_projection', 'conv_out_kernel', 'class_embeddings_concat', 'time_embedding_dim', 'only_cross_attention', 'upcast_attention', 'resnet_time_scale_shift', 'time_embedding_act_fn', 'resnet_out_scale_factor', 'mid_block_only_cross_attention', 'dual_cross_attention', 'addition_embed_type_num_heads', 'resnet_skip_time_act', 'timestep_post_act', 'time_embedding_type', 'mid_block_type', 'conv_in_kernel', 'class_embed_type', 'encoder_hid_dim', 'projection_class_embeddings_input_dim', 'time_cond_proj_dim'} was not found in config. Values will be initialized to default values.
{'prediction_type'} was not found in config. Values will be initialized to default values.
You have disabled

Running on local URL:  http://0.0.0.0:7861

To create a public link, set `share=True` in `launch()`.




  deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False)


  0%|          | 0/30 [00:00<?, ?it/s]

  deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False)


  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False)


  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False)


  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False)


  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False)


  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False)


  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False)


  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False)


  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False)


  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

In [18]:
block.close()

Closing server running on port: 7861
