@@ -851,295 +851,6 @@ def test(self):
851851 self .assertEqualRel (x , xla_x .to_tensor (), rel_err = 1e-3 , abs_err = 5 )
852852
853853
854- class TestGradients (XlaTestCase ):
855-
856- def checkGrad (self ,
857- model ,
858- inputs ,
859- grad_outputs = 'random' ,
860- xla = True ,
861- rel_err = 1e-2 ,
862- abs_err = 1e-5 ):
863- # Trace and symbolically differentiate
864- traced_model = torch .jit .trace (model , * inputs )
865- fwd = traced_model ._get_method ('forward' )
866- xm .forward_passes (fwd .graph )
867-
868- inputs_params = inputs + list (model .parameters ())
869- inputs_params_buffers = inputs + list (fwd .initial_ivalues ())
870-
871- gradient = torch ._C ._jit_differentiate (fwd .graph )
872- xm .forward_passes (gradient .f )
873- xm .backward_passes (gradient .df )
874-
875- ##############################################################
876- # Run forward and backwarg graphs via jit interpreter
877- exec_f = torch_xla ._XLAC .GraphExecutor (gradient .f , False )
878- exec_df = torch_xla ._XLAC .GraphExecutor (gradient .df , False )
879-
880- # forward function
881- raw_outputs = exec_f (* inputs_params_buffers )
882- raw_outputs = xu .as_list (raw_outputs )
883- intermediate_outputs = [
884- raw_output for raw_output in raw_outputs [gradient .f_real_outputs :]
885- if isinstance (raw_output , torch .Tensor )
886- ]
887- outputs = raw_outputs [:gradient .f_real_outputs ]
888-
889- if grad_outputs == 'random' :
890- grad_outputs = _random_like (outputs ) + _zeros_like (intermediate_outputs )
891-
892- raw_grad_outputs = []
893- raw_grad_outputs += grad_outputs
894- raw_grad_outputs += [
895- inputs_params_buffers [i ] for i in gradient .df_input_captured_inputs
896- ]
897- raw_grad_outputs += [
898- raw_outputs [i ] for i in gradient .df_input_captured_outputs
899- ]
900-
901- grad_inputs = exec_df (* raw_grad_outputs )
902- grad_inputs = xu .as_list (grad_inputs )
903-
904- ##############################################################
905- # backward with XLA
906- if xla :
907- xla_model = torch_xla ._XLAC .XlaModule (traced_model )
908- inputs_xla = [torch_xla ._XLAC .XLATensor (input ) for input in inputs ]
909- xla_model ((tuple (inputs_xla )))
910- grads_output_xla = [
911- torch_xla ._XLAC .XLATensor (grad_output )
912- for grad_output in grad_outputs [:gradient .f_real_outputs ]
913- ]
914- xla_model .backward ((tuple (grads_output_xla )))
915- grad_inputs_xla = [input_xla .grad .to_tensor () for input_xla in inputs_xla ]
916- grad_inputs_xla .extend (
917- [p .grad .to_tensor () for p in xla_model .parameters ()[0 ]])
918- ##############################################################
919- # forward + backward with regular autograd / torch
920- outputs_gt = model (* inputs )
921- outputs_gt = xu .as_list (outputs_gt )
922- grad_inputs_gt = torch .autograd .grad (
923- outputs_gt , inputs_params , grad_outputs , only_inputs = True )
924- for out_jit , out_autograd in zip (outputs , outputs_gt ):
925- self .assertEqualRel (
926- out_jit , out_autograd , rel_err = rel_err , abs_err = abs_err )
927-
928- for grad_input_jit , grad_input_autograd in zip (grad_inputs , grad_inputs_gt ):
929- self .assertEqualRel (
930- grad_input_jit , grad_input_autograd , rel_err = rel_err , abs_err = abs_err )
931-
932- # TODO: test buffers as well (running_mean, etc.)
933- if xla :
934- for i , (grad_input_jit ,
935- grad_input_xla ) in enumerate (zip (grad_inputs , grad_inputs_xla )):
936- self .assertEqualRel (grad_input_jit , grad_input_xla , rel_err , abs_err )
937-
938- def test_avgpool (self ):
939-
940- class AvgPoolGrad (nn .Module ):
941-
942- def __init__ (self , stride , padding , count_include_pad ):
943- super (AvgPoolGrad , self ).__init__ ()
944- self .stride = stride
945- self .padding = padding
946- self .count_include_pad = count_include_pad
947-
948- def forward (self , x ):
949- return F .avg_pool2d (x , 2 , self .stride , self .padding , False ,
950- self .count_include_pad )
951-
952- for stride in [1 , 2 ]:
953- for padding in [0 , 1 ]:
954- for count_include_pad in [False , True ]:
955- model = AvgPoolGrad (stride , padding , count_include_pad )
956- inputs = [_gen_tensor (4 , 1 , 28 , 28 , requires_grad = True )]
957- self .checkGrad (model , inputs , xla = True )
958-
959- def test_adaptive_avgpool (self ):
960-
961- class AdaptiveAvgPoolGrad (nn .Module ):
962-
963- def __init__ (self , output_size ):
964- super (AdaptiveAvgPoolGrad , self ).__init__ ()
965- self .output_size = output_size
966-
967- def forward (self , x ):
968- return F .adaptive_avg_pool2d (x , self .output_size )
969-
970- model = AdaptiveAvgPoolGrad ((2 , 3 ))
971- for scale in [1 , 2 ]:
972- inputs = [_gen_tensor (10 , 3 , 2 * scale , 3 * scale , requires_grad = True )]
973- self .checkGrad (model , inputs , xla = True )
974-
975- def test_threshold (self ):
976-
977- class ThresholdPoolGrad (nn .Module ):
978-
979- def __init__ (self ):
980- super (ThresholdPoolGrad , self ).__init__ ()
981- self .threshold = nn .Threshold (0.4 , 20 )
982-
983- def forward (self , x ):
984- return self .threshold (x )
985-
986- model = ThresholdPoolGrad ()
987- inputs = [_gen_tensor (4 , 2 , requires_grad = True )]
988- self .checkGrad (model , inputs , xla = True )
989-
990- def test_maxpool (self ):
991-
992- class MaxPoolGrad (nn .Module ):
993-
994- def forward (self , x ):
995- return F .max_pool2d (x , 2 )
996-
997- model = MaxPoolGrad ()
998- inputs = [_gen_tensor (4 , 1 , 28 , 28 , requires_grad = True )]
999- self .checkGrad (model , inputs , xla = True )
1000-
1001- def test_tanh (self ):
1002-
1003- class TanhGrad (nn .Module ):
1004-
1005- def forward (self , x ):
1006- return torch .tanh (x )
1007-
1008- model = TanhGrad ()
1009- inputs = [_gen_tensor (4 , 2 , requires_grad = True )]
1010- self .checkGrad (model , inputs , xla = True )
1011-
1012- def test_sigmoid (self ):
1013-
1014- class SigmoidGrad (nn .Module ):
1015-
1016- def forward (self , x ):
1017- return torch .sigmoid (x )
1018-
1019- model = SigmoidGrad ()
1020- inputs = [_gen_tensor (4 , 2 , requires_grad = True )]
1021- self .checkGrad (model , inputs , xla = True , rel_err = 1e-2 , abs_err = 1e-2 )
1022-
1023- @unittest .skip (
1024- 'differentiation of prim::ListUnpack is not supported, or it is missing '
1025- 'necessary type information' )
1026- def test_chunk (self ):
1027-
1028- class ChunkGrad (nn .Module ):
1029-
1030- def forward (self , x ):
1031- return x .chunk (2 , 1 )
1032-
1033- model = ChunkGrad ()
1034- inputs = [_gen_tensor (4 , 4 , requires_grad = True )]
1035- self .checkGrad (model , inputs , xla = True )
1036-
1037- @unittest .skip ('bool value of Tensor with more than one value is ambiguous' )
1038- def test_lstm_cell (self ):
1039-
1040- class LSTMCellGrad (nn .Module ):
1041-
1042- def __init__ (self ):
1043- super (LSTMCellGrad , self ).__init__ ()
1044- self .i2h = nn .Linear (3 , 8 )
1045- self .h2h = nn .Linear (2 , 8 )
1046-
1047- def forward (self , x , hx , cx ):
1048- gates = self .i2h (x ) + self .h2h (hx )
1049-
1050- ingate , forgetgate , cellgate , outgate = gates .chunk (4 , 1 )
1051- ingate = torch .sigmoid (ingate )
1052- forgetgate = torch .sigmoid (forgetgate )
1053- cellgate = torch .tanh (cellgate )
1054- outgate = torch .sigmoid (outgate )
1055-
1056- cy = (forgetgate * cx ) + (ingate * cellgate )
1057- hy = outgate * torch .tanh (cy )
1058- return hy , cy
1059-
1060- model = LSTMCellGrad ()
1061- inputs = [
1062- _gen_tensor (4 , 3 , requires_grad = True ),
1063- _gen_tensor (4 , 2 , requires_grad = True ),
1064- _gen_tensor (4 , 2 , requires_grad = True )
1065- ]
1066- self .checkGrad (model , inputs , xla = True )
1067-
1068- def test_conv2d (self ):
1069- if FLAGS .long_test :
1070- config = [
1071- [1 , 7 , 15 , 32 ], # ichans
1072- [1 , 4 , 21 , 32 ], # ochans
1073- [1 , 2 , 3 , 5 ], # size
1074- [1 , 2 ], # stride
1075- [0 , 1 ], # padding
1076- [True , False ], # bias
1077- ]
1078- else :
1079- config = [
1080- [1 , 5 ], # ichans
1081- [1 , 4 ], # ochans
1082- [1 , 3 ], # size
1083- [1 ], # stride
1084- [0 ], # padding
1085- [False ], # bias
1086- ]
1087- for ichans , ochans , size , stride , padding , bias in (
1088- itertools .product (* config )):
1089- # TODO: dilation, groups, transpose
1090- model = nn .Conv2d (ichans , ochans , size , stride , padding , bias = bias )
1091- inputs = [_gen_tensor (4 , ichans , 28 , 28 , requires_grad = True )]
1092- self .checkGrad (model , inputs , xla = True , abs_err = 1e-3 )
1093-
1094- def test_batchnorm2d (self ):
1095- for chans in [1 , 15 , 32 ]:
1096- for eps in [1e-5 , 1e-3 , 1e-2 ]:
1097- # TODO: momentum, training, affine
1098- model = nn .BatchNorm2d (chans , eps = eps )
1099- inputs = [_gen_tensor (4 , chans , 28 , 28 , requires_grad = True )]
1100- self .checkGrad (model , inputs , xla = True )
1101-
1102- def test_logsoftmax (self ):
1103- for dim in [0 , 1 ]: # todo test 3d as well
1104- for batch in [1 , 3 , 4 ]:
1105-
1106- class LSMGrad (nn .Module ):
1107-
1108- def forward (self , x ):
1109- return F .log_softmax (x , dim )
1110-
1111- model = LSMGrad ()
1112- inputs = [_gen_tensor (batch , 9 , requires_grad = True )]
1113- self .checkGrad (model , inputs , xla = True )
1114-
1115- def test_nll_loss (self ):
1116- input = _gen_tensor (3 , 5 , requires_grad = True )
1117- target = torch .empty (3 , dtype = torch .long ).random_ (5 )
1118- model = XlaNllLoss ()
1119- traced_model = torch .jit .trace (model , (input , target ))
1120- xla_model = torch_xla ._XLAC .XlaModule (traced_model )
1121- xla_inputs = [
1122- torch_xla ._XLAC .XLATensor (input ),
1123- torch_xla ._XLAC .XLATensor (target )
1124- ]
1125- output_xla = xla_model ((tuple (xla_inputs )))
1126- xla_model .backward (* output_xla )
1127- output = model (input , target )
1128- output .backward ()
1129- self .assertEqual (input .grad .data , xla_inputs [0 ].grad .data .to_tensor ())
1130-
1131- def test_mnist (self ):
1132- model = XlaMNIST ()
1133- inputs = [_gen_tensor (4 , 1 , 28 , 28 , requires_grad = True )]
1134- self .checkGrad (model , inputs , xla = True )
1135-
1136- @unittest .skip ('Disable until we figure out the precision issue' )
1137- def test_resnet (self ):
1138- model = torchvision .models .resnet18 ()
1139- inputs = [_gen_tensor (4 , 3 , 224 , 224 , requires_grad = True )]
1140- self .checkGrad (model , inputs , xla = False )
1141-
1142-
1143854class TestOptimizer (XlaTestCase ):
1144855
1145856 def test_inplace_add_mul (self ):
0 commit comments