diff --git a/lib/IRGen/GenCast.cpp b/lib/IRGen/GenCast.cpp index f5d350f075ca3..6a2fa61080b5e 100644 --- a/lib/IRGen/GenCast.cpp +++ b/lib/IRGen/GenCast.cpp @@ -803,8 +803,35 @@ void irgen::emitScalarCheckedCast(IRGenFunction &IGF, assert(sourceType.isObject()); assert(targetType.isObject()); - if (auto sourceOptObjectType = sourceType.getOptionalObjectType()) { + llvm::BasicBlock *nilCheckBB = nullptr; + llvm::BasicBlock *nilMergeBB = nullptr; + + // Merge the nil check and return the merged result: either nil or the value. + auto returnNilCheckedResult = [&](IRBuilder &Builder, + Explosion &nonNilResult) { + if (nilCheckBB) { + auto notNilBB = Builder.GetInsertBlock(); + Builder.CreateBr(nilMergeBB); + + Builder.emitBlock(nilMergeBB); + // Insert result phi. + Explosion result; + while (!nonNilResult.empty()) { + auto val = nonNilResult.claimNext(); + auto valTy = cast(val->getType()); + auto nil = llvm::ConstantPointerNull::get(valTy); + auto phi = Builder.CreatePHI(valTy, 2); + phi->addIncoming(nil, nilCheckBB); + phi->addIncoming(val, notNilBB); + result.add(phi); + } + out = std::move(result); + } else { + out = std::move(nonNilResult); + } + }; + if (auto sourceOptObjectType = sourceType.getOptionalObjectType()) { // Translate the value from an enum representation to a possibly-null // representation. Note that we assume that this projection is safe // for the particular case of an optional class-reference or metatype @@ -816,6 +843,22 @@ void irgen::emitScalarCheckedCast(IRGenFunction &IGF, assert(value.empty()); value = std::move(optValue); sourceType = sourceOptObjectType; + + // We need a null-check because the runtime function can't handle null in + // some of the cases. + if (targetType.isExistentialType()) { + auto &Builder = IGF.Builder; + auto val = value.getAll()[0]; + auto isNotNil = Builder.CreateICmpNE( + val, llvm::ConstantPointerNull::get( + cast(val->getType()))); + auto *isNotNilContBB = llvm::BasicBlock::Create(IGF.IGM.getLLVMContext()); + nilMergeBB = llvm::BasicBlock::Create(IGF.IGM.getLLVMContext()); + nilCheckBB = Builder.GetInsertBlock(); + Builder.CreateCondBr(isNotNil, isNotNilContBB, nilMergeBB); + + Builder.emitBlock(isNotNilContBB); + } } // If the source value is a metatype, either do a metatype-to-metatype @@ -880,11 +923,14 @@ void irgen::emitScalarCheckedCast(IRGenFunction &IGF, } if (targetType.isExistentialType()) { + Explosion outRes; emitScalarExistentialDowncast(IGF, instance, sourceType, targetType, - mode, /*not a metatype*/ None, out); + mode, /*not a metatype*/ None, outRes); + returnNilCheckedResult(IGF.Builder, outRes); return; } + Explosion outRes; llvm::Value *result = emitClassDowncast(IGF, instance, targetType, mode); out.add(result); } diff --git a/test/IRGen/casts.sil b/test/IRGen/casts.sil index 1d8b0de90a438..65a63f21f5870 100644 --- a/test/IRGen/casts.sil +++ b/test/IRGen/casts.sil @@ -271,6 +271,31 @@ nay: unreachable } +// CHECK: define swiftcc {{.*}} @checked_downcast_optional_class_to_ex([[INT]]) +// CHECK: entry: +// CHECK: [[V1:%.*]] = inttoptr [[INT]] %0 to %T5casts1AC* +// CHECK: [[V2:%.*]] = icmp ne %T5casts1AC* [[V1]], null +// CHECK: br i1 [[V2]], label %[[LBL:.*]], label +// CHECK: [[LBL]]: +// CHECK: [[V4:%.*]] = bitcast %T5casts1AC* [[V1]] to %swift.type** +// CHECK: load %swift.type*, %swift.type** [[V4]] +sil @checked_downcast_optional_class_to_ex : $@convention(thin) (@guaranteed Optional) -> @owned Optional { +bb0(%0 : $Optional): + checked_cast_br %0 : $Optional to $CP, bb1, bb2 + +bb1(%3 : $CP): + %4 = enum $Optional, #Optional.some!enumelt.1, %3 : $CP + retain_value %0 : $Optional + br bb3(%4 : $Optional) + +bb2: + %7 = enum $Optional, #Optional.none!enumelt + br bb3(%7 : $Optional) + +bb3(%9 : $Optional): + return %9 : $Optional +} + // CHECK-LABEL: define{{( protected)?}} swiftcc void @checked_metatype_to_object_casts sil @checked_metatype_to_object_casts : $@convention(thin) (@thick Any.Type) -> () { entry(%e : $@thick Any.Type): diff --git a/test/Interpreter/casts.swift b/test/Interpreter/casts.swift new file mode 100644 index 0000000000000..c063c8e8ed81a --- /dev/null +++ b/test/Interpreter/casts.swift @@ -0,0 +1,78 @@ +// RUN: %target-run-simple-swift +// REQUIRES: executable_test +// REQUIRES: objc_interop + +import StdlibUnittest +import Foundation + +protocol P : class { } +protocol C : class { } + +class Foo : NSObject {} +var Casts = TestSuite("Casts") + + +@inline(never) +func castit(_ o: NSObject?, _ t: ObjectType.Type) -> ObjectType? { + return o as? ObjectType +} + +@inline(never) +func castitExistential(_ o: C?, _ t: ObjectType.Type) -> ObjectType? { + return o as? ObjectType +} + +Casts.test("cast optional to protocol") { + if let obj = castit(nil, P.self) { + print("fail") + expectUnreachable() + } else { + print("success") + } +} + +Casts.test("cast optional to protocol meta") { + if let obj = castit(nil, P.Type.self) { + print("fail") + expectUnreachable() + } else { + print("success") + } +} +Casts.test("cast optional to protocol") { + if let obj = castitExistential(nil, P.self) { + print("fail") + expectUnreachable() + } else { + print("success") + } +} + +Casts.test("cast optional to class") { + if let obj = castitExistential(nil, Foo.self) { + print("fail") + expectUnreachable() + } else { + print("success") + } +} + +Casts.test("cast optional to protocol meta") { + if let obj = castitExistential(nil, P.Type.self) { + expectUnreachable() + print("fail") + } else { + print("success") + } +} + +Casts.test("cast optional to class meta") { + if let obj = castitExistential(nil, Foo.Type.self) { + expectUnreachable() + print("fail") + } else { + print("success") + } +} + +runAllTests()