Permalink
Switch branches/tags
JoelMarcey-patch-1 Jorghi12-PytorchROCmRemove_Deps Jorghi12-patch-2 Jorghi12-patch-6 Jorghi12-transpiler PyRocm_removing_deps SsnL-patch-1 SsnL-patch-2 SsnL-patch-3 SsnL-patch-4 SsnL-patch-5 a5d7abe add-issue-templates always_scriptmodule anderspapitto-patch-1 batch_mm_t build_fixes build_incremental_fix cache_vars caffe2-docker-image-rebuild circleci_all_commits circleci_credentials circleci_disable_all_jobs circleci_fetch_depth circleci_fix_git circleci_flaky_tests circleci_generic_kernel circleci_test circleci_timestamp circleci cleanup_linkage cpp cudastaticfixes docker/rocm-1.8.2 docker/rocm-update export-D9526248 export-D9526650 export-D9526737 export-D9539945 export-D9540025 export-D9545248 export-D9545704 export-D9557315 export-D9561478 export-D9561802 export-D9562197 export-D9562312 export-D9562467 export-D9563464 export-D9563753 export-D9564206 export-D9564516 export-D9578397 export-D9578398 export-D9578399 export-D9578734 export-D9579371 export-D9581560 export-D9583630 export-D9583699 export-D9585627 export-D9613800 export-D9613897 export-D9614321 export-D9623916 export-D9631619 export-D9634750 export-D9634904 export-D9635105 export-D9635292 export-D9644700 export-D9644899 export-D9646190 export-D9648570 export-D9648830 export-D9652088 export-D9652089 export-D9654871 export-D9656548 export-D9657449 export-D9663476 export-D9666612 export-D9670493 export-D9676205 export-D9694097 export-D9694326 export-D9694327 export-D9694918 export-D9697878 export-D9724805 export-D9727134 export-D9727532 export-D9728631 export-D9731326 export-D9756666 export-D9757935 export-D9763422 export-D9763423 export-D9763424 export-D9771708 export-D9775191 export-D9778042 export-D9778043 export-D9779821 export-D9790187 export-D9806425 export-D9810823 export-D9811028 export-D9813544 export-D9813742 export-D9814536 export-D9823457 export-D9826209 export-D9830460 export-D9831230 export-D9831384 export-D9833361 export-D9841660 export-D9847835 export-D9847859 export-D9882726 export-D9884177 export-D9884563 export-D9889990 export-D9924348 export-D9967509 export-D9968041 export-D9968320 export-D9977058 export-D9977505 export-D9977654 export-D9979976 export-D9980641 export-D9995559 export-D9995633 export-D9996898 export-D10001033 export-D10022853 export-D10024439 export-D10024467 export-D10024485 export-D10024554 export-D10026392 export-D10030556 export-D10030819 export-D10031072 export-D10032707 export-D10033396 export-D10034589 export-D10037265 export-D10050781 export-D10050859 export-D10050905 export-D10051005 export-D10051012 export-D10051078 export-D10051079 export-D10051126 export-D10051202 export-D10051298 export-D10051365 export-D10051424 export-D10052523 export-D10069839 export-D10073519 export-D10111759 export-D10134083 export-D10134319 export-D10139933 export-D10139934 export-D10139935 export-D10150834 export-D10184116 export-D10184117 export-D10200448 export-D10204135 export-D10207890 export-D10209620 export-D10216315 export-D10222739 export-D10227820 export-D10229684 export-D10232118 export-D10232147 export-D10232154 export-D10249293 export-D10251907 export-D10255651 export-D10359443 export-D10371541 export-D10379903 export-D10380678 export-D10392295 export-D10400927 export-D10404407 export-D10415069 export-D10415430 export-D10416051 export-D10419671 export-D10421896 export-D10450290 export-D10454455 export-D10457671 export-D10467239 export-D10467556 export-D10469310 export-D10469960 export-D10476220 export-D10476225 export-D10476226 export-D10476232 export-D10476235 export-D10488399 export-D10492071 export-D10492507 export-D10496244 export-D10513246 export-D10518499 export-D10518929 export-D10520295 export-D10520421 export-D10528061 export-D10853224 export-D10855883 export-D10858024 export-D11669870 export-D12143282 export-D12832080 export-D12848855 export-D12849620 export-D12850690 export-D12850691 export-D12850833 export-D12873145 export-D12874357 export-D12894385 export-D12894386 export-D12912235 export-D12912237 export-D12912238 export-D12912239 export-D12912240 export-D12912241 export-D12912242 export-D12934074 export-D12936031 export-D12937090 export-D12937091 export-D12964886 export-D12985774 export-D13009482 export-D13011878 export-D13015236 export-D13015239 export-D13024368 export-D13025313 export-D13036478 export-D13046201 export-D13046500 export-D13046722 export-D13047468 export-D13053648 export-D13056152 export-D13062526 export-D13062564 export-D13062604 export-D13062631 export-D13062649 export-D13062706 export-D13066808 export-D13081602 export-D13081603 export-D13081604 export-D13081605 export-D13081606 export-D13081607 export-D13081608 export-D13081609 export-D13081610 export-D13104693 export-D13104694 export-D13105166 export-D13111509 export-D13111712 export-D13111781 export-D13112081 export-D13112298 export-D13113129 export-D13119624 export-D13121531 export-D13128077 export-D13128977 export-D13131338 export-D13141949 export-D13145293 export-D13156470 export-D13156471 export-D13156472 export-D13158474 export-D13158475 export-D13205022 export-D13218540 export-D13221302 export-D13223125 export-D13223126 export-D13223904 export-D13224015 export-D13235001 export-D13241355 export-D13241401 export-D13257847 export-D13258252 export-D13258512 export-D13258513 export-D13266063 export-D13267832 export-D13271560 export-D13272227 export-D13277246 export-D13277567 export-D13283492 export-D13283493 export-D13283494 export-D13283495 export-D13283496 export-D13283497 export-D13285370 export-D13287688 export-D13288655 export-D13304398 export-D13316078 export-D13318594 export-D13318596 export-D13318644 export-D13318645 export-D13336841 export-D13336842 export-D13336843 export-D13336856 export-D13348039 export-D13348040 export-D13348041 export-D13348042 export-D13348044 export-D13349163 export-D13349164 export-D13365817 export-D13401512 export-D13425628 export-D13427897 export-D13432800 export-D13432922 export-D13441008 export-D13445571 export-D13451550 export-D13461640 export-D13462124 export-D13463673 export-D13473791 export-D13495323 export-D13498492 export-D13498493 export-D13500679 ext_test_fix ezyang-patch-1 ezyang-patch-2 ezyang-patch-3 ezyang-patch-4 ezyang/retry-type-id-core ezyang/rocm-docker-update fast_dp fb-config fbsync fix_inplace_double_backward gh/ezyang/1/base gh/ezyang/1/head gh/ezyang/1/orig gh/ezyang/2/base gh/ezyang/2/head gh/ezyang/2/orig gh/ezyang/3/base gh/ezyang/3/head gh/ezyang/3/orig gloo_dedup halfconv jerryzh168-patch-1-1 jerryzh168-patch-1 jit_frontend known-good legacy_remove magmatestfix master merge_variable_tensor mkl_set_dynamic nccl_fix new_symbolic_diff nn_c_port oanderso/test random_device readme_fix remove_time scalar_type simple_engine soumith-patch-1 ssnl-9348 suo/annotations suo/clang-format suo/dce2 suo/fix-expect suo/graph-equals suo/ir-parser suo/mm suo/parser suo/schematize suo/slicer tensor-merge tensorimpl_autogradmeta tensorimpl_3_variable_functions tensorimpl_4_AutogradMetaInterface test10 testing/full-caffe2 tmp_enable_scalars v0.4.1 v1.0.0 weak_tensor weak_tracing win_py2.7 windows_jit_error windows_jit_error_12378
Nothing to show
Find file Copy path
122 lines (90 sloc) 3.46 KB
import bisect
import warnings
from torch._utils import _accumulate
from torch import randperm
class Dataset(object):
"""An abstract class representing a Dataset.
All other datasets should subclass it. All subclasses should override
``__len__``, that provides the size of the dataset, and ``__getitem__``,
supporting integer indexing in range from 0 to len(self) exclusive.
"""
def __getitem__(self, index):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
def __add__(self, other):
return ConcatDataset([self, other])
class TensorDataset(Dataset):
"""Dataset wrapping tensors.
Each sample will be retrieved by indexing tensors along the first dimension.
Arguments:
*tensors (Tensor): tensors that have the same size of the first dimension.
"""
def __init__(self, *tensors):
assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)
self.tensors = tensors
def __getitem__(self, index):
return tuple(tensor[index] for tensor in self.tensors)
def __len__(self):
return self.tensors[0].size(0)
class ConcatDataset(Dataset):
"""
Dataset to concatenate multiple datasets.
Purpose: useful to assemble different existing datasets, possibly
large-scale datasets as the concatenation operation is done in an
on-the-fly manner.
Arguments:
datasets (sequence): List of datasets to be concatenated
"""
@staticmethod
def cumsum(sequence):
r, s = [], 0
for e in sequence:
l = len(e)
r.append(l + s)
s += l
return r
def __init__(self, datasets):
super(ConcatDataset, self).__init__()
assert len(datasets) > 0, 'datasets should not be an empty iterable'
self.datasets = list(datasets)
self.cumulative_sizes = self.cumsum(self.datasets)
def __len__(self):
return self.cumulative_sizes[-1]
def __getitem__(self, idx):
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
if dataset_idx == 0:
sample_idx = idx
else:
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
return self.datasets[dataset_idx][sample_idx]
@property
def cummulative_sizes(self):
warnings.warn("cummulative_sizes attribute is renamed to "
"cumulative_sizes", DeprecationWarning, stacklevel=2)
return self.cumulative_sizes
class Subset(Dataset):
"""
Subset of a dataset at specified indices.
Arguments:
dataset (Dataset): The whole Dataset
indices (sequence): Indices in the whole set selected for subset
"""
def __init__(self, dataset, indices):
self.dataset = dataset
self.indices = indices
def __getitem__(self, idx):
return self.dataset[self.indices[idx]]
def __len__(self):
return len(self.indices)
def random_split(dataset, lengths):
"""
Randomly split a dataset into non-overlapping new datasets of given lengths.
Arguments:
dataset (Dataset): Dataset to be split
lengths (sequence): lengths of splits to be produced
"""
if sum(lengths) != len(dataset):
raise ValueError("Sum of input lengths does not equal the length of the input dataset!")
indices = randperm(sum(lengths))
return [Subset(dataset, indices[offset - length:offset]) for offset, length in zip(_accumulate(lengths), lengths)]