Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to load data in multiple processes? #35

Closed
gfjiangly opened this issue May 18, 2020 · 12 comments
Closed

How to load data in multiple processes? #35

gfjiangly opened this issue May 18, 2020 · 12 comments

Comments

@gfjiangly
Copy link

How to load data in multiple processes?

@JerryLead
Copy link

JerryLead commented May 22, 2020

@gfjiangly The same problem with you. I think current version cannot work for multiple GPUs. The reason is that IterableDataset (here it is TFRecordDataset) cannot work with DistributedSampler in PyTorch. Maybe we can create a special DistributedSampler for IterableDataset later. However, you can load data using multiple data loading threads with num_workers > 1 for a single CPU.

@DelightRun
Copy link

@gfjiangly @JerryLead The way I solve this problem is to shuffle and distribute TFRecord files to different GPUs evenly before each epoch. The problem is, how can to handle these 2 situations:

  1. The number of TFRecord files cannot be divided by number of GPUs, i.e. 16 TFRecord files with 10 GPUs.
  2. The number of samples cannot be divided by number of GPUs, i.e. 1000 samples with 16 GPUs. In other word, how to implement drop_last=True.

@linkun-1998
Copy link

Is the IterableDataset the reason why, when usingtfrecord.torch.dataset.MultiTFRecordDataset with torch_xla.distributed.parallel_loader.ParallelLoader gives the following type error-

Exception in device=TPU:0: object of type 'MultiTFRecordDataset' has no len()
Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 231, in _start_fn
    fn(gindex, *args)
  File "<ipython-input-25-bd5e4111d32a>", line 182, in _mp_fn
    accuracy, data, pred, target = train_mnist()
  File "<ipython-input-25-bd5e4111d32a>", line 165, in train_mnist
    para_loader = pl.ParallelLoader(train_loader, [device])
  File "/usr/local/lib/python3.6/dist-packages/torch_xla/distributed/parallel_loader.py", line 80, in __init__
    self._per_device_samples = len(loader) // len(devices)
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 315, in __len__
    length = self._IterableDataset_len_called = len(self.dataset)
TypeError: object of type 'MultiTFRecordDataset' has no len()
---------------------------------------------------------------------------
Exception                                 Traceback (most recent call last)
<ipython-input-25-bd5e4111d32a> in <module>()
    184     # Retrieve tensors that are on TPU core 0 and plot
    185     plot_results(data.cpu(), pred.cpu(), target.cpu())
--> 186 xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=FLAGS['num_cores'], start_method='fork')

2 frames
/usr/local/lib/python3.6/dist-packages/torch/multiprocessing/spawn.py in join(self, timeout)
    111                 raise Exception(
    112                     "process %d terminated with exit code %d" %
--> 113                     (error_index, exitcode)
    114                 )
    115 

Exception: process 0 terminated with exit code 17

What can be done to resolve this error?

@DelightRun
Copy link

@linkun-1998 Yes, IterableDataset does not has __len__ method by default, so len(dataset) is unavailable for it. You must add __len__ method by your self.

@linkun-1998
Copy link

@DelightRun awesome, but how can you define a __len__ function when you are streaming from multiple TFRecords, using tfrecord.torch.dataset.MultiTFRecordDataset ?

@vahidk
Copy link
Owner

vahidk commented Aug 27, 2020

It's pretty straightforward when the index is available. You can just implement a new "RandomAccessMultiTFRecordDataset" class that inherits from torch.utils.data.Dataset and change the logic. PR are welcome.

@linkun-1998
Copy link

I tried to implement newMultiTFRecordDataset which inherits from torch.utils.data.Dataset as follows:

class newMultiTFRecordDataset():
      def __init__(self,
                  data_pattern: str,
                  index_pattern: typing.Union[str, None],
                  splits: typing.Dict[str, float],
                  description: typing.Union[typing.List[str], typing.Dict[str, str], None] = None,
                  shuffle_queue_size: typing.Optional[int] = None,
                  transform: typing.Callable[[dict], typing.Any] = None) -> None:
        super(newMultiTFRecordDataset, self).__init__()
        self.data_pattern = data_pattern
        self.index_pattern = index_pattern
        self.splits = splits
        self.description = description
        self.shuffle_queue_size = shuffle_queue_size
        self.transform = transform

      def __len__(self):
        index_len = 0
        for split in self.splits.keys():
          index_len += len(np.loadtxt(self.index_pattern.replace('{}', str(split)), dtype=np.float32)[:, 0])
        return index_len

      def __getitem__(self, index):
        worker_info = torch.utils.data.get_worker_info()
        if worker_info is not None:
            np.random.seed(worker_info.seed % np.iinfo(np.uint32).max)
        it = tfrecord.reader.multi_tfrecord_loader(
            self.data_pattern, self.index_pattern, self.splits, self.description)
        if self.shuffle_queue_size:
            it = iterator_utils.shuffle_iterator(it, self.shuffle_queue_size)
        if self.transform:
            it = map(self.transform, it)
        data = next(it)
    
        return data

I try to use next(iter(dataloader)) of the following dataset which also just works fine.
But when I try to implement in a model with TPU device in colab, I get the following error mentioned below:

Exception                                 Traceback (most recent call last)
<ipython-input-10-51f2165c039a> in <module>()
    175     # Retrieve tensors that are on TPU core 0 and plot
    176     plot_results(data.cpu(), pred.cpu(), target.cpu())
--> 177 xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=FLAGS['num_cores'], start_method='fork')

2 frames
/usr/local/lib/python3.6/dist-packages/torch/multiprocessing/spawn.py in join(self, timeout)
    106                 raise Exception(
    107                     "process %d terminated with signal %s" %
--> 108                     (error_index, name)
    109                 )
    110             else:

Exception: process 0 terminated with signal SIGABRT

The implementation of the following training is as follows:

SERIAL_EXEC = xmp.MpSerialExecutor()

class Network(nn.Module):

  def __init__(self):
    super(Network, self).__init__()
    self.conv1 = nn.Conv2d(3, 10, kernel_size = 5)
    self.bn1 = nn.BatchNorm2d(10)
    self.conv2 = nn.Conv2d(10, 20, kernel_size = 5)
    self.bn2 = nn.BatchNorm2d(20)
    self.fc1 = nn.Linear(320, 50)
    self.fc2 = nn.Linear(50, 5)

  def forward(self, x):
    x = F.relu(F.max_pool2d(self.conv1(x), 2))
    x = self.bn1(x)
    x = F.relu(F.max_pool2d(self.conv2(x), 2))
    x = self.bn2(x)
    x = torch.flatten(x, 1)
    x = F.relu(self.fc1(x))
    x = self.fc2(x)
    return F.log_softmax(x, dim=1)

#Only instantiate model weight once in memory
net = xmp.MpModelWrapper(Network())
print(net)

def train_flowers():
  torch.manual_seed(1)
----------------------------------------------------------------------------------------------------------------------------
  def get_datasets():

    class newMultiTFRecordDataset():
      def __init__(self,
                  data_pattern: str,
                  index_pattern: typing.Union[str, None],
                  splits: typing.Dict[str, float],
                  description: typing.Union[typing.List[str], typing.Dict[str, str], None] = None,
                  shuffle_queue_size: typing.Optional[int] = None,
                  transform: typing.Callable[[dict], typing.Any] = None) -> None:
        super(newMultiTFRecordDataset, self).__init__()
        self.data_pattern = data_pattern
        self.index_pattern = index_pattern
        self.splits = splits
        self.description = description
        self.shuffle_queue_size = shuffle_queue_size
        self.transform = transform

      def __len__(self):
        index_len = 0
        for split in self.splits.keys():
          index_len += len(np.loadtxt(self.index_pattern.replace('{}', str(split)), dtype=np.float32)[:, 0])
        return index_len

      def __getitem__(self, index):
        worker_info = torch.utils.data.get_worker_info()
        if worker_info is not None:
            np.random.seed(worker_info.seed % np.iinfo(np.uint32).max)
        it = tfrecord.reader.multi_tfrecord_loader(
            self.data_pattern, self.index_pattern, self.splits, self.description)
        if self.shuffle_queue_size:
            it = iterator_utils.shuffle_iterator(it, self.shuffle_queue_size)
        if self.transform:
            it = map(self.transform, it)
        data = next(it)
    
        return data

    filenames = os.listdir(FLAGS['datadir'])
    filenames = [file_[:-6] for file_ in filenames]
    #split training filenames
    random.seed(1)
    validation_filenames = list(random.sample(filenames, int(len(filenames)*FLAGS['test_split'])))
    training_filenames = [filename for filename in filenames if filename not in validation_filenames]
    
    #getting tfrecords pattern
    tfrec_pattern = os.path.join(FLAGS['datadir'], '{}.tfrec')
    #getting index pattern
    index_pattern = os.path.join(FLAGS['indexdir'], '{}.idx')

    def primary_transforms(features):
      features['image'] = cv2.resize(cv2.imdecode(features["image"], -1), FLAGS['image_size'])
      features["image"] = cv2.cvtColor(features["image"] , cv2.COLOR_BGR2RGB)
      features["image"] = np.moveaxis(features["image"], -1, 0)
      features['class'] = np.squeeze(np.eye(FLAGS['num_classes'])[np.array([features["class"]]).reshape(-1)])
      return features

    description = { "image": "byte",
                  "class": "int"}
    
    train_samp_split = {}
    for file_ in training_filenames:
      train_samp_split[file_] = 1/len(training_filenames)

    val_samp_split = {}
    for file_ in validation_filenames:
      val_samp_split[file_] = 1/len(validation_filenames)
    
    train_dataset = newMultiTFRecordDataset(tfrec_pattern,
                                            index_pattern,
                                            train_samp_split,
                                            description,
                                            transform=primary_transforms)
    
    val_dataset = newMultiTFRecordDataset(tfrec_pattern,
                                          index_pattern,
                                          val_samp_split,
                                          description,
                                          transform=primary_transforms)

    
    return (train_dataset, val_dataset)
-----------------------------------------------------------------------------------------------------------------------------

  train_dataset, test_dataset = get_datasets()

  train_sampler = torch.utils.data.distributed.DistributedSampler(
      train_dataset,
      num_replicas=xm.xrt_world_size(),
      rank=xm.get_ordinal(),
      shuffle=True)
      
  train_loader = torch.utils.data.DataLoader(
      train_dataset,
      batch_size=FLAGS['batch_size'],
      sampler = train_sampler,
      num_workers=FLAGS['num_workers'],
      drop_last=True
  )
  test_loader = torch.utils.data.DataLoader(
      test_dataset,
      batch_size=FLAGS['batch_size'],
      shuffle=False,
      num_workers=FLAGS['num_workers'],
      drop_last=True
  )
  
----------------------------------------------------------------------------------------------------------------------------
  #Scale learning rate to world size
  lr = FLAGS['learning_rate']*xm.xrt_world_size()

  #Get loss function, optimizer, and model
  device = xm.xla_device()
  model = net.to(device)
  optimizer = optim.SGD(model.parameters(), lr=lr, momentum=FLAGS['momentum'])
  loss_fn = nn.NLLLoss()

  def train_loop_fn(loader):
    tracker = xm.RateTracker()
    model.train()
    for x, batch in enumerate(loader):
      optimizer.zero_grad()
      data, target = batch['image'], batch['class']
      output = model(data)
      loss = loss_fn(output, target)
      loss.backward()
      xm.optimizer_step(optimizer)
      tracker.add(FLAGS['batch_size'])
      if x % FLAGS['log_steps'] == 0:
        print('[xla:{}]({}) Loss={:.5f} Rate={:.2f} GlobalRate={:.2f} Time={}'.format(
            xm.get_ordinal(), x, loss.item(), tracker.rate(),
            tracker.global_rate(), time.asctime()), flush=True)

  def test_loop_fn(loader):
    total_samples = 0
    correct = 0
    model.eval()
    data, pred, target = None, None, None
    for batch in loader:
      data, target = batch['image'], batch['class']
      output = model(data)
      pred = output.max(1, keepdim=True)[1]
      correct += pred.eq(target.view_as(pred)).sum().item()
      total_samples += data.size()[0]

    accuracy = 100.0 * correct/total_samples
    print('[xla:{}] Accuracy={:.2f}%'.format(
          xm.get_ordinal(), accuracy), flush=True)
    return accuracy, data, pred, target

  # Train and eval loops
  accuracy = 0.0
  data, pred, target = None, None, None
  for epoch in range(1, FLAGS['num_epochs']+1):
    para_loader = pl.ParallelLoader(train_loader, [device])
    train_loop_fn(para_loader.per_device_loader(device))
    xm.master_print("Finished training epoch {}".format(epoch))

    para_loader = pl.ParallelLoader(test_loader, [device])
    accuracy, data, pred, target = test_loop_fn(para_loader.per_device_loader(device))
    if FLAGS['metrics_debug']:
      xm.master_print(met.metrics_report(), flush=True)

  return accuracy, data, pred, target


def _mp_fn(rank, flags):
  global FLAGS 
  FLAGS = flags
  torch.set_default_tensor_type('torch.FloatTensor')
  torch.set_default_tensor_type('torch.FloatTensor')
  accuracy, data, pred, target = train_flowers()
  if rank == 0:
    # Retrieve tensors that are on TPU core 0 and plot
    plot_results(data.cpu(), pred.cpu(), target.cpu())
xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=FLAGS['num_cores'], start_method='fork')

What could be the possible error?

@DelightRun
Copy link

@linkun-1998 Due to company's compliance reason, I cannot upload the full code. This is the core part of MultiTFRecordDataset:

class MultiTFRecordDataset(torch.utils.data.IterableDataset):
    """Parse multiple (generic) TFRecords datasets into an `IterableDataset`
    object, which contain `np.ndarrays`s.

    Params:
    -------
    data_pattern: str
        Input data path pattern.

    index_pattern: str or None
        Input index path pattern.

    splits: dict
        Dictionary of (key, value) pairs, where the key is used to
        construct the data and index path(s) and the value determines
        the contribution of each split to the batch.

    description: list or dict of str, optional, default=None
        List of keys or dict of (key, value) pairs to extract from each
        record. The keys represent the name of the features and the
        values ("byte", "float", or "int") correspond to the data type.
        If dtypes are provided, then they are verified against the
        inferred type for compatibility purposes. If None (default),
        then all features contained in the file are extracted.

    is_sequence: bool, optional, default=False
        TFRecord example type. Using tf.train.SequenceExample if
        is_sequence=True, else tf.train.Example.

    shuffle_queue_size: int, optional, default=None
        Length of buffer. Determines how many records are queued to
        sample from.

    transform : a callable, default = None
        A function that takes in the input `features` i.e the dict
        provided in the description, transforms it and returns a
        desirable output.

    """

    def __init__(self,
                 data_pattern: str,
                 index_pattern: typing.Union[str, None],
                 splits: typing.Dict[str, float],
                 description: typing.Union[typing.List[str], typing.Dict[str, str], None] = None,
                 is_sequence: bool = False,
                 shuffle_queue_size: typing.Optional[int] = None,
                 transform: typing.Callable[[dict], typing.Any] = None) -> None:
        super(MultiTFRecordDataset, self).__init__()
        self.data_pattern = data_pattern
        self.index_pattern = index_pattern
        self.splits = splits
        self.description = description
        self.is_sequence = is_sequence
        self.shuffle_queue_size = shuffle_queue_size
        self.transform = transform

        if self.index_pattern is not None:
            self.num_samples = sum(
                sum(1 for _ in open(self.index_pattern.format(split)))
                for split in self.splits
            )
        else:
            self.num_samples = None

    def __len__(self):
        if self.num_samples is not None:
            return self.num_samples
        else:
            raise NotImplementedError()

    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
        if worker_info is not None:
            shard = worker_info.id, worker_info.num_workers
            np.random.seed(worker_info.seed % np.iinfo(np.uint32).max)
        else:
            shard = None
        it = reader.multi_tfrecord_loader(
            self.data_pattern, self.index_pattern, self.splits, self.description, self.is_sequence, shard)
        if self.shuffle_queue_size:
            it = iterator_utils.shuffle_iterator(it, self.shuffle_queue_size)
        if self.transform:
            it = map(self.transform, it)
        return it

@linkun-1998
Copy link

@DelightRun You just implemented the len function which was required. Nice! Thanks. But Can you just gimme the reason why my code dosenot works?

@linkun-1998
Copy link

@DelightRun Moreover I get the same error after adding a __len__() function to MultiTFRecordDataset.

@SK124
Copy link

SK124 commented Jun 27, 2021

@linkun-1998 Were you able to solve it?

@wtmilk
Copy link

wtmilk commented Jan 29, 2022

does someone solve the this issue?

@vahidk vahidk closed this as completed Apr 24, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

7 participants