diff --git a/.lintrunner.toml b/.lintrunner.toml index 1d7c00a2c772d..874a553ee9bc4 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -1532,28 +1532,6 @@ exclude_patterns = [ 'torch/distributed/optim/post_localSGD_optimizer.py', 'torch/distributed/optim/utils.py', 'torch/distributed/optim/zero_redundancy_optimizer.py', - 'torch/distributed/pipeline/__init__.py', - 'torch/distributed/pipeline/sync/__init__.py', - 'torch/distributed/pipeline/sync/_balance/__init__.py', - 'torch/distributed/pipeline/sync/_balance/blockpartition.py', - 'torch/distributed/pipeline/sync/_balance/profile.py', - 'torch/distributed/pipeline/sync/batchnorm.py', - 'torch/distributed/pipeline/sync/checkpoint.py', - 'torch/distributed/pipeline/sync/copy.py', - 'torch/distributed/pipeline/sync/dependency.py', - 'torch/distributed/pipeline/sync/microbatch.py', - 'torch/distributed/pipeline/sync/phony.py', - 'torch/distributed/pipeline/sync/pipe.py', - 'torch/distributed/pipeline/sync/pipeline.py', - 'torch/distributed/pipeline/sync/skip/__init__.py', - 'torch/distributed/pipeline/sync/skip/layout.py', - 'torch/distributed/pipeline/sync/skip/namespace.py', - 'torch/distributed/pipeline/sync/skip/portal.py', - 'torch/distributed/pipeline/sync/skip/skippable.py', - 'torch/distributed/pipeline/sync/skip/tracker.py', - 'torch/distributed/pipeline/sync/stream.py', - 'torch/distributed/pipeline/sync/utils.py', - 'torch/distributed/pipeline/sync/worker.py', 'torch/distributed/remote_device.py', 'torch/distributed/rendezvous.py', 'torch/distributed/rpc/__init__.py', @@ -1847,8 +1825,6 @@ exclude_patterns = [ 'torch/testing/_internal/distributed/nn/__init__.py', 'torch/testing/_internal/distributed/nn/api/__init__.py', 'torch/testing/_internal/distributed/nn/api/remote_module_test.py', - 'torch/testing/_internal/distributed/pipe_with_ddp_test.py', - 'torch/testing/_internal/distributed/pipeline/__init__.py', 'torch/testing/_internal/distributed/rpc/__init__.py', 'torch/testing/_internal/distributed/rpc/dist_autograd_test.py', 'torch/testing/_internal/distributed/rpc/dist_optimizer_test.py', diff --git a/benchmarks/distributed/pipeline/benchmark_dataset.py b/benchmarks/distributed/pipeline/benchmark_dataset.py deleted file mode 100644 index 3cd22e9a468d1..0000000000000 --- a/benchmarks/distributed/pipeline/benchmark_dataset.py +++ /dev/null @@ -1,58 +0,0 @@ -import torch -from torch.utils.data import Dataset - - -def collate_sentences_lm(samples): - if len(samples) == 0: - return {} - - id = torch.LongTensor([s["id"] for s in samples]) - src_tokens = torch.stack([s["source"] for s in samples], 0) - tgt_tokens = torch.stack([s["target"] for s in samples], 0) - ntokens = len(samples) * len(samples[0]["target"]) - src_lengths = torch.LongTensor([len(samples[0]["source"])] * len(samples)) - - batch = { - "id": id, - "nsentences": len(samples), - "ntokens": ntokens, - "input": src_tokens, - "target": tgt_tokens, - } - return batch - - -class BenchmarkLMDataset(Dataset): - """ - Dataset to benchmark a translation like seq2seq task. - Args: - vocab_size (int, optional): size of the vocabulary (default 10000). - max_source_positions (int, optional): max number of tokens in the - source sentence (default: 1024). - total_samples (int, optional): the total number of rows in the - dataset (default: 10000). - """ - - def __init__( - self, - vocab_size=10000, - max_source_positions=1024, - total_samples=10000, - ): - self.vocab_size = vocab_size - self.max_source_positions = max_source_positions - self.total_samples = total_samples - self.sizes = [self.max_source_positions] * self.total_samples - - def __getitem__(self, index): - length = self.sizes[index] - source = torch.randint(1, self.vocab_size, (length,)) - target = source.clone() - return { - "id": index, - "source": source, - "target": target, - } - - def __len__(self): - return self.total_samples diff --git a/benchmarks/distributed/pipeline/pipe.py b/benchmarks/distributed/pipeline/pipe.py deleted file mode 100644 index c465c2488565d..0000000000000 --- a/benchmarks/distributed/pipeline/pipe.py +++ /dev/null @@ -1,296 +0,0 @@ -import argparse -import math -import os -import time - -from benchmark_dataset import BenchmarkLMDataset, collate_sentences_lm - -import torch -import torch.nn as nn -from torch.distributed import rpc - -from torch.distributed.pipeline.sync import Pipe -from torch.distributed.pipeline.sync.utils import partition_model -from torch.optim import Adam -from torch.utils.data import DataLoader - - -def sizeof_fmt(num, suffix="B"): - for unit in ["", "Ki", "Mi", "Gi", "Ti"]: - if abs(num) < 1024.0: - return f"{num:3.2f}{unit}B" - num /= 1024.0 - - -def init_random_seed(seed: int): - import numpy - - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - numpy.random.seed(seed) - - -iteration_count = 0 - - -class EmbeddingLayer(nn.Embedding): - def __init__(self, ntoken, ninp, initrange): - super().__init__(ntoken, ninp) - self.ninp = ninp - nn.init.uniform_(self.weight, -initrange, initrange) - - def forward(self, src): - return super().forward(src) * math.sqrt(self.ninp) - - -class PositionalEncodingLayer(nn.Module): - def __init__(self, d_model, dropout=0.1, max_len=5000): - super().__init__() - self.dropout = nn.Dropout(p=dropout) - - pe = torch.zeros(max_len, d_model) - position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) - div_term = torch.exp( - torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model) - ) - pe[:, 0::2] = torch.sin(position * div_term) - pe[:, 1::2] = torch.cos(position * div_term) - pe = pe.unsqueeze(0).transpose(0, 1) - self.register_buffer("pe", pe) - - def forward(self, x): - x = x + self.pe[: x.size(0), :] - return self.dropout(x) - - -class TransformerDecoderLayer(nn.TransformerEncoderLayer): - """Though this class inherits from torch.nn.TransformerEncoderLayer, - it functions as a decoder in this model""" - - def __init__(self, ninp, nhead, nhid, droupout): - super().__init__(ninp, nhead, nhid, droupout) - self.src_mask = None - - def forward(self, src): - global iteration_count - iteration_count += 1 - - if self.src_mask is None or self.src_mask.size(0) != len(src): - device = src.device - mask = nn.Transformer.generate_square_subsequent_mask(len(src)).to(device) - self.src_mask = mask - - return super().forward(src, self.src_mask) - - -class LinearLayer(nn.Linear): - def __init__(self, ninp, ntoken, initrange): - super().__init__(ninp, ntoken) - nn.init.zeros_(self.bias) - nn.init.uniform_(self.weight, -initrange, initrange) - - -class TransformerLMSequential(nn.Sequential): - """A small language model based on the design of GPT-2 using nn.Sequential - for compatibility with Pipe""" - - def __init__(self, ntokens, ninp, nhead, nhid, dropout, initrange, ndecoder): - layers = [ - EmbeddingLayer(ntokens, ninp, initrange), - PositionalEncodingLayer(ninp, dropout), - ] - for _ in range(ndecoder): - layers.append(TransformerDecoderLayer(ninp, nhead, nhid, dropout)) - - layers.append(LinearLayer(ninp, ntokens, initrange)) - super().__init__(*layers) - - -def make_model(args, device, ntokens): - ninp = 2048 # embedding dimension - nhid = ( - 2048 # the dimension of the feedforward network model in nn.TransformerEncoder - ) - nhead = 32 # the number of heads in the multiheadattention models - dropout = 0 - initrange = 0.1 - ndecoder = args.num_decoder_layers - - model = TransformerLMSequential( - ntokens, ninp, nhead, nhid, dropout, initrange, ndecoder - ).to(device) - - criterion = nn.CrossEntropyLoss() - lr = 0.01 # learning rate - - def make_adam(model): - return Adam(model.parameters(), lr=lr) - - optimizer = make_adam - - return model, criterion, optimizer - - -def train(lm_dataloader, model, criterion, optimizer, vocab_size, args): - model.train() - - vocab_size = 10000 - total_loss = 0.0 - start_time = time.time() - word_counter = 0 - - optimizer = optimizer(model) - - def get_first_device(model): - if model.devices: - return model.devices[0] - else: - return torch.cuda.current_device() - - def get_last_device(model): - if model.devices: - return model.devices[-1] - else: - return torch.cuda.current_device() - - print( - f"Number of parameters for model: {sum(p.numel() for p in model.parameters())}" - ) - for i, batch in enumerate(lm_dataloader): - bi = batch["input"] - if args.max_batch and i > args.max_batch: - break - optimizer.zero_grad() - try: - tmp = batch["input"].to(get_first_device(model)) - output = model(tmp).local_value() - except Exception as e: - raise RuntimeError( - f"training failed on {torch.distributed.get_rank()}" - ) from e - - target = batch["target"].to(get_last_device(model)) - output = output.to(target.device) - - loss = criterion(output.view(-1, vocab_size), target.view(-1)) - loss.backward() - del target - del output - - torch.nn.utils.clip_grad_value_(model.parameters(), 0.05) - optimizer.step() - - total_loss += loss.item() - log_interval = 1 - word_counter += batch["ntokens"] - if i % log_interval == 0 and i > 0: - cur_loss = total_loss / log_interval - elapsed = time.time() - start_time - print( - f"| batch {i:5d} | wps {word_counter / elapsed:5.2f} | loss {cur_loss:5.2f} | ppl {math.exp(cur_loss):8.2f}" - ) - word_counter = 0 - total_loss = 0 - start_time = time.time() - - print("Peak memory usage for GPUs: ", end="") - for i in range(len(model.devices)): - print( - f"cuda:{i}: {sizeof_fmt(torch.cuda.memory_stats(i)['allocated_bytes.all.peak'])}, ", - end="", - ) - print() - - -def generate_balance(num_devices, num_layers): - balance = [] - layers_assigned = 0 - for i in range(num_devices): - x = (num_layers - layers_assigned) / (num_devices - i) - if x.is_integer(): - balance.append(int(x)) - layers_assigned += x - else: - balance.append(math.ceil(x)) - layers_assigned += math.ceil(x) - return balance - - -def make_model_and_data(args, device): - device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") - vocab_size = 10000 - model, criterion, optimizer = make_model(args, device, vocab_size) - lm_dataset = BenchmarkLMDataset() - lm_dataloader = DataLoader( - lm_dataset, - batch_size=args.batch_size, - shuffle=True, - num_workers=0, - collate_fn=collate_sentences_lm, - ) - return { - "model": model, - "criterion": criterion, - "optimizer": optimizer, - "data": lm_dataloader, - "vocab_size": vocab_size, - } - - -def bench_single_process(args): - os.environ.update({"MASTER_ADDR": args.host}) - os.environ.update({"MASTER_PORT": "10638"}) - - rpc.init_rpc( - "worker", - rank=0, - world_size=1, - ) - - num_devices = torch.cuda.device_count() if torch.cuda.is_available() else 1 - num_devices = min(args.num_devices, num_devices) - assert num_devices > 0 - init_random_seed(0) - device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") - - blob = make_model_and_data(args, None) - model = blob["model"] - - balance = generate_balance(num_devices, len(model)) - model = partition_model(model, balance) - p = Pipe(model, chunks=args.chunks, checkpoint=args.checkpoint) - del model - del blob["model"] - - train( - blob["data"], p, blob["criterion"], blob["optimizer"], blob["vocab_size"], args - ) - - -parser = argparse.ArgumentParser(description="benchmark") -parser.add_argument("--host", "-o", type=str, default="localhost", help="hostname") -parser.add_argument( - "--chunks", type=int, default=4, help="number of microbatches per batch" -) -parser.add_argument("--batch-size", type=int, default=8, help="size of a batch") -parser.add_argument("--max-batch", type=int, default=10, help="Max number of batches") -parser.add_argument( - "--num-decoder-layers", - type=int, - default=10, - help="Number of decoder layers in the model", -) -parser.add_argument( - "--checkpoint", - default="except_last", - choices=["always", "except_last", "never"], - help="Checkpointing strategy for pipe", -) -parser.add_argument( - "--num-devices", type=int, default=4, help="Number of GPU devices to use" -) - -if __name__ == "__main__": - args = parser.parse_args() - print(f"Running benchmark with args: {args}") - bench_single_process(args) diff --git a/docs/source/conf.py b/docs/source/conf.py index ef492f17c5060..4f73c111cb235 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -606,47 +606,6 @@ # torch.distributed.optim.utils "as_functional_optim", "register_functional_optim", - # torch.distributed.pipeline.sync.checkpoint - "checkpoint", - "enable_checkpointing", - "enable_recomputing", - "is_checkpointing", - "is_recomputing", - "restore_rng_states", - "save_rng_states", - # torch.distributed.pipeline.sync.dependency - "fork", - "join", - # torch.distributed.pipeline.sync.microbatch - "check", - "gather", - "scatter", - # torch.distributed.pipeline.sync.phony - "get_phony", - # torch.distributed.pipeline.sync.skip.layout - "inspect_skip_layout", - # torch.distributed.pipeline.sync.skip.tracker - "current_skip_tracker", - "use_skip_tracker", - # torch.distributed.pipeline.sync.stream - "as_cuda", - "current_stream", - "default_stream", - "get_device", - "is_cuda", - "new_stream", - "record_stream", - "use_device", - "use_stream", - "wait_stream", - # torch.distributed.pipeline.sync.utils - "partition_model", - # torch.distributed.pipeline.sync.worker - "create_workers", - "spawn_workers", - "worker", - # torch.distributed.pipelining.PipelineSchedule - "step", # torch.distributed.rendezvous "register_rendezvous_handler", "rendezvous", @@ -2650,52 +2609,6 @@ "PostLocalSGDOptimizer", # torch.distributed.optim.zero_redundancy_optimizer "ZeroRedundancyOptimizer", - # torch.distributed.pipeline.sync.batchnorm - "DeferredBatchNorm", - # torch.distributed.pipeline.sync.checkpoint - "Checkpoint", - "Checkpointing", - "Context", - "Function", - "Recompute", - "ThreadLocal", - # torch.distributed.pipeline.sync.copy - "Context", - "Copy", - "Wait", - # torch.distributed.pipeline.sync.dependency - "Fork", - "Join", - # torch.distributed.pipeline.sync.microbatch - "Batch", - "NoChunk", - # torch.distributed.pipeline.sync.pipe - "BalanceError", - "Pipe", - "PipeSequential", - "WithDevice", - # torch.distributed.pipeline.sync.pipeline - "Pipeline", - # torch.distributed.pipeline.sync.skip.layout - "SkipLayout", - # torch.distributed.pipeline.sync.skip.namespace - "Namespace", - # torch.distributed.pipeline.sync.skip.portal - "Context", - "Portal", - "PortalBlue", - "PortalCopy", - "PortalOrange", - # torch.distributed.pipeline.sync.skip.skippable - "Skippable", - # torch.distributed.pipeline.sync.skip.tracker - "SkipTracker", - "SkipTrackerThroughPotals", - "ThreadLocal", - # torch.distributed.pipeline.sync.stream - "CPUStreamType", - # torch.distributed.pipeline.sync.worker - "Task", # torch.distributed.rpc.api "AllGatherStates", "RRef", diff --git a/docs/source/distributed.pipelining.rst b/docs/source/distributed.pipelining.rst index 48f66b5d3276c..a8203a5f3b2ce 100644 --- a/docs/source/distributed.pipelining.rst +++ b/docs/source/distributed.pipelining.rst @@ -299,12 +299,6 @@ You can implement your own pipeline schedule by extending one of the following t For example, ``ScheduleGPipe`` and ``Schedule1F1B`` are subclasses of ``PipelineScheduleSingle``. Whereas, ``ScheduleInterleaved1F1B`` and ``ScheduleLoopedBFS`` are subclasses of ``PipelineScheduleMulti``. -.. currentmodule:: torch.distributed.pipelining.PipelineSchedule - -.. autoclass:: PipelineScheduleSingle - -.. autoclass:: PipelineScheduleMulti - API Reference ************* @@ -370,3 +364,9 @@ Pipeline Schedules .. autoclass:: ScheduleInterleaved1F1B .. autoclass:: ScheduleLoopedBFS + +.. autoclass:: PipelineScheduleSingle + :members: + +.. autoclass:: PipelineScheduleMulti + :members: diff --git a/docs/source/distributed.rst b/docs/source/distributed.rst index 0b091d5670312..f4c73b9381e59 100644 --- a/docs/source/distributed.rst +++ b/docs/source/distributed.rst @@ -876,9 +876,6 @@ If you are running single node training, it may be convenient to interactively b .. py:module:: torch.distributed.nn.api .. py:module:: torch.distributed.nn.jit .. py:module:: torch.distributed.nn.jit.templates -.. py:module:: torch.distributed.pipeline -.. py:module:: torch.distributed.pipeline.sync -.. py:module:: torch.distributed.pipeline.sync.skip .. py:module:: torch.distributed.tensor .. py:module:: torch.distributed.algorithms.ddp_comm_hooks.ddp_zero_hook .. py:module:: torch.distributed.algorithms.ddp_comm_hooks.debugging_hooks @@ -964,22 +961,6 @@ If you are running single node training, it may be convenient to interactively b .. py:module:: torch.distributed.optim.post_localSGD_optimizer .. py:module:: torch.distributed.optim.utils .. py:module:: torch.distributed.optim.zero_redundancy_optimizer -.. py:module:: torch.distributed.pipeline.sync.batchnorm -.. py:module:: torch.distributed.pipeline.sync.checkpoint -.. py:module:: torch.distributed.pipeline.sync.copy -.. py:module:: torch.distributed.pipeline.sync.dependency -.. py:module:: torch.distributed.pipeline.sync.microbatch -.. py:module:: torch.distributed.pipeline.sync.phony -.. py:module:: torch.distributed.pipeline.sync.pipe -.. py:module:: torch.distributed.pipeline.sync.pipeline -.. py:module:: torch.distributed.pipeline.sync.skip.layout -.. py:module:: torch.distributed.pipeline.sync.skip.namespace -.. py:module:: torch.distributed.pipeline.sync.skip.portal -.. py:module:: torch.distributed.pipeline.sync.skip.skippable -.. py:module:: torch.distributed.pipeline.sync.skip.tracker -.. py:module:: torch.distributed.pipeline.sync.stream -.. py:module:: torch.distributed.pipeline.sync.utils -.. py:module:: torch.distributed.pipeline.sync.worker .. py:module:: torch.distributed.remote_device .. py:module:: torch.distributed.rendezvous .. py:module:: torch.distributed.rpc.api diff --git a/docs/source/index.rst b/docs/source/index.rst index ea704f20c3af7..dcaadcbb63edc 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -103,7 +103,6 @@ Features described in this documentation are classified by release status: optim complex_numbers ddp_comm_hooks - pipeline quantization rpc torch.random diff --git a/docs/source/pipeline.rst b/docs/source/pipeline.rst deleted file mode 100644 index 94d730ee223d3..0000000000000 --- a/docs/source/pipeline.rst +++ /dev/null @@ -1,85 +0,0 @@ -.. _pipeline-parallelism: - -Pipeline Parallelism -==================== - -Pipeline parallelism was original introduced in the -`Gpipe `__ paper and is an efficient -technique to train large models on multiple GPUs. - -.. warning :: - torch.distributed.pipeline is deprecated, so is this document. For - up-to-date pipeline parallel implementation, please refer to the - `PiPPy `__ library under the PyTorch - organization (Pipeline Parallelism for PyTorch). - -Model Parallelism using multiple GPUs -------------------------------------- - -Typically for large models which don't fit on a single GPU, model parallelism -is employed where certain parts of the model are placed on different GPUs. -Although, if this is done naively for sequential models, the training process -suffers from GPU under utilization since only one GPU is active at one time as -shown in the figure below: - -.. figure:: _static/img/pipeline_parallelism/no_pipe.png - - The figure represents a model with 4 layers placed on 4 different GPUs - (vertical axis). The horizontal axis represents training this model through - time demonstrating that only 1 GPU is utilized at a time - (`image source `__). - -Pipelined Execution -------------------- - -To alleviate this problem, pipeline parallelism splits the input minibatch into -multiple microbatches and pipelines the execution of these microbatches across -multiple GPUs. This is outlined in the figure below: - -.. figure:: _static/img/pipeline_parallelism/pipe.png - - The figure represents a model with 4 layers placed on 4 different GPUs - (vertical axis). The horizontal axis represents training this model through - time demonstrating that the GPUs are utilized much more efficiently. - However, there still exists a bubble (as demonstrated in the figure) where - certain GPUs are not utilized. - (`image source `__). - -Pipe APIs in PyTorch --------------------- -.. autoclass:: torch.distributed.pipeline.sync.Pipe - :members: forward - -Skip connections -^^^^^^^^^^^^^^^^ - -Certain models like `ResNeXt `__ -are not completely sequential and have skip connections between layers. -Naively implementing as part of pipeline parallelism would imply that -we need to copy outputs for certain layers through multiple GPUs till -we eventually reach the GPU where the layer for the skip connection resides. -To avoid this copy overhead, we provide APIs below to stash and pop Tensors -in different layers of the model. - -.. autofunction:: torch.distributed.pipeline.sync.skip.skippable.skippable -.. autoclass:: torch.distributed.pipeline.sync.skip.skippable.stash -.. autoclass:: torch.distributed.pipeline.sync.skip.skippable.pop -.. autofunction:: torch.distributed.pipeline.sync.skip.skippable.verify_skippables - -Tutorials ---------- - -The following tutorials give a good overview of how to use the -:class:`~torch.distributed.pipeline.sync.Pipe` API to train your models with the -rest of the components that PyTorch provides: - -- `Training Transformer models using Pipeline Parallelism `__ -- `Training Transformer models using Distributed Data Parallel and Pipeline Parallelism `__ - -Acknowledgements ----------------- - -The implementation for pipeline parallelism is based on `fairscale's pipe implementation `__ and -`torchgpipe `__. We would like to -thank both teams for their contributions and guidance towards bringing pipeline -parallelism into PyTorch. diff --git a/test/allowlist_for_publicAPI.json b/test/allowlist_for_publicAPI.json index 0ead16868f2f7..8bedc00723002 100644 --- a/test/allowlist_for_publicAPI.json +++ b/test/allowlist_for_publicAPI.json @@ -211,30 +211,6 @@ "torch.distributed.optim.utils": [ "Type" ], - "torch.distributed.pipeline.sync.pipe": [ - "Pipeline" - ], - "torch.distributed.pipeline.sync.skip.layout": [ - "SkipLayout", - "inspect_skip_layout" - ], - "torch.distributed.pipeline.sync.skip.portal": [ - "Context", - "Portal", - "PortalBlue", - "PortalCopy", - "PortalOrange" - ], - "torch.distributed.pipeline.sync.skip.skippable": [ - "Skippable" - ], - "torch.distributed.pipeline.sync.skip.tracker": [ - "SkipTracker", - "SkipTrackerThroughPotals", - "ThreadLocal", - "current_skip_tracker", - "use_skip_tracker" - ], "torch.distributed.remote_device": [ "Optional", "Union" @@ -1695,10 +1671,6 @@ "get_args_parser", "run" ], - "torch.distributed.pipeline.sync": [ - "NoChunk", - "WithDevice" - ], "torch.distributed.rpc.rref_proxy": [ "Future", "partial", diff --git a/test/distributed/pipeline/sync/LICENSE b/test/distributed/pipeline/sync/LICENSE deleted file mode 100644 index e52be240fdc98..0000000000000 --- a/test/distributed/pipeline/sync/LICENSE +++ /dev/null @@ -1,27 +0,0 @@ -Copyright 2019-2020 Kakao Brain - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are met: - -1. Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - -2. Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - -3. Neither the name of the copyright holder nor the names of its - contributors may be used to endorse or promote products derived from this - software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE -LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE -POSSIBILITY OF SUCH DAMAGE. diff --git a/test/distributed/pipeline/sync/__init__.py b/test/distributed/pipeline/sync/__init__.py deleted file mode 100644 index 94cd5bcb415e0..0000000000000 --- a/test/distributed/pipeline/sync/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -# tests/__init__.py makes pytest can import the application without custom sys.path or PYTHONPATH. -# See also: https://docs.pytest.org/en/latest/goodpractices.html diff --git a/test/distributed/pipeline/sync/conftest.py b/test/distributed/pipeline/sync/conftest.py deleted file mode 100644 index 4f2479b27b29d..0000000000000 --- a/test/distributed/pipeline/sync/conftest.py +++ /dev/null @@ -1,61 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -import tempfile - -import pytest - -import torch -import torch.distributed as dist - - -@pytest.fixture(autouse=True) -def manual_seed_zero(): - torch.manual_seed(0) - - -@pytest.fixture(scope="session") -def cuda_sleep(): - # Warm-up CUDA. - torch.empty(1, device="cuda") - - # From test/test_cuda.py in PyTorch. - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - start.record() - torch.cuda._sleep(1000000) - end.record() - end.synchronize() - cycles_per_ms = 1000000 / start.elapsed_time(end) - - def cuda_sleep(seconds): - torch.cuda._sleep(int(seconds * cycles_per_ms * 1000)) - - return cuda_sleep - - -def pytest_report_header(): - return f"torch: {torch.__version__}" - - -@pytest.fixture -def setup_rpc(scope="session"): - file = tempfile.NamedTemporaryFile() - dist.rpc.init_rpc( - name="worker0", - rank=0, - world_size=1, - rpc_backend_options=dist.rpc.TensorPipeRpcBackendOptions( - init_method=f"file://{file.name}", - ), - ) - yield - dist.rpc.shutdown() - - -def pytest_ignore_collect(path, config): - "Skip this directory if distributed modules are not enabled." - return not dist.is_available() diff --git a/test/distributed/pipeline/sync/skip/__init__.py b/test/distributed/pipeline/sync/skip/__init__.py deleted file mode 100644 index ab03724cafbf5..0000000000000 --- a/test/distributed/pipeline/sync/skip/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. diff --git a/test/distributed/pipeline/sync/skip/test_api.py b/test/distributed/pipeline/sync/skip/test_api.py deleted file mode 100644 index be38d6d83dace..0000000000000 --- a/test/distributed/pipeline/sync/skip/test_api.py +++ /dev/null @@ -1,52 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -import copy - -from torch import nn - -from torch.distributed.pipeline.sync.skip import Namespace, skippable, stash -from torch.testing._internal.common_utils import run_tests - - -def test_namespace_difference(): - ns1 = Namespace() - ns2 = Namespace() - assert ns1 != ns2 - - -def test_namespace_copy(): - ns = Namespace() - assert copy.copy(ns) == ns - assert copy.copy(ns) is not ns - - -def test_skippable_repr(): - @skippable(stash=["hello"]) - class Hello(nn.Module): - def __init__(self): - super().__init__() - self.conv = nn.Conv2d(1, 1, 1) - - def forward(self, x): - yield stash("hello", x) - return self.conv(x) # noqa: B901 - - m = Hello() - assert ( - repr(m) - == """ -@skippable(Hello( - (conv): Conv2d(1, 1, kernel_size=(1, 1), stride=(1, 1)) -)) -""".strip() - ) - - -if __name__ == "__main__": - run_tests() diff --git a/test/distributed/pipeline/sync/skip/test_gpipe.py b/test/distributed/pipeline/sync/skip/test_gpipe.py deleted file mode 100644 index 4f433ab38941c..0000000000000 --- a/test/distributed/pipeline/sync/skip/test_gpipe.py +++ /dev/null @@ -1,126 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -import pytest - -import torch -from torch import nn - -from torch.distributed.pipeline.sync import Pipe -from torch.distributed.pipeline.sync.skip import pop, skippable, stash -from torch.distributed.pipeline.sync.skip.portal import ( - PortalBlue, - PortalCopy, - PortalOrange, -) -from torch.distributed.pipeline.sync.utils import partition_model -from torch.testing._internal.common_utils import run_tests - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") -@pytest.mark.parametrize( - "balance", [[3], [1, 2], [2, 1], [1, 1, 1]], ids=["3", "1:2", "2:1", "1:1:1"] -) -@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"]) -def test_1to3(balance, checkpoint, setup_rpc): - if torch.cuda.device_count() < len(balance): - pytest.skip("at least %d cuda devices required" % len(balance)) - - @skippable(stash=["1to3"]) - class Layer1(nn.Module): - def __init__(self): - super().__init__() - self.conv = nn.Conv2d(3, 3, 1) - - def forward(self, input): - yield stash("1to3", input) - output = self.conv(input) - return output # noqa: B901 - - class Layer2(nn.Module): - def __init__(self): - super().__init__() - self.conv = nn.Conv2d(3, 3, 1) - - def forward(self, input): - output = self.conv(input) - return output - - @skippable(pop=["1to3"]) - class Layer3(nn.Module): - def __init__(self): - super().__init__() - self.conv = nn.Conv2d(3, 3, 1) - - def forward(self, input): - skip_1to3 = yield pop("1to3") - output = self.conv(input) + skip_1to3 - return output - - model = nn.Sequential(Layer1(), Layer2(), Layer3()) - model = partition_model(model, balance) - model = Pipe(model, chunks=3, checkpoint=checkpoint) - - in_device = model.devices[0] - out_device = model.devices[-1] - - input = torch.rand(30, 3, 224, 224, device=in_device, requires_grad=True) - output = model(input) - loss = output.local_value().mean() - loss.backward() - - assert torch.allclose( - output.local_value().norm(), torch.tensor(1039.0, device=out_device), atol=6e-1 - ) - assert torch.allclose( - input.grad.norm(), torch.tensor(0.0004533053, device=in_device) - ) - - -def test_none_skip(setup_rpc): - @skippable(stash=["none"]) - class Stash(nn.Module): - def forward(self, input): - yield stash("none", None) - return input # noqa: B901 - - @skippable(pop=["none"]) - class Pop(nn.Module): - def forward(self, input): - none = yield pop("none") - assert none is None - return input - - model = nn.Sequential(Stash(), Pop()) - model = Pipe(model, chunks=5) - - input = torch.rand(10, requires_grad=True) - output = model(input) - - def assert_grad_fn_is_not_portal(grad_fn, visited=None): - if visited is None: - visited = set() - if grad_fn in visited or grad_fn is None: - return - - assert not isinstance(grad_fn, PortalBlue._backward_cls) - assert not isinstance(grad_fn, PortalCopy._backward_cls) - assert not isinstance(grad_fn, PortalOrange._backward_cls) - - visited.add(grad_fn) - for next_grad_fn, _ in grad_fn.next_functions: - assert_grad_fn_is_not_portal(next_grad_fn, visited) - - assert_grad_fn_is_not_portal(output.local_value().grad_fn) - - output.local_value().sum().backward() - assert input.grad.mean().item() == 1 - - -if __name__ == "__main__": - run_tests() diff --git a/test/distributed/pipeline/sync/skip/test_inspect_skip_layout.py b/test/distributed/pipeline/sync/skip/test_inspect_skip_layout.py deleted file mode 100644 index 4d542285cd5af..0000000000000 --- a/test/distributed/pipeline/sync/skip/test_inspect_skip_layout.py +++ /dev/null @@ -1,118 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -from torch import nn - -from torch.distributed.pipeline.sync.skip import Namespace, pop, skippable, stash -from torch.distributed.pipeline.sync.skip.layout import inspect_skip_layout -from torch.testing._internal.common_utils import run_tests - - -class Pass(nn.Module): - def forward(self, input): - return input - - -@skippable(stash=["foo"]) -class StashFoo(nn.Module): - def forward(self, input): - yield stash("foo", input) - return input # noqa: B901 - - -@skippable(pop=["foo"]) -class PopFoo(nn.Module): - def forward(self, input): - foo = yield stash("foo") - return input + foo - - -@skippable(stash=["bar"]) -class StashBar(nn.Module): - def forward(self, input): - yield stash("bar", input) - return input # noqa: B901 - - -@skippable(pop=["bar"]) -class PopBar(nn.Module): - def forward(self, input): - bar = yield pop("bar") - return input + bar - - -def test_no_skippables(): - p1 = nn.Sequential(Pass()) - p2 = nn.Sequential(Pass()) - - layout = inspect_skip_layout([p1, p2]) - policy = [list(layout.copy_policy(i)) for i in range(2)] - - assert policy == [[], []] - - -def test_inner_partition(): - p1 = nn.Sequential(StashFoo(), PopFoo()) - p2 = nn.Sequential(Pass()) - - layout = inspect_skip_layout([p1, p2]) - policy = [list(layout.copy_policy(i)) for i in range(2)] - - assert policy == [[], []] - - -def test_adjoining_partitions(): - p1 = nn.Sequential(StashFoo()) - p2 = nn.Sequential(PopFoo()) - - layout = inspect_skip_layout([p1, p2]) - policy = [list(layout.copy_policy(i)) for i in range(2)] - - assert policy == [[], [(0, None, "foo")]] - - -def test_far_partitions(): - p1 = nn.Sequential(StashFoo()) - p2 = nn.Sequential(Pass()) - p3 = nn.Sequential(PopFoo()) - - layout = inspect_skip_layout([p1, p2, p3]) - policy = [list(layout.copy_policy(i)) for i in range(3)] - - assert policy == [[], [], [(0, None, "foo")]] - - -def test_pop_2_from_different_partitions(): - p1 = nn.Sequential(StashFoo()) - p2 = nn.Sequential(StashBar()) - p3 = nn.Sequential(PopBar(), PopFoo()) - - layout = inspect_skip_layout([p1, p2, p3]) - policy = [list(layout.copy_policy(i)) for i in range(3)] - - # p3 pops 'bar' before 'foo', but the plan is sorted by source partition index. - assert policy == [[], [], [(0, None, "foo"), (1, None, "bar")]] - - -def test_namespace(): - ns1 = Namespace() - ns2 = Namespace() - - p1 = nn.Sequential(StashFoo().isolate(ns1)) - p2 = nn.Sequential(StashFoo().isolate(ns2)) - p3 = nn.Sequential(PopFoo().isolate(ns2), PopFoo().isolate(ns1)) - - layout = inspect_skip_layout([p1, p2, p3]) - policy = [list(layout.copy_policy(i)) for i in range(3)] - - # p3 pops 'bar' before 'foo', but the plan is sorted by source partition index. - assert policy == [[], [], [(0, ns1, "foo"), (1, ns2, "foo")]] - - -if __name__ == "__main__": - run_tests() diff --git a/test/distributed/pipeline/sync/skip/test_leak.py b/test/distributed/pipeline/sync/skip/test_leak.py deleted file mode 100644 index f4d1043e05498..0000000000000 --- a/test/distributed/pipeline/sync/skip/test_leak.py +++ /dev/null @@ -1,136 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -import pytest - -import torch -from torch import nn - -from torch.distributed.pipeline.sync import is_checkpointing, is_recomputing, Pipe -from torch.distributed.pipeline.sync.skip import pop, skippable, stash -from torch.distributed.pipeline.sync.skip.tracker import current_skip_tracker -from torch.testing._internal.common_utils import run_tests - - -@skippable(stash=["skip"]) -class Stash(nn.Module): - def forward(self, input): - yield stash("skip", input) - return input # noqa: B901 - - -@skippable(pop=["skip"]) -class Pop(nn.Module): - def forward(self, input): - skip = yield pop("skip") - return input + skip - - -@pytest.mark.parametrize("train", [True, False], ids=["train", "eval"]) -@pytest.mark.parametrize("checkpoint", ["always", "except_last", "never"]) -def test_delete_portal_tensor(train, checkpoint, setup_rpc): - # Without checkpointing: - # +- Stash --+ +--- Pop ----+ - - - layers - # | 2,blue,1 |--| 1,orange,0 | - - - tensor_life and portal function - # +----------+ +------------+ - # - # With checkpointing: - # +- Stash --+ +--- Pop ----+ +--- Pop'----+ +- Stash'--+ - # | 3,blue,2 |--| 2,orange,1 |--| 1,orange,0 |--| 1,blue,0 | - # +----------+ +------------+ +------------+ +----------+ - - def portal_tensor_life_is(tensor_life, skip_tracker=None): - if skip_tracker is None: - skip_tracker = current_skip_tracker() - - # Get the current portal. - portal = next(iter(skip_tracker.portals.values())) - - if tensor_life == 0: - return portal.tensor_life == 0 and portal.tensor is None - else: - return portal.tensor_life == tensor_life and portal.tensor is not None - - # Check the portal tensor after 'Stash'. - stash_ = Stash() - - @stash_.register_forward_hook - def check_portal_tensor_after_stash(*_): - if is_checkpointing(): - assert portal_tensor_life_is(2) - elif is_recomputing(): - assert portal_tensor_life_is(0) - else: - assert portal_tensor_life_is(1) - - pop_ = Pop() - - @pop_.register_forward_hook - def check_portal_tensor_after_pop(*_): - if is_checkpointing(): - assert portal_tensor_life_is(1) - elif is_recomputing(): - assert portal_tensor_life_is(0) - else: - assert portal_tensor_life_is(0) - - class NoPortalTensorAtBackward(nn.Module): - class F(torch.autograd.Function): - @staticmethod - def forward(ctx, input): - ctx.skip_tracker = current_skip_tracker() - return input.detach() - - @staticmethod - def backward(ctx, grad): - assert portal_tensor_life_is(0, skip_tracker=ctx.skip_tracker) - return grad - - def forward(self, input): - return self.F.apply(input) - - model = nn.Sequential(NoPortalTensorAtBackward(), stash_, pop_) - model = Pipe(model, chunks=2, checkpoint=checkpoint) - - input = torch.rand(10, requires_grad=True) - - if train: - model.train() - output = model(input).local_value() - output.norm().backward() - else: - model.eval() - with torch.no_grad(): - model(input) - - -@pytest.mark.parametrize("train", [True, False], ids=["train", "eval"]) -def test_no_portal_without_pipe(train, monkeypatch, setup_rpc): - def deny(*args, **kwargs): - raise AssertionError("tried to create Portal without Pipe") - - monkeypatch.setattr( - "torch.distributed.pipeline.sync.skip.portal.Portal.__init__", deny - ) - - model = nn.Sequential(Stash(), Pop()) - - input = torch.rand(10, requires_grad=True) - - if train: - model.train() - output = model(input) - output.norm().backward() - else: - model.eval() - with torch.no_grad(): - model(input) - - -if __name__ == "__main__": - run_tests() diff --git a/test/distributed/pipeline/sync/skip/test_portal.py b/test/distributed/pipeline/sync/skip/test_portal.py deleted file mode 100644 index 5ad180b6f9c84..0000000000000 --- a/test/distributed/pipeline/sync/skip/test_portal.py +++ /dev/null @@ -1,163 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -import pytest - -import torch - -from torch.distributed.pipeline.sync.dependency import fork, join -from torch.distributed.pipeline.sync.skip.portal import Portal -from torch.distributed.pipeline.sync.stream import default_stream -from torch.testing._internal.common_utils import run_tests - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") -def test_copy_returns_on_next_device(): - portal = Portal(torch.rand(1), tensor_life=1) - - prev_stream = default_stream(torch.device("cpu")) - next_stream = default_stream(torch.device("cuda")) - - phony = torch.zeros(0, requires_grad=True) - assert phony.device.type == "cpu" - - phony = portal.copy(prev_stream, next_stream, phony) - assert phony.device.type == "cuda" - - -def test_blue_orange(): - tensor1 = torch.rand(1, requires_grad=True) - tensor2 = torch.rand(1, requires_grad=True) - - # Same with: output = tensor1*2 + tensor2 - # - # +----------------------+ - # | | - # tensor2 -- PortalBlue -+ +- PortalOrange -+ - # | | | - # tensor1 ------------ Join -- Fork --- Mul --- Add -- output - # - main = tensor1 - portal = Portal(tensor2, tensor_life=2) - phony = portal.blue() - main = join(main, phony) - main, phony = fork(main) - sub = portal.orange(phony) - output = main * 2 + sub - - output.backward() - - assert torch.allclose(tensor1.grad, torch.tensor([2.0])) - assert torch.allclose(tensor2.grad, torch.tensor([1.0])) - - -def test_blue_orange_not_requires_grad(): - tensor1 = torch.rand(1, requires_grad=True) - tensor2 = torch.rand(1) - - # Same with: output = tensor1*2 + tensor2 - # - # +----------------------+ - # | | - # tensor2 -- PortalBlue -+ +- PortalOrange -+ - # | | | - # tensor1 ------------ Join -- Fork --- Mul --- Add -- output - # - main = tensor1 - portal = Portal(tensor2, tensor_life=2) - phony = portal.blue() - main = join(main, phony) - main, phony = fork(main) - sub = portal.orange(phony) - output = main * 2 + sub - - output.backward() - - assert torch.allclose(tensor1.grad, torch.tensor([2.0])) - assert tensor2.grad is None - - -def test_use_grad(): - tensor = torch.rand(1, requires_grad=True) - portal = Portal(tensor, tensor_life=1) - - portal.put_grad(tensor) - assert portal.use_grad() is tensor - - # Gradient in a portal is ephemeral. - with pytest.raises(RuntimeError): - portal.use_grad() - - -class TestTensorLife: - @pytest.fixture - def new_portal(self): - portal = None - - def new_portal(tensor_life): - nonlocal portal - tensor = torch.rand(1, requires_grad=True) - portal = Portal(tensor, tensor_life) - return portal, tensor - - yield new_portal - - # A test using this fixture must exhaust the tensor in the portal. - with pytest.raises(RuntimeError): - portal.check_tensor_life() - assert portal.tensor is None - - def test_tensor_life_0(self, new_portal): - portal, tensor = new_portal(0) - assert portal.tensor is None - - def test_tensor_life_1(self, new_portal): - portal, tensor = new_portal(1) - assert portal.tensor is tensor - - portal.blue() - - def test_tensor_life_2(self, new_portal): - portal, tensor = new_portal(2) - assert portal.tensor is tensor - - phony = portal.blue() - assert portal.orange(phony).data_ptr() == tensor.data_ptr() - - def test_tensor_life_3(self, new_portal): - portal, tensor = new_portal(3) - assert portal.tensor is tensor - - phony = portal.blue() - assert portal.orange(phony).data_ptr() == tensor.data_ptr() - assert portal.orange(phony).data_ptr() == tensor.data_ptr() - - def test_tensor_life_4(self, new_portal): - portal, tensor = new_portal(4) - assert portal.tensor is tensor - - phony = portal.blue() - assert portal.orange(phony).data_ptr() == tensor.data_ptr() - assert portal.orange(phony).data_ptr() == tensor.data_ptr() - portal.blue() - - def test_tensor_life_3_plus_1(self, new_portal): - portal, tensor = new_portal(3) - assert portal.tensor is tensor - - phony = portal.blue() - assert portal.orange(phony).data_ptr() == tensor.data_ptr() - assert portal.orange(phony).data_ptr() == tensor.data_ptr() - - another_tensor = torch.rand(1, requires_grad=True) - portal.put_tensor(another_tensor, tensor_life=1) - portal.blue() - - -if __name__ == "__main__": - run_tests() diff --git a/test/distributed/pipeline/sync/skip/test_stash_pop.py b/test/distributed/pipeline/sync/skip/test_stash_pop.py deleted file mode 100644 index 5d273860f6a6c..0000000000000 --- a/test/distributed/pipeline/sync/skip/test_stash_pop.py +++ /dev/null @@ -1,144 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -import pytest - -import torch -from torch import nn - -from torch.distributed.pipeline.sync.skip import pop, skippable, stash -from torch.distributed.pipeline.sync.skip.tracker import SkipTracker, use_skip_tracker -from torch.testing._internal.common_utils import run_tests - - -@pytest.fixture(autouse=True) -def skip_tracker(): - skip_tracker = SkipTracker() - with use_skip_tracker(skip_tracker): - yield skip_tracker - - -def test_stash(skip_tracker): - @skippable(stash=["foo"]) - class Stash(nn.Module): - def forward(self, input): - yield stash("foo", input) - return input * 2 # noqa: B901 - - l1 = Stash() - - assert len(skip_tracker.tensors) == 0 - - with use_skip_tracker(skip_tracker): - l1(torch.tensor(42)) - - assert len(skip_tracker.tensors) == 1 - - -def test_pop(): - @skippable(stash=["foo"]) - class Stash(nn.Module): - def forward(self, input): - yield stash("foo", input) - return input * 2 # noqa: B901 - - @skippable(pop=["foo"]) - class Pop(nn.Module): - def forward(self, input): - foo = yield pop("foo") - return foo - - l1 = Stash() - l2 = Pop() - - output = l2(l1(torch.tensor(42))) - - assert output.item() == 42 - - -def test_declare_but_not_use(): - @skippable(stash=["foo"]) - class Stash(nn.Module): - def forward(self, input): - return input * 2 - - @skippable(pop=["foo"]) - class Pop(nn.Module): - def forward(self, input): - return input * 3 - - l1 = Stash() - l2 = Pop() - - with pytest.raises(RuntimeError): - l1(torch.tensor(42)) - - with pytest.raises(RuntimeError): - l2(torch.tensor(42)) - - -def test_stash_not_declared(): - @skippable() - class Stash(nn.Module): - def forward(self, input): - yield stash("foo", input) - return input * 2 # noqa: B901 - - l1 = Stash() - - with pytest.raises(RuntimeError): - l1(torch.tensor(42)) - - -def test_pop_not_declared(): - @skippable(stash=["foo"]) - class Stash(nn.Module): - def forward(self, input): - yield stash("foo", input) - return input * 2 # noqa: B901 - - @skippable() - class Pop(nn.Module): - def forward(self, input): - foo = yield pop("foo") - return foo - - l1 = Stash() - l2 = Pop() - - latent = l1(torch.tensor(42)) - - with pytest.raises(RuntimeError): - l2(latent) - - -def test_pop_not_stashed(): - @skippable(pop=["foo"]) - class Pop(nn.Module): - def forward(self, input): - yield pop("foo") - - l1 = Pop() - - with pytest.raises(RuntimeError): - l1(torch.tensor(42)) - - -def test_stash_none(): - @skippable(stash=["foo"]) - class Stash(nn.Module): - def forward(self, input): - yield stash("foo", None) - return input * 2 # noqa: B901 - - l1 = Stash() - l1(torch.tensor(42)) - - -if __name__ == "__main__": - run_tests() diff --git a/test/distributed/pipeline/sync/skip/test_tracker.py b/test/distributed/pipeline/sync/skip/test_tracker.py deleted file mode 100644 index 9c3a970f7574e..0000000000000 --- a/test/distributed/pipeline/sync/skip/test_tracker.py +++ /dev/null @@ -1,145 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -import threading -from queue import Queue - -import pytest - -import torch -from torch import nn - -from torch.distributed.pipeline.sync.checkpoint import ( - enable_checkpointing, - enable_recomputing, -) -from torch.distributed.pipeline.sync.microbatch import Batch -from torch.distributed.pipeline.sync.skip import pop, skippable, stash -from torch.distributed.pipeline.sync.skip.layout import SkipLayout -from torch.distributed.pipeline.sync.skip.tracker import ( - current_skip_tracker, - SkipTracker, - SkipTrackerThroughPotals, -) -from torch.testing._internal.common_utils import run_tests - - -def test_default_skip_tracker(): - q = Queue() - - def f(): - q.put(current_skip_tracker()) - - t = threading.Thread(target=f) - t.start() - t.join() - - skip_tracker = q.get() - - assert type(skip_tracker) is SkipTracker - assert type(skip_tracker) is not SkipTrackerThroughPotals - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") -def test_default_skip_tracker_by_data_parallel(): - @skippable(stash=["foo"]) - class Stash(nn.Module): - def forward(self, input): - yield stash("foo", input) - return input * 2 # noqa: B901 - - @skippable(pop=["foo"]) - class Pop(nn.Module): - def forward(self, input): - foo = yield pop("foo") - return foo - - model = nn.Sequential(Stash(), Pop()) - model = nn.DataParallel(model, device_ids=[0, 0], output_device=0) - - input = torch.rand(10, device=0) - output = model(input) - - assert torch.allclose(output, input) - - -def test_reuse_portal(): - skip_layout = SkipLayout(num_partitions=2, skip_routes={(None, "test"): (0, 1)}) - skip_tracker = SkipTrackerThroughPotals(skip_layout) - - batch = Batch(torch.tensor([1.0])) - a = torch.tensor([2.0]) - b = torch.tensor([2.0]) - - skip_tracker.save(batch, None, "test", a) - portal = skip_tracker.portals[(None, "test")] - - skip_tracker.save(batch, None, "test", b) - assert portal is skip_tracker.portals[(None, "test")] - - -def test_no_copy_no_portal(): - skip_layout = SkipLayout( - num_partitions=2, - skip_routes={(None, "copy"): (0, 1), (None, "not_copy"): (0, 0)}, - ) - skip_tracker = SkipTrackerThroughPotals(skip_layout) - - batch = Batch(torch.tensor([1.0])) - a = torch.tensor([2.0]) - b = torch.tensor([2.0]) - - skip_tracker.save(batch, None, "copy", a) - skip_tracker.save(batch, None, "not_copy", b) - - assert (None, "copy") in skip_tracker.portals - assert (None, "copy") not in skip_tracker.tensors - assert (None, "not_copy") in skip_tracker.tensors - assert (None, "not_copy") not in skip_tracker.portals - - -def test_tensor_life_without_checkpointing(): - skip_layout = SkipLayout(num_partitions=2, skip_routes={(None, "test"): (0, 1)}) - skip_tracker = SkipTrackerThroughPotals(skip_layout) - - batch = Batch(torch.tensor([1.0])) - tensor = torch.tensor([2.0]) - - skip_tracker.save(batch, None, "test", tensor) - assert skip_tracker.portals[(None, "test")].tensor_life == 1 - - skip_tracker.load(batch, None, "test") - assert skip_tracker.portals[(None, "test")].tensor_life == 0 - - -def test_tensor_life_with_checkpointing(): - skip_layout = SkipLayout(num_partitions=2, skip_routes={(None, "test"): (0, 1)}) - skip_tracker = SkipTrackerThroughPotals(skip_layout) - - batch = Batch(torch.tensor([1.0])) - tensor = torch.tensor([2.0]) - - with enable_checkpointing(): - skip_tracker.save(batch, None, "test", tensor) - assert skip_tracker.portals[(None, "test")].tensor_life == 2 - - with enable_checkpointing(): - skip_tracker.load(batch, None, "test") - assert skip_tracker.portals[(None, "test")].tensor_life == 1 - - with enable_recomputing(): - skip_tracker.load(batch, None, "test") - assert skip_tracker.portals[(None, "test")].tensor_life == 0 - - with enable_recomputing(): - skip_tracker.save(batch, None, "test", tensor) - assert skip_tracker.portals[(None, "test")].tensor_life == 0 - - -if __name__ == "__main__": - run_tests() diff --git a/test/distributed/pipeline/sync/skip/test_verify_skippables.py b/test/distributed/pipeline/sync/skip/test_verify_skippables.py deleted file mode 100644 index 1d5941487da87..0000000000000 --- a/test/distributed/pipeline/sync/skip/test_verify_skippables.py +++ /dev/null @@ -1,165 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -import pytest - -from torch import nn - -from torch.distributed.pipeline.sync.skip import Namespace, skippable, verify_skippables -from torch.testing._internal.common_utils import run_tests - - -def test_matching(): - @skippable(stash=["foo"]) - class Layer1(nn.Module): - pass - - @skippable(pop=["foo"]) - class Layer2(nn.Module): - pass - - verify_skippables(nn.Sequential(Layer1(), Layer2())) - - -def test_stash_not_pop(): - @skippable(stash=["foo"]) - class Layer1(nn.Module): - pass - - with pytest.raises(TypeError) as e: - verify_skippables(nn.Sequential(Layer1())) - assert "no module declared 'foo' as poppable but stashed" in str(e.value) - - -def test_pop_unknown(): - @skippable(pop=["foo"]) - class Layer1(nn.Module): - pass - - with pytest.raises(TypeError) as e: - verify_skippables(nn.Sequential(Layer1())) - assert "'0' declared 'foo' as poppable but it was not stashed" in str(e.value) - - -def test_stash_again(): - @skippable(stash=["foo"]) - class Layer1(nn.Module): - pass - - @skippable(stash=["foo"]) - class Layer2(nn.Module): - pass - - @skippable(pop=["foo"]) - class Layer3(nn.Module): - pass - - with pytest.raises(TypeError) as e: - verify_skippables(nn.Sequential(Layer1(), Layer2(), Layer3())) - assert "'1' redeclared 'foo' as stashable" in str(e.value) - - -def test_pop_again(): - @skippable(stash=["foo"]) - class Layer1(nn.Module): - pass - - @skippable(pop=["foo"]) - class Layer2(nn.Module): - pass - - @skippable(pop=["foo"]) - class Layer3(nn.Module): - pass - - with pytest.raises(TypeError) as e: - verify_skippables(nn.Sequential(Layer1(), Layer2(), Layer3())) - assert "'2' redeclared 'foo' as poppable" in str(e.value) - - -def test_stash_pop_together_different_names(): - @skippable(stash=["foo"]) - class Layer1(nn.Module): - pass - - @skippable(pop=["foo"], stash=["bar"]) - class Layer2(nn.Module): - pass - - @skippable(pop=["bar"]) - class Layer3(nn.Module): - pass - - verify_skippables(nn.Sequential(Layer1(), Layer2(), Layer3())) - - -def test_stash_pop_together_same_name(): - @skippable(stash=["foo"], pop=["foo"]) - class Layer1(nn.Module): - pass - - with pytest.raises(TypeError) as e: - verify_skippables(nn.Sequential(Layer1())) - assert "'0' declared 'foo' both as stashable and as poppable" in str(e.value) - - -def test_double_stash_pop(): - @skippable(stash=["foo"]) - class Layer1(nn.Module): - pass - - @skippable(pop=["foo"]) - class Layer2(nn.Module): - pass - - @skippable(stash=["foo"]) - class Layer3(nn.Module): - pass - - @skippable(pop=["foo"]) - class Layer4(nn.Module): - pass - - with pytest.raises(TypeError) as e: - verify_skippables(nn.Sequential(Layer1(), Layer2(), Layer3(), Layer4())) - assert "'2' redeclared 'foo' as stashable" in str(e.value) - assert "'3' redeclared 'foo' as poppable" in str(e.value) - - -def test_double_stash_pop_but_isolated(): - @skippable(stash=["foo"]) - class Layer1(nn.Module): - pass - - @skippable(pop=["foo"]) - class Layer2(nn.Module): - pass - - @skippable(stash=["foo"]) - class Layer3(nn.Module): - pass - - @skippable(pop=["foo"]) - class Layer4(nn.Module): - pass - - ns1 = Namespace() - ns2 = Namespace() - - verify_skippables( - nn.Sequential( - Layer1().isolate(ns1), - Layer2().isolate(ns1), - Layer3().isolate(ns2), - Layer4().isolate(ns2), - ) - ) - - -if __name__ == "__main__": - run_tests() diff --git a/test/distributed/pipeline/sync/test_balance.py b/test/distributed/pipeline/sync/test_balance.py deleted file mode 100644 index faf09f4581ae7..0000000000000 --- a/test/distributed/pipeline/sync/test_balance.py +++ /dev/null @@ -1,240 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -import time - -import pytest - -import torch -from torch import nn - -from torch.distributed.pipeline.sync._balance import ( - balance_by_size, - balance_by_time, - blockpartition, -) -from torch.distributed.pipeline.sync._balance.profile import layerwise_sandbox -from torch.testing._internal.common_utils import run_tests - -skip_if_no_cuda = pytest.mark.skipif( - not torch.cuda.is_available(), reason="cuda required" -) - -devices = ["cpu"] -if torch.cuda.is_available(): - devices.append("cuda") - - -def test_blockpartition(): - assert blockpartition.solve([1, 2, 3, 4, 5, 6], partitions=2) == [ - [1, 2, 3, 4], - [5, 6], - ] - - -def test_blockpartition_zeros(): - assert blockpartition.solve([0, 0], partitions=2) == [[0], [0]] - - -def test_blockpartition_non_positive_partitions(): - with pytest.raises(ValueError): - blockpartition.solve([42], partitions=0) - with pytest.raises(ValueError): - blockpartition.solve([42], partitions=-1) - - -def test_blockpartition_short_sequence(): - with pytest.raises(ValueError): - blockpartition.solve([], partitions=1) - with pytest.raises(ValueError): - blockpartition.solve([42], partitions=2) - - -@pytest.mark.parametrize("device", devices) -@pytest.mark.skip(reason="Flaky due to time.sleep()") -def test_balance_by_time(device): - class Delay(nn.Module): - def __init__(self, seconds): - super().__init__() - self.seconds = seconds - - def forward(self, x): - time.sleep(self.seconds) - return x - - model = nn.Sequential(*[Delay(i / 10) for i in [1, 2, 3, 4, 5, 6]]) - sample = torch.rand(1) - balance = balance_by_time(2, model, sample, device=device) - assert balance == [4, 2] - - -def test_balance_by_time_loop_resets_input(): - # nn.Flatten was introduced at PyTorch 1.2.0. - class Flatten(nn.Module): - def forward(self, x): - return x.flatten(1) - - model = nn.Sequential(nn.Conv2d(3, 2, 1), Flatten(), nn.Linear(128, 10)) - sample = torch.rand(10, 3, 8, 8) - balance = balance_by_time(2, model, sample, device="cpu") - assert balance == [1, 2] - - -@skip_if_no_cuda -def test_balance_by_size_latent(): - class Expand(nn.Module): - def __init__(self, times): - super().__init__() - self.times = times - - def forward(self, x): - for i in range(self.times): - x = x + torch.rand_like(x, requires_grad=True) - return x - - sample = torch.rand(10, 100, 100) - - model = nn.Sequential(*[Expand(i) for i in [1, 2, 3, 4, 5, 6]]) - balance = balance_by_size(2, model, sample) - assert balance == [4, 2] - - model = nn.Sequential(*[Expand(i) for i in [6, 5, 4, 3, 2, 1]]) - balance = balance_by_size(2, model, sample) - assert balance == [2, 4] - - -@skip_if_no_cuda -def test_balance_by_size_param(): - model = nn.Sequential(*[nn.Linear(i + 1, i + 2) for i in range(6)]) - sample = torch.rand(7, 1) - balance = balance_by_size(2, model, sample, param_scale=100) - assert balance == [4, 2] - - model = nn.Sequential(*[nn.Linear(i + 2, i + 1) for i in reversed(range(6))]) - sample = torch.rand(1, 7) - balance = balance_by_size(2, model, sample, param_scale=100) - assert balance == [2, 4] - - -@skip_if_no_cuda -def test_balance_by_size_param_scale(): - class Tradeoff(nn.Module): - def __init__(self, param_size, latent_size): - super().__init__() - self.fc = nn.Linear(param_size, param_size) - self.latent_size = latent_size - - def forward(self, x): - for i in range(self.latent_size): - x = x + torch.rand_like(x, requires_grad=True) - return x - - model = nn.Sequential( - Tradeoff(param_size=1, latent_size=6), - Tradeoff(param_size=2, latent_size=5), - Tradeoff(param_size=3, latent_size=4), - Tradeoff(param_size=4, latent_size=3), - Tradeoff(param_size=5, latent_size=2), - Tradeoff(param_size=6, latent_size=1), - ) - - sample = torch.rand(1, requires_grad=True) - - balance = balance_by_size(2, model, sample, param_scale=0) - assert balance == [2, 4] - - balance = balance_by_size(2, model, sample, param_scale=100) - assert balance == [4, 2] - - -@pytest.mark.parametrize("device", devices) -def test_layerwise_sandbox(device): - model = nn.Sequential(nn.Conv2d(3, 3, 1), nn.BatchNorm2d(3)) - model.eval() - - for layer in layerwise_sandbox(model, torch.device(device)): - assert layer.training - assert all(p.device.type == device for p in layer.parameters()) - - assert all(not l.training for l in model) - assert all(p.device.type == "cpu" for p in model.parameters()) - - -@pytest.mark.parametrize("device", devices) -def test_sandbox_during_profiling(device): - model = nn.Sequential(nn.BatchNorm2d(3)) - - before = {k: v.clone() for k, v in model.state_dict().items()} - - sample = torch.rand(1, 3, 10, 10) - balance_by_time(1, model, sample, device=device) - - after = model.state_dict() - - assert before.keys() == after.keys() - for key, value in before.items(): - assert torch.allclose(after[key], value), key - - -def test_not_training(): - class AssertTraining(nn.Module): - def forward(self, x): - assert self.training - return x - - model = nn.Sequential(AssertTraining()) - - model.eval() - assert not model.training - - sample = torch.rand(1) - balance_by_time(1, model, sample, device="cpu") - - assert not model.training - - -def test_balance_by_time_tuple(): - class Twin(nn.Module): - def forward(self, x): - return x, x.detach() - - class Add(nn.Module): - def forward(self, a, b): - return a + b - - model = nn.Sequential(Twin(), Add()) - sample = torch.rand(1, requires_grad=True) - balance_by_time(1, model, sample, device="cpu") - - -@skip_if_no_cuda -def test_balance_by_size_tuple(): - class Twin(nn.Module): - def forward(self, x): - return x, x.detach() - - class Add(nn.Module): - def forward(self, a, b): - return a + b - - model = nn.Sequential(Twin(), Add()) - sample = torch.rand(1, requires_grad=True) - balance_by_size(1, model, sample) - - -def test_already_has_grad(): - model = nn.Sequential(nn.Conv2d(3, 3, 1)) - sample = torch.rand(1, 3, 32, 32) - model(sample).norm().backward() - - with pytest.raises(ValueError, match="some parameter already has gradient"): - balance_by_time(1, model, sample, device="cpu") - - -if __name__ == "__main__": - run_tests() diff --git a/test/distributed/pipeline/sync/test_bugs.py b/test/distributed/pipeline/sync/test_bugs.py deleted file mode 100644 index 928a78db6e325..0000000000000 --- a/test/distributed/pipeline/sync/test_bugs.py +++ /dev/null @@ -1,146 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -import pytest - -import torch -import torch.nn.functional as F -from torch import nn - -from torch.distributed.pipeline.sync import Pipe -from torch.testing._internal.common_cuda import TEST_MULTIGPU -from torch.testing._internal.common_utils import run_tests - - -def test_python_autograd_function(setup_rpc): - # A Python autograd function might fail with this error: - # - # RuntimeError: Returning Variables sharing storage with other Variables - # that require grad is not supported in Python functions. Please submit a - # feature request if you hit this error. - # - # It doesn't look like an essential restriction. But it happens on the - # current PyTorch version. To avoid it, we should detach the tensor before - # returning by identity autograd functions, such as Wait, Fork, and Join. - # - class Identity(torch.autograd.Function): - @staticmethod - def forward(ctx, input): - return input - - @staticmethod - def backward(ctx, grad): - return grad - - class M(nn.Module): - def forward(self, input): - return Identity.apply(input) - - model = nn.Sequential(M(), M()) - model = Pipe(model, checkpoint="always") - - x = torch.rand(42) - y = model(x) - assert torch.allclose(x, y.local_value()) - - -def test_exception_no_hang(setup_rpc): - # In v0.0.2, once a failed partition receives a normal message - # (non-closing) for the next micro-batch, a hang occurred. The reason was - # that a failed partition didn't call in_queue.task_done() on a normal - # message. So the former partition was blocked at out_queue.join() for the - # next of next micro-batch. - class ExpectedException(Exception): - pass - - class Pass(nn.Module): - def forward(self, x): - return x - - class Raise(nn.Module): - def forward(self, x): - raise ExpectedException - - model = nn.Sequential(Pass(), Pass(), Raise()) - model = Pipe(model, chunks=3) - - with pytest.raises(ExpectedException): - model(torch.rand(3)) - - -@pytest.mark.skipif(not TEST_MULTIGPU, reason="2 cuda devices required") -def test_tuple_wait(cuda_sleep, setup_rpc): - # In v0.0.3, Wait is applied to only the first tensor on a micro-batch. - # Under this behavior, if checkpointing was disabled, there's a possibility - # that gradient accumulations on other tensors are not synchronized - # properly to the copy stream. - class Sleep(torch.autograd.Function): - @staticmethod - def forward(ctx, x): - return x.detach() - - @staticmethod - def backward(ctx, grad): - with torch.cuda.device(grad.device): - cuda_sleep(0.05) - return grad - - class Layer1(nn.Module): - def __init__(self): - super().__init__() - self.ones = nn.Parameter(torch.ones(32, 3, 32, 32, requires_grad=True)) - - def forward(self, a, b): - a = a * self.ones - return a * 1, b * 2, b * 3 - - class Layer2(nn.Module): - def __init__(self): - super().__init__() - self.ones = nn.Parameter(torch.ones(32, 3, 32, 32, requires_grad=True)) - - def forward(self, a, b, c): - a = a * self.ones - b = Sleep.apply(b) - return a + b + c - - model = nn.Sequential(Layer1().cuda(0), Layer2().cuda(1)) - model = Pipe(model, chunks=32, checkpoint="never") - - a = torch.rand(1024, 3, 32, 32, device=0, requires_grad=True) - b = torch.rand(1024, 3, 32, 32, device=0, requires_grad=True) - - y = model(a, b) - y.local_value().norm().backward() - - torch.cuda.synchronize(0) - torch.cuda.synchronize(1) - - assert torch.isclose(b.grad.norm().cpu(), torch.tensor(5.000)) - - -def test_parallel_randoms(setup_rpc): - class Dropouts(nn.Module): - def forward(self, x): - for _ in range(100): - x = F.dropout(x, p=0.001) - return x - - model = nn.Sequential(Dropouts(), Dropouts()) - - x = torch.rand(10, 10, requires_grad=True) - model = Pipe(model, chunks=10, checkpoint="always") - y = model(x) - y = y.local_value() - y.norm().backward() - - assert y.to(torch.bool).tolist() == x.grad.to(torch.bool).tolist() - - -if __name__ == "__main__": - run_tests() diff --git a/test/distributed/pipeline/sync/test_checkpoint.py b/test/distributed/pipeline/sync/test_checkpoint.py deleted file mode 100644 index 7be8ddefafe9e..0000000000000 --- a/test/distributed/pipeline/sync/test_checkpoint.py +++ /dev/null @@ -1,178 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -from functools import partial - -import pytest - -import torch -import torch.cuda -from torch import nn - -from torch.distributed.pipeline.sync.checkpoint import ( - checkpoint, - Checkpointing, - is_checkpointing, - is_recomputing, -) -from torch.distributed.pipeline.sync.dependency import fork, join -from torch.distributed.pipeline.sync.microbatch import Batch -from torch.testing._internal.common_utils import run_tests - -devices = ["cpu"] -if torch.cuda.is_available(): - devices.append("cuda") - - -@pytest.mark.parametrize("device", devices) -def test_serial_checkpoints(device): - # Copied from https://github.com/pytorch/pytorch/pull/18568. - timeline = [] - - class Log(torch.autograd.Function): - @staticmethod - def forward(ctx, name, x): - ctx.name = name - timeline.append(f"{name}:forward") - return x.detach() - - @staticmethod - def backward(ctx, grad_output): - name = ctx.name - timeline.append(f"{name}:backward") - return None, grad_output - - a = torch.rand(1, device=device, requires_grad=True) - b = torch.rand(1, device=device, requires_grad=True) - - # Increase the next function sequence number. - _ = a + 1 + 2 + 3 + 4 + 5 - - a = checkpoint(partial(Log.apply, "a"), a) - - a, phony = fork(a) - b = join(b, phony) - - b = checkpoint(partial(Log.apply, "b"), b) - - c = torch.cat((a, b)) - - out = c.sum() - - # +--> {a} --Checkpoint(Log)--> {a} - # {out} --Sum--> {c} --Cat ^-----------------------------+ - # +--> {b} --Checkpoint(Log)--> {b} --First--> {b} - out.backward() - - assert timeline == [ - "a:forward", - "b:forward", - "b:forward", - "b:backward", - "a:forward", - "a:backward", - ] - # |----------------------| |-----------------------| |-----------------------| - # forward pass Checkpoint(Log[b]) Checkpoint(Log[a]) - - -def test_not_requires_grad(): - x = Batch(torch.rand(1, requires_grad=False)) - assert not x[0].requires_grad - - def f(x): - return x * 2 - - chk = Checkpointing(f, x) - x = chk.checkpoint() - assert x[0].requires_grad - - chk.recompute(x) - assert x[0].requires_grad - - x.tensor.backward() - - -def test_not_requires_grad_with_parameter(): - x = torch.rand(1, requires_grad=False) - a = torch.rand(1, requires_grad=True) - - def f(x): - return x * a - - y = checkpoint(f, x) - y.backward() - - assert a.grad is not None - - -@pytest.mark.parametrize("device", devices) -def test_random_in_checkpoint(device): - dropout = nn.Dropout(p=0.5) - - torch.manual_seed(0) - x = torch.randn(3, 3, device=device, requires_grad=True) - y = dropout(x) - y.norm().backward() - - torch.manual_seed(0) - chk_x = torch.randn(3, 3, device=device, requires_grad=True) - chk_y = checkpoint(dropout, chk_x) - chk_y.norm().backward() - - assert torch.allclose(x.grad, chk_x.grad) - - -def test_detect_checkpointing_recomputing(): - logs = [] - - class Detect(nn.Module): - def forward(self, input): - logs.append((is_checkpointing(), is_recomputing())) - return input - - model = Detect() - input = torch.rand(1, requires_grad=True) - - output = checkpoint(model, input) - output.backward() - - assert logs == [(True, False), (False, True)] - - -def test_detect_checkpointing_recomputing_without_checkpoint(): - logs = [] - - class Detect(nn.Module): - def forward(self, input): - logs.append((is_checkpointing(), is_recomputing())) - return input - - model = Detect() - input = torch.rand(1, requires_grad=True) - - output = model(input) - output.backward() - - assert logs == [(False, False)] - - -def test_non_grad_output(): - class ForkNonGrad(nn.Module): - def forward(self, input): - return (input * 2, torch.rand(1)) - - model = ForkNonGrad() - input = torch.rand(1, requires_grad=True) - - output = checkpoint(model, input) - output[0].backward() - - -if __name__ == "__main__": - run_tests() diff --git a/test/distributed/pipeline/sync/test_copy.py b/test/distributed/pipeline/sync/test_copy.py deleted file mode 100644 index 302c3d25d53f4..0000000000000 --- a/test/distributed/pipeline/sync/test_copy.py +++ /dev/null @@ -1,85 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -import pytest - -import torch - -from torch.distributed.pipeline.sync.copy import Copy, Wait -from torch.distributed.pipeline.sync.stream import ( - CPUStream, - current_stream, - get_device, - is_cuda, - new_stream, - use_stream, -) -from torch.testing._internal.common_utils import run_tests - -skip_if_no_cuda = pytest.mark.skipif( - not torch.cuda.is_available(), reason="cuda required" -) - - -def _test_copy_wait(prev_stream, next_stream, cuda_sleep=None): - device = get_device(prev_stream) - - with use_stream(prev_stream): - if is_cuda(prev_stream): - cuda_sleep(0.5) - x = torch.ones(100, device=device, requires_grad=True) - - (y,) = Copy.apply(prev_stream, next_stream, x) - (y,) = Wait.apply(prev_stream, next_stream, x) - - with use_stream(next_stream): - assert torch.allclose(y.sum(), torch.tensor(100.0, device=device)) - y.norm().backward() - with use_stream(prev_stream): - assert torch.allclose(x.grad.sum(), torch.tensor(10.0, device=device)) - - -def test_copy_wait_cpu_cpu(): - prev_stream = CPUStream - next_stream = CPUStream - _test_copy_wait(prev_stream, next_stream) - - -@skip_if_no_cuda -def test_copy_wait_cpu_cuda(cuda_sleep): - prev_stream = CPUStream - next_stream = current_stream(torch.device("cuda")) - _test_copy_wait(prev_stream, next_stream, cuda_sleep) - - -@skip_if_no_cuda -def test_copy_wait_cuda_cpu(cuda_sleep): - prev_stream = current_stream(torch.device("cuda")) - next_stream = CPUStream - _test_copy_wait(prev_stream, next_stream, cuda_sleep) - - -@skip_if_no_cuda -def test_copy_wait_cuda_cuda(cuda_sleep): - prev_stream = current_stream(torch.device("cuda")) - next_stream = new_stream(torch.device("cuda")) - _test_copy_wait(prev_stream, next_stream, cuda_sleep) - - -def test_wait_multiple_tensors(): - a = torch.rand(1, requires_grad=True) - b = torch.rand(1, requires_grad=True) - - a, b = Wait.apply(CPUStream, CPUStream, a, b) - - assert a.grad_fn is b.grad_fn - assert a.grad_fn.__class__ is Wait._backward_cls - - -if __name__ == "__main__": - run_tests() diff --git a/test/distributed/pipeline/sync/test_deferred_batch_norm.py b/test/distributed/pipeline/sync/test_deferred_batch_norm.py deleted file mode 100644 index c3807c57d612e..0000000000000 --- a/test/distributed/pipeline/sync/test_deferred_batch_norm.py +++ /dev/null @@ -1,200 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -from copy import deepcopy -from itertools import chain - -import pytest - -import torch -from torch import nn, optim - -from torch.distributed.pipeline.sync.batchnorm import DeferredBatchNorm -from torch.testing._internal.common_utils import run_tests - -CHUNKS = 4 - - -def tilt_dist(input): - # Tilt variance by channel. - rgb = input.transpose(0, 1) - rgb[0] *= 1 - rgb[1] *= 10 - rgb[2] *= 100 - - # Tilt mean by single batch. - for i, single in enumerate(input): - single += 2**i - - return input - - -def chunked_forward(model, input, chunks=CHUNKS): - output_chunks = [] - - for chunk in input.chunk(chunks): - output_chunks.append(model(chunk)) - - return torch.cat(output_chunks) - - -@pytest.mark.parametrize("chunks", [1, 4]) -@pytest.mark.parametrize("input_requires_grad", [True, False]) -def test_transparency(chunks, input_requires_grad): - bn = nn.BatchNorm2d(3) - dbn = DeferredBatchNorm.convert_deferred_batch_norm(deepcopy(bn), chunks=chunks) - - input1 = torch.rand(16, 3, 224, 224) - input1 = tilt_dist(input1) - input2 = input1.clone() - input1.requires_grad = input_requires_grad - input2.requires_grad = input_requires_grad - - output1 = chunked_forward(bn, input1, chunks=chunks) - output2 = chunked_forward(dbn, input2, chunks=chunks) - - assert torch.allclose(output1, output2, atol=1e-4) - - output1.mean().backward() - output2.mean().backward() - - assert torch.allclose(bn.weight.grad, dbn.weight.grad, atol=1e-4) - - if input_requires_grad: - assert input1.grad is not None - assert input2.grad is not None - assert torch.allclose(input1.grad, input2.grad, atol=1e-4) - - -@pytest.mark.parametrize("momentum", [0.1, None]) -def test_running_stats(momentum): - bn = nn.BatchNorm2d(3, momentum=momentum) - dbn = DeferredBatchNorm.convert_deferred_batch_norm(deepcopy(bn), chunks=CHUNKS) - - input = torch.rand(16, 3, 224, 224) - input = tilt_dist(input) - - bn(input) - chunked_forward(dbn, input) - - assert torch.allclose(bn.running_mean, dbn.running_mean, atol=1e-4) - assert torch.allclose(bn.running_var, dbn.running_var, atol=1e-4) - - -def test_convert_deferred_batch_norm(): - bn = nn.BatchNorm2d(3, track_running_stats=False) - bn = DeferredBatchNorm.convert_deferred_batch_norm(bn, chunks=CHUNKS) - assert type(bn) is nn.BatchNorm2d # because of track_running_stats=False - - dbn = DeferredBatchNorm(3, chunks=CHUNKS) - dbn_again = DeferredBatchNorm.convert_deferred_batch_norm(dbn, chunks=CHUNKS) - assert dbn is dbn_again - - dbn_again = DeferredBatchNorm.convert_deferred_batch_norm(dbn, chunks=CHUNKS + 1) - assert dbn is not dbn_again # because of different chunks - - -def test_eval(): - bn = nn.BatchNorm2d(3) - dbn = DeferredBatchNorm.convert_deferred_batch_norm(deepcopy(bn), chunks=CHUNKS) - - input = torch.rand(16, 3, 224, 224) - input = tilt_dist(input) - - bn(input) - chunked_forward(dbn, input) - - bn.eval() - dbn.eval() - - assert torch.allclose(bn(input), dbn(input), atol=1e-4) - - -def test_optimize(): - bn = nn.BatchNorm2d(3) - dbn = DeferredBatchNorm.convert_deferred_batch_norm(deepcopy(bn), chunks=CHUNKS) - - opt = optim.SGD(chain(bn.parameters(), dbn.parameters()), lr=1.0) - - for i in range(5): - input = torch.rand(16, 3, 224, 224) - input = tilt_dist(input) - - # train - y = bn(input) - a = y.sum() - a.backward() - - y = chunked_forward(dbn, input) - b = y.sum() - b.backward() - - opt.step() - - # eval - bn.eval() - dbn.eval() - - with torch.no_grad(): - assert torch.allclose(bn(input), dbn(input), atol=1e-1 * (10**i)) - - -def test_conv_bn(): - bn = nn.Sequential(nn.Conv2d(3, 3, 1), nn.BatchNorm2d(3)) - dbn = DeferredBatchNorm.convert_deferred_batch_norm(deepcopy(bn), chunks=CHUNKS) - - input = torch.rand(16, 3, 224, 224) - input = tilt_dist(input) - - opt = optim.SGD(chain(bn.parameters(), dbn.parameters()), lr=0.1) - - # 1st step - a = bn(input) - b = chunked_forward(dbn, input) - - # Outputs are different. (per-mini-batch vs. per-micro-batch) - assert not torch.allclose(a, b) - - a.sum().backward() - b.sum().backward() - opt.step() - opt.zero_grad() - - # Conv layers are also trained differently because of their different outputs. - assert not torch.allclose(bn[0].weight, dbn[0].weight) - - # But BNs track identical running stats. - assert torch.allclose(bn[1].running_mean, dbn[1].running_mean, atol=1e-4) - assert torch.allclose(bn[1].running_var, dbn[1].running_var, atol=1e3) - - # 2nd step - a = bn(input) - b = chunked_forward(dbn, input) - a.sum().backward() - b.sum().backward() - - # BNs can't track identical running stats due to the different conv layers. - assert not torch.allclose(bn[1].running_mean, dbn[1].running_mean, atol=1e-4) - assert not torch.allclose(bn[1].running_var, dbn[1].running_var, atol=1e3) - - -def test_input_requiring_grad(): - dbn = DeferredBatchNorm(3, chunks=CHUNKS) - - input = torch.rand(16, 3, 224, 224) - input = tilt_dist(input) - input.requires_grad = True - - chunked_forward(dbn, input) - - assert not dbn.sum.requires_grad - assert dbn.sum.grad_fn is None - - -if __name__ == "__main__": - run_tests() diff --git a/test/distributed/pipeline/sync/test_dependency.py b/test/distributed/pipeline/sync/test_dependency.py deleted file mode 100644 index e966d6541bf59..0000000000000 --- a/test/distributed/pipeline/sync/test_dependency.py +++ /dev/null @@ -1,152 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -import weakref - -import pytest - -import torch - -from torch.distributed.pipeline.sync.dependency import Fork, fork, Join, join -from torch.testing._internal.common_utils import run_tests - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") -def test_fork_join(): - logs = [] - - class Log(torch.autograd.Function): - @staticmethod - def forward(ctx, number, tensor): - ctx.number = number - return tensor.detach() - - @staticmethod - def backward(ctx, grad): - logs.append(ctx.number) - return None, grad - - a = torch.rand(1, device="cpu", requires_grad=True) - b = torch.rand(1, device="cuda", requires_grad=True) - - a = Log.apply(1, a) - - a, phony = fork(a) - b = join(a, phony) - - b = Log.apply(2, b) - b = b.to("cpu") - - (a + b).backward() - - assert logs == [2, 1] - - -def test_fork_join_enable_grad(): - x = torch.rand(1, requires_grad=True) - - with torch.enable_grad(): - x2, p = fork(x) - - assert p.requires_grad - assert x2 is not x - x = x2 - - assert x.requires_grad - assert p.requires_grad - assert x.grad_fn.__class__ is Fork._backward_cls - assert p.grad_fn.__class__ is Fork._backward_cls - - with torch.enable_grad(): - x2 = join(x, p) - - assert x2 is not x - x = x2 - - assert x.requires_grad - assert x.grad_fn.__class__ is Join._backward_cls - - -def test_fork_join_no_grad(monkeypatch): - def do_not_apply(*args): - raise AssertionError("Function.apply called") - - monkeypatch.setattr("torch.autograd.Function.apply", do_not_apply) - - x = torch.rand(1, requires_grad=True) - - with torch.no_grad(): - x2, p = fork(x) - - assert not p.requires_grad - assert x2 is x - x = x2 - - with torch.no_grad(): - x2 = join(x, p) - - assert x2 is x - x = x2 - - -def test_fork_leak(): - leak = None - - class F(torch.autograd.Function): - @staticmethod - def forward(ctx, input): - return input - - @staticmethod - def backward(ctx, grad): - nonlocal leak - leak = weakref.ref(ctx) - return grad - - x = torch.rand(1, requires_grad=True) - x = F.apply(x) - x, phony = fork(x) - x = join(x, phony) - - x.backward() - del x, phony - - assert leak() is None - - -def test_join_when_fork_not_requires_grad(): - x = torch.rand(2, 1) - a, b = x.chunk(2) - - assert not a.requires_grad - a, p = fork(a) - assert not a.requires_grad - assert not p.requires_grad - - assert not b.requires_grad - b = join(b, p) - assert not b.requires_grad - - -def test_join_when_fork_requires_grad(): - x = torch.rand(2, 1) - a, b = x.chunk(2) - - a.requires_grad_() - assert a.requires_grad - a, p = fork(a) - assert a.requires_grad - assert p.requires_grad - - assert not b.requires_grad - b = join(b, p) - assert b.requires_grad - - -if __name__ == "__main__": - run_tests() diff --git a/test/distributed/pipeline/sync/test_inplace.py b/test/distributed/pipeline/sync/test_inplace.py deleted file mode 100644 index 33f31b2a52bb8..0000000000000 --- a/test/distributed/pipeline/sync/test_inplace.py +++ /dev/null @@ -1,79 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -import pytest - -import torch -from torch import nn - -from torch.distributed.pipeline.sync import Pipe -from torch.testing._internal.common_utils import run_tests - - -def test_inplace_on_requires_grad(setup_rpc): - model = nn.Sequential(nn.Linear(1, 1), nn.ReLU(inplace=True)) - model = Pipe(model, checkpoint="always") - - x = torch.rand(1) - y = model(x).local_value() - - message = r"a leaf Variable that requires grad .* used in an in-place operation." - with pytest.raises(RuntimeError, match=message): - y.backward() - - -@pytest.mark.xfail(strict=True) -def test_inplace_on_not_requires_grad(setup_rpc): - # In-place operation on a tensor not requiring grad doesn't cause a - # RuntimeError. Currently, we cannot detect this case. - model = nn.Sequential(nn.ReLU(inplace=True)) - model = Pipe(model, [1], devices=["cpu"], checkpoint="always") - - x = torch.rand(1) - y = model(x).local_value() - del model - - message = r"a leaf Variable that requires grad .* used in an in-place operation." - with pytest.raises(RuntimeError, match=message): - y.backward() - - -@pytest.mark.xfail(strict=True) -def test_inplace_incorrect_grad(setup_rpc): - class M(nn.Module): - def forward(self, foo_bar): - # 'foo' requires grad but 'bar' does not. In-place operation on - # 'bar' won't cause a RuntimeError. - foo, bar = foo_bar - - # add_(1) is not idempotent, in contrast to relu_(). If it is - # executed multiple times, it will accumulates each difference onto - # 'bar'. - bar.add_(1) - - # 'bar' is still captured by checkpointing. 'foo' will get - # incorrect grad. - return foo * bar - - model = nn.Sequential(M()) - model = Pipe(model, [1], devices=["cpu"], checkpoint="always") - - foo = torch.tensor([1.0], requires_grad=True) - bar = torch.tensor([1.0]) - - output = model((foo, bar)).local_value() - del model - output.backward() - - # The gradient of 'foo' should be 2, but it is 3 actually because - # bar.add_(1) was executed twice due to checkpointing. - assert foo.grad.item() == 2.0 - - -if __name__ == "__main__": - run_tests() diff --git a/test/distributed/pipeline/sync/test_microbatch.py b/test/distributed/pipeline/sync/test_microbatch.py deleted file mode 100644 index b5e44aa73a8d2..0000000000000 --- a/test/distributed/pipeline/sync/test_microbatch.py +++ /dev/null @@ -1,148 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -import pytest - -import torch -import torch.cuda - -from torch.distributed.pipeline.sync.microbatch import Batch, check, gather, scatter -from torch.testing._internal.common_utils import run_tests - - -def test_batch_atomic(): - x = torch.tensor(42) - b = Batch(x) - - assert b.atomic - - assert b.tensor is x - with pytest.raises(AttributeError): - b.tensors - - assert list(b) == [x] - assert len(b) == 1 - assert b[0] is x - - -def test_batch_non_atomic(): - x, y = torch.tensor(42), torch.tensor(21) - b = Batch((x, y)) - - assert not b.atomic - - with pytest.raises(AttributeError): - b.tensor - - assert list(b) == [x, y] - assert len(b) == 2 - assert b[0] is x - assert b[1] is y - - -def test_batch_call(): - a = Batch(torch.tensor(42)) - b = Batch((torch.tensor(42), torch.tensor(21))) - - def f(x): - return x - - def g(x, y): - return x, y - - assert a.call(f).atomic - assert not b.call(g).atomic - - -def test_batch_setitem_by_index(): - a = Batch(torch.tensor(42)) - b = Batch((torch.tensor(42), torch.tensor(21))) - - a[0] = torch.tensor(0) - b[0] = torch.tensor(0) - - assert a.atomic - assert a[0].item() == 0 - - assert not b.atomic - assert len(b) == 2 - assert b[0].item() == 0 - assert b[1].item() == 21 - - -def test_batch_setitem_by_slice(): - a = Batch(torch.tensor(42)) - b = Batch((torch.tensor(42), torch.tensor(21))) - - a[:] = (torch.tensor(0),) - b[:] = (torch.tensor(0),) - - assert a.atomic - assert a[0].item() == 0 - - assert not b.atomic - assert len(b) == 1 - assert b[0].item() == 0 - - -def test_check(): - check(torch.device("cpu"), torch.tensor(42)) - check(torch.device("cpu"), torch.tensor(4), torch.tensor(2)) - - with pytest.raises(TypeError): - check(torch.device("cpu"), 42) - - with pytest.raises(TypeError): - check(torch.device("cpu"), "str") - - with pytest.raises(TypeError): - check(torch.device("cpu"), (torch.tensor(4), 2)) - - -def test_gather_tensors(): - a = torch.zeros(1, 1) - b = torch.zeros(1, 1) - - ab = gather([Batch(a), Batch(b)]) - - assert ab.size() == (2, 1) - - -def test_gather_tuples(): - a = (torch.zeros(1, 1), torch.zeros(2, 2)) - b = (torch.zeros(1, 1), torch.zeros(2, 2)) - - ab = gather([Batch(a), Batch(b)]) - - assert isinstance(ab, tuple) - assert ab[0].size() == (2, 1) - assert ab[1].size() == (4, 2) - - -def test_scatter_tensor(): - ab = torch.zeros(2, 1) - - a, b = scatter(ab, chunks=2) - - assert a.tensor.size() == (1, 1) - assert b.tensor.size() == (1, 1) - - -def test_scatter_multiple_tensors(): - ab = (torch.zeros(2, 1), torch.zeros(4, 2)) - - a, b = scatter(*ab, chunks=2) - - assert next(iter(a)).size() == (1, 1) - assert next(iter(b)).size() == (1, 1) - assert list(a)[1].size() == (2, 2) - assert list(b)[1].size() == (2, 2) - - -if __name__ == "__main__": - run_tests() diff --git a/test/distributed/pipeline/sync/test_phony.py b/test/distributed/pipeline/sync/test_phony.py deleted file mode 100644 index 6aeb873b30b2b..0000000000000 --- a/test/distributed/pipeline/sync/test_phony.py +++ /dev/null @@ -1,57 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -import torch - -from torch.distributed.pipeline.sync.phony import get_phony -from torch.testing._internal.common_utils import run_tests - - -def test_phony_size(): - p = get_phony(torch.device("cpu"), requires_grad=False) - assert p.size() == (0,) - - -def test_phony_requires_grad(): - p1 = get_phony(torch.device("cpu"), requires_grad=True) - p2 = get_phony(torch.device("cpu"), requires_grad=False) - assert p1.requires_grad - assert not p2.requires_grad - - -def test_cached_phony(): - p1 = get_phony(torch.device("cpu"), requires_grad=True) - p2 = get_phony(torch.device("cpu"), requires_grad=True) - assert p1 is p2 - - p3 = get_phony(torch.device("cpu"), requires_grad=False) - p4 = get_phony(torch.device("cpu"), requires_grad=False) - assert p3 is p4 - - assert p1 is not p3 - - -def test_phony_in_autograd_function(): - class Phonify(torch.autograd.Function): - @staticmethod - def forward(ctx, input): - phony = get_phony(input.device, requires_grad=False) - return phony.detach() - - x = torch.rand(1, requires_grad=True) - - p1 = Phonify.apply(x) - p2 = get_phony(torch.device("cpu"), requires_grad=True) - - assert p1 is not p2 - assert p1.grad_fn is not None - assert p2.grad_fn is None - - -if __name__ == "__main__": - run_tests() diff --git a/test/distributed/pipeline/sync/test_pipe.py b/test/distributed/pipeline/sync/test_pipe.py deleted file mode 100644 index e493b1d5a03e7..0000000000000 --- a/test/distributed/pipeline/sync/test_pipe.py +++ /dev/null @@ -1,858 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -import random -import time -from collections import OrderedDict -from copy import deepcopy - -import pytest - -import torch -from torch import nn, Tensor - -from torch.distributed.pipeline.sync import NoChunk, Pipe, WithDevice -from torch.distributed.pipeline.sync.pipe import PipeSequential -from torch.testing._internal.common_cuda import TEST_MULTIGPU -from torch.testing._internal.common_utils import run_tests, TEST_CUDA - -skip_if_no_cuda = pytest.mark.skipif(not TEST_CUDA, reason="cuda required") - - -def test_pipe_without_rpc(): - model = nn.Sequential(nn.Linear(1, 1)) - with pytest.raises(RuntimeError, match="Please initialize RPC framework"): - pipe = Pipe(model, chunks=1) - - -def test_parameters(setup_rpc): - model = nn.Sequential(nn.Linear(1, 1)) - pipe = Pipe(model, chunks=1) - assert list(pipe.parameters()) != [] - - -def test_public_attrs(setup_rpc): - class MyString: - def __init__(self, value): - self.value = value - - def __str__(self): - return self.value - - model = nn.Sequential(nn.Linear(1, 1)) - pipe = Pipe(model, chunks=42.000, checkpoint=MyString("always")) - - assert pipe.devices == [torch.device("cpu")] - assert pipe.chunks == 42 - assert isinstance(pipe.chunks, int) - assert pipe.checkpoint == "always" - assert isinstance(pipe.checkpoint, str) - - -def test_sequential_like(setup_rpc): - a = nn.Linear(1, 1) - b = nn.Linear(1, 1) - - model = nn.Sequential(a, b) - model = Pipe(model) - - assert len(model) == 2 - assert list(model) == [a, b] - - assert model[0] is a - assert model[1] is b - with pytest.raises(IndexError): - _ = model[2] - - assert model[-1] is b - assert model[-2] is a - - -def test_chunks_less_than_1(setup_rpc): - model = nn.Sequential(nn.Linear(1, 1)) - - with pytest.raises(ValueError): - Pipe(model, chunks=0) - - with pytest.raises(ValueError): - Pipe(model, chunks=-1) - - -def test_batch_size_indivisible(setup_rpc): - model = nn.Sequential(nn.Linear(1, 1)) - model = Pipe(model, chunks=4) - - with pytest.warns(None) as record: - model(torch.rand(7, 1)) - - # Indivisible batch size is legal. - assert not record - - -def test_batch_size_small(setup_rpc): - model = nn.Sequential(nn.Linear(1, 1)) - model = Pipe(model, chunks=4) - - with pytest.warns(None) as record: - model(torch.rand(2, 1)) - - # Batch size smaller than chunks is legal. - assert not record - - -def test_checkpoint_mode(setup_rpc): - def count_grad_fn(grad_fn, name, visited=None): - if visited is None: - visited = set() - if grad_fn in visited: - return 0 - visited.add(grad_fn) - - if grad_fn is None: - return 0 - if grad_fn.__class__.__name__ == name: - return 1 - - counter = 0 - for next_grad_fn, _ in grad_fn.next_functions: - counter += count_grad_fn(next_grad_fn, name, visited=visited) - return counter - - model = nn.Sequential(nn.Linear(1, 1)) - input = torch.rand(2, 1) - - always = Pipe(model, chunks=2, checkpoint="always") - except_last = Pipe(model, chunks=2, checkpoint="except_last") - never = Pipe(model, chunks=2, checkpoint="never") - - always_output = always(input) - except_last_output = except_last(input) - never_output = never(input) - - assert count_grad_fn(always_output.local_value().grad_fn, "CheckpointBackward") == 2 - assert ( - count_grad_fn(except_last_output.local_value().grad_fn, "CheckpointBackward") - == 1 - ) - assert count_grad_fn(never_output.local_value().grad_fn, "CheckpointBackward") == 0 - - -def test_checkpoint_mode_invalid(setup_rpc): - model = nn.Sequential(nn.Linear(1, 1)) - - with pytest.raises( - ValueError, match="checkpoint is not one of 'always', 'except_last', or 'never'" - ): - Pipe(model, chunks=2, checkpoint="INVALID_CHECKPOINT") - - -def test_checkpoint_mode_when_chunks_1(setup_rpc): - model = nn.Sequential(nn.Linear(1, 1)) - - # All checkpoint modes are fine. - Pipe(model, chunks=1, checkpoint="except_last") - Pipe(model, chunks=1, checkpoint="always") - Pipe(model, chunks=1, checkpoint="never") - - -def test_checkpoint_eval(setup_rpc): - model = nn.Sequential(nn.Linear(1, 1)) - model = Pipe(model, chunks=2) - input = torch.rand(2, 1) - - def find_grad_fn(grad_fn, name): - if grad_fn is None: - return False - if grad_fn.__class__.__name__ == name: - return True - for next_grad_fn, _ in grad_fn.next_functions: - if find_grad_fn(next_grad_fn, name): - return True - return False - - model.train() - train_output = model(input) - assert find_grad_fn(train_output.local_value().grad_fn, "CheckpointBackward") - assert find_grad_fn(train_output.local_value().grad_fn, "RecomputeBackward") - - model.eval() - eval_output = model(input) - assert not find_grad_fn(eval_output.local_value().grad_fn, "CheckpointBackward") - assert not find_grad_fn(eval_output.local_value().grad_fn, "RecomputeBackward") - - -def test_checkpoint_non_float_input(setup_rpc): - class ForkNonFloat(nn.Module): - def forward(self, input): - return (input * 2, torch.tensor([False])) - - class JoinNonFloat(nn.Module): - def forward(self, input, non_float): - return input * 2 - - model = nn.Sequential(ForkNonFloat(), JoinNonFloat()) - model = Pipe(model, chunks=1, checkpoint="always") - - input = torch.rand(1, requires_grad=True) - output = model(input) - output.backward() - - -def test_no_grad(setup_rpc): - model = nn.Sequential(nn.Linear(1, 1)) - model = Pipe(model, chunks=2) - input = torch.rand(2, 1) - - latent = None - - def hook(module, input, output): - _ = module - _ = input - - nonlocal latent - latent = output - - partition = model.partitions[0] - partition.register_forward_hook(hook) - - with torch.no_grad(): - model(input) - - assert latent.grad_fn is None - - -def test_exception(setup_rpc): - class ExpectedException(Exception): - pass - - class Raise(nn.Module): - def forward(self, *_): - raise ExpectedException - - model = nn.Sequential(Raise()) - model = Pipe(model, chunks=1) - - with pytest.raises(ExpectedException): - model(torch.rand(1)) - - -def test_exception_early_stop_asap(setup_rpc): - """Even the first partitions have finished to process, the partition before - the failed partition should be killed as soon as possible. - """ - - class ExpectedException(Exception): - pass - - class Pass(nn.Module): - def forward(self, x): - return x - - counter = 0 - - class Counter(nn.Module): - def forward(self, x): - time.sleep(0.1) - - nonlocal counter - counter += 1 - - return x - - class Raise(nn.Module): - def forward(self, x): - raise ExpectedException - - model = nn.Sequential(Pass(), Pass(), Counter(), Raise()) - model = Pipe(model, chunks=3) - - with pytest.raises(ExpectedException): - model(torch.rand(3)) - - # If the early stop doesn't work, it would be 3 instead. - assert counter == 2 - - -def test_nested_input(setup_rpc): - class NestedInput(nn.Module): - def __init__(self): - super().__init__() - self.fc_a = nn.Linear(1, 1) - self.fc_b = nn.Linear(1, 1) - - def forward(self, inp): - return inp - - model = nn.Sequential(NestedInput()) - model = Pipe(model, chunks=2) - - a = torch.rand(10, 1, requires_grad=True) - b = torch.rand(10, 1, requires_grad=True) - - # TypeError: expected Tensor, but got tuple - with pytest.raises(TypeError): - model((a, (a, b))).local_value() - - # TypeError: expected Tensor, but got list - with pytest.raises(TypeError): - model((a, [a, b])).local_value() - - -def test_input_pair(setup_rpc): - class Two(nn.Module): - def __init__(self): - super().__init__() - self.fc_a = nn.Linear(1, 1) - self.fc_b = nn.Linear(1, 1) - - def forward(self, a, b): - return (self.fc_a(a), self.fc_b(b)) - - model = nn.Sequential(Two()) - model = Pipe(model, chunks=2) - - a = torch.rand(10, 1, requires_grad=True) - b = torch.rand(10, 1, requires_grad=True) - - a_out, b_out = model(a, b).local_value() - loss = (a_out + b_out).mean() - loss.backward() - - assert a.grad is not None - assert b.grad is not None - - -def test_multi_sequence_input(setup_rpc): - class MultiSeq(nn.Module): - def forward(self, tup1, tup2): - return tup1, tup2 - - model = Pipe(nn.Sequential(MultiSeq())) - with pytest.raises(TypeError): - model([torch.rand(10), torch.rand(10)], [torch.rand(10), torch.rand(10)]) - - -def test_input_singleton(setup_rpc): - class One(nn.Module): - def __init__(self): - super().__init__() - self.fc = nn.Linear(1, 1) - - def forward(self, a): - return (self.fc(a),) - - model = nn.Sequential(One()) - model = Pipe(model, chunks=2) - - a = torch.rand(10, 1, requires_grad=True) - - (a_out,) = model(a).local_value() - loss = a_out.mean() - loss.backward() - - assert all(p.grad is not None for p in model.parameters()) - assert a.grad is not None - - -def test_input_varargs(setup_rpc): - model = nn.Sequential(nn.Linear(1, 1)) - model = Pipe(model) - - a = torch.rand(1) - b = torch.rand(1) - - # TypeError: forward() takes 2 positional arguments but 3 were given - with pytest.raises(TypeError): - model(a, b) - - -def test_non_tensor(setup_rpc): - class NonTensor(nn.Module): - def forward(self, _): - return "hello" - - model = nn.Sequential(NonTensor()) - model = Pipe(model) - x = torch.rand(1) - - with pytest.raises(TypeError): - model(x) - - with pytest.raises(TypeError): - model("hello") - - -def test_non_tensor_sequence(setup_rpc): - class NonTensorTuple(nn.Module): - def forward(self, x): - return (x, "hello") - - class NonTensorArgs(nn.Module): - def forward(self, x: str, y: bool): - return x, y - - model = nn.Sequential(NonTensorTuple()) - model = Pipe(model) - x = torch.rand(1) - - with pytest.raises(TypeError): - model((x, "hello")) - - with pytest.raises(TypeError): - model([x, "hello"]) - - model = nn.Sequential(NonTensorArgs()) - model = Pipe(model) - - with pytest.raises(TypeError): - # Need atleast one Tensor. - model("hello", True) - - -@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"]) -def test_valid_non_tensor(checkpoint, setup_rpc): - class NonTensor1(nn.Module): - def forward(self, a: int, b: Tensor, c: bool, d: Tensor): - res = b + a if c else b * a - if d is not None: - res += d - return res, c, a, b, "hello", d - - class NonTensor2(nn.Module): - def forward(self, a: Tensor, b: bool, c: int, d: Tensor, e: str, f: Tensor): - res = a * c if b else a + c - res += d - return c, res, a, d + f if f is not None else d, b, e, f - - model = Pipe( - nn.Sequential(NonTensor1(), NonTensor2()), chunks=5, checkpoint=checkpoint - ) - a = random.randint(0, 10) - b = torch.rand(10, 10) - c = random.randint(0, 1) == 0 - d = torch.rand(10, 10) - res = model(a, b, c, d).local_value() - assert 7 == len(res) - assert [a] * 5 == res[0] - if c: - assert torch.allclose(((b + a + d) * a) + b, res[1]) - assert torch.allclose(b + a + d, res[2]) - else: - assert torch.allclose(((b * a) + d + a) + b, res[1]) - assert torch.allclose(b * a + d, res[2]) - assert torch.allclose(b + d, res[3]) - assert [c] * 5 == res[4] - assert ["hello"] * 5 == res[5] - assert torch.allclose(d, res[6]) - - # Test one of the tensors can be None - res = model(a, b, c, None).local_value() - assert 7 == len(res) - assert [a] * 5 == res[0] - if c: - assert torch.allclose(((b + a) * a) + b, res[1]) - assert torch.allclose(b + a, res[2]) - else: - assert torch.allclose(((b * a) + a) + b, res[1]) - assert torch.allclose(b * a, res[2]) - assert torch.allclose(b, res[3]) - assert [c] * 5 == res[4] - assert ["hello"] * 5 == res[5] - assert [None] * 5 == res[6] - - # Need atleast one tensor. - with pytest.raises(TypeError): - model(a, None, c, None) - - -@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"]) -def test_no_tensor_output(checkpoint, setup_rpc): - class Model1(nn.Module): - def forward(self, a: int, b: Tensor, c: bool): - return a, c, "hello" - - class Model2(nn.Module): - def forward(self, a: int, b: bool, c: str): - return a, c, b - - model = Pipe(nn.Sequential(Model1(), Model2()), chunks=5) - a = random.randint(0, 10) - b = torch.rand(10, 10) - c = random.randint(0, 1) == 0 - - # Need atleast one tensor across partitions too. - with pytest.raises(TypeError): - res = model(a, b, c).local_value() - - -@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"]) -def test_uneven_batch_size(checkpoint, setup_rpc): - class Model(nn.Module): - def forward(self, a: Tensor, b: int, c: Tensor): - return a, b, c - - model = Pipe(nn.Sequential(Model()), checkpoint=checkpoint, chunks=5) - a = torch.rand(3, 10) - b = random.randint(0, 10) - c = torch.rand(6, 10) - res = model(a, b, c).local_value() - assert torch.allclose(a, res[0]) - assert [b] * 3 == res[1] # 3 chunks - assert torch.allclose(c, res[2]) - - # Two tensors producing uneven chunks would fail. - model = Pipe(nn.Sequential(Model()), checkpoint=checkpoint, chunks=5) - a = torch.rand(3, 10) - b = random.randint(0, 10) - c = torch.rand(4, 10) - - with pytest.raises(RuntimeError, match="Found different number of chunks"): - model(a, b, c) - - -@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"]) -def test_no_chunk(checkpoint, setup_rpc): - class Model(nn.Module): - def forward(self, a: Tensor, b: int, c: Tensor): - return a, b, c - - model = Pipe(nn.Sequential(Model()), checkpoint=checkpoint, chunks=5) - a = torch.rand(10, 10) - b = random.randint(0, 10) - c = torch.rand(10, 10) - res = model(a, b, NoChunk(c)).local_value() - assert torch.allclose(a, res[0]) - assert [b] * 5 == res[1] - # c gets replicated due to NoChunk and the same tensor gets concatenated 5 - # times in the output. - assert torch.allclose(torch.cat((c, c, c, c, c)), res[2]) - - # Test invalid type for NoChunk - with pytest.raises(TypeError, match="NoChunk only supported for tensors"): - NoChunk(b) - - -@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"]) -def test_deferred_batch_norm(checkpoint, setup_rpc): - bn = nn.BatchNorm2d(3) - pipe_bn = deepcopy(bn) - pipe = Pipe( - nn.Sequential(pipe_bn), - chunks=2, - checkpoint=checkpoint, - deferred_batch_norm=True, - ) - - x = torch.rand(4, 3, 10, 10) - pipe(x).local_value().mean().backward() - bn(x).mean().backward() - - assert torch.allclose(pipe[0].running_mean, bn.running_mean, atol=1e-4) - assert torch.allclose(pipe[0].running_var, bn.running_var, atol=1e-4) - - -@pytest.mark.parametrize("checkpoint", ["never", "always"]) -def test_deferred_batch_norm_params(checkpoint, setup_rpc): - bn = nn.BatchNorm2d(3) - pipe_bn = deepcopy(bn) - pipe = Pipe( - nn.Sequential(pipe_bn), - chunks=1, - checkpoint=checkpoint, - deferred_batch_norm=True, - ) - - x = torch.rand(4, 3, 10, 10) - pipe(x).local_value().mean().backward() - bn(x).mean().backward() - - assert pipe[0].weight.grad is not None - assert pipe[0].bias.grad is not None - - assert torch.allclose(pipe[0].weight.grad, bn.weight.grad, atol=1e-4) - assert torch.allclose(pipe[0].bias.grad, bn.bias.grad, atol=1e-4) - - -def test_devices(setup_rpc): - a = nn.Linear(1, 1) - b = nn.Linear(1, 1) - c = nn.Linear(1, 1) - - # There are extra two devices. - model = nn.Sequential(a, b, c) - model = Pipe(model) - - cpu = torch.device("cpu") - # Extra devices must be discarded. - assert model.devices == [cpu, cpu, cpu] - - -def test_partitions(setup_rpc): - a = nn.Linear(1, 1) - b = nn.Linear(1, 1) - - model = nn.Sequential(a, b) - model = Pipe(model) - - assert isinstance(model.partitions, nn.ModuleList) - assert isinstance(model.partitions[0], nn.Sequential) - assert isinstance(model.partitions[1], nn.Sequential) - - assert "partitions.0.0.weight" in model.state_dict() - - -@skip_if_no_cuda -def test_merged_partitions(setup_rpc): - a = nn.Linear(1, 1).to(0) - b = nn.Sequential(nn.Linear(1, 1), nn.Linear(1, 2)).to(0) - c = nn.Linear(1, 1) - d = nn.Linear(1, 2) - - model = nn.Sequential(a, b, c, d) - model = Pipe(model) - - assert isinstance(model.partitions, nn.ModuleList) - assert isinstance(model.partitions[0], PipeSequential) - assert isinstance(model.partitions[1], PipeSequential) - assert list(model.partitions[0]) == [a, b[0], b[1]] - assert list(model.partitions[1]) == [c] - assert list(model.partitions[2]) == [d] - - -def test_deny_moving(setup_rpc): - a = nn.Linear(1, 1) - b = nn.Linear(1, 1) - - model = nn.Sequential(a, b) - model = Pipe(model) - - # Moving is denied. - with pytest.raises(TypeError): - model.cuda() - - with pytest.raises(TypeError): - model.cpu() - - with pytest.raises(TypeError): - model.to(torch.device("cuda")) - - with pytest.raises(TypeError): - model.to(0) - - with pytest.raises(TypeError): - model.to("cuda") - - with pytest.raises(TypeError): - model.to(device=0) - - with pytest.raises(TypeError): - model.to(torch.rand(1)) - - with pytest.raises(TypeError): - model.to(tensor=torch.rand(1)) - - # Casting is allowed. - model.half() - model.to(torch.double) - model.to(dtype=torch.float) - - -def test_empty_module(setup_rpc): - # Empty sequential module is not illegal. - model = nn.Sequential() - model = Pipe(model) - - assert model(torch.tensor(42)).local_value() == torch.tensor(42) - - # But only tensor or tensors is legal in Pipe. - with pytest.raises(TypeError): - model(42) - - -def test_named_children(setup_rpc): - a = nn.Linear(1, 1) - b = nn.Linear(1, 1) - - model = nn.Sequential(OrderedDict([("a", a), ("b", b)])) - model = Pipe(model) - - names = {n for n, _ in model.named_modules()} - assert "partitions.0.0" in names - assert "partitions.1.0" in names - - # Pipe doesn't support __getattr__. Unlike nn.Sequential, Pipe requires - # several methods in its namespace. - with pytest.raises(AttributeError): - model.a - - -def test_verify_module_non_sequential(setup_rpc): - with pytest.raises( - TypeError, match="module must be nn.Sequential to be partitioned" - ): - Pipe(nn.Module()) - - -def test_verify_module_duplicate_children(setup_rpc): - conv = nn.Conv2d(3, 3, 1) - model = nn.Sequential(conv, conv) - - with pytest.raises( - ValueError, match="module with duplicate children is not supported" - ): - Pipe(model) - - -@skip_if_no_cuda -def test_verify_module_params_on_same_device(setup_rpc): - class Surrogate(nn.Module): - def __init__(self, param1, param2): - super().__init__() - self.param1 = param1 - self.param2 = param2 - - conv1 = nn.Conv2d(3, 3, 1) - conv2 = nn.Conv2d(3, 3, 1) - model = nn.Sequential(Surrogate(conv1, conv2.cuda())) - - with pytest.raises( - ValueError, - match=r"should have all parameters on a single device, please use .to\(\)" - " to place the module on a single device", - ): - Pipe(model) - - -@pytest.mark.skipif(not TEST_MULTIGPU, reason="Need atleast two GPUs") -def test_verify_nested_modules(setup_rpc): - model = nn.Sequential( - nn.Sequential(nn.Linear(32, 16).cuda(0), nn.Linear(16, 8).cuda(0)), - nn.Sequential(nn.Linear(8, 4).cuda(1), nn.Linear(4, 2).cuda(1)), - ) - - pipe = Pipe(model) - out = pipe(torch.rand(10, 32).cuda(0)) - assert out.local_value().device == torch.device("cuda:1") - assert out.local_value().size() == torch.Size([10, 2]) - - -def test_verify_module_duplicate_parameters_on_same_device(setup_rpc): - class Surrogate(nn.Module): - def __init__(self, module): - super().__init__() - self.module = module - - conv = nn.Conv2d(3, 3, 1) - model = nn.Sequential(Surrogate(conv), Surrogate(conv)) - - Pipe(model) - - -def test_forward_lockstep(setup_rpc): - timeline = [] - - class DelayedLog(nn.Module): - def __init__(self, j, seconds): - super().__init__() - self.i = 0 - self.j = j - self.seconds = seconds - - def forward(self, x): - time.sleep(self.seconds) - - timeline.append((self.i, self.j)) - self.i += 1 - - return x - - model = nn.Sequential(DelayedLog(0, seconds=0), DelayedLog(1, seconds=0.1)) - model = Pipe(model, chunks=3) - model(torch.rand(3, 1)) - - # Expected timeline: (Logs are recorded at !) - # - # Partition #0: 0! 1! 2! - # Partition #1: 000! 111! 222! - # - assert timeline == [(0, 0), (1, 0), (0, 1), (2, 0), (1, 1), (2, 1)] - - -@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"]) -@skip_if_no_cuda -def test_multiple_inputs(checkpoint, setup_rpc): - class Module1(nn.Module): - def forward(self, a, b, c): - return a + b + c, a * b * c - - class Module2(nn.Module): - def forward(self, a, b): - return a + b - - model = Pipe( - nn.Sequential(Module1().cuda(0), Module2().cuda(0)), - chunks=2, - checkpoint=checkpoint, - ) - t = torch.rand(10) - res = model(t, t, t).local_value() - assert torch.equal(res, (t + t + t) + (t * t * t)) - - -@pytest.mark.skipif(not TEST_MULTIGPU, reason="Need atleast two GPUs") -def test_inputs_wrong_device(setup_rpc): - class Module1(nn.Module): - def __init__(self): - super().__init__() - self.param = torch.nn.Parameter(torch.rand(5)) - - def forward(self, a, b): - return a + b + self.param, b - - # Start inputs on wrong device and ensure Pipe moves them correctly. - a = torch.rand(10).cuda(1) - b = torch.rand(10).cuda(1) - model = Pipe(nn.Sequential(Module1().cuda(0), Module1().cuda(1)), chunks=2) - with pytest.raises( - ValueError, - match="All inputs should be on the same device as the first partition", - ): - model(a, b) - - -@pytest.mark.skipif(not TEST_MULTIGPU, reason="Need atleast two GPUs") -def test_with_device_wrapper(setup_rpc): - fc1 = nn.Linear(16, 8).cuda(0) - fc2 = nn.Linear(8, 4).cuda(1) - dropout = nn.Dropout() - - model = nn.Sequential(fc1, fc2, WithDevice(dropout, "cuda:1")) - model = Pipe(model, chunks=8) - assert ( - torch.device("cuda:1") == model(torch.rand(16, 16).cuda(0)).local_value().device - ) - assert [torch.device("cuda:0"), torch.device("cuda:1")] == model.devices - - model = nn.Sequential(fc1, WithDevice(dropout, "cuda:1")) - model = Pipe(model, chunks=8) - assert ( - torch.device("cuda:1") == model(torch.rand(16, 16).cuda(0)).local_value().device - ) - assert [torch.device("cuda:0"), torch.device("cuda:1")] == model.devices - - model = nn.Sequential(fc1, WithDevice(fc2, "cuda:0")) - model = Pipe(model, chunks=8) - assert ( - torch.device("cuda:0") == model(torch.rand(16, 16).cuda(0)).local_value().device - ) - assert [torch.device("cuda:0")] == model.devices - assert torch.device("cuda:0") == fc2.weight.device - - -if __name__ == "__main__": - run_tests() diff --git a/test/distributed/pipeline/sync/test_pipeline.py b/test/distributed/pipeline/sync/test_pipeline.py deleted file mode 100644 index 9548cb959db1c..0000000000000 --- a/test/distributed/pipeline/sync/test_pipeline.py +++ /dev/null @@ -1,36 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -from torch.distributed.pipeline.sync.pipeline import _clock_cycles -from torch.testing._internal.common_utils import run_tests - - -def test_clock_cycles(): - assert list(_clock_cycles(1, 1)) == [[(0, 0)]] - assert list(_clock_cycles(1, 3)) == [[(0, 0)], [(0, 1)], [(0, 2)]] - assert list(_clock_cycles(3, 1)) == [[(0, 0)], [(1, 0)], [(2, 0)]] - - assert list(_clock_cycles(3, 3)) == [ - [(0, 0)], - [(1, 0), (0, 1)], - [(2, 0), (1, 1), (0, 2)], - [(2, 1), (1, 2)], - [(2, 2)], - ] - - assert list(_clock_cycles(4, 2)) == [ - [(0, 0)], - [(1, 0), (0, 1)], - [(2, 0), (1, 1)], - [(3, 0), (2, 1)], - [(3, 1)], - ] - - -if __name__ == "__main__": - run_tests() diff --git a/test/distributed/pipeline/sync/test_stream.py b/test/distributed/pipeline/sync/test_stream.py deleted file mode 100644 index f9702c8e41525..0000000000000 --- a/test/distributed/pipeline/sync/test_stream.py +++ /dev/null @@ -1,198 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -import pytest - -import torch - -from torch.distributed.pipeline.sync.stream import ( - CPUStream, - current_stream, - default_stream, - get_device, - is_cuda, - new_stream, - record_stream, - use_device, - use_stream, - wait_stream, -) -from torch.testing._internal.common_utils import run_tests - -skip_if_no_cuda = pytest.mark.skipif( - not torch.cuda.is_available(), reason="cuda required" -) - - -class TestNewStream: - def test_new_stream_cpu(self): - stream = new_stream(torch.device("cpu")) - assert stream is CPUStream - - @skip_if_no_cuda - def test_new_stream_cuda(self): - stream = new_stream(torch.device("cuda")) - assert isinstance(stream, torch.cuda.Stream) - assert stream != torch.cuda.default_stream() - - -class TestCurrentStream: - def test_current_stream_cpu(self): - stream = current_stream(torch.device("cpu")) - assert stream is CPUStream - - @skip_if_no_cuda - def test_current_stream_cuda(self): - stream = current_stream(torch.device("cuda")) - assert isinstance(stream, torch.cuda.Stream) - assert stream == torch.cuda.current_stream() - - -class TestDefaultStream: - def test_default_stream_cpu(self): - stream = default_stream(torch.device("cpu")) - assert stream is CPUStream - - @skip_if_no_cuda - def test_default_stream_cuda(self): - stream = default_stream(torch.device("cuda")) - assert isinstance(stream, torch.cuda.Stream) - assert stream == torch.cuda.default_stream() - - -class TestUseDevice: - def test_use_device_cpu(self): - with use_device(torch.device("cpu")): - pass - - @skip_if_no_cuda - def test_use_device_cuda(self): - with use_device(torch.device("cuda")): - pass - - -class TestUseStream: - def test_use_stream_cpu(self): - with use_stream(CPUStream): - pass - - @skip_if_no_cuda - def test_use_stream_cuda(self): - stream = new_stream(torch.device("cuda")) - with use_stream(stream): - assert current_stream(torch.device("cuda")) == stream - - -class TestGetDevice: - def test_get_device_cpu(self): - assert get_device(CPUStream).type == "cpu" - - @skip_if_no_cuda - def test_get_device_cuda(self): - stream = current_stream(torch.device("cuda")) - assert get_device(stream).type == "cuda" - - -class TestWaitStream: - def _test_wait_stream(self, source, target, cuda_sleep=None): - with use_stream(target): - if is_cuda(target): - cuda_sleep(0.5) - x = torch.ones(100, 100, device=get_device(target)) - - wait_stream(source, target) - - with use_stream(source): - assert x.sum().item() == 10000 - - def test_wait_stream_cpu_cpu(self): - source = CPUStream - target = CPUStream - self._test_wait_stream(source, target) - - @skip_if_no_cuda - def test_wait_stream_cpu_cuda(self, cuda_sleep): - source = CPUStream - target = new_stream(torch.device("cuda")) - self._test_wait_stream(source, target, cuda_sleep) - - @skip_if_no_cuda - def test_wait_stream_cuda_cpu(self, cuda_sleep): - source = new_stream(torch.device("cuda")) - target = CPUStream - self._test_wait_stream(source, target, cuda_sleep) - - @skip_if_no_cuda - def test_wait_stream_cuda_cuda(self, cuda_sleep): - source = current_stream(torch.device("cuda")) - target = new_stream(torch.device("cuda")) - self._test_wait_stream(source, target, cuda_sleep) - - -class TestRecordStream: - def test_record_stream_cpu(self): - # It should silently ignore CPU tensors. - x = torch.rand(1, device=torch.device("cpu")) - record_stream(x, CPUStream) - - @skip_if_no_cuda - def test_record_stream_cuda(self, cuda_sleep): - # This test detects unexpected block reallocation. For reliable test, - # the stream to allocate tensors is isolated. The allocator will not - # reuse free blocks which were allocated from another stream. - stream_alloc = new_stream(torch.device("cuda")) - with torch.cuda.stream(stream_alloc): - x = torch.rand(1, device=torch.device("cuda")) - - stream = new_stream(torch.device("cuda")) - record_stream(x, stream) - with use_stream(stream): - cuda_sleep(0.5) - - # 'x' is deleted at Python's perspective. But the block of 'x' is still - # required for 'stream'. 'y' shouldn't be allocated to the block. - data_ptr = x.data_ptr() - del x - stream_alloc.synchronize() - with torch.cuda.stream(stream_alloc): - y = torch.rand(1, device=torch.device("cuda")) - assert y.data_ptr() != data_ptr - - # Pause Python until 'stream' finishes tasks queued. Now the block of - # 'x' is free to be reallocated. - wait_stream(CPUStream, stream) - with torch.cuda.stream(stream_alloc): - z = torch.rand(1, device=torch.device("cuda")) - assert z.data_ptr() == data_ptr - - @skip_if_no_cuda - def test_record_stream_shifted_view(self, cuda_sleep): - # Issue: https://github.com/pytorch/pytorch/issues/27366 - stream_alloc = new_stream(torch.device("cuda")) - with torch.cuda.stream(stream_alloc): - x = torch.rand(2, device=torch.device("cuda")) - - y = x[1:] - assert y.data_ptr() > x.data_ptr() - - stream = new_stream(torch.device("cuda")) - with use_stream(stream): - cuda_sleep(0.5) - record_stream(y, stream) - - data_ptr = x.data_ptr() - del x, y - - stream_alloc.synchronize() - with torch.cuda.stream(stream_alloc): - z = torch.rand(2, device=torch.device("cuda")) - assert z.data_ptr() != data_ptr - - -if __name__ == "__main__": - run_tests() diff --git a/test/distributed/pipeline/sync/test_transparency.py b/test/distributed/pipeline/sync/test_transparency.py deleted file mode 100644 index a87a04150fdc3..0000000000000 --- a/test/distributed/pipeline/sync/test_transparency.py +++ /dev/null @@ -1,55 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -import torch -from torch import nn - -from torch.distributed.pipeline.sync import Pipe -from torch.testing._internal.common_utils import run_tests - - -def test_simple_linears(setup_rpc): - def sum_grad(parameters): - return sum(p.grad.sum() for p in parameters if p.grad is not None) - - def zero_grad(parameters): - for p in parameters: - p.grad = None - - inputs = torch.rand(8, 1) - model = nn.Sequential( - nn.Linear(1, 2), - nn.Linear(2, 4), - nn.Linear(4, 2), - nn.Linear(2, 1), - ) - - # Without Pipe - outputs = model(inputs) - loss = outputs.mean() - loss.backward() - - grad_without_pipe = sum_grad(model.parameters()) - - zero_grad(model.parameters()) - - # With Pipe - model = Pipe(model, chunks=4) - - outputs = model(inputs).local_value() - loss = outputs.mean() - loss.backward() - - grad_with_pipe = sum_grad(model.parameters()) - - # Both grads should be identical. - assert torch.allclose(grad_with_pipe, grad_without_pipe) - - -if __name__ == "__main__": - run_tests() diff --git a/test/distributed/pipeline/sync/test_worker.py b/test/distributed/pipeline/sync/test_worker.py deleted file mode 100644 index f82af2ea00679..0000000000000 --- a/test/distributed/pipeline/sync/test_worker.py +++ /dev/null @@ -1,118 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -import threading - -import pytest - -import torch - -from torch.distributed.pipeline.sync.microbatch import Batch -from torch.distributed.pipeline.sync.stream import CPUStream -from torch.distributed.pipeline.sync.worker import spawn_workers, Task -from torch.testing._internal.common_utils import run_tests - - -class fake_device: - """A test double for :class:`torch.device`. Every fake device is different - with each other. - """ - - type = "fake" - index = None - - -def test_compute_multithreading(): - """Task.compute should be executed on multiple threads.""" - thread_ids = set() - - def log_thread_id(): - thread_id = threading.current_thread().ident - thread_ids.add(thread_id) - return Batch(()) - - with spawn_workers([fake_device() for _ in range(2)]) as (in_queues, out_queues): - for i in range(2): - t = Task(CPUStream, compute=log_thread_id, finalize=None) - in_queues[i].put(t) - for i in range(2): - out_queues[i].get() - - assert len(thread_ids) == 2 - - -def test_compute_success(): - """Task.compute returns (True, (task, batch)) on success.""" - - def _42(): - return Batch(torch.tensor(42)) - - with spawn_workers([torch.device("cpu")]) as (in_queues, out_queues): - t = Task(CPUStream, compute=_42, finalize=None) - in_queues[0].put(t) - ok, (task, batch) = out_queues[0].get() - - assert ok - assert task is t - assert isinstance(batch, Batch) - assert batch[0].item() == 42 - - -def test_compute_exception(): - """Task.compute returns (False, exc_info) on failure.""" - - def zero_div(): - 0 / 0 - - with spawn_workers([torch.device("cpu")]) as (in_queues, out_queues): - t = Task(CPUStream, compute=zero_div, finalize=None) - in_queues[0].put(t) - ok, exc_info = out_queues[0].get() - - assert not ok - assert isinstance(exc_info, tuple) - assert issubclass(exc_info[0], ZeroDivisionError) - - -@pytest.mark.parametrize("grad_mode", [True, False]) -def test_grad_mode(grad_mode): - def detect_grad_enabled(): - x = torch.rand(1, requires_grad=torch.is_grad_enabled()) - return Batch(x) - - with torch.set_grad_enabled(grad_mode): - with spawn_workers([torch.device("cpu")]) as (in_queues, out_queues): - task = Task(CPUStream, compute=detect_grad_enabled, finalize=None) - in_queues[0].put(task) - - ok, (_, batch) = out_queues[0].get() - - assert ok - assert batch[0].requires_grad == grad_mode - - -def test_worker_per_device(): - cpu = torch.device("cpu") - cpu0 = torch.device("cpu", index=0) - fake1 = fake_device() - fake2 = fake_device() - - with spawn_workers([cpu, cpu, cpu0, fake1, fake2]) as (in_queues, out_queues): - assert len(in_queues) == len(out_queues) == 5 - - # 0: cpu, 1: cpu, 2: cpu0 - assert in_queues[0] is in_queues[1] is in_queues[2] - assert out_queues[0] is out_queues[1] is out_queues[2] - - # 3: fake1, 4: fake2 - assert in_queues[3] is not in_queues[4] - assert out_queues[3] is not out_queues[4] - - -if __name__ == "__main__": - run_tests() diff --git a/test/test_public_bindings.py b/test/test_public_bindings.py index 1db0e5718ce69..8ab2ac1f511f0 100644 --- a/test/test_public_bindings.py +++ b/test/test_public_bindings.py @@ -329,7 +329,6 @@ def test_modules_can_be_imported(self): "torch.testing._internal.distributed.fake_pg", "torch.testing._internal.distributed.multi_threaded_pg", "torch.testing._internal.distributed.nn.api.remote_module_test", - "torch.testing._internal.distributed.pipe_with_ddp_test", "torch.testing._internal.distributed.rpc.dist_autograd_test", "torch.testing._internal.distributed.rpc.dist_optimizer_test", "torch.testing._internal.distributed.rpc.examples.parameter_server_test", @@ -408,7 +407,6 @@ def test_modules_can_be_imported(self): "torch.distributed.nn.api.remote_module", "torch.distributed.optim", "torch.distributed.optim.optimizer", - "torch.distributed.pipeline.sync", "torch.distributed.rendezvous", "torch.distributed.rpc.api", "torch.distributed.rpc.backend_registry", diff --git a/test/test_testing.py b/test/test_testing.py index ba9558a3ddd14..1e1dce59a32e7 100644 --- a/test/test_testing.py +++ b/test/test_testing.py @@ -2245,7 +2245,6 @@ def test_circular_dependencies(self) -> None: else: ignored_modules.append("torch.distributed.nn.api.") ignored_modules.append("torch.distributed.optim.") - ignored_modules.append("torch.distributed.pipeline.") ignored_modules.append("torch.distributed.rpc.") ignored_modules.append("torch.testing._internal.dist_utils") # And these both end up with transitive dependencies on distributed diff --git a/torch/distributed/pipeline/__init__.py b/torch/distributed/pipeline/__init__.py deleted file mode 100644 index eacd2bc99d046..0000000000000 --- a/torch/distributed/pipeline/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -import warnings - - -with warnings.catch_warnings(): - warnings.simplefilter("always") - warnings.warn( - "`torch.distributed.pipeline` is deprecated. For up-to-date pipeline parallel " - "implementation, please refer to the PiPPy library under the PyTorch " - "organization (Pipeline Parallelism for PyTorch): " - "https://github.com/pytorch/PiPPy", - DeprecationWarning, - stacklevel=2, - ) diff --git a/torch/distributed/pipeline/sync/LICENSE b/torch/distributed/pipeline/sync/LICENSE deleted file mode 100644 index e52be240fdc98..0000000000000 --- a/torch/distributed/pipeline/sync/LICENSE +++ /dev/null @@ -1,27 +0,0 @@ -Copyright 2019-2020 Kakao Brain - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are met: - -1. Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - -2. Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - -3. Neither the name of the copyright holder nor the names of its - contributors may be used to endorse or promote products derived from this - software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE -LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE -POSSIBILITY OF SUCH DAMAGE. diff --git a/torch/distributed/pipeline/sync/__init__.py b/torch/distributed/pipeline/sync/__init__.py deleted file mode 100644 index 75a80c5db0f9f..0000000000000 --- a/torch/distributed/pipeline/sync/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -"""A Pipe implementation in PyTorch.""" -from .checkpoint import is_checkpointing, is_recomputing -from .pipe import Pipe, WithDevice -from .microbatch import NoChunk - -__all__ = ["Pipe", "is_checkpointing", "is_recomputing"] diff --git a/torch/distributed/pipeline/sync/_balance/__init__.py b/torch/distributed/pipeline/sync/_balance/__init__.py deleted file mode 100644 index 8ffc657896d87..0000000000000 --- a/torch/distributed/pipeline/sync/_balance/__init__.py +++ /dev/null @@ -1,164 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -"""A helper to roughly balance a sequential module. - -Usage:: - - import torch - from torch.distributed.pipeline.sync import Pipe - from torch.distributed.pipeline.sync.balance import balance_by_time - - sample = torch.empty(128, 3, 224, 224) - balance = balance_by_time(torch.cuda.device_count(), model, sample) - - pipe = Pipe(model, balance, chunks=8) - -""" -from typing import Any, List, Union, Sequence - -import torch -from torch import Tensor -import torch.nn as nn - -from . import blockpartition -from .profile import profile_sizes, profile_times - -__all__ = ["balance_by_time", "balance_by_size"] - - -Device = Union[torch.device, int, str] - -Tensors = Sequence[Tensor] -TensorOrTensors = Union[Tensor, Tensors] - - -def balance_cost(cost: List[int], partitions: int) -> List[int]: - partitioned = blockpartition.solve(cost, partitions) - return [len(p) for p in partitioned] - - -def balance_by_time( - partitions: int, - module: nn.Sequential, - sample: Union[List[Any], Tensor], - *, - timeout: float = 1.0, - device: Device = torch.device("cuda"), -) -> List[int]: - """Naive automatic balancing by elapsed time per layer. - :: - - sample = torch.empty(128, 3, 224, 224) - balance = balance_by_time(torch.cuda.device_count(), model, sample) - pipe = Pipe(model, balance, chunks=8) - - Args: - partitions (int): - intended number of partitions - module (torch.nn.Sequential): - sequential module to be partitioned - sample (torch.Tensor): - example input with arbitrary batch size - - Keyword Args: - timeout (float): - profiling iterates again if the timeout (in second) is not exceeded - (default: ``1.0``) - device ('cpu' or 'cuda' device): - CPU or CUDA device where each layer is profiled (default: the - current CUDA device) - - Returns: - A list of number of layers in each partition. Use it for the `balance` - parameter of :class:`~torchpipe.Pipe`. - - .. note:: - `module` and `sample` must be placed on the same device. - - """ - times = profile_times(module, sample, timeout, torch.device(device)) - return balance_cost(times, partitions) - - -def balance_by_size( - partitions: int, - module: nn.Sequential, - input: Union[List[Any], Tensor], - *, - chunks: int = 1, - param_scale: float = 2.0, - device: Device = torch.device("cuda"), -) -> List[int]: - """Naive automatic balancing by CUDA memory usage per layer. - - During training, required memory for parameters depends on which optimizer - is used. Optimizers may use buffers for each parameter to track - optimization statistics internally, such as momentum buffer in SGD. - - To get more reliable size based balance, you should specify `param_scale` - with regard to your optimizer. The default `param_scale` is 2 instead of 1 - due to gradient accumulation which is necessary for every optimizer. - - Follow this guide to choose correct `param_scale` for typical optimizers: - - ========= ============= ========================================= - Optimizer `param_scale` Internal State - ========= ============= ========================================= - SGD 2--3 (momentum_buffer) - Adam 4--5 exp_avg, exp_avg_sq, (max_exp_avg_sq) - Adadelta 4 square_avg, acc_delta - Adagrad 3 sum - RMSprop 3--5 square_avg, (momentum_buffer), (grad_avg) - ========= ============= ========================================= - - Here's a simple example with the Adam optimizer:: - - balance = balance_by_size( - torch.cuda.device_count(), - model, - - # Same size with mini-batch to train - torch.empty(1024, 3, 224, 224), - - # Number of micro-batches to train with Pipe - chunks=8, - - # 4 for Adam - param_scale=4.0, - ) - - pipe = Pipe(model, balance, chunks=8) - adam = Adam(pipe.parameters()) - - Args: - partitions (int): - intended number of partitions - module (torch.nn.Sequential): - sequential module to be partitioned - input (torch.Tensor): - example mini-batch with the same size to train - - Keyword Args: - chunks (int): - number of micro-batches will be used to train (default: ``1``) - param_scale (float): - how many copies of parameters would be allocated for training. It - depends on optimizer. See the above guide. (default: ``2.0``) - device ('cuda' device): - CUDA device where each layer is profiled (default: the current CUDA - device) - - Returns: - A list of number of layers in each partition. Use it for the `balance` - parameter of :class:`~torchpipe.Pipe`. - - .. note:: - `module` and `input` must be placed on the same CUDA device. - - """ - sizes = profile_sizes(module, input, chunks, param_scale, torch.device(device)) - return balance_cost(sizes, partitions) diff --git a/torch/distributed/pipeline/sync/_balance/blockpartition.py b/torch/distributed/pipeline/sync/_balance/blockpartition.py deleted file mode 100644 index ccdf5fe4df990..0000000000000 --- a/torch/distributed/pipeline/sync/_balance/blockpartition.py +++ /dev/null @@ -1,95 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -"""Implements "Block Partitions of Sequences" by Imre B\u00e1r\u00e1ny et al. - -Paper: https://arxiv.org/pdf/1308.2452.pdf - -""" -from typing import Iterator, List, Tuple - -__all__ = ["solve"] - - -def solve(sequence: List[int], partitions: int = 1) -> List[List[int]]: - """Splits a sequence into several partitions to minimize variance for each - partition. - - The result might not be optimal. However, it can be done only in O(kn\u00b3), - where k is the number of partitions and n is the length of the sequence. - - """ - if partitions < 1: - raise ValueError(f"partitions must be a positive integer ({partitions} < 1)") - - n = len(sequence) - if n < partitions: - raise ValueError(f"sequence is shorter than intended partitions ({n} < {partitions})") - - # Normalize the sequence in [0, 1]. - minimum = min(sequence) - maximum = max(sequence) - minimum - - normal_sequence: List[float] - if maximum == 0: - normal_sequence = [0 for _ in sequence] - else: - normal_sequence = [(x - minimum) / maximum for x in sequence] - - splits = [n // partitions * (x + 1) for x in range(partitions - 1)] + [n] - - def block_size(i: int) -> float: - start = splits[i - 1] if i > 0 else 0 - stop = splits[i] - return sum(normal_sequence[start:stop]) - - def leaderboard() -> Iterator[Tuple[float, int]]: - return ((block_size(i), i) for i in range(partitions)) - - while True: - """ - (1) Fix p element-of [k] with M(P) = bp. So Bp is a maximal block of P. - """ - # max_size: M(P) - max_size, p = max(leaderboard()) - - while True: - """ - (2) If M(P) <= m(P) + 1, then stop. - """ - # min_size: m(P) - min_size, q = min(leaderboard()) - - if max_size <= min_size + 1: - return [sequence[i:j] for i, j in zip([0] + splits[:-1], splits)] - - """ - (3) If M(P) > m(P) + 1, then let m(P) = bq for the q element-of [k] which is - closest to p (ties broken arbitrarily). Thus Bq is a minimal block - of P. Let Bh be the block next to Bq between Bp and Bq. (Note that - Bh is a non-empty block: if it were, then m(P) = 0 and we should - have chosen Bh instead of Bq.) - """ - if p < q: - """ - So either p < q and then h = q-1 and we define P * by moving - the last element from Bh = Bq-1 to Bq, - """ - h = q - 1 - splits[h] -= 1 - else: - """ - or q < p, and then h = q + 1 and P * is obtained by moving the - first element of Bh = Bq+1 to Bq. - """ - h = q + 1 - splits[q] += 1 - - """ - Set P = P * . If p = h, then go to (1), else go to (2). - """ - if p == h: - break diff --git a/torch/distributed/pipeline/sync/_balance/profile.py b/torch/distributed/pipeline/sync/_balance/profile.py deleted file mode 100644 index fa1a0c06a8e3a..0000000000000 --- a/torch/distributed/pipeline/sync/_balance/profile.py +++ /dev/null @@ -1,116 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -"""Per-layer profilers.""" -import copy -import time -from typing import Any, Generator, List, Union, Sequence - -import torch -from torch import Tensor -import torch.nn as nn - -from ..microbatch import Batch - -__all__: List[str] = [] - - -Device = Union[torch.device, int, str] - -Tensors = Sequence[Tensor] -TensorOrTensors = Union[Tensor, Tensors] - - -def layerwise_sandbox(module: nn.Sequential, device: torch.device,) -> Generator[nn.Module, None, None]: - """Copies layers for ease to profile. It doesn't modify the given - module. - """ - for layer in module: - layer_copy = copy.deepcopy(layer) - layer_copy.to(device) - layer_copy.train() - yield layer_copy - - -def detach(batch: Batch) -> None: - """Detaches from autograd graph.""" - for i, x in enumerate(batch): - batch[i] = x.detach().requires_grad_(x.requires_grad) - - -def profile_times(module: nn.Sequential, sample: Union[List[Any], Tensor], timeout: float, device: torch.device,) -> List[int]: - """Profiles elapsed times per layer.""" - if any(p.grad is not None for p in module.parameters()): - raise ValueError("some parameter already has gradient") - - _batch = Batch(sample) - for i, x in enumerate(_batch): - _batch[i] = x.detach().to(device).requires_grad_(x.requires_grad) - - time_bufs: List[List[float]] = [[] for _ in module] - begun_at = time.time() - - while time.time() - begun_at < timeout: - batch = _batch - - for i, layer in enumerate(layerwise_sandbox(module, device)): - detach(batch) - - if device.type == "cuda": - torch.cuda.synchronize(device) - tick = time.time() - - # Forward - batch = batch.call(layer) - - # Backward - backward_tensors = tuple(y for y in batch if y.requires_grad) - if backward_tensors: - torch.autograd.backward(backward_tensors, backward_tensors) - - if device.type == "cuda": - torch.cuda.synchronize(device) - tock = time.time() - - time_bufs[i].append(tock - tick) - - us = 1_000_000 - return [sum(int(t * us) for t in buf) for buf in time_bufs] - - -def profile_sizes( - module: nn.Sequential, input: Union[List[Any], Tensor], chunks: int, param_scale: float, device: torch.device, -) -> List[int]: - """Profiles CUDA memory usage per layer.""" - if device.type != "cuda": - raise ValueError("size profiler supports only CUDA device") - - batch = Batch(input) - sizes: List[int] = [] - - latent_scale = batch[0].size(0) / chunks - for i, x in enumerate(batch): - batch[i] = x[:1].detach().to(device).requires_grad_(x.requires_grad) - - for layer in layerwise_sandbox(module, device): - detach(batch) - - # Detect memory usage at forward. - torch._C._cuda_clearCublasWorkspaces() - memory_before = torch.cuda.memory_allocated(device) - batch = batch.call(layer) - torch._C._cuda_clearCublasWorkspaces() - memory_after = torch.cuda.memory_allocated(device) - latent_size = memory_after - memory_before - - # Analyze size of parameters. - param_size = sum(p._typed_storage()._nbytes() for p in layer.parameters()) - - # Combine size of parameters and activations with normalize scales. - size = latent_size * latent_scale + param_size * param_scale - sizes.append(int(size)) - - return sizes diff --git a/torch/distributed/pipeline/sync/_balance/py.typed b/torch/distributed/pipeline/sync/_balance/py.typed deleted file mode 100644 index ab03724cafbf5..0000000000000 --- a/torch/distributed/pipeline/sync/_balance/py.typed +++ /dev/null @@ -1,6 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. diff --git a/torch/distributed/pipeline/sync/batchnorm.py b/torch/distributed/pipeline/sync/batchnorm.py deleted file mode 100644 index 868ad50cf3fcf..0000000000000 --- a/torch/distributed/pipeline/sync/batchnorm.py +++ /dev/null @@ -1,159 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -"""Tracks the running statistics per mini-batch instead of micro-batch.""" -from typing import TypeVar, Optional, cast - -import torch -from torch import Tensor, nn -from torch.nn.functional import batch_norm -from torch.nn.modules.batchnorm import _BatchNorm - -from .checkpoint import is_recomputing - -__all__ = ["DeferredBatchNorm"] - - -TModule = TypeVar("TModule", bound=nn.Module) - - -class DeferredBatchNorm(_BatchNorm): - """A BatchNorm layer tracks multiple micro-batches to update running statistics per mini-batch.""" - - sum: Tensor - sum_squares: Tensor - running_mean: Tensor - running_var: Tensor - num_batches_tracked: Tensor - - def __init__( - self, - num_features: int, - eps: float = 1e-5, - momentum: Optional[float] = 0.1, - affine: bool = True, - chunks: int = 1, - ) -> None: - super().__init__(num_features, eps, momentum, affine, track_running_stats=True) - - self.register_buffer("sum", torch.zeros_like(self.running_mean)) - self.register_buffer("sum_squares", torch.zeros_like(self.running_var)) - - self.counter = 0 - self.tracked = 0 - self.chunks = chunks - - def _check_input_dim(self, input: Tensor) -> None: - # It's the typical _check_input_dim() implementation in PyTorch. - if input.dim() <= 2: - raise ValueError("expected at least 3D input (got %dD input)" % input.dim()) - - def _track(self, input: Tensor) -> bool: - """Tracks statistics of a micro-batch.""" - # Dimensions except channel. For example, (0, 2, 3) is for BatchNorm2d. - dim = [0] - dim.extend(range(2, input.dim())) - - with torch.no_grad(): - self.sum += input.sum(dim) - self.sum_squares += (input ** 2).sum(dim) - - size = input.size().numel() // input.size(1) - self.counter += size - self.tracked += 1 - - return self.tracked == self.chunks - - def _commit(self) -> None: - """Update the running statistics of a mini-batch.""" - exponential_average_factor = 0.0 - self.num_batches_tracked += 1 - if self.momentum is None: # use cumulative moving average - exponential_average_factor = 1.0 / float(self.num_batches_tracked) - else: # use exponential moving average - exponential_average_factor = self.momentum - - mean = self.sum / self.counter - var = self.sum_squares / self.counter - mean ** 2 - - # Calculate the exponential moving average here. - m = exponential_average_factor - - self.running_mean *= 1 - m - self.running_mean += mean * m - - self.running_var *= 1 - m - self.running_var += var * m - - self.sum.zero_() - self.sum_squares.zero_() - self.counter = 0 - self.tracked = 0 - - def forward(self, input: Tensor) -> Tensor: - if not self.training: - # Don't train parameters on the evaluation mode. - return batch_norm( - input, - running_mean=self.running_mean, - running_var=self.running_var, - weight=self.weight, - bias=self.bias, - training=False, - momentum=0.0, - eps=self.eps, - ) - - if not is_recomputing(): - # Track a micro-batch on the training mode - # but not under a recomputation. - tracked_enough = self._track(input) - - # Update the running statistics for a mini-batch - # if it has tracked enough micro-batches. - if tracked_enough: - self._commit() - - # Normalize a micro-batch and train the parameters. - return batch_norm( - input, - running_mean=None, - running_var=None, - weight=self.weight, - bias=self.bias, - training=True, - momentum=0.0, - eps=self.eps, - ) - - @classmethod - def convert_deferred_batch_norm(cls, module: TModule, chunks: int = 1) -> TModule: - """Converts a :class:`nn.BatchNorm` or underlying :class:`nn.BatchNorm`s into :class:`DeferredBatchNorm`:: - - from torchvision.models.resnet import resnet101 - from torchpipe.batchnorm import DeferredBatchNorm - model = resnet101() - model = DeferredBatchNorm.convert_deferred_batch_norm(model) - - """ - if isinstance(module, DeferredBatchNorm) and module.chunks is chunks: - return cast(TModule, module) - - module_output: nn.Module = module - - if isinstance(module, _BatchNorm) and module.track_running_stats: - module_output = DeferredBatchNorm(module.num_features, module.eps, module.momentum, module.affine, chunks) - if module.affine: - module_output.register_parameter("weight", module.weight) - module_output.register_parameter("bias", module.bias) - module_output.register_buffer("running_mean", module.running_mean) - module_output.register_buffer("running_var", module.running_var) - module_output.register_buffer("num_batches_tracked", module.num_batches_tracked) - - for name, child in module.named_children(): - module_output.add_module(name, cls.convert_deferred_batch_norm(child, chunks)) - - return cast(TModule, module_output) diff --git a/torch/distributed/pipeline/sync/checkpoint.py b/torch/distributed/pipeline/sync/checkpoint.py deleted file mode 100644 index e67da2499d573..0000000000000 --- a/torch/distributed/pipeline/sync/checkpoint.py +++ /dev/null @@ -1,364 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -"""Checkpointing with preceding recomputation. - -PyTorch already provides the official checkpointing utilities in -:mod:`torch.utils.checkpoint`. The official checkpointing combines -recomputation and recursive backpropagation into one autograd function named -``CheckpointFunction``. Hence, the recomputation can be started only when the -gradients arrive to the function. In Pipe, the recomputation needs to precede -the gradient arrival to minimize the GPU idle time. - -We solve this problem by introducing separate autograd functions named -:class:`Recompute` and :class:`Checkpoint`. Each function represents -recomputation and recursive backpropagation, respectively. We can manipulate -the control flow in aspect of both the autograd engine and CUDA with a pair of -the functions. - -Specifically, we place CUDA stream synchronization between :class:`Recompute` -and :class:`Checkpoint` to delay only :class:`Checkpoint` until the gradient is -copied entirely. - -""" -from collections import deque -from contextlib import contextmanager -import threading -from typing import ( - Any, - Deque, - Generator, - List, - Optional, - Protocol, - Union, - Sequence, - Tuple -) - -import torch -from torch import Tensor -import torch.autograd - -from .dependency import fork, join -from .microbatch import Batch -from .phony import get_phony - -__all__ = ["Function", "checkpoint", "Checkpointing", "ThreadLocal", "enable_checkpointing", - "enable_recomputing", "is_checkpointing", "is_recomputing", "Context", "save_rng_states", - "restore_rng_states", "Checkpoint", "Recompute"] - - -Tensors = Sequence[Tensor] -TensorOrTensors = Union[Tensor, Tensors] - -# Types for shared memory between Checkpoint and Recompute. -Recomputed = Tuple[TensorOrTensors, Tensors] # (output, input_leaf) -RNGStates = Tuple[Tensor, Optional[Tensor]] # (cpu_rng_state, gpu_rng_state) - - -# Protocol with __call__ instead of Callable can be used as an attribute type. -# See: https://github.com/python/mypy/issues/708#issuecomment-561735949 -class Function(Protocol): - def __call__(self, input: TensorOrTensors) -> TensorOrTensors: - ... - - -def checkpoint(function: Function, input): - """Make a checkpoint with a simple interface like - :func:`torch.utils.checkpoint.checkpoint`. It's only used to test or debug - :class:`Checkpoint` and :class:`Recompute` without boilerplate. - """ - batch = Batch(input) - - chk = Checkpointing(function, batch) - batch = chk.checkpoint() - chk.recompute(batch) - - return batch.values - - -class Checkpointing: - """Generates a pair of :class:`Checkpoint` and :class:`Recompute`.""" - - def __init__(self, function: Function, batch: Batch) -> None: - self.function = function - self.batch = batch - - # Shared memory between Checkpoint and Recompute. 1-length deque is - # used for mutability and length limitation. - self.recomputed: Deque[Recomputed] = deque(maxlen=1) - self.rng_states: Deque[RNGStates] = deque(maxlen=1) - - def checkpoint(self) -> Batch: - """Return a batch applied by :class:`Checkpoint`.""" - input_atomic = self.batch.atomic - inputs = tuple(self.batch) - - # Use a phony which requires grad to ensure that Checkpoint can be - # tracked by the autograd engine even when none of the input tensors - # require grad. - phony = get_phony(self.batch.get_device(), requires_grad=True) - - output = Checkpoint.apply(phony, self.recomputed, self.rng_states, self.function, input_atomic, *inputs) - - # Gradients are only supported for float Tensors. - if isinstance(output, tuple): - output = tuple([x.detach() if torch.is_tensor(x) and not x.is_floating_point() else x for x in output]) - - return Batch(output) - - def recompute(self, batch: Batch) -> None: - """Apply :class:`Recompute` to the batch in place.""" - input_atomic = self.batch.atomic - inputs = tuple(self.batch) - - # Use a tensor in the batch to tie together fork-join - tensor_idx = batch.find_tensor_idx() - # batch[tensor_idx] is always requiring grad, because it has been passed - # checkpoint with a phony requiring grad. - batch[tensor_idx], phony = fork(batch[tensor_idx]) - phony = Recompute.apply(phony, self.recomputed, self.rng_states, self.function, input_atomic, *inputs) - batch[tensor_idx] = join(batch[tensor_idx], phony) - - -class ThreadLocal(threading.local): - def __init__(self) -> None: - self.is_checkpointing = False - self.is_recomputing = False - - -thread_local = ThreadLocal() - - -@contextmanager -def enable_checkpointing() -> Generator[None, None, None]: - """Make :func:`is_checkpointing` return :data:`True` within a context.""" - orig = thread_local.is_checkpointing - thread_local.is_checkpointing = True - try: - yield - finally: - thread_local.is_checkpointing = orig - - -@contextmanager -def enable_recomputing() -> Generator[None, None, None]: - """Makes :func:`is_recomputing` return :data:`True` within a context.""" - orig = thread_local.is_recomputing - thread_local.is_recomputing = True - try: - yield - finally: - thread_local.is_recomputing = orig - - -def is_checkpointing() -> bool: - """Whether the current forward propagation is under checkpointing. - - Returns: - bool: :data:`True` if it's under checkpointing. - - """ - return thread_local.is_checkpointing - - -def is_recomputing() -> bool: - """Whether the current forward propagation is under checkpoint recomputation. - - Use this to prevent duplicated side-effects at forward - propagation:: - - class Counter(nn.Module): - def __init__(self): - super().__init__() - self.counter = 0 - - def forward(self, input): - if not is_recomputing(): - self.counter += 1 - return input - - Returns: - bool: :data:`True` if it's under checkpoint recomputation. - - .. seealso:: :ref:`Detecting Recomputation` - - """ - return thread_local.is_recomputing - - -class Context: - """The common interface between the :class:`Checkpoint` and :class:`Recompute` context.""" - - recomputed: Deque[Recomputed] - rng_states: Deque[RNGStates] - function: Function - input_atomic: bool - inputs: Sequence[Any] - - saved_tensors: Tuple[Tensor, ...] - - def save_for_backward(self, *tensors: Tensor) -> None: # pragma: no cover - pass - - -def save_rng_states(device: torch.device, rng_states: Deque[RNGStates],) -> None: - """: - Capture the current random number generator states. - - meth:`Checkpoint.forward` captures the current PyTorch's random number - generator states at CPU and GPU to reuse in :meth:`Recompute.backward`. - - .. seealso:: :ref:`Referential Transparency` - - """ - cpu_rng_state = torch.get_rng_state() - - gpu_rng_state: Optional[Tensor] - if device.type == "cuda": - gpu_rng_state = torch.cuda.get_rng_state(device) - else: - gpu_rng_state = None - - rng_states.append((cpu_rng_state, gpu_rng_state)) - - -@contextmanager -def restore_rng_states(device: torch.device, rng_states: Deque[RNGStates],) -> Generator[None, None, None]: - """: - Restore the random number generator state. - - meth:`Recompute.backward` restores the random number generator states - captured by :func:`save_rng_states` within its context. - - .. seealso:: :ref:`Referential Transparency` - - """ - cpu_rng_state, gpu_rng_state = rng_states.pop() - - gpu_devices: List[torch.device] = [] - if device.type == "cuda": - gpu_devices.append(device) - - with torch.random.fork_rng(gpu_devices): - torch.set_rng_state(cpu_rng_state) - if gpu_rng_state is not None: - torch.cuda.set_rng_state(gpu_rng_state, device) - yield - - -class Checkpoint(torch.autograd.Function): - @staticmethod - # type: ignore[override] - def forward( - ctx: Context, - phony: Tensor, - recomputed: Deque[Recomputed], - rng_states: Deque[RNGStates], - function: Function, - input_atomic: bool, - *inputs, - ): - ctx.recomputed = recomputed - ctx.rng_states = rng_states - - save_rng_states(phony.device, ctx.rng_states) - - ctx.function = function - ctx.input_atomic = input_atomic - if input_atomic: - tensors = [inputs[0]] - else: - tensors = [] - for input in inputs: - if torch.is_tensor(input): - tensors.append(input) - - ctx.save_for_backward(*tensors) - - with torch.no_grad(), enable_checkpointing(): - if input_atomic: - assert len(inputs) == 1 - output = function(inputs[0]) - else: - output = function(*inputs) - return output - - @staticmethod - def backward(ctx: Context, *grad_output: Tensor,) -> Tuple[Optional[Tensor], ...]: # pragma: no cover - output, input_leaf = ctx.recomputed.pop() - - if isinstance(output, tuple): - outputs = output - else: - outputs = (output,) - if any(torch.is_tensor(y) and y.requires_grad for y in outputs): - tensors = tuple([x for x in outputs if torch.is_tensor(x) and x.requires_grad]) - torch.autograd.backward(tensors, grad_output) - - grad_input: List[Optional[Tensor]] = [None, None, None, None, None] - grad_input.extend(x.grad if torch.is_tensor(x) else None for x in input_leaf) - return tuple(grad_input) - - -class Recompute(torch.autograd.Function): - @staticmethod - # type: ignore[override] - def forward( - ctx: Context, - phony: Tensor, - recomputed: Deque[Recomputed], - rng_states: Deque[RNGStates], - function: Function, - input_atomic: bool, - *inputs, - ) -> Tensor: - ctx.recomputed = recomputed - ctx.rng_states = rng_states - - ctx.function = function - ctx.input_atomic = input_atomic - ctx.inputs = inputs - if input_atomic: - tensors = [inputs[0]] - else: - tensors = [] - for input in inputs: - if torch.is_tensor(input): - tensors.append(input) - ctx.save_for_backward(*tensors) - - return phony - - @staticmethod - def backward(ctx: Context, *grad_output: Tensor) -> Tuple[None, ...]: # pragma: no cover - inputs = ctx.inputs - inputs_leaf = tuple(x.detach().requires_grad_(x.requires_grad) if torch.is_tensor(x) else x for x in inputs) - - # Get the device for the inputs from a tensor - device = None - for input in inputs: - if torch.is_tensor(input): - device = input.device - break - - if device is None: - raise RuntimeError(f'No tensors found in {inputs}') - - with restore_rng_states(device, ctx.rng_states): - with torch.enable_grad(), enable_recomputing(): - if ctx.input_atomic: - assert len(inputs_leaf) == 1 - output = ctx.function(inputs_leaf[0]) - else: - output = ctx.function(*inputs_leaf) - - ctx.recomputed.append((output, inputs_leaf)) - - grad_input: List[None] = [None, None, None, None, None] - grad_input.extend(None for _ in ctx.inputs) - return tuple(grad_input) diff --git a/torch/distributed/pipeline/sync/copy.py b/torch/distributed/pipeline/sync/copy.py deleted file mode 100644 index b717f0c2932c6..0000000000000 --- a/torch/distributed/pipeline/sync/copy.py +++ /dev/null @@ -1,108 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -"""Autograd functions for stream-aware CUDA copy. - -It is used to overlap copy and computation on the same GPU. -""" -from collections import deque -from typing import Deque, List, Optional, Tuple, Sequence - -import torch -from torch import Tensor - -from .stream import AbstractStream, current_stream, get_device, record_stream, use_stream, wait_stream - -__all__: List[str] = ["Context", "Copy", "Wait"] - - -Tensors = Sequence[Tensor] - - -# Common interface between :class:`Copy` and :class:`Wait`. -class Context: - prev_stream: AbstractStream - next_stream: AbstractStream - - -class Copy(torch.autograd.Function): - """Copies tensors on specific streams.""" - - @staticmethod - # type: ignore[override] - def forward(ctx: Context, prev_stream: AbstractStream, next_stream: AbstractStream, *input,) -> Tensors: - ctx.prev_stream = prev_stream - ctx.next_stream = next_stream - - output = [] - output_stream = current_stream(get_device(next_stream)) - - with use_stream(prev_stream), use_stream(next_stream): - for x in input: - if torch.is_tensor(x): - y = x.to(get_device(next_stream), non_blocking=True) - output.append(y) - - # 'prev_stream' is not where 'x' has been allocated. - record_stream(x, prev_stream) - # 'y' has been allocated on 'next_stream'. - # It might be used on the current stream captured as 'output_stream'. - record_stream(y, output_stream) - else: - output.append(x) - - return tuple(output) - - @staticmethod - def backward(ctx: Context, *grad_output: Tensor,) -> Tuple[Optional[Tensor], ...]: - prev_stream = ctx.prev_stream - next_stream = ctx.next_stream - - grad_input: Deque[Tensor] = deque(maxlen=len(grad_output)) - input_stream = current_stream(get_device(prev_stream)) - - with use_stream(prev_stream), use_stream(next_stream): - for x in reversed(grad_output): - y = x.to(get_device(prev_stream), non_blocking=True) - grad_input.appendleft(y) - - # 'next_stream' is not where 'x' has been allocated. - record_stream(x, next_stream) - # 'y' has been allocated on 'prev_stream'. - # It might be used on the current stream captured as 'input_stream'. - record_stream(y, input_stream) - - grad_streams: Tuple[Optional[Tensor], ...] = (None, None) - return grad_streams + tuple(grad_input) - - -class Wait(torch.autograd.Function): - """Synchronizes a stream to another stream. - - Place it just before you want to start an operation on the next stream, - provided that all operations on the previous stream are done. - - """ - - @staticmethod - # type: ignore[override] - def forward(ctx: Context, prev_stream: AbstractStream, next_stream: AbstractStream, *input) -> Tensors: - ctx.prev_stream = prev_stream - ctx.next_stream = next_stream - - wait_stream(next_stream, prev_stream) - - return tuple(x.detach() if torch.is_tensor(x) else x for x in input) - - @staticmethod - def backward(ctx: Context, *grad_input: Tensor,) -> Tuple[Optional[Tensor], ...]: - prev_stream = ctx.prev_stream - next_stream = ctx.next_stream - - wait_stream(prev_stream, next_stream) - - grad_streams: Tuple[Optional[Tensor], ...] = (None, None) - return grad_streams + grad_input diff --git a/torch/distributed/pipeline/sync/dependency.py b/torch/distributed/pipeline/sync/dependency.py deleted file mode 100644 index ca5c69e388fe4..0000000000000 --- a/torch/distributed/pipeline/sync/dependency.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -"""Arbitrary dependency between two autograd lanes.""" -from typing import List, Tuple - -import torch -from torch import Tensor - -from .phony import get_phony - -__all__: List[str] = ["fork", "Fork", "join", "Join"] - - -def fork(input: Tensor) -> Tuple[Tensor, Tensor]: - """Branches out from an autograd lane of the given tensor.""" - if torch.is_grad_enabled() and input.requires_grad: - input, phony = Fork.apply(input) - else: - phony = get_phony(input.device, requires_grad=False) - - return input, phony - - -class Fork(torch.autograd.Function): - @staticmethod - def forward(ctx: "Fork", input: Tensor) -> Tuple[Tensor, Tensor]: # type: ignore[override] - phony = get_phony(input.device, requires_grad=False) - return input.detach(), phony.detach() - - @staticmethod - def backward(ctx: "Fork", grad_input: Tensor, grad_grad: Tensor) -> Tensor: # type: ignore[override] - return grad_input - - -def join(input: Tensor, phony: Tensor) -> Tensor: - """Merge two autograd lanes.""" - if torch.is_grad_enabled() and (input.requires_grad or phony.requires_grad): - input = Join.apply(input, phony) - - return input - - -class Join(torch.autograd.Function): - @staticmethod - def forward(ctx: "Join", input: Tensor, phony: Tensor) -> Tensor: # type: ignore[override] - return input.detach() - - @staticmethod - def backward(ctx: "Join", grad_input: Tensor) -> Tuple[Tensor, None]: # type: ignore[override] - return grad_input, None diff --git a/torch/distributed/pipeline/sync/microbatch.py b/torch/distributed/pipeline/sync/microbatch.py deleted file mode 100644 index 5b8aca2575480..0000000000000 --- a/torch/distributed/pipeline/sync/microbatch.py +++ /dev/null @@ -1,234 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -"""Manipulation of micro-batches.""" -import typing -from typing import Any, Callable, List, Union, cast, Sequence - -import torch -from torch import Tensor -import torch.cuda.comm - -__all__: List[str] = ["NoChunk", "Batch", "check", "scatter", "gather"] - - -Tensors = Sequence[Tensor] -TensorOrTensors = Union[Tensor, Tensors] -Function = Callable[[TensorOrTensors], Union[List[Any], Tensor]] - - -class NoChunk: - """ - Wrapper for a Tensor in :meth:`Pipe.forward` indicating that the tensor - should not be chunked on the batch dimension and instead be replicated - as-is across all micro-batches. This is useful for tensors which might - not have any 'batch' semantics for the model. - """ - def __init__(self, inp: Tensor): - if not torch.is_tensor(inp): - raise TypeError(f'NoChunk only supported for tensors, found: {inp}') - self._tensor = inp - - @property - def tensor(self): - return self._tensor - - -class Batch: - """ - An abstraction representing a microbatch in the pipeline. - """ - - def __init__(self, values: Union[List[Any], Tensor]) -> None: - self._values = values - self.atomic = torch.is_tensor(values) - - # Verify at least on tensor - if not self.atomic: - if not any(torch.is_tensor(value) for value in self._values): - raise TypeError(f'No tensors found in batch: {self._values}') - - @property - def tensor(self) -> Tensor: - """Retrieves the underlying tensor.""" - if not self.atomic: - raise AttributeError("not atomic batch") - return cast(Tensor, self._values) - - @property - def values(self): - """Retrieves the underlying values for the batch""" - return self._values - - def find_tensor_idx(self): - """ - Retrieves the index of first tensor found. - """ - if self.atomic: - return 0 - for i, value in enumerate(self._values): - if torch.is_tensor(value): - return i - - raise TypeError("No tensor found!") - - def get_device(self): - """ - Retrieves the device for this microbatch. - """ - if self.atomic: - return self._values.device # type: ignore[union-attr] - - for value in self._values: - if torch.is_tensor(value): - return value.device - - def call(self, function: Function) -> "Batch": - """Calls a function on the microbatch. It also wraps - the output with :class:`Batch`. - """ - if self.atomic: - return Batch(function(self._values)) - else: - return Batch(function(*self._values)) - - def __repr__(self) -> str: - return f"Batch[atomic={self.atomic!r}]({self._values!r})" - - def __iter__(self): - if self.atomic: - yield self._values - else: - yield from self._values - - def __len__(self) -> int: - return 1 if self.atomic else len(self._values) - - def __getitem__(self, index: int): - if not self.atomic: - return self._values[index] - - if index != 0: - raise IndexError("atomic batch allows index 0 only") - - return self._values - - # NOTE(sublee): pyflakes can't detect "overload" instead of "typing.overload". - @typing.overload - def __setitem__(self, index: int, value: Tensor) -> None: - ... - - @typing.overload - def __setitem__(self, index: slice, value: Tensors) -> None: - ... - - def __setitem__(self, index: Union[int, slice], value) -> None: - if isinstance(index, int): - self._setitem_by_index(index, value) - else: - self._setitem_by_slice(index, value) - - def _setitem_by_index(self, index: int, value) -> None: - if not self.atomic: - i = index - self._values = self._values[:i] + (value,) + self._values[i + 1 :] # type: ignore[operator] - return - - if index != 0: - raise IndexError("atomic batch allows index 0 only") - - self._values = value - - def _setitem_by_slice(self, index: slice, value) -> None: - if not (index.start is index.stop is index.step is None): # noqa: E714 - raise NotImplementedError("only slice [:] supported") - - if not self.atomic: - self._values = value - return - - if len(value) != 1: - raise IndexError("atomic batch cannot be replaced with multiple tensors") - - self._values = value[0] - - -def check(first_device, *inputs) -> None: - """ - Checks whether the input contains at least one tensor and each tensor is - on the same device as the first partition. - - Raises: - ValueError: input does not contain at least one tensor - - """ - - if not any(torch.is_tensor(input) for input in inputs): - raise TypeError(f'inputs do not have any tensors: {inputs}') - if any(torch.is_tensor(input) and input.device != first_device for input in inputs): - raise ValueError('All inputs should be on the same device as the first partition') - - -def scatter(*inputs, chunks: int) -> List[Batch]: - """Splits an input mini-batch into multiple micro-batches.""" - if len(inputs) == 1 and isinstance(inputs[0], Tensor): - return [Batch(x) for x in inputs[0].chunk(chunks)] - - batches: List[Any] = [[] for _ in range(chunks)] - # Actual number of chunks produced - num_chunks = -1 - for input in inputs: - if torch.is_tensor(input): - # Chunk only tensors. - tensors = input.chunk(chunks) - - # Validate number of chunks equal across all inputs. - if num_chunks != -1 and num_chunks != len(tensors): - raise RuntimeError(f'Found different number of chunks produced for inputs: {num_chunks} and {len(tensors)}') - num_chunks = len(tensors) - - for i, tensor in enumerate(tensors): - batches[i].append(tensor) - else: - # Replicate non-tensors or tensors wrapped with 'NoChunk'. - for i in range(chunks): - if isinstance(input, NoChunk): - # Extract the tensor out. - batches[i].append(input.tensor) - else: - batches[i].append(input) - - # Truncate to actual number of chunks - batches = batches[:num_chunks] - - return [Batch(x) for x in batches] - - -def gather(outputs: List[Batch]): - """Concatenates output micro-batches into a mini-batch.""" - output: Any - - if outputs[0].atomic: - tensors = tuple(b.tensor for b in outputs) - output = torch.cat(tensors) - else: - output_buf: List[Any] = [] - for i in range(len(outputs[0])): - output_type = type(outputs[0][i]) - current_outputs = [] - for batch in outputs: - if output_type != type(batch[i]): - raise TypeError(f'Types for microbatch outputs do not match, found: {output_type} and {type(batch[i])}') - current_outputs.append(batch[i]) - - if torch.is_tensor(outputs[0][i]): - output_buf.append(torch.cat(current_outputs)) - else: - output_buf.append(current_outputs) - - output = tuple(output_buf) - - return output diff --git a/torch/distributed/pipeline/sync/phony.py b/torch/distributed/pipeline/sync/phony.py deleted file mode 100644 index 012926699cfbc..0000000000000 --- a/torch/distributed/pipeline/sync/phony.py +++ /dev/null @@ -1,50 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -"""Provides phony for arbitrary dependency in a autograd graph.""" -from typing import Dict, List, Tuple - -import torch -from torch import Tensor - -from .stream import default_stream, use_stream - -__all__: List[str] = ["get_phony"] - - -_phonies: Dict[Tuple[torch.device, bool], Tensor] = {} - - -def get_phony(device: torch.device, *, requires_grad: bool) -> Tensor: - """Get a phony. Phony is tensor without space. - - It is useful to make arbitrary dependency in a autograd graph because it doesn't require any - gradient accumulation. - - .. note:: - - Phonies for each device are cached. If an autograd function gets a phony - internally, the phony must be detached to be returned. Otherwise, the - autograd engine will mutate the cached phony in-place:: - - class Phonify(torch.autograd.Function): - @staticmethod - def forward(ctx, input): - phony = get_phony(input.device, requires_grad=False) - return phony.detach() # detach() is necessary. - - """ - key = (device, requires_grad) - - try: - phony = _phonies[key] - except KeyError: - with use_stream(default_stream(device)): - phony = torch.empty(0, device=device, requires_grad=requires_grad) - - _phonies[key] = phony - - return phony diff --git a/torch/distributed/pipeline/sync/pipe.py b/torch/distributed/pipeline/sync/pipe.py deleted file mode 100644 index 5e61341d9ad9f..0000000000000 --- a/torch/distributed/pipeline/sync/pipe.py +++ /dev/null @@ -1,490 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -"""The Pipe interface.""" -from collections import OrderedDict -from typing import TYPE_CHECKING, Any, Iterable, Iterator, List, Optional, Union, Sequence, Tuple, cast - -import torch -from torch import Tensor, nn -from torch.distributed.rpc import RRef -import torch.autograd -import torch.cuda - -from . import microbatch -from .batchnorm import DeferredBatchNorm -from .pipeline import Pipeline -from .skip.layout import inspect_skip_layout -from .skip.skippable import verify_skippables -from .stream import AbstractStream, new_stream - -__all__ = ["Pipe", "BalanceError", "PipeSequential", "WithDevice"] - - -Device = Union[torch.device, int, str] -Devices = Union[Iterable[Device], List[Device]] - -Tensors = Sequence[Tensor] -TensorOrTensors = Union[Tensor, Tensors] - -if TYPE_CHECKING: - # Typechecking: nn.Module is not a Generic - Module = nn.Module[TensorOrTensors] # type: ignore[type-arg] - NamedModules = OrderedDict[str, Module] -else: - Module = nn.Module - NamedModules = OrderedDict - - -def _recommend_auto_balance(message: str) -> str: - """Expands a message with recommendation to :mod:`torchpipe.balance`.""" - return f"""{message} - -If your model is still under development, its optimal balance would change -frequently. In this case, we highly recommend 'torch.distributed.pipeline.sync.balance' for -naive automatic balancing: - - from torch.distributed.pipeline.sync import Pipe - from torch.distributed.pipeline.sync.balance import balance_by_time - - partitions = torch.cuda.device_count() - sample = torch.empty(...) - balance = balance_by_time(partitions, model, sample) - - model = Pipe(model, balance, ...) -""" - - -def _verify_module(module: nn.Sequential) -> None: - if not isinstance(module, nn.Sequential): - raise TypeError("module must be nn.Sequential to be partitioned") - - named_children = list(module.named_children()) - if len(named_children) != len(module): - raise ValueError("module with duplicate children is not supported") - - -def _verify_splitting( - module: nn.Sequential, partitions: List[nn.Sequential], devices: List[torch.device] -) -> None: - num_parameters = len(list(module.parameters())) - num_child_parameters = sum(len(list(child.parameters())) for child in module.children()) - if num_parameters == num_child_parameters: - return - - for i in range(len(partitions)): - for j in range(i + 1, len(partitions)): - parti = partitions[i] - partj = partitions[j] - if devices[i] == devices[j]: - continue - for p in parti.parameters(): - for q in partj.parameters(): - if p is q: - raise ValueError("module with duplicate parameters on distinct devices is not supported") - - -class BalanceError(ValueError): - pass - - -def _retrieve_device(module: nn.Module) -> torch.device: - """Validates all parameters in the Module have the same device and returns - the appropriate device. - - Args: - An ``nn.Module`` to process. - - Returns: - ``torch.Device`` for the entire module. - - Raises: - ValueError: - If devices for ``nn.Module`` parameters are not all same. - """ - - device = None - for parameter in module.parameters(): - if device is None: - device = parameter.device - elif device != parameter.device: - raise ValueError( - f'nn.Module: {module}, should have all parameters on a single device,' - ' please use .to() to place the module on a single device') - - return device if device is not None else torch.device("cpu") - - -class PipeSequential(nn.Sequential): - """ - Pipe variant of ``nn.Sequential`` which supports multiple inputs. - """ - - def forward(self, *inputs): - for module in self: - if isinstance(inputs, Tuple): # type: ignore[arg-type] - inputs = module(*inputs) - else: - # Don't expand single variables (ex: lists/Tensor) - inputs = module(inputs) - return inputs - - -class WithDevice(nn.Module): - """ - Wraps an ``nn.Module`` which is part of ``nn.Sequential`` passed into :class:`Pipe` - that overrides the device for that module. In cases where :class:`Pipe` - can't implicitly determine the device for the module and places it on CPU, - this wrapper can be used to override the implicit behavior and explicitly - specify which device a module should run on. - - The provided module is also moved to the given device via ``.to(device)`` - by :class:`Pipe` - - Args: - module(:class:`torch.nn.Module`): The module to be wrapped. - device(:class:`torch.device`): The device to run the module on. - - Example:: - >>> # xdoctest: +SKIP("distributed") - >>> fc1 = nn.Linear(16, 8).cuda(0) - >>> fc2 = nn.Linear(8, 4).cuda(1) - >>> dropout = nn.Dropout() - >>> - >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA1) - >>> # Dropout does not have any parameters/buffers, but we want to - >>> # run it on cuda:1 to avoid any GPU to CPU transfers. - >>> model = nn.Sequential(fc1, fc2, WithDevice(dropout, 'cuda:1')) - >>> # xdoctest: +SKIP("Needs RPC framework init") - >>> model = Pipe(model, chunks=8) - """ - def __init__(self, module: nn.Module, device: torch.device): - super().__init__() - self._module = module - self._device = torch.device(device) - - def forward(self, *args, **kwargs): - return self._module(*args, **kwargs) - - @property - def module(self): - return self._module - - @property - def device(self): - return self._device - - -def _assemble_partition(modules: List[nn.Module]): - modules_list: List[nn.Module] = [] - for module in modules: - if isinstance(module, nn.Sequential): - modules_list.extend(module.children()) - else: - modules_list.append(module) - return PipeSequential(*modules_list) - - -def _split_module(modules: nn.Sequential) -> Tuple[List[nn.Sequential], List[torch.device]]: - partitions = [] - devices = [] - - current_partition = [] - current_device = None - for name, module in modules.named_children(): - if isinstance(module, WithDevice): - # Process device override and move module to appropriate device. - device = module.device - module = module.module - module.to(device) - else: - device = _retrieve_device(module) - if current_device is not None and (current_device != device or device.type == 'cpu'): - partitions.append(_assemble_partition(current_partition)) - devices.append(current_device) - current_partition = [] - current_device = device - current_partition.append(module) - - if current_device is not None: - partitions.append(_assemble_partition(current_partition)) - devices.append(current_device) - - partitions = cast(List[nn.Sequential], nn.ModuleList(partitions)) - - return partitions, devices - - -MOVING_DENIED = TypeError("denied to move parameters and buffers, because Pipe should manage device placement") - - -class Pipe(Module): - """Wraps an arbitrary :class:`nn.Sequential ` module - to train on using synchronous pipeline parallelism. If the module requires - lots of memory and doesn't fit on a single GPU, pipeline parallelism is a - useful technique to employ for training. - - The implementation is based on the torchgpipe_ paper. - - .. _torchgpipe: https://arxiv.org/abs/2004.09910 - - Pipe combines pipeline parallelism with checkpointing to reduce peak - memory required to train while minimizing device under-utilization. - - You should place all the modules on the appropriate devices and wrap them - into an :class:`nn.Sequential ` module defining the - desired order of execution. If a module does not contain any - parameters/buffers, it is assumed this module should be executed on CPU - and appropriate input tensors to the module are moved to CPU before - execution. This behavior can be overridden by the :class:`WithDevice` - wrapper which can be used to explicitly specify which device a module - should run on. - - Args: - module (:class:`nn.Sequential `): - sequential module to be parallelized using pipelining. Each module - in the sequence has to have all of its parameters on a single - device. Each module in the sequence has to either be an nn.Module - or :class:`nn.Sequential ` (to combine multiple - sequential modules on a single device) - chunks (int): - number of micro-batches (default: ``1``) - checkpoint (str): - when to enable checkpointing, one of ``'always'``, - ``'except_last'``, or ``'never'`` (default: ``'except_last'``). - ``'never'`` disables checkpointing completely, ``'except_last'`` - enables checkpointing for all micro-batches except the last one - and ``'always'`` enables checkpointing for all micro-batches. - deferred_batch_norm (bool): - whether to use deferred ``BatchNorm`` moving statistics (default: - :data:`False`). If set to :data:`True`, we track statistics across - multiple micro-batches to update the running statistics per - mini-batch. - - Raises: - TypeError: - the module is not a :class:`nn.Sequential `. - ValueError: - invalid arguments - - Example:: - Pipeline of two FC layers across GPUs 0 and 1. - - >>> # Need to initialize RPC framework first. - >>> # xdoctest: +SKIP - >>> os.environ['MASTER_ADDR'] = 'localhost' - >>> os.environ['MASTER_PORT'] = '29500' - >>> torch.distributed.rpc.init_rpc('worker', rank=0, world_size=1) - >>> - >>> # Build pipe. - >>> fc1 = nn.Linear(16, 8).cuda(0) - >>> fc2 = nn.Linear(8, 4).cuda(1) - >>> model = nn.Sequential(fc1, fc2) - >>> model = Pipe(model, chunks=8) - >>> input = torch.rand(16, 16).cuda(0) - >>> output_rref = model(input) - - .. note:: - You can wrap a :class:`Pipe` model with - :class:`torch.nn.parallel.DistributedDataParallel` only when the - checkpoint parameter of :class:`Pipe` is ``'never'``. - - .. note:: - :class:`Pipe` only supports intra-node pipelining currently, but - will be expanded to support inter-node pipelining in the future. - The forward function returns an :class:`~torch.distributed.rpc.RRef` - to allow for inter-node pipelining in the future, where the output - might be on a remote host. For intra-node pipelining you can use - :meth:`~torch.distributed.rpc.RRef.local_value` to retrieve the - output locally. - - .. warning:: - :class:`Pipe` is experimental and subject to change. - """ - - def __init__( - self, - module: nn.Sequential, - chunks: int = 1, - checkpoint: str = "except_last", - deferred_batch_norm: bool = False, - ) -> None: - super().__init__() - - # Check if RPC framework is initialized. - if not torch.distributed.rpc._is_current_rpc_agent_set(): - raise RuntimeError( - 'Please initialize RPC framework for Pipe using ' - 'torch.distributed.rpc.init_rpc') - - chunks = int(chunks) - checkpoint = str(checkpoint) - - if chunks <= 0: - raise ValueError("number of chunks must be positive integer") - if checkpoint not in ["always", "except_last", "never"]: - raise ValueError("checkpoint is not one of 'always', 'except_last', or 'never'") - - _verify_module(module) - - # Verify if the underlying skippable modules satisfy integrity. The - # integrity can be verified before forward() because it is static. - verify_skippables(module) - - self.chunks = chunks - self.checkpoint = checkpoint - - if deferred_batch_norm: - module = DeferredBatchNorm.convert_deferred_batch_norm(module, chunks) - - self.partitions, self.devices = _split_module(module) - _verify_splitting(module, self.partitions, self.devices) - - self._copy_streams: List[List[AbstractStream]] = [] - self._skip_layout = inspect_skip_layout(self.partitions) - - # Separate CUDA streams for copy. - copy_streams = self._ensure_copy_streams() - - # The micro-batch index where the checkpointing stops. - checkpoint_stop = {"always": self.chunks, "except_last": self.chunks - 1, "never": 0}[self.checkpoint] - - self.pipeline = Pipeline(self.partitions, self.devices, copy_streams, self._skip_layout, checkpoint_stop) - - def __len__(self) -> int: - """Counts the length of the underlying sequential module.""" - return sum(len(p) for p in self.partitions) - - def __getitem__(self, index: int) -> nn.Module: - """Gets a layer in the underlying sequential module.""" - partitions = self.partitions - if index < 0: - partitions = partitions[::-1] - - for partition in partitions: - try: - return partition[index] - except IndexError: - pass - - shift = len(partition) - - if index < 0: - index += shift - else: - index -= shift - - raise IndexError - - def __iter__(self) -> Iterator[nn.Module]: - """Iterates over children of the underlying sequential module.""" - for partition in self.partitions: - yield from partition - - # Pipe should manage the device of each partition. - # Deny cuda(), cpu(), and to() with device, by TypeError. - def cuda(self, device: Optional[Device] = None) -> "Pipe": - raise MOVING_DENIED - - def cpu(self) -> "Pipe": - raise MOVING_DENIED - - def to(self, *args: Any, **kwargs: Any) -> "Pipe": - # Deny these usages: - # - # - to(device[, dtype, non_blocking]) - # - to(tensor[, non_blocking]) - # - # But allow this: - # - # - to(dtype[, non_blocking]) - # - if "device" in kwargs or "tensor" in kwargs: - raise MOVING_DENIED - - if args: - if isinstance(args[0], (torch.device, int, str)): - raise MOVING_DENIED - if torch.is_tensor(args[0]): - raise MOVING_DENIED - - return super().to(*args, **kwargs) - - def _ensure_copy_streams(self) -> List[List[AbstractStream]]: - """Ensures that :class:`Pipe` caches CUDA streams for copy. - - It's worth to cache CUDA streams although PyTorch already manages a - pool of pre-allocated CUDA streams, because it may reduce GPU memory - fragmentation when the number of micro-batches is small. - - """ - if not self._copy_streams: - for device in self.devices: - self._copy_streams.append([new_stream(device) for _ in range(self.chunks)]) - - return self._copy_streams - - def forward(self, *inputs) -> RRef: - """ - Processes a single input mini-batch through the pipe and returns an - :class:`~torch.distributed.rpc.RRef` pointing to the output. - :class:`Pipe` is a fairly transparent module wrapper. It doesn't - modify the input and output signature of the underlying module. But - there's type restriction. Input and output have to contain at least one - tensor. This restriction is applied at partition boundaries too. - - The sequence of inputs are fed into the first stage of the pipeline as - ``*inputs``. As a result the positional args for this function should - match the positional args for the first stage of the pipeline. The same - condition applies for output of one stage of the pipeline which is the - input for the next stage. - - The input tensor is split into multiple micro-batches based on the - ``chunks`` parameter used to initialize :class:`Pipe`. The batch size - is assumed to be the first dimension of the tensor and if the batch - size is less than ``chunks``, the number of micro-batches is equal to - the batch size. - - Only tensors are split into multiple micro-batches, non-Tensor inputs - are just replicated as-is in each micro-batch. For non-Tensor outputs - in the last stage of the pipeline, they are aggregated as a ``List`` - and returned the user. For example, if you have 2 micro-batches - returning the integer 5, the user would receive the consolidated - output of `[5, 5]` - - All the input tensors need to be on the same device as the first - partition of the pipeline. - - If a tensor is wrapped with the :class:`NoChunk` wrapper, the tensor - is not split across micro-batches and is replicated as-is similar to - non-tensors. - - Args: - inputs: input mini-batch - - Returns: - :class:`~torch.distributed.rpc.RRef` to the output of the mini-batch - - Raises: - TypeError: input doesn't contain at least one tensor - - """ - first_partition_device = self.devices[0] if len(self.devices) != 0 else torch.device("cpu") - microbatch.check(first_partition_device, *inputs) - - if not self.devices: - # Empty sequential module is not illegal. - return RRef(*inputs) - - # Divide a mini-batch into micro-batches. - batches = microbatch.scatter(*inputs, chunks=self.chunks) - - # Run pipeline parallelism. - self.pipeline.run(batches) - - # Merge the micro-batches into one mini-batch. - output = microbatch.gather(batches) - return RRef(output) diff --git a/torch/distributed/pipeline/sync/pipeline.py b/torch/distributed/pipeline/sync/pipeline.py deleted file mode 100644 index 7cd5e58311697..0000000000000 --- a/torch/distributed/pipeline/sync/pipeline.py +++ /dev/null @@ -1,255 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -"""The pipeline parallelism of Pipe.""" -from queue import Queue -from types import TracebackType -from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Type, Union, cast, Sequence - -import torch -from torch import Tensor, nn -from torch.autograd.profiler import record_function - -from .checkpoint import Checkpointing -from .copy import Copy, Wait -from .dependency import fork, join -from .microbatch import Batch -from .skip.layout import SkipLayout -from .skip.tracker import SkipTrackerThroughPotals, use_skip_tracker -from .stream import AbstractStream, current_stream, use_device -from .worker import Task, create_workers - -__all__: List[str] = ["Pipeline"] - - -Tensors = Sequence[Tensor] -TensorOrTensors = Union[Tensor, Tensors] - -ExcInfo = Tuple[Type[BaseException], BaseException, TracebackType] - -# Queue is generic only in stubs. -# https://mypy.readthedocs.io/en/latest/common_issues.html#using-classes-that-are-generic-in-stubs-but-not-at-runtime -if TYPE_CHECKING: - InQueue = Queue[Optional["Task"]] - OutQueue = Queue[Tuple[bool, Union[Tuple["Task", Batch], ExcInfo, None]]] -else: - InQueue = Queue - OutQueue = Queue - - -def _depend(fork_from: Batch, join_to: Batch) -> None: - fork_from_idx = fork_from.find_tensor_idx() - join_to_idx = join_to.find_tensor_idx() - - fork_from[fork_from_idx], phony = fork(fork_from[fork_from_idx]) - join_to[join_to_idx] = join(join_to[join_to_idx], phony) - - -def _copy(batch: Batch, prev_stream: AbstractStream, next_stream: AbstractStream) -> None: - batch[:] = Copy.apply(prev_stream, next_stream, *batch) - # Gradients are only supported for float Tensors. - batch[:] = tuple([x.detach() if torch.is_tensor(x) and not x.is_floating_point() else x for x in batch]) - - -def _wait(batch: Batch, prev_stream: AbstractStream, next_stream: AbstractStream) -> None: - batch[:] = Wait.apply(prev_stream, next_stream, *batch) - # Gradients are only supported for float Tensors. - batch[:] = tuple([x.detach() if torch.is_tensor(x) and not x.is_floating_point() else x for x in batch]) - - -def _clock_cycles(m: int, n: int) -> Iterable[List[Tuple[int, int]]]: - """Generate schedules for each clock cycle.""" - # m: number of micro-batches - # n: number of partitions - # i: index of micro-batch - # j: index of partition - # k: clock number - # - # k (i,j) (i,j) (i,j) - # - ----- ----- ----- - # 0 (0,0) - # 1 (1,0) (0,1) - # 2 (2,0) (1,1) (0,2) - # 3 (2,1) (1,2) - # 4 (2,2) - for k in range(m + n - 1): - yield [(k - j, j) for j in range(max(1 + k - m, 0), min(1 + k, n))] - - -class Pipeline: - """The pipeline parallelism for Pipe.""" - - def __init__( - self, - partitions: List[nn.Sequential], - devices: List[torch.device], - copy_streams: List[List[AbstractStream]], - skip_layout: SkipLayout, - checkpoint_stop: int, - ) -> None: - self.partitions = partitions - self.devices = devices - self.copy_streams = copy_streams - self.skip_layout = skip_layout - self.checkpoint_stop = checkpoint_stop - (self.in_queues, self.out_queues) = create_workers(devices) - - def run(self, batches: List[Batch]) -> None: - """Runs pipeline parallelism. - - It modifies the given batches in place. - - """ - partitions = self.partitions - devices = self.devices - skip_layout = self.skip_layout - - m = len(batches) - n = len(partitions) - - skip_trackers = [SkipTrackerThroughPotals(skip_layout) for _ in batches] - - for schedule in _clock_cycles(m, n): - self.fence(batches, schedule, skip_trackers) - self.compute(batches, schedule, skip_trackers) - - def fence( - self, batches: List[Batch], schedule: List[Tuple[int, int]], skip_trackers: List[SkipTrackerThroughPotals], - ) -> None: - """Copy micro-batches after computation for the previous micro-batches.""" - copy_streams = self.copy_streams - skip_layout = self.skip_layout - - for i, j in schedule: - # Ensure that batches[i-1] is executed after batches[i] in - # backpropagation by an explicit dependency. - if i != 0 and j != 0: - _depend(batches[i - 1], batches[i]) - - next_stream = copy_streams[j][i] - - for prev_j, ns, name in skip_layout.copy_policy(j): - prev_stream = copy_streams[prev_j][i] - skip_trackers[i].copy(batches[i], prev_stream, next_stream, ns, name) - - if j != 0: - prev_stream = copy_streams[j - 1][i] - _copy(batches[i], prev_stream, next_stream) - - def compute( - self, batches: List[Batch], schedule: List[Tuple[int, int]], skip_trackers: List[SkipTrackerThroughPotals], - ) -> None: - """Run tasks with synchronization to copy streams.""" - partitions = self.partitions - devices = self.devices - copy_streams = self.copy_streams - checkpoint_stop = self.checkpoint_stop - - # Disable checkpointing if in eval mode. - if not self.partitions[0].training: - checkpoint_stop = 0 - - n = len(partitions) - streams = [current_stream(d) for d in devices] - exc_info: Optional[ExcInfo] = None - - # With checkpointing, the autograd graph looks like this diagram: - # +-----+------+ - # | Copy | - # +-----+------+ (fence) - # - - - + - - - - - - - - - - # | (compute) - # +-----+------+ - # | Wait | [1] Synchronize the current stream with the copy stream. - # +-----+------+ - # +-----+------+ - # | Checkpoint | [2] Compute a partition within checkpointing. - # +-----+------+ - # +-----+------+ - # | Wait | [3] Synchronize the copy stream with the current stream. - # +-----+------+ - # + - - - + - # | +-----+-----+ - # | | Recompute | [4] Schedule the recomputation at backpropagation. - # | +-----+-----+ - # + - - - + - # | - # - - - + - - - - - - - - - - # +-----+------+ (fence) - # | Copy | - # +-----+------+ - for i, j in schedule: - batch = batches[i] - partition = partitions[j] - - # Synchronize with the copied input. ([1] in the diagram) - if j != 0: - _wait(batch, copy_streams[j][i], streams[j]) - - # Determine whether checkpointing or not. - checkpoint = i < checkpoint_stop - if checkpoint: - - def function( - *inputs, - partition: nn.Module = partition, - skip_tracker: SkipTrackerThroughPotals = skip_trackers[i], - chunk_id: int = i, - part_id: int = j, - ) -> TensorOrTensors: - with use_skip_tracker(skip_tracker), record_function("chunk%d-part%d" % (chunk_id, part_id)): - return partition(*inputs) - - chk = Checkpointing(function, batch) # type: ignore[arg-type] - task = Task(streams[j], compute=chk.checkpoint, finalize=chk.recompute) - del function, chk - - else: - - def compute( - batch: Batch = batch, - partition: nn.Module = partition, - skip_tracker: SkipTrackerThroughPotals = skip_trackers[i], - chunk_id: int = i, - part_id: int = j, - ) -> Batch: - with use_skip_tracker(skip_tracker), record_function("chunk%d-part%d" % (chunk_id, part_id)): - return batch.call(partition) - - task = Task(streams[j], compute=compute, finalize=None) - del compute - - # Compute tasks in parallel. ([2] in the diagram) - self.in_queues[j].put(task) - - for i, j in schedule: - ok, payload = self.out_queues[j].get() - - # Hold the first exception. - if exc_info is not None: - continue - elif not ok: - exc_info = cast(ExcInfo, payload) - continue - - task, batch = cast(Tuple[Task, Batch], payload) - - # The copy stream synchronizes to copy the output. ([3] in the - # diagram) - if j != n - 1: - _wait(batch, streams[j], copy_streams[j][i]) - - # Finalize tasks. If checkpointing is enabled, here the - # recomputation is scheduled at backpropagation. ([4] in the - # diagram) - with use_device(devices[j]): - task.finalize(batch) - - batches[i] = batch - - # Fail at the first exception. - if exc_info is not None: - raise exc_info[0].with_traceback(exc_info[1], exc_info[2]) diff --git a/torch/distributed/pipeline/sync/py.typed b/torch/distributed/pipeline/sync/py.typed deleted file mode 100644 index ab03724cafbf5..0000000000000 --- a/torch/distributed/pipeline/sync/py.typed +++ /dev/null @@ -1,6 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. diff --git a/torch/distributed/pipeline/sync/skip/__init__.py b/torch/distributed/pipeline/sync/skip/__init__.py deleted file mode 100644 index bdcb913867a73..0000000000000 --- a/torch/distributed/pipeline/sync/skip/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -"""Supports efficiency with skip connections.""" -from .namespace import Namespace -from .skippable import pop, skippable, stash, verify_skippables - -__all__ = ["skippable", "stash", "pop", "verify_skippables", "Namespace"] diff --git a/torch/distributed/pipeline/sync/skip/layout.py b/torch/distributed/pipeline/sync/skip/layout.py deleted file mode 100644 index 04d76d34ea166..0000000000000 --- a/torch/distributed/pipeline/sync/skip/layout.py +++ /dev/null @@ -1,92 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -"""Static skip connection layout of ``@skippable`` modules.""" -from typing import Dict, Iterable, List, Tuple - -from torch import nn - -from .namespace import Namespace - -__all__: List[str] = [] - - -class SkipLayout: - """Represents a skip connection layout across partitions.""" - - # Skip routes indexed by 'ns, name': {(ns, name): (prev_j, next_j), ...} - by_ns_name: Dict[Tuple[Namespace, str], Tuple[int, int]] - - # Skip routes indexed by partition number 'j': [[next_j]: [(prev_j, ns, name), ...], ...] - by_partition: List[List[Tuple[int, Namespace, str]]] - - def __init__(self, num_partitions: int, skip_routes: Dict[Tuple[Namespace, str], Tuple[int, int]],) -> None: - # The skip routes are already indexed by 'ns, name'. - self.by_ns_name = skip_routes - - # Index skip routes by partition number 'j'. - self.by_partition = [[] for _ in range(num_partitions)] - - for (ns, name), (prev_j, next_j) in skip_routes.items(): - self.by_partition[next_j].append((prev_j, ns, name)) - - for p in self.by_partition: - p.sort() - - def copy_policy(self, next_j: int) -> Iterable[Tuple[int, Namespace, str]]: - """Generates skip routes for the given destination partition number. - The skip routes are sorted by source partition number in ascending - order. - - Yields: - Each tuple of (source partition number, namespace, name). - - """ - for prev_j, ns, name in self.by_partition[next_j]: - if prev_j == next_j: - # This skip tensor will be popped at the same partition where - # it is stashed. In this case, copy is not required. - continue - - yield (prev_j, ns, name) - - def requires_copy(self, ns: Namespace, name: str) -> bool: - """Whether the given namespace and name requires partition-to-partition - copy or not. - """ - prev_j, next_j = self.by_ns_name.get((ns, name), (-1, -1)) - return prev_j != next_j - - -def inspect_skip_layout(partitions: List[nn.Sequential]) -> SkipLayout: - """Inspects the skip connection layout in the given partitions.""" - # NOTE(sublee): Hide circular import inside this subroutine. Circular - # import is not ideal but placing this logic near to SkipLayout may - # increase cohesion of code. - from .skippable import Skippable - - skip_routes: Dict[Tuple[Namespace, str], Tuple[int, int]] = {} - stashed_at: Dict[Tuple[Namespace, str], int] = {} - - for j, partition in enumerate(partitions): - def inspect_layer(layer): - if not isinstance(layer, Skippable): - return - - for ns, name in layer.stashable(): - stashed_at[(ns, name)] = j - - for ns, name in layer.poppable(): - prev_j = stashed_at.pop((ns, name)) - skip_routes[(ns, name)] = (prev_j, j) - - if isinstance(partition, nn.Sequential): - for layer in partition: - inspect_layer(layer) - else: - inspect_layer(partition) - - return SkipLayout(len(partitions), skip_routes) diff --git a/torch/distributed/pipeline/sync/skip/namespace.py b/torch/distributed/pipeline/sync/skip/namespace.py deleted file mode 100644 index 7d9c0d9b7d842..0000000000000 --- a/torch/distributed/pipeline/sync/skip/namespace.py +++ /dev/null @@ -1,50 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -"""Provides isolated namespace of skip tensors.""" -import abc -from functools import total_ordering -from typing import Any -import uuid - -__all__ = ["Namespace"] - - -@total_ordering -class Namespace(metaclass=abc.ABCMeta): # noqa: B024 - """Namespace for isolating skip tensors used by :meth:`isolate() - `. - """ - - __slots__ = ("id",) - - def __init__(self) -> None: - self.id = uuid.uuid4() - - def __repr__(self) -> str: - return f"" - - def __hash__(self) -> int: - return hash(self.id) - - # Namespaces should support ordering, since SkipLayout will sort tuples - # including a namespace. But actual order between namespaces is not - # important. That's why they are ordered by version 4 UUID which generates - # random numbers. - def __lt__(self, other: Any) -> bool: - if isinstance(other, Namespace): - return self.id < other.id - return False - - def __eq__(self, other: object) -> bool: - if isinstance(other, Namespace): - return self.id == other.id - return False - - -# 'None' is the default namespace, -# which means that 'isinstance(None, Namespace)' is 'True'. -Namespace.register(type(None)) diff --git a/torch/distributed/pipeline/sync/skip/portal.py b/torch/distributed/pipeline/sync/skip/portal.py deleted file mode 100644 index 335793f4cc137..0000000000000 --- a/torch/distributed/pipeline/sync/skip/portal.py +++ /dev/null @@ -1,231 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -"""Portal keeps a tensor in the pocket plane. The tensor becomes hidden to the -autograd engine. The shared context of three functions (:class:`PortalBlue`, -:class:`PortalOrange`, and :class:`PortalCopy`) out of the computation graph is -one of the most important feature of :mod:`torchpipe.skip`. - -The metaphor is inspired by Portal(tm) from Valve. - -""" -from typing import List, Optional, Tuple - -import torch -from torch import Tensor - -from ..copy import Context as CopyContext -from ..copy import Copy -from ..phony import get_phony -from ..stream import AbstractStream, get_device - -__all__: List[str] = [] - - -class Portal: - """A portal for a tensor.""" - - def __init__(self, tensor: Optional[Tensor], tensor_life: int) -> None: - self.put_tensor(tensor, tensor_life) - self.grad: Optional[Tensor] = None - - def blue(self) -> Tensor: - """Creates a :class:`PortalBlue` which hides the underlying tensor from - the autograd engine. - - Join the returning phony to the main lane of the autograd graph to - assure the correct backpropagation:: - - PortalBlue --+ - | - ---------- Join -- - - """ - tensor = self.use_tensor() - - if tensor is None: - return get_phony(torch.device("cpu"), requires_grad=False) - - return PortalBlue.apply(self, tensor) - - def orange(self, phony: Tensor) -> Optional[Tensor]: - """Creates a :class:`PortalOrange` which retrieves the hidden tensor - without losing ability of backpropagation. - - Give a phony forked from the main lane of an autograd graph:: - - +-- PortalOrange --+ - | | - -- Fork --------- f(a, b) -- - - """ - self.check_tensor_life() - - if self.tensor is None: - return self.use_tensor() - - return PortalOrange.apply(self, phony) - - def copy(self, prev_stream: AbstractStream, next_stream: AbstractStream, phony: Tensor,) -> Tensor: - """Copies the hidden tensor by a :class:`PortalCopy`. - - Give a phony and use the returning phony to keep backpropagation:: - - +-- PortalCopy --+ - | | - -- Fork ---------- Join -- - - """ - if self.tensor is None: - return get_phony(torch.device("cpu"), requires_grad=False) - - return PortalCopy.apply(self, prev_stream, next_stream, phony) - - def check_tensor_life(self) -> None: - if self.tensor_life <= 0: - raise RuntimeError("tensor in portal has been removed") - - def put_tensor(self, tensor: Optional[Tensor], tensor_life: int) -> None: - """Stores a tensor into this portal.""" - # [Life of Tensor through Portal] - # - # The tensor can be retrieved by use_tensor() up to 'tensor_life' - # times. When the life becomes 0, the tensor will be deleted for - # deallocation in CUDA memory. - # - # The below events participate in a tensor through a portal. - # Note that [x] denotes the events which call use_tensor(): - # - # 1. [x] blue() - # 2. [ ] PortalBlue.forward - # 3. [ ] copy() - # 4. [ ] PortalCopy.forward - # 5. [ ] orange() - # 6. [x] PortalOrange.forward - # - - - - - - - - - - - - - - - - - - - - - - - - - - - - # 7. [ ] orange() (recomputed) - # 8. [x] PortalOrange.forward (recomputed) - # 9. [ ] PortalOrange.backward - # 10. [ ] PortalCopy.backward - # 11. [x] blue() (recomputed) - # 12. [ ] PortalBlue.forward (recomputed) - # 13. [ ] PortalBlue.backward - # - self.tensor_life = tensor_life - - if tensor_life > 0: - self.tensor = tensor - else: - self.tensor = None - - def use_tensor(self) -> Optional[Tensor]: - """Retrieves the underlying tensor and decreases the tensor life. When - the life becomes 0, it the tensor will be removed. - """ - self.check_tensor_life() - - tensor = self.tensor - - self.tensor_life -= 1 - - if self.tensor_life <= 0: - self.tensor = None - - return tensor - - def put_grad(self, grad: Tensor) -> None: - """Stores a gradient into this portal.""" - self.grad = grad - - def use_grad(self) -> Tensor: - """Retrieves and removes the underlying gradient. The gradient is - always ephemeral. - """ - if self.grad is None: - raise RuntimeError("grad in portal has been removed or never set") - - grad = self.grad - self.grad = None - return grad - - -# Common interface between :class:`PortalBlue`, :class:`PortalOrange`, and -# :class:`PortalCopy`. -class Context(CopyContext): - portal: Portal - - -class PortalBlue(torch.autograd.Function): - """Hides a tensor from the autograd engine by a :class:`Portal`.""" - - @staticmethod - # type: ignore[override] - def forward( - ctx: Context, - portal: Portal, - # This tensor must be retrieved by portal.use_tensor(). - tensor: Tensor, - ) -> Tensor: - ctx.portal = portal - - phony = get_phony(tensor.device, requires_grad=False) - return phony.detach() - - @staticmethod - # type: ignore[override] - def backward(ctx: Context, grad_phony: Tensor,) -> Tuple[None, Tensor]: - # The paired PortalOrange should keep the gradient. - grad = ctx.portal.use_grad() - return None, grad - - -class PortalOrange(torch.autograd.Function): - """Retrieves the hidden tensor from a :class:`Portal`.""" - - @staticmethod - # type: ignore[override] - def forward(ctx: Context, portal: Portal, phony: Tensor) -> Tensor: - ctx.portal = portal - - tensor = portal.use_tensor() - assert tensor is not None - - return tensor.detach() - - @staticmethod - def backward(ctx: Context, grad: Tensor) -> Tuple[None, None]: # type: ignore[override] - # The paired PortalBlue will use the gradient. - ctx.portal.put_grad(grad) - return None, None - - -class PortalCopy(torch.autograd.Function): - """Copies the hidden tensor in a :class:`Portal`. It replaces the hidden - tensor with copied one. - """ - - @staticmethod - # type: ignore[override] - def forward( - ctx: Context, portal: Portal, prev_stream: AbstractStream, next_stream: AbstractStream, phony: Tensor, - ) -> Tensor: - ctx.portal = portal - - assert portal.tensor is not None - (portal.tensor,) = Copy.forward(ctx, prev_stream, next_stream, portal.tensor) - - phony = get_phony(get_device(next_stream), requires_grad=False) - return phony.detach() - - @staticmethod - # type: ignore[override] - def backward(ctx: Context, grad_phony: Tensor,) -> Tuple[None, None, None, None]: - portal = ctx.portal - - assert portal.grad is not None - _, _, portal.grad = Copy.backward(ctx, portal.grad) - - return None, None, None, None diff --git a/torch/distributed/pipeline/sync/skip/skippable.py b/torch/distributed/pipeline/sync/skip/skippable.py deleted file mode 100644 index 9d4db76c6b670..0000000000000 --- a/torch/distributed/pipeline/sync/skip/skippable.py +++ /dev/null @@ -1,431 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -"""The user interface to define skip connections.""" -from typing import ( - TYPE_CHECKING, - Any, - Callable, - ClassVar, - Dict, - FrozenSet, - Generator, - Iterable, - List, - Optional, - Set, - Sequence, - Tuple, - Type, - TypeVar, - Union, - cast, -) - -from torch import Tensor, nn - -from ..microbatch import Batch -from .namespace import Namespace -from .tracker import current_skip_tracker - -__all__ = ["skippable", "stash", "pop", "verify_skippables"] - - -Tensors = Sequence[Tensor] -TensorOrTensors = Union[Tensor, Tensors] - -StashPop = Union["stash", "pop"] -StashPopGenerator = Generator[StashPop, Optional[Tensor], TensorOrTensors] -if TYPE_CHECKING: - # Typechecking: nn.Module is not a Generic - SkippableModule = nn.Module[Union[StashPopGenerator, TensorOrTensors]] # type: ignore[type-arg] -else: - SkippableModule = nn.Module - -T = TypeVar("T", bound="Skippable") - - -class Skippable(nn.Module): - """The base class for skippable modules. - - Do not use this class directly. Define a subclass by :func:`skippable` - instead. - - """ - - module_cls: ClassVar[Type[SkippableModule]] - stashable_names: ClassVar[FrozenSet[str]] - poppable_names: ClassVar[FrozenSet[str]] - - def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__() - self.module = self.module_cls(*args, **kwargs) # type: ignore[call-arg] - self.namespaces: Dict[str, Namespace] = {} - - def __repr__(self) -> str: - return f"@skippable({self.module})" - - def namespaced(self, name: str) -> Tuple[Namespace, str]: - """Prepend namespace for the given skip name.""" - ns = self.namespaces.get(name) - ns = cast(Namespace, ns) - return (ns, name) - - def stashable(self) -> Iterable[Tuple[Namespace, str]]: - """Iterate over namespaced skip names to be stashed.""" - for name in self.stashable_names: - yield self.namespaced(name) - - def poppable(self) -> Iterable[Tuple[Namespace, str]]: - """Iterate over namespaced skip names to be popped.""" - for name in self.poppable_names: - yield self.namespaced(name) - - def isolate(self: T, ns: Namespace, *, only: Optional[Iterable[str]] = None) -> T: - r"""Isolate a specified subset or the whole set of skip tensors. - - In a single sequential module, skip tensors with the same - name are not allowed unless they are isolated by different namespaces. - - Here's an example using the same name for skip tensors twice. Each pair - of ``Layer1`` and ``Layer2`` is isolated with its own namespace ``ns1`` - and ``ns2``. There is no conflict anymore:: - - ns1 = Namespace() - ns2 = Namespace() - - model = nn.Sequential( - Layer1().isolate(ns1), - Layer1().isolate(ns2), - Layer2(), - Layer3().isolate(ns2), - Layer3().isolate(ns1), - ) - - When `only` parameter is omitted, all skip tensors are isolated. You - can isolate a subset of skip tensors by passing `only` parameter:: - - ns_alice = Namespace() - ns_bob = Namespace() - - model = nn.Sequential( - ... - StashStashPop().isolate(ns_alice, only=['alice']) \ - .isolate(ns_bob, only=['bob']), - ... - ) - - Args: - ns (Namespace): - namespace for isolation - - Keyword Args: - only (iterable of strs): - names of specific skip tensors to be isolated (omit this option - to isolate all skip tensors declared in this module) - - Returns: - this module itself - - """ - names: Iterable[str] - - if only is None: - names = self.stashable_names | self.poppable_names - else: - names = set(only) - - for name in names: - self.namespaces[name] = ns - - return self - - def dispatch( - self, - input, - handle_stash: Callable[[str, Optional[Tensor]], None], - handle_pop: Callable[[str], Optional[Tensor]], - ): - """Dispatch :class:`stash` or :class:`pop` commands. - - The commands are generated by the module's ``forward()``. - """ - generator = self.module(input) - - if not isinstance(generator, Generator): - # The underlying module returned output without any yield. - output = generator - return output - - try: - op = next(generator) - - while True: - if isinstance(op, stash): - handle_stash(op.name, op.tensor) - op = next(generator) - continue - - if isinstance(op, pop): - tensor = handle_pop(op.name) - op = generator.send(tensor) - continue - - raise TypeError(f"{op!r} is not a command from @skippable") - - except StopIteration as stop: - output = stop.args[0] - return output - - def forward(self, input: Union[List[Any], Tensor]) -> TensorOrTensors: - """Perform the forward propagation. - - :class:`stash` or :class:`pop` commands will be handled by portals - silently. The portals won't be exposed to users. - - Raises: - RuntimeError: - illegal 'stash' or 'pop' is found. - - """ - skip_tracker = current_skip_tracker() - stashed_tensors: Dict[str, Optional[Tensor]] = {} - - # Load skip tensors that might be popped. - poppable_tensors = {} - batch = Batch(input) - for ns, name in self.poppable(): - try: - poppable_tensors[name] = skip_tracker.load(batch, ns, name) - except KeyError as e: - raise RuntimeError(f"'{name}' has not been stashed") from e - input = batch.values - - # Handle skip commands. - def handle_stash(name: str, tensor: Optional[Tensor]) -> None: - if name not in self.stashable_names: - raise RuntimeError(f"'{name}' has not been declared as stashable") - stashed_tensors[name] = tensor - - def handle_pop(name: str) -> Optional[Tensor]: - if name not in self.poppable_names: - raise RuntimeError(f"'{name}' has not been declared as poppable") - return poppable_tensors.pop(name) - - output = self.dispatch(input, handle_stash, handle_pop) - - # All declared skips must be stashed or popped. - not_stashed = self.stashable_names - stashed_tensors.keys() - if not_stashed: - comma_names = ", ".join(f"'{n}'" for n in not_stashed) - raise RuntimeError(f"{comma_names} must be stashed but have not") - - not_popped = poppable_tensors.keys() - if not_popped: - comma_names = ", ".join(f"'{n}'" for n in not_popped) - raise RuntimeError(f"{comma_names} must be popped but have not") - - # Save stashed skip tensors. - batch = Batch(output) - for ns, name in self.stashable(): - tensor = stashed_tensors[name] - skip_tracker.save(batch, ns, name, tensor) - output = batch.values - - return output - - -# TODO(sublee): Move to above of Skippable class for better read flow. -def skippable( - stash: Iterable[str] = (), pop: Iterable[str] = (), -) -> Callable[[Type[SkippableModule]], Type[Skippable]]: - """Define a decorator to create :class:`nn.Module ` with skip connections. - - These decorated modules are called "skippable". This functionality works perfectly - fine even when the module is not wrapped by :class:`~torch.distributed.pipeline.sync.Pipe`. - - Each skip tensor is managed by its name. Before manipulating skip tensors, - a skippable module must statically declare the names for skip tensors by - `stash` and/or `pop` parameters. Skip tensors with pre-declared name can be - stashed by ``yield stash(name, tensor)`` or popped by ``tensor = yield - pop(name)``. - - Here is an example with three layers. A skip tensor named "1to3" is stashed - and popped at the first and last layer, respectively:: - - @skippable(stash=['1to3']) - class Layer1(nn.Module): - def forward(self, input): - yield stash('1to3', input) - return f1(input) - - class Layer2(nn.Module): - def forward(self, input): - return f2(input) - - @skippable(pop=['1to3']) - class Layer3(nn.Module): - def forward(self, input): - skip_1to3 = yield pop('1to3') - return f3(input) + skip_1to3 - - model = nn.Sequential(Layer1(), Layer2(), Layer3()) - - One skippable module can stash or pop multiple skip tensors:: - - @skippable(stash=['alice', 'bob'], pop=['carol']) - class StashStashPop(nn.Module): - def forward(self, input): - yield stash('alice', f_alice(input)) - yield stash('bob', f_bob(input)) - carol = yield pop('carol') - return input + carol - - Every skip tensor must be associated with exactly one pair of `stash` and - `pop`. :class:`~torch.distributed.pipeline.sync.Pipe` checks this - restriction automatically when wrapping a module. You can also check the - restriction by :func:`verify_skippables` - without :class:`~torch.distributed.pipeline.sync.Pipe`. - - """ - stashable_names = frozenset(stash) - poppable_names = frozenset(pop) - - def extend_skippable(module_cls: Type[SkippableModule]) -> Type[Skippable]: - name = module_cls.__name__ - bases = (Skippable,) - attrs = {"module_cls": module_cls, "stashable_names": stashable_names, "poppable_names": poppable_names} - return type(name, bases, attrs) - - return extend_skippable - - -class stash: - """The command to stash a skip tensor. - - :: - - def forward(self, input): - yield stash('name', input) - return f(input) - - Args: - name (str): name of skip tensor - input (torch.Tensor or None): tensor to pass to the skip connection - - """ - - __slots__ = ("name", "tensor") - - def __init__(self, name: str, tensor: Optional[Tensor]) -> None: - self.name = name - self.tensor = tensor - - -class pop: - """The command to pop a skip tensor. - - :: - - def forward(self, input): - skip = yield pop('name') - return f(input) + skip - - Args: - name (str): name of skip tensor - - Returns: - the skip tensor previously stashed by another layer under the same name - - """ - - __slots__ = ("name",) - - def __init__(self, name: str) -> None: - self.name = name - - -def verify_skippables(module: nn.Sequential) -> None: - """Verify if the underlying skippable modules satisfy integrity. - - Every skip tensor must have only one pair of `stash` and `pop`. If there - are one or more unmatched pairs, it will raise :exc:`TypeError` with the - detailed messages. - - Here are a few failure cases. :func:`verify_skippables` will report failure - for these cases:: - - # Layer1 stashes "1to3". - # Layer3 pops "1to3". - - nn.Sequential(Layer1(), Layer2()) - # +---- ? - - nn.Sequential(Layer2(), Layer3()) - # ? ----+ - - nn.Sequential(Layer1(), Layer2(), Layer3(), Layer3()) - # +-------------------+ ^^^^^^ - - nn.Sequential(Layer1(), Layer1(), Layer2(), Layer3()) - # ^^^^^^ +-------------------+ - - To use the same name for multiple skip tensors, they must be isolated by - different namespaces. See :meth:`isolate() - `. - - Raises: - TypeError: - one or more pairs of `stash` and `pop` are not matched. - - """ - stashed: Set[Tuple[Namespace, str]] = set() - popped: Set[Tuple[Namespace, str]] = set() - msgs: List[str] = [] - - for layer_name, layer in module.named_children(): - if not isinstance(layer, Skippable): - continue - - for name in layer.stashable_names & layer.poppable_names: - msg = f"'{layer_name}' declared '{name}' both as stashable and as poppable" - msgs.append(msg) - - for ns, name in layer.stashable(): - if name in layer.poppable_names: - continue - - if (ns, name) in stashed: - msg = f"'{layer_name}' redeclared '{name}' as stashable but not isolated by namespace" - msgs.append(msg) - continue - - stashed.add((ns, name)) - - for ns, name in layer.poppable(): - if name in layer.stashable_names: - continue - - if (ns, name) in popped: - msg = f"'{layer_name}' redeclared '{name}' as poppable but not isolated by namespace" - msgs.append(msg) - continue - - if (ns, name) not in stashed: - msg = f"'{layer_name}' declared '{name}' as poppable but it was not stashed" - msgs.append(msg) - continue - - popped.add((ns, name)) - - for (_, name) in stashed - popped: - msg = f"no module declared '{name}' as poppable but stashed" - msgs.append(msg) - - if msgs: - raise TypeError( - "one or more pairs of stash and pop do not match:\n\n{}" "".format("\n".join(f"* {x}" for x in msgs)) - ) diff --git a/torch/distributed/pipeline/sync/skip/tracker.py b/torch/distributed/pipeline/sync/skip/tracker.py deleted file mode 100644 index 8ac82bc05dc94..0000000000000 --- a/torch/distributed/pipeline/sync/skip/tracker.py +++ /dev/null @@ -1,180 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -"""Tracks skip tensors on a thread.""" -from contextlib import contextmanager -import threading -from typing import Dict, Generator, List, Optional, Tuple - -from torch import Tensor - -from ..checkpoint import is_checkpointing -from ..dependency import fork, join -from ..microbatch import Batch -from ..stream import AbstractStream -from .layout import SkipLayout -from .namespace import Namespace -from .portal import Portal - -__all__: List[str] = [] - - -class SkipTracker: - """Tracks saved skip tensors. - - It will update the given micro-batch in place. This is because when it - manipulates the underlying skip tensors, the current micro-batch also has - to be connected with the skip tensors. - - One thread has one skip tracker. Call :func:`current_skip_tracker` to get - the skip tracker on the current thread. - - """ - - def __init__(self) -> None: - self.tensors: Dict[Tuple[Namespace, str], Optional[Tensor]] = {} - - def save(self, batch: Batch, ns: Namespace, name: str, tensor: Optional[Tensor]) -> None: - self.tensors[(ns, name)] = tensor - - def load(self, batch: Batch, ns: Namespace, name: str) -> Optional[Tensor]: - return self.tensors.pop((ns, name)) - - def copy( - self, batch: Batch, prev_stream: AbstractStream, next_stream: AbstractStream, ns: Namespace, name: str, - ) -> None: - raise TypeError("copy is not supported for non-portal skip tensors") - - -class SkipTrackerThroughPotals(SkipTracker): - """Tracks saved skip tensors through portals. The skip tensors will be - hidden in portals so that the autograd engine does not need to track them. - - This tracker is only used when the training or evaluating module is wrapped - with :class:`torchpipe.Pipe`. - - """ - - def __init__(self, skip_layout: SkipLayout) -> None: - super().__init__() - self.skip_layout = skip_layout - self.portals: Dict[Tuple[Namespace, str], Portal] = {} - - def save(self, batch: Batch, ns: Namespace, name: str, tensor: Optional[Tensor]) -> None: - """Saves the stashed skip tensor in a portal. The portal is then - connected to the given micro-batch with :class:`Join`. - """ - if not self.skip_layout.requires_copy(ns, name): - super().save(batch, ns, name, tensor) - return - - # See [Tensor Life of Portal] at Portal.put_tensor() to understand the - # below tensor_life values. Here are the selected events which retrieve - # the tensor in portal: - # - # 1. [x] blue() - # ... - # 6. [x] PortalOrange.forward - # ... - # 8. [x] PortalOrange.forward (recomputed) - # ... - # 11. [x] blue() (recomputed) - # - if (ns, name) not in self.portals: - if is_checkpointing(): - # Under checkpointing, the tensor used by the first - # PortalOrange should be alive in the portal. This tensor will - # be used again by the second PortalOrange during the - # recomputation. - tensor_life = 3 # Delete at [8. PortalOrange.forward (recomputed)] - else: - tensor_life = 2 # Delete at [6. PortalOrange.forward] - - portal = Portal(tensor, tensor_life) - self.portals[(ns, name)] = portal - - else: - # Under recomputation, the portal already exists. - portal = self.portals[(ns, name)] - - # The existing tensor life already became 0. It should be reset as - # 1 to delete the tensor after the second PortalBlue immediately. - tensor_life = 1 # Delete at [11. blue() (recomputed)] - - portal.put_tensor(tensor, tensor_life) - - phony = portal.blue() - tensor_idx = batch.find_tensor_idx() - batch[tensor_idx] = join(batch[tensor_idx], phony) - - def load(self, batch: Batch, ns: Namespace, name: str) -> Optional[Tensor]: - """Loads a skip tensor from the corresponding portal to pop. The given - micro-batch is connected to the portal with :class:`Fork`. - """ - if not self.skip_layout.requires_copy(ns, name): - tensor = super().load(batch, ns, name) - return tensor - - portal = self.portals[(ns, name)] - tensor_idx = batch.find_tensor_idx() - batch[tensor_idx], phony = fork(batch[tensor_idx]) - tensor = portal.orange(phony) - return tensor - - def copy( - self, batch: Batch, prev_stream: AbstractStream, next_stream: AbstractStream, ns: Namespace, name: str, - ) -> None: - """Copies the skip tensor in the corresponding portal. The given - micro-batch and the portal will be tied with :class:`Fork` and - :class:`Join`. - """ - assert self.skip_layout.requires_copy(ns, name) - - tensor_idx = batch.find_tensor_idx() - batch[tensor_idx], phony = fork(batch[tensor_idx]) - - portal = self.portals[(ns, name)] - phony = portal.copy(prev_stream, next_stream, phony) - - batch[tensor_idx] = join(batch[tensor_idx], phony) - - -class ThreadLocal(threading.local): - def __init__(self) -> None: - self.skip_tracker: Optional[SkipTracker] = None - - -thread_local = ThreadLocal() - - -@contextmanager -def use_skip_tracker(skip_tracker: SkipTracker) -> Generator[None, None, None]: - """Registers the given skip tracker on the current thread within a - context:: - - with use_skip_tracker(my_skip_tracker): - ... - - """ - orig = thread_local.skip_tracker - - thread_local.skip_tracker = skip_tracker - - try: - yield - finally: - thread_local.skip_tracker = orig - - -def current_skip_tracker() -> SkipTracker: - """Gets the skip tracker on the current thread.""" - skip_tracker = thread_local.skip_tracker - - if skip_tracker is None: - skip_tracker = SkipTracker() - thread_local.skip_tracker = skip_tracker - - return skip_tracker diff --git a/torch/distributed/pipeline/sync/stream.py b/torch/distributed/pipeline/sync/stream.py deleted file mode 100644 index 59fedf865a42b..0000000000000 --- a/torch/distributed/pipeline/sync/stream.py +++ /dev/null @@ -1,120 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -"""Utilities for eliminating boilerplate code to handle abstract streams with -CPU device. -""" -from contextlib import contextmanager -from typing import Generator, List, Union, cast - -import torch - -__all__: List[str] = ["CPUStreamType", "new_stream", "current_stream", "default_stream", - "use_device", "use_stream", "get_device", "wait_stream", "record_stream", - "is_cuda", "as_cuda"] - - -class CPUStreamType: - pass - - -# The placeholder on place of streams for the CPU device instead of CUDA. -CPUStream = CPUStreamType() - -# It represents both CUDA streams and the CPU stream. -AbstractStream = Union[torch.cuda.Stream, CPUStreamType] - - -def new_stream(device: torch.device) -> AbstractStream: - """Creates a new stream for either CPU or CUDA device.""" - if device.type != "cuda": - return CPUStream - return torch.cuda.Stream(device) - - -def current_stream(device: torch.device) -> AbstractStream: - """:func:`torch.cuda.current_stream` for either CPU or CUDA device.""" - if device.type != "cuda": - return CPUStream - return torch.cuda.current_stream(device) - - -def default_stream(device: torch.device) -> AbstractStream: - """:func:`torch.cuda.default_stream` for either CPU or CUDA device.""" - if device.type != "cuda": - return CPUStream - return torch.cuda.default_stream(device) - - -@contextmanager -def use_device(device: torch.device) -> Generator[None, None, None]: - """:func:`torch.cuda.device` for either CPU or CUDA device.""" - if device.type != "cuda": - yield - return - - with torch.cuda.device(device): - yield - - -@contextmanager -def use_stream(stream: AbstractStream) -> Generator[None, None, None]: - """:func:`torch.cuda.stream` for either CPU or CUDA stream.""" - if not is_cuda(stream): - yield - return - - with torch.cuda.stream(as_cuda(stream)): - yield - - -def get_device(stream: AbstractStream) -> torch.device: - """Gets the device from CPU or CUDA stream.""" - if is_cuda(stream): - return as_cuda(stream).device - return torch.device("cpu") - - -def wait_stream(source: AbstractStream, target: AbstractStream) -> None: - """:meth:`torch.cuda.Stream.wait_stream` for either CPU or CUDA stream. It - makes the source stream wait until the target stream completes work queued. - """ - if is_cuda(target): - if is_cuda(source): - # A CUDA stream waits another CUDA stream. - as_cuda(source).wait_stream(as_cuda(target)) - else: - # CPU waits a CUDA stream. - as_cuda(target).synchronize() - - # If the target is CPU, synchronization is not required. - - -def record_stream(tensor: torch.Tensor, stream: AbstractStream) -> None: - """:meth:`torch.Tensor.record_stream` for either CPU or CUDA stream.""" - if is_cuda(stream): - # NOTE(sublee): record_stream() on a shifted view tensor throws - # RuntimeError in PyTorch 1.1.0, and does nothing in 1.2.0. To safely - # protect the tensor against unexpected reallocation, here we use a - # temporal tensor associated with the same storage without shifting as - # a workaround. - # - # Issue: https://github.com/pytorch/pytorch/issues/27366 - # - tensor = tensor.new_empty([0]).set_(tensor._typed_storage()) - - # Typechecking: torch.cuda.Stream is incompatible with torch._C.Stream - tensor.record_stream(as_cuda(stream)) # type: ignore[arg-type] - - -def is_cuda(stream: AbstractStream) -> bool: - """Returns ``True`` if the given stream is a valid CUDA stream.""" - return stream is not CPUStream - - -def as_cuda(stream: AbstractStream) -> torch.cuda.Stream: - """Casts the given stream as :class:`torch.cuda.Stream`.""" - return cast(torch.cuda.Stream, stream) diff --git a/torch/distributed/pipeline/sync/utils.py b/torch/distributed/pipeline/sync/utils.py deleted file mode 100644 index 210c475317e2c..0000000000000 --- a/torch/distributed/pipeline/sync/utils.py +++ /dev/null @@ -1,38 +0,0 @@ -from torch import nn -from typing import List, Optional - -__all__ = ["partition_model"] - -def partition_model( - module: nn.Sequential, - balance: List[int], - devices: Optional[List[int]] = None): - """ - Partions the model accross multiple GPU devices. - - Given an :class:`nn.Sequential ` module, partitions - the model across multiple GPU devices according the provided ``balance`` - and ``devices``. - - Args: - module (:class:`nn.Sequential `): - Sequential model representing the pipe. - balance (List[int]): - List indicating the number of layers in each partition. - devices (List[int], optional): - List indicating the device to use for each partition. Defaults to - ``range(len(balance))`` - """ - device_idx = 0 - pipe_idx = 0 - balanced_pipe = [] - for num_layers in balance: - layers = [] - for i in range(num_layers): - layers.append(module[pipe_idx]) - pipe_idx += 1 - device = device_idx if devices is None else devices[device_idx] - balanced_pipe.append(nn.Sequential(*layers).to(device)) - device_idx += 1 - - return nn.Sequential(*balanced_pipe) diff --git a/torch/distributed/pipeline/sync/worker.py b/torch/distributed/pipeline/sync/worker.py deleted file mode 100644 index 87b20c4a55519..0000000000000 --- a/torch/distributed/pipeline/sync/worker.py +++ /dev/null @@ -1,132 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -"""Multithreading in pipeline parallelism.""" -from contextlib import contextmanager -from queue import Queue -import sys -from threading import Thread -from types import TracebackType -from typing import TYPE_CHECKING, Callable, Dict, Generator, List, Optional, Tuple, Type, Union, cast - -import torch - -from .microbatch import Batch -from .stream import AbstractStream, use_device, use_stream - -__all__: List[str] = ["Task", "worker", "create_workers", "spawn_workers"] - - -ExcInfo = Tuple[Type[BaseException], BaseException, TracebackType] - -# Queue is generic only in stubs. -# https://mypy.readthedocs.io/en/latest/common_issues.html#using-classes-that-are-generic-in-stubs-but-not-at-runtime -if TYPE_CHECKING: - InQueue = Queue[Optional["Task"]] - OutQueue = Queue[Tuple[bool, Union[Tuple["Task", Batch], ExcInfo, None]]] -else: - InQueue = Queue - OutQueue = Queue - - -class Task: - """A task represents how to compute a micro-batch on a partition. - - It consists of two parts: :meth:`compute` and :meth:`finalize`. - :meth:`compute` should be executed in worker threads concurrently. - :meth:`finalize` should be executed after when worker threads complete to - execute :meth:`compute`. - - :meth:`compute` might be boosted by worker threads. Because it produces - several CUDA API calls by user code. In PyTorch, parallel CUDA API calls - are not serialized through GIL. So more than one CUDA API call can be - produced at the same time. - - """ - - def __init__( - self, stream: AbstractStream, *, compute: Callable[[], Batch], finalize: Optional[Callable[[Batch], None]], - ) -> None: - self.stream = stream - self._compute = compute - self._finalize = finalize - self._grad_enabled = torch.is_grad_enabled() - - def compute(self) -> Batch: - with use_stream(self.stream), torch.set_grad_enabled(self._grad_enabled): - return self._compute() - - def finalize(self, batch: Batch) -> None: - if self._finalize is None: - return - with use_stream(self.stream), torch.set_grad_enabled(self._grad_enabled): - self._finalize(batch) - - -def worker(in_queue: InQueue, out_queue: OutQueue, device: torch.device) -> None: - """Main loop of a worker thread.""" - with use_device(device): - while True: - task = in_queue.get() - - if task is None: - break - - try: - batch = task.compute() - except Exception: - exc_info = cast(ExcInfo, sys.exc_info()) - out_queue.put((False, exc_info)) - continue - - out_queue.put((True, (task, batch))) - - done = (False, None) - out_queue.put(done) - - -def create_workers(devices: List[torch.device],) -> Tuple[List[InQueue], List[OutQueue]]: - """Spawns worker threads. A worker thread is bound to a device.""" - in_queues: List[InQueue] = [] - out_queues: List[OutQueue] = [] - - # Spawn workers. - workers: Dict[torch.device, Tuple[InQueue, OutQueue]] = {} - - def normalize_device(device: torch.device) -> torch.device: - if device.type == "cuda" and device.index is None: - return torch.device("cuda", index=torch.cuda.current_device()) - - if device.type == "cpu" and device.index is not None: - return torch.device("cpu") - - return device - - for device in devices: - device = normalize_device(device) - - try: - in_queue, out_queue = workers[device] - except KeyError: - in_queue = Queue() - out_queue = Queue() - workers[device] = (in_queue, out_queue) - - t = Thread(target=worker, args=(in_queue, out_queue, device), daemon=True,) - t.start() - - in_queues.append(in_queue) - out_queues.append(out_queue) - - return (in_queues, out_queues) - -@contextmanager -def spawn_workers(devices: List[torch.device],) -> Generator[Tuple[List[InQueue], List[OutQueue]], None, None]: - try: - (in_queues, out_queues) = create_workers(devices) - yield (in_queues, out_queues) - finally: - pass diff --git a/torch/distributed/pipelining/PipelineSchedule.py b/torch/distributed/pipelining/PipelineSchedule.py index fabc9377277a5..28b7514ab16f4 100644 --- a/torch/distributed/pipelining/PipelineSchedule.py +++ b/torch/distributed/pipelining/PipelineSchedule.py @@ -303,6 +303,17 @@ def __init__( self._stage.has_backward = self._has_backward def step(self, *args, target=None, losses: Optional[List] = None, **kwargs): + """ + Run one iteration of the pipeline schedule with *whole-batch* input. + Will chunk the input into microbatches automatically, and go through the + microbatches according to the schedule implementation. + + args: positional arguments to the model (as in non-pipeline case). + kwargs: keyword arguments to the model (as in non-pipeline case). + target: target for the loss function. + losses: a list to store the losses for each microbatch. + """ + # Clean per iteration self._stage.clear_runtime_states() @@ -583,6 +594,17 @@ def __init__( ) def step(self, *args, target=None, losses: Optional[List] = None, **kwargs): + """ + Run one iteration of the pipeline schedule with *whole-batch* input. + Will chunk the input into microbatches automatically, and go through the + microbatches according to the schedule implementation. + + args: positional arguments to the model (as in non-pipeline case). + kwargs: keyword arguments to the model (as in non-pipeline case). + target: target for the loss function. + losses: a list to store the losses for each microbatch. + """ + # Clean per iteration for stage in self._stages: stage.clear_runtime_states() diff --git a/torch/testing/_internal/distributed/pipe_with_ddp_test.py b/torch/testing/_internal/distributed/pipe_with_ddp_test.py deleted file mode 100644 index 1ed9f3cc96dfc..0000000000000 --- a/torch/testing/_internal/distributed/pipe_with_ddp_test.py +++ /dev/null @@ -1,149 +0,0 @@ -# mypy: ignore-errors - -import torch -import torch.distributed as dist - -from torch import nn -from torch.nn.parallel import DistributedDataParallel -from torch.testing._internal.dist_utils import INIT_METHOD_TEMPLATE, dist_init -from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import ( - RpcAgentTestFixture, -) -from torch.testing._internal.common_distributed import ( - requires_gloo, - requires_nccl, - skip_if_lt_x_gpu, - skip_if_rocm, -) -from torch.distributed.pipeline.sync import Pipe - -class PipeWithDDPTest(RpcAgentTestFixture): - @property - def world_size(self) -> int: - return 2 - - @skip_if_lt_x_gpu(4) - @requires_nccl() - @dist_init - @skip_if_rocm - def test_basic_nccl_ckpt_never(self): - self._run_basic_test("nccl", "never") - - @skip_if_lt_x_gpu(4) - @requires_nccl() - @dist_init - @skip_if_rocm - def test_basic_nccl_ckpt_never_find_unused(self): - self._run_basic_test("nccl", "never", find_unused_parameters=True) - - @skip_if_lt_x_gpu(4) - @requires_nccl() - @dist_init - @skip_if_rocm - def test_basic_nccl_ckpt_always(self): - self._run_basic_test("nccl", "always", static_graph=True) - - @skip_if_lt_x_gpu(4) - @requires_nccl() - @dist_init - @skip_if_rocm - def test_basic_nccl_ckpt_except_last(self): - self._run_basic_test("nccl", "except_last", static_graph=True) - - @skip_if_lt_x_gpu(4) - @requires_gloo() - @dist_init - @skip_if_rocm - def test_basic_gloo_ckpt_never(self): - self._run_basic_test("gloo", "never") - - @skip_if_lt_x_gpu(4) - @requires_gloo() - @dist_init - @skip_if_rocm - def test_basic_gloo_ckpt_never_find_unused(self): - self._run_basic_test("gloo", "never", find_unused_parameters=True) - - @skip_if_lt_x_gpu(4) - @requires_gloo() - @dist_init - @skip_if_rocm - def test_basic_gloo_ckpt_always(self): - self._run_basic_test("gloo", "always", static_graph=True) - - @skip_if_lt_x_gpu(4) - @requires_gloo() - @dist_init - @skip_if_rocm - def test_basic_gloo_ckpt_except_last(self): - self._run_basic_test("gloo", "except_last", static_graph=True) - - def _run_basic_test(self, backend, checkpoint, find_unused_parameters=False, static_graph=False): - dist.init_process_group( - backend=backend, - init_method=INIT_METHOD_TEMPLATE.format(file_name=self.file_name), - world_size=self.world_size, - rank=self.rank, - ) - - # Use 4 GPUs, two replicas of a pipe across GPU 0 and 1 and another - # pipe between GPU 2 and 3. Both replicas are replicated via DDP. - fc1 = nn.Linear(16, 8, bias=False).cuda(2 * self.rank) - - class MyModule(nn.Module): - def __init__(self, device): - super().__init__() - self.fc2 = nn.Linear(8, 4, bias=False).cuda(device) - self.fc3 = nn.Linear(4, 2, bias=False).cuda(device) - - def forward(self, inp): - if find_unused_parameters: - return self.fc2(inp) - else: - return self.fc3(self.fc2(inp)) - - layer2 = MyModule(2 * self.rank + 1) - model = nn.Sequential( - fc1, - layer2 - ) - model = Pipe(model, chunks=2, checkpoint=checkpoint) - model = DistributedDataParallel( - model, - find_unused_parameters=find_unused_parameters, - static_graph=static_graph, - ) - - # Ensure inputs are different across ranks to verify that gradient - # sync indeed occurs. - model_input = torch.rand(16, 16).cuda(2 * self.rank) * (self.rank + 1) - out = model(model_input).local_value() - out.sum().backward() - - # Run forward again for find_unused_parameters to trigger any potential errors. - if find_unused_parameters: - # Ensure inputs are different across ranks to verify that gradient - # sync indeed occurs. - unused_param_input = torch.rand(16, 16).cuda(2 * self.rank) * (self.rank + 1) - model(unused_param_input).local_value().sum().backward() - - # Run a few more iterations of fwd + bwd to ensure gradient synchronization - # occurs properly across iterations via delay_all_reduce/bucketized allreduce. - for _ in range(3): - model_input = torch.rand(16, 16).cuda(2 * self.rank) * (self.rank + 1) - out = model(model_input).local_value() - out.sum().backward() - - # Check grads - output = [torch.empty_like(fc1.weight.grad), torch.empty_like(fc1.weight.grad)] - dist.all_gather(output, fc1.weight.grad) - self.assertEqual(output[0], output[1]) - - output = [torch.empty_like(layer2.fc2.weight.grad), torch.empty_like(layer2.fc2.weight.grad)] - dist.all_gather(output, layer2.fc2.weight.grad) - self.assertEqual(output[0], output[1]) - - if not find_unused_parameters: - output = [torch.empty_like(layer2.fc3.weight.grad), torch.empty_like(layer2.fc3.weight.grad)] - dist.all_gather(output, layer2.fc3.weight.grad) - self.assertEqual(output[0], output[1]) diff --git a/torch/testing/_internal/distributed/pipeline/__init__.py b/torch/testing/_internal/distributed/pipeline/__init__.py deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/torch/testing/_internal/distributed/rpc_utils.py b/torch/testing/_internal/distributed/rpc_utils.py index cdbbdcfd06814..5b6e2c90770f4 100644 --- a/torch/testing/_internal/distributed/rpc_utils.py +++ b/torch/testing/_internal/distributed/rpc_utils.py @@ -16,9 +16,6 @@ DdpComparisonTest, DdpUnderDistAutogradTest, ) -from torch.testing._internal.distributed.pipe_with_ddp_test import ( - PipeWithDDPTest, -) from torch.testing._internal.distributed.nn.api.remote_module_test import ( CudaRemoteModuleTest, RemoteModuleTest, @@ -121,7 +118,6 @@ def tearDown(self): CudaDistAutogradTest, CudaRemoteModuleTest, CudaDdpComparisonTest, - PipeWithDDPTest, ]