Skip to content

Commit

Permalink
use mutex
Browse files Browse the repository at this point in the history
  • Loading branch information
armanbilge committed Apr 23, 2023
1 parent d733616 commit df26d21
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 24 deletions.
14 changes: 7 additions & 7 deletions io/jvm-native/src/main/scala/fs2/io/net/SocketPlatform.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ package net

import com.comcast.ip4s.{IpAddress, SocketAddress}
import cats.effect.{Async, Resource}
import cats.effect.std.Semaphore
import cats.effect.std.Mutex
import cats.syntax.all._

import java.net.InetSocketAddress
Expand All @@ -37,20 +37,20 @@ private[net] trait SocketCompanionPlatform {
ch: AsynchronousSocketChannel
): Resource[F, Socket[F]] =
Resource.make {
(Semaphore[F](1), Semaphore[F](1)).mapN { (readSemaphore, writeSemaphore) =>
(Mutex[F], Mutex[F]).mapN { (readSemaphore, writeSemaphore) =>
new AsyncSocket[F](ch, readSemaphore, writeSemaphore)
}
}(_ => Async[F].delay(if (ch.isOpen) ch.close else ()))

private[net] abstract class BufferedReads[F[_]](
readSemaphore: Semaphore[F]
readSemaphore: Mutex[F]
)(implicit F: Async[F])
extends Socket[F] {
private[this] final val defaultReadSize = 8192
private[this] var readBuffer: ByteBuffer = ByteBuffer.allocateDirect(defaultReadSize)

private def withReadBuffer[A](size: Int)(f: ByteBuffer => F[A]): F[A] =
readSemaphore.permit.use { _ =>
readSemaphore.lock.use { _ =>
F.delay {
if (readBuffer.capacity() < size)
readBuffer = ByteBuffer.allocateDirect(size)
Expand Down Expand Up @@ -107,8 +107,8 @@ private[net] trait SocketCompanionPlatform {

private final class AsyncSocket[F[_]](
ch: AsynchronousSocketChannel,
readSemaphore: Semaphore[F],
writeSemaphore: Semaphore[F]
readSemaphore: Mutex[F],
writeSemaphore: Mutex[F]
)(implicit F: Async[F])
extends BufferedReads[F](readSemaphore) {

Expand Down Expand Up @@ -142,7 +142,7 @@ private[net] trait SocketCompanionPlatform {
go(buff)
else F.unit
}
writeSemaphore.permit.use { _ =>
writeSemaphore.lock.use { _ =>
go(bytes.toByteBuffer)
}
}
Expand Down
16 changes: 8 additions & 8 deletions io/jvm/src/main/scala/fs2/io/net/tls/TLSEngine.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import javax.net.ssl.{SSLEngine, SSLEngineResult}

import cats.Applicative
import cats.effect.kernel.{Async, Sync}
import cats.effect.std.Semaphore
import cats.effect.std.Mutex
import cats.syntax.all._

/** Provides the ability to establish and communicate over a TLS session.
Expand Down Expand Up @@ -65,9 +65,9 @@ private[tls] object TLSEngine {
engine.getSession.getPacketBufferSize,
engine.getSession.getApplicationBufferSize
)
readSemaphore <- Semaphore[F](1)
writeSemaphore <- Semaphore[F](1)
handshakeSemaphore <- Semaphore[F](1)
readSemaphore <- Mutex[F]
writeSemaphore <- Mutex[F]
handshakeSemaphore <- Mutex[F]
sslEngineTaskRunner = SSLEngineTaskRunner[F](engine)
} yield new TLSEngine[F] {
private val doLog: (() => String) => F[Unit] =
Expand All @@ -85,7 +85,7 @@ private[tls] object TLSEngine {
def stopUnwrap = Sync[F].delay(engine.closeInbound()).attempt.void

def write(data: Chunk[Byte]): F[Unit] =
writeSemaphore.permit.use(_ => write0(data))
writeSemaphore.lock.use(_ => write0(data))

private def write0(data: Chunk[Byte]): F[Unit] =
wrapBuffer.input(data) >> wrap
Expand All @@ -104,7 +104,7 @@ private[tls] object TLSEngine {
wrapBuffer.inputRemains
.flatMap(x => wrap.whenA(x > 0 && result.bytesConsumed > 0))
case _ =>
handshakeSemaphore.permit
handshakeSemaphore.lock
.use(_ => stepHandshake(result, true)) >> wrap
}
}
Expand All @@ -124,7 +124,7 @@ private[tls] object TLSEngine {
}

def read(maxBytes: Int): F[Option[Chunk[Byte]]] =
readSemaphore.permit.use(_ => read0(maxBytes))
readSemaphore.lock.use(_ => read0(maxBytes))

private def initialHandshakeDone: F[Boolean] =
Sync[F].delay(engine.getSession.getCipherSuite != "SSL_NULL_WITH_NULL_NULL")
Expand Down Expand Up @@ -168,7 +168,7 @@ private[tls] object TLSEngine {
case SSLEngineResult.HandshakeStatus.FINISHED =>
unwrap(maxBytes)
case _ =>
handshakeSemaphore.permit
handshakeSemaphore.lock
.use(_ => stepHandshake(result, false)) >> unwrap(
maxBytes
)
Expand Down
8 changes: 4 additions & 4 deletions io/jvm/src/main/scala/fs2/io/net/tls/TLSSocketPlatform.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ package net
package tls

import cats.Applicative
import cats.effect.std.Semaphore
import cats.effect.std.Mutex
import cats.effect.kernel._
import cats.syntax.all._

Expand Down Expand Up @@ -53,7 +53,7 @@ private[tls] trait TLSSocketCompanionPlatform { self: TLSSocket.type =>
engine: TLSEngine[F]
): F[TLSSocket[F]] =
for {
readSem <- Semaphore(1)
readSem <- Mutex[F]
} yield new UnsealedTLSSocket[F] {
def write(bytes: Chunk[Byte]): F[Unit] =
engine.write(bytes)
Expand All @@ -62,7 +62,7 @@ private[tls] trait TLSSocketCompanionPlatform { self: TLSSocket.type =>
engine.read(maxBytes)

def readN(numBytes: Int): F[Chunk[Byte]] =
readSem.permit.use { _ =>
readSem.lock.use { _ =>
def go(acc: Chunk[Byte]): F[Chunk[Byte]] = {
val toRead = numBytes - acc.size
if (toRead <= 0) Applicative[F].pure(acc)
Expand All @@ -76,7 +76,7 @@ private[tls] trait TLSSocketCompanionPlatform { self: TLSSocket.type =>
}

def read(maxBytes: Int): F[Option[Chunk[Byte]]] =
readSem.permit.use(_ => read0(maxBytes))
readSem.lock.use(_ => read0(maxBytes))

def reads: Stream[F, Byte] =
Stream.repeatEval(read(8192)).unNoneTerminate.unchunks
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
package fs2.io.net.unixsocket

import cats.effect.kernel.{Async, Resource}
import cats.effect.std.Semaphore
import cats.effect.std.Mutex
import cats.syntax.all._
import com.comcast.ip4s.{IpAddress, SocketAddress}
import fs2.{Chunk, Stream}
Expand Down Expand Up @@ -89,15 +89,15 @@ private[unixsocket] trait UnixSocketsCompanionPlatform {
ch: SocketChannel
): Resource[F, Socket[F]] =
Resource.make {
(Semaphore[F](1), Semaphore[F](1)).mapN { (readSemaphore, writeSemaphore) =>
(Mutex[F], Mutex[F]).mapN { (readSemaphore, writeSemaphore) =>
new AsyncSocket[F](ch, readSemaphore, writeSemaphore)
}
}(_ => Async[F].delay(if (ch.isOpen) ch.close else ()))

private final class AsyncSocket[F[_]](
ch: SocketChannel,
readSemaphore: Semaphore[F],
writeSemaphore: Semaphore[F]
readSemaphore: Mutex[F],
writeSemaphore: Mutex[F]
)(implicit F: Async[F])
extends Socket.BufferedReads[F](readSemaphore) {

Expand All @@ -110,7 +110,7 @@ private[unixsocket] trait UnixSocketsCompanionPlatform {
if (buff.remaining <= 0) F.unit
else go(buff)
}
writeSemaphore.permit.use { _ =>
writeSemaphore.lock.use { _ =>
go(bytes.toByteBuffer)
}
}
Expand Down

0 comments on commit df26d21

Please sign in to comment.