Skip to content

Commit

Permalink
Fixed bug that Saturating operations on Index 1 were unsafe (clash-…
Browse files Browse the repository at this point in the history
  • Loading branch information
rowanG077 committed Jul 28, 2019
1 parent b81b5a5 commit 9daf2b8
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 20 deletions.
45 changes: 25 additions & 20 deletions clash-prelude/src/Clash/Sized/Internal/Index.hs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ License : BSD2 (see the file LICENSE)
Maintainer : Christiaan Baaij <christiaan.baaij@gmail.com>
-}

{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveAnyClass #-}
Expand Down Expand Up @@ -101,7 +102,7 @@ import Clash.Class.Resize (Resize (..))
import Clash.Prelude.BitIndex (replaceBit)
import {-# SOURCE #-} Clash.Sized.Internal.BitVector (BitVector (BV), high, low, undefError)
import qualified Clash.Sized.Internal.BitVector as BV
import Clash.Promoted.Nat (SNat, snatToNum, leToPlusKN)
import Clash.Promoted.Nat (SNat(..), snatToNum, leToPlusKN)
import Clash.XException
(ShowX (..), Undefined (..), errorX, showsPrecXWith, rwhnfX)

Expand Down Expand Up @@ -278,23 +279,25 @@ times# :: Index m -> Index n -> Index (((m - 1) * (n - 1)) + 1)
times# (I a) (I b) = I (a * b)

instance (KnownNat n, 1 <= n) => SaturatingNum (Index n) where
satAdd SatWrap a b =
leToPlusKN @1 @n $
case plus# a b of
z | let m = fromInteger# (natVal (Proxy @ n))
, z >= m -> resize# (z - m)
z -> resize# z
satAdd SatWrap !a !b =
case snatToNum @Int (SNat @n) of
1 -> fromInteger# 0
_ -> leToPlusKN @1 @n $
case plus# a b of
z | let m = fromInteger# (natVal (Proxy @ n))
, z >= m -> resize# (z - m)
z -> resize# z
satAdd SatZero a b =
leToPlusKN @1 @n $
case plus# a b of
z | let m = fromInteger# (natVal (Proxy @ n))
, z >= m -> fromInteger# 0
z | let m = fromInteger# (natVal (Proxy @ (n - 1)))
, z > m -> fromInteger# 0
z -> resize# z
satAdd _ a b =
leToPlusKN @1 @n $
case plus# a b of
z | let m = fromInteger# (natVal (Proxy @ n))
, z >= m -> maxBound#
z | let m = fromInteger# (natVal (Proxy @ (n - 1)))
, z > m -> maxBound#
z -> resize# z

satSub SatWrap a b =
Expand All @@ -307,22 +310,24 @@ instance (KnownNat n, 1 <= n) => SaturatingNum (Index n) where
then fromInteger# 0
else a -# b

satMul SatWrap a b =
leToPlusKN @1 @n $
case times# a b of
z -> let m = fromInteger# (natVal (Proxy @ n))
in resize# (z `mod` m)
satMul SatWrap !a !b =
case snatToNum @Int (SNat @n) of
1 -> fromInteger# 0
_ -> leToPlusKN @1 @n $
case times# a b of
z -> let m = fromInteger# (natVal (Proxy @ n))
in resize# (z `mod` m)
satMul SatZero a b =
leToPlusKN @1 @n $
case times# a b of
z | let m = fromInteger# (natVal (Proxy @ n))
, z >= m -> fromInteger# 0
z | let m = fromInteger# (natVal (Proxy @ (n - 1)))
, z > m -> fromInteger# 0
z -> resize# z
satMul _ a b =
leToPlusKN @1 @n $
case times# a b of
z | let m = fromInteger# (natVal (Proxy @ n))
, z >= m -> maxBound#
z | let m = fromInteger# (natVal (Proxy @ (n - 1)))
, z > m -> maxBound#
z -> resize# z

instance KnownNat n => Real (Index n) where
Expand Down
19 changes: 19 additions & 0 deletions tests/shouldwork/Numbers/NumConstantFolding.hs
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,24 @@ cExtendingNum = (r1,r2,r3)
r2 = sub @a @b (lit 22203) (lit 22202)
r3 = mul @a @b (lit 22204) (lit 2)

cIndex1 :: _
cIndex1 = (r1a,r1b,r1c,r1d, r2a,r2b,r2c,r2d, r3a,r3b,r3c,r3d)
where
f :: Int -> Index 1 -> Int
f a = \x -> case x of {0 -> 0; 1 -> a}
r1a = f 22270 (satAdd SatWrap (lit 0) (lit 0))
r2a = f 22271 (satSub SatWrap (lit 0) (lit 0))
r3a = f 22272 (satMul SatWrap (lit 0) (lit 0))
r1b = f 22273 (satAdd SatBound (lit 0) (lit 0))
r2b = f 22274 (satSub SatBound (lit 0) (lit 0))
r3b = f 22275 (satMul SatBound (lit 0) (lit 0))
r1c = f 22276 (satAdd SatZero (lit 0) (lit 0))
r2c = f 22277 (satSub SatZero (lit 0) (lit 0))
r3c = f 22278 (satMul SatZero (lit 0) (lit 0))
r1d = f 22279 (satAdd SatSymmetric (lit 0) (lit 0))
r2d = f 22280 (satSub SatSymmetric (lit 0) (lit 0))
r3d = f 22281 (satMul SatSymmetric (lit 0) (lit 0))

cSaturatingNum :: forall n. (Num n, SaturatingNum n) => _
cSaturatingNum = (r1a,r1b,r1c,r1d, r2a,r2b,r2c,r2d, r3a,r3b,r3c,r3d)
where
Expand Down Expand Up @@ -213,6 +231,7 @@ tIndex
, cIntegral @(Index 50000)
, cBits @(Index 65536)
-- , cFiniteBits @(Index 50000) -- broken
, cIndex1 -- ensure special case for index 1 is verified
, csClashSpecific @(Index 50000)
, cResize @(Index 50000)
)
Expand Down

0 comments on commit 9daf2b8

Please sign in to comment.