From 532f5c58a6e83a3400f82103f5854ff3f63d77d7 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 27 Jan 2025 20:50:42 +0900 Subject: [PATCH 01/27] formatting --- train_network.py | 229 ++++++++++++++++++++++------------------------- 1 file changed, 108 insertions(+), 121 deletions(-) diff --git a/train_network.py b/train_network.py index 2c3bb2aae..cc54be7cc 100644 --- a/train_network.py +++ b/train_network.py @@ -100,9 +100,7 @@ def generate_step_logs( if ( args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None ): # tracking d*lr value of unet. - logs["lr/d*lr"] = ( - optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"] - ) + logs["lr/d*lr"] = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"] else: idx = 0 if not args.network_train_unet_only: @@ -115,16 +113,17 @@ def generate_step_logs( logs[f"lr/d*lr/group{i}"] = ( lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"] ) - if ( - args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None - ): - logs[f"lr/d*lr/group{i}"] = ( - optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"] - ) + if args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None: + logs[f"lr/d*lr/group{i}"] = optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"] return logs - def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]): + def assert_extra_args( + self, + args, + train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], + val_dataset_group: Optional[train_util.DatasetGroup], + ): train_dataset_group.verify_bucket_reso_steps(64) if val_dataset_group is not None: val_dataset_group.verify_bucket_reso_steps(64) @@ -219,7 +218,7 @@ def get_noise_pred_and_target( network, weight_dtype, train_unet, - is_train=True + is_train=True, ): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified @@ -315,22 +314,22 @@ def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, # endregion def process_batch( - self, - batch, - text_encoders, - unet, - network, - vae, - noise_scheduler, - vae_dtype, - weight_dtype, - accelerator, - args, - text_encoding_strategy: strategy_base.TextEncodingStrategy, - tokenize_strategy: strategy_base.TokenizeStrategy, - is_train=True, - train_text_encoder=True, - train_unet=True + self, + batch, + text_encoders, + unet, + network, + vae, + noise_scheduler, + vae_dtype, + weight_dtype, + accelerator, + args, + text_encoding_strategy: strategy_base.TextEncodingStrategy, + tokenize_strategy: strategy_base.TokenizeStrategy, + is_train=True, + train_text_encoder=True, + train_unet=True, ) -> torch.Tensor: """ Process a batch for the network @@ -397,7 +396,7 @@ def process_batch( network, weight_dtype, train_unet, - is_train=is_train + is_train=is_train, ) huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) @@ -484,7 +483,7 @@ def train(self, args): else: # use arbitrary dataset class train_dataset_group = train_util.load_arbitrary_dataset(args) - val_dataset_group = None # placeholder until validation dataset supported for arbitrary + val_dataset_group = None # placeholder until validation dataset supported for arbitrary current_epoch = Value("i", 0) current_step = Value("i", 0) @@ -701,7 +700,7 @@ def train(self, args): num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers, ) - + val_dataloader = torch.utils.data.DataLoader( val_dataset_group if val_dataset_group is not None else [], shuffle=False, @@ -900,7 +899,9 @@ def load_model_hook(models, input_dir): accelerator.print("running training / 学習開始") accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") - accelerator.print(f" num validation images * repeats / 学習画像の数×繰り返し回数: {val_dataset_group.num_train_images if val_dataset_group is not None else 0}") + accelerator.print( + f" num validation images * repeats / 学習画像の数×繰り返し回数: {val_dataset_group.num_train_images if val_dataset_group is not None else 0}" + ) accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") accelerator.print(f" num epochs / epoch数: {num_train_epochs}") @@ -968,11 +969,11 @@ def load_model_hook(models, input_dir): "ss_huber_c": args.huber_c, "ss_fp8_base": bool(args.fp8_base), "ss_fp8_base_unet": bool(args.fp8_base_unet), - "ss_validation_seed": args.validation_seed, - "ss_validation_split": args.validation_split, - "ss_max_validation_steps": args.max_validation_steps, - "ss_validate_every_n_epochs": args.validate_every_n_epochs, - "ss_validate_every_n_steps": args.validate_every_n_steps, + "ss_validation_seed": args.validation_seed, + "ss_validation_split": args.validation_split, + "ss_max_validation_steps": args.max_validation_steps, + "ss_validate_every_n_epochs": args.validate_every_n_epochs, + "ss_validate_every_n_steps": args.validate_every_n_steps, } self.update_metadata(metadata, args) # architecture specific metadata @@ -1248,9 +1249,7 @@ def remove_model(old_ckpt_name): accelerator.log({}, step=0) validation_steps = ( - min(args.max_validation_steps, len(val_dataloader)) - if args.max_validation_steps is not None - else len(val_dataloader) + min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) ) # training loop @@ -1298,21 +1297,21 @@ def remove_model(old_ckpt_name): self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) loss = self.process_batch( - batch, - text_encoders, - unet, - network, - vae, - noise_scheduler, - vae_dtype, - weight_dtype, - accelerator, - args, - text_encoding_strategy, - tokenize_strategy, - is_train=True, - train_text_encoder=train_text_encoder, - train_unet=train_unet + batch, + text_encoders, + unet, + network, + vae, + noise_scheduler, + vae_dtype, + weight_dtype, + accelerator, + args, + text_encoding_strategy, + tokenize_strategy, + is_train=True, + train_text_encoder=train_text_encoder, + train_unet=train_unet, ) accelerator.backward(loss) @@ -1369,32 +1368,21 @@ def remove_model(old_ckpt_name): if args.scale_weight_norms: progress_bar.set_postfix(**{**max_mean_logs, **logs}) - if is_tracking: logs = self.generate_step_logs( - args, - current_loss, - avr_loss, - lr_scheduler, - lr_descriptions, - optimizer, - keys_scaled, - mean_norm, - maximum_norm + args, current_loss, avr_loss, lr_scheduler, lr_descriptions, optimizer, keys_scaled, mean_norm, maximum_norm ) accelerator.log(logs, step=global_step) # VALIDATION PER STEP should_validate_step = ( - args.validate_every_n_steps is not None - and global_step != 0 # Skip first step + args.validate_every_n_steps is not None + and global_step != 0 # Skip first step and global_step % args.validate_every_n_steps == 0 ) if accelerator.sync_gradients and validation_steps > 0 and should_validate_step: val_progress_bar = tqdm( - range(validation_steps), smoothing=0, - disable=not accelerator.is_local_main_process, - desc="validation steps" + range(validation_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="validation steps" ) for val_step, batch in enumerate(val_dataloader): if val_step >= validation_steps: @@ -1404,27 +1392,27 @@ def remove_model(old_ckpt_name): self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) loss = self.process_batch( - batch, - text_encoders, - unet, - network, - vae, - noise_scheduler, - vae_dtype, - weight_dtype, - accelerator, - args, - text_encoding_strategy, - tokenize_strategy, + batch, + text_encoders, + unet, + network, + vae, + noise_scheduler, + vae_dtype, + weight_dtype, + accelerator, + args, + text_encoding_strategy, + tokenize_strategy, is_train=False, - train_text_encoder=False, - train_unet=False + train_text_encoder=False, + train_unet=False, ) current_loss = loss.detach().item() val_step_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss) val_progress_bar.update(1) - val_progress_bar.set_postfix({ "val_avg_loss": val_step_loss_recorder.moving_average }) + val_progress_bar.set_postfix({"val_avg_loss": val_step_loss_recorder.moving_average}) if is_tracking: logs = { @@ -1436,26 +1424,25 @@ def remove_model(old_ckpt_name): if is_tracking: loss_validation_divergence = val_step_loss_recorder.moving_average - loss_recorder.moving_average logs = { - "loss/validation/step_average": val_step_loss_recorder.moving_average, - "loss/validation/step_divergence": loss_validation_divergence, + "loss/validation/step_average": val_step_loss_recorder.moving_average, + "loss/validation/step_divergence": loss_validation_divergence, } accelerator.log(logs, step=global_step) - + if global_step >= args.max_train_steps: break # EPOCH VALIDATION should_validate_epoch = ( - (epoch + 1) % args.validate_every_n_epochs == 0 - if args.validate_every_n_epochs is not None - else True + (epoch + 1) % args.validate_every_n_epochs == 0 if args.validate_every_n_epochs is not None else True ) if should_validate_epoch and len(val_dataloader) > 0: val_progress_bar = tqdm( - range(validation_steps), smoothing=0, - disable=not accelerator.is_local_main_process, - desc="epoch validation steps" + range(validation_steps), + smoothing=0, + disable=not accelerator.is_local_main_process, + desc="epoch validation steps", ) for val_step, batch in enumerate(val_dataloader): @@ -1466,43 +1453,43 @@ def remove_model(old_ckpt_name): self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) loss = self.process_batch( - batch, - text_encoders, - unet, - network, - vae, - noise_scheduler, - vae_dtype, - weight_dtype, - accelerator, - args, - text_encoding_strategy, - tokenize_strategy, + batch, + text_encoders, + unet, + network, + vae, + noise_scheduler, + vae_dtype, + weight_dtype, + accelerator, + args, + text_encoding_strategy, + tokenize_strategy, is_train=False, - train_text_encoder=False, - train_unet=False + train_text_encoder=False, + train_unet=False, ) current_loss = loss.detach().item() val_epoch_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss) val_progress_bar.update(1) - val_progress_bar.set_postfix({ "val_epoch_avg_loss": val_epoch_loss_recorder.moving_average }) + val_progress_bar.set_postfix({"val_epoch_avg_loss": val_epoch_loss_recorder.moving_average}) if is_tracking: logs = { - "loss/validation/epoch_current": current_loss, - "epoch": epoch + 1, - "val_step": (epoch * validation_steps) + val_step + "loss/validation/epoch_current": current_loss, + "epoch": epoch + 1, + "val_step": (epoch * validation_steps) + val_step, } accelerator.log(logs, step=global_step) if is_tracking: avr_loss: float = val_epoch_loss_recorder.moving_average - loss_validation_divergence = val_step_loss_recorder.moving_average - avr_loss + loss_validation_divergence = val_step_loss_recorder.moving_average - avr_loss logs = { - "loss/validation/epoch_average": avr_loss, - "loss/validation/epoch_divergence": loss_validation_divergence, - "epoch": epoch + 1 + "loss/validation/epoch_average": avr_loss, + "loss/validation/epoch_divergence": loss_validation_divergence, + "epoch": epoch + 1, } accelerator.log(logs, step=global_step) @@ -1510,7 +1497,7 @@ def remove_model(old_ckpt_name): if is_tracking: logs = {"loss/epoch_average": loss_recorder.moving_average, "epoch": epoch + 1} accelerator.log(logs, step=global_step) - + accelerator.wait_for_everyone() # 指定エポックごとにモデルを保存 @@ -1696,31 +1683,31 @@ def setup_parser() -> argparse.ArgumentParser: "--validation_seed", type=int, default=None, - help="Validation seed for shuffling validation dataset, training `--seed` used otherwise / 検証データセットをシャッフルするための検証シード、それ以外の場合はトレーニング `--seed` を使用する" + help="Validation seed for shuffling validation dataset, training `--seed` used otherwise / 検証データセットをシャッフルするための検証シード、それ以外の場合はトレーニング `--seed` を使用する", ) parser.add_argument( "--validation_split", type=float, default=0.0, - help="Split for validation images out of the training dataset / 学習画像から検証画像に分割する割合" + help="Split for validation images out of the training dataset / 学習画像から検証画像に分割する割合", ) parser.add_argument( "--validate_every_n_steps", type=int, default=None, - help="Run validation on validation dataset every N steps. By default, validation will only occur every epoch if a validation dataset is available / 検証データセットの検証をNステップごとに実行します。デフォルトでは、検証データセットが利用可能な場合にのみ、検証はエポックごとに実行されます" + help="Run validation on validation dataset every N steps. By default, validation will only occur every epoch if a validation dataset is available / 検証データセットの検証をNステップごとに実行します。デフォルトでは、検証データセットが利用可能な場合にのみ、検証はエポックごとに実行されます", ) parser.add_argument( "--validate_every_n_epochs", type=int, default=None, - help="Run validation dataset every N epochs. By default, validation will run every epoch if a validation dataset is available / 検証データセットをNエポックごとに実行します。デフォルトでは、検証データセットが利用可能な場合、検証はエポックごとに実行されます" + help="Run validation dataset every N epochs. By default, validation will run every epoch if a validation dataset is available / 検証データセットをNエポックごとに実行します。デフォルトでは、検証データセットが利用可能な場合、検証はエポックごとに実行されます", ) parser.add_argument( "--max_validation_steps", type=int, default=None, - help="Max number of validation dataset items processed. By default, validation will run the entire validation dataset / 処理される検証データセット項目の最大数。デフォルトでは、検証は検証データセット全体を実行します" + help="Max number of validation dataset items processed. By default, validation will run the entire validation dataset / 処理される検証データセット項目の最大数。デフォルトでは、検証は検証データセット全体を実行します", ) return parser From 86a2f3fd262e52b3249d9f5508efe4774f1fa3ed Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 27 Jan 2025 21:10:52 +0900 Subject: [PATCH 02/27] Fix gradient handling when Text Encoders are trained --- flux_train_network.py | 43 ++----------------------------------------- sd3_train_network.py | 2 +- train_network.py | 10 +++++----- 3 files changed, 8 insertions(+), 47 deletions(-) diff --git a/flux_train_network.py b/flux_train_network.py index 5cd1b9d51..475bd751b 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -376,9 +376,8 @@ def get_noise_pred_and_target( t5_attn_mask = None def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask): - # if not args.split_mode: - # normal forward - with torch.set_grad_enabled(is_train and train_unet), accelerator.autocast(): + # grad is enabled even if unet is not in train mode, because Text Encoder is in train mode + with torch.set_grad_enabled(is_train), accelerator.autocast(): # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) model_pred = unet( img=img, @@ -390,44 +389,6 @@ def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t guidance=guidance_vec, txt_attention_mask=t5_attn_mask, ) - """ - else: - # split forward to reduce memory usage - assert network.train_blocks == "single", "train_blocks must be single for split mode" - with accelerator.autocast(): - # move flux lower to cpu, and then move flux upper to gpu - unet.to("cpu") - clean_memory_on_device(accelerator.device) - self.flux_upper.to(accelerator.device) - - # upper model does not require grad - with torch.no_grad(): - intermediate_img, intermediate_txt, vec, pe = self.flux_upper( - img=packed_noisy_model_input, - img_ids=img_ids, - txt=t5_out, - txt_ids=txt_ids, - y=l_pooled, - timesteps=timesteps / 1000, - guidance=guidance_vec, - txt_attention_mask=t5_attn_mask, - ) - - # move flux upper back to cpu, and then move flux lower to gpu - self.flux_upper.to("cpu") - clean_memory_on_device(accelerator.device) - unet.to(accelerator.device) - - # lower model requires grad - intermediate_img.requires_grad_(True) - intermediate_txt.requires_grad_(True) - vec.requires_grad_(True) - pe.requires_grad_(True) - - with torch.set_grad_enabled(is_train and train_unet): - model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask) - """ - return model_pred model_pred = call_dit( diff --git a/sd3_train_network.py b/sd3_train_network.py index dcf497f53..2f4579492 100644 --- a/sd3_train_network.py +++ b/sd3_train_network.py @@ -345,7 +345,7 @@ def get_noise_pred_and_target( t5_attn_mask = None # call model - with torch.set_grad_enabled(is_train and train_unet), accelerator.autocast(): + with torch.set_grad_enabled(is_train), accelerator.autocast(): # TODO support attention mask model_pred = unet(noisy_model_input, timesteps, context=context, y=lg_pooled) diff --git a/train_network.py b/train_network.py index cc54be7cc..6f1652fd9 100644 --- a/train_network.py +++ b/train_network.py @@ -232,7 +232,7 @@ def get_noise_pred_and_target( t.requires_grad_(True) # Predict the noise residual - with torch.set_grad_enabled(is_train and train_unet), accelerator.autocast(): + with torch.set_grad_enabled(is_train), accelerator.autocast(): noise_pred = self.call_unet( args, accelerator, @@ -1405,8 +1405,8 @@ def remove_model(old_ckpt_name): text_encoding_strategy, tokenize_strategy, is_train=False, - train_text_encoder=False, - train_unet=False, + train_text_encoder=train_text_encoder, # this is needed for validation because Text Encoders must be called if train_text_encoder is True + train_unet=train_unet, ) current_loss = loss.detach().item() @@ -1466,8 +1466,8 @@ def remove_model(old_ckpt_name): text_encoding_strategy, tokenize_strategy, is_train=False, - train_text_encoder=False, - train_unet=False, + train_text_encoder=train_text_encoder, + train_unet=train_unet, ) current_loss = loss.detach().item() From b6a309321675b5d0a59b776ffb4d0ecdd3d28ec2 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 27 Jan 2025 21:22:11 +0900 Subject: [PATCH 03/27] call optimizer eval/train fn before/after validation --- train_network.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/train_network.py b/train_network.py index 6f1652fd9..e735c582d 100644 --- a/train_network.py +++ b/train_network.py @@ -1381,6 +1381,8 @@ def remove_model(old_ckpt_name): and global_step % args.validate_every_n_steps == 0 ) if accelerator.sync_gradients and validation_steps > 0 and should_validate_step: + optimizer_eval_fn() + val_progress_bar = tqdm( range(validation_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="validation steps" ) @@ -1429,6 +1431,8 @@ def remove_model(old_ckpt_name): } accelerator.log(logs, step=global_step) + optimizer_train_fn() + if global_step >= args.max_train_steps: break @@ -1438,6 +1442,8 @@ def remove_model(old_ckpt_name): ) if should_validate_epoch and len(val_dataloader) > 0: + optimizer_eval_fn() + val_progress_bar = tqdm( range(validation_steps), smoothing=0, @@ -1493,6 +1499,8 @@ def remove_model(old_ckpt_name): } accelerator.log(logs, step=global_step) + optimizer_train_fn() + # END OF EPOCH if is_tracking: logs = {"loss/epoch_average": loss_recorder.moving_average, "epoch": epoch + 1} From 29f31d005f12a08650389164fa9c60504928d451 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 27 Jan 2025 21:35:43 +0900 Subject: [PATCH 04/27] add network.train()/eval() for validation --- train_network.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/train_network.py b/train_network.py index e735c582d..9b8036f8b 100644 --- a/train_network.py +++ b/train_network.py @@ -1276,7 +1276,7 @@ def remove_model(old_ckpt_name): metadata["ss_epoch"] = str(epoch + 1) - accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet) + accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet) # network.train() is called here # TRAINING skipped_dataloader = None @@ -1382,6 +1382,7 @@ def remove_model(old_ckpt_name): ) if accelerator.sync_gradients and validation_steps > 0 and should_validate_step: optimizer_eval_fn() + accelerator.unwrap_model(network).eval() val_progress_bar = tqdm( range(validation_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="validation steps" @@ -1432,6 +1433,7 @@ def remove_model(old_ckpt_name): accelerator.log(logs, step=global_step) optimizer_train_fn() + accelerator.unwrap_model(network).train() if global_step >= args.max_train_steps: break @@ -1443,6 +1445,7 @@ def remove_model(old_ckpt_name): if should_validate_epoch and len(val_dataloader) > 0: optimizer_eval_fn() + accelerator.unwrap_model(network).eval() val_progress_bar = tqdm( range(validation_steps), @@ -1500,6 +1503,7 @@ def remove_model(old_ckpt_name): accelerator.log(logs, step=global_step) optimizer_train_fn() + accelerator.unwrap_model(network).train() # END OF EPOCH if is_tracking: From 0750859133eec7858052cd3f79106113fa786e94 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 27 Jan 2025 21:56:59 +0900 Subject: [PATCH 05/27] validation: Implement timestep-based validation processing --- sd3_train_network.py | 1 + train_network.py | 167 +++++++++++++++++++++++++------------------ 2 files changed, 100 insertions(+), 68 deletions(-) diff --git a/sd3_train_network.py b/sd3_train_network.py index 2f4579492..d4f131252 100644 --- a/sd3_train_network.py +++ b/sd3_train_network.py @@ -446,6 +446,7 @@ def forward(hidden_states): prepare_fp8(text_encoder, weight_dtype) def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype): + # TODO consider validation # drop cached text encoder outputs text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) if text_encoder_outputs_list is not None: diff --git a/train_network.py b/train_network.py index 9b8036f8b..a63e9d1e9 100644 --- a/train_network.py +++ b/train_network.py @@ -9,6 +9,7 @@ import time import json from multiprocessing import Value +import numpy as np import toml from tqdm import tqdm @@ -1248,10 +1249,6 @@ def remove_model(old_ckpt_name): # log empty object to commit the sample images to wandb accelerator.log({}, step=0) - validation_steps = ( - min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) - ) - # training loop if initial_step > 0: # only if skip_until_initial_step is specified for skip_epoch in range(epoch_to_start): # skip epochs @@ -1270,6 +1267,17 @@ def remove_model(old_ckpt_name): clean_memory_on_device(accelerator.device) + validation_steps = ( + min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) + ) + NUM_VALIDATION_TIMESTEPS = 4 # 200, 400, 600, 800 TODO make this configurable + min_timestep = 0 if args.min_timestep is None else args.min_timestep + max_timestep = noise_scheduler.num_train_timesteps if args.max_timestep is None else args.max_timestep + validation_timesteps = np.linspace(min_timestep, max_timestep, (NUM_VALIDATION_TIMESTEPS + 2), dtype=int)[1:-1] + validation_total_steps = validation_steps * len(validation_timesteps) + original_args_min_timestep = args.min_timestep + original_args_max_timestep = args.max_timestep + for epoch in range(epoch_to_start, num_train_epochs): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}\n") current_epoch.value = epoch + 1 @@ -1385,44 +1393,55 @@ def remove_model(old_ckpt_name): accelerator.unwrap_model(network).eval() val_progress_bar = tqdm( - range(validation_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="validation steps" + range(validation_total_steps), + smoothing=0, + disable=not accelerator.is_local_main_process, + desc="validation steps", ) + val_ts_step = 0 for val_step, batch in enumerate(val_dataloader): if val_step >= validation_steps: break - # temporary, for batch processing - self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) - - loss = self.process_batch( - batch, - text_encoders, - unet, - network, - vae, - noise_scheduler, - vae_dtype, - weight_dtype, - accelerator, - args, - text_encoding_strategy, - tokenize_strategy, - is_train=False, - train_text_encoder=train_text_encoder, # this is needed for validation because Text Encoders must be called if train_text_encoder is True - train_unet=train_unet, - ) - - current_loss = loss.detach().item() - val_step_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss) - val_progress_bar.update(1) - val_progress_bar.set_postfix({"val_avg_loss": val_step_loss_recorder.moving_average}) - - if is_tracking: - logs = { - "loss/validation/step_current": current_loss, - "val_step": (epoch * validation_steps) + val_step, - } - accelerator.log(logs, step=global_step) + for timestep in validation_timesteps: + # temporary, for batch processing + self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) + + args.min_timestep = args.max_timestep = timestep # dirty hack to change timestep + + loss = self.process_batch( + batch, + text_encoders, + unet, + network, + vae, + noise_scheduler, + vae_dtype, + weight_dtype, + accelerator, + args, + text_encoding_strategy, + tokenize_strategy, + is_train=False, + train_text_encoder=train_text_encoder, # this is needed for validation because Text Encoders must be called if train_text_encoder is True + train_unet=train_unet, + ) + + current_loss = loss.detach().item() + val_step_loss_recorder.add(epoch=epoch, step=val_ts_step, loss=current_loss) + val_progress_bar.update(1) + val_progress_bar.set_postfix( + {"val_avg_loss": val_step_loss_recorder.moving_average, "timestep": timestep} + ) + + if is_tracking: + logs = { + "loss/validation/step_current": current_loss, + "val_step": (epoch * validation_total_steps) + val_ts_step, + } + accelerator.log(logs, step=global_step) + + val_ts_step += 1 if is_tracking: loss_validation_divergence = val_step_loss_recorder.moving_average - loss_recorder.moving_average @@ -1432,6 +1451,8 @@ def remove_model(old_ckpt_name): } accelerator.log(logs, step=global_step) + args.min_timestep = original_args_min_timestep + args.max_timestep = original_args_max_timestep optimizer_train_fn() accelerator.unwrap_model(network).train() @@ -1448,49 +1469,57 @@ def remove_model(old_ckpt_name): accelerator.unwrap_model(network).eval() val_progress_bar = tqdm( - range(validation_steps), + range(validation_total_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="epoch validation steps", ) + val_ts_step = 0 for val_step, batch in enumerate(val_dataloader): if val_step >= validation_steps: break - # temporary, for batch processing - self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) + for timestep in validation_timesteps: + args.min_timestep = args.max_timestep = timestep - loss = self.process_batch( - batch, - text_encoders, - unet, - network, - vae, - noise_scheduler, - vae_dtype, - weight_dtype, - accelerator, - args, - text_encoding_strategy, - tokenize_strategy, - is_train=False, - train_text_encoder=train_text_encoder, - train_unet=train_unet, - ) + # temporary, for batch processing + self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) + + loss = self.process_batch( + batch, + text_encoders, + unet, + network, + vae, + noise_scheduler, + vae_dtype, + weight_dtype, + accelerator, + args, + text_encoding_strategy, + tokenize_strategy, + is_train=False, + train_text_encoder=train_text_encoder, + train_unet=train_unet, + ) - current_loss = loss.detach().item() - val_epoch_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss) - val_progress_bar.update(1) - val_progress_bar.set_postfix({"val_epoch_avg_loss": val_epoch_loss_recorder.moving_average}) + current_loss = loss.detach().item() + val_epoch_loss_recorder.add(epoch=epoch, step=val_ts_step, loss=current_loss) + val_progress_bar.update(1) + val_progress_bar.set_postfix( + {"val_epoch_avg_loss": val_epoch_loss_recorder.moving_average, "timestep": timestep} + ) - if is_tracking: - logs = { - "loss/validation/epoch_current": current_loss, - "epoch": epoch + 1, - "val_step": (epoch * validation_steps) + val_step, - } - accelerator.log(logs, step=global_step) + if is_tracking: + logs = { + "loss/validation/epoch_current": current_loss, + "epoch": epoch + 1, + "val_step": (epoch * validation_total_steps) + val_ts_step, + } + accelerator.log(logs, step=global_step) + + val_ts_step += 1 if is_tracking: avr_loss: float = val_epoch_loss_recorder.moving_average @@ -1502,6 +1531,8 @@ def remove_model(old_ckpt_name): } accelerator.log(logs, step=global_step) + args.min_timestep = original_args_min_timestep + args.max_timestep = original_args_max_timestep optimizer_train_fn() accelerator.unwrap_model(network).train() From 45ec02b2a8b5eb5af8f5b4877381dc4dcc596cb9 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 27 Jan 2025 22:10:38 +0900 Subject: [PATCH 06/27] use same noise for every validation --- flux_train_network.py | 1 - train_network.py | 6 ++++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/flux_train_network.py b/flux_train_network.py index aab025735..475bd751b 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -377,7 +377,6 @@ def get_noise_pred_and_target( def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask): # grad is enabled even if unet is not in train mode, because Text Encoder is in train mode - with torch.set_grad_enabled(is_train), accelerator.autocast(): # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) model_pred = unet( diff --git a/train_network.py b/train_network.py index a63e9d1e9..f0deb67ab 100644 --- a/train_network.py +++ b/train_network.py @@ -1391,6 +1391,8 @@ def remove_model(old_ckpt_name): if accelerator.sync_gradients and validation_steps > 0 and should_validate_step: optimizer_eval_fn() accelerator.unwrap_model(network).eval() + rng_state = torch.get_rng_state() + torch.manual_seed(args.validation_seed if args.validation_seed is not None else args.seed) val_progress_bar = tqdm( range(validation_total_steps), @@ -1451,6 +1453,7 @@ def remove_model(old_ckpt_name): } accelerator.log(logs, step=global_step) + torch.set_rng_state(rng_state) args.min_timestep = original_args_min_timestep args.max_timestep = original_args_max_timestep optimizer_train_fn() @@ -1467,6 +1470,8 @@ def remove_model(old_ckpt_name): if should_validate_epoch and len(val_dataloader) > 0: optimizer_eval_fn() accelerator.unwrap_model(network).eval() + rng_state = torch.get_rng_state() + torch.manual_seed(args.validation_seed if args.validation_seed is not None else args.seed) val_progress_bar = tqdm( range(validation_total_steps), @@ -1531,6 +1536,7 @@ def remove_model(old_ckpt_name): } accelerator.log(logs, step=global_step) + torch.set_rng_state(rng_state) args.min_timestep = original_args_min_timestep args.max_timestep = original_args_max_timestep optimizer_train_fn() From de830b89416f0671d7a1364a9262fa850c0669df Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Wed, 29 Jan 2025 00:02:45 -0500 Subject: [PATCH 07/27] Move progress bar to account for sampling image first --- train_network.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/train_network.py b/train_network.py index c3879531d..2deb736d6 100644 --- a/train_network.py +++ b/train_network.py @@ -1163,10 +1163,6 @@ def load_model_hook(models, input_dir): args.max_train_steps > initial_step ), f"max_train_steps should be greater than initial step / max_train_stepsは初期ステップより大きい必要があります: {args.max_train_steps} vs {initial_step}" - progress_bar = tqdm( - range(args.max_train_steps - initial_step), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps" - ) - epoch_to_start = 0 if initial_step > 0: if args.skip_until_initial_step: @@ -1271,6 +1267,10 @@ def remove_model(old_ckpt_name): clean_memory_on_device(accelerator.device) + progress_bar = tqdm( + range(args.max_train_steps - initial_step), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps" + ) + for epoch in range(epoch_to_start, num_train_epochs): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}\n") current_epoch.value = epoch + 1 From 4a71687d20787d78a30b7a0df327067f5c402999 Mon Sep 17 00:00:00 2001 From: tsukimiya <71832+tsukimiya@users.noreply.github.com> Date: Tue, 4 Feb 2025 00:42:27 +0900 Subject: [PATCH 08/27] =?UTF-8?q?=E4=B8=8D=E8=A6=81=E3=81=AA=E8=AD=A6?= =?UTF-8?q?=E5=91=8A=E3=81=AE=E5=89=8A=E9=99=A4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit (おそらく https://github.com/kohya-ss/sd-scripts/commit/be14c062674973d0e4fee1eb4527e04707bb72b8 の修正漏れ ) --- library/sdxl_train_util.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py index b74bea91a..f78d94244 100644 --- a/library/sdxl_train_util.py +++ b/library/sdxl_train_util.py @@ -345,8 +345,6 @@ def add_sdxl_training_arguments(parser: argparse.ArgumentParser): def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True): assert not args.v2, "v2 cannot be enabled in SDXL training / SDXL学習ではv2を有効にすることはできません" - if args.v_parameterization: - logger.warning("v_parameterization will be unexpected / SDXL学習ではv_parameterizationは想定外の動作になります") if args.clip_skip is not None: logger.warning("clip_skip will be unexpected / SDXL学習ではclip_skipは動作しません") From c5b803ce94bd70812e6979ac7b986a769659b14e Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 4 Feb 2025 21:59:09 +0900 Subject: [PATCH 09/27] rng state management: Implement functions to get and set RNG states for consistent validation --- train_network.py | 33 +++++++++++++++++++++++++++++---- 1 file changed, 29 insertions(+), 4 deletions(-) diff --git a/train_network.py b/train_network.py index f0deb67ab..b3c7ff524 100644 --- a/train_network.py +++ b/train_network.py @@ -1278,6 +1278,31 @@ def remove_model(old_ckpt_name): original_args_min_timestep = args.min_timestep original_args_max_timestep = args.max_timestep + def get_rng_state() -> tuple[torch.ByteTensor, Optional[torch.ByteTensor], tuple]: + cpu_rng_state = torch.get_rng_state() + if accelerator.device.type == "cuda": + gpu_rng_state = torch.cuda.get_rng_state() + elif accelerator.device.type == "xpu": + gpu_rng_state = torch.xpu.get_rng_state() + elif accelerator.device.type == "mps": + gpu_rng_state = torch.cuda.get_rng_state() + else: + gpu_rng_state = None + python_rng_state = random.getstate() + return (cpu_rng_state, gpu_rng_state, python_rng_state) + + def set_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTensor], tuple]): + cpu_rng_state, gpu_rng_state, python_rng_state = rng_states + torch.set_rng_state(cpu_rng_state) + if gpu_rng_state is not None: + if accelerator.device.type == "cuda": + torch.cuda.set_rng_state(gpu_rng_state) + elif accelerator.device.type == "xpu": + torch.xpu.set_rng_state(gpu_rng_state) + elif accelerator.device.type == "mps": + torch.cuda.set_rng_state(gpu_rng_state) + random.setstate(python_rng_state) + for epoch in range(epoch_to_start, num_train_epochs): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}\n") current_epoch.value = epoch + 1 @@ -1391,7 +1416,7 @@ def remove_model(old_ckpt_name): if accelerator.sync_gradients and validation_steps > 0 and should_validate_step: optimizer_eval_fn() accelerator.unwrap_model(network).eval() - rng_state = torch.get_rng_state() + rng_states = get_rng_state() torch.manual_seed(args.validation_seed if args.validation_seed is not None else args.seed) val_progress_bar = tqdm( @@ -1453,7 +1478,7 @@ def remove_model(old_ckpt_name): } accelerator.log(logs, step=global_step) - torch.set_rng_state(rng_state) + set_rng_state(rng_states) args.min_timestep = original_args_min_timestep args.max_timestep = original_args_max_timestep optimizer_train_fn() @@ -1470,7 +1495,7 @@ def remove_model(old_ckpt_name): if should_validate_epoch and len(val_dataloader) > 0: optimizer_eval_fn() accelerator.unwrap_model(network).eval() - rng_state = torch.get_rng_state() + rng_states = get_rng_state() torch.manual_seed(args.validation_seed if args.validation_seed is not None else args.seed) val_progress_bar = tqdm( @@ -1536,7 +1561,7 @@ def remove_model(old_ckpt_name): } accelerator.log(logs, step=global_step) - torch.set_rng_state(rng_state) + set_rng_state(rng_states) args.min_timestep = original_args_min_timestep args.max_timestep = original_args_max_timestep optimizer_train_fn() From a24db1d532a95cc9dd91aba25a06b8eb58db5cff Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 4 Feb 2025 22:02:42 +0900 Subject: [PATCH 10/27] fix: validation timestep generation fails on SD/SDXL training --- library/train_util.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index 37ed0a994..01fa64674 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5935,7 +5935,10 @@ def save_sd_model_on_train_end_common( def get_timesteps(min_timestep: int, max_timestep: int, b_size: int, device: torch.device) -> torch.Tensor: - timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device="cpu") + if min_timestep < max_timestep: + timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device="cpu") + else: + timesteps = torch.full((b_size,), max_timestep, device="cpu") timesteps = timesteps.long().to(device) return timesteps From 0911683717e439676bba758a5f7a29356984966c Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 9 Feb 2025 20:53:49 +0900 Subject: [PATCH 11/27] set python random state --- train_network.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/train_network.py b/train_network.py index b3c7ff524..083e5993d 100644 --- a/train_network.py +++ b/train_network.py @@ -1278,7 +1278,7 @@ def remove_model(old_ckpt_name): original_args_min_timestep = args.min_timestep original_args_max_timestep = args.max_timestep - def get_rng_state() -> tuple[torch.ByteTensor, Optional[torch.ByteTensor], tuple]: + def switch_rng_state(seed:int) -> tuple[torch.ByteTensor, Optional[torch.ByteTensor], tuple]: cpu_rng_state = torch.get_rng_state() if accelerator.device.type == "cuda": gpu_rng_state = torch.cuda.get_rng_state() @@ -1289,9 +1289,13 @@ def get_rng_state() -> tuple[torch.ByteTensor, Optional[torch.ByteTensor], tuple else: gpu_rng_state = None python_rng_state = random.getstate() + + torch.manual_seed(seed) + random.seed(seed) + return (cpu_rng_state, gpu_rng_state, python_rng_state) - def set_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTensor], tuple]): + def restore_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTensor], tuple]): cpu_rng_state, gpu_rng_state, python_rng_state = rng_states torch.set_rng_state(cpu_rng_state) if gpu_rng_state is not None: @@ -1416,8 +1420,7 @@ def set_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTensor] if accelerator.sync_gradients and validation_steps > 0 and should_validate_step: optimizer_eval_fn() accelerator.unwrap_model(network).eval() - rng_states = get_rng_state() - torch.manual_seed(args.validation_seed if args.validation_seed is not None else args.seed) + rng_states = switch_rng_state(args.validation_seed if args.validation_seed is not None else args.seed) val_progress_bar = tqdm( range(validation_total_steps), @@ -1478,7 +1481,7 @@ def set_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTensor] } accelerator.log(logs, step=global_step) - set_rng_state(rng_states) + restore_rng_state(rng_states) args.min_timestep = original_args_min_timestep args.max_timestep = original_args_max_timestep optimizer_train_fn() @@ -1495,8 +1498,7 @@ def set_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTensor] if should_validate_epoch and len(val_dataloader) > 0: optimizer_eval_fn() accelerator.unwrap_model(network).eval() - rng_states = get_rng_state() - torch.manual_seed(args.validation_seed if args.validation_seed is not None else args.seed) + rng_states = switch_rng_state(args.validation_seed if args.validation_seed is not None else args.seed) val_progress_bar = tqdm( range(validation_total_steps), @@ -1561,7 +1563,7 @@ def set_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTensor] } accelerator.log(logs, step=global_step) - set_rng_state(rng_states) + restore_rng_state(rng_states) args.min_timestep = original_args_min_timestep args.max_timestep = original_args_max_timestep optimizer_train_fn() From 344845b42941b48956dce94d614fbf32e900c70e Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 9 Feb 2025 21:25:40 +0900 Subject: [PATCH 12/27] fix: validation with block swap --- flux_train_network.py | 14 ++++++++++++-- sd3_train_network.py | 19 ++++++++++++++----- train_network.py | 18 +++++++++++------- 3 files changed, 37 insertions(+), 14 deletions(-) diff --git a/flux_train_network.py b/flux_train_network.py index 475bd751b..e97dfc5b8 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -36,7 +36,12 @@ def __init__(self): self.is_schnell: Optional[bool] = None self.is_swapping_blocks: bool = False - def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]): + def assert_extra_args( + self, + args, + train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], + val_dataset_group: Optional[train_util.DatasetGroup], + ): super().assert_extra_args(args, train_dataset_group, val_dataset_group) # sdxl_train_util.verify_sdxl_training_args(args) @@ -341,7 +346,7 @@ def get_noise_pred_and_target( network, weight_dtype, train_unet, - is_train=True + is_train=True, ): # Sample noise that we'll add to the latents noise = torch.randn_like(latents) @@ -507,6 +512,11 @@ def forward(hidden_states): text_encoder.to(te_weight_dtype) # fp8 prepare_fp8(text_encoder, weight_dtype) + def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype): + if self.is_swapping_blocks: + # prepare for next forward: because backward pass is not called, we need to prepare it here + accelerator.unwrap_model(unet).prepare_block_swap_before_forward() + def prepare_unet_with_accelerator( self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module ) -> torch.nn.Module: diff --git a/sd3_train_network.py b/sd3_train_network.py index d4f131252..216d93c58 100644 --- a/sd3_train_network.py +++ b/sd3_train_network.py @@ -26,7 +26,12 @@ def __init__(self): super().__init__() self.sample_prompts_te_outputs = None - def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]): + def assert_extra_args( + self, + args, + train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], + val_dataset_group: Optional[train_util.DatasetGroup], + ): # super().assert_extra_args(args, train_dataset_group) # sdxl_train_util.verify_sdxl_training_args(args) @@ -317,7 +322,7 @@ def get_noise_pred_and_target( network, weight_dtype, train_unet, - is_train=True + is_train=True, ): # Sample noise that we'll add to the latents noise = torch.randn_like(latents) @@ -445,15 +450,19 @@ def forward(hidden_states): text_encoder.to(te_weight_dtype) # fp8 prepare_fp8(text_encoder, weight_dtype) - def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype): - # TODO consider validation - # drop cached text encoder outputs + def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=True): + # drop cached text encoder outputs: in validation, we drop cached outputs deterministically by fixed seed text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) if text_encoder_outputs_list is not None: text_encodoing_strategy: strategy_sd3.Sd3TextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy() text_encoder_outputs_list = text_encodoing_strategy.drop_cached_text_encoder_outputs(*text_encoder_outputs_list) batch["text_encoder_outputs_list"] = text_encoder_outputs_list + def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype): + if self.is_swapping_blocks: + # prepare for next forward: because backward pass is not called, we need to prepare it here + accelerator.unwrap_model(unet).prepare_block_swap_before_forward() + def prepare_unet_with_accelerator( self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module ) -> torch.nn.Module: diff --git a/train_network.py b/train_network.py index 083e5993d..49013c708 100644 --- a/train_network.py +++ b/train_network.py @@ -309,7 +309,10 @@ def prepare_unet_with_accelerator( ) -> torch.nn.Module: return accelerator.prepare(unet) - def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype): + def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train: bool = True): + pass + + def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype): pass # endregion @@ -1278,7 +1281,7 @@ def remove_model(old_ckpt_name): original_args_min_timestep = args.min_timestep original_args_max_timestep = args.max_timestep - def switch_rng_state(seed:int) -> tuple[torch.ByteTensor, Optional[torch.ByteTensor], tuple]: + def switch_rng_state(seed: int) -> tuple[torch.ByteTensor, Optional[torch.ByteTensor], tuple]: cpu_rng_state = torch.get_rng_state() if accelerator.device.type == "cuda": gpu_rng_state = torch.cuda.get_rng_state() @@ -1330,8 +1333,8 @@ def restore_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTen with accelerator.accumulate(training_model): on_step_start_for_network(text_encoder, unet) - # temporary, for batch processing - self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) + # preprocess batch for each model + self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=True) loss = self.process_batch( batch, @@ -1434,8 +1437,7 @@ def restore_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTen break for timestep in validation_timesteps: - # temporary, for batch processing - self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) + self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=False) args.min_timestep = args.max_timestep = timestep # dirty hack to change timestep @@ -1471,6 +1473,7 @@ def restore_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTen } accelerator.log(logs, step=global_step) + self.on_validation_step_end(args, accelerator, network, text_encoders, unet, batch, weight_dtype) val_ts_step += 1 if is_tracking: @@ -1516,7 +1519,7 @@ def restore_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTen args.min_timestep = args.max_timestep = timestep # temporary, for batch processing - self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) + self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=False) loss = self.process_batch( batch, @@ -1551,6 +1554,7 @@ def restore_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTen } accelerator.log(logs, step=global_step) + self.on_validation_step_end(args, accelerator, network, text_encoders, unet, batch, weight_dtype) val_ts_step += 1 if is_tracking: From 177203818a024329efa74640a588674323363373 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 11 Feb 2025 21:42:46 +0900 Subject: [PATCH 13/27] fix: unpause training progress bar after vaidation --- train_network.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/train_network.py b/train_network.py index 49013c708..8bfb19258 100644 --- a/train_network.py +++ b/train_network.py @@ -1489,6 +1489,7 @@ def restore_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTen args.max_timestep = original_args_max_timestep optimizer_train_fn() accelerator.unwrap_model(network).train() + progress_bar.unpause() if global_step >= args.max_train_steps: break @@ -1572,6 +1573,7 @@ def restore_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTen args.max_timestep = original_args_max_timestep optimizer_train_fn() accelerator.unwrap_model(network).train() + progress_bar.unpause() # END OF EPOCH if is_tracking: From cd80752175c663ede2cb7995da652ed5f5f7f749 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 11 Feb 2025 21:42:58 +0900 Subject: [PATCH 14/27] fix: remove unused parameter 'accelerator' from encode_images_to_latents method --- flux_train_network.py | 2 +- sd3_train_network.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/flux_train_network.py b/flux_train_network.py index e97dfc5b8..def441559 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -328,7 +328,7 @@ def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> self.noise_scheduler_copy = copy.deepcopy(noise_scheduler) return noise_scheduler - def encode_images_to_latents(self, args, accelerator, vae, images): + def encode_images_to_latents(self, args, vae, images): return vae.encode(images) def shift_scale_latents(self, args, latents): diff --git a/sd3_train_network.py b/sd3_train_network.py index 216d93c58..cdb7aa4e3 100644 --- a/sd3_train_network.py +++ b/sd3_train_network.py @@ -304,7 +304,7 @@ def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.training_shift) return noise_scheduler - def encode_images_to_latents(self, args, accelerator, vae, images): + def encode_images_to_latents(self, args, vae, images): return vae.encode(images) def shift_scale_latents(self, args, latents): From 76b761943b5166f496aa1cb8ffbcc2d04469346a Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 11 Feb 2025 21:53:57 +0900 Subject: [PATCH 15/27] fix: simplify validation step condition in NetworkTrainer --- train_network.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/train_network.py b/train_network.py index 8bfb19258..99c58f49f 100644 --- a/train_network.py +++ b/train_network.py @@ -1414,12 +1414,9 @@ def restore_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTen ) accelerator.log(logs, step=global_step) - # VALIDATION PER STEP - should_validate_step = ( - args.validate_every_n_steps is not None - and global_step != 0 # Skip first step - and global_step % args.validate_every_n_steps == 0 - ) + # VALIDATION PER STEP: global_step is already incremented + # for example, if validate_every_n_steps=100, validate at step 100, 200, 300, ... + should_validate_step = args.validate_every_n_steps is not None and global_step % args.validate_every_n_steps == 0 if accelerator.sync_gradients and validation_steps > 0 and should_validate_step: optimizer_eval_fn() accelerator.unwrap_model(network).eval() From ab88b431b0c903f7a60ae59e22fbb8a7cf9d78a1 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Fri, 14 Feb 2025 11:14:38 -0500 Subject: [PATCH 16/27] Fix validation epoch divergence --- train_network.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_network.py b/train_network.py index c3879531d..b5f92e06b 100644 --- a/train_network.py +++ b/train_network.py @@ -1498,7 +1498,7 @@ def remove_model(old_ckpt_name): if is_tracking: avr_loss: float = val_epoch_loss_recorder.moving_average - loss_validation_divergence = val_step_loss_recorder.moving_average - avr_loss + loss_validation_divergence = val_epoch_loss_recorder.moving_average - avr_loss logs = { "loss/validation/epoch_average": avr_loss, "loss/validation/epoch_divergence": loss_validation_divergence, From 4671e237781dcfe9a16e90f5343afd57586a1df6 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sun, 16 Feb 2025 01:42:44 -0500 Subject: [PATCH 17/27] Fix validation epoch loss to check epoch average --- train_network.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_network.py b/train_network.py index b5f92e06b..674f1cb66 100644 --- a/train_network.py +++ b/train_network.py @@ -1498,7 +1498,7 @@ def remove_model(old_ckpt_name): if is_tracking: avr_loss: float = val_epoch_loss_recorder.moving_average - loss_validation_divergence = val_epoch_loss_recorder.moving_average - avr_loss + loss_validation_divergence = val_epoch_loss_recorder.moving_average - loss_recorder.moving_average logs = { "loss/validation/epoch_average": avr_loss, "loss/validation/epoch_divergence": loss_validation_divergence, From 3c7496ae3f2736a8283a881f49698d3e8f3a4291 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sun, 16 Feb 2025 22:18:14 -0500 Subject: [PATCH 18/27] Fix sizes for validation split --- library/train_util.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 37ed0a994..6c782ea1c 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -148,10 +148,11 @@ def split_train_val( paths: List[str], + sizes: List[Optional[Tuple[int, int]]], is_training_dataset: bool, validation_split: float, validation_seed: int | None -) -> List[str]: +) -> Tuple[List[str], List[Optional[Tuple[int, int]]]]: """ Split the dataset into train and validation @@ -172,10 +173,12 @@ def split_train_val( # Split the dataset between training and validation if is_training_dataset: # Training dataset we split to the first part - return paths[0:math.ceil(len(paths) * (1 - validation_split))] + split = math.ceil(len(paths) * (1 - validation_split)) + return paths[0:split], sizes[0:split] else: # Validation dataset we split to the second part - return paths[len(paths) - round(len(paths) * validation_split):] + split = len(paths) - round(len(paths) * validation_split) + return paths[split:], sizes[split:] class ImageInfo: @@ -1931,12 +1934,12 @@ def load_dreambooth_dir(subset: DreamBoothSubset): with open(info_cache_file, "r", encoding="utf-8") as f: metas = json.load(f) img_paths = list(metas.keys()) - sizes = [meta["resolution"] for meta in metas.values()] + sizes: List[Optional[Tuple[int, int]]] = [meta["resolution"] for meta in metas.values()] # we may need to check image size and existence of image files, but it takes time, so user should check it before training else: img_paths = glob_images(subset.image_dir, "*") - sizes = [None] * len(img_paths) + sizes: List[Optional[Tuple[int, int]]] = [None] * len(img_paths) # new caching: get image size from cache files strategy = LatentsCachingStrategy.get_strategy() @@ -1969,7 +1972,7 @@ def load_dreambooth_dir(subset: DreamBoothSubset): w, h = None, None if w is not None and h is not None: - sizes[i] = [w, h] + sizes[i] = (w, h) size_set_count += 1 logger.info(f"set image size from cache files: {size_set_count}/{len(img_paths)}") @@ -1990,8 +1993,9 @@ def load_dreambooth_dir(subset: DreamBoothSubset): # Otherwise the img_paths remain as original img_paths and no split # required for training images dataset of regularization images else: - img_paths = split_train_val( + img_paths, sizes = split_train_val( img_paths, + sizes, self.is_training_dataset, self.validation_split, self.validation_seed From f3a010978c0e4b88c4839b3a81400b8973f52158 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sun, 16 Feb 2025 22:28:34 -0500 Subject: [PATCH 19/27] Clear sizes for validation reg images to be consistent --- library/train_util.py | 1 + 1 file changed, 1 insertion(+) diff --git a/library/train_util.py b/library/train_util.py index 6c782ea1c..39b4af856 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1990,6 +1990,7 @@ def load_dreambooth_dir(subset: DreamBoothSubset): # Skip any validation dataset for regularization images if self.is_training_dataset is False: img_paths = [] + sizes = [] # Otherwise the img_paths remain as original img_paths and no split # required for training images dataset of regularization images else: From 9436b410617f22716eac64f7c604c8f53fa8c1a8 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Mon, 17 Feb 2025 14:28:41 -0500 Subject: [PATCH 20/27] Fix validation split and add test --- library/train_util.py | 8 ++++++-- tests/test_validation.py | 17 +++++++++++++++++ 2 files changed, 23 insertions(+), 2 deletions(-) create mode 100644 tests/test_validation.py diff --git a/library/train_util.py b/library/train_util.py index 39b4af856..b23290663 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -161,15 +161,19 @@ def split_train_val( [0:80] = 80 training images [80:] = 20 validation images """ + dataset = list(zip(paths, sizes)) if validation_seed is not None: logging.info(f"Using validation seed: {validation_seed}") prevstate = random.getstate() random.seed(validation_seed) - random.shuffle(paths) + random.shuffle(dataset) random.setstate(prevstate) else: - random.shuffle(paths) + random.shuffle(dataset) + paths, sizes = zip(*dataset) + paths = list(paths) + sizes = list(sizes) # Split the dataset between training and validation if is_training_dataset: # Training dataset we split to the first part diff --git a/tests/test_validation.py b/tests/test_validation.py new file mode 100644 index 000000000..f80686d8c --- /dev/null +++ b/tests/test_validation.py @@ -0,0 +1,17 @@ +from library.train_util import split_train_val + + +def test_split_train_val(): + paths = ["path1", "path2", "path3", "path4", "path5", "path6", "path7"] + sizes = [(1, 1), (2, 2), None, (4, 4), (5, 5), (6, 6), None] + result_paths, result_sizes = split_train_val(paths, sizes, True, 0.2, 1234) + assert result_paths == ["path2", "path3", "path6", "path5", "path1", "path4"], result_paths + assert result_sizes == [(2, 2), None, (6, 6), (5, 5), (1, 1), (4, 4)], result_sizes + + result_paths, result_sizes = split_train_val(paths, sizes, False, 0.2, 1234) + assert result_paths == ["path7"], result_paths + assert result_sizes == [None], result_sizes + + +if __name__ == "__main__": + test_split_train_val() From 4a369961346ca153a370728247449978d8a33415 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 18 Feb 2025 22:05:08 +0900 Subject: [PATCH 21/27] modify log step calculation --- train_network.py | 26 ++++++++++---------------- 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/train_network.py b/train_network.py index 47c4bb56e..93558da45 100644 --- a/train_network.py +++ b/train_network.py @@ -1464,11 +1464,10 @@ def restore_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTen ) if is_tracking: - logs = { - "loss/validation/step_current": current_loss, - "val_step": (epoch * validation_total_steps) + val_ts_step, - } - accelerator.log(logs, step=global_step) + logs = {"loss/validation/step_current": current_loss} + accelerator.log( + logs, step=global_step + val_ts_step + ) # a bit weird to log with global_step + val_ts_step self.on_validation_step_end(args, accelerator, network, text_encoders, unet, batch, weight_dtype) val_ts_step += 1 @@ -1545,25 +1544,20 @@ def restore_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTen ) if is_tracking: - logs = { - "loss/validation/epoch_current": current_loss, - "epoch": epoch + 1, - "val_step": (epoch * validation_total_steps) + val_ts_step, - } - accelerator.log(logs, step=global_step) + logs = {"loss/validation/epoch_current": current_loss} + accelerator.log(logs, step=global_step + val_ts_step) self.on_validation_step_end(args, accelerator, network, text_encoders, unet, batch, weight_dtype) val_ts_step += 1 if is_tracking: avr_loss: float = val_epoch_loss_recorder.moving_average - loss_validation_divergence = val_epoch_loss_recorder.moving_average - loss_recorder.moving_average + loss_validation_divergence = val_epoch_loss_recorder.moving_average - loss_recorder.moving_average logs = { "loss/validation/epoch_average": avr_loss, "loss/validation/epoch_divergence": loss_validation_divergence, - "epoch": epoch + 1, } - accelerator.log(logs, step=global_step) + accelerator.log(logs, step=epoch + 1) restore_rng_state(rng_states) args.min_timestep = original_args_min_timestep @@ -1574,8 +1568,8 @@ def restore_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTen # END OF EPOCH if is_tracking: - logs = {"loss/epoch_average": loss_recorder.moving_average, "epoch": epoch + 1} - accelerator.log(logs, step=global_step) + logs = {"loss/epoch_average": loss_recorder.moving_average} + accelerator.log(logs, step=epoch + 1) accelerator.wait_for_everyone() From 13df47516dda6e350b6aa79373b5a0e7287648b5 Mon Sep 17 00:00:00 2001 From: Yidi Date: Thu, 20 Feb 2025 04:49:51 -0500 Subject: [PATCH 22/27] Remove position_ids for V2 The postions_ids cause errors for the newer version of transformer. This has already been fixed in convert_ldm_clip_checkpoint_v1() but not in v2. The new code applies the same fix to convert_ldm_clip_checkpoint_v2(). --- library/model_util.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/library/model_util.py b/library/model_util.py index be410a026..9918c7b2a 100644 --- a/library/model_util.py +++ b/library/model_util.py @@ -643,16 +643,15 @@ def convert_key(key): new_sd[key_pfx + "k_proj" + key_suffix] = values[1] new_sd[key_pfx + "v_proj" + key_suffix] = values[2] - # rename or add position_ids + # remove position_ids for newer transformer, which causes error :( ANOTHER_POSITION_IDS_KEY = "text_model.encoder.text_model.embeddings.position_ids" if ANOTHER_POSITION_IDS_KEY in new_sd: # waifu diffusion v1.4 - position_ids = new_sd[ANOTHER_POSITION_IDS_KEY] del new_sd[ANOTHER_POSITION_IDS_KEY] - else: - position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64) - new_sd["text_model.embeddings.position_ids"] = position_ids + if "text_model.embeddings.position_ids" in new_sd: + del new_sd["text_model.embeddings.position_ids"] + return new_sd From efb2a128cd0d2c6340a21bf544e77853a20b3453 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 21 Feb 2025 22:07:35 +0900 Subject: [PATCH 23/27] fix wandb val logging --- library/train_util.py | 57 +++++++++++++++------------------ train_network.py | 73 ++++++++++++++++++++++++++++++++----------- 2 files changed, 80 insertions(+), 50 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 258701982..1f591c422 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -13,17 +13,7 @@ import shutil import time import typing -from typing import ( - Any, - Callable, - Dict, - List, - NamedTuple, - Optional, - Sequence, - Tuple, - Union -) +from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs, PartialState import glob import math @@ -146,12 +136,13 @@ TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz" TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX_SD3 = "_sd3_te.npz" + def split_train_val( - paths: List[str], + paths: List[str], sizes: List[Optional[Tuple[int, int]]], - is_training_dataset: bool, - validation_split: float, - validation_seed: int | None + is_training_dataset: bool, + validation_split: float, + validation_seed: int | None, ) -> Tuple[List[str], List[Optional[Tuple[int, int]]]]: """ Split the dataset into train and validation @@ -1842,7 +1833,7 @@ def get_item_for_caching(self, bucket, bucket_batch_size, image_index): class DreamBoothDataset(BaseDataset): IMAGE_INFO_CACHE_FILE = "metadata_cache.json" - # The is_training_dataset defines the type of dataset, training or validation + # The is_training_dataset defines the type of dataset, training or validation # if is_training_dataset is True -> training dataset # if is_training_dataset is False -> validation dataset def __init__( @@ -1981,29 +1972,25 @@ def load_dreambooth_dir(subset: DreamBoothSubset): logger.info(f"set image size from cache files: {size_set_count}/{len(img_paths)}") # We want to create a training and validation split. This should be improved in the future - # to allow a clearer distinction between training and validation. This can be seen as a + # to allow a clearer distinction between training and validation. This can be seen as a # short-term solution to limit what is necessary to implement validation datasets - # + # # We split the dataset for the subset based on if we are doing a validation split - # The self.is_training_dataset defines the type of dataset, training or validation + # The self.is_training_dataset defines the type of dataset, training or validation # if self.is_training_dataset is True -> training dataset # if self.is_training_dataset is False -> validation dataset if self.validation_split > 0.0: - # For regularization images we do not want to split this dataset. + # For regularization images we do not want to split this dataset. if subset.is_reg is True: # Skip any validation dataset for regularization images if self.is_training_dataset is False: img_paths = [] sizes = [] - # Otherwise the img_paths remain as original img_paths and no split + # Otherwise the img_paths remain as original img_paths and no split # required for training images dataset of regularization images else: img_paths, sizes = split_train_val( - img_paths, - sizes, - self.is_training_dataset, - self.validation_split, - self.validation_seed + img_paths, sizes, self.is_training_dataset, self.validation_split, self.validation_seed ) logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files") @@ -2373,7 +2360,7 @@ def __init__( bucket_no_upscale: bool, debug_dataset: bool, validation_split: float, - validation_seed: Optional[int], + validation_seed: Optional[int], ) -> None: super().__init__(resolution, network_multiplier, debug_dataset) @@ -2431,9 +2418,9 @@ def __init__( self.image_data = self.dreambooth_dataset_delegate.image_data self.batch_size = batch_size self.num_train_images = self.dreambooth_dataset_delegate.num_train_images - self.num_reg_images = self.dreambooth_dataset_delegate.num_reg_images + self.num_reg_images = self.dreambooth_dataset_delegate.num_reg_images self.validation_split = validation_split - self.validation_seed = validation_seed + self.validation_seed = validation_seed # assert all conditioning data exists missing_imgs = [] @@ -5952,7 +5939,9 @@ def get_timesteps(min_timestep: int, max_timestep: int, b_size: int, device: tor return timesteps -def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.IntTensor]: +def get_noise_noisy_latents_and_timesteps( + args, noise_scheduler, latents: torch.FloatTensor +) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.IntTensor]: # Sample noise that we'll add to the latents noise = torch.randn_like(latents, device=latents.device) if args.noise_offset: @@ -6444,7 +6433,7 @@ def sample_image_inference( wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption -def init_trackers(accelerator: Accelerator, args: argparse.Namespace, default_tracker_name: str): +def init_trackers(accelerator: Accelerator, args: argparse.Namespace, default_tracker_name: str): """ Initialize experiment trackers with tracker specific behaviors """ @@ -6461,13 +6450,17 @@ def init_trackers(accelerator: Accelerator, args: argparse.Namespace, default_tr ) if "wandb" in [tracker.name for tracker in accelerator.trackers]: - import wandb + import wandb + wandb_tracker = accelerator.get_tracker("wandb", unwrap=True) # Define specific metrics to handle validation and epochs "steps" wandb_tracker.define_metric("epoch", hidden=True) wandb_tracker.define_metric("val_step", hidden=True) + wandb_tracker.define_metric("global_step", hidden=True) + + # endregion diff --git a/train_network.py b/train_network.py index 93558da45..ab5483deb 100644 --- a/train_network.py +++ b/train_network.py @@ -119,6 +119,45 @@ def generate_step_logs( return logs + def step_logging(self, accelerator: Accelerator, logs: dict, global_step: int, epoch: int): + self.accelerator_logging(accelerator, logs, global_step, global_step, epoch) + + def epoch_logging(self, accelerator: Accelerator, logs: dict, global_step: int, epoch: int): + self.accelerator_logging(accelerator, logs, epoch, global_step, epoch) + + def val_logging(self, accelerator: Accelerator, logs: dict, global_step: int, epoch: int, val_step: int): + self.accelerator_logging(accelerator, logs, global_step + val_step, global_step, epoch, val_step) + + def accelerator_logging( + self, accelerator: Accelerator, logs: dict, step_value: int, global_step: int, epoch: int, val_step: Optional[int] = None + ): + """ + step_value is for tensorboard, other values are for wandb + """ + tensorboard_tracker = None + wandb_tracker = None + other_trackers = [] + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + tensorboard_tracker = accelerator.get_tracker("tensorboard") + elif tracker.name == "wandb": + wandb_tracker = accelerator.get_tracker("wandb") + else: + other_trackers.append(accelerator.get_tracker(tracker.name)) + + if tensorboard_tracker is not None: + tensorboard_tracker.log(logs, step=step_value) + + if wandb_tracker is not None: + logs["global_step"] = global_step + logs["epoch"] = epoch + if val_step is not None: + logs["val_step"] = val_step + wandb_tracker.log(logs) + + for tracker in other_trackers: + tracker.log(logs, step=step_value) + def assert_extra_args( self, args, @@ -1412,7 +1451,7 @@ def restore_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTen logs = self.generate_step_logs( args, current_loss, avr_loss, lr_scheduler, lr_descriptions, optimizer, keys_scaled, mean_norm, maximum_norm ) - accelerator.log(logs, step=global_step) + self.step_logging(accelerator, logs, global_step, epoch + 1) # VALIDATION PER STEP: global_step is already incremented # for example, if validate_every_n_steps=100, validate at step 100, 200, 300, ... @@ -1428,7 +1467,7 @@ def restore_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTen disable=not accelerator.is_local_main_process, desc="validation steps", ) - val_ts_step = 0 + val_timesteps_step = 0 for val_step, batch in enumerate(val_dataloader): if val_step >= validation_steps: break @@ -1457,20 +1496,18 @@ def restore_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTen ) current_loss = loss.detach().item() - val_step_loss_recorder.add(epoch=epoch, step=val_ts_step, loss=current_loss) + val_step_loss_recorder.add(epoch=epoch, step=val_timesteps_step, loss=current_loss) val_progress_bar.update(1) val_progress_bar.set_postfix( {"val_avg_loss": val_step_loss_recorder.moving_average, "timestep": timestep} ) - if is_tracking: - logs = {"loss/validation/step_current": current_loss} - accelerator.log( - logs, step=global_step + val_ts_step - ) # a bit weird to log with global_step + val_ts_step + # if is_tracking: + # logs = {f"loss/validation/step_current_{timestep}": current_loss} + # self.val_logging(accelerator, logs, global_step, epoch + 1, val_step) self.on_validation_step_end(args, accelerator, network, text_encoders, unet, batch, weight_dtype) - val_ts_step += 1 + val_timesteps_step += 1 if is_tracking: loss_validation_divergence = val_step_loss_recorder.moving_average - loss_recorder.moving_average @@ -1478,7 +1515,7 @@ def restore_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTen "loss/validation/step_average": val_step_loss_recorder.moving_average, "loss/validation/step_divergence": loss_validation_divergence, } - accelerator.log(logs, step=global_step) + self.step_logging(accelerator, logs, global_step, epoch=epoch + 1) restore_rng_state(rng_states) args.min_timestep = original_args_min_timestep @@ -1507,7 +1544,7 @@ def restore_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTen desc="epoch validation steps", ) - val_ts_step = 0 + val_timesteps_step = 0 for val_step, batch in enumerate(val_dataloader): if val_step >= validation_steps: break @@ -1537,18 +1574,18 @@ def restore_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTen ) current_loss = loss.detach().item() - val_epoch_loss_recorder.add(epoch=epoch, step=val_ts_step, loss=current_loss) + val_epoch_loss_recorder.add(epoch=epoch, step=val_timesteps_step, loss=current_loss) val_progress_bar.update(1) val_progress_bar.set_postfix( {"val_epoch_avg_loss": val_epoch_loss_recorder.moving_average, "timestep": timestep} ) - if is_tracking: - logs = {"loss/validation/epoch_current": current_loss} - accelerator.log(logs, step=global_step + val_ts_step) + # if is_tracking: + # logs = {f"loss/validation/epoch_current_{timestep}": current_loss} + # self.val_logging(accelerator, logs, global_step, epoch + 1, val_step) self.on_validation_step_end(args, accelerator, network, text_encoders, unet, batch, weight_dtype) - val_ts_step += 1 + val_timesteps_step += 1 if is_tracking: avr_loss: float = val_epoch_loss_recorder.moving_average @@ -1557,7 +1594,7 @@ def restore_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTen "loss/validation/epoch_average": avr_loss, "loss/validation/epoch_divergence": loss_validation_divergence, } - accelerator.log(logs, step=epoch + 1) + self.epoch_logging(accelerator, logs, global_step, epoch + 1) restore_rng_state(rng_states) args.min_timestep = original_args_min_timestep @@ -1569,7 +1606,7 @@ def restore_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTen # END OF EPOCH if is_tracking: logs = {"loss/epoch_average": loss_recorder.moving_average} - accelerator.log(logs, step=epoch + 1) + self.epoch_logging(accelerator, logs, global_step, epoch + 1) accelerator.wait_for_everyone() From f68702f71c16719d0f85820a2a4585f19b96552f Mon Sep 17 00:00:00 2001 From: Disty0 Date: Tue, 25 Feb 2025 21:27:41 +0300 Subject: [PATCH 24/27] Update IPEX libs --- library/device_utils.py | 11 +- library/ipex/__init__.py | 170 +++++++++++------- library/ipex/attention.py | 236 ++++++++++--------------- library/ipex/diffusers.py | 349 +++++-------------------------------- library/ipex/gradscaler.py | 2 +- library/ipex/hijacks.py | 132 +++++++++----- 6 files changed, 337 insertions(+), 563 deletions(-) diff --git a/library/device_utils.py b/library/device_utils.py index 8823c5d9a..d2e197450 100644 --- a/library/device_utils.py +++ b/library/device_utils.py @@ -2,6 +2,13 @@ import gc import torch +try: + # intel gpu support for pytorch older than 2.5 + # ipex is not needed after pytorch 2.5 + import intel_extension_for_pytorch as ipex # noqa +except Exception: + pass + try: HAS_CUDA = torch.cuda.is_available() @@ -14,8 +21,6 @@ HAS_MPS = False try: - import intel_extension_for_pytorch as ipex # noqa - HAS_XPU = torch.xpu.is_available() except Exception: HAS_XPU = False @@ -69,7 +74,7 @@ def init_ipex(): This function should run right after importing torch and before doing anything else. - If IPEX is not available, this function does nothing. + If xpu is not available, this function does nothing. """ try: if HAS_XPU: diff --git a/library/ipex/__init__.py b/library/ipex/__init__.py index e5aba693c..a36664bb3 100644 --- a/library/ipex/__init__.py +++ b/library/ipex/__init__.py @@ -2,7 +2,11 @@ import sys import contextlib import torch -import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import +try: + import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import + legacy = True +except Exception: + legacy = False from .hijacks import ipex_hijacks # pylint: disable=protected-access, missing-function-docstring, line-too-long @@ -12,6 +16,13 @@ def ipex_init(): # pylint: disable=too-many-statements if hasattr(torch, "cuda") and hasattr(torch.cuda, "is_xpu_hijacked") and torch.cuda.is_xpu_hijacked: return True, "Skipping IPEX hijack" else: + try: # force xpu device on torch compile and triton + torch._inductor.utils.GPU_TYPES = ["xpu"] + torch._inductor.utils.get_gpu_type = lambda *args, **kwargs: "xpu" + from triton import backends as triton_backends # pylint: disable=import-error + triton_backends.backends["nvidia"].driver.is_active = lambda *args, **kwargs: False + except Exception: + pass # Replace cuda with xpu: torch.cuda.current_device = torch.xpu.current_device torch.cuda.current_stream = torch.xpu.current_stream @@ -26,84 +37,99 @@ def ipex_init(): # pylint: disable=too-many-statements torch.cuda.is_current_stream_capturing = lambda: False torch.cuda.set_device = torch.xpu.set_device torch.cuda.stream = torch.xpu.stream - torch.cuda.synchronize = torch.xpu.synchronize torch.cuda.Event = torch.xpu.Event torch.cuda.Stream = torch.xpu.Stream - torch.cuda.FloatTensor = torch.xpu.FloatTensor torch.Tensor.cuda = torch.Tensor.xpu torch.Tensor.is_cuda = torch.Tensor.is_xpu torch.nn.Module.cuda = torch.nn.Module.xpu - torch.UntypedStorage.cuda = torch.UntypedStorage.xpu - torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock - torch.cuda._initialized = torch.xpu.lazy_init._initialized - torch.cuda._lazy_seed_tracker = torch.xpu.lazy_init._lazy_seed_tracker - torch.cuda._queued_calls = torch.xpu.lazy_init._queued_calls - torch.cuda._tls = torch.xpu.lazy_init._tls - torch.cuda.threading = torch.xpu.lazy_init.threading - torch.cuda.traceback = torch.xpu.lazy_init.traceback torch.cuda.Optional = torch.xpu.Optional torch.cuda.__cached__ = torch.xpu.__cached__ torch.cuda.__loader__ = torch.xpu.__loader__ - torch.cuda.ComplexFloatStorage = torch.xpu.ComplexFloatStorage torch.cuda.Tuple = torch.xpu.Tuple torch.cuda.streams = torch.xpu.streams - torch.cuda._lazy_new = torch.xpu._lazy_new - torch.cuda.FloatStorage = torch.xpu.FloatStorage torch.cuda.Any = torch.xpu.Any torch.cuda.__doc__ = torch.xpu.__doc__ torch.cuda.default_generators = torch.xpu.default_generators - torch.cuda.HalfTensor = torch.xpu.HalfTensor torch.cuda._get_device_index = torch.xpu._get_device_index torch.cuda.__path__ = torch.xpu.__path__ - torch.cuda.Device = torch.xpu.Device - torch.cuda.IntTensor = torch.xpu.IntTensor - torch.cuda.ByteStorage = torch.xpu.ByteStorage torch.cuda.set_stream = torch.xpu.set_stream - torch.cuda.BoolStorage = torch.xpu.BoolStorage - torch.cuda.os = torch.xpu.os torch.cuda.torch = torch.xpu.torch - torch.cuda.BFloat16Storage = torch.xpu.BFloat16Storage torch.cuda.Union = torch.xpu.Union - torch.cuda.DoubleTensor = torch.xpu.DoubleTensor - torch.cuda.ShortTensor = torch.xpu.ShortTensor - torch.cuda.LongTensor = torch.xpu.LongTensor - torch.cuda.IntStorage = torch.xpu.IntStorage - torch.cuda.LongStorage = torch.xpu.LongStorage torch.cuda.__annotations__ = torch.xpu.__annotations__ torch.cuda.__package__ = torch.xpu.__package__ torch.cuda.__builtins__ = torch.xpu.__builtins__ - torch.cuda.CharTensor = torch.xpu.CharTensor torch.cuda.List = torch.xpu.List torch.cuda._lazy_init = torch.xpu._lazy_init - torch.cuda.BFloat16Tensor = torch.xpu.BFloat16Tensor - torch.cuda.DoubleStorage = torch.xpu.DoubleStorage - torch.cuda.ByteTensor = torch.xpu.ByteTensor torch.cuda.StreamContext = torch.xpu.StreamContext - torch.cuda.ComplexDoubleStorage = torch.xpu.ComplexDoubleStorage - torch.cuda.ShortStorage = torch.xpu.ShortStorage torch.cuda._lazy_call = torch.xpu._lazy_call - torch.cuda.HalfStorage = torch.xpu.HalfStorage torch.cuda.random = torch.xpu.random torch.cuda._device = torch.xpu._device - torch.cuda.classproperty = torch.xpu.classproperty torch.cuda.__name__ = torch.xpu.__name__ torch.cuda._device_t = torch.xpu._device_t - torch.cuda.warnings = torch.xpu.warnings torch.cuda.__spec__ = torch.xpu.__spec__ - torch.cuda.BoolTensor = torch.xpu.BoolTensor - torch.cuda.CharStorage = torch.xpu.CharStorage torch.cuda.__file__ = torch.xpu.__file__ - torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork # torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing + if legacy: + torch.cuda.os = torch.xpu.os + torch.cuda.Device = torch.xpu.Device + torch.cuda.warnings = torch.xpu.warnings + torch.cuda.classproperty = torch.xpu.classproperty + torch.UntypedStorage.cuda = torch.UntypedStorage.xpu + if float(ipex.__version__[:3]) < 2.3: + torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock + torch.cuda._initialized = torch.xpu.lazy_init._initialized + torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork + torch.cuda._lazy_seed_tracker = torch.xpu.lazy_init._lazy_seed_tracker + torch.cuda._queued_calls = torch.xpu.lazy_init._queued_calls + torch.cuda._tls = torch.xpu.lazy_init._tls + torch.cuda.threading = torch.xpu.lazy_init.threading + torch.cuda.traceback = torch.xpu.lazy_init.traceback + torch.cuda._lazy_new = torch.xpu._lazy_new + + torch.cuda.FloatTensor = torch.xpu.FloatTensor + torch.cuda.FloatStorage = torch.xpu.FloatStorage + torch.cuda.BFloat16Tensor = torch.xpu.BFloat16Tensor + torch.cuda.BFloat16Storage = torch.xpu.BFloat16Storage + torch.cuda.HalfTensor = torch.xpu.HalfTensor + torch.cuda.HalfStorage = torch.xpu.HalfStorage + torch.cuda.ByteTensor = torch.xpu.ByteTensor + torch.cuda.ByteStorage = torch.xpu.ByteStorage + torch.cuda.DoubleTensor = torch.xpu.DoubleTensor + torch.cuda.DoubleStorage = torch.xpu.DoubleStorage + torch.cuda.ShortTensor = torch.xpu.ShortTensor + torch.cuda.ShortStorage = torch.xpu.ShortStorage + torch.cuda.LongTensor = torch.xpu.LongTensor + torch.cuda.LongStorage = torch.xpu.LongStorage + torch.cuda.IntTensor = torch.xpu.IntTensor + torch.cuda.IntStorage = torch.xpu.IntStorage + torch.cuda.CharTensor = torch.xpu.CharTensor + torch.cuda.CharStorage = torch.xpu.CharStorage + torch.cuda.BoolTensor = torch.xpu.BoolTensor + torch.cuda.BoolStorage = torch.xpu.BoolStorage + torch.cuda.ComplexFloatStorage = torch.xpu.ComplexFloatStorage + torch.cuda.ComplexDoubleStorage = torch.xpu.ComplexDoubleStorage + + if not legacy or float(ipex.__version__[:3]) >= 2.3: + torch.cuda._initialization_lock = torch.xpu._initialization_lock + torch.cuda._initialized = torch.xpu._initialized + torch.cuda._is_in_bad_fork = torch.xpu._is_in_bad_fork + torch.cuda._lazy_seed_tracker = torch.xpu._lazy_seed_tracker + torch.cuda._queued_calls = torch.xpu._queued_calls + torch.cuda._tls = torch.xpu._tls + torch.cuda.threading = torch.xpu.threading + torch.cuda.traceback = torch.xpu.traceback + # Memory: - torch.cuda.memory = torch.xpu.memory if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read(): torch.xpu.empty_cache = lambda: None torch.cuda.empty_cache = torch.xpu.empty_cache + + if legacy: + torch.cuda.memory_summary = torch.xpu.memory_summary + torch.cuda.memory_snapshot = torch.xpu.memory_snapshot + torch.cuda.memory = torch.xpu.memory torch.cuda.memory_stats = torch.xpu.memory_stats - torch.cuda.memory_summary = torch.xpu.memory_summary - torch.cuda.memory_snapshot = torch.xpu.memory_snapshot torch.cuda.memory_allocated = torch.xpu.memory_allocated torch.cuda.max_memory_allocated = torch.xpu.max_memory_allocated torch.cuda.memory_reserved = torch.xpu.memory_reserved @@ -128,32 +154,44 @@ def ipex_init(): # pylint: disable=too-many-statements torch.cuda.initial_seed = torch.xpu.initial_seed # AMP: - torch.cuda.amp = torch.xpu.amp - torch.is_autocast_enabled = torch.xpu.is_autocast_xpu_enabled - torch.get_autocast_gpu_dtype = torch.xpu.get_autocast_xpu_dtype + if legacy: + torch.xpu.amp.custom_fwd = torch.cuda.amp.custom_fwd + torch.xpu.amp.custom_bwd = torch.cuda.amp.custom_bwd + torch.cuda.amp = torch.xpu.amp + if float(ipex.__version__[:3]) < 2.3: + torch.is_autocast_enabled = torch.xpu.is_autocast_xpu_enabled + torch.get_autocast_gpu_dtype = torch.xpu.get_autocast_xpu_dtype - if not hasattr(torch.cuda.amp, "common"): - torch.cuda.amp.common = contextlib.nullcontext() - torch.cuda.amp.common.amp_definitely_not_available = lambda: False + if not hasattr(torch.cuda.amp, "common"): + torch.cuda.amp.common = contextlib.nullcontext() + torch.cuda.amp.common.amp_definitely_not_available = lambda: False - try: - torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler - except Exception: # pylint: disable=broad-exception-caught try: - from .gradscaler import gradscaler_init # pylint: disable=import-outside-toplevel, import-error - gradscaler_init() torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler except Exception: # pylint: disable=broad-exception-caught - torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler + try: + from .gradscaler import gradscaler_init # pylint: disable=import-outside-toplevel, import-error + gradscaler_init() + torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler + except Exception: # pylint: disable=broad-exception-caught + torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler # C - torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream - ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_subslice_count - ipex._C._DeviceProperties.major = 2024 - ipex._C._DeviceProperties.minor = 0 + if legacy and float(ipex.__version__[:3]) < 2.3: + torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentRawStream + ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_subslice_count + ipex._C._DeviceProperties.major = 12 + ipex._C._DeviceProperties.minor = 1 + else: + torch._C._cuda_getCurrentRawStream = torch._C._xpu_getCurrentRawStream + torch._C._XpuDeviceProperties.multi_processor_count = torch._C._XpuDeviceProperties.gpu_subslice_count + torch._C._XpuDeviceProperties.major = 12 + torch._C._XpuDeviceProperties.minor = 1 # Fix functions with ipex: - torch.cuda.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_reserved(device)), torch.xpu.get_device_properties(device).total_memory] + # torch.xpu.mem_get_info always returns the total memory as free memory + torch.xpu.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_reserved(device)), torch.xpu.get_device_properties(device).total_memory] + torch.cuda.mem_get_info = torch.xpu.mem_get_info torch._utils._get_available_device_type = lambda: "xpu" torch.has_cuda = True torch.cuda.has_half = True @@ -161,19 +199,19 @@ def ipex_init(): # pylint: disable=too-many-statements torch.cuda.is_fp16_supported = lambda *args, **kwargs: True torch.backends.cuda.is_built = lambda *args, **kwargs: True torch.version.cuda = "12.1" - torch.cuda.get_device_capability = lambda *args, **kwargs: [12,1] + torch.cuda.get_arch_list = lambda: ["ats-m150", "pvc"] + torch.cuda.get_device_capability = lambda *args, **kwargs: (12,1) torch.cuda.get_device_properties.major = 12 torch.cuda.get_device_properties.minor = 1 torch.cuda.ipc_collect = lambda *args, **kwargs: None torch.cuda.utilization = lambda *args, **kwargs: 0 - ipex_hijacks() - if not torch.xpu.has_fp64_dtype() or os.environ.get('IPEX_FORCE_ATTENTION_SLICE', None) is not None: - try: - from .diffusers import ipex_diffusers - ipex_diffusers() - except Exception: # pylint: disable=broad-exception-caught - pass + device_supports_fp64, can_allocate_plus_4gb = ipex_hijacks(legacy=legacy) + try: + from .diffusers import ipex_diffusers + ipex_diffusers(device_supports_fp64=device_supports_fp64, can_allocate_plus_4gb=can_allocate_plus_4gb) + except Exception: # pylint: disable=broad-exception-caught + pass torch.cuda.is_xpu_hijacked = True except Exception as e: return False, e diff --git a/library/ipex/attention.py b/library/ipex/attention.py index 2bc62f65c..400b59b66 100644 --- a/library/ipex/attention.py +++ b/library/ipex/attention.py @@ -1,177 +1,119 @@ import os import torch -import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import -from functools import cache +from functools import cache, wraps # pylint: disable=protected-access, missing-function-docstring, line-too-long # ARC GPUs can't allocate more than 4GB to a single block so we slice the attention layers -sdpa_slice_trigger_rate = float(os.environ.get('IPEX_SDPA_SLICE_TRIGGER_RATE', 4)) -attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 4)) +sdpa_slice_trigger_rate = float(os.environ.get('IPEX_SDPA_SLICE_TRIGGER_RATE', 1)) +attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 0.5)) # Find something divisible with the input_tokens @cache -def find_slice_size(slice_size, slice_block_size): - while (slice_size * slice_block_size) > attention_slice_rate: - slice_size = slice_size // 2 - if slice_size <= 1: - slice_size = 1 - break - return slice_size +def find_split_size(original_size, slice_block_size, slice_rate=2): + split_size = original_size + while True: + if (split_size * slice_block_size) <= slice_rate and original_size % split_size == 0: + return split_size + split_size = split_size - 1 + if split_size <= 1: + return 1 + return split_size + # Find slice sizes for SDPA @cache -def find_sdpa_slice_sizes(query_shape, query_element_size): - if len(query_shape) == 3: - batch_size_attention, query_tokens, shape_three = query_shape - shape_four = 1 - else: - batch_size_attention, query_tokens, shape_three, shape_four = query_shape - - slice_block_size = query_tokens * shape_three * shape_four / 1024 / 1024 * query_element_size - block_size = batch_size_attention * slice_block_size - - split_slice_size = batch_size_attention - split_2_slice_size = query_tokens - split_3_slice_size = shape_three - - do_split = False - do_split_2 = False - do_split_3 = False - - if block_size > sdpa_slice_trigger_rate: - do_split = True - split_slice_size = find_slice_size(split_slice_size, slice_block_size) - if split_slice_size * slice_block_size > attention_slice_rate: - slice_2_block_size = split_slice_size * shape_three * shape_four / 1024 / 1024 * query_element_size - do_split_2 = True - split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size) - if split_2_slice_size * slice_2_block_size > attention_slice_rate: - slice_3_block_size = split_slice_size * split_2_slice_size * shape_four / 1024 / 1024 * query_element_size - do_split_3 = True - split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size) - - return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size - -# Find slice sizes for BMM -@cache -def find_bmm_slice_sizes(input_shape, input_element_size, mat2_shape): - batch_size_attention, input_tokens, mat2_atten_shape = input_shape[0], input_shape[1], mat2_shape[2] - slice_block_size = input_tokens * mat2_atten_shape / 1024 / 1024 * input_element_size - block_size = batch_size_attention * slice_block_size - - split_slice_size = batch_size_attention - split_2_slice_size = input_tokens - split_3_slice_size = mat2_atten_shape - - do_split = False - do_split_2 = False - do_split_3 = False - - if block_size > attention_slice_rate: - do_split = True - split_slice_size = find_slice_size(split_slice_size, slice_block_size) - if split_slice_size * slice_block_size > attention_slice_rate: - slice_2_block_size = split_slice_size * mat2_atten_shape / 1024 / 1024 * input_element_size - do_split_2 = True - split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size) - if split_2_slice_size * slice_2_block_size > attention_slice_rate: - slice_3_block_size = split_slice_size * split_2_slice_size / 1024 / 1024 * input_element_size - do_split_3 = True - split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size) - - return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size - - -original_torch_bmm = torch.bmm -def torch_bmm_32_bit(input, mat2, *, out=None): - if input.device.type != "xpu": - return original_torch_bmm(input, mat2, out=out) - do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_bmm_slice_sizes(input.shape, input.element_size(), mat2.shape) - - # Slice BMM - if do_split: - batch_size_attention, input_tokens, mat2_atten_shape = input.shape[0], input.shape[1], mat2.shape[2] - hidden_states = torch.zeros(input.shape[0], input.shape[1], mat2.shape[2], device=input.device, dtype=input.dtype) - for i in range(batch_size_attention // split_slice_size): - start_idx = i * split_slice_size - end_idx = (i + 1) * split_slice_size - if do_split_2: - for i2 in range(input_tokens // split_2_slice_size): # pylint: disable=invalid-name - start_idx_2 = i2 * split_2_slice_size - end_idx_2 = (i2 + 1) * split_2_slice_size - if do_split_3: - for i3 in range(mat2_atten_shape // split_3_slice_size): # pylint: disable=invalid-name - start_idx_3 = i3 * split_3_slice_size - end_idx_3 = (i3 + 1) * split_3_slice_size - hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = original_torch_bmm( - input[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3], - mat2[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3], - out=out - ) - else: - hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_torch_bmm( - input[start_idx:end_idx, start_idx_2:end_idx_2], - mat2[start_idx:end_idx, start_idx_2:end_idx_2], - out=out - ) - else: - hidden_states[start_idx:end_idx] = original_torch_bmm( - input[start_idx:end_idx], - mat2[start_idx:end_idx], - out=out - ) - torch.xpu.synchronize(input.device) - else: - return original_torch_bmm(input, mat2, out=out) - return hidden_states +def find_sdpa_slice_sizes(query_shape, key_shape, query_element_size, slice_rate=2, trigger_rate=3): + batch_size, attn_heads, query_len, _ = query_shape + _, _, key_len, _ = key_shape + + slice_batch_size = attn_heads * (query_len * key_len) * query_element_size / 1024 / 1024 / 1024 + + split_batch_size = batch_size + split_head_size = attn_heads + split_query_size = query_len + + do_batch_split = False + do_head_split = False + do_query_split = False + + if batch_size * slice_batch_size >= trigger_rate: + do_batch_split = True + split_batch_size = find_split_size(batch_size, slice_batch_size, slice_rate=slice_rate) + + if split_batch_size * slice_batch_size > slice_rate: + slice_head_size = split_batch_size * (query_len * key_len) * query_element_size / 1024 / 1024 / 1024 + do_head_split = True + split_head_size = find_split_size(attn_heads, slice_head_size, slice_rate=slice_rate) + + if split_head_size * slice_head_size > slice_rate: + slice_query_size = split_batch_size * split_head_size * (key_len) * query_element_size / 1024 / 1024 / 1024 + do_query_split = True + split_query_size = find_split_size(query_len, slice_query_size, slice_rate=slice_rate) + + return do_batch_split, do_head_split, do_query_split, split_batch_size, split_head_size, split_query_size + original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention -def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, **kwargs): +@wraps(torch.nn.functional.scaled_dot_product_attention) +def dynamic_scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, **kwargs): if query.device.type != "xpu": return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs) - do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_sdpa_slice_sizes(query.shape, query.element_size()) + is_unsqueezed = False + if len(query.shape) == 3: + query = query.unsqueeze(0) + is_unsqueezed = True + if len(key.shape) == 3: + key = key.unsqueeze(0) + if len(value.shape) == 3: + value = value.unsqueeze(0) + do_batch_split, do_head_split, do_query_split, split_batch_size, split_head_size, split_query_size = find_sdpa_slice_sizes(query.shape, key.shape, query.element_size(), slice_rate=attention_slice_rate, trigger_rate=sdpa_slice_trigger_rate) # Slice SDPA - if do_split: - batch_size_attention, query_tokens, shape_three = query.shape[0], query.shape[1], query.shape[2] - hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype) - for i in range(batch_size_attention // split_slice_size): - start_idx = i * split_slice_size - end_idx = (i + 1) * split_slice_size - if do_split_2: - for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name - start_idx_2 = i2 * split_2_slice_size - end_idx_2 = (i2 + 1) * split_2_slice_size - if do_split_3: - for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name - start_idx_3 = i3 * split_3_slice_size - end_idx_3 = (i3 + 1) * split_3_slice_size - hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = original_scaled_dot_product_attention( - query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3], - key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3], - value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3], - attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attn_mask is not None else attn_mask, + if do_batch_split: + batch_size, attn_heads, query_len, _ = query.shape + _, _, _, head_dim = value.shape + hidden_states = torch.zeros((batch_size, attn_heads, query_len, head_dim), device=query.device, dtype=query.dtype) + if attn_mask is not None: + attn_mask = attn_mask.expand((query.shape[0], query.shape[1], query.shape[2], key.shape[-2])) + for ib in range(batch_size // split_batch_size): + start_idx = ib * split_batch_size + end_idx = (ib + 1) * split_batch_size + if do_head_split: + for ih in range(attn_heads // split_head_size): # pylint: disable=invalid-name + start_idx_h = ih * split_head_size + end_idx_h = (ih + 1) * split_head_size + if do_query_split: + for iq in range(query_len // split_query_size): # pylint: disable=invalid-name + start_idx_q = iq * split_query_size + end_idx_q = (iq + 1) * split_query_size + hidden_states[start_idx:end_idx, start_idx_h:end_idx_h, start_idx_q:end_idx_q, :] = original_scaled_dot_product_attention( + query[start_idx:end_idx, start_idx_h:end_idx_h, start_idx_q:end_idx_q, :], + key[start_idx:end_idx, start_idx_h:end_idx_h, :, :], + value[start_idx:end_idx, start_idx_h:end_idx_h, :, :], + attn_mask=attn_mask[start_idx:end_idx, start_idx_h:end_idx_h, start_idx_q:end_idx_q, :] if attn_mask is not None else attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs ) else: - hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_scaled_dot_product_attention( - query[start_idx:end_idx, start_idx_2:end_idx_2], - key[start_idx:end_idx, start_idx_2:end_idx_2], - value[start_idx:end_idx, start_idx_2:end_idx_2], - attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attn_mask is not None else attn_mask, + hidden_states[start_idx:end_idx, start_idx_h:end_idx_h, :, :] = original_scaled_dot_product_attention( + query[start_idx:end_idx, start_idx_h:end_idx_h, :, :], + key[start_idx:end_idx, start_idx_h:end_idx_h, :, :], + value[start_idx:end_idx, start_idx_h:end_idx_h, :, :], + attn_mask=attn_mask[start_idx:end_idx, start_idx_h:end_idx_h, :, :] if attn_mask is not None else attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs ) else: - hidden_states[start_idx:end_idx] = original_scaled_dot_product_attention( - query[start_idx:end_idx], - key[start_idx:end_idx], - value[start_idx:end_idx], - attn_mask=attn_mask[start_idx:end_idx] if attn_mask is not None else attn_mask, + hidden_states[start_idx:end_idx, :, :, :] = original_scaled_dot_product_attention( + query[start_idx:end_idx, :, :, :], + key[start_idx:end_idx, :, :, :], + value[start_idx:end_idx, :, :, :], + attn_mask=attn_mask[start_idx:end_idx, :, :, :] if attn_mask is not None else attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs ) torch.xpu.synchronize(query.device) else: - return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs) + hidden_states = original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs) + if is_unsqueezed: + hidden_states.squeeze(0) return hidden_states diff --git a/library/ipex/diffusers.py b/library/ipex/diffusers.py index 732a18568..75715d161 100644 --- a/library/ipex/diffusers.py +++ b/library/ipex/diffusers.py @@ -1,312 +1,47 @@ -import os +from functools import wraps import torch -import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import -import diffusers #0.24.0 # pylint: disable=import-error -from diffusers.models.attention_processor import Attention -from diffusers.utils import USE_PEFT_BACKEND -from functools import cache +import diffusers # pylint: disable=import-error # pylint: disable=protected-access, missing-function-docstring, line-too-long -attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 4)) -@cache -def find_slice_size(slice_size, slice_block_size): - while (slice_size * slice_block_size) > attention_slice_rate: - slice_size = slice_size // 2 - if slice_size <= 1: - slice_size = 1 - break - return slice_size - -@cache -def find_attention_slice_sizes(query_shape, query_element_size, query_device_type, slice_size=None): - if len(query_shape) == 3: - batch_size_attention, query_tokens, shape_three = query_shape - shape_four = 1 - else: - batch_size_attention, query_tokens, shape_three, shape_four = query_shape - if slice_size is not None: - batch_size_attention = slice_size - - slice_block_size = query_tokens * shape_three * shape_four / 1024 / 1024 * query_element_size - block_size = batch_size_attention * slice_block_size - - split_slice_size = batch_size_attention - split_2_slice_size = query_tokens - split_3_slice_size = shape_three - - do_split = False - do_split_2 = False - do_split_3 = False - - if query_device_type != "xpu": - return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size - - if block_size > attention_slice_rate: - do_split = True - split_slice_size = find_slice_size(split_slice_size, slice_block_size) - if split_slice_size * slice_block_size > attention_slice_rate: - slice_2_block_size = split_slice_size * shape_three * shape_four / 1024 / 1024 * query_element_size - do_split_2 = True - split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size) - if split_2_slice_size * slice_2_block_size > attention_slice_rate: - slice_3_block_size = split_slice_size * split_2_slice_size * shape_four / 1024 / 1024 * query_element_size - do_split_3 = True - split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size) - - return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size - -class SlicedAttnProcessor: # pylint: disable=too-few-public-methods - r""" - Processor for implementing sliced attention. - - Args: - slice_size (`int`, *optional*): - The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and - `attention_head_dim` must be a multiple of the `slice_size`. - """ - - def __init__(self, slice_size): - self.slice_size = slice_size - - def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, - encoder_hidden_states=None, attention_mask=None) -> torch.FloatTensor: # pylint: disable=too-many-statements, too-many-locals, too-many-branches - - residual = hidden_states - - input_ndim = hidden_states.ndim - - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - - if attn.group_norm is not None: - hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - - query = attn.to_q(hidden_states) - dim = query.shape[-1] - query = attn.head_to_batch_dim(query) - - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - elif attn.norm_cross: - encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - key = attn.head_to_batch_dim(key) - value = attn.head_to_batch_dim(value) - - batch_size_attention, query_tokens, shape_three = query.shape - hidden_states = torch.zeros( - (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype - ) - - #################################################################### - # ARC GPUs can't allocate more than 4GB to a single block, Slice it: - _, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_attention_slice_sizes(query.shape, query.element_size(), query.device.type, slice_size=self.slice_size) - - for i in range(batch_size_attention // split_slice_size): - start_idx = i * split_slice_size - end_idx = (i + 1) * split_slice_size - if do_split_2: - for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name - start_idx_2 = i2 * split_2_slice_size - end_idx_2 = (i2 + 1) * split_2_slice_size - if do_split_3: - for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name - start_idx_3 = i3 * split_3_slice_size - end_idx_3 = (i3 + 1) * split_3_slice_size - - query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] - key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] - attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attention_mask is not None else None - - attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) - del query_slice - del key_slice - del attn_mask_slice - attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]) - - hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = attn_slice - del attn_slice - else: - query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2] - key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2] - attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attention_mask is not None else None - - attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) - del query_slice - del key_slice - del attn_mask_slice - attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2]) - - hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice - del attn_slice - torch.xpu.synchronize(query.device) - else: - query_slice = query[start_idx:end_idx] - key_slice = key[start_idx:end_idx] - attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None - - attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) - del query_slice - del key_slice - del attn_mask_slice - attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) - - hidden_states[start_idx:end_idx] = attn_slice - del attn_slice - #################################################################### - - hidden_states = attn.batch_to_head_dim(hidden_states) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) - - if attn.residual_connection: - hidden_states = hidden_states + residual - - hidden_states = hidden_states / attn.rescale_output_factor - - return hidden_states - - -class AttnProcessor: - r""" - Default processor for performing attention-related computations. - """ - - def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, - encoder_hidden_states=None, attention_mask=None, - temb=None, scale: float = 1.0) -> torch.Tensor: # pylint: disable=too-many-statements, too-many-locals, too-many-branches - - residual = hidden_states - - args = () if USE_PEFT_BACKEND else (scale,) - - if attn.spatial_norm is not None: - hidden_states = attn.spatial_norm(hidden_states, temb) - - input_ndim = hidden_states.ndim - - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - - if attn.group_norm is not None: - hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - - query = attn.to_q(hidden_states, *args) - - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - elif attn.norm_cross: - encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - - key = attn.to_k(encoder_hidden_states, *args) - value = attn.to_v(encoder_hidden_states, *args) - - query = attn.head_to_batch_dim(query) - key = attn.head_to_batch_dim(key) - value = attn.head_to_batch_dim(value) - - #################################################################### - # ARC GPUs can't allocate more than 4GB to a single block, Slice it: - batch_size_attention, query_tokens, shape_three = query.shape[0], query.shape[1], query.shape[2] - hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype) - do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_attention_slice_sizes(query.shape, query.element_size(), query.device.type) - - if do_split: - for i in range(batch_size_attention // split_slice_size): - start_idx = i * split_slice_size - end_idx = (i + 1) * split_slice_size - if do_split_2: - for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name - start_idx_2 = i2 * split_2_slice_size - end_idx_2 = (i2 + 1) * split_2_slice_size - if do_split_3: - for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name - start_idx_3 = i3 * split_3_slice_size - end_idx_3 = (i3 + 1) * split_3_slice_size - - query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] - key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] - attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attention_mask is not None else None - - attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) - del query_slice - del key_slice - del attn_mask_slice - attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]) - - hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = attn_slice - del attn_slice - else: - query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2] - key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2] - attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attention_mask is not None else None - - attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) - del query_slice - del key_slice - del attn_mask_slice - attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2]) - - hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice - del attn_slice - else: - query_slice = query[start_idx:end_idx] - key_slice = key[start_idx:end_idx] - attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None - - attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) - del query_slice - del key_slice - del attn_mask_slice - attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) - - hidden_states[start_idx:end_idx] = attn_slice - del attn_slice - torch.xpu.synchronize(query.device) - else: - attention_probs = attn.get_attention_scores(query, key, attention_mask) - hidden_states = torch.bmm(attention_probs, value) - #################################################################### - hidden_states = attn.batch_to_head_dim(hidden_states) - - # linear proj - hidden_states = attn.to_out[0](hidden_states, *args) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) - - if attn.residual_connection: - hidden_states = hidden_states + residual - - hidden_states = hidden_states / attn.rescale_output_factor - - return hidden_states - -def ipex_diffusers(): - #ARC GPUs can't allocate more than 4GB to a single block: - diffusers.models.attention_processor.SlicedAttnProcessor = SlicedAttnProcessor - diffusers.models.attention_processor.AttnProcessor = AttnProcessor +# Diffusers FreeU +original_fourier_filter = diffusers.utils.torch_utils.fourier_filter +@wraps(diffusers.utils.torch_utils.fourier_filter) +def fourier_filter(x_in, threshold, scale): + return_dtype = x_in.dtype + return original_fourier_filter(x_in.to(dtype=torch.float32), threshold, scale).to(dtype=return_dtype) + + +# fp64 error +class FluxPosEmbed(torch.nn.Module): + def __init__(self, theta: int, axes_dim): + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + + def forward(self, ids: torch.Tensor) -> torch.Tensor: + n_axes = ids.shape[-1] + cos_out = [] + sin_out = [] + pos = ids.float() + for i in range(n_axes): + cos, sin = diffusers.models.embeddings.get_1d_rotary_pos_embed( + self.axes_dim[i], + pos[:, i], + theta=self.theta, + repeat_interleave_real=True, + use_real=True, + freqs_dtype=torch.float32, + ) + cos_out.append(cos) + sin_out.append(sin) + freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device) + freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device) + return freqs_cos, freqs_sin + + +def ipex_diffusers(device_supports_fp64=False, can_allocate_plus_4gb=False): + diffusers.utils.torch_utils.fourier_filter = fourier_filter + if not device_supports_fp64: + diffusers.models.embeddings.FluxPosEmbed = FluxPosEmbed diff --git a/library/ipex/gradscaler.py b/library/ipex/gradscaler.py index 6eb56bc2b..0a8610095 100644 --- a/library/ipex/gradscaler.py +++ b/library/ipex/gradscaler.py @@ -5,7 +5,7 @@ # pylint: disable=protected-access, missing-function-docstring, line-too-long -device_supports_fp64 = torch.xpu.has_fp64_dtype() +device_supports_fp64 = torch.xpu.has_fp64_dtype() if hasattr(torch.xpu, "has_fp64_dtype") else torch.xpu.get_device_properties("xpu").has_fp64 OptState = ipex.cpu.autocast._grad_scaler.OptState _MultiDeviceReplicator = ipex.cpu.autocast._grad_scaler._MultiDeviceReplicator _refresh_per_optimizer_state = ipex.cpu.autocast._grad_scaler._refresh_per_optimizer_state diff --git a/library/ipex/hijacks.py b/library/ipex/hijacks.py index d3cef8276..91569746a 100644 --- a/library/ipex/hijacks.py +++ b/library/ipex/hijacks.py @@ -2,10 +2,19 @@ from functools import wraps from contextlib import nullcontext import torch -import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import import numpy as np -device_supports_fp64 = torch.xpu.has_fp64_dtype() +device_supports_fp64 = torch.xpu.has_fp64_dtype() if hasattr(torch.xpu, "has_fp64_dtype") else torch.xpu.get_device_properties("xpu").has_fp64 +if os.environ.get('IPEX_FORCE_ATTENTION_SLICE', '0') == '0' and (torch.xpu.get_device_properties("xpu").total_memory / 1024 / 1024 / 1024) > 4.1: + try: + x = torch.ones((33000,33000), dtype=torch.float32, device="xpu") + del x + torch.xpu.empty_cache() + can_allocate_plus_4gb = True + except Exception: + can_allocate_plus_4gb = False +else: + can_allocate_plus_4gb = bool(os.environ.get('IPEX_FORCE_ATTENTION_SLICE', '0') == '-1') # pylint: disable=protected-access, missing-function-docstring, line-too-long, unnecessary-lambda, no-else-return @@ -26,7 +35,7 @@ def check_device(device): return bool((isinstance(device, torch.device) and device.type == "cuda") or (isinstance(device, str) and "cuda" in device) or isinstance(device, int)) def return_xpu(device): - return f"xpu:{device.split(':')[-1]}" if isinstance(device, str) and ":" in device else f"xpu:{device}" if isinstance(device, int) else torch.device("xpu") if isinstance(device, torch.device) else "xpu" + return f"xpu:{device.split(':')[-1]}" if isinstance(device, str) and ":" in device else f"xpu:{device}" if isinstance(device, int) else torch.device(f"xpu:{device.index}" if device.index is not None else "xpu") if isinstance(device, torch.device) else "xpu" # Autocast @@ -42,7 +51,7 @@ def autocast_init(self, device_type, dtype=None, enabled=True, cache_enabled=Non original_interpolate = torch.nn.functional.interpolate @wraps(torch.nn.functional.interpolate) def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments - if antialias or align_corners is not None or mode == 'bicubic': + if mode in {'bicubic', 'bilinear'}: return_device = tensor.device return_dtype = tensor.dtype return original_interpolate(tensor.to("cpu", dtype=torch.float32), size=size, scale_factor=scale_factor, mode=mode, @@ -73,35 +82,46 @@ def as_tensor(data, dtype=None, device=None): return original_as_tensor(data, dtype=dtype, device=device) -if device_supports_fp64 and os.environ.get('IPEX_FORCE_ATTENTION_SLICE', None) is None: - original_torch_bmm = torch.bmm +if can_allocate_plus_4gb: original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention else: # 32 bit attention workarounds for Alchemist: try: - from .attention import torch_bmm_32_bit as original_torch_bmm - from .attention import scaled_dot_product_attention_32_bit as original_scaled_dot_product_attention + from .attention import dynamic_scaled_dot_product_attention as original_scaled_dot_product_attention except Exception: # pylint: disable=broad-exception-caught - original_torch_bmm = torch.bmm original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention +@wraps(torch.nn.functional.scaled_dot_product_attention) +def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, **kwargs): + if query.dtype != key.dtype: + key = key.to(dtype=query.dtype) + if query.dtype != value.dtype: + value = value.to(dtype=query.dtype) + if attn_mask is not None and query.dtype != attn_mask.dtype: + attn_mask = attn_mask.to(dtype=query.dtype) + return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs) # Data Type Errors: +original_torch_bmm = torch.bmm @wraps(torch.bmm) def torch_bmm(input, mat2, *, out=None): if input.dtype != mat2.dtype: mat2 = mat2.to(input.dtype) return original_torch_bmm(input, mat2, out=out) -@wraps(torch.nn.functional.scaled_dot_product_attention) -def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False): - if query.dtype != key.dtype: - key = key.to(dtype=query.dtype) - if query.dtype != value.dtype: - value = value.to(dtype=query.dtype) - if attn_mask is not None and query.dtype != attn_mask.dtype: - attn_mask = attn_mask.to(dtype=query.dtype) - return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal) +# Diffusers FreeU +original_fft_fftn = torch.fft.fftn +@wraps(torch.fft.fftn) +def fft_fftn(input, s=None, dim=None, norm=None, *, out=None): + return_dtype = input.dtype + return original_fft_fftn(input.to(dtype=torch.float32), s=s, dim=dim, norm=norm, out=out).to(dtype=return_dtype) + +# Diffusers FreeU +original_fft_ifftn = torch.fft.ifftn +@wraps(torch.fft.ifftn) +def fft_ifftn(input, s=None, dim=None, norm=None, *, out=None): + return_dtype = input.dtype + return original_fft_ifftn(input.to(dtype=torch.float32), s=s, dim=dim, norm=norm, out=out).to(dtype=return_dtype) # A1111 FP16 original_functional_group_norm = torch.nn.functional.group_norm @@ -133,6 +153,15 @@ def functional_linear(input, weight, bias=None): bias.data = bias.data.to(dtype=weight.data.dtype) return original_functional_linear(input, weight, bias=bias) +original_functional_conv1d = torch.nn.functional.conv1d +@wraps(torch.nn.functional.conv1d) +def functional_conv1d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): + if input.dtype != weight.data.dtype: + input = input.to(dtype=weight.data.dtype) + if bias is not None and bias.data.dtype != weight.data.dtype: + bias.data = bias.data.to(dtype=weight.data.dtype) + return original_functional_conv1d(input, weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups) + original_functional_conv2d = torch.nn.functional.conv2d @wraps(torch.nn.functional.conv2d) def functional_conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): @@ -142,14 +171,15 @@ def functional_conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, bias.data = bias.data.to(dtype=weight.data.dtype) return original_functional_conv2d(input, weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups) -# A1111 Embedding BF16 -original_torch_cat = torch.cat -@wraps(torch.cat) -def torch_cat(tensor, *args, **kwargs): - if len(tensor) == 3 and (tensor[0].dtype != tensor[1].dtype or tensor[2].dtype != tensor[1].dtype): - return original_torch_cat([tensor[0].to(tensor[1].dtype), tensor[1], tensor[2].to(tensor[1].dtype)], *args, **kwargs) - else: - return original_torch_cat(tensor, *args, **kwargs) +# LTX Video +original_functional_conv3d = torch.nn.functional.conv3d +@wraps(torch.nn.functional.conv3d) +def functional_conv3d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): + if input.dtype != weight.data.dtype: + input = input.to(dtype=weight.data.dtype) + if bias is not None and bias.data.dtype != weight.data.dtype: + bias.data = bias.data.to(dtype=weight.data.dtype) + return original_functional_conv3d(input, weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups) # SwinIR BF16: original_functional_pad = torch.nn.functional.pad @@ -164,6 +194,7 @@ def functional_pad(input, pad, mode='constant', value=None): original_torch_tensor = torch.tensor @wraps(torch.tensor) def torch_tensor(data, *args, dtype=None, device=None, **kwargs): + global device_supports_fp64 if check_device(device): device = return_xpu(device) if not device_supports_fp64: @@ -227,7 +258,7 @@ def torch_empty(*args, device=None, **kwargs): original_torch_randn = torch.randn @wraps(torch.randn) def torch_randn(*args, device=None, dtype=None, **kwargs): - if dtype == bytes: + if dtype is bytes: dtype = None if check_device(device): return original_torch_randn(*args, device=return_xpu(device), **kwargs) @@ -250,6 +281,14 @@ def torch_zeros(*args, device=None, **kwargs): else: return original_torch_zeros(*args, device=device, **kwargs) +original_torch_full = torch.full +@wraps(torch.full) +def torch_full(*args, device=None, **kwargs): + if check_device(device): + return original_torch_full(*args, device=return_xpu(device), **kwargs) + else: + return original_torch_full(*args, device=device, **kwargs) + original_torch_linspace = torch.linspace @wraps(torch.linspace) def torch_linspace(*args, device=None, **kwargs): @@ -258,14 +297,6 @@ def torch_linspace(*args, device=None, **kwargs): else: return original_torch_linspace(*args, device=device, **kwargs) -original_torch_Generator = torch.Generator -@wraps(torch.Generator) -def torch_Generator(device=None): - if check_device(device): - return original_torch_Generator(return_xpu(device)) - else: - return original_torch_Generator(device) - original_torch_load = torch.load @wraps(torch.load) def torch_load(f, map_location=None, *args, **kwargs): @@ -276,9 +307,27 @@ def torch_load(f, map_location=None, *args, **kwargs): else: return original_torch_load(f, *args, map_location=map_location, **kwargs) +original_torch_Generator = torch.Generator +@wraps(torch.Generator) +def torch_Generator(device=None): + if check_device(device): + return original_torch_Generator(return_xpu(device)) + else: + return original_torch_Generator(device) + +@wraps(torch.cuda.synchronize) +def torch_cuda_synchronize(device=None): + if check_device(device): + return torch.xpu.synchronize(return_xpu(device)) + else: + return torch.xpu.synchronize(device) + # Hijack Functions: -def ipex_hijacks(): +def ipex_hijacks(legacy=True): + global device_supports_fp64, can_allocate_plus_4gb + if legacy and float(torch.__version__[:3]) < 2.5: + torch.nn.functional.interpolate = interpolate torch.tensor = torch_tensor torch.Tensor.to = Tensor_to torch.Tensor.cuda = Tensor_cuda @@ -289,9 +338,11 @@ def ipex_hijacks(): torch.randn = torch_randn torch.ones = torch_ones torch.zeros = torch_zeros + torch.full = torch_full torch.linspace = torch_linspace - torch.Generator = torch_Generator torch.load = torch_load + torch.Generator = torch_Generator + torch.cuda.synchronize = torch_cuda_synchronize torch.backends.cuda.sdp_kernel = return_null_context torch.nn.DataParallel = DummyDataParallel @@ -302,12 +353,15 @@ def ipex_hijacks(): torch.nn.functional.group_norm = functional_group_norm torch.nn.functional.layer_norm = functional_layer_norm torch.nn.functional.linear = functional_linear + torch.nn.functional.conv1d = functional_conv1d torch.nn.functional.conv2d = functional_conv2d - torch.nn.functional.interpolate = interpolate + torch.nn.functional.conv3d = functional_conv3d torch.nn.functional.pad = functional_pad torch.bmm = torch_bmm - torch.cat = torch_cat + torch.fft.fftn = fft_fftn + torch.fft.ifftn = fft_ifftn if not device_supports_fp64: torch.from_numpy = from_numpy torch.as_tensor = as_tensor + return device_supports_fp64, can_allocate_plus_4gb From ae409e83c939f2c4a997cfb1679bd7cd364baf7e Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 26 Feb 2025 20:56:32 +0900 Subject: [PATCH 25/27] fix: FLUX/SD3 network training not working without caching latents closes #1954 --- flux_train_network.py | 11 ++++++++--- sd3_train_network.py | 11 ++++++++--- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/flux_train_network.py b/flux_train_network.py index ae4b62f5c..26503df1f 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -36,7 +36,12 @@ def __init__(self): self.is_schnell: Optional[bool] = None self.is_swapping_blocks: bool = False - def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]): + def assert_extra_args( + self, + args, + train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], + val_dataset_group: Optional[train_util.DatasetGroup], + ): super().assert_extra_args(args, train_dataset_group, val_dataset_group) # sdxl_train_util.verify_sdxl_training_args(args) @@ -323,7 +328,7 @@ def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> self.noise_scheduler_copy = copy.deepcopy(noise_scheduler) return noise_scheduler - def encode_images_to_latents(self, args, accelerator, vae, images): + def encode_images_to_latents(self, args, vae, images): return vae.encode(images) def shift_scale_latents(self, args, latents): @@ -341,7 +346,7 @@ def get_noise_pred_and_target( network, weight_dtype, train_unet, - is_train=True + is_train=True, ): # Sample noise that we'll add to the latents noise = torch.randn_like(latents) diff --git a/sd3_train_network.py b/sd3_train_network.py index 2f4579492..9438bc7bc 100644 --- a/sd3_train_network.py +++ b/sd3_train_network.py @@ -26,7 +26,12 @@ def __init__(self): super().__init__() self.sample_prompts_te_outputs = None - def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]): + def assert_extra_args( + self, + args, + train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], + val_dataset_group: Optional[train_util.DatasetGroup], + ): # super().assert_extra_args(args, train_dataset_group) # sdxl_train_util.verify_sdxl_training_args(args) @@ -299,7 +304,7 @@ def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.training_shift) return noise_scheduler - def encode_images_to_latents(self, args, accelerator, vae, images): + def encode_images_to_latents(self, args, vae, images): return vae.encode(images) def shift_scale_latents(self, args, latents): @@ -317,7 +322,7 @@ def get_noise_pred_and_target( network, weight_dtype, train_unet, - is_train=True + is_train=True, ): # Sample noise that we'll add to the latents noise = torch.randn_like(latents) From 3d79239be4b20d67faed67c47f693396342e3af4 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 26 Feb 2025 21:21:04 +0900 Subject: [PATCH 26/27] docs: update README to include recent improvements in validation loss calculation --- README.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/README.md b/README.md index 4bbd7617e..3c6993075 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,11 @@ The command to install PyTorch is as follows: ### Recent Updates +Feb 26, 2025: + +- Improve the validation loss calculation in `train_network.py`, `sdxl_train_network.py`, `flux_train_network.py`, and `sd3_train_network.py`. PR [#1903](https://github.com/kohya-ss/sd-scripts/pull/1903) + - The validation loss uses the fixed timestep sampling and the fixed random seed. This is to ensure that the validation loss is not fluctuated by the random values. + Jan 25, 2025: - `train_network.py`, `sdxl_train_network.py`, `flux_train_network.py`, and `sd3_train_network.py` now support validation loss. PR [#1864](https://github.com/kohya-ss/sd-scripts/pull/1864) Thank you to rockerBOO! From ce2610d29b399c8353686f50bf1973457a133153 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Thu, 27 Feb 2025 02:47:04 -0500 Subject: [PATCH 27/27] Change system prompt to inject Prompt Start special token --- library/lumina_train_util.py | 5 +++-- library/strategy_lumina.py | 3 ++- library/train_util.py | 6 ++++-- lumina_train_network.py | 9 ++++++--- 4 files changed, 15 insertions(+), 8 deletions(-) diff --git a/library/lumina_train_util.py b/library/lumina_train_util.py index 11dd3febc..bfc470a93 100644 --- a/library/lumina_train_util.py +++ b/library/lumina_train_util.py @@ -330,11 +330,12 @@ def sample_image_inference( logger.info(f"renorm: {renorm_cfg}") # logger.info(f"sample_sampler: {sampler_name}") - system_prompt = args.system_prompt or "" + system_prompt_special_token = "" + system_prompt = f"{args.system_prompt} {system_prompt_special_token} " if args.system_prompt else "" # Apply system prompt to prompts prompt = system_prompt + prompt - negative_prompt = system_prompt + negative_prompt + negative_prompt = negative_prompt # Get sample prompts from cache if sample_prompts_gemma2_outputs and prompt in sample_prompts_gemma2_outputs: diff --git a/library/strategy_lumina.py b/library/strategy_lumina.py index c9e654236..275e290f6 100644 --- a/library/strategy_lumina.py +++ b/library/strategy_lumina.py @@ -216,7 +216,8 @@ def cache_batch_outputs( assert isinstance(text_encoding_strategy, LuminaTextEncodingStrategy) assert isinstance(tokenize_strategy, LuminaTokenizeStrategy) - captions = [info.system_prompt or "" + info.caption for info in batch] + system_prompt_special_token = "" + captions = [f"{info.system_prompt} {system_prompt_special_token} " if info.system_prompt else "" + info.caption for info in batch] if self.is_weighted: tokens, attention_masks, weights_list = ( diff --git a/library/train_util.py b/library/train_util.py index 0c057bd1a..34b98f89f 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1692,7 +1692,8 @@ def __getitem__(self, index): text_encoder_outputs_list.append(text_encoder_outputs) if tokenization_required: - system_prompt = subset.system_prompt or "" + system_prompt_special_token = "" + system_prompt = f"{subset.system_prompt} {system_prompt_special_token} " if subset.system_prompt else "" caption = self.process_caption(subset, image_info.caption) input_ids = [ids[0] for ids in self.tokenize_strategy.tokenize(system_prompt + caption)] # remove batch dimension # if self.XTI_layers: @@ -2091,7 +2092,8 @@ def load_dreambooth_dir(subset: DreamBoothSubset): else: num_train_images += num_repeats * len(img_paths) - system_prompt = self.system_prompt or subset.system_prompt or "" + system_prompt_special_token = "" + system_prompt = f"{self.system_prompt or subset.system_prompt} {system_prompt_special_token} " if self.system_prompt or subset.system_prompt else "" for img_path, caption, size in zip(img_paths, captions, sizes): info = ImageInfo(img_path, num_repeats, system_prompt + caption, subset.is_reg, img_path) if size is not None: diff --git a/lumina_train_network.py b/lumina_train_network.py index 5f20c0146..c9ef5f02c 100644 --- a/lumina_train_network.py +++ b/lumina_train_network.py @@ -155,7 +155,8 @@ def cache_text_encoder_outputs_if_needed( assert isinstance(tokenize_strategy, strategy_lumina.LuminaTokenizeStrategy) assert isinstance(text_encoding_strategy, strategy_lumina.LuminaTextEncodingStrategy) - system_prompt = args.system_prompt or "" + system_prompt_special_token = "" + system_prompt = f"{args.system_prompt} {system_prompt_special_token} " if args.system_prompt else "" sample_prompts = train_util.load_prompts(args.sample_prompts) sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs with accelerator.autocast(), torch.no_grad(): @@ -164,8 +165,10 @@ def cache_text_encoder_outputs_if_needed( prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", ""), ] - for prompt in prompts: - prompt = system_prompt + prompt + for i, prompt in enumerate(prompts): + # Add system prompt only to positive prompt + if i == 0: + prompt = system_prompt + prompt if prompt in sample_prompts_te_outputs: continue