Skip to content

Commit 1f4006b

Browse files
committed
Remove TestGradients due to breakage on the PT side, and due to close proximity to JIT-abandoning point.
1 parent c58c087 commit 1f4006b

File tree

1 file changed

+0
-289
lines changed

1 file changed

+0
-289
lines changed

test/test_operations.py

Lines changed: 0 additions & 289 deletions
Original file line numberDiff line numberDiff line change
@@ -851,295 +851,6 @@ def test(self):
851851
self.assertEqualRel(x, xla_x.to_tensor(), rel_err=1e-3, abs_err=5)
852852

853853

854-
class TestGradients(XlaTestCase):
855-
856-
def checkGrad(self,
857-
model,
858-
inputs,
859-
grad_outputs='random',
860-
xla=True,
861-
rel_err=1e-2,
862-
abs_err=1e-5):
863-
# Trace and symbolically differentiate
864-
traced_model = torch.jit.trace(model, *inputs)
865-
fwd = traced_model._get_method('forward')
866-
xm.forward_passes(fwd.graph)
867-
868-
inputs_params = inputs + list(model.parameters())
869-
inputs_params_buffers = inputs + list(fwd.initial_ivalues())
870-
871-
gradient = torch._C._jit_differentiate(fwd.graph)
872-
xm.forward_passes(gradient.f)
873-
xm.backward_passes(gradient.df)
874-
875-
##############################################################
876-
# Run forward and backwarg graphs via jit interpreter
877-
exec_f = torch_xla._XLAC.GraphExecutor(gradient.f, False)
878-
exec_df = torch_xla._XLAC.GraphExecutor(gradient.df, False)
879-
880-
# forward function
881-
raw_outputs = exec_f(*inputs_params_buffers)
882-
raw_outputs = xu.as_list(raw_outputs)
883-
intermediate_outputs = [
884-
raw_output for raw_output in raw_outputs[gradient.f_real_outputs:]
885-
if isinstance(raw_output, torch.Tensor)
886-
]
887-
outputs = raw_outputs[:gradient.f_real_outputs]
888-
889-
if grad_outputs == 'random':
890-
grad_outputs = _random_like(outputs) + _zeros_like(intermediate_outputs)
891-
892-
raw_grad_outputs = []
893-
raw_grad_outputs += grad_outputs
894-
raw_grad_outputs += [
895-
inputs_params_buffers[i] for i in gradient.df_input_captured_inputs
896-
]
897-
raw_grad_outputs += [
898-
raw_outputs[i] for i in gradient.df_input_captured_outputs
899-
]
900-
901-
grad_inputs = exec_df(*raw_grad_outputs)
902-
grad_inputs = xu.as_list(grad_inputs)
903-
904-
##############################################################
905-
# backward with XLA
906-
if xla:
907-
xla_model = torch_xla._XLAC.XlaModule(traced_model)
908-
inputs_xla = [torch_xla._XLAC.XLATensor(input) for input in inputs]
909-
xla_model((tuple(inputs_xla)))
910-
grads_output_xla = [
911-
torch_xla._XLAC.XLATensor(grad_output)
912-
for grad_output in grad_outputs[:gradient.f_real_outputs]
913-
]
914-
xla_model.backward((tuple(grads_output_xla)))
915-
grad_inputs_xla = [input_xla.grad.to_tensor() for input_xla in inputs_xla]
916-
grad_inputs_xla.extend(
917-
[p.grad.to_tensor() for p in xla_model.parameters()[0]])
918-
##############################################################
919-
# forward + backward with regular autograd / torch
920-
outputs_gt = model(*inputs)
921-
outputs_gt = xu.as_list(outputs_gt)
922-
grad_inputs_gt = torch.autograd.grad(
923-
outputs_gt, inputs_params, grad_outputs, only_inputs=True)
924-
for out_jit, out_autograd in zip(outputs, outputs_gt):
925-
self.assertEqualRel(
926-
out_jit, out_autograd, rel_err=rel_err, abs_err=abs_err)
927-
928-
for grad_input_jit, grad_input_autograd in zip(grad_inputs, grad_inputs_gt):
929-
self.assertEqualRel(
930-
grad_input_jit, grad_input_autograd, rel_err=rel_err, abs_err=abs_err)
931-
932-
# TODO: test buffers as well (running_mean, etc.)
933-
if xla:
934-
for i, (grad_input_jit,
935-
grad_input_xla) in enumerate(zip(grad_inputs, grad_inputs_xla)):
936-
self.assertEqualRel(grad_input_jit, grad_input_xla, rel_err, abs_err)
937-
938-
def test_avgpool(self):
939-
940-
class AvgPoolGrad(nn.Module):
941-
942-
def __init__(self, stride, padding, count_include_pad):
943-
super(AvgPoolGrad, self).__init__()
944-
self.stride = stride
945-
self.padding = padding
946-
self.count_include_pad = count_include_pad
947-
948-
def forward(self, x):
949-
return F.avg_pool2d(x, 2, self.stride, self.padding, False,
950-
self.count_include_pad)
951-
952-
for stride in [1, 2]:
953-
for padding in [0, 1]:
954-
for count_include_pad in [False, True]:
955-
model = AvgPoolGrad(stride, padding, count_include_pad)
956-
inputs = [_gen_tensor(4, 1, 28, 28, requires_grad=True)]
957-
self.checkGrad(model, inputs, xla=True)
958-
959-
def test_adaptive_avgpool(self):
960-
961-
class AdaptiveAvgPoolGrad(nn.Module):
962-
963-
def __init__(self, output_size):
964-
super(AdaptiveAvgPoolGrad, self).__init__()
965-
self.output_size = output_size
966-
967-
def forward(self, x):
968-
return F.adaptive_avg_pool2d(x, self.output_size)
969-
970-
model = AdaptiveAvgPoolGrad((2, 3))
971-
for scale in [1, 2]:
972-
inputs = [_gen_tensor(10, 3, 2 * scale, 3 * scale, requires_grad=True)]
973-
self.checkGrad(model, inputs, xla=True)
974-
975-
def test_threshold(self):
976-
977-
class ThresholdPoolGrad(nn.Module):
978-
979-
def __init__(self):
980-
super(ThresholdPoolGrad, self).__init__()
981-
self.threshold = nn.Threshold(0.4, 20)
982-
983-
def forward(self, x):
984-
return self.threshold(x)
985-
986-
model = ThresholdPoolGrad()
987-
inputs = [_gen_tensor(4, 2, requires_grad=True)]
988-
self.checkGrad(model, inputs, xla=True)
989-
990-
def test_maxpool(self):
991-
992-
class MaxPoolGrad(nn.Module):
993-
994-
def forward(self, x):
995-
return F.max_pool2d(x, 2)
996-
997-
model = MaxPoolGrad()
998-
inputs = [_gen_tensor(4, 1, 28, 28, requires_grad=True)]
999-
self.checkGrad(model, inputs, xla=True)
1000-
1001-
def test_tanh(self):
1002-
1003-
class TanhGrad(nn.Module):
1004-
1005-
def forward(self, x):
1006-
return torch.tanh(x)
1007-
1008-
model = TanhGrad()
1009-
inputs = [_gen_tensor(4, 2, requires_grad=True)]
1010-
self.checkGrad(model, inputs, xla=True)
1011-
1012-
def test_sigmoid(self):
1013-
1014-
class SigmoidGrad(nn.Module):
1015-
1016-
def forward(self, x):
1017-
return torch.sigmoid(x)
1018-
1019-
model = SigmoidGrad()
1020-
inputs = [_gen_tensor(4, 2, requires_grad=True)]
1021-
self.checkGrad(model, inputs, xla=True, rel_err=1e-2, abs_err=1e-2)
1022-
1023-
@unittest.skip(
1024-
'differentiation of prim::ListUnpack is not supported, or it is missing '
1025-
'necessary type information')
1026-
def test_chunk(self):
1027-
1028-
class ChunkGrad(nn.Module):
1029-
1030-
def forward(self, x):
1031-
return x.chunk(2, 1)
1032-
1033-
model = ChunkGrad()
1034-
inputs = [_gen_tensor(4, 4, requires_grad=True)]
1035-
self.checkGrad(model, inputs, xla=True)
1036-
1037-
@unittest.skip('bool value of Tensor with more than one value is ambiguous')
1038-
def test_lstm_cell(self):
1039-
1040-
class LSTMCellGrad(nn.Module):
1041-
1042-
def __init__(self):
1043-
super(LSTMCellGrad, self).__init__()
1044-
self.i2h = nn.Linear(3, 8)
1045-
self.h2h = nn.Linear(2, 8)
1046-
1047-
def forward(self, x, hx, cx):
1048-
gates = self.i2h(x) + self.h2h(hx)
1049-
1050-
ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
1051-
ingate = torch.sigmoid(ingate)
1052-
forgetgate = torch.sigmoid(forgetgate)
1053-
cellgate = torch.tanh(cellgate)
1054-
outgate = torch.sigmoid(outgate)
1055-
1056-
cy = (forgetgate * cx) + (ingate * cellgate)
1057-
hy = outgate * torch.tanh(cy)
1058-
return hy, cy
1059-
1060-
model = LSTMCellGrad()
1061-
inputs = [
1062-
_gen_tensor(4, 3, requires_grad=True),
1063-
_gen_tensor(4, 2, requires_grad=True),
1064-
_gen_tensor(4, 2, requires_grad=True)
1065-
]
1066-
self.checkGrad(model, inputs, xla=True)
1067-
1068-
def test_conv2d(self):
1069-
if FLAGS.long_test:
1070-
config = [
1071-
[1, 7, 15, 32], # ichans
1072-
[1, 4, 21, 32], # ochans
1073-
[1, 2, 3, 5], # size
1074-
[1, 2], # stride
1075-
[0, 1], # padding
1076-
[True, False], # bias
1077-
]
1078-
else:
1079-
config = [
1080-
[1, 5], # ichans
1081-
[1, 4], # ochans
1082-
[1, 3], # size
1083-
[1], # stride
1084-
[0], # padding
1085-
[False], # bias
1086-
]
1087-
for ichans, ochans, size, stride, padding, bias in (
1088-
itertools.product(*config)):
1089-
# TODO: dilation, groups, transpose
1090-
model = nn.Conv2d(ichans, ochans, size, stride, padding, bias=bias)
1091-
inputs = [_gen_tensor(4, ichans, 28, 28, requires_grad=True)]
1092-
self.checkGrad(model, inputs, xla=True, abs_err=1e-3)
1093-
1094-
def test_batchnorm2d(self):
1095-
for chans in [1, 15, 32]:
1096-
for eps in [1e-5, 1e-3, 1e-2]:
1097-
# TODO: momentum, training, affine
1098-
model = nn.BatchNorm2d(chans, eps=eps)
1099-
inputs = [_gen_tensor(4, chans, 28, 28, requires_grad=True)]
1100-
self.checkGrad(model, inputs, xla=True)
1101-
1102-
def test_logsoftmax(self):
1103-
for dim in [0, 1]: # todo test 3d as well
1104-
for batch in [1, 3, 4]:
1105-
1106-
class LSMGrad(nn.Module):
1107-
1108-
def forward(self, x):
1109-
return F.log_softmax(x, dim)
1110-
1111-
model = LSMGrad()
1112-
inputs = [_gen_tensor(batch, 9, requires_grad=True)]
1113-
self.checkGrad(model, inputs, xla=True)
1114-
1115-
def test_nll_loss(self):
1116-
input = _gen_tensor(3, 5, requires_grad=True)
1117-
target = torch.empty(3, dtype=torch.long).random_(5)
1118-
model = XlaNllLoss()
1119-
traced_model = torch.jit.trace(model, (input, target))
1120-
xla_model = torch_xla._XLAC.XlaModule(traced_model)
1121-
xla_inputs = [
1122-
torch_xla._XLAC.XLATensor(input),
1123-
torch_xla._XLAC.XLATensor(target)
1124-
]
1125-
output_xla = xla_model((tuple(xla_inputs)))
1126-
xla_model.backward(*output_xla)
1127-
output = model(input, target)
1128-
output.backward()
1129-
self.assertEqual(input.grad.data, xla_inputs[0].grad.data.to_tensor())
1130-
1131-
def test_mnist(self):
1132-
model = XlaMNIST()
1133-
inputs = [_gen_tensor(4, 1, 28, 28, requires_grad=True)]
1134-
self.checkGrad(model, inputs, xla=True)
1135-
1136-
@unittest.skip('Disable until we figure out the precision issue')
1137-
def test_resnet(self):
1138-
model = torchvision.models.resnet18()
1139-
inputs = [_gen_tensor(4, 3, 224, 224, requires_grad=True)]
1140-
self.checkGrad(model, inputs, xla=False)
1141-
1142-
1143854
class TestOptimizer(XlaTestCase):
1144855

1145856
def test_inplace_add_mul(self):

0 commit comments

Comments
 (0)