Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vb/imagesource #221

Merged
merged 68 commits into from Jun 6, 2023
Merged

Vb/imagesource #221

merged 68 commits into from Jun 6, 2023

Conversation

vineetbansal
Copy link
Collaborator

@vineetbansal vineetbansal commented Feb 10, 2023

I'm starting this new PR (not quite there yet, but close) that demonstrates a new cryodrgn.source.ImageSource class that should simplify calling code quite a bit.

tests/test_source.py is a good place to see how to use this. The basic idea is:

src = ImageSource.from_file(<mrc/star/cs/txt>, lazy=True)
im = src.images(<slice>)  # To get torch.Tensor, or
im = src[<slice>]         # To get torch.Tensor

# Exactly the same usage as above when lazy=False
src = ImageSource.from_file(<mrc/star/cs/txt>, lazy=False)
im = src.images(<slice>)  # To get torch.Tensor, or
im = src[<slice>]         # To get torch.Tensor

For utilizing this during training:

    data = dataset.ImageDataset(
        mrcfile=args.particles,
        tilt_mrcfile=args.tilt,
        lazy=args.lazy,
        ...
     )
     ...
    data_generator = DataLoader(
        data,
        num_workers=num_workers_per_gpu,
        sampler=BatchSampler(
            RandomSampler(data), batch_size=args.batch_size, drop_last=False
        ),
        batch_size=None,
    )
   ...
   
   for epoch in range(...):
          for minibatch in data_generator:
          ... 

Quite a few refactorings have been done in cryodrgn already to use this new functionality, but I'm certain there are more, and till the test coverage reaches close to 100% and I'm certain I haven't overlooked any scripts/utilities, I will keep this is draft mode.

We're now also using torch.Tensor for all image data throughout our code (except a few places where saving to .pkl or .mrc files), and torch.fft for all FFT functions.

A large number of changes in the code have to do with the fact that trying to use torch.Tensor/torch.fft by default would not have worked without these changes, because the codebase was assuming np.arrays at lots of places.

There is hardly any documentation, and I need to address some newly introduced pyright complaints too.

Performance on lazy/eager datasets is comparable to the old implementation (actually just slightly faster). This is unsurprising because while the new API makes it easier for us to implement chunking, we haven't (yet) changed any logic to do so. I'll add some performance graphs to this PR as we proceed through the chunking logic.

Todo before merge:

  • Add DataShuffler support to all scripts that use a DataLoader
  • Add unit test for data shuffler (test that it returns all the elements; test that if you iterate twice it gives different orders)
  • Refactor TxtFileSource to inherit from _MRCDataFrameSource
  • Decide on whether to merge require_adjacent argument for DataShuffler's use #287 [YES]
  • Decide on whether to merge not supporting getitem on ImageSource anymore #288 [NO]
  • Set --num-workers 1 default everywhere
  • Clean up --max-threads argument in train_vae - decide whether to always use 1 thread, or keep the argument and change comment from FFT threads to data-loading threads.
  • Merge master
  • Test it on real data [blocked on cluster down]

@vineetbansal vineetbansal marked this pull request as draft February 10, 2023 19:43
@adamlerer
Copy link
Collaborator

Hey @vineetbansal , this looks wild! Please let me know when I should start reviewing this PR, I expect it will take some time to review :)

@vineetbansal vineetbansal marked this pull request as ready for review February 20, 2023 16:23
@vineetbansal
Copy link
Collaborator Author

Hi @adamlerer. I think the introduction of an ImageSource and its incorporation into the rest of cryodrgn passes basic sanity checks so that you and @zhonge can start looking at this. Yes, I suspect there will be some back-and-forth so I don't want to hide this PR anymore!

Copy link
Collaborator

@adamlerer adamlerer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is super cool! So many useful improvements being worked on here. I'm going to split the review into smaller chunks since there's so much to get through.

I'm seeing a lot of complexity and dangerous code in here that looks like it's the result of a premature optimization around doing the FFT in-place on CPU? I think you just want to do the FFT on the GPU on the current batches, and don't worry about in-place. That's something I'd be interested in discussing if there's another way to accomplish this goal.

cryodrgn/source.py Outdated Show resolved Hide resolved
cryodrgn/source.py Outdated Show resolved Hide resolved
cryodrgn/source.py Outdated Show resolved Hide resolved
cryodrgn/source.py Show resolved Hide resolved
cryodrgn/source.py Outdated Show resolved Hide resolved
cryodrgn/mrc.py Outdated Show resolved Hide resolved
cryodrgn/commands/downsample.py Outdated Show resolved Hide resolved
cryodrgn/fft.py Outdated Show resolved Hide resolved
cryodrgn/fft.py Show resolved Hide resolved
cryodrgn/fft.py Outdated Show resolved Hide resolved
cryodrgn/commands/abinit_homo.py Outdated Show resolved Hide resolved
cryodrgn/commands/downsample.py Show resolved Hide resolved
cryodrgn/commands/eval_vol.py Show resolved Hide resolved
cryodrgn/commands/eval_vol.py Outdated Show resolved Hide resolved
cryodrgn/commands/train_vae.py Outdated Show resolved Hide resolved
cryodrgn/dataset.py Outdated Show resolved Hide resolved
@adamlerer
Copy link
Collaborator

Btw, the use of max_threads in this PR is a bit concerning. If you look at what --max-threads previously meant (e.g. looking at the help string in train_vae) it was the number of threads being used for the FFT. Now it's the number of threads used for data loading. --max-threads and --num-workers don't interact well, becuase --num-workers splits the creation of each batch into N processes, and then --max-threads tries to split each of those into multiple threads, each of which want as large of a sub-batch (from a single file) as possible.

I think we should default --num-workers to 1 everywhere.

@vineetbansal
Copy link
Collaborator Author

Fair point. n_workers set to 1 (and decoupled with --max-threads). There's value in setting n_workers > 1, but not even close to the gains from the upcoming DataShuffler.

@zhonge zhonge merged commit dcb30eb into master Jun 6, 2023
4 checks passed
@michal-g michal-g deleted the vb/imagesource branch February 22, 2024 19:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants