Skip to content

Commit

Permalink
Fixed bug that Saturating operations on Index 1 were unsafe (clash-…
Browse files Browse the repository at this point in the history
  • Loading branch information
rowanG077 committed Jul 26, 2019
1 parent d859a93 commit 3ba4c75
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 19 deletions.
43 changes: 24 additions & 19 deletions clash-prelude/src/Clash/Sized/Internal/Index.hs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ License : BSD2 (see the file LICENSE)
Maintainer : Christiaan Baaij <christiaan.baaij@gmail.com>
-}

{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveAnyClass #-}
Expand Down Expand Up @@ -101,7 +102,7 @@ import Clash.Class.Resize (Resize (..))
import Clash.Prelude.BitIndex (replaceBit)
import {-# SOURCE #-} Clash.Sized.Internal.BitVector (BitVector (BV), high, low, undefError)
import qualified Clash.Sized.Internal.BitVector as BV
import Clash.Promoted.Nat (SNat, snatToNum, leToPlusKN)
import Clash.Promoted.Nat (SNat(..), snatToNum, leToPlusKN)
import Clash.XException
(ShowX (..), Undefined (..), errorX, showsPrecXWith, rwhnfX)

Expand Down Expand Up @@ -278,23 +279,25 @@ times# :: Index m -> Index n -> Index (((m - 1) * (n - 1)) + 1)
times# (I a) (I b) = I (a * b)

instance (KnownNat n, 1 <= n) => SaturatingNum (Index n) where
satAdd SatWrap a b =
leToPlusKN @1 @n $
case plus# a b of
z | let m = fromInteger# (natVal (Proxy @ n))
, z >= m -> resize# (z - m)
z -> resize# z
satAdd SatWrap !a !b =
case snatToNum @Int (SNat @n) of
1 -> fromInteger# 0
_ -> leToPlusKN @1 @n $
case plus# a b of
z | let m = fromInteger# (natVal (Proxy @ n))
, z >= m -> resize# (z - m)
z -> resize# z
satAdd SatZero a b =
leToPlusKN @1 @n $
case plus# a b of
z | let m = fromInteger# (natVal (Proxy @ n))
, z >= m -> fromInteger# 0
z | let m = fromInteger# (natVal (Proxy @ (n - 1)))
, z > m -> fromInteger# 0
z -> resize# z
satAdd _ a b =
leToPlusKN @1 @n $
case plus# a b of
z | let m = fromInteger# (natVal (Proxy @ n))
, z >= m -> maxBound#
z | let m = fromInteger# (natVal (Proxy @ (n - 1)))
, z > m -> maxBound#
z -> resize# z

satSub SatWrap a b =
Expand All @@ -308,21 +311,23 @@ instance (KnownNat n, 1 <= n) => SaturatingNum (Index n) where
else a -# b

satMul SatWrap a b =
leToPlusKN @1 @n $
case times# a b of
z -> let m = fromInteger# (natVal (Proxy @ n))
in resize# (z `mod` m)
case snatToNum @Int (SNat @n) of
1 -> fromInteger# 0
_ -> leToPlusKN @1 @n $
case times# a b of
z -> let m = fromInteger# (natVal (Proxy @ n))
in resize# (z `mod` m)
satMul SatZero a b =
leToPlusKN @1 @n $
case times# a b of
z | let m = fromInteger# (natVal (Proxy @ n))
, z >= m -> fromInteger# 0
z | let m = fromInteger# (natVal (Proxy @ (n - 1)))
, z > m -> fromInteger# 0
z -> resize# z
satMul _ a b =
leToPlusKN @1 @n $
case times# a b of
z | let m = fromInteger# (natVal (Proxy @ n))
, z >= m -> maxBound#
z | let m = fromInteger# (natVal (Proxy @ (n - 1)))
, z > m -> maxBound#
z -> resize# z

instance KnownNat n => Real (Index n) where
Expand Down
17 changes: 17 additions & 0 deletions tests/shouldwork/Numbers/NumConstantFolding.hs
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,22 @@ cExtendingNum = (r1,r2,r3)
r2 = sub @a @b (lit 22203) (lit 22202)
r3 = mul @a @b (lit 22204) (lit 2)

cIndex1 :: _
cIndex1 = (r1a,r1b,r1c,r1d, r2a,r2b,r2c,r2d, r3a,r3b,r3c,r3d)
where
r1a = satAdd @(Index 1) SatWrap (lit 0) (lit 0)
r2a = satSub @(Index 1) SatWrap (lit 0) (lit 0)
r3a = satMul @(Index 1) SatWrap (lit 0) (lit 0)
r1b = satAdd @(Index 1) SatBound (lit 0) (lit 0)
r2b = satSub @(Index 1) SatBound (lit 0) (lit 0)
r3b = satMul @(Index 1) SatBound (lit 0) (lit 0)
r1c = satAdd @(Index 1) SatZero (lit 0) (lit 0)
r2c = satSub @(Index 1) SatZero (lit 0) (lit 0)
r3c = satMul @(Index 1) SatZero (lit 0) (lit 0)
r1d = satAdd @(Index 1) SatSymmetric (lit 0) (lit 0)
r2d = satSub @(Index 1) SatSymmetric (lit 0) (lit 0)
r3d = satMul @(Index 1) SatSymmetric (lit 0) (lit 0)

cSaturatingNum :: forall n. (Num n, SaturatingNum n) => _
cSaturatingNum = (r1a,r1b,r1c,r1d, r2a,r2b,r2c,r2d, r3a,r3b,r3c,r3d)
where
Expand Down Expand Up @@ -213,6 +229,7 @@ tIndex
, cIntegral @(Index 50000)
, cBits @(Index 65536)
-- , cFiniteBits @(Index 50000) -- broken
, cIndex1 -- ensure special case for index 1 is verified
, csClashSpecific @(Index 50000)
, cResize @(Index 50000)
)
Expand Down

0 comments on commit 3ba4c75

Please sign in to comment.