Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .circleci/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ echo "Running Python Tests"

echo "Running MNIST Test"
python test/test_train_mnist.py --tidy
if [ -x "$(command -v nvidia-smi)" ]; then
python test/test_train_mp_mnist_amp.py --fake_data
fi

echo "Running C++ Tests"
pushd test/cpp
Expand Down
102 changes: 102 additions & 0 deletions test/cpp/test_aten_xla_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10007,5 +10007,107 @@ TEST_F(AtenXlaTensorTest, TestEmbeddingBackward) {
}
}

TEST_F(AtenXlaTensorTest, TestAmpForeachNonFiniteCheckAndUnscale) {
torch::Tensor grads0 =
torch::tensor({1, 2, 3, 4}, torch::TensorOptions(torch::kFloat));
torch::Tensor grads1 = torch::tensor({1.0, 2.0, std::nan("1"), 4.0},
torch::TensorOptions(torch::kFloat));
torch::Tensor inv_scale =
torch::scalar_tensor(0.2, torch::TensorOptions(torch::kFloat));
torch::Tensor found_inf =
torch::scalar_tensor(0, torch::TensorOptions(torch::kFloat));
torch::Tensor grads_output0 = grads0 * inv_scale;
torch::Tensor found_inf_output0 =
torch::scalar_tensor(0, torch::TensorOptions(torch::kFloat));
torch::Tensor found_inf_output1 =
torch::scalar_tensor(1, torch::TensorOptions(torch::kFloat));
ForEachDevice([&](const torch::Device& device) {
torch::Tensor xla_grads0 = CopyToDevice(grads0, device);
torch::Tensor xla_inv_scale = CopyToDevice(inv_scale, device);
torch::Tensor xla_found_inf = CopyToDevice(found_inf, device);
torch::_amp_foreach_non_finite_check_and_unscale_(xla_grads0, xla_found_inf,
xla_inv_scale);
AllClose(grads_output0, xla_grads0, /*rtol=*/1e-2, /*atol=*/1e-4);
AllEqual(found_inf_output0, xla_found_inf);

torch::Tensor xla_grads1 = CopyToDevice(grads1, device);
torch::_amp_foreach_non_finite_check_and_unscale_(xla_grads1, xla_found_inf,
xla_inv_scale);
AllEqual(found_inf_output1, xla_found_inf);
});
ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::_amp_foreach_non_finite_check_and_unscale_",
cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestAmpUpdateScale) {
torch::Tensor growth_tracker =
torch::scalar_tensor(0, torch::TensorOptions(torch::kInt32));
torch::Tensor current_scale =
torch::scalar_tensor(4, torch::TensorOptions(torch::kFloat));
torch::Tensor found_inf =
torch::scalar_tensor(1, torch::TensorOptions(torch::kFloat));
torch::Tensor not_found_inf =
torch::scalar_tensor(0, torch::TensorOptions(torch::kFloat));
float scale_growth_factor = 2.0;
float scale_backoff_factor = 0.5;
int growth_interval = 3;

torch::Tensor growth_tracker_result0 =
torch::scalar_tensor(1, torch::TensorOptions(torch::kInt32));
torch::Tensor current_scale_result0 =
torch::scalar_tensor(4, torch::TensorOptions(torch::kFloat));
torch::Tensor growth_tracker_result1 =
torch::scalar_tensor(2, torch::TensorOptions(torch::kInt32));
torch::Tensor current_scale_result1 =
torch::scalar_tensor(4, torch::TensorOptions(torch::kFloat));
torch::Tensor growth_tracker_result2 =
torch::scalar_tensor(0, torch::TensorOptions(torch::kInt32));
torch::Tensor current_scale_result2 =
torch::scalar_tensor(8, torch::TensorOptions(torch::kFloat));
torch::Tensor growth_tracker_result3 =
torch::scalar_tensor(0, torch::TensorOptions(torch::kInt32));
torch::Tensor current_scale_result3 =
torch::scalar_tensor(4, torch::TensorOptions(torch::kFloat));

ForEachDevice([&](const torch::Device& device) {
torch::Tensor xla_growth_tracker = CopyToDevice(growth_tracker, device);
torch::Tensor xla_current_scale = CopyToDevice(current_scale, device);
torch::Tensor xla_found_inf = CopyToDevice(found_inf, device);
torch::Tensor xla_not_found_inf = CopyToDevice(not_found_inf, device);

xla_current_scale = torch::_amp_update_scale(
xla_growth_tracker, xla_current_scale, xla_not_found_inf,
scale_growth_factor, scale_backoff_factor, growth_interval);
AllClose(current_scale_result0, xla_current_scale, /*rtol=*/1e-2,
/*atol=*/1e-4);
AllEqual(growth_tracker_result0, xla_growth_tracker);

xla_current_scale = torch::_amp_update_scale(
xla_growth_tracker, xla_current_scale, xla_not_found_inf,
scale_growth_factor, scale_backoff_factor, growth_interval);
AllClose(current_scale_result1, xla_current_scale, /*rtol=*/1e-2,
/*atol=*/1e-4);
AllEqual(growth_tracker_result1, xla_growth_tracker);

xla_current_scale = torch::_amp_update_scale(
xla_growth_tracker, xla_current_scale, xla_not_found_inf,
scale_growth_factor, scale_backoff_factor, growth_interval);
AllClose(current_scale_result2, xla_current_scale, /*rtol=*/1e-2,
/*atol=*/1e-4);
AllEqual(growth_tracker_result2, xla_growth_tracker);

xla_current_scale = torch::_amp_update_scale(
xla_growth_tracker, xla_current_scale, xla_found_inf,
scale_growth_factor, scale_backoff_factor, growth_interval);
AllClose(current_scale_result3, xla_current_scale, /*rtol=*/1e-2,
/*atol=*/1e-4);
AllEqual(growth_tracker_result3, xla_growth_tracker);
});
ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::_amp_update_scale",
cpp_test::GetIgnoredCounters());
}

} // namespace cpp_test
} // namespace torch_xla
194 changes: 194 additions & 0 deletions test/test_train_mp_mnist_amp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
import args_parse

FLAGS = args_parse.parse_common_options(
datadir='/tmp/mnist-data',
batch_size=128,
momentum=0.5,
lr=0.01,
target_accuracy=98.0,
num_epochs=18)

import os
import shutil
import sys
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import torch_xla
import torch_xla.debug.metrics as met
import torch_xla.distributed.parallel_loader as pl
import torch_xla.utils.utils as xu
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.test.test_utils as test_utils
from torch_xla.amp import autocast, GradScaler


class MNIST(nn.Module):

def __init__(self):
super(MNIST, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.bn1 = nn.BatchNorm2d(10)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.bn2 = nn.BatchNorm2d(20)
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)

def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = self.bn1(x)
x = F.relu(F.max_pool2d(self.conv2(x), 2))
x = self.bn2(x)
x = torch.flatten(x, 1)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)


def _train_update(device, x, loss, tracker, writer):
test_utils.print_training_update(
device,
x,
loss.item(),
tracker.rate(),
tracker.global_rate(),
summary_writer=writer)


def train_mnist(flags, **kwargs):
torch.manual_seed(1)

if flags.fake_data:
train_loader = xu.SampleGenerator(
data=(torch.zeros(flags.batch_size, 1, 28,
28), torch.zeros(flags.batch_size,
dtype=torch.int64)),
sample_count=60000 // flags.batch_size // xm.xrt_world_size())
test_loader = xu.SampleGenerator(
data=(torch.zeros(flags.batch_size, 1, 28,
28), torch.zeros(flags.batch_size,
dtype=torch.int64)),
sample_count=10000 // flags.batch_size // xm.xrt_world_size())
else:
train_dataset = datasets.MNIST(
os.path.join(flags.datadir, str(xm.get_ordinal())),
train=True,
download=True,
transform=transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))]))
test_dataset = datasets.MNIST(
os.path.join(flags.datadir, str(xm.get_ordinal())),
train=False,
download=True,
transform=transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))]))
train_sampler = None
if xm.xrt_world_size() > 1:
train_sampler = torch.utils.data.distributed.DistributedSampler(
train_dataset,
num_replicas=xm.xrt_world_size(),
rank=xm.get_ordinal(),
shuffle=True)
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=flags.batch_size,
sampler=train_sampler,
drop_last=flags.drop_last,
shuffle=False if train_sampler else True,
num_workers=flags.num_workers)
test_loader = torch.utils.data.DataLoader(
test_dataset,
batch_size=flags.batch_size,
drop_last=flags.drop_last,
shuffle=False,
num_workers=flags.num_workers)

# Scale learning rate to num cores
lr = flags.lr * xm.xrt_world_size()

device = xm.xla_device()
model = MNIST().to(device)
writer = None
if xm.is_master_ordinal():
writer = test_utils.get_summary_writer(flags.logdir)
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=flags.momentum)
loss_fn = nn.NLLLoss()
scaler = GradScaler()

def train_loop_fn(loader):
tracker = xm.RateTracker()
model.train()
for step, (data, target) in enumerate(loader):
optimizer.zero_grad()
with autocast():
output = model(data)
loss = loss_fn(output, target)
scaler.scale(loss).backward()
gradients = xm._fetch_gradients(optimizer)
xm.all_reduce('sum', gradients, scale=1.0 / xm.xrt_world_size())
scaler.step(optimizer)
scaler.update()
xm.mark_step()
tracker.add(flags.batch_size)
if step % flags.log_steps == 0:
xm.add_step_closure(
_train_update, args=(device, step, loss, tracker, writer))

def test_loop_fn(loader):
total_samples = 0
correct = 0
model.eval()
for data, target in loader:
output = model(data)
pred = output.max(1, keepdim=True)[1]
correct += pred.eq(target.view_as(pred)).sum()
total_samples += data.size()[0]

accuracy = 100.0 * correct.item() / total_samples
accuracy = xm.mesh_reduce('test_accuracy', accuracy, np.mean)
return accuracy

train_device_loader = pl.MpDeviceLoader(train_loader, device)
test_device_loader = pl.MpDeviceLoader(test_loader, device)
accuracy, max_accuracy = 0.0, 0.0
for epoch in range(1, flags.num_epochs + 1):
xm.master_print('Epoch {} train begin {}'.format(epoch, test_utils.now()))
train_loop_fn(train_device_loader)
xm.master_print('Epoch {} train end {}'.format(epoch, test_utils.now()))

accuracy = test_loop_fn(test_device_loader)
xm.master_print('Epoch {} test end {}, Accuracy={:.2f}'.format(
epoch, test_utils.now(), accuracy))
max_accuracy = max(accuracy, max_accuracy)
test_utils.write_to_summary(
writer,
epoch,
dict_to_write={'Accuracy/test': accuracy},
write_xla_metrics=True)
if flags.metrics_debug:
xm.master_print(met.metrics_report())

test_utils.close_summary_writer(writer)
xm.master_print('Max Accuracy: {:.2f}%'.format(max_accuracy))
return max_accuracy


def _mp_fn(index, flags):
torch.set_default_tensor_type('torch.FloatTensor')
accuracy = train_mnist(flags)
if flags.tidy and os.path.isdir(flags.datadir):
shutil.rmtree(flags.datadir)
if accuracy < flags.target_accuracy:
print('Accuracy {} is below target {}'.format(accuracy,
flags.target_accuracy))
sys.exit(21)


if __name__ == '__main__':
xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=FLAGS.num_cores)
2 changes: 2 additions & 0 deletions torch_xla/amp/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .autocast_mode import autocast, custom_fwd, custom_bwd # noqa: F401
from .grad_scaler import GradScaler # noqa: F401
5 changes: 5 additions & 0 deletions torch_xla/amp/autocast_mode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import torch

autocast = torch.cuda.amp.autocast
custom_fwd = torch.cuda.amp.custom_fwd
custom_bwd = torch.cuda.amp.custom_bwd
13 changes: 13 additions & 0 deletions torch_xla/amp/grad_scaler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import torch
import torch_xla.core.xla_model as xm


class GradScaler(torch.cuda.amp.GradScaler):

def _maybe_opt_step(self, optimizer, optimizer_state, *args, **kwargs):
retval = None
xm.mark_step()
if not sum(
v.item() for v in optimizer_state["found_inf_per_device"].values()):
retval = optimizer.step(*args, **kwargs)
return retval
22 changes: 22 additions & 0 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,28 @@ at::Tensor AtenXlaType::_adaptive_avg_pool2d_backward(
bridge::GetXlaTensor(grad_output), bridge::GetXlaTensor(self)));
}

void AtenXlaType::_amp_foreach_non_finite_check_and_unscale_(
at::TensorList self, at::Tensor& found_inf, const at::Tensor& inv_scale) {
XLA_FN_COUNTER("xla::");
XLATensor found_inf_tensor = bridge::GetXlaTensor(found_inf);
XLATensor::_amp_foreach_non_finite_check_and_unscale_(
bridge::GetXlaTensors(self), found_inf_tensor,
bridge::GetXlaTensor(inv_scale));
}

at::Tensor AtenXlaType::_amp_update_scale(at::Tensor& growth_tracker,
const at::Tensor& current_scale,
const at::Tensor& found_inf,
double scale_growth_factor,
double scale_backoff_factor,
int64_t growth_interval) {
XLA_FN_COUNTER("xla::");
return bridge::AtenFromXlaTensor(XLATensor::_amp_update_scale(
bridge::GetXlaTensor(growth_tracker), bridge::GetXlaTensor(current_scale),
bridge::GetXlaTensor(found_inf), scale_growth_factor,
scale_backoff_factor, growth_interval));
}

at::Tensor AtenXlaType::_copy_from(const at::Tensor& self,
const at::Tensor& dst, bool non_blocking) {
XLA_FN_COUNTER("xla::");
Expand Down
10 changes: 10 additions & 0 deletions torch_xla/csrc/aten_xla_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,16 @@ class AtenXlaType {
static at::Tensor _adaptive_avg_pool2d_backward(const at::Tensor& grad_output,
const at::Tensor& self);

static void _amp_foreach_non_finite_check_and_unscale_(
at::TensorList self, at::Tensor& found_inf, const at::Tensor& inv_scale);

static at::Tensor _amp_update_scale(at::Tensor& growth_tracker,
const at::Tensor& current_scale,
const at::Tensor& found_inf,
double scale_growth_factor,
double scale_backoff_factor,
int64_t growth_interval);

static at::Tensor _copy_from(const at::Tensor& self, const at::Tensor& dst,
bool non_blocking);

Expand Down
Loading