diff --git a/clash-prelude/src/Clash/Sized/Internal/Index.hs b/clash-prelude/src/Clash/Sized/Internal/Index.hs index 2a8a5c153d..3d01ba04d1 100644 --- a/clash-prelude/src/Clash/Sized/Internal/Index.hs +++ b/clash-prelude/src/Clash/Sized/Internal/Index.hs @@ -5,6 +5,7 @@ License : BSD2 (see the file LICENSE) Maintainer : Christiaan Baaij -} +{-# LANGUAGE BangPatterns #-} {-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveAnyClass #-} @@ -101,7 +102,7 @@ import Clash.Class.Resize (Resize (..)) import Clash.Prelude.BitIndex (replaceBit) import {-# SOURCE #-} Clash.Sized.Internal.BitVector (BitVector (BV), high, low, undefError) import qualified Clash.Sized.Internal.BitVector as BV -import Clash.Promoted.Nat (SNat, snatToNum, leToPlusKN) +import Clash.Promoted.Nat (SNat(..), snatToNum, leToPlusKN) import Clash.XException (ShowX (..), Undefined (..), errorX, showsPrecXWith, rwhnfX) @@ -278,23 +279,25 @@ times# :: Index m -> Index n -> Index (((m - 1) * (n - 1)) + 1) times# (I a) (I b) = I (a * b) instance (KnownNat n, 1 <= n) => SaturatingNum (Index n) where - satAdd SatWrap a b = - leToPlusKN @1 @n $ - case plus# a b of - z | let m = fromInteger# (natVal (Proxy @ n)) - , z >= m -> resize# (z - m) - z -> resize# z + satAdd SatWrap !a !b = + case snatToNum @Int (SNat @n) of + 1 -> fromInteger# 0 + _ -> leToPlusKN @1 @n $ + case plus# a b of + z | let m = fromInteger# (natVal (Proxy @ n)) + , z >= m -> resize# (z - m) + z -> resize# z satAdd SatZero a b = leToPlusKN @1 @n $ case plus# a b of - z | let m = fromInteger# (natVal (Proxy @ n)) - , z >= m -> fromInteger# 0 + z | let m = fromInteger# (natVal (Proxy @ (n - 1))) + , z > m -> fromInteger# 0 z -> resize# z satAdd _ a b = leToPlusKN @1 @n $ case plus# a b of - z | let m = fromInteger# (natVal (Proxy @ n)) - , z >= m -> maxBound# + z | let m = fromInteger# (natVal (Proxy @ (n - 1))) + , z > m -> maxBound# z -> resize# z satSub SatWrap a b = @@ -307,22 +310,24 @@ instance (KnownNat n, 1 <= n) => SaturatingNum (Index n) where then fromInteger# 0 else a -# b - satMul SatWrap a b = - leToPlusKN @1 @n $ - case times# a b of - z -> let m = fromInteger# (natVal (Proxy @ n)) - in resize# (z `mod` m) + satMul SatWrap !a !b = + case snatToNum @Int (SNat @n) of + 1 -> fromInteger# 0 + _ -> leToPlusKN @1 @n $ + case times# a b of + z -> let m = fromInteger# (natVal (Proxy @ n)) + in resize# (z `mod` m) satMul SatZero a b = leToPlusKN @1 @n $ case times# a b of - z | let m = fromInteger# (natVal (Proxy @ n)) - , z >= m -> fromInteger# 0 + z | let m = fromInteger# (natVal (Proxy @ (n - 1))) + , z > m -> fromInteger# 0 z -> resize# z satMul _ a b = leToPlusKN @1 @n $ case times# a b of - z | let m = fromInteger# (natVal (Proxy @ n)) - , z >= m -> maxBound# + z | let m = fromInteger# (natVal (Proxy @ (n - 1))) + , z > m -> maxBound# z -> resize# z instance KnownNat n => Real (Index n) where diff --git a/tests/shouldwork/Numbers/NumConstantFolding.hs b/tests/shouldwork/Numbers/NumConstantFolding.hs index 572b0a770f..45b5cc475a 100644 --- a/tests/shouldwork/Numbers/NumConstantFolding.hs +++ b/tests/shouldwork/Numbers/NumConstantFolding.hs @@ -113,6 +113,24 @@ cExtendingNum = (r1,r2,r3) r2 = sub @a @b (lit 22203) (lit 22202) r3 = mul @a @b (lit 22204) (lit 2) +cIndex1 :: _ +cIndex1 = (r1a,r1b,r1c,r1d, r2a,r2b,r2c,r2d, r3a,r3b,r3c,r3d) + where + f :: Int -> Index 1 -> Int + f a = \x -> case x of {0 -> 0; 1 -> a} + r1a = f 22270 (satAdd SatWrap (lit 0) (lit 0)) + r2a = f 22271 (satSub SatWrap (lit 0) (lit 0)) + r3a = f 22272 (satMul SatWrap (lit 0) (lit 0)) + r1b = f 22273 (satAdd SatBound (lit 0) (lit 0)) + r2b = f 22274 (satSub SatBound (lit 0) (lit 0)) + r3b = f 22275 (satMul SatBound (lit 0) (lit 0)) + r1c = f 22276 (satAdd SatZero (lit 0) (lit 0)) + r2c = f 22277 (satSub SatZero (lit 0) (lit 0)) + r3c = f 22278 (satMul SatZero (lit 0) (lit 0)) + r1d = f 22279 (satAdd SatSymmetric (lit 0) (lit 0)) + r2d = f 22280 (satSub SatSymmetric (lit 0) (lit 0)) + r3d = f 22281 (satMul SatSymmetric (lit 0) (lit 0)) + cSaturatingNum :: forall n. (Num n, SaturatingNum n) => _ cSaturatingNum = (r1a,r1b,r1c,r1d, r2a,r2b,r2c,r2d, r3a,r3b,r3c,r3d) where @@ -213,6 +231,7 @@ tIndex , cIntegral @(Index 50000) , cBits @(Index 65536) -- , cFiniteBits @(Index 50000) -- broken + , cIndex1 -- ensure special case for index 1 is verified , csClashSpecific @(Index 50000) , cResize @(Index 50000) )