Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
0df4743
tpu amp
cowanmeg Mar 8, 2023
7997da0
Add torch pin
cowanmeg Mar 16, 2023
6ac8262
Merge branch 'master' of https://github.com/cowanmeg/xla into amp
cowanmeg Mar 22, 2023
042a631
Merge branch 'master' of https://github.com/cowanmeg/xla into amp
cowanmeg Mar 22, 2023
401eaf6
updates
cowanmeg Apr 11, 2023
041fce2
Clean up
cowanmeg Apr 12, 2023
0aa5011
Delete test_fsdp_auto_wrap_amp.py
cowanmeg Apr 12, 2023
2157e2c
updates
cowanmeg Apr 13, 2023
e247b43
Update autocast key to pick between Cuda and Xla. Unit tests
cowanmeg Apr 14, 2023
7d215d1
lint
cowanmeg Apr 14, 2023
661759a
moving code from pt to ptxla
cowanmeg Apr 14, 2023
e339506
fixes
cowanmeg Apr 17, 2023
03da7ec
lint
cowanmeg Apr 17, 2023
e4a0daa
move autocast+test from common.sh to run_tests.sh
cowanmeg Apr 17, 2023
ff8f10b
updates
cowanmeg Apr 24, 2023
a0437a5
Merge branch 'master' of https://github.com/cowanmeg/xla into amp
cowanmeg Apr 24, 2023
410912e
updates with pytorch
cowanmeg Apr 24, 2023
aabbd12
lint
cowanmeg Apr 24, 2023
1ce912c
Merge branch 'master' of https://github.com/cowanmeg/xla into amp
cowanmeg Apr 26, 2023
dfe1f27
Merge branch 'master' of https://github.com/cowanmeg/xla into amp
cowanmeg Apr 26, 2023
5275334
build autocast_mode
cowanmeg Apr 26, 2023
0652dd6
Merge branch 'master' of https://github.com/cowanmeg/xla into amp
cowanmeg May 8, 2023
089f296
Merge branch 'master' of https://github.com/cowanmeg/xla into amp
cowanmeg Jun 1, 2023
e1243d8
Merge branch 'master' of https://github.com/pytorch/xla into amp
cowanmeg Jun 1, 2023
85d607a
Merge branch 'master' of https://github.com/pytorch/xla into amp
cowanmeg Jun 2, 2023
9241b0a
Merge branch 'master' of https://github.com/pytorch/xla into amp
cowanmeg Jun 6, 2023
643691c
Disable bazel remote cache
cowanmeg Jun 7, 2023
4eba81f
experiment with no new files
cowanmeg Jun 8, 2023
ffac43f
revert back
cowanmeg Jun 8, 2023
4097324
lint
cowanmeg Jun 8, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions .circleci/common.sh
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,6 @@ function run_torch_xla_tests() {
chmod -R 755 ~/htmlcov
else
./test/run_tests.sh
# only run test_autocast for cpu and gpu on circleCI.
python test/test_autocast.py

# GPU tests
if [ -x "$(command -v nvidia-smi)" ]; then
Expand Down
1 change: 1 addition & 0 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ function run_xla_op_tests {
run_test "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY
run_test_without_functionalization "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY
run_test "$CDIR/test_async_closures.py"
run_test "$CDIR/test_autocast.py"
run_test "$CDIR/test_xla_dist.py"
run_test "$CDIR/test_profiler.py"
run_test "$CDIR/test_ops.py"
Expand Down
273 changes: 232 additions & 41 deletions test/test_autocast.py

Large diffs are not rendered by default.

24 changes: 15 additions & 9 deletions test/test_train_mp_imagenet_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
'--test_only_at_end': {
'action': 'store_true',
},
# AMP only works with XLA:GPU
'--amp': {
'action': 'store_true',
},
Expand Down Expand Up @@ -197,6 +196,7 @@ def train_imagenet():
torch.manual_seed(42)

device = xm.xla_device()
device_hw = xm.xla_device_hw(device)
model = get_model_property('model_fn')().to(device)
writer = None
if xm.is_master_ordinal():
Expand All @@ -219,23 +219,29 @@ def train_imagenet():
summary_writer=writer)
loss_fn = nn.CrossEntropyLoss()
if FLAGS.amp:
scaler = GradScaler(use_zero_grad=FLAGS.use_zero_grad)
if device_hw == 'TPU':
scaler = None
elif device_hw == 'GPU':
scaler = GradScaler(use_zero_grad=FLAGS.use_zero_grad)

def train_loop_fn(loader, epoch):
tracker = xm.RateTracker()
model.train()
for step, (data, target) in enumerate(loader):
optimizer.zero_grad()
if FLAGS.amp:
with autocast():
with autocast(xm.xla_device()):
output = model(data)
loss = loss_fn(output, target)

scaler.scale(loss).backward()
gradients = xm._fetch_gradients(optimizer)
xm.all_reduce('sum', gradients, scale=1.0 / xm.xrt_world_size())
scaler.step(optimizer)
scaler.update()
if scaler:
scaler.scale(loss).backward()
gradients = xm._fetch_gradients(optimizer)
xm.all_reduce('sum', gradients, scale=1.0 / xm.xrt_world_size())
scaler.step(optimizer)
scaler.update()
else:
loss.backward()
xm.optimizer_step(optimizer)
else:
output = model(data)
loss = loss_fn(output, target)
Expand Down
32 changes: 23 additions & 9 deletions test/test_train_mp_mnist_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,10 @@ def forward(self, x):
return F.log_softmax(x, dim=1)


def _train_update(device, x, loss, tracker, writer):
def _train_update(device, step, loss, tracker, writer):
test_utils.print_training_update(
device,
x,
step,
loss.item(),
tracker.rate(),
tracker.global_rate(),
Expand Down Expand Up @@ -130,28 +130,42 @@ def train_mnist(flags, **kwargs):
lr = flags.lr * xm.xrt_world_size()

device = xm.xla_device()
device_hw = xm.xla_device_hw(device)
model = MNIST().to(device)

writer = None
if xm.is_master_ordinal():
writer = test_utils.get_summary_writer(flags.logdir)
optim_cls = syncfree.SGD if FLAGS.use_syncfree_optim else optim.SGD
optimizer = optim_cls(model.parameters(), lr=lr, momentum=flags.momentum)
loss_fn = nn.NLLLoss()
scaler = GradScaler(use_zero_grad=FLAGS.use_zero_grad)

if device_hw == 'TPU':
scaler = None
elif device_hw == 'GPU':
# GradScaler only used for GPU
scaler = GradScaler(use_zero_grad=FLAGS.use_zero_grad)
else:
print("Only TPU or GPU supported for AMP.")
sys.exit(1)

def train_loop_fn(loader):
tracker = xm.RateTracker()
model.train()
for step, (data, target) in enumerate(loader):
optimizer.zero_grad()
with autocast():
with autocast(device):
output = model(data)
loss = loss_fn(output, target)
scaler.scale(loss).backward()
gradients = xm._fetch_gradients(optimizer)
xm.all_reduce('sum', gradients, scale=1.0 / xm.xrt_world_size())
scaler.step(optimizer)
scaler.update()
if scaler:
scaler.scale(loss).backward()
gradients = xm._fetch_gradients(optimizer)
xm.all_reduce('sum', gradients, scale=1.0 / xm.xrt_world_size())
scaler.step(optimizer)
scaler.update()
else:
loss.backward()
xm.optimizer_step(optimizer)
tracker.add(flags.batch_size)
if step % flags.log_steps == 0:
xm.add_step_closure(
Expand Down
1 change: 1 addition & 0 deletions torch_patches/.torch_pin
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
#96370
2 changes: 1 addition & 1 deletion torch_xla/amp/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .autocast_mode import autocast, custom_fwd, custom_bwd # noqa: F401
from .autocast_mode import autocast # noqa: F401
from .grad_scaler import GradScaler # noqa: F401
41 changes: 38 additions & 3 deletions torch_xla/amp/autocast_mode.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,40 @@
import torch
import torch_xla.core.xla_model as xm
from typing import Any

autocast = torch.cuda.amp.autocast
custom_fwd = torch.cuda.amp.custom_fwd
custom_bwd = torch.cuda.amp.custom_bwd

class autocast(torch.amp.autocast_mode.autocast):
r"""
See :class:`torch.autocast`.
``torch_xla.amp.autocast(device, args...)`` is equivalent to ``torch.autocast("xla", args...)`` for TPUs
``torch.autocast("cuda", args...)`` for GPUs.
"""

def __init__(self,
device,
enabled: bool = True,
dtype: torch.dtype = torch.bfloat16,
cache_enabled: bool = True):
if xm.xla_device_hw(device) == 'GPU':
super().__init__(
"cuda",
enabled=enabled,
dtype=torch.float16,
cache_enabled=cache_enabled)
elif xm.xla_device_hw(device) == 'TPU':
super().__init__(
"xla", enabled=enabled, dtype=dtype, cache_enabled=cache_enabled)
else:
print(
'Warning: AMP only supported for XLA:TPU and XLA:GPU. Ignoring autocast.'
)

def __enter__(self):
return super().__enter__()

def __exit__(self, exc_type: Any, exc_val: Any,
exc_tb: Any): # type: ignore[override]
return super().__exit__(exc_type, exc_val, exc_tb)

def __call__(self, func):
return super().__call__(func)
1 change: 1 addition & 0 deletions torch_xla/csrc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ ptxla_cc_library(
"aten_autograd_ops.cpp",
"aten_xla_bridge.cpp",
"aten_xla_type.cpp",
"autocast_mode.cpp",
"batch_norm.cpp",
"convert_ops.cpp",
"convolution.cpp",
Expand Down
8 changes: 0 additions & 8 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -437,10 +437,6 @@ void XLANativeFunctions::_amp_foreach_non_finite_check_and_unscale_(
at::TensorList self, at::Tensor& found_inf, const at::Tensor& inv_scale) {
TORCH_LAZY_FN_COUNTER("xla::");
XLATensorPtr found_inf_tensor = bridge::GetXlaTensor(found_inf);
XlaDeviceType hw_type =
static_cast<XlaDeviceType>(found_inf_tensor->GetDevice().type());
XLA_CHECK(hw_type == XlaDeviceType::GPU || hw_type == XlaDeviceType::CPU)
<< "AMP should be used with XLA:GPU";
tensor_methods::_amp_foreach_non_finite_check_and_unscale_(
bridge::GetXlaTensors(self), found_inf_tensor,
bridge::GetXlaTensor(inv_scale));
Expand All @@ -455,10 +451,6 @@ at::Tensor& XLANativeFunctions::_amp_update_scale_(at::Tensor& current_scale,
TORCH_LAZY_FN_COUNTER("xla::");
XLATensorPtr growth_tracker_tensor = bridge::GetXlaTensor(growth_tracker);
XLATensorPtr current_scale_tensor = bridge::GetXlaTensor(current_scale);
XlaDeviceType hw_type =
static_cast<XlaDeviceType>(growth_tracker_tensor->GetDevice().type());
XLA_CHECK(hw_type == XlaDeviceType::GPU || hw_type == XlaDeviceType::CPU)
<< "AMP should be used with XLA:GPU";
tensor_methods::_amp_update_scale_(
growth_tracker_tensor, current_scale_tensor,
bridge::GetXlaTensor(found_inf), scale_growth_factor,
Expand Down
160 changes: 160 additions & 0 deletions torch_xla/csrc/autocast_mode.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
#include <ATen/ATen.h>
#include <ATen/NativeFunctions.h>
#include <ATen/Operators.h>
#include <ATen/autocast_mode.h>
#include <c10/core/impl/LocalDispatchKeySet.h>
#include <c10/util/intrusive_ptr.h>
#include <torch/library.h>

namespace at {
namespace autocast {
namespace {

#define KERNEL_XLA(OP, POLICY) KERNEL(c10::DeviceType::XLA, OP, POLICY)

#define KERNEL_XLA2(OP, OVERLOAD, POLICY) \
KERNEL2(c10::DeviceType::XLA, OP, OVERLOAD, POLICY)

#define KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_XLA( \
REDISPATCH_FUNC, REGISTER_NAME, REGISTER_SIGNATURE, REDISPATCH_SIGNATURE, \
POLICY) \
KERNEL_DIFFERENT_REDISPATCH_SIGNATURE(c10::DeviceType::XLA, REDISPATCH_FUNC, \
REGISTER_NAME, REGISTER_SIGNATURE, \
REDISPATCH_SIGNATURE, POLICY)

TORCH_LIBRARY_IMPL(_, AutocastXLA, m) {
m.fallback(torch::CppFunction::makeFallthrough());
}

TORCH_LIBRARY_IMPL(aten, AutocastXLA, m) {
// lower_precision_fp cast policy
KERNEL_XLA(conv1d, lower_precision_fp)
KERNEL_XLA2(conv1d, padding, lower_precision_fp)
KERNEL_XLA(conv2d, lower_precision_fp)
KERNEL_XLA2(conv2d, padding, lower_precision_fp)
KERNEL_XLA(conv3d, lower_precision_fp)
KERNEL_XLA2(conv3d, padding, lower_precision_fp)
KERNEL_XLA(bmm, lower_precision_fp)
KERNEL_XLA(mm, lower_precision_fp)
KERNEL_XLA(baddbmm, lower_precision_fp)
KERNEL_XLA(addmm, lower_precision_fp)
KERNEL_XLA(addbmm, lower_precision_fp)
KERNEL_XLA(linear, lower_precision_fp)
KERNEL_XLA(matmul, lower_precision_fp)
KERNEL_XLA(conv_tbc, lower_precision_fp)
KERNEL_XLA(conv_transpose1d, lower_precision_fp)
KERNEL_XLA2(conv_transpose2d, input, lower_precision_fp)
KERNEL_XLA2(conv_transpose3d, input, lower_precision_fp)
KERNEL_XLA(prelu, lower_precision_fp)
KERNEL_XLA(relu, lower_precision_fp)
KERNEL_XLA(max_pool2d, lower_precision_fp)

// fp32 cast policy
// Commented out ops are included in the AutoCastCPU Policy,
// but not lowered. Enable if op is lowered.
KERNEL_XLA(batch_norm, fp32)
KERNEL_XLA2(log_softmax, int, fp32)
KERNEL_XLA2(log_softmax, Dimname, fp32)
KERNEL_XLA(binary_cross_entropy, fp32)
// KERNEL_XLA(grid_sampler, fp32)
// KERNEL_XLA(polar, fp32)
KERNEL_XLA(prod, fp32)
KERNEL_XLA2(prod, dim_int, fp32)
KERNEL_XLA2(prod, dim_Dimname, fp32)
// KERNEL_XLA(quantile, fp32)
// KERNEL_XLA2(quantile, scalar, fp32)
// KERNEL_XLA(nanquantile, fp32)
// KERNEL_XLA2(nanquantile, scalar, fp32)
// KERNEL_XLA(stft, fp32)
// KERNEL_XLA2(stft, center, fp32)
KERNEL_XLA(cdist, fp32)
// KERNEL_XLA(grid_sampler_2d, fp32)
// KERNEL_XLA(grid_sampler_3d, fp32)
KERNEL_XLA(trace, fp32)
// KERNEL_XLA(view_as_complex, fp32)
KERNEL_XLA(cholesky, fp32)
KERNEL_XLA(cholesky_inverse, fp32)
KERNEL_XLA(cholesky_solve, fp32)
KERNEL_XLA(inverse, fp32)
// KERNEL_XLA(lu_solve, fp32)
// KERNEL_XLA(orgqr, fp32)
// KERNEL_XLA(ormqr, fp32)
// KERNEL_XLA(pinverse, fp32)
KERNEL_XLA(reflection_pad1d, fp32)
KERNEL_XLA(reflection_pad2d, fp32)
KERNEL_XLA(replication_pad1d, fp32)
KERNEL_XLA(replication_pad2d, fp32)
KERNEL_XLA(replication_pad3d, fp32)
KERNEL_XLA(mse_loss, fp32)
KERNEL_XLA(cosine_embedding_loss, fp32)
KERNEL_XLA(nll_loss, fp32)
KERNEL_XLA(nll_loss2d, fp32)
KERNEL_XLA(hinge_embedding_loss, fp32)
// KERNEL_XLA(poisson_nll_loss, fp32)
KERNEL_XLA(smooth_l1_loss, fp32)
// KERNEL_XLA(cross_entropy_loss, fp32)
KERNEL_XLA(l1_loss, fp32)
// KERNEL_XLA(huber_loss, fp32)
KERNEL_XLA(margin_ranking_loss, fp32)
KERNEL_XLA(soft_margin_loss, fp32)
KERNEL_XLA(triplet_margin_loss, fp32)
KERNEL_XLA(multi_margin_loss, fp32)
KERNEL_XLA2(ctc_loss, IntList, fp32)
KERNEL_XLA2(ctc_loss, Tensor, fp32)
KERNEL_XLA(kl_div, fp32)
KERNEL_XLA(multilabel_margin_loss, fp32)
KERNEL_XLA(binary_cross_entropy_with_logits, fp32)
// KERNEL_XLA(fft_fft, fp32)
// KERNEL_XLA(fft_ifft, fp32)
// KERNEL_XLA(fft_fft2, fp32)
// KERNEL_XLA(fft_ifft2, fp32)
// KERNEL_XLA(fft_fftn, fp32)
// KERNEL_XLA(fft_ifftn, fp32)
// KERNEL_XLA(fft_rfft, fp32)
// KERNEL_XLA(fft_irfft, fp32)
// KERNEL_XLA(fft_rfft2, fp32)
// KERNEL_XLA(fft_irfft2, fp32)
// KERNEL_XLA(fft_rfftn, fp32)
// KERNEL_XLA(fft_irfftn, fp32)
// KERNEL_XLA(fft_hfft, fp32)
// KERNEL_XLA(fft_ihfft, fp32)
// KERNEL_XLA(linalg_cond, fp32)
// KERNEL_XLA2(linalg_cond, p_str, fp32)
// KERNEL_XLA(linalg_matrix_rank, fp32)
// KERNEL_XLA2(linalg_matrix_rank, tol_tensor, fp32)
// KERNEL_XLA2(linalg_matrix_rank, atol_rtol_tensor, fp32)
// KERNEL_XLA2(linalg_matrix_rank, atol_rtol_float, fp32)
// KERNEL_XLA(linalg_solve, fp32)
// KERNEL_XLA(linalg_cholesky, fp32)
// KERNEL_XLA(linalg_svdvals, fp32)
// KERNEL_XLA(linalg_eigvals, fp32)
// KERNEL_XLA(linalg_eigvalsh, fp32)
// KERNEL_XLA(linalg_inv, fp32)
// KERNEL_XLA(linalg_householder_product, fp32)
// KERNEL_XLA(linalg_tensorinv, fp32)
// KERNEL_XLA(linalg_tensorsolve, fp32)
// KERNEL_XLA(fake_quantize_per_tensor_affine, fp32)
// KERNEL_XLA(geqrf, fp32)
// KERNEL_XLA(_lu_with_info, fp32)
KERNEL_XLA(qr, fp32)
KERNEL_XLA(svd, fp32)
KERNEL_XLA(triangular_solve, fp32)
KERNEL_XLA(multilabel_margin_loss_forward, fp32)
// KERNEL_XLA(linalg_qr, fp32)
// KERNEL_XLA(linalg_cholesky_ex, fp32)
KERNEL_XLA(linalg_svd, fp32)
// KERNEL_XLA(linalg_eig, fp32)
// KERNEL_XLA(linalg_eigh, fp32)
// KERNEL_XLA(linalg_lstsq, fp32)
KERNEL_XLA(linalg_inv_ex, fp32)

// promote
KERNEL_XLA(stack, promote)
KERNEL_XLA(cat, promote)
KERNEL_XLA(index_copy, promote)
KERNEL_XLA2(index_copy, dimname, promote)
}

} // namespace
} // namespace autocast
} // namespace at
10 changes: 10 additions & 0 deletions torch_xla/csrc/tensor_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <c10/core/ScalarType.h>
#include <c10/core/impl/DeviceGuardImplInterface.h>
#include <c10/core/impl/LocalDispatchKeySet.h>
#include <c10/macros/Macros.h>
#include <torch/csrc/lazy/backend/backend_interface.h>
#include <torch/csrc/lazy/core/tensor.h>
Expand Down Expand Up @@ -62,6 +63,15 @@ XLATensorImpl::XLATensorImpl(XLATensor&& tensor)
GetTypeMeta(tensor),
bridge::XlaDeviceToAtenDevice(tensor.GetDevice())),
tensor_(c10::make_intrusive<XLATensor>(std::move(tensor))) {
// Update the Autocast key based off the backend device.
// Upstream TensorImpl cannot differentiate between XLA:TPU and XLA:GPU
// so we must manually update Autocast to AutocastCUDA on XLA:GPU.
torch::lazy::BackendDevice current_device = GetCurrentDevice();
if (static_cast<XlaDeviceType>(current_device.type()) == XlaDeviceType::GPU) {
auto autocast_cuda_ks = c10::DispatchKeySet(c10::DispatchKey::AutocastCUDA);
auto autocast_xla_ks = c10::DispatchKeySet(c10::DispatchKey::AutocastXLA);
key_set_ = (key_set_ - autocast_xla_ks) | autocast_cuda_ks;
}
is_non_overlapping_and_dense_ = false;
set_custom_sizes_strides(SizesStridesPolicy::CustomSizes);
}
Expand Down