Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training sf:1 (debluring) #22

Closed
cyprian opened this issue Oct 8, 2023 · 12 comments
Closed

Training sf:1 (debluring) #22

cyprian opened this issue Oct 8, 2023 · 12 comments

Comments

@cyprian
Copy link

cyprian commented Oct 8, 2023

Thank you for providing your code. I already tested the super resolution, and it great. Is it possible to adopt the config to do debluring on lq:256 x gt:256 so without any super resolution? What would I have to change.

@zsyOAOA
Copy link
Owner

zsyOAOA commented Oct 8, 2023

You only need to rewrite the class of dataset. Please refer to this bicubic dataset.

@cyprian
Copy link
Author

cyprian commented Oct 21, 2023

I have changed the Dataset implemetation. However I am getting the following issue in Unet mismatch when I run it.

Traceback (most recent call last):
  File "/content/ResShift/main.py", line 53, in <module>
    trainer.train()
  File "/content/ResShift/trainer.py", line 275, in train
    self.training_step(data)
  File "/content/ResShift/trainer.py", line 638, in training_step
    losses, z_t, z0_pred = compute_losses()
  File "/content/ResShift/models/respace.py", line 47, in training_losses
    return super().training_losses(self._wrap_model(model), *args, **kwargs)
  File "/content/ResShift/models/gaussian_diffusion.py", line 537, in training_losses
    model_output = model(self._scale_input(z_t, t), t, **model_kwargs)
  File "/content/ResShift/models/respace.py", line 63, in __call__
    return self.model(x, new_ts, **kwargs)
  File "/usr/local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/content/ResShift/models/unet.py", line 846, in forward
    x = th.cat([x, lq], dim=1)
RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 16 but got size 32 for tensor number 1 in the list.

I am trying test it on input sizes GT= 64x64 and LQ= 64x64

Here is my implementation of the Dataset:

class PairedData(Dataset):
    def __init__(
            self,
            sf,
            dir_path=None,
            dir_path_lq=None,
            txt_file_path=None,
            txt_file_path_lq=None,
            mean=0.5,
            std=0.5,
            hflip=False,
            rotation=False,
            resize_back=False,
            length=None,
            need_path=False,
            im_exts=['png', 'jpg', 'jpeg', 'JPEG', 'bmp'],
            recursive=False,
            use_sharp=False,
            rescale_gt=True,
            gt_size=256,
            ):

        if txt_file_path is None:
            assert dir_path is not None
            self.file_paths_all = util_common.scan_files_from_folder(dir_path, im_exts, recursive)
        else:
            self.file_paths_all = util_common.readline_txt(txt_file_path)

        # Load low-quality (lq) file paths
        if txt_file_path_lq is None:
            assert dir_path_lq is not None
            self.file_paths_all_lq = util_common.scan_files_from_folder(dir_path_lq, im_exts, recursive)
        else:
            self.file_paths_all_lq = util_common.readline_txt(txt_file_path_lq)

        if length is None:
            self.file_paths = self.file_paths_all
            self.file_paths_lq = self.file_paths_all_lq
        else:
            assert len(self.file_paths_all) >= length
            assert len(self.file_paths_all_lq) >= length
            self.file_paths = random.sample(self.file_paths_all, length)
            self.file_paths_lq = random.sample(self.file_paths_all_lq, length)

        self.sf = sf
        self.mean = mean
        self.std = std
        self.hflip = hflip
        self.rotation = rotation
        self.length = length
        self.need_path = need_path
        self.resize_back = resize_back
        self.use_sharp = use_sharp
        self.rescale_gt = rescale_gt
        self.gt_size = gt_size

        self.transform = get_transforms('default', {'mean': mean, 'std': std})
        if rescale_gt:
            self.smallest_rescaler = SmallestMaxSize(max_size=gt_size)

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, index):
         # Load ground truth image
        im_path = self.file_paths[index]
        im_gt = util_image.imread(im_path, chn='rgb', dtype='float32')

        # Load low quality image
        im_path_lq = self.file_paths_lq[index]
        im_lq = util_image.imread(im_path_lq, chn='rgb', dtype='float32')

        # Augmentation for training
        im_gt = augment(im_gt, hflip=self.hflip, rotation=self.rotation, return_status=False)
        im_lq = augment(im_lq, hflip=self.hflip, rotation=self.rotation, return_status=False)

        # im_lq = np.clip(im_lq, 0.0, 1.0)

        out = {'lq': self.transform(im_lq), 'gt': self.transform(im_gt)}
        if self.need_path:
            out['path'] = im_path  # or you can include both im_path and im_path_lq

        return out

And here is the test config I am running it with:

model:
  target: models.unet.UNetModelSwin
  ckpt_path: ~
  params:
    image_size: 64
    in_channels: 6
    model_channels: 160
    out_channels: 3
    cond_lq: True
    attention_resolutions: [64,32,16,8]
    dropout: 0
    channel_mult: [1, 2, 2, 4]
    num_res_blocks: [2, 2, 2, 2]
    conv_resample: True
    dims: 2
    use_fp16: False
    num_head_channels: 32
    use_scale_shift_norm: True
    resblock_updown: False
    swin_depth: 2
    swin_embed_dim: 192
    window_size: 8
    mlp_ratio: 4

diffusion:
  target: models.script_util.create_gaussian_diffusion
  params:
    sf: 1
    schedule_name: exponential
    schedule_kwargs:
      power: 0.3
    etas_end: 0.99
    steps: 15
    min_noise_level: 0.04
    kappa: 1.0
    weighted_mse: False
    predict_type: xstart
    timestep_respacing: ~
    scale_factor: 1.0
    normalize_input: True
    latent_flag: True

autoencoder:
  target: ldm.models.autoencoder.VQModelTorch
  ckpt_path: weights/autoencoder_vq_f4.pth
  use_fp16: True
  params:
    embed_dim: 3
    n_embed: 8192
    ddconfig:
      double_z: False
      z_channels: 3
      resolution: 64
      in_channels: 3
      out_ch: 3
      ch: 128
      ch_mult:
      - 1
      - 2
      - 4
      num_res_blocks: 2
      attn_resolutions: []
      dropout: 0.0
      padding_mode: zeros

data:
  train:
    type: paired
    params:
      sf: 1
      dir_path: datasets/clean/64/train/clean
      dir_path_lq: datasets/clean/64/train/dirty
      txt_file_path: ~
      txt_file_path_lq: ~
      mean: 0.5
      std: 0.5
      hflip: False
      rotation: False
      resize_back: False
      length: ~
      need_path: False
      im_exts: ['png', 'jpg', 'jpeg', 'JPEG', 'bmp']
      recursive: False
      use_sharp: False
      rescale_gt: False
      gt_size: 64
  val:
    type: paired
    params:
      sf: 1
      dir_path: datasets/clean/64/val/clean
      dir_path_lq: datasets/clean/64/val/dirty
      txt_file_path: ~
      txt_file_path_lq: ~
      mean: 0.5
      std: 0.5
      hflip: False
      rotation: False
      resize_back: False
      length: 5 # number of images to be evaluated
      need_path: False
      im_exts: ['png', 'jpg', 'jpeg', 'JPEG', 'bmp']
      recursive: False
      use_sharp: False
      rescale_gt: False
      gt_size: 64

train:
  lr: 5e-5
  batch: [64, 8]   # batchsize for training and validation
  use_fp16: False
  microbatch: 16
  seed: 123456
  global_seeding: False
  prefetch_factor: 4
  num_workers: 8
  ema_rate: 0.999
  iterations: 500000
  milestones: [5000, 500000]
  weight_decay: 0
  save_freq: 100
  val_freq: 100
  log_freq: [10, 100, 1] #[training loss, training images, val images]
  save_images: True  # save the images of tensorboard logging
  use_ema_val: True

@cyprian
Copy link
Author

cyprian commented Oct 26, 2023

Hi @zsyOAOA would you be able to provide any guidance on where to look in the code to enable training of same size inputs and outputs?

@zsyOAOA
Copy link
Owner

zsyOAOA commented Oct 27, 2023

@cyprian For the same size inputs and outputs, try to set the diffusion.params.sf = 1 and model.params.in_channels=51 (51=3+48) in the config file.

Note that the vqgan downsamples the input 4 times, we thus unshuffles the input to the dimension of H/4 x W/4 x 48.

@cyprian
Copy link
Author

cyprian commented Nov 7, 2023

@zsyOAOA the model.params.in_channels=51 still gave me tensor size mismatch on the vqgan downsampling. I went with a different approach of encoding both LQ and GT via VQGAN and that seams to work, but I am loosing some of the fidelity with this 4x downsampling. Is there a VGGAN model that does 2x downsampling? I could not find where you got the autoencoder_vq_f4.pth from.

@zsyOAOA
Copy link
Owner

zsyOAOA commented Nov 7, 2023

I haven't a VQGAN model with 2x downsampling. For the 4x model, please find it in this link.

@cyprian
Copy link
Author

cyprian commented Nov 10, 2023

@zsyOAOA I would like to train my own VQGAN model to create better image representation for my image class. Can you tell me how you trained your VQGAN?

@zsyOAOA
Copy link
Owner

zsyOAOA commented Nov 10, 2023

I haven't trained VQGAN. If you want to train VQGAN, please refer to this repo. @cyprian

@cyprian
Copy link
Author

cyprian commented Nov 10, 2023

Thank you for a quick reply. In the repo I don't see the weights for the VQGAN f4, that you are using in your code. I am trying to find config for that training so that I could just use my own Dataset. Did you download your weights from that repo? @zsyOAOA

@zsyOAOA
Copy link
Owner

zsyOAOA commented Nov 10, 2023

The checkpoint I used is extracted from the latent diffusion model for image super-resolution.

@cyprian
Copy link
Author

cyprian commented Nov 10, 2023

Ok. So if I understand correctly you extracted just extracted the first_stage_model from that super resolution checkpoint? @zsyOAOA (BTW, I really appreciate your help)

@zsyOAOA
Copy link
Owner

zsyOAOA commented Nov 10, 2023

Yes. @cyprian

@cyprian cyprian closed this as completed Nov 10, 2023
@aleksmirosh aleksmirosh mentioned this issue Nov 16, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants