Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
532f5c5
formatting
kohya-ss Jan 27, 2025
86a2f3f
Fix gradient handling when Text Encoders are trained
kohya-ss Jan 27, 2025
b6a3093
call optimizer eval/train fn before/after validation
kohya-ss Jan 27, 2025
29f31d0
add network.train()/eval() for validation
kohya-ss Jan 27, 2025
0750859
validation: Implement timestep-based validation processing
kohya-ss Jan 27, 2025
42c0a9e
Merge branch 'sd3' into val-loss-improvement
kohya-ss Jan 27, 2025
45ec02b
use same noise for every validation
kohya-ss Jan 27, 2025
de830b8
Move progress bar to account for sampling image first
rockerBOO Jan 29, 2025
4a71687
不要な警告の削除
tsukimiya Feb 3, 2025
c5b803c
rng state management: Implement functions to get and set RNG states f…
kohya-ss Feb 4, 2025
a24db1d
fix: validation timestep generation fails on SD/SDXL training
kohya-ss Feb 4, 2025
0911683
set python random state
kohya-ss Feb 9, 2025
344845b
fix: validation with block swap
kohya-ss Feb 9, 2025
1772038
fix: unpause training progress bar after vaidation
kohya-ss Feb 11, 2025
cd80752
fix: remove unused parameter 'accelerator' from encode_images_to_late…
kohya-ss Feb 11, 2025
76b7619
fix: simplify validation step condition in NetworkTrainer
kohya-ss Feb 11, 2025
ab88b43
Fix validation epoch divergence
rockerBOO Feb 14, 2025
ee295c7
Merge pull request #1935 from rockerBOO/validation-epoch-fix
kohya-ss Feb 15, 2025
63337d9
Merge branch 'sd3' into val-loss-improvement
kohya-ss Feb 15, 2025
4671e23
Fix validation epoch loss to check epoch average
rockerBOO Feb 16, 2025
3c7496a
Fix sizes for validation split
rockerBOO Feb 17, 2025
f3a0109
Clear sizes for validation reg images to be consistent
rockerBOO Feb 17, 2025
6051fa8
Merge pull request #1940 from rockerBOO/split-size-fix
kohya-ss Feb 17, 2025
7c22e12
Merge pull request #1938 from rockerBOO/validation-epoch-loss-recorder
kohya-ss Feb 17, 2025
9436b41
Fix validation split and add test
rockerBOO Feb 17, 2025
894037f
Merge pull request #1943 from rockerBOO/validation-split-test
kohya-ss Feb 18, 2025
dc7d5fb
Merge branch 'sd3' into val-loss-improvement
kohya-ss Feb 18, 2025
4a36996
modify log step calculation
kohya-ss Feb 18, 2025
13df475
Remove position_ids for V2
yidiq7 Feb 20, 2025
efb2a12
fix wandb val logging
kohya-ss Feb 21, 2025
59ae9ea
Merge pull request #1945 from yidiq7/dev
kohya-ss Feb 24, 2025
905f081
Merge branch 'dev' into sd3
kohya-ss Feb 24, 2025
386b733
Merge pull request #1918 from tsukimiya/fix_vperd_warning
kohya-ss Feb 24, 2025
67fde01
Merge branch 'dev' into sd3
kohya-ss Feb 24, 2025
6e90c0f
Merge pull request #1909 from rockerBOO/progress_bar
kohya-ss Feb 24, 2025
f68702f
Update IPEX libs
Disty0 Feb 25, 2025
ae409e8
fix: FLUX/SD3 network training not working without caching latents cl…
kohya-ss Feb 26, 2025
b286304
Merge pull request #1953 from Disty0/dev
kohya-ss Feb 26, 2025
1fcac98
Merge branch 'sd3' into val-loss-improvement
kohya-ss Feb 26, 2025
4965189
Merge pull request #1903 from kohya-ss/val-loss-improvement
kohya-ss Feb 26, 2025
ec350c8
Merge branch 'dev' into sd3
kohya-ss Feb 26, 2025
3d79239
docs: update README to include recent improvements in validation loss…
kohya-ss Feb 26, 2025
7b83d50
Merge branch 'sd3' into lumina
rockerBOO Feb 27, 2025
ce2610d
Change system prompt to inject Prompt Start special token
rockerBOO Feb 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@ The command to install PyTorch is as follows:

### Recent Updates

Feb 26, 2025:

- Improve the validation loss calculation in `train_network.py`, `sdxl_train_network.py`, `flux_train_network.py`, and `sd3_train_network.py`. PR [#1903](https://github.com/kohya-ss/sd-scripts/pull/1903)
- The validation loss uses the fixed timestep sampling and the fixed random seed. This is to ensure that the validation loss is not fluctuated by the random values.

Jan 25, 2025:

- `train_network.py`, `sdxl_train_network.py`, `flux_train_network.py`, and `sd3_train_network.py` now support validation loss. PR [#1864](https://github.com/kohya-ss/sd-scripts/pull/1864) Thank you to rockerBOO!
Expand Down
57 changes: 14 additions & 43 deletions flux_train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,12 @@ def __init__(self):
self.is_schnell: Optional[bool] = None
self.is_swapping_blocks: bool = False

def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]):
def assert_extra_args(
self,
args,
train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset],
val_dataset_group: Optional[train_util.DatasetGroup],
):
super().assert_extra_args(args, train_dataset_group, val_dataset_group)
# sdxl_train_util.verify_sdxl_training_args(args)

Expand Down Expand Up @@ -323,7 +328,7 @@ def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) ->
self.noise_scheduler_copy = copy.deepcopy(noise_scheduler)
return noise_scheduler

def encode_images_to_latents(self, args, accelerator, vae, images):
def encode_images_to_latents(self, args, vae, images):
return vae.encode(images)

def shift_scale_latents(self, args, latents):
Expand All @@ -341,7 +346,7 @@ def get_noise_pred_and_target(
network,
weight_dtype,
train_unet,
is_train=True
is_train=True,
):
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
Expand Down Expand Up @@ -376,8 +381,7 @@ def get_noise_pred_and_target(
t5_attn_mask = None

def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask):
# if not args.split_mode:
# normal forward
# grad is enabled even if unet is not in train mode, because Text Encoder is in train mode
with torch.set_grad_enabled(is_train), accelerator.autocast():
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
model_pred = unet(
Expand All @@ -390,44 +394,6 @@ def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t
guidance=guidance_vec,
txt_attention_mask=t5_attn_mask,
)
"""
else:
# split forward to reduce memory usage
assert network.train_blocks == "single", "train_blocks must be single for split mode"
with accelerator.autocast():
# move flux lower to cpu, and then move flux upper to gpu
unet.to("cpu")
clean_memory_on_device(accelerator.device)
self.flux_upper.to(accelerator.device)

# upper model does not require grad
with torch.no_grad():
intermediate_img, intermediate_txt, vec, pe = self.flux_upper(
img=packed_noisy_model_input,
img_ids=img_ids,
txt=t5_out,
txt_ids=txt_ids,
y=l_pooled,
timesteps=timesteps / 1000,
guidance=guidance_vec,
txt_attention_mask=t5_attn_mask,
)

# move flux upper back to cpu, and then move flux lower to gpu
self.flux_upper.to("cpu")
clean_memory_on_device(accelerator.device)
unet.to(accelerator.device)

# lower model requires grad
intermediate_img.requires_grad_(True)
intermediate_txt.requires_grad_(True)
vec.requires_grad_(True)
pe.requires_grad_(True)

with torch.set_grad_enabled(is_train and train_unet):
model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask)
"""

return model_pred

model_pred = call_dit(
Expand Down Expand Up @@ -546,6 +512,11 @@ def forward(hidden_states):
text_encoder.to(te_weight_dtype) # fp8
prepare_fp8(text_encoder, weight_dtype)

def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
if self.is_swapping_blocks:
# prepare for next forward: because backward pass is not called, we need to prepare it here
accelerator.unwrap_model(unet).prepare_block_swap_before_forward()

def prepare_unet_with_accelerator(
self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module
) -> torch.nn.Module:
Expand Down
11 changes: 8 additions & 3 deletions library/device_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,13 @@
import gc

import torch
try:
# intel gpu support for pytorch older than 2.5
# ipex is not needed after pytorch 2.5
import intel_extension_for_pytorch as ipex # noqa
except Exception:
pass


try:
HAS_CUDA = torch.cuda.is_available()
Expand All @@ -14,8 +21,6 @@
HAS_MPS = False

try:
import intel_extension_for_pytorch as ipex # noqa

HAS_XPU = torch.xpu.is_available()
except Exception:
HAS_XPU = False
Expand Down Expand Up @@ -69,7 +74,7 @@ def init_ipex():

This function should run right after importing torch and before doing anything else.

If IPEX is not available, this function does nothing.
If xpu is not available, this function does nothing.
"""
try:
if HAS_XPU:
Expand Down
Loading