Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#3649 - Add ZStream.fromTcpSocketServer #3677

Merged
merged 3 commits into from Jun 6, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
@@ -1,14 +1,26 @@
package zio.stream

import java.net.InetSocketAddress
import java.nio.channels.AsynchronousSocketChannel
import java.nio.file.{ Files, NoSuchFileException, Paths }
import java.nio.{ Buffer, ByteBuffer }

import scala.concurrent.ExecutionContext.global

import zio._
import zio.blocking.effectBlockingIO
import zio.test.Assertion._
import zio.test._

object ZStreamPlatformSpecificSpec extends ZIOBaseSpec {

def socketClient(port: Int) =
ZManaged.make(effectBlockingIO(AsynchronousSocketChannel.open()).flatMap { client =>
ZIO
.fromFutureJava(client.connect(new InetSocketAddress("localhost", port)))
.map(_ => client)
})(c => ZIO.effectTotal(c.close()))

def spec = suite("ZStream JVM")(
suite("Constructors")(
testM("effectAsync")(checkM(Gen.chunkOf(Gen.anyInt)) { chunk =>
Expand Down Expand Up @@ -191,6 +203,58 @@ object ZStreamPlatformSpecificSpec extends ZIOBaseSpec {
fails(isSubtype[NoSuchFileException](anything))
)
}
),
suite("fromSocketServer")(
testM("read data")(checkM(Gen.anyString.filter(_.nonEmpty)) { message =>
for {
refOut <- Ref.make("")

server <- ZStream
.fromSocketServer(8886)
.foreach { c =>
c.read
.transduce(ZTransducer.utf8Decode)
.runCollect
.map(_.mkString)
.flatMap(s => refOut.update(_ + s))
}
.fork

_ <- socketClient(8886)
.use(c => ZIO.fromFutureJava(c.write(ByteBuffer.wrap(message.getBytes))))
.retry(Schedule.forever)

receive <- refOut.get.doWhileM(s => ZIO.succeed(s.isEmpty))

_ <- server.interrupt
} yield assert(receive)(equalTo(message))
}),
testM("write data")(checkM(Gen.anyString.filter(_.nonEmpty)) { message =>
(for {
refOut <- Ref.make("")

server <- ZStream
.fromSocketServer(8887)
.foreach(c => ZStream.fromIterable(message.getBytes).run(c.write))
.fork

_ <- socketClient(8887).use { c =>
val buffer = ByteBuffer.allocate(message.getBytes.length)

ZIO
.fromFutureJava(c.read(buffer))
.repeat(Schedule.doUntil(_ < 1))
.flatMap { _ =>
(buffer: Buffer).flip()
refOut.update(_ => new String(buffer.array))
}
}.retry(Schedule.forever)

receive <- refOut.get.doWhileM(s => ZIO.succeed(s.isEmpty))

_ <- server.interrupt
} yield assert(receive)(equalTo(message)))
})
)
)
)
Expand Down
109 changes: 108 additions & 1 deletion streams/jvm/src/main/scala/zio/stream/platform.scala
@@ -1,10 +1,12 @@
package zio.stream

import java.io.{ IOException, InputStream, OutputStream }
import java.nio.ByteBuffer
import java.net.InetSocketAddress
import java.nio.channels.FileChannel
import java.nio.channels.{ AsynchronousServerSocketChannel, AsynchronousSocketChannel, CompletionHandler }
import java.nio.file.StandardOpenOption._
import java.nio.file.{ OpenOption, Path }
import java.nio.{ Buffer, ByteBuffer }
import java.{ util => ju }

import zio._
Expand Down Expand Up @@ -302,4 +304,109 @@ trait ZStreamPlatformSpecificConstructors { self: ZStream.type =>
*/
final def fromJavaStreamTotal[A](stream: => ju.stream.Stream[A]): ZStream[Any, Nothing, A] =
ZStream.fromJavaIteratorTotal(stream.iterator())

/**
* Create a stream of accepted connection from server socket
* Emit socket `Connection` from which you can read / write and ensure it is closed after it is used
*/
def fromSocketServer(
port: Int,
host: Option[String] = None
): ZStream[Blocking, Throwable, Connection] =
for {
server <- ZStream.managed(ZManaged.fromAutoCloseable(blocking.effectBlocking {
AsynchronousServerSocketChannel
.open()
.bind(
host.fold(new InetSocketAddress(port))(new InetSocketAddress(_, port))
)
}))

registerConnection <- ZStream.managed(ZManaged.scope)

conn <- ZStream.repeatEffect {
IO.effectAsync[Throwable, UManaged[Connection]] { callback =>
server.accept(
null,
new CompletionHandler[AsynchronousSocketChannel, Void]() {
self =>
override def completed(socket: AsynchronousSocketChannel, attachment: Void): Unit =
callback(ZIO.succeed(Connection.make(socket)))

override def failed(exc: Throwable, attachment: Void): Unit = callback(ZIO.fail(exc))
}
)
}
.flatMap(managedConn => registerConnection(managedConn).map(_._2))
}
} yield conn

/**
* Accepted connection made to a specific channel `AsynchronousServerSocketChannel`
*/
class Connection(socket: AsynchronousSocketChannel) {

/**
* Read the entire `AsynchronousSocketChannel` by emitting a `Chunk[Byte]`
* The caller of this function is NOT responsible for closing the `AsynchronousSocketChannel`.
*/
def read: Stream[Throwable, Byte] =
ZStream.unfoldChunkM(0) {
case -1 => ZIO.succeed(Option.empty)
case _ =>
val buff = ByteBuffer.allocate(ZStream.DefaultChunkSize)

IO.effectAsync[Throwable, Option[(Chunk[Byte], Int)]] { callback =>
socket.read(
buff,
null,
new CompletionHandler[Integer, Void] {
override def completed(bytesRead: Integer, attachment: Void): Unit = {
(buff: Buffer).flip()
callback(ZIO.succeed(Option(Chunk.fromByteBuffer(buff) -> bytesRead.toInt)))
}

override def failed(error: Throwable, attachment: Void): Unit = callback(ZIO.fail(error))
}
)
}
}

/**
* Write the entire Chuck[Byte] to the socket channel.
* The caller of this function is NOT responsible for closing the `AsynchronousSocketChannel`.
*
* The sink will yield the count of bytes written.
*/
def write: Sink[Throwable, Byte, Int] =
ZSink.foldLeftChunksM(0) {
case (nbBytesWritten, c) =>
IO.effectAsync[Throwable, Int] { callback =>
socket.write(
ByteBuffer.wrap(c.toArray),
null,
new CompletionHandler[Integer, Void] {
override def completed(result: Integer, attachment: Void): Unit =
callback(ZIO.succeed(nbBytesWritten + result.toInt))

override def failed(error: Throwable, attachment: Void): Unit = callback(ZIO.fail(error))
}
)
}
}

/**
* Close the underlying socket
*/
def close(): UIO[Unit] = ZIO.effectTotal(socket.close())
}

object Connection {

/**
* Create a `Managed` connection
*/
def make(socket: AsynchronousSocketChannel): UManaged[Connection] =
Managed.make(ZIO.succeed(new Connection(socket)))(_.close())
}
}