Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,6 @@ function run_xla_op_tests {
run_test "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY
run_test_without_functionalization "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY
run_test "$CDIR/test_async_closures.py"
run_test "$CDIR/test_autocast.py"
run_test "$CDIR/test_xla_dist.py"
run_test "$CDIR/test_profiler.py"
run_test "$CDIR/test_ops.py"
Expand Down
273 changes: 41 additions & 232 deletions test/test_autocast.py

Large diffs are not rendered by default.

24 changes: 9 additions & 15 deletions test/test_train_mp_imagenet_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
'--test_only_at_end': {
'action': 'store_true',
},
# AMP only works with XLA:GPU
'--amp': {
'action': 'store_true',
},
Expand Down Expand Up @@ -196,7 +197,6 @@ def train_imagenet():
torch.manual_seed(42)

device = xm.xla_device()
device_hw = xm.xla_device_hw(device)
model = get_model_property('model_fn')().to(device)
writer = None
if xm.is_master_ordinal():
Expand All @@ -219,29 +219,23 @@ def train_imagenet():
summary_writer=writer)
loss_fn = nn.CrossEntropyLoss()
if FLAGS.amp:
if device_hw == 'TPU':
scaler = None
elif device_hw == 'GPU':
scaler = GradScaler(use_zero_grad=FLAGS.use_zero_grad)
scaler = GradScaler(use_zero_grad=FLAGS.use_zero_grad)

def train_loop_fn(loader, epoch):
tracker = xm.RateTracker()
model.train()
for step, (data, target) in enumerate(loader):
optimizer.zero_grad()
if FLAGS.amp:
with autocast(xm.xla_device()):
with autocast():
output = model(data)
loss = loss_fn(output, target)
if scaler:
scaler.scale(loss).backward()
gradients = xm._fetch_gradients(optimizer)
xm.all_reduce('sum', gradients, scale=1.0 / xm.xrt_world_size())
scaler.step(optimizer)
scaler.update()
else:
loss.backward()
xm.optimizer_step(optimizer)

scaler.scale(loss).backward()
gradients = xm._fetch_gradients(optimizer)
xm.all_reduce('sum', gradients, scale=1.0 / xm.xrt_world_size())
scaler.step(optimizer)
scaler.update()
else:
output = model(data)
loss = loss_fn(output, target)
Expand Down
32 changes: 9 additions & 23 deletions test/test_train_mp_mnist_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,10 @@ def forward(self, x):
return F.log_softmax(x, dim=1)


def _train_update(device, step, loss, tracker, writer):
def _train_update(device, x, loss, tracker, writer):
test_utils.print_training_update(
device,
step,
x,
loss.item(),
tracker.rate(),
tracker.global_rate(),
Expand Down Expand Up @@ -130,42 +130,28 @@ def train_mnist(flags, **kwargs):
lr = flags.lr * xm.xrt_world_size()

device = xm.xla_device()
device_hw = xm.xla_device_hw(device)
model = MNIST().to(device)

writer = None
if xm.is_master_ordinal():
writer = test_utils.get_summary_writer(flags.logdir)
optim_cls = syncfree.SGD if FLAGS.use_syncfree_optim else optim.SGD
optimizer = optim_cls(model.parameters(), lr=lr, momentum=flags.momentum)
loss_fn = nn.NLLLoss()

if device_hw == 'TPU':
scaler = None
elif device_hw == 'GPU':
# GradScaler only used for GPU
scaler = GradScaler(use_zero_grad=FLAGS.use_zero_grad)
else:
print("Only TPU or GPU supported for AMP.")
sys.exit(1)
scaler = GradScaler(use_zero_grad=FLAGS.use_zero_grad)

def train_loop_fn(loader):
tracker = xm.RateTracker()
model.train()
for step, (data, target) in enumerate(loader):
optimizer.zero_grad()
with autocast(device):
with autocast():
output = model(data)
loss = loss_fn(output, target)
if scaler:
scaler.scale(loss).backward()
gradients = xm._fetch_gradients(optimizer)
xm.all_reduce('sum', gradients, scale=1.0 / xm.xrt_world_size())
scaler.step(optimizer)
scaler.update()
else:
loss.backward()
xm.optimizer_step(optimizer)
scaler.scale(loss).backward()
gradients = xm._fetch_gradients(optimizer)
xm.all_reduce('sum', gradients, scale=1.0 / xm.xrt_world_size())
scaler.step(optimizer)
scaler.update()
tracker.add(flags.batch_size)
if step % flags.log_steps == 0:
xm.add_step_closure(
Expand Down
1 change: 0 additions & 1 deletion torch_patches/.torch_pin

This file was deleted.

2 changes: 1 addition & 1 deletion torch_xla/amp/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .autocast_mode import autocast # noqa: F401
from .autocast_mode import autocast, custom_fwd, custom_bwd # noqa: F401
from .grad_scaler import GradScaler # noqa: F401
41 changes: 3 additions & 38 deletions torch_xla/amp/autocast_mode.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,5 @@
import torch
import torch_xla.core.xla_model as xm
from typing import Any


class autocast(torch.amp.autocast_mode.autocast):
r"""
See :class:`torch.autocast`.
``torch_xla.amp.autocast(device, args...)`` is equivalent to ``torch.autocast("xla", args...)`` for TPUs
``torch.autocast("cuda", args...)`` for GPUs.
"""

def __init__(self,
device,
enabled: bool = True,
dtype: torch.dtype = torch.bfloat16,
cache_enabled: bool = True):
if xm.xla_device_hw(device) == 'GPU':
super().__init__(
"cuda",
enabled=enabled,
dtype=torch.float16,
cache_enabled=cache_enabled)
elif xm.xla_device_hw(device) == 'TPU':
super().__init__(
"xla", enabled=enabled, dtype=dtype, cache_enabled=cache_enabled)
else:
print(
'Warning: AMP only supported for XLA:TPU and XLA:GPU. Ignoring autocast.'
)

def __enter__(self):
return super().__enter__()

def __exit__(self, exc_type: Any, exc_val: Any,
exc_tb: Any): # type: ignore[override]
return super().__exit__(exc_type, exc_val, exc_tb)

def __call__(self, func):
return super().__call__(func)
autocast = torch.cuda.amp.autocast
custom_fwd = torch.cuda.amp.custom_fwd
custom_bwd = torch.cuda.amp.custom_bwd
1 change: 0 additions & 1 deletion torch_xla/csrc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ ptxla_cc_library(
"aten_autograd_ops.cpp",
"aten_xla_bridge.cpp",
"aten_xla_type.cpp",
"autocast_mode.cpp",
"batch_norm.cpp",
"convert_ops.cpp",
"convolution.cpp",
Expand Down
8 changes: 8 additions & 0 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,10 @@ void XLANativeFunctions::_amp_foreach_non_finite_check_and_unscale_(
at::TensorList self, at::Tensor& found_inf, const at::Tensor& inv_scale) {
TORCH_LAZY_FN_COUNTER("xla::");
XLATensorPtr found_inf_tensor = bridge::GetXlaTensor(found_inf);
XlaDeviceType hw_type =
static_cast<XlaDeviceType>(found_inf_tensor->GetDevice().type());
XLA_CHECK(hw_type == XlaDeviceType::GPU || hw_type == XlaDeviceType::CPU)
<< "AMP should be used with XLA:GPU";
tensor_methods::_amp_foreach_non_finite_check_and_unscale_(
bridge::GetXlaTensors(self), found_inf_tensor,
bridge::GetXlaTensor(inv_scale));
Expand All @@ -451,6 +455,10 @@ at::Tensor& XLANativeFunctions::_amp_update_scale_(at::Tensor& current_scale,
TORCH_LAZY_FN_COUNTER("xla::");
XLATensorPtr growth_tracker_tensor = bridge::GetXlaTensor(growth_tracker);
XLATensorPtr current_scale_tensor = bridge::GetXlaTensor(current_scale);
XlaDeviceType hw_type =
static_cast<XlaDeviceType>(growth_tracker_tensor->GetDevice().type());
XLA_CHECK(hw_type == XlaDeviceType::GPU || hw_type == XlaDeviceType::CPU)
<< "AMP should be used with XLA:GPU";
tensor_methods::_amp_update_scale_(
growth_tracker_tensor, current_scale_tensor,
bridge::GetXlaTensor(found_inf), scale_growth_factor,
Expand Down
160 changes: 0 additions & 160 deletions torch_xla/csrc/autocast_mode.cpp

This file was deleted.

10 changes: 0 additions & 10 deletions torch_xla/csrc/tensor_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

#include <c10/core/ScalarType.h>
#include <c10/core/impl/DeviceGuardImplInterface.h>
#include <c10/core/impl/LocalDispatchKeySet.h>
#include <c10/macros/Macros.h>
#include <torch/csrc/lazy/backend/backend_interface.h>
#include <torch/csrc/lazy/core/tensor.h>
Expand Down Expand Up @@ -63,15 +62,6 @@ XLATensorImpl::XLATensorImpl(XLATensor&& tensor)
GetTypeMeta(tensor),
bridge::XlaDeviceToAtenDevice(tensor.GetDevice())),
tensor_(c10::make_intrusive<XLATensor>(std::move(tensor))) {
// Update the Autocast key based off the backend device.
// Upstream TensorImpl cannot differentiate between XLA:TPU and XLA:GPU
// so we must manually update Autocast to AutocastCUDA on XLA:GPU.
torch::lazy::BackendDevice current_device = GetCurrentDevice();
if (static_cast<XlaDeviceType>(current_device.type()) == XlaDeviceType::GPU) {
auto autocast_cuda_ks = c10::DispatchKeySet(c10::DispatchKey::AutocastCUDA);
auto autocast_xla_ks = c10::DispatchKeySet(c10::DispatchKey::AutocastXLA);
key_set_ = (key_set_ - autocast_xla_ks) | autocast_cuda_ks;
}
is_non_overlapping_and_dense_ = false;
set_custom_sizes_strides(SizesStridesPolicy::CustomSizes);
}
Expand Down