-
Notifications
You must be signed in to change notification settings - Fork 15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Updated Weights & Biases logging + Made codebase pip installable #22
Conversation
@soumik12345 thanks for your PR! I think we would want to keep the codebase a bit more lightweight for people to be able to navigate around easily and customize as per their needs. So, with that in mind, I would request to introduce the changes only pertaining to WandB logging (images). Also, please keep the code-lints unchanged. Here's what we can do I think (@deep-diver feel free to chime in):
|
Hi @sayakpaul I will edit the readme to include the changes, once you approve the changes. |
I still don't see the prompt in the validation panel (not the table but the panel). |
Hi @sayakpaul, I updated the logging of generated images from validation prompt on W&B panel. Here's a sample run: https://wandb.ai/geekyrakshit/dreambooth-keras/runs/8shvurfp |
src/utils.py
Outdated
for prompt in tqdm(prompts): | ||
images_dreamboothed = sd_model.text_to_image(prompt, batch_size=num_imgs_to_gen) | ||
wandb.log( | ||
{ | ||
f"validation/Prompt: {prompt}": [ | ||
wandb.Image(PIL.Image.fromarray(image), caption=f"{i}: {prompt}") | ||
for i, image in enumerate(images_dreamboothed) | ||
] | ||
} | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
train_dreambooth.py
Outdated
num_imgs_to_gen=args.num_images_to_generate, | ||
), | ||
] | ||
if args.log_wandb |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please refactor it such that the else block becomes more pronounced. No comprehensions here, please.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not addressed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sayakpaul made the change.
train_dreambooth.py
Outdated
else [ | ||
tf.keras.callbacks.ModelCheckpoint( | ||
ckpt_path_prefix, save_weights_only=True, monitor="loss", mode="min" | ||
) | ||
] | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The problem with only saving the low-loss checkpoints in DreamBooth is that the loss curves are not that informative if you take a look at them (think of vanilla GAN training loss curves).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, this is why I just logged weights after the end of training.
In my experience, the low-loss checkpoints often lead to inferior results.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll remove the monitor in that case.
train_dreambooth.py
Outdated
DreamBoothCheckpointCallback( | ||
ckpt_path_prefix, save_weights_only=True, monitor="loss", mode="min" | ||
), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the changes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One change remains.
train_dreambooth.py
Outdated
if args.log_wandb: | ||
# log training metrics to Weights & Biases | ||
callbacks.append(WandbMetricsLogger(log_freq="batch")) | ||
# log model checkpoints to Weights & Biases as artifacts | ||
callbacks.append( | ||
DreamBoothCheckpointCallback(ckpt_path_prefix, save_weights_only=True) | ||
) | ||
# perform inference on validation prompts at the end of every epoch and | ||
# log the resuts to a Weights & Biases table | ||
callbacks.append( | ||
QualitativeValidationCallback( | ||
img_heigth=args.img_resolution, | ||
img_width=args.img_resolution, | ||
prompts=validation_prompts, | ||
num_imgs_to_gen=args.num_images_to_generate, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perfect 💯
LGTM! @deep-diver what do you think? |
train_dreambooth.py
Outdated
else: | ||
callbacks.append( | ||
tf.keras.callbacks.ModelCheckpoint(ckpt_path_prefix, save_weights_only=True) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DreamBoothCheckpointCallback
is a subclass of
WandbModelCheckpoint
which ultimately inherits tf.keras.callbacks.ModelCheckpoint
.
I think we don't need this separate else
branch. Instead, we could simply pass args.log_wandb
to DreamBoothCheckpointCallback
to decide if ckpts should be uploaded to W&B.
train_dreambooth.py
Outdated
validation_prompts = [f"A photo of {args.unique_id} {args.class_category} in a bucket"] | ||
if args.validation_prompts is not None: | ||
validation_prompts += args.validation_prompts |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: if validation prompts are provided, the default prompt might not make sense since prompts depend on different situations (maybe =
instead of +=
?)
|
||
train(dreambooth_trainer, train_dataset, args.max_train_steps, callbacks) | ||
|
||
if args.log_wandb: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: maybe it is better to reuse StableDiffusion
from QualitativeValidationCallback
since it is always there under args.log_wandb=True
condition.
…ualitativeValidationCallback + updated readme to reflect changes
LGTM! Thanks for your contribution! |
Feel free to add a note of these changes to README, @soumik12345! Thank you. |
@sayakpaul updated the README to reflect the changes. |
Changes included in this PR:
WandbMetricsLogger
callback logging training metrics to Weights & Biases.WandbModelCheckpoint
for logging model checkpoints as Weights & Biases artifacts.QualitativeValidationCallback
for visualizing the generations corresponding to multiple prompts in an epoch-wise manner using Weights & Biases Table.validation_prompt
argument in the training script to accept multiple prompts for visualizing the generations during training.Todos:
Sample Weights & Biases run: https://wandb.ai/geekyrakshit/dreambooth-keras/runs/h0kdq2dw