diff --git a/core/src/main/scala/cats/data/EitherT.scala b/core/src/main/scala/cats/data/EitherT.scala index 2bdc04655e..66915a7404 100644 --- a/core/src/main/scala/cats/data/EitherT.scala +++ b/core/src/main/scala/cats/data/EitherT.scala @@ -3,7 +3,6 @@ package data import cats.functor.Bifunctor import cats.instances.either._ -import cats.syntax.EitherUtil import cats.syntax.either._ /** @@ -73,7 +72,7 @@ final case class EitherT[F[_], A, B](value: F[Either[A, B]]) { def flatMap[D](f: B => EitherT[F, A, D])(implicit F: Monad[F]): EitherT[F, A, D] = EitherT(F.flatMap(value) { - case l @ Left(_) => F.pure(EitherUtil.leftCast(l)) + case l @ Left(_) => F.pure(l.rightCast) case Right(b) => f(b).value }) diff --git a/core/src/main/scala/cats/instances/either.scala b/core/src/main/scala/cats/instances/either.scala index 5acaa3795d..d5b3cd628f 100644 --- a/core/src/main/scala/cats/instances/either.scala +++ b/core/src/main/scala/cats/instances/either.scala @@ -55,7 +55,7 @@ trait EitherInstances extends EitherInstances1 { override def map2Eval[B, C, Z](fb: Either[A, B], fc: Eval[Either[A, C]])(f: (B, C) => Z): Eval[Either[A, Z]] = fb match { - case l @ Left(_) => Now(EitherUtil.leftCast(l)) + case l @ Left(_) => Now(EitherUtil.rightCast(l)) case Right(b) => fc.map(_.right.map(f(b, _))) } diff --git a/core/src/main/scala/cats/syntax/either.scala b/core/src/main/scala/cats/syntax/either.scala index cbac3404b8..f4809f410a 100644 --- a/core/src/main/scala/cats/syntax/either.scala +++ b/core/src/main/scala/cats/syntax/either.scala @@ -9,6 +9,10 @@ trait EitherSyntax { implicit def catsSyntaxEither[A, B](eab: Either[A, B]): EitherOps[A, B] = new EitherOps(eab) implicit def catsSyntaxEitherObject(either: Either.type): EitherObjectOps = new EitherObjectOps(either) // scalastyle:off ensure.single.space.after.token + + implicit def catsSyntaxLeft[A, B](left: Left[A, B]): LeftOps[A, B] = new LeftOps(left) + + implicit def catsSyntaxRight[A, B](right: Right[A, B]): RightOps[A, B] = new RightOps(right) } final class EitherOps[A, B](val eab: Either[A, B]) extends AnyVal { @@ -24,7 +28,7 @@ final class EitherOps[A, B](val eab: Either[A, B]) extends AnyVal { def orElse[C](fallback: => Either[C, B]): Either[C, B] = eab match { case Left(_) => fallback - case r @ Right(_) => EitherUtil.rightCast(r) + case r @ Right(_) => EitherUtil.leftCast(r) } def recover(pf: PartialFunction[A, B]): Either[A, B] = eab match { @@ -103,23 +107,23 @@ final class EitherOps[A, B](val eab: Either[A, B]) extends AnyVal { } def map[C](f: B => C): Either[A, C] = eab match { - case l @ Left(_) => EitherUtil.leftCast(l) + case l @ Left(_) => EitherUtil.rightCast(l) case Right(b) => Right(f(b)) } def map2Eval[C, Z](fc: Eval[Either[A, C]])(f: (B, C) => Z): Eval[Either[A, Z]] = eab match { - case l @ Left(_) => Now(EitherUtil.leftCast(l)) + case l @ Left(_) => Now(EitherUtil.rightCast(l)) case Right(b) => fc.map(either => new EitherOps(either).map(f(b, _))) } def leftMap[C](f: A => C): Either[C, B] = eab match { case Left(a) => Left(f(a)) - case r @ Right(_) => EitherUtil.rightCast(r) + case r @ Right(_) => EitherUtil.leftCast(r) } def flatMap[D](f: B => Either[A, D]): Either[A, D] = eab match { - case l @ Left(_) => EitherUtil.leftCast(l) + case l @ Left(_) => EitherUtil.rightCast(l) case Right(b) => f(b) } @@ -163,7 +167,7 @@ final class EitherOps[A, B](val eab: Either[A, B]) extends AnyVal { } def traverse[F[_], D](f: B => F[D])(implicit F: Applicative[F]): F[Either[A, D]] = eab match { - case l @ Left(_) => F.pure(EitherUtil.leftCast(l)) + case l @ Left(_) => F.pure(EitherUtil.rightCast(l)) case Right(b) => F.map(f(b))(Right(_)) } @@ -311,10 +315,19 @@ final class CatchOnlyPartiallyApplied[T] private[syntax] { } } +final class LeftOps[A, B](val left: Left[A, B]) extends AnyVal { + /** Cast the right type parameter of the `Left`. */ + def rightCast[C]: Either[A, C] = left.asInstanceOf[Either[A, C]] +} + +final class RightOps[A, B](val right: Right[A, B]) extends AnyVal { + /** Cast the left type parameter of the `Right`. */ + def leftCast[C]: Either[C, B] = right.asInstanceOf[Either[C, B]] +} + +/** Convenience methods to use `Either` syntax inside `Either` syntax definitions. */ private[cats] object EitherUtil { - /** Cast the *right* type parameter of a `Left`. */ - def leftCast[A, B, C](l: Left[A, B]): Either[A, C] = l.asInstanceOf[Left[A, C]] + def leftCast[A, B, C](r: Right[A, B]): Either[C, B] = new RightOps(r).leftCast[C] - /** Cast the *left* type parameter of a `Right` */ - def rightCast[A, B, C](r: Right[A, B]): Either[C, B] = r.asInstanceOf[Right[C, B]] + def rightCast[A, B, C](l: Left[A, B]): Either[A, C] = new LeftOps(l).rightCast[C] } diff --git a/tests/src/test/scala/cats/tests/EitherTests.scala b/tests/src/test/scala/cats/tests/EitherTests.scala index 3f91e21f1f..5db908b543 100644 --- a/tests/src/test/scala/cats/tests/EitherTests.scala +++ b/tests/src/test/scala/cats/tests/EitherTests.scala @@ -44,6 +44,18 @@ class EitherTests extends CatsSuite { checkAll("Either[Int, String]", orderLaws.partialOrder(partialOrder)) checkAll("Either[Int, String]", orderLaws.order(order)) + test("Left/Right cast syntax") { + forAll { (e: Either[Int, String]) => + e match { + case l @ Left(_) => + l.rightCast[Double]: Either[Int, Double] + assert(true) + case r @ Right(_) => + r.leftCast[List[Byte]]: Either[List[Byte], String] + assert(true) + } + } + } test("implicit instances resolve specifically") { val eq = catsStdEqForEither[Int, String]