Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fastest way to load TensorDict data? #374

Closed
fedebotu opened this issue May 7, 2023 · 7 comments
Closed

Fastest way to load TensorDict data? #374

fedebotu opened this issue May 7, 2023 · 7 comments

Comments

@fedebotu
Copy link

fedebotu commented May 7, 2023

Hi there! First of all thanks for your work, I really enjoy tensordicts!

I have a question about performance. What is the best way to load TensorDicts?
I ran some test on possible combinations of dataloaders / collate_fn:


Simple benchmarking code

import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data.dataloader import DataLoader
from tensordict import TensorDict

a = torch.rand(1000, 50, 2)
td = TensorDict({"a": a}, batch_size=1000)

Case 1: store data as tensors, create TensorDicts on the run

class SimpleDataset(Dataset):
    def __init__(self, data):
        # We split into a list since it is faster to dataload (fair comparison vs others)
        self.data = [d for d in data]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]
    
dataset = SimpleDataset(td['a'])
dataloader = DataLoader(dataset, batch_size=32, collate_fn=torch.stack)
x = TensorDict({'a': next(iter(dataloader))}, batch_size=32)
%timeit for x in dataloader: TensorDict({'a': x}, batch_size=x.shape[0])

520 µs ± 833 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

Case 2: store data as TensorDicts and directly load them

class TensorDictDataset(Dataset):
    def __init__(self, data):
        self.data = [d for d in data]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]
    
data  = TensorDictDataset(td)
# use collate_fn=torch.stack to avoid StopIteration error
dataloader = DataLoader(data, batch_size=32, collate_fn=torch.stack)
x = next(iter(dataloader))
%timeit for x in dataloader: pass

1.72 ms ± 5.57 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

Case 3: store TensorDict data as dictionaries and create TensorDicts on the run with collate_fn

class CustomTensorDictDataset(Dataset):
    def __init__(self, data):
        self.data = [
            {key: value[i] for key, value in data.items()}
            for i in range(data.shape[0])
        ]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]


class CustomTensorDictCollate(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, batch):
        return TensorDict(
            {key: torch.stack([b[key] for b in batch]) for key in batch[0].keys()},
            batch_size=len(batch),
        )
    
data = CustomTensorDictDataset(td)
dataloader = DataLoader(data, batch_size=32, collate_fn=CustomTensorDictCollate())
x = next(iter(dataloader))
%timeit for x in dataloader: pass

567 µs ± 924 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


Apparently, splitting data into dictionaries and creating TensorDicts on the run is the fastest way to load data... but why is it not faster to just index TensorDicts instead? And is there a better way?

@vmoens
Copy link
Contributor

vmoens commented May 7, 2023

Hey thanks for the interest and for reporting.
The "easiest way" to build a dataloader should be this IMO

dataloader = DataLoader(td, batch_size=32, collate_fn=lambda x: x)

ie directly indexing your TD.
You should get an execution time slightly higher than your custom dataloader (last example).

Here are the results I get on colab:

  1. Building TDs on the fly
4.79 ms ± 1 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)
  1. Stacking tensordict (expensive as we need to look at the tensordict attributes, + we return lazy-stacks)
13.3 ms ± 4.49 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)
  1. Stacking data in the collate
2.56 ms ± 109 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
  1. Plain indexing of tensordict
3.16 ms ± 113 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

4 being the preferred option, I dug a little bit and it seems we're spending a considerable amount of time determining the size of the tensordict with _getitem_batch_size. That could probably be optimized!

Another thing we're spending time doing is converting the list index provided by the dataloader in a tensor. That should not be done since tensors can be indexed directly...

Here's a PR to check the speed of these indexing ops:
#375

Here's another one to improve the sampling speed:
#376

I'll post updates on these 2 here.

@vmoens
Copy link
Contributor

vmoens commented May 7, 2023

I merged both of them.
Indexing a tensordict with a dataloader (which uses list of integers) should be considerably faster now.

LMK if you observe a speedup on your end too! If not we can come up with different solutions :)

@fedebotu
Copy link
Author

fedebotu commented May 8, 2023

Thanks for the commits! I installed the new version locally and can confirm that the easiest way to load TensorDicts you posted is very close to the custom dataloader:

dataloader = DataLoader(td, batch_size=32, collate_fn=lambda x: x)

668 µs ± 1.02 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

others are the same.

Interestingly, if I try to index TensorDicts in this way:

class TensorDictDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]
    
data  = TensorDictDataset(td)
dataloader = DataLoader(data, batch_size=32, collate_fn=torch.stack)

6.67 ms ± 12.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

It is way slower than all options above.

What is the reason why the custom dataloader of case 3 is still slightly faster than directly indexing td, even though the custom collate has to re-create the td on the fly?

@vmoens
Copy link
Contributor

vmoens commented May 8, 2023

The native indexing is much faster than your example because we implement a __getitems__ (notice the s). The dataloader will look for __getitems__:

  • if it finds it, it will not use the collate function and consider that the class takes care of it during __getitems__
  • otherwise it will call __getitem__ (without s) for each index independently and call the collate_fn.

It's far from perfect that the dataloader does that IMO but it's a compromise on the pytorch side between building loaders in an easy way and supporting multi-indexing.

@vmoens
Copy link
Contributor

vmoens commented May 8, 2023

By the way, in TensorDict __getitems__ is the same as __getitem__ :)

@vmoens
Copy link
Contributor

vmoens commented May 8, 2023

Regarding your question I'm not 100% sure but my guess is that we spend time figuring out what the tensordict batch size is while your dataloader does it on its own (you explicitly tell it).
When you index a td, we'll check your index against multiple data types until we find what it is, then infer what is the size of the indexed tensordict.

I'm confident that we could rewrite our indexing and batch-size computation using custom c++ methods but I never really took the time to do it.

We tried using "meta-tensors" (tensors on a "meta", fake device) instead which we could index at 0 cost but it is much much slower...

@fedebotu
Copy link
Author

fedebotu commented May 8, 2023

I see, very interesting! :D
Also, I didn't know about _getitems_ as a method and now it works way faster:

class TensorDictDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitems__(self, idx):
        return self.data[idx]
    
data  = TensorDictDataset(td)
dataloader = DataLoader(data, batch_size=32, collate_fn=lambda x: x)

678 µs ± 1.91 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

Actually, as fast as the easy implementation you suggested (almost exactly same timings)!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants