Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated Weights & Biases logging + Made codebase pip installable #22

Merged
merged 27 commits into from
Feb 15, 2023

Conversation

soumik12345
Copy link
Contributor

@soumik12345 soumik12345 commented Feb 10, 2023

Changes included in this PR:

  • Added use of the WandbMetricsLogger callback logging training metrics to Weights & Biases.
  • Added use of WandbModelCheckpoint for logging model checkpoints as Weights & Biases artifacts.
  • Added an additional callback QualitativeValidationCallback for visualizing the generations corresponding to multiple prompts in an epoch-wise manner using Weights & Biases Table.
  • Updated the validation_prompt argument in the training script to accept multiple prompts for visualizing the generations during training.
  • Made the codebase pip installable

Todos:

  • Make sure to log the results corresponding to all prompts in both table and media panel.
  • Update README to reflect all the changes.

Sample Weights & Biases run: https://wandb.ai/geekyrakshit/dreambooth-keras/runs/h0kdq2dw

@soumik12345 soumik12345 marked this pull request as draft February 10, 2023 16:35
@sayakpaul
Copy link
Owner

sayakpaul commented Feb 10, 2023

@soumik12345 thanks for your PR!

I think we would want to keep the codebase a bit more lightweight for people to be able to navigate around easily and customize as per their needs.

So, with that in mind, I would request to introduce the changes only pertaining to WandB logging (images). Also, please keep the code-lints unchanged.

Here's what we can do I think (@deep-diver feel free to chime in):

  • Keep your logging (image) related changes. We will mention you from the repository crediting your work (of course).
  • You package this repository from a separate repository (with due credits).

train_dreambooth.py Outdated Show resolved Hide resolved
@soumik12345
Copy link
Contributor Author

soumik12345 commented Feb 10, 2023

Hi @sayakpaul
I updated the PR as per your request. Here's a sample wandb run: https://wandb.ai/geekyrakshit/dreambooth-keras/runs/4wydmf8w

I will edit the readme to include the changes, once you approve the changes.

@soumik12345 soumik12345 marked this pull request as ready for review February 10, 2023 20:11
@sayakpaul
Copy link
Owner

sayakpaul commented Feb 11, 2023

Hi @sayakpaul I updated the PR as per your request. Here's a sample wandb run: https://wandb.ai/geekyrakshit/dreambooth-keras/runs/4wydmf8w

I will edit the readme to include the changes, once you approve the changes.

I still don't see the prompt in the validation panel (not the table but the panel).

@soumik12345

@soumik12345
Copy link
Contributor Author

Hi @sayakpaul, I updated the logging of generated images from validation prompt on W&B panel. Here's a sample run: https://wandb.ai/geekyrakshit/dreambooth-keras/runs/8shvurfp

src/utils.py Outdated
Comment on lines 25 to 34
for prompt in tqdm(prompts):
images_dreamboothed = sd_model.text_to_image(prompt, batch_size=num_imgs_to_gen)
wandb.log(
{
f"validation/Prompt: {prompt}": [
wandb.Image(PIL.Image.fromarray(image), caption=f"{i}: {prompt}")
for i, image in enumerate(images_dreamboothed)
]
}
)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

num_imgs_to_gen=args.num_images_to_generate,
),
]
if args.log_wandb
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please refactor it such that the else block becomes more pronounced. No comprehensions here, please.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not addressed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sayakpaul made the change.

Comment on lines 188 to 193
else [
tf.keras.callbacks.ModelCheckpoint(
ckpt_path_prefix, save_weights_only=True, monitor="loss", mode="min"
)
]
)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem with only saving the low-loss checkpoints in DreamBooth is that the loss curves are not that informative if you take a look at them (think of vanilla GAN training loss curves).

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, this is why I just logged weights after the end of training.

In my experience, the low-loss checkpoints often lead to inferior results.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll remove the monitor in that case.

Comment on lines 177 to 179
DreamBoothCheckpointCallback(
ckpt_path_prefix, save_weights_only=True, monitor="loss", mode="min"
),
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Owner

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the changes.

Copy link
Owner

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One change remains.

Comment on lines 175 to 189
if args.log_wandb:
# log training metrics to Weights & Biases
callbacks.append(WandbMetricsLogger(log_freq="batch"))
# log model checkpoints to Weights & Biases as artifacts
callbacks.append(
DreamBoothCheckpointCallback(ckpt_path_prefix, save_weights_only=True)
)
# perform inference on validation prompts at the end of every epoch and
# log the resuts to a Weights & Biases table
callbacks.append(
QualitativeValidationCallback(
img_heigth=args.img_resolution,
img_width=args.img_resolution,
prompts=validation_prompts,
num_imgs_to_gen=args.num_images_to_generate,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perfect 💯

@sayakpaul
Copy link
Owner

LGTM! @deep-diver what do you think?

Comment on lines 192 to 195
else:
callbacks.append(
tf.keras.callbacks.ModelCheckpoint(ckpt_path_prefix, save_weights_only=True)
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DreamBoothCheckpointCallback is a subclass of
WandbModelCheckpoint which ultimately inherits tf.keras.callbacks.ModelCheckpoint.

I think we don't need this separate else branch. Instead, we could simply pass args.log_wandb to DreamBoothCheckpointCallback to decide if ckpts should be uploaded to W&B.

Comment on lines 142 to 144
validation_prompts = [f"A photo of {args.unique_id} {args.class_category} in a bucket"]
if args.validation_prompts is not None:
validation_prompts += args.validation_prompts
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: if validation prompts are provided, the default prompt might not make sense since prompts depend on different situations (maybe = instead of +=?)


train(dreambooth_trainer, train_dataset, args.max_train_steps, callbacks)

if args.log_wandb:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe it is better to reuse StableDiffusion from QualitativeValidationCallback since it is always there under args.log_wandb=True condition.

@deep-diver
Copy link
Collaborator

LGTM! Thanks for your contribution!

@sayakpaul
Copy link
Owner

Feel free to add a note of these changes to README, @soumik12345! Thank you.

@soumik12345
Copy link
Contributor Author

Feel free to add a note of these changes to README, @soumik12345! Thank you.

@sayakpaul updated the README to reflect the changes.

@sayakpaul sayakpaul merged commit f7be949 into sayakpaul:main Feb 15, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants