Skip to content

Commit

Permalink
Add ability to compare downloaded content's SHA-256 sum to given value
Browse files Browse the repository at this point in the history
  • Loading branch information
voidcontext committed May 12, 2020
1 parent adcf7ec commit b1dcdf2
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 9 deletions.
2 changes: 1 addition & 1 deletion docker/static-files/pre_build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ DIR="$(dirname $0)"
echo "Destination dir: $DIR"

head -c 104857600 </dev/urandom > $DIR/100MB.bin
shasum $DIR/100MB.bin | awk '{print $1}' > $DIR/100MB.bin.shasum
shasum -a 256 $DIR/100MB.bin | awk '{print $1}' > $DIR/100MB.bin.sha256
32 changes: 26 additions & 6 deletions fetch-file/src/main/scala/vdx/fetchfile/Downloader.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
package vdx.fetchfile

import cats.effect._
import fs2.Pipe
import cats.syntax.eq._
import cats.instances.string._
import cats.syntax.functor._
import fs2.{Pipe, Stream}
import fs2.io.writeOutputStream

import java.io.OutputStream
Expand All @@ -23,7 +26,7 @@ trait Downloader[F[_]] {
/**
* Fetches the given URL and popuplates the given output stream.
*/
def fetch(url: URL, out: Resource[F, OutputStream]): F[Unit]
def fetch(url: URL, out: Resource[F, OutputStream], sha256Sum: Option[String] = None): F[Unit]
}

object Downloader {
Expand All @@ -34,17 +37,34 @@ object Downloader {
*/
def apply[F[_]: Concurrent: ContextShift](
ec: Blocker,
progress: ContentLength => Pipe[F, Byte, Unit] = Progress.noop[F]
progress: ContentLength => Pipe[F, Byte, Unit] = Progress.noop[F],
)(implicit client: HttpClient[F]): Downloader[F] = new Downloader[F] {
def fetch(url: URL, out: Resource[F, OutputStream]): F[Unit] =
def fetch(url: URL, out: Resource[F, OutputStream], sha256Sum: Option[String] = None): F[Unit] =
out.use { outStream =>
client(url) { (contentLength, body) =>
body.observe(progress(contentLength))
.through(writeOutputStream[F](Concurrent[F].delay(outStream), ec))
// The writeOutputStream pipe returns Unit so it is safe to write the final output using observe
.observe(writeOutputStream[F](Concurrent[F].delay(outStream), ec))
.through(maybeCompareSHA(sha256Sum))
.compile
.drain

}
}

def maybeCompareSHA(sha256: Option[String]): Pipe[F, Byte, Unit] =
stream =>
sha256.map[Stream[F, Unit]] { expectedSHA =>
Stream.eval(
// We'll compute the sh256 hash of the downloaded file
stream.through(fs2.hash.sha256)
.compile
.toVector
).flatMap { hashBytes =>
val hash = hashBytes.map("%02x".format(_)).mkString
if (hash === expectedSHA.toLowerCase()) Stream.emit(()).covary[F]
else Stream.raiseError(new Exception(s"Sha256 sum doesn't match (expected: $expectedSHA, got: $hash)"))
}
}
.getOrElse(stream.void)
}
}
29 changes: 27 additions & 2 deletions fetch-file/src/test/scala/vdx/fetchfile/DownloaderSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,39 @@ class DownloaderSpec extends AnyFlatSpec with Matchers {
}).unsafeRunSync()

downloadedBytes.length should be(1024 * 1024 * 100)
val expectedShaSum = Source.fromFile("docker/static-files/100MB.bin.shasum").mkString.trim()
val expectedShaSum = Source.fromFile("docker/static-files/100MB.bin.sha256").mkString.trim()

val shaSum = MessageDigest.getInstance("SHA-1")
val shaSum = MessageDigest.getInstance("SHA-256")
.digest(downloadedBytes)
.map("%02x".format(_))
.mkString

shaSum should be(expectedShaSum)
}

it should "be successful when the given shasum is correct" in {
val expectedShaSum = Source.fromFile("docker/static-files/100MB.bin.sha256").mkString.trim()

download(expectedShaSum).attempt.unsafeRunSync() should be(Right(()))
}

it should "be fail when the given shasum is incorrect" in {
download("some-wrong-sha-sum").attempt.unsafeRunSync() should be(a[Left[_, _]])
}

def download(shaSum: String): IO[Unit] =
Blocker[IO].use { blocker =>
implicit val client = HttpURLConnectionClient[IO](blocker, 1024 * 8)
val downloader = Downloader[IO](blocker)
val out = new ByteArrayOutputStream()
for {
_ <- downloader.fetch(
new URL("http://localhost:8088/100MB.bin"),
Resource.fromAutoCloseable(IO.delay(out)),
sha256Sum = Option(shaSum)
)
_ <- IO.delay(out.toByteArray())
} yield ()
}
}

0 comments on commit b1dcdf2

Please sign in to comment.