Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse files

fix bugs in QR

  • Loading branch information...
commit 346bc6c5229e6c503d35ce24185ce538b4e48d24 1 parent 3fb8d78
Patrick Perry authored
View
3  cbits/double.c
@@ -2,7 +2,8 @@
#include "LAPACK.h"
#include "config.h"
-static char *BLAS_TRANS_CODES[] = { "N", "T", "C" };
+
+static char *BLAS_TRANS_CODES[] = { "N", "T", "T" };
#define TRANS(x) BLAS_TRANS_CODES[(int) (x) - (int) BlasNoTrans]
static char *BLAS_UPLO_CODES[] = { "U", "L" };
View
2  lib/Data/Elem/LAPACK/C.hs
@@ -37,7 +37,7 @@ callWithWork call =
alloca $ \pQuery -> do
call pQuery (-1)
ldWork <- peek (castPtr pQuery) :: IO Double
- let lWork = ceiling ldWork
+ let lWork = max 1 $ ceiling ldWork
allocaArray lWork $ \pWork -> do
call pWork lWork
View
4 lib/Data/Elem/LAPACK/Zomplex.hs
@@ -23,11 +23,11 @@ foreign import ccall unsafe "LAPACK.h lapack_zgeqrf"
foreign import ccall unsafe "LAPACK.h lapack_zgelqf"
zgelqf :: Int -> Int -> Ptr Zomplex -> Int -> Ptr Zomplex -> Ptr Zomplex -> Int -> IO Int
-foreign import ccall unsafe "LAPACK.h lapack_dormqr"
+foreign import ccall unsafe "LAPACK.h lapack_zunmqr"
zunmqr :: CBLASSide -> CBLASTrans -> Int -> Int -> Int -> Ptr Zomplex -> Int -> Ptr Zomplex
-> Ptr Zomplex -> Int -> Ptr Zomplex -> Int -> IO Int
-foreign import ccall unsafe "LAPACK.h lapack_dormlq"
+foreign import ccall unsafe "LAPACK.h lapack_zunmlq"
zunmlq :: CBLASSide -> CBLASTrans -> Int -> Int -> Int -> Ptr Zomplex -> Int -> Ptr Zomplex
-> Ptr Zomplex -> Int -> Ptr Zomplex -> Int -> IO Int
View
8 lib/Data/Matrix/QR.hs
@@ -96,7 +96,9 @@ qrFactor a = runST $ getQRFactor a
-- | Get the QR factorization of a dense matrix.
getQRFactor :: (ReadMatrix a m, LAPACK e) => a (n,p) e -> m (QR (n,p) e)
-getQRFactor a = unsafePerformIOWithMatrix a qrFactorize
+getQRFactor a = unsafePerformIOWithMatrix a $ \a' -> do
+ a'' <- newCopyMatrix a'
+ qrFactorize a''
{-# INLINE getQRFactor #-}
-- | Compute the QR factorization of a matrix in-place and return the
@@ -135,9 +137,9 @@ qrFactorize a
in return $ QR q r
else let q = fromJust $ flip maybeHouseFromCols tau' $ lowerU a''
r = upper a''
- in return $ QR q r
+ in return $ QR q r
where (n,p) = shape a
(n',p') = if isHermMatrix a then (p,n) else (n,p)
- np = max 1 (min n p)
+ np = min n p
ldA = ldaMatrix a
{-# INLINE qrFactorize #-}
View
2  tests/Main.hs
@@ -9,7 +9,7 @@ import Orthogonal( tests_Orthogonal )
main :: IO ()
main = do
args <- getArgs
- let n = if null args then 100 else read (head args)
+ let n = if null args then 1000 else read (head args)
printf "\nRunnings tests for field `%s'\n" field
View
28 tests/Orthogonal.hs
@@ -5,7 +5,7 @@ module Orthogonal
import Driver
import Monadic
import Test.QuickCheck hiding ( vector )
-import Test.QuickCheck.BLAS( Pos(..) )
+import Test.QuickCheck.BLAS( Pos(..), Nat2(..) )
import qualified Test.QuickCheck.BLAS as Test
import Control.Monad
@@ -14,7 +14,11 @@ import Data.Elem.BLAS
import Data.Vector.Dense
import Data.Vector.Dense.ST
import Data.Matrix.Dense
+import Data.Matrix.Dense.ST
import Data.Matrix.House
+import Data.Matrix.QR
+
+import Debug.Trace
prop_setReflector_snd (Pos n) =
monadicST $ do
@@ -42,9 +46,31 @@ prop_reflector_matrix (Pos n) =
ra = colsMatrix (n,p) [ (k*beta) *> basisVector n 0 | k <- es ]
in herm r <**> a ~== ra
+prop_qrFactor (Nat2 mn) =
+ forAll (Test.matrix mn) $ \(a :: M) ->
+ let qr = qrFactor a
+ (q,r) = (qrQ qr, qrR qr)
+ i = identityMatrix (numCols r, numCols r)
+ a' = q <**> r <**> i
+ in a' ~== a
+
+prop_qrFactor_solveVector (Nat2 mn) =
+ forAll (Test.matrix mn) $ \(a :: M) ->
+ let a' = runSTMatrix (do
+ ma <- unsafeThawMatrix a
+ setConstant 1 (diagView ma 0)
+ return ma) in
+ forAll (Test.vector $ snd mn) $ \x ->
+ let y = a' <*> x
+ x' = qrFactor a' <\> y
+ y' = a' <*> x'
+ in y' ~== y
+
tests_Orthogonal =
[ ("snd . setReflector", mytest prop_setReflector_snd)
, ("fst . reflector", mytest prop_reflector_fst)
, ("reflector <*>", mytest prop_reflector_vector)
, ("reflector <**>", mytest prop_reflector_matrix)
+ , ("qrFactor", mytest prop_qrFactor)
+ , ("qrFactor solveVector", mytest prop_qrFactor_solveVector)
]
Please sign in to comment.
Something went wrong with that request. Please try again.