Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 48 additions & 2 deletions lib/IRGen/GenCast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -800,8 +800,35 @@ void irgen::emitScalarCheckedCast(IRGenFunction &IGF,
assert(sourceType.isObject());
assert(targetType.isObject());

if (auto sourceOptObjectType = sourceType.getOptionalObjectType()) {
llvm::BasicBlock *nilCheckBB = nullptr;
llvm::BasicBlock *nilMergeBB = nullptr;

// Merge the nil check and return the merged result: either nil or the value.
auto returnNilCheckedResult = [&](IRBuilder &Builder,
Explosion &nonNilResult) {
if (nilCheckBB) {
auto notNilBB = Builder.GetInsertBlock();
Builder.CreateBr(nilMergeBB);

Builder.emitBlock(nilMergeBB);
// Insert result phi.
Explosion result;
while (!nonNilResult.empty()) {
auto val = nonNilResult.claimNext();
auto valTy = cast<llvm::PointerType>(val->getType());
auto nil = llvm::ConstantPointerNull::get(valTy);
auto phi = Builder.CreatePHI(valTy, 2);
phi->addIncoming(nil, nilCheckBB);
phi->addIncoming(val, notNilBB);
result.add(phi);
}
out = std::move(result);
} else {
out = std::move(nonNilResult);
}
};

if (auto sourceOptObjectType = sourceType.getOptionalObjectType()) {
// Translate the value from an enum representation to a possibly-null
// representation. Note that we assume that this projection is safe
// for the particular case of an optional class-reference or metatype
Expand All @@ -813,6 +840,22 @@ void irgen::emitScalarCheckedCast(IRGenFunction &IGF,
assert(value.empty());
value = std::move(optValue);
sourceType = sourceOptObjectType;

// We need a null-check because the runtime function can't handle null in
// some of the cases.
if (targetType.isExistentialType()) {
auto &Builder = IGF.Builder;
auto val = value.getAll()[0];
auto isNotNil = Builder.CreateICmpNE(
val, llvm::ConstantPointerNull::get(
cast<llvm::PointerType>(val->getType())));
auto *isNotNilContBB = llvm::BasicBlock::Create(IGF.IGM.getLLVMContext());
nilMergeBB = llvm::BasicBlock::Create(IGF.IGM.getLLVMContext());
nilCheckBB = Builder.GetInsertBlock();
Builder.CreateCondBr(isNotNil, isNotNilContBB, nilMergeBB);

Builder.emitBlock(isNotNilContBB);
}
}

// If the source value is a metatype, either do a metatype-to-metatype
Expand Down Expand Up @@ -877,11 +920,14 @@ void irgen::emitScalarCheckedCast(IRGenFunction &IGF,
}

if (targetType.isExistentialType()) {
Explosion outRes;
emitScalarExistentialDowncast(IGF, instance, sourceType, targetType,
mode, /*not a metatype*/ None, out);
mode, /*not a metatype*/ None, outRes);
returnNilCheckedResult(IGF.Builder, outRes);
return;
}

Explosion outRes;
llvm::Value *result = emitClassDowncast(IGF, instance, targetType, mode);
out.add(result);
}
27 changes: 26 additions & 1 deletion test/IRGen/casts.sil
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: %target-swift-frontend -assume-parsing-unqualified-ownership-sil %s -emit-ir -disable-objc-attr-requires-foundation-module | %FileCheck %s
// RUN: %target-swift-frontend -assume-parsing-unqualified-ownership-sil %s -emit-ir -disable-objc-attr-requires-foundation-module | %FileCheck %s -DINT=i%target-ptrsize

// REQUIRES: CPU=i386 || CPU=x86_64
// XFAIL: linux
Expand Down Expand Up @@ -268,6 +268,31 @@ nay:
unreachable
}

// CHECK: define swiftcc {{.*}} @checked_downcast_optional_class_to_ex([[INT]])
// CHECK: entry:
// CHECK: [[V1:%.*]] = inttoptr [[INT]] %0 to %T5casts1AC*
// CHECK: [[V2:%.*]] = icmp ne %T5casts1AC* [[V1]], null
// CHECK: br i1 [[V2]], label %[[LBL:.*]], label
// CHECK: [[LBL]]:
// CHECK: [[V4:%.*]] = bitcast %T5casts1AC* [[V1]] to %swift.type**
// CHECK: load %swift.type*, %swift.type** [[V4]]
sil @checked_downcast_optional_class_to_ex : $@convention(thin) (@guaranteed Optional<A>) -> @owned Optional<CP> {
bb0(%0 : $Optional<A>):
checked_cast_br %0 : $Optional<A> to $CP, bb1, bb2

bb1(%3 : $CP):
%4 = enum $Optional<CP>, #Optional.some!enumelt.1, %3 : $CP
retain_value %0 : $Optional<A>
br bb3(%4 : $Optional<CP>)

bb2:
%7 = enum $Optional<CP>, #Optional.none!enumelt
br bb3(%7 : $Optional<CP>)

bb3(%9 : $Optional<CP>):
return %9 : $Optional<CP>
}

// CHECK-LABEL: define{{( protected)?}} swiftcc void @checked_metatype_to_object_casts
sil @checked_metatype_to_object_casts : $@convention(thin) <T> (@thick Any.Type) -> () {
entry(%e : $@thick Any.Type):
Expand Down
78 changes: 78 additions & 0 deletions test/Interpreter/casts.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
// RUN: %target-run-simple-swift
// REQUIRES: executable_test
// REQUIRES: objc_interop

import StdlibUnittest
import Foundation

protocol P : class { }
protocol C : class { }

class Foo : NSObject {}
var Casts = TestSuite("Casts")


@inline(never)
func castit<ObjectType>(_ o: NSObject?, _ t: ObjectType.Type) -> ObjectType? {
return o as? ObjectType
}

@inline(never)
func castitExistential<ObjectType>(_ o: C?, _ t: ObjectType.Type) -> ObjectType? {
return o as? ObjectType
}

Casts.test("cast optional<nsobject> to protocol") {
if let obj = castit(nil, P.self) {
print("fail")
expectUnreachable()
} else {
print("success")
}
}

Casts.test("cast optional<nsobject> to protocol meta") {
if let obj = castit(nil, P.Type.self) {
print("fail")
expectUnreachable()
} else {
print("success")
}
}
Casts.test("cast optional<protocol> to protocol") {
if let obj = castitExistential(nil, P.self) {
print("fail")
expectUnreachable()
} else {
print("success")
}
}

Casts.test("cast optional<protocol> to class") {
if let obj = castitExistential(nil, Foo.self) {
print("fail")
expectUnreachable()
} else {
print("success")
}
}

Casts.test("cast optional<protocol> to protocol meta") {
if let obj = castitExistential(nil, P.Type.self) {
expectUnreachable()
print("fail")
} else {
print("success")
}
}

Casts.test("cast optional<protocol> to class meta") {
if let obj = castitExistential(nil, Foo.Type.self) {
expectUnreachable()
print("fail")
} else {
print("success")
}
}

runAllTests()