Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use a socket to communicate with the jit runtime process #4887

Merged
merged 12 commits into from
Apr 17, 2024
Merged
31 changes: 19 additions & 12 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -471,12 +471,20 @@ jobs:
echo "jit_dist_exe=$jit_dist_exe" >> $GITHUB_ENV
echo "ucm=$ucm" >> $GITHUB_ENV

- name: cache jit binaries
- uses: actions/checkout@v4
with:
sparse-checkout: |
scripts/get-share-hash.sh
scheme-libs
unison-src/builtin-tests/jit-tests.tpl.md
unison-src/transcripts-using-base/serialized-cases/case-00.v4.ser

- name: cache/restore jit binaries
id: cache-jit-binaries
uses: actions/cache@v4
uses: actions/cache/restore@v4
with:
path: ${{ env.jit_dist }}
key: jit_dist-${{ matrix.os }}.racket_${{ env.racket_version }}.jit_${{ env.jit_version }}
key: jit_dist-${{ matrix.os }}.racket_${{ env.racket_version }}.jit_${{ env.jit_version }}-${{hashFiles('**/scheme-libs/**')}}

- name: Cache Racket dependencies
if: steps.cache-jit-binaries.outputs.cache-hit != 'true'
Expand All @@ -495,14 +503,6 @@ jobs:
variant: CS
version: ${{env.racket_version}}

- uses: actions/checkout@v4
with:
sparse-checkout: |
scripts/get-share-hash.sh
scheme-libs
unison-src/builtin-tests/jit-tests.tpl.md
unison-src/transcripts-using-base/serialized-cases/case-00.v4.ser

- name: look up hash for runtime tests
run: |
echo "runtime_tests_causalhash=$(scripts/get-share-hash.sh ${{env.runtime_tests_version}})" >> $GITHUB_ENV
Expand All @@ -512,7 +512,7 @@ jobs:
uses: actions/cache@v4
with:
path: ${{env.jit_test_results}}
key: jit-test-results.${{ matrix.os }}.racket_${{ env.racket_version }}.jit_${{ env.jit_version }}.tests_${{env.runtime_tests_causalhash}}
key: jit-test-results.${{ matrix.os }}.racket_${{ env.racket_version }}.jit_${{ env.jit_version }}-${{hashFiles('**/scheme-libs/**')}}.tests_${{env.runtime_tests_causalhash}}

- name: install libb2 (linux)
uses: awalsh128/cache-apt-pkgs-action@latest
Expand Down Expand Up @@ -547,6 +547,13 @@ jobs:
raco exe --embed-dlls "$jit_src_scheme"/unison-runtime.rkt
raco distribute "$jit_dist" "$jit_exe"

- name: cache/save jit binaries
if: steps.cache-jit-binaries.outputs.cache-hit != 'true'
uses: actions/cache/save@v4
with:
path: ${{ env.jit_dist }}
key: jit_dist-${{ matrix.os }}.racket_${{ env.racket_version }}.jit_${{ env.jit_version }}-${{hashFiles('**/scheme-libs/**')}}

- name: save jit binary
uses: actions/upload-artifact@v4
with:
Expand Down
7 changes: 5 additions & 2 deletions parser-typechecker/src/Unison/Runtime/ANF/Serialize.hs
Original file line number Diff line number Diff line change
Expand Up @@ -962,8 +962,8 @@ serializeGroupForRehash fops (Derived h _) sg =
f _ = Nothing
refrep = Map.fromList . mapMaybe f $ groupTermLinks sg

deserializeValue :: ByteString -> Either String Value
deserializeValue bs = runGetS (getVersion >>= getValue) bs
getVersionedValue :: MonadGet m => m Value
getVersionedValue = getVersion >>= getValue
where
getVersion =
getWord32be >>= \case
Expand All @@ -973,6 +973,9 @@ deserializeValue bs = runGetS (getVersion >>= getValue) bs
| n <= 4 -> pure n
| otherwise -> fail $ "deserializeValue: unknown version: " ++ show n

deserializeValue :: ByteString -> Either String Value
deserializeValue bs = runGetS getVersionedValue bs

serializeValue :: Value -> ByteString
serializeValue v = runPutS (putVersion *> putValue v)
where
Expand Down
116 changes: 86 additions & 30 deletions parser-typechecker/src/Unison/Runtime/Interface.hs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import Data.Binary.Get (runGetOrFail)
-- import Data.Bits (shiftL)
import Data.ByteString qualified as BS
import Data.ByteString.Lazy qualified as BL
import Data.Bytes.Get (MonadGet)
import Data.Bytes.Get (MonadGet, getWord8, runGetS)
import Data.Bytes.Put (MonadPut, putWord32be, runPutL, runPutS)
import Data.Bytes.Serial
import Data.Foldable
Expand All @@ -44,9 +44,11 @@ import Data.Set as Set
(\\),
)
import Data.Set qualified as Set
import Data.Text (isPrefixOf, unpack)
import Data.Text as Text (isPrefixOf, pack, unpack)
import GHC.IO.Exception (IOErrorType (NoSuchThing, OtherError, PermissionDenied), IOException (ioe_description, ioe_type))
import GHC.Stack (callStack)
import Network.Simple.TCP (Socket, acceptFork, listen, recv, send)
import Network.Socket (PortNumber, socketPort)
import System.Directory
( XdgDirectory (XdgCache),
createDirectoryIfMissing,
Expand Down Expand Up @@ -85,6 +87,7 @@ import Unison.Runtime.ANF as ANF
import Unison.Runtime.ANF.Rehash as ANF (rehashGroups)
import Unison.Runtime.ANF.Serialize as ANF
( getGroup,
getVersionedValue,
putGroup,
serializeValue,
)
Expand Down Expand Up @@ -460,7 +463,18 @@ nativeEval executable ctxVar cl ppe tm = catchInternalErrors $ do
(ctx, codes) <- loadDeps cl ppe ctx tyrs tmrs
(ctx, tcodes, base) <- prepareEvaluation ppe tm ctx
writeIORef ctxVar ctx
nativeEvalInContext executable ppe ctx (codes ++ tcodes) base
-- Note: port 0 mean choosing an arbitrary available port.
-- We then ask what port was actually chosen.
listen "127.0.0.1" "0" $ \(serv, _) ->
socketPort serv >>= \port ->
nativeEvalInContext
executable
ppe
ctx
serv
port
(codes ++ tcodes)
base

interpEval ::
ActiveThreads ->
Expand Down Expand Up @@ -790,14 +804,46 @@ backReferenceTm ws frs irs dcm c i = do
bs <- Map.lookup r dcm
Map.lookup i bs

ucrProc :: FilePath -> [String] -> CreateProcess
ucrProc executable args =
ucrEvalProc :: FilePath -> [String] -> CreateProcess
ucrEvalProc executable args =
(proc executable args)
{ std_in = Inherit,
std_out = Inherit,
std_err = Inherit
}

ucrCompileProc :: FilePath -> [String] -> CreateProcess
ucrCompileProc executable args =
(proc executable args)
{ std_in = CreatePipe,
std_out = Inherit,
std_err = Inherit
}

receiveAll :: Socket -> IO ByteString
receiveAll sock = read []
where
read acc =
recv sock 4096 >>= \case
Just chunk -> read (chunk : acc)
Nothing -> pure . BS.concat $ reverse acc

data NativeResult
= Success Value
| Bug Text Value
| Error Text

deserializeNativeResponse :: ByteString -> NativeResult
deserializeNativeResponse =
run $
getWord8 >>= \case
0 -> Success <$> getVersionedValue
1 -> Bug <$> getText <*> getVersionedValue
2 -> Error <$> getText
_ -> pure $ Error "Unexpected result bytes tag"
where
run e bs = either (Error . pack) id (runGetS e bs)

-- Note: this currently does not support yielding values; instead it
-- just produces a result appropriate for unitary `run` commands. The
-- reason is that the executed code can cause output to occur, which
Expand All @@ -813,37 +859,45 @@ nativeEvalInContext ::
FilePath ->
PrettyPrintEnv ->
EvalCtx ->
Socket ->
PortNumber ->
[(Reference, SuperGroup Symbol)] ->
Reference ->
IO (Either Error ([Error], Term Symbol))
nativeEvalInContext executable _ ctx codes base = do
nativeEvalInContext executable ppe ctx serv port codes base = do
ensureRuntimeExists executable
let cc = ccache ctx
crs <- readTVarIO $ combRefs cc
let bytes = serializeValue . compileValue base $ codes

decodeResult (Left msg) = pure . Left $ fromString msg
decodeResult (Right val) =
decodeResult (Error msg) = pure . Left $ text msg
decodeResult (Bug msg val) =
reifyValue cc val >>= \case
Left _ -> pure . Left $ "missing references from bug result"
Right cl ->
pure . Left . bugMsg ppe [] msg $ decompileCtx crs ctx cl
decodeResult (Success val) =
reifyValue cc val >>= \case
Left _ -> pure . Left $ "missing references from result"
Right cl -> case decompileCtx crs ctx cl of
(errs, dv) -> pure $ Right (listErrors errs, dv)

callout (Just pin) _ _ ph = do
BS.hPut pin . runPutS . putWord32be . fromIntegral $ BS.length bytes
BS.hPut pin bytes
UnliftIO.hClose pin
let unit = Data RF.unitRef 0 [] []
sunit = Data RF.pairRef 0 [] [unit, unit]
comm mv (sock, _) = do
send sock . runPutS . putWord32be . fromIntegral $ BS.length bytes
send sock bytes
UnliftIO.putMVar mv =<< receiveAll sock

callout _ _ _ ph = do
mv <- UnliftIO.newEmptyMVar
tid <- acceptFork serv $ comm mv
waitForProcess ph >>= \case
ExitSuccess -> decodeResult $ Right sunit
ExitFailure _ ->
ExitSuccess ->
decodeResult . deserializeNativeResponse
=<< UnliftIO.takeMVar mv
ExitFailure _ -> do
UnliftIO.killThread tid
pure . Left $ "native evaluation failed"
-- TODO: actualy receive output from subprocess
-- decodeResult . deserializeValue =<< BS.hGetContents pout
callout _ _ _ _ =
pure . Left $ "withCreateProcess didn't provide handles"
p = ucrProc executable []
p = ucrEvalProc executable ["-p", show port]
ucrError (e :: IOException) = pure $ Left (runtimeErrMsg (cmdspec p) (Right e))
withCreateProcess p callout
`UnliftIO.catch` ucrError
Expand Down Expand Up @@ -872,7 +926,7 @@ nativeCompileCodes executable codes base path = do
throwIO $ PE callStack (runtimeErrMsg (cmdspec p) (Right e))
racoError (e :: IOException) =
throwIO $ PE callStack (racoErrMsg (makeRacoCmd RawCommand) (Right e))
p = ucrProc executable ["-G", srcPath]
p = ucrCompileProc executable ["-G", srcPath]
makeRacoCmd :: (FilePath -> [String] -> a) -> a
makeRacoCmd f = f "raco" ["exe", "-o", path, srcPath]
withCreateProcess p callout
Expand Down Expand Up @@ -953,7 +1007,7 @@ bugMsg ::
Pretty ColorText
bugMsg ppe tr name (errs, tm)
| name == "blank expression" =
P.callout icon . P.lines $
P.callout icon . P.linesNonEmpty $
[ P.wrap
( "I encountered a"
<> P.red (P.text name)
Expand All @@ -965,7 +1019,7 @@ bugMsg ppe tr name (errs, tm)
stackTrace ppe tr
]
| "pattern match failure" `isPrefixOf` name =
P.callout icon . P.lines $
P.callout icon . P.linesNonEmpty $
[ P.wrap
( "I've encountered a"
<> P.red (P.text name)
Expand All @@ -980,7 +1034,7 @@ bugMsg ppe tr name (errs, tm)
stackTrace ppe tr
]
| name == "builtin.raise" =
P.callout icon . P.lines $
P.callout icon . P.linesNonEmpty $
[ P.wrap ("The program halted with an unhandled exception:"),
"",
P.indentN 2 $ pretty ppe tm,
Expand All @@ -990,7 +1044,7 @@ bugMsg ppe tr name (errs, tm)
| name == "builtin.bug",
RF.TupleTerm' [Tm.Text' msg, x] <- tm,
"pattern match failure" `isPrefixOf` msg =
P.callout icon . P.lines $
P.callout icon . P.linesNonEmpty $
[ P.wrap
( "I've encountered a"
<> P.red (P.text msg)
Expand All @@ -1005,7 +1059,7 @@ bugMsg ppe tr name (errs, tm)
stackTrace ppe tr
]
bugMsg ppe tr name (errs, tm) =
P.callout icon . P.lines $
P.callout icon . P.linesNonEmpty $
[ P.wrap
( "I've encountered a call to"
<> P.red (P.text name)
Expand All @@ -1018,7 +1072,8 @@ bugMsg ppe tr name (errs, tm) =
]

stackTrace :: PrettyPrintEnv -> [(Reference, Int)] -> Pretty ColorText
stackTrace ppe tr = "Stack trace:\n" <> P.indentN 2 (P.lines $ f <$> tr)
stackTrace _ [] = mempty
stackTrace ppe tr = "\nStack trace:\n" <> P.indentN 2 (P.lines $ f <$> tr)
where
f (rf, n) = name <> count
where
Expand Down Expand Up @@ -1165,10 +1220,11 @@ listErrors :: Set DecompError -> [Error]
listErrors = fmap (P.indentN 2 . renderDecompError) . toList

tabulateErrors :: Set DecompError -> Error
tabulateErrors errs | null errs = "\n"
tabulateErrors errs | null errs = mempty
tabulateErrors errs =
P.indentN 2 . P.lines $
P.wrap "The following errors occured while decompiling:"
""
: P.wrap "The following errors occured while decompiling:"
: (listErrors errs)

restoreCache :: StoredCache -> IO CCache
Expand Down
Loading
Loading