Skip to content

Commit

Permalink
✨ [Core] Add FreeU mechanism (huggingface#5164)
Browse files Browse the repository at this point in the history
* ✨ Added Fourier filter function to upsample blocks

* 🔧 Update Fourier_filter for float16 support

* ✨ Added UNetFreeUConfig to UNet model for FreeU adaptation 🛠️

* move unet to its original form and add fourier_filter to torch_utils.

* implement freeU enable mechanism

* implement disable mechanism

* resolution index.

* correct resolution idx condition.

* fix copies.

* no need to use resolution_idx in vae.

* spell out the kwargs

* proper config property

* fix attribution setting

* place unet hasattr properly.

* fix: attribute access.

* proper disable

* remove validation method.

* debug

* debug

* debug

* debug

* debug

* debug

* potential fix.

* add: doc.

* fix copies

* add: tests.

* add: support freeU in SDXL.

* set default value of resolution idx.

* set default values for resolution_idx.

* fix copies

* fix rest.

* fix copies

* address PR comments.

* run fix-copies

* move apply_free_u to utils and other minors.

* introduce support for video (unet3D)

* minor ups

* consistent fix-copies.

* consistent stuff

* fix-copies

* add: rest

* add: docs.

* fix: tests

* fix: doc path

* Apply suggestions from code review

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* style up

* move to techniques.

* add: slow test for sd freeu.

* add: slow test for sd freeu.

* add: slow test for sd freeu.

* add: slow test for sd freeu.

* add: slow test for sd freeu.

* add: slow test for sd freeu.

* add: slow test for video with freeu

* add: slow test for video with freeu

* add: slow test for video with freeu

* style

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
  • Loading branch information
4 people committed Oct 5, 2023
1 parent c4c9bf5 commit dddc62f
Show file tree
Hide file tree
Showing 10 changed files with 438 additions and 0 deletions.
79 changes: 79 additions & 0 deletions models/unet_2d_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from torch import nn

from ..utils import is_torch_version, logging
from ..utils.torch_utils import apply_freeu
from .activations import get_activation
from .attention import AdaGroupNorm
from .attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0
Expand Down Expand Up @@ -249,6 +250,7 @@ def get_up_block(
add_upsample,
resnet_eps,
resnet_act_fn,
resolution_idx=None,
transformer_layers_per_block=1,
num_attention_heads=None,
resnet_groups=None,
Expand Down Expand Up @@ -281,6 +283,7 @@ def get_up_block(
out_channels=out_channels,
prev_output_channel=prev_output_channel,
temb_channels=temb_channels,
resolution_idx=resolution_idx,
dropout=dropout,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
Expand All @@ -295,6 +298,7 @@ def get_up_block(
out_channels=out_channels,
prev_output_channel=prev_output_channel,
temb_channels=temb_channels,
resolution_idx=resolution_idx,
dropout=dropout,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
Expand All @@ -314,6 +318,7 @@ def get_up_block(
out_channels=out_channels,
prev_output_channel=prev_output_channel,
temb_channels=temb_channels,
resolution_idx=resolution_idx,
dropout=dropout,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
Expand All @@ -337,6 +342,7 @@ def get_up_block(
out_channels=out_channels,
prev_output_channel=prev_output_channel,
temb_channels=temb_channels,
resolution_idx=resolution_idx,
dropout=dropout,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
Expand All @@ -362,6 +368,7 @@ def get_up_block(
out_channels=out_channels,
prev_output_channel=prev_output_channel,
temb_channels=temb_channels,
resolution_idx=resolution_idx,
dropout=dropout,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
Expand All @@ -377,6 +384,7 @@ def get_up_block(
out_channels=out_channels,
prev_output_channel=prev_output_channel,
temb_channels=temb_channels,
resolution_idx=resolution_idx,
dropout=dropout,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
Expand All @@ -390,6 +398,7 @@ def get_up_block(
out_channels=out_channels,
prev_output_channel=prev_output_channel,
temb_channels=temb_channels,
resolution_idx=resolution_idx,
dropout=dropout,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
Expand All @@ -402,6 +411,7 @@ def get_up_block(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
resolution_idx=resolution_idx,
dropout=dropout,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
Expand All @@ -415,6 +425,7 @@ def get_up_block(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
resolution_idx=resolution_idx,
dropout=dropout,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
Expand All @@ -430,6 +441,7 @@ def get_up_block(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
resolution_idx=resolution_idx,
dropout=dropout,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
Expand All @@ -441,6 +453,7 @@ def get_up_block(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
resolution_idx=resolution_idx,
dropout=dropout,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
Expand Down Expand Up @@ -1993,6 +2006,7 @@ def __init__(
prev_output_channel: int,
out_channels: int,
temb_channels: int,
resolution_idx: int = None,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
Expand Down Expand Up @@ -2075,6 +2089,8 @@ def __init__(
else:
self.upsamplers = None

self.resolution_idx = resolution_idx

def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, scale: float = 1.0):
for resnet, attn in zip(self.resnets, self.attentions):
# pop res hidden states
Expand Down Expand Up @@ -2103,6 +2119,7 @@ def __init__(
out_channels: int,
prev_output_channel: int,
temb_channels: int,
resolution_idx: int = None,
dropout: float = 0.0,
num_layers: int = 1,
transformer_layers_per_block: int = 1,
Expand Down Expand Up @@ -2181,6 +2198,7 @@ def __init__(
self.upsamplers = None

self.gradient_checkpointing = False
self.resolution_idx = resolution_idx

def forward(
self,
Expand All @@ -2194,11 +2212,30 @@ def forward(
encoder_attention_mask: Optional[torch.FloatTensor] = None,
):
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
is_freeu_enabled = (
getattr(self, "s1", None)
and getattr(self, "s2", None)
and getattr(self, "b1", None)
and getattr(self, "b2", None)
)

for resnet, attn in zip(self.resnets, self.attentions):
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]

# FreeU: Only operate on the first two stages
if is_freeu_enabled:
hidden_states, res_hidden_states = apply_freeu(
self.resolution_idx,
hidden_states,
res_hidden_states,
s1=self.s1,
s2=self.s2,
b1=self.b1,
b2=self.b2,
)

hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)

if self.training and self.gradient_checkpointing:
Expand Down Expand Up @@ -2252,6 +2289,7 @@ def __init__(
prev_output_channel: int,
out_channels: int,
temb_channels: int,
resolution_idx: int = None,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
Expand Down Expand Up @@ -2292,12 +2330,33 @@ def __init__(
self.upsamplers = None

self.gradient_checkpointing = False
self.resolution_idx = resolution_idx

def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, scale: float = 1.0):
is_freeu_enabled = (
getattr(self, "s1", None)
and getattr(self, "s2", None)
and getattr(self, "b1", None)
and getattr(self, "b2", None)
)

for resnet in self.resnets:
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]

# FreeU: Only operate on the first two stages
if is_freeu_enabled:
hidden_states, res_hidden_states = apply_freeu(
self.resolution_idx,
hidden_states,
res_hidden_states,
s1=self.s1,
s2=self.s2,
b1=self.b1,
b2=self.b2,
)

hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)

if self.training and self.gradient_checkpointing:
Expand Down Expand Up @@ -2331,6 +2390,7 @@ def __init__(
self,
in_channels: int,
out_channels: int,
resolution_idx: int = None,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
Expand Down Expand Up @@ -2370,6 +2430,8 @@ def __init__(
else:
self.upsamplers = None

self.resolution_idx = resolution_idx

def forward(self, hidden_states, temb=None, scale: float = 1.0):
for resnet in self.resnets:
hidden_states = resnet(hidden_states, temb=temb, scale=scale)
Expand All @@ -2386,6 +2448,7 @@ def __init__(
self,
in_channels: int,
out_channels: int,
resolution_idx: int = None,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
Expand Down Expand Up @@ -2449,6 +2512,8 @@ def __init__(
else:
self.upsamplers = None

self.resolution_idx = resolution_idx

def forward(self, hidden_states, temb=None, scale: float = 1.0):
for resnet, attn in zip(self.resnets, self.attentions):
hidden_states = resnet(hidden_states, temb=temb, scale=scale)
Expand All @@ -2469,6 +2534,7 @@ def __init__(
prev_output_channel: int,
out_channels: int,
temb_channels: int,
resolution_idx: int = None,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
Expand Down Expand Up @@ -2553,6 +2619,8 @@ def __init__(
self.skip_norm = None
self.act = None

self.resolution_idx = resolution_idx

def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None, scale: float = 1.0):
for resnet in self.resnets:
# pop res hidden states
Expand Down Expand Up @@ -2589,6 +2657,7 @@ def __init__(
prev_output_channel: int,
out_channels: int,
temb_channels: int,
resolution_idx: int = None,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
Expand Down Expand Up @@ -2651,6 +2720,8 @@ def __init__(
self.skip_norm = None
self.act = None

self.resolution_idx = resolution_idx

def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None, scale: float = 1.0):
for resnet in self.resnets:
# pop res hidden states
Expand Down Expand Up @@ -2684,6 +2755,7 @@ def __init__(
prev_output_channel: int,
out_channels: int,
temb_channels: int,
resolution_idx: int = None,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
Expand Down Expand Up @@ -2743,6 +2815,7 @@ def __init__(
self.upsamplers = None

self.gradient_checkpointing = False
self.resolution_idx = resolution_idx

def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, scale: float = 1.0):
for resnet in self.resnets:
Expand Down Expand Up @@ -2784,6 +2857,7 @@ def __init__(
out_channels: int,
prev_output_channel: int,
temb_channels: int,
resolution_idx: int = None,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
Expand Down Expand Up @@ -2873,6 +2947,7 @@ def __init__(
self.upsamplers = None

self.gradient_checkpointing = False
self.resolution_idx = resolution_idx

def forward(
self,
Expand Down Expand Up @@ -2947,6 +3022,7 @@ def __init__(
in_channels: int,
out_channels: int,
temb_channels: int,
resolution_idx: int,
dropout: float = 0.0,
num_layers: int = 5,
resnet_eps: float = 1e-5,
Expand Down Expand Up @@ -2988,6 +3064,7 @@ def __init__(
self.upsamplers = None

self.gradient_checkpointing = False
self.resolution_idx = resolution_idx

def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, scale: float = 1.0):
res_hidden_states_tuple = res_hidden_states_tuple[-1]
Expand Down Expand Up @@ -3027,6 +3104,7 @@ def __init__(
in_channels: int,
out_channels: int,
temb_channels: int,
resolution_idx: int,
dropout: float = 0.0,
num_layers: int = 4,
resnet_eps: float = 1e-5,
Expand Down Expand Up @@ -3104,6 +3182,7 @@ def __init__(
self.upsamplers = None

self.gradient_checkpointing = False
self.resolution_idx = resolution_idx

def forward(
self,
Expand Down
33 changes: 33 additions & 0 deletions models/unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,7 @@ def __init__(
add_upsample=add_upsample,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resolution_idx=i,
resnet_groups=norm_num_groups,
cross_attention_dim=reversed_cross_attention_dim[i],
num_attention_heads=reversed_num_attention_heads[i],
Expand Down Expand Up @@ -733,6 +734,38 @@ def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value

def enable_freeu(self, s1, s2, b1, b2):
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
The suffixes after the scaling factors represent the stage blocks where they are being applied.
Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that
are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
Args:
s1 (`float`):
Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
mitigate the "oversmoothing effect" in the enhanced denoising process.
s2 (`float`):
Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
mitigate the "oversmoothing effect" in the enhanced denoising process.
b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
"""
for i, upsample_block in enumerate(self.up_blocks):
setattr(upsample_block, "s1", s1)
setattr(upsample_block, "s2", s2)
setattr(upsample_block, "b1", b1)
setattr(upsample_block, "b2", b2)

def disable_freeu(self):
"""Disables the FreeU mechanism."""
freeu_keys = {"s1", "s2", "b1", "b2"}
for i, upsample_block in enumerate(self.up_blocks):
for k in freeu_keys:
if hasattr(upsample_block, k) or getattr(upsample_block, k) is not None:
setattr(upsample_block, k, None)

def forward(
self,
sample: torch.FloatTensor,
Expand Down
Loading

0 comments on commit dddc62f

Please sign in to comment.