Skip to content

Commit

Permalink
Merge pull request #7552 from dwijnand/PF-andThen-upcast-PF
Browse files Browse the repository at this point in the history
Make PartialFunction#andThen properly handle upcast PFs
  • Loading branch information
lrytz committed Dec 19, 2018
2 parents f3c0b0f + 893a064 commit c7838e1
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 4 deletions.
13 changes: 10 additions & 3 deletions src/library/scala/PartialFunction.scala
Expand Up @@ -82,13 +82,20 @@ trait PartialFunction[-A, +B] extends (A => B) { self =>

/** Composes this partial function with a transformation function that
* gets applied to results of this partial function.
*
* If the runtime type of the function is a `PartialFunction` then the
* other `andThen` method is used (note its cautions).
*
* @param k the transformation function
* @tparam C the result type of the transformation function.
* @return a partial function with the same domain as this partial function, which maps
* @return a partial function with the domain of this partial function,
* possibly narrowed by the specified function, which maps
* arguments `x` to `k(this(x))`.
*/
override def andThen[C](k: B => C): PartialFunction[A, C] =
new AndThen[A, B, C] (this, k)
override def andThen[C](k: B => C): PartialFunction[A, C] = k match {
case pf: PartialFunction[B, C] => andThen(pf)
case _ => new AndThen[A, B, C](this, k)
}

/**
* Composes this partial function with another partial function that
Expand Down
12 changes: 11 additions & 1 deletion test/junit/scala/PartialFunctionCompositionTest.scala
Expand Up @@ -121,6 +121,16 @@ class PartialFunctionCompositionTest {
assertEquals((pf andThen f).applyOrElse("passpass", fallbackFun), "fallback")
}

@Test
def andThenWithUpcastPartialFunctionTests(): Unit = {
val f: PartialFunction[Int, Int] = {case x if x % 2 == 0 => x + 2}
val g: PartialFunction[Int, Int] = {case x if x % 2 == 1 => x - 2}
val c1 = f andThen g
val c2 = f andThen (g: Int => Int)
assertEquals(8, c1.applyOrElse(2, (_: Int) => 8))
assertEquals(8, c2.applyOrElse(2, (_: Int) => 8))
}

@Test
def inferenceTests(): Unit = {
val fb = (_: Int) => 42
Expand All @@ -136,4 +146,4 @@ class PartialFunctionCompositionTest {
assertFalse(pf3.isDefinedAt(15))
assertTrue(pf3.isDefinedAt(21))
}
}
}

0 comments on commit c7838e1

Please sign in to comment.