Skip to content

Commit

Permalink
Fixed bug that Saturating operations on Index 1 were unsafe (clash-…
Browse files Browse the repository at this point in the history
  • Loading branch information
rowanG077 committed Jul 26, 2019
1 parent d859a93 commit 38add0c
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 21 deletions.
43 changes: 24 additions & 19 deletions clash-prelude/src/Clash/Sized/Internal/Index.hs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ License : BSD2 (see the file LICENSE)
Maintainer : Christiaan Baaij <christiaan.baaij@gmail.com>
-}

{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveAnyClass #-}
Expand Down Expand Up @@ -101,7 +102,7 @@ import Clash.Class.Resize (Resize (..))
import Clash.Prelude.BitIndex (replaceBit)
import {-# SOURCE #-} Clash.Sized.Internal.BitVector (BitVector (BV), high, low, undefError)
import qualified Clash.Sized.Internal.BitVector as BV
import Clash.Promoted.Nat (SNat, snatToNum, leToPlusKN)
import Clash.Promoted.Nat (SNat(..), snatToNum, leToPlusKN)
import Clash.XException
(ShowX (..), Undefined (..), errorX, showsPrecXWith, rwhnfX)

Expand Down Expand Up @@ -278,23 +279,25 @@ times# :: Index m -> Index n -> Index (((m - 1) * (n - 1)) + 1)
times# (I a) (I b) = I (a * b)

instance (KnownNat n, 1 <= n) => SaturatingNum (Index n) where
satAdd SatWrap a b =
leToPlusKN @1 @n $
case plus# a b of
z | let m = fromInteger# (natVal (Proxy @ n))
, z >= m -> resize# (z - m)
z -> resize# z
satAdd SatWrap !a !b =
case snatToNum @Int (SNat @n) of
1 -> fromInteger# 0
_ -> leToPlusKN @1 @n $
case plus# a b of
z | let m = fromInteger# (natVal (Proxy @ n))
, z >= m -> resize# (z - m)
z -> resize# z
satAdd SatZero a b =
leToPlusKN @1 @n $
case plus# a b of
z | let m = fromInteger# (natVal (Proxy @ n))
, z >= m -> fromInteger# 0
z | let m = fromInteger# (natVal (Proxy @ (n - 1)))
, z > m -> fromInteger# 0
z -> resize# z
satAdd _ a b =
leToPlusKN @1 @n $
case plus# a b of
z | let m = fromInteger# (natVal (Proxy @ n))
, z >= m -> maxBound#
z | let m = fromInteger# (natVal (Proxy @ (n - 1)))
, z > m -> maxBound#
z -> resize# z

satSub SatWrap a b =
Expand All @@ -308,21 +311,23 @@ instance (KnownNat n, 1 <= n) => SaturatingNum (Index n) where
else a -# b

satMul SatWrap a b =
leToPlusKN @1 @n $
case times# a b of
z -> let m = fromInteger# (natVal (Proxy @ n))
in resize# (z `mod` m)
case snatToNum @Int (SNat @n) of
1 -> fromInteger# 0
_ -> leToPlusKN @1 @n $
case times# a b of
z -> let m = fromInteger# (natVal (Proxy @ n))
in resize# (z `mod` m)
satMul SatZero a b =
leToPlusKN @1 @n $
case times# a b of
z | let m = fromInteger# (natVal (Proxy @ n))
, z >= m -> fromInteger# 0
z | let m = fromInteger# (natVal (Proxy @ (n - 1)))
, z > m -> fromInteger# 0
z -> resize# z
satMul _ a b =
leToPlusKN @1 @n $
case times# a b of
z | let m = fromInteger# (natVal (Proxy @ n))
, z >= m -> maxBound#
z | let m = fromInteger# (natVal (Proxy @ (n - 1)))
, z > m -> maxBound#
z -> resize# z

instance KnownNat n => Real (Index n) where
Expand Down
2 changes: 1 addition & 1 deletion tests/shouldwork/Numbers/ConstantFoldingUtil.hs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ lit = fromIntegral
---------------
-- output tests
---------------
-- Any number found in the HDL between [22000,22999] is considered unfolded
-- Any number found in the HDL that is between [22000,22999] is considered unfolded
-- and is reported as an error.
mainVHDL = checkForUnfolded vhdlNr
mainVerilog = checkForUnfolded verilogNr
Expand Down
19 changes: 18 additions & 1 deletion tests/shouldwork/Numbers/NumConstantFolding.hs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

-- | Test constant folding of most primitives on "Num" types
--
-- Any number found in the HDL between [22000,22999] is considered unfolded
-- Any number found in the HDL that is between [22000,22999] is considered unfolded
-- and is reported as an error.

module NumConstantFolding (topEntity, module ConstantFoldingUtil) where
Expand Down Expand Up @@ -113,6 +113,22 @@ cExtendingNum = (r1,r2,r3)
r2 = sub @a @b (lit 22203) (lit 22202)
r3 = mul @a @b (lit 22204) (lit 2)

cIndex1 :: _
cIndex1 = (r1a,r1b,r1c,r1d, r2a,r2b,r2c,r2d, r3a,r3b,r3c,r3d)
where
r1a = satAdd @(Index 1) SatWrap (lit 0) (lit 0)
r2a = satSub @(Index 1) SatWrap (lit 0) (lit 0)
r3a = satMul @(Index 1) SatWrap (lit 0) (lit 0)
r1b = satAdd @(Index 1) SatBound (lit 0) (lit 0)
r2b = satSub @(Index 1) SatBound (lit 0) (lit 0)
r3b = satMul @(Index 1) SatBound (lit 0) (lit 0)
r1c = satAdd @(Index 1) SatZero (lit 0) (lit 0)
r2c = satSub @(Index 1) SatZero (lit 0) (lit 0)
r3c = satMul @(Index 1) SatZero (lit 0) (lit 0)
r1d = satAdd @(Index 1) SatSymmetric (lit 0) (lit 0)
r2d = satSub @(Index 1) SatSymmetric (lit 0) (lit 0)
r3d = satMul @(Index 1) SatSymmetric (lit 0) (lit 0)

cSaturatingNum :: forall n. (Num n, SaturatingNum n) => _
cSaturatingNum = (r1a,r1b,r1c,r1d, r2a,r2b,r2c,r2d, r3a,r3b,r3c,r3d)
where
Expand Down Expand Up @@ -213,6 +229,7 @@ tIndex
, cIntegral @(Index 50000)
, cBits @(Index 65536)
-- , cFiniteBits @(Index 50000) -- broken
, cIndex1 -- ensure special case for index 1 is verified
, csClashSpecific @(Index 50000)
, cResize @(Index 50000)
)
Expand Down

0 comments on commit 38add0c

Please sign in to comment.