-
Notifications
You must be signed in to change notification settings - Fork 77
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fastest way to load TensorDict data? #374
Comments
Hey thanks for the interest and for reporting. dataloader = DataLoader(td, batch_size=32, collate_fn=lambda x: x) ie directly indexing your TD. Here are the results I get on colab:
4 being the preferred option, I dug a little bit and it seems we're spending a considerable amount of time determining the size of the tensordict with Another thing we're spending time doing is converting the list index provided by the dataloader in a tensor. That should not be done since tensors can be indexed directly... Here's a PR to check the speed of these indexing ops: Here's another one to improve the sampling speed: I'll post updates on these 2 here. |
I merged both of them. LMK if you observe a speedup on your end too! If not we can come up with different solutions :) |
Thanks for the commits! I installed the new version locally and can confirm that the easiest way to load TensorDicts you posted is very close to the custom dataloader: dataloader = DataLoader(td, batch_size=32, collate_fn=lambda x: x)
others are the same. Interestingly, if I try to index TensorDicts in this way: class TensorDictDataset(Dataset):
def __init__(self, data):
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
data = TensorDictDataset(td)
dataloader = DataLoader(data, batch_size=32, collate_fn=torch.stack)
It is way slower than all options above. What is the reason why the custom dataloader of case 3 is still slightly faster than directly indexing td, even though the custom collate has to re-create the td on the fly? |
The native indexing is much faster than your example because we implement a
It's far from perfect that the dataloader does that IMO but it's a compromise on the pytorch side between building loaders in an easy way and supporting multi-indexing. |
By the way, in TensorDict |
Regarding your question I'm not 100% sure but my guess is that we spend time figuring out what the tensordict batch size is while your dataloader does it on its own (you explicitly tell it). I'm confident that we could rewrite our indexing and batch-size computation using custom c++ methods but I never really took the time to do it. We tried using "meta-tensors" (tensors on a "meta", fake device) instead which we could index at 0 cost but it is much much slower... |
I see, very interesting! :D class TensorDictDataset(Dataset):
def __init__(self, data):
self.data = data
def __len__(self):
return len(self.data)
def __getitems__(self, idx):
return self.data[idx]
data = TensorDictDataset(td)
dataloader = DataLoader(data, batch_size=32, collate_fn=lambda x: x)
Actually, as fast as the easy implementation you suggested (almost exactly same timings)! |
Hi there! First of all thanks for your work, I really enjoy
tensordicts
!I have a question about performance. What is the best way to load TensorDicts?
I ran some test on possible combinations of dataloaders / collate_fn:
Simple benchmarking code
Case 1: store data as tensors, create TensorDicts on the run
Case 2: store data as TensorDicts and directly load them
Case 3: store TensorDict data as dictionaries and create TensorDicts on the run with collate_fn
Apparently, splitting data into dictionaries and creating TensorDicts on the run is the fastest way to load data... but why is it not faster to just index TensorDicts instead? And is there a better way?
The text was updated successfully, but these errors were encountered: