Skip to content

Commit

Permalink
Make TLS shutdown Unidirectional (thanks, Eric Wong!)
Browse files Browse the repository at this point in the history
  • Loading branch information
gregorycollins committed Jan 14, 2012
1 parent aca6a78 commit 7caaf23
Showing 1 changed file with 51 additions and 30 deletions.
81 changes: 51 additions & 30 deletions src/Snap/Internal/Http/Server/TLS.hs
@@ -1,10 +1,9 @@
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE EmptyDataDecls #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE CPP #-}

------------------------------------------------------------------------------
module Snap.Internal.Http.Server.TLS
( TLSException
, initTLS
Expand All @@ -17,15 +16,12 @@ module Snap.Internal.Http.Server.TLS
, send
) where


------------------------------------------------------------------------------
import Control.Exception
import Data.ByteString.Char8 (ByteString)
import Data.Dynamic
import Foreign.C

import Snap.Internal.Http.Server.Backend

------------------------------------------------------------------------------
#ifdef OPENSSL
import Control.Monad
import qualified Data.ByteString.Char8 as S
Expand All @@ -41,95 +37,121 @@ import OpenSSL
import OpenSSL.Session
import qualified OpenSSL.Session as SSL
import Unsafe.Coerce

import Snap.Internal.Http.Server.Address
#endif
------------------------------------------------------------------------------
import Snap.Internal.Http.Server.Backend


------------------------------------------------------------------------------
data TLSException = TLSException String
deriving (Show, Typeable)
deriving (Show, Typeable)
instance Exception TLSException

#ifndef OPENSSL

#ifndef OPENSSL
------------------------------------------------------------------------------
initTLS :: IO ()
initTLS = throwIO $ TLSException "TLS is not supported"


------------------------------------------------------------------------------
stopTLS :: IO ()
stopTLS = return ()


------------------------------------------------------------------------------
bindHttps :: ByteString -> Int -> FilePath -> FilePath -> IO ListenSocket
bindHttps _ _ _ _ = throwIO $ TLSException "TLS is not supported"


------------------------------------------------------------------------------
freePort :: ListenSocket -> IO ()
freePort _ = return ()


------------------------------------------------------------------------------
createSession :: ListenSocket -> Int -> CInt -> IO () -> IO NetworkSession
createSession _ _ _ _ = throwIO $ TLSException "TLS is not supported"


------------------------------------------------------------------------------
endSession :: NetworkSession -> IO ()
endSession _ = return ()


------------------------------------------------------------------------------
send :: IO () -> IO () -> NetworkSession -> ByteString -> IO ()
send _ _ _ _ = return ()


------------------------------------------------------------------------------
recv :: IO b -> NetworkSession -> IO (Maybe ByteString)
recv _ _ = throwIO $ TLSException "TLS is not supported"

------------------------------------------------------------------------------
#else

#else
------------------------------------------------------------------------------
initTLS :: IO ()
initTLS = withOpenSSL $ return ()


------------------------------------------------------------------------------
stopTLS :: IO ()
stopTLS = return ()


------------------------------------------------------------------------------
bindHttps :: ByteString
-> Int
-> FilePath
-> FilePath
-> IO ListenSocket
bindHttps bindAddress bindPort cert key = do
(family, addr) <- getSockAddr bindPort bindAddress
sock <- Socket.socket family Socket.Stream 0
sock <- Socket.socket family Socket.Stream 0

Socket.setSocketOption sock Socket.ReuseAddr 1
Socket.bindSocket sock addr
Socket.listen sock 150

ctx <- context
contextSetPrivateKeyFile ctx key
contextSetPrivateKeyFile ctx key
contextSetCertificateFile ctx cert
contextSetDefaultCiphers ctx
contextSetDefaultCiphers ctx

certOK <- contextCheckPrivateKey ctx
when (not certOK) $ do
throwIO $ TLSException $ "OpenSSL says that the certificate "
++ "doesn't match the private key!"
when (not certOK) $ throwIO $ TLSException certificateError
return $! ListenHttps sock ctx

return $ ListenHttps sock ctx
where
certificateError = "OpenSSL says that the certificate " ++
"doesn't match the private key!"


------------------------------------------------------------------------------
freePort :: ListenSocket -> IO ()
freePort (ListenHttps sock _) = Socket.sClose sock
freePort _ = return ()


------------------------------------------------------------------------------
createSession :: ListenSocket -> Int -> CInt -> IO () -> IO NetworkSession
createSession (ListenHttps _ ctx) recvSize socket _ = do
csock <- mkSocket socket AF_INET Stream defaultProtocol Connected
ssl <- connection ctx csock

accept ssl
return $ NetworkSession socket (unsafeCoerce ssl) recvSize
return $! NetworkSession socket (unsafeCoerce ssl) recvSize
createSession _ _ _ _ = error "can't call createSession on a ListenHttp"


------------------------------------------------------------------------------
endSession :: NetworkSession -> IO ()
endSession (NetworkSession _ aSSL _) = shutdown ssl Bidirectional
where
ssl = unsafeCoerce aSSL
endSession (NetworkSession _ aSSL _) =
shutdown (unsafeCoerce aSSL) Unidirectional


------------------------------------------------------------------------------
send :: IO () -> IO () -> NetworkSession -> ByteString -> IO ()
send tickleTimeout _ (NetworkSession _ aSSL sz) bs = go bs
where
Expand All @@ -141,20 +163,19 @@ send tickleTimeout _ (NetworkSession _ aSSL sz) bs = go bs
go !s = if S.null s
then return ()
else do
SSL.write ssl a
tickleTimeout
go b

SSL.write ssl a
tickleTimeout
go b
where
(a,b) = S.splitAt sz s


------------------------------------------------------------------------------
recv :: IO b -> NetworkSession -> IO (Maybe ByteString)
recv _ (NetworkSession _ aSSL recvLen) = do
b <- SSL.read ssl recvLen
if S.null b then return Nothing else return $ Just b
return $! if S.null b then Nothing else Just b
where
ssl = unsafeCoerce aSSL


#endif

0 comments on commit 7caaf23

Please sign in to comment.