Skip to content

Commit

Permalink
trn scripts added
Browse files Browse the repository at this point in the history
  • Loading branch information
radimspetlik committed Jan 6, 2024
1 parent 47bf010 commit 8620bbb
Show file tree
Hide file tree
Showing 8 changed files with 138 additions and 80 deletions.
22 changes: 21 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ Note that we only support pipenv installation. If you do not have pipenv install

## Evaluation

The download process of pretrained models, and benchmark datasets is somewhat cumbersome, so we will eventually put everything in one script (use issues to show your interest in doing so, please).

#### 1. Download the pre-trained models

The pre-trained SI-DDPM-FMO models as reported in the paper are available [here](https://drive.google.com/drive/folders/1sS67PAuaKzffSOw6h0pwhKE-Wsvz6nA8?usp=drive_link).
Expand All @@ -47,7 +49,13 @@ To evaluate the baseline model on the FMO benchmark dataset, run:

## Training

The training scripts will be published soon.
For training, you first need to generate the training dataset as described bellow, and place it in the `datasets` dir. Then, simply run

```./trn_siddpmfmo.sh```

Inspecting the script, you find the training parameters. The training is expected to be run on multiple GPUs, with the DDP library of pytorch. Modify according to your needs.

The baseline training is run in the same way, but using the `trn_baseline.sh` script.

### Synthetic dataset generation
For the dataset generation, please download:
Expand Down Expand Up @@ -82,3 +90,15 @@ If you use this repository, please cite the following [publication](https://arxi
year = {2024}
}
```

# Code Sources

We use a wild mix of the following repositories in our implementation:

* [RePaint](https://github.com/andreas128/RePaint.git)
* [GuidedDiffusion](https://github.com/openai/guided-diffusion)
* [ImprovedDiffusion](https://github.com/openai/improved-diffusion)
* [DeFMO](https://github.com/rozumden/DeFMO)
* [FMO-Deblurring-Benchmark](https://github.com/rozumden/fmo-deblurring-benchmark)

Therefore a big thanks to the authors of these repositories!
24 changes: 3 additions & 21 deletions confs/baseline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ data:
max_len: 8
paths:
default:
data_filepath: ./data/ShapeNetv2/ShapeBlur1000STL.hdf5
data_dir: ./datasets/ShapeBlur1000STL/
# data_filepath: ./datasets/ShapeBlur1000STL.hdf5
srs: ./experiments/mask/srs/
gts: ./experiments/mask/gts/
eval:
Expand All @@ -104,25 +105,6 @@ data:
paths:
default:
data_dir: ./datasets/ShapeBlur1000STL/
paper_face_mask:
mask_loader: true
gt_path: ./datasets/gts/face
mask_path: ./datasets/gt_keep_masks/face
image_size: 256
class_cond: false
deterministic: true
random_crop: false
random_flip: false
return_dict: true
drop_last: false
return_dataloader: true
offset: 0
max_len: 8
paths:
srs: ./log/face_example/inpainted
lrs: ./log/face_example/gt_masked
gts: ./log/face_example/gt
gt_keep_masks: ./log/face_example/gt_keep_mask
test:
seq_24_sanity:
mask_loader: false
Expand All @@ -139,6 +121,6 @@ data:
max_len: 8
paths:
default:
data_dir: ./data/ShapeNetv2/ShapeBlur20STL/
data_dir: ./datasets/ShapeBlur20STL/
srs: ./experiments/mask/srs/
gts: ./experiments/mask/gts/
27 changes: 5 additions & 22 deletions confs/siddpmfmo.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#
# This repository was forked from https://github.com/openai/guided-diffusion, which is under the MIT license

mode: sidefmo
seed: 0
attention_resolutions: 32,16,8
class_cond: false
Expand All @@ -29,7 +30,7 @@ schedule_sampler: loss-second-moment
lr: 1e-4
weight_decay: 0.0
lr_anneal_steps: 0
batch_size: 5
batch_size: 1
microbatch: 1 # -1 disables microbatches
ema_rate: "0.9999" # comma-separated list of EMA values
log_interval: 10
Expand Down Expand Up @@ -91,7 +92,8 @@ data:
max_len: 8
paths:
default:
data_filepath: ./data/ShapeNetv2/ShapeBlur1000STL.hdf5
data_dir: ./datasets/ShapeBlur1000STL/
# data_filepath: ./datasets/ShapeBlur1000STL.hdf5
srs: ./experiments/mask/srs/
gts: ./experiments/mask/gts/
eval:
Expand All @@ -102,25 +104,6 @@ data:
paths:
default:
data_dir: ./datasets/ShapeBlur1000STL/
paper_face_mask:
mask_loader: true
gt_path: ./datasets/gts/face
mask_path: ./datasets/gt_keep_masks/face
image_size: 256
class_cond: false
deterministic: true
random_crop: false
random_flip: false
return_dict: true
drop_last: false
return_dataloader: true
offset: 0
max_len: 8
paths:
srs: ./log/face_example/inpainted
lrs: ./log/face_example/gt_masked
gts: ./log/face_example/gt
gt_keep_masks: ./log/face_example/gt_keep_mask
test:
seq_24_sanity:
mask_loader: false
Expand All @@ -137,6 +120,6 @@ data:
max_len: 8
paths:
default:
data_dir: ./data/ShapeNetv2/ShapeBlur20STL/
data_dir: ./datasets/ShapeNetv2/ShapeBlur20STL/
srs: ./experiments/mask/srs/
gts: ./experiments/mask/gts/
59 changes: 24 additions & 35 deletions guided_diffusion/gaussian_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,9 +560,7 @@ def training_losses(self, model, x, t, t_alphas, model_kwargs=None, noise=None):
:return: a dict with the key "loss" containing a tensor of shape [N].
Some mean or variance settings may also have other keys.
"""
xs_blurry_m1p1, bg_m1p1, mask_bool_0_1, x_start = x

x1_blurry_m1p1, x2_blurry_m1p1 = th.split(xs_blurry_m1p1, 3, 1)
x_blurry_m1p1, bg_m1p1, mask_bool_0_1, x_start = x

if model_kwargs is None:
model_kwargs = {}
Expand All @@ -573,24 +571,22 @@ def training_losses(self, model, x, t, t_alphas, model_kwargs=None, noise=None):
terms = {}

if self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
model_output1, features1 = model(th.cat((x_t, x1_blurry_m1p1), dim=1),
self._scale_timesteps(t), **model_kwargs)
_, features2 = model(th.cat((x_t, x2_blurry_m1p1), dim=1),
self._scale_timesteps(t), **model_kwargs)
model_output, _ = model(th.cat((x_t, x_blurry_m1p1), dim=1),
self._scale_timesteps(t), **model_kwargs)

if self.model_var_type in [
ModelVarType.LEARNED,
ModelVarType.LEARNED_RANGE,
]:
B, C = x_t.shape[:2]
C = self.conf.out_channels
assert model_output1.shape == (B, C * 2, *x_t.shape[2:])
model_output1, model_var_values = th.split(model_output1, C, dim=1)
assert model_output.shape == (B, C * 2, *x_t.shape[2:])
model_output, model_var_values = th.split(model_output, C, dim=1)
# Learn the variance using the variational bound, but don't let
# it affect our mean prediction.
frozen_out = th.cat([model_output1.detach(), model_var_values], dim=1)
frozen_out = th.cat([model_output.detach(), model_var_values], dim=1)
terms["vb"] = self._vb_terms_bpd(model=lambda *args, r=frozen_out: r,
x_start=x_start, x_blurry=x1_blurry_m1p1,
x_start=x_start, x_blurry=x_blurry_m1p1,
usr_mask_m1p1=mask_bool_0_1, x_t=x_t,
bg_m1p1=bg_m1p1, t=t_alphas, clip_denoised=False)["output"]
if self.loss_type == LossType.RESCALED_MSE:
Expand All @@ -602,25 +598,20 @@ def training_losses(self, model, x, t, t_alphas, model_kwargs=None, noise=None):
ModelMeanType.START_X: x_start,
ModelMeanType.EPSILON: noise,
}[self.model_mean_type]
assert model_output1.shape == target.shape == x_start.shape
assert model_output.shape == target.shape == x_start.shape

c = 0.3
terms["mse"] = \
c * mean_flat((mask_bool_0_1 * (target[:, 1::4] - model_output1[:, 1::4])) ** 2) \
+ c * mean_flat((mask_bool_0_1 * (target[:, 2::4] - model_output1[:, 2::4])) ** 2) \
+ c * mean_flat((mask_bool_0_1 * (target[:, 3::4] - model_output1[:, 3::4])) ** 2)
terms["mse_mask"] = 2.0 * mean_flat((target[:, 0::4] - model_output1[:, 0::4]) ** 2)

res = 0
for f1, f2 in zip(features1, features2):
res = res + mean_flat((f1 - f2) ** 2)
terms["feature_loss"] = self.conf.feature_loss_weight * res
c * mean_flat((mask_bool_0_1 * (target[:, 1::4] - model_output[:, 1::4])) ** 2) \
+ c * mean_flat((mask_bool_0_1 * (target[:, 2::4] - model_output[:, 2::4])) ** 2) \
+ c * mean_flat((mask_bool_0_1 * (target[:, 3::4] - model_output[:, 3::4])) ** 2)
terms["mse_mask"] = 2.0 * mean_flat((target[:, 0::4] - model_output[:, 0::4]) ** 2)

if "vb" in terms:
terms["loss"] = terms["mse"] + terms["vb"]
else:
terms["loss"] = terms["mse"]
terms["loss"] = terms["loss"] + terms["mse_mask"] + terms["feature_loss"]
terms["loss"] = terms["loss"] + terms["mse_mask"]
else:
raise NotImplementedError(self.loss_type)

Expand All @@ -639,9 +630,7 @@ def background_training_losses(self, model, x, t, t_alphas, model_kwargs=None, n
:return: a dict with the key "loss" containing a tensor of shape [N].
Some mean or variance settings may also have other keys.
"""
xs_blurry_m1p1, bg_m1p1, mask_bool_0_1, x_tsr_m1p1 = x

x1_blurry_m1p1, x2_blurry_m1p1 = th.split(xs_blurry_m1p1, 3, 1)
x_blurry_m1p1, bg_m1p1, mask_bool_0_1, x_tsr_m1p1 = x

if model_kwargs is None:
model_kwargs = {}
Expand All @@ -652,21 +641,21 @@ def background_training_losses(self, model, x, t, t_alphas, model_kwargs=None, n
terms = {}

if self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
model_output1, features1 = model(th.cat((x_t, x1_blurry_m1p1), dim=1),
self._scale_timesteps(t), **model_kwargs)
model_output, _ = model(th.cat((x_t, x_blurry_m1p1), dim=1),
self._scale_timesteps(t), **model_kwargs)
if self.model_var_type in [
ModelVarType.LEARNED,
ModelVarType.LEARNED_RANGE,
]:
B, C = x_t.shape[:2]
C = self.conf.out_channels
assert model_output1.shape == (B, C * 2, *x_t.shape[2:])
model_output1, model_var_values = th.split(model_output1, C, dim=1)
assert model_output.shape == (B, C * 2, *x_t.shape[2:])
model_output, model_var_values = th.split(model_output, C, dim=1)
# Learn the variance using the variational bound, but don't let
# it affect our mean prediction.
frozen_out = th.cat([model_output1.detach(), model_var_values], dim=1)
frozen_out = th.cat([model_output.detach(), model_var_values], dim=1)
terms["vb"] = self._vb_terms_bpd(model=lambda *args, r=frozen_out: r,
x_start=bg_m1p1, x_blurry=x1_blurry_m1p1,
x_start=bg_m1p1, x_blurry=x_blurry_m1p1,
usr_mask_m1p1=mask_bool_0_1, x_t=x_t,
bg_m1p1=bg_m1p1, t=t_alphas, clip_denoised=False)["output"]
if self.loss_type == LossType.RESCALED_MSE:
Expand All @@ -678,13 +667,13 @@ def background_training_losses(self, model, x, t, t_alphas, model_kwargs=None, n
ModelMeanType.START_X: bg_m1p1,
ModelMeanType.EPSILON: noise,
}[self.model_mean_type]
assert model_output1.shape == target.shape == bg_m1p1.shape
assert model_output.shape == target.shape == bg_m1p1.shape

c = 0.3
terms["mse"] = \
c * mean_flat((target[:, 1::4] - model_output1[:, 1::4]) ** 2) \
+ c * mean_flat((target[:, 2::4] - model_output1[:, 2::4]) ** 2) \
+ c * mean_flat((target[:, 0::4] - model_output1[:, 0::4]) ** 2)
c * mean_flat((target[:, 1::4] - model_output[:, 1::4]) ** 2) \
+ c * mean_flat((target[:, 2::4] - model_output[:, 2::4]) ** 2) \
+ c * mean_flat((target[:, 0::4] - model_output[:, 0::4]) ** 2)

if "vb" in terms:
terms["loss"] = terms["mse"] + terms["vb"]
Expand Down
2 changes: 1 addition & 1 deletion guided_diffusion/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def forward_backward(self, batch, cond):
temporal_sr_steps,
dist_util.dev())

micro_img_blurry_m1p1 = th.cat(list(imgs_m1p1[k]
micro_img_blurry_m1p1 = th.cat(list(imgs_m1p1[k:k+1]
for k in range(len(imgs_m1p1))),
dim=1).to(dist_util.dev())
micro_bg_m1p1 = th.repeat_interleave(bg_m1p1,
Expand Down
66 changes: 66 additions & 0 deletions scripts/image_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
"""
Train a diffusion model on images.
"""

import argparse
import os
from multiprocessing import set_start_method
from shutil import copyfile

import conf_mgt
from guided_diffusion import dist_util, logger
from guided_diffusion.resample import create_named_schedule_sampler
from guided_diffusion.script_util import (
model_and_diffusion_defaults,
create_model_and_diffusion,
select_args,
)
from guided_diffusion.train_util import TrainLoop
from utils import yamlread


def main(conf: conf_mgt.Default_Conf):
global_rank = dist_util.setup_dist()
experiment_dir = os.getenv("EXPERIMENT_DIR") if os.getenv("EXPERIMENT_DIR") is not None else './experiments/debug/'

logger.configure(dir=experiment_dir, global_rank=global_rank)
logger.log(f'experiment_dir: {experiment_dir}')
logger.log("creating model and diffusion...")

conf_path = args.get('conf_path')
if os.path.isdir(experiment_dir) and os.path.isfile(conf_path):
_, conf_filename = os.path.split(conf_path)
copyfile(conf_path, os.path.join(experiment_dir, conf_filename))

model, diffusion = create_model_and_diffusion(conf=conf,
**select_args(
conf,
model_and_diffusion_defaults().keys()
))
model.to(dist_util.dev())
schedule_sampler = create_named_schedule_sampler(conf.schedule_sampler, diffusion)

logger.log("creating data loader...")
train_name = conf.get_default_train_name()
data = conf.get_dataloader(dsName=train_name)

logger.log("training...")
TrainLoop(model=model,
diffusion=diffusion,
data=data,
schedule_sampler=schedule_sampler,
conf=conf
).run_loop()


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--conf_path', type=str, required=False, default=None)
args = vars(parser.parse_args())

conf_arg = conf_mgt.conf_base.Default_Conf()
conf_arg.update(yamlread(args.get('conf_path')))

set_start_method('forkserver', force=True)

main(conf_arg)
9 changes: 9 additions & 0 deletions trn_baseline.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#!/usr/bin/bash

SCRIPTS_DIR="${HOME}/SI-DDPM-FMO/"

cd "${SCRIPTS_DIR}" || exit

source "$(pipenv --venv)/bin/activate"

python -m torch.distributed.run --standalone --nnodes=1 --nproc_per_node=1 scripts/image_train.py --conf_path confs/baseline.yml
9 changes: 9 additions & 0 deletions trn_siddmpfmo.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#!/usr/bin/bash

SCRIPTS_DIR="${HOME}/SI-DDPM-FMO/"

cd "${SCRIPTS_DIR}" || exit

source "$(pipenv --venv)/bin/activate"

python -m torch.distributed.run --standalone --nnodes=1 --nproc_per_node=1 scripts/image_train.py --conf_path confs/siddpmfmo.yml

0 comments on commit 8620bbb

Please sign in to comment.