Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

qr and ormqr tests and bugfix #119

Merged
merged 1 commit into from
Oct 14, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 32 additions & 11 deletions test/test_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,18 @@ def small_3d_positive(t):
def small_3d_unique(t):
return t(S, S, S).copy_(torch.range(1, S*S*S))

def small_1d_lapack(t):
return torch.range(1, 3).view(3)

def small_2d_lapack(t):
return torch.range(1, 9).view(3, 3)

def small_2d_lapack_skinny(t):
return torch.range(1, 12).view(3, 4)

def small_2d_lapack_fat(t):
return torch.range(1, 12).view(4, 3)

def new_t(*sizes):
def tmp(t):
return t(*sizes).copy_(torch.randn(*sizes))
Expand Down Expand Up @@ -92,13 +104,9 @@ def tmp(t):
('addmv', medium_1d, lambda t: [medium_2d(t), medium_1d(t)], ),
('addmv', medium_1d, lambda t: [number(0.4, 2, t), medium_2d(t), medium_1d(t)], 'scalar' ),
('addmv', medium_1d, lambda t: [number(0.5, 3, t), number(0.4, 2, t), medium_2d(t), medium_1d(t)], 'two_scalars' ),
('addmv', medium_1d, lambda t: [medium_2d(t), medium_1d(t)], ),
('addmv', medium_1d, lambda t: [number(0.4, 2, t), medium_2d(t), medium_1d(t)], 'scalar' ),
('addmv', medium_1d, lambda t: [number(0.5, 3, t), number(0.4, 2, t), medium_2d(t), medium_1d(t)], 'two_scalars' ),
('addr', medium_2d, lambda t: [medium_1d(t), medium_1d(t)], ),
('addr', medium_2d, lambda t: [number(0.4, 2, t), medium_1d(t), medium_1d(t)], 'scalar' ),
('addr', medium_2d, lambda t: [number(0.5, 3, t), number(0.4, 2, t), medium_1d(t), medium_1d(t)], 'two_scalars' ),
('addr', medium_2d, lambda t: [number(0.5, 3, t), number(0.4, 2, t), medium_1d(t), medium_1d(t)], 'two_scalars' ),
('atan2', medium_2d, lambda t: [medium_2d(t)], ),
('chunk', medium_2d, lambda t: [4], ),
('chunk', medium_2d, lambda t: [4, 1], 'dim' ),
Expand Down Expand Up @@ -195,6 +203,11 @@ def tmp(t):
('rsqrt', lambda t: small_3d(t) + 1, lambda t: [], ),
('sinh', lambda t: small_3d(t).clamp(-1, 1), lambda t: [], ),
('tan', lambda t: small_3d(t).clamp(-1, 1), lambda t: [], ),
# lapack tests
('qr', small_2d_lapack, lambda t: [], 'square' ),
('qr', small_2d_lapack_skinny, lambda t: [], 'skinny' ),
('qr', small_2d_lapack_fat, lambda t: [], 'fat' ),

]

# TODO: random functions, cat, gather, scatter, index*, masked*, resize, resizeAs, storage_offset, storage, stride, unfold
Expand Down Expand Up @@ -251,6 +264,11 @@ def tmp(self):
if 'unimplemented data type' in reason:
raise unittest.SkipTest('unimplemented data type')
raise
except AttributeError as e:
reason = e.args[0]
if 'object has no attribute' in reason:
raise unittest.SkipTest('unimplemented data type')
raise
# If one changes, another should change as well
self.assertEqual(cpu_tensor, gpu_tensor, precision)
self.assertEqual(cpu_args, gpu_args, precision)
Expand Down Expand Up @@ -482,19 +500,22 @@ def test_multigpu_serialization_remap_dict(self):
precision = custom_precision.get(name, TestCuda.precision)
for inplace in (True, False):
if inplace:
name = name + '_'
if not hasattr(tensor, name):
name_inner = name + '_'
else:
name_inner = name
if not hasattr(tensor, name_inner):
print("Ignoring {}, because it's not implemented by torch.{}".format(name_inner, tensor.__class__.__name__))
continue
if not hasattr(gpu_tensor, name):
print("Ignoring {}, because it's not implemented by torch.cuda.{}".format(name, gpu_tensor.__class__.__name__))
if not hasattr(gpu_tensor, name_inner):
print("Ignoring {}, because it's not implemented by torch.cuda.{}".format(name_inner, gpu_tensor.__class__.__name__))
continue

test_name = 'test_' + t.__name__ + '_' + name
test_name = 'test_' + t.__name__ + '_' + name_inner
if desc:
test_name += '_' + desc

assert not hasattr(TestCase, test_name)
setattr(TestCuda, test_name, compare_cpu_gpu(constr, arg_constr, name, t, precision))
assert not hasattr(TestCuda, test_name), "Duplicated test name: " + test_name
setattr(TestCuda, test_name, compare_cpu_gpu(constr, arg_constr, name_inner, t, precision))

if __name__ == '__main__':
unittest.main()
117 changes: 117 additions & 0 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1113,6 +1113,123 @@ def test_gesv(self):
torch.gesv(tb, ta, b, a)[0]
self.assertEqual(res1, tb)

@skipIfNoLapack
def test_qr(self):

# Since the QR decomposition is unique only up to the signs of the rows of
# R, we must ensure these are positive before doing the comparison.
def canonicalize(q, r):
d = r.diag().sign().diag()
return torch.mm(q, d), torch.mm(d, r)

def canon_and_check(q, r, expected_q, expected_r):
q_canon, r_canon = canonicalize(q, r)
expected_q_canon, expected_r_canon = canonicalize(expected_q, expected_r)
self.assertEqual(q_canon, expected_q_canon)
self.assertEqual(r_canon, expected_r_canon)

def check_qr(a, expected_q, expected_r):
# standard invocation
q, r = torch.qr(a)
canon_and_check(q, r, expected_q, expected_r)

# in-place
q, r = torch.Tensor(), torch.Tensor()
torch.qr(q, r, a)
canon_and_check(q, r, expected_q, expected_r)

# manually calculate qr using geqrf and orgqr
m = a.size(0)
n = a.size(1)
k = min(m, n)
result, tau = torch.geqrf(a)
self.assertEqual(result.size(0), m)
self.assertEqual(result.size(1), n)
self.assertEqual(tau.size(0), k)
r = torch.triu(result.narrow(0, 0, k))
q, _ = torch.orgqr(result, tau)
q, r = q.narrow(1, 0, k), r
canon_and_check(q, r, expected_q, expected_r)

# check square case
a = torch.Tensor(((1, 2, 3), (4, 5, 6), (7, 8, 10)))

expected_q = torch.Tensor((
(-1.230914909793328e-01, 9.045340337332914e-01, 4.082482904638621e-01),
(-4.923659639173310e-01, 3.015113445777629e-01, -8.164965809277264e-01),
(-8.616404368553292e-01, -3.015113445777631e-01, 4.082482904638634e-01)))
expected_r = torch.Tensor((
(-8.124038404635959e+00, -9.601136296387955e+00, -1.193987e+01),
( 0.000000000000000e+00, 9.045340337332926e-01, 1.507557e+00),
( 0.000000000000000e+00, 0.000000000000000e+00, 4.082483e-01)))

check_qr(a, expected_q, expected_r)

# check rectangular thin
a = torch.Tensor((
( 1, 2, 3),
( 4, 5, 6),
( 7, 8, 9),
(10, 11, 13),
))
expected_q = torch.Tensor((
(-0.0776150525706334, -0.833052161400748 , 0.3651483716701106),
(-0.3104602102825332, -0.4512365874254053, -0.1825741858350556),
(-0.5433053679944331, -0.0694210134500621, -0.7302967433402217),
(-0.7761505257063329, 0.3123945605252804, 0.5477225575051663)
))
expected_r = torch.Tensor((
(-12.8840987267251261, -14.5916298832790581, -17.0753115655393231),
( 0, -1.0413152017509357, -1.770235842976589 ),
( 0, 0, 0.5477225575051664)
))

check_qr(a, expected_q, expected_r)

# check rectangular fat
a = torch.Tensor((
(1, 2, 3, 4),
(5, 6, 7, 8),
(9, 10, 11, 13)
))
expected_q = torch.Tensor((
(-0.0966736489045663, 0.907737593658436 , 0.4082482904638653),
(-0.4833682445228317, 0.3157348151855452, -0.8164965809277254),
(-0.870062840141097 , -0.2762679632873518, 0.4082482904638621)
))
expected_r = torch.Tensor((
( -1.0344080432788603e+01, -1.1794185166357092e+01,
-1.3244289899925587e+01, -1.5564457473635180e+01),
( 0.0000000000000000e+00, 9.4720444555662542e-01,
1.8944088911132546e+00, 2.5653453733825331e+00),
( 0.0000000000000000e+00, 0.0000000000000000e+00,
1.5543122344752192e-15, 4.0824829046386757e-01)
))
check_qr(a, expected_q, expected_r)

@skipIfNoLapack
def test_ormqr(self):
mat1 = torch.randn(10, 10)
mat2 = torch.randn(10, 10)
q, r = torch.qr(mat1)
m, tau = torch.geqrf(mat1)

res1 = torch.mm(q, mat2)
res2, _ = torch.ormqr(m, tau, mat2)
self.assertEqual(res1, res2)

res1 = torch.mm(mat2, q)
res2, _ = torch.ormqr(m, tau, mat2, False)
self.assertEqual(res1, res2)

res1 = torch.mm(q.t(), mat2)
res2, _ = torch.ormqr(m, tau, mat2, True, True)
self.assertEqual(res1, res2)

res1 = torch.mm(mat2, q.t())
res2, _ = torch.ormqr(m, tau, mat2, False, True)
self.assertEqual(res1, res2)

@skipIfNoLapack
def test_trtrs(self):
a = torch.Tensor(((6.80, -2.11, 5.66, 5.97, 8.23),
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/Module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -715,7 +715,7 @@ static PyMethodDef TorchMethods[] = {
{"potrs", (PyCFunction)THPModule_potrs, METH_VARARGS | METH_KEYWORDS, NULL},
{"potri", (PyCFunction)THPModule_potri, METH_VARARGS | METH_KEYWORDS, NULL},
{"pstrf", (PyCFunction)THPModule_pstrf, METH_VARARGS | METH_KEYWORDS, NULL},
{"qe", (PyCFunction)THPModule_qr, METH_VARARGS | METH_KEYWORDS, NULL},
{"qr", (PyCFunction)THPModule_qr, METH_VARARGS | METH_KEYWORDS, NULL},
{"geqrf", (PyCFunction)THPModule_geqrf, METH_VARARGS | METH_KEYWORDS, NULL},
{"orgqr", (PyCFunction)THPModule_orgqr, METH_VARARGS | METH_KEYWORDS, NULL},
{"ormqr", (PyCFunction)THPModule_ormqr, METH_VARARGS | METH_KEYWORDS, NULL},
Expand Down
Loading