Skip to content

Commit

Permalink
add implicit cast from enum to union
Browse files Browse the repository at this point in the history
when the enum is the tag type of the union and is comptime known
to be of a void field of the union

See #642
  • Loading branch information
andrewrk committed Dec 6, 2017
1 parent 63a2f9a commit 960914a
Show file tree
Hide file tree
Showing 5 changed files with 129 additions and 8 deletions.
15 changes: 10 additions & 5 deletions src/analyze.cpp
Expand Up @@ -2244,12 +2244,10 @@ static void resolve_union_zero_bits(CodeGen *g, TypeTableEntry *union_type) {
TypeTableEntry *enum_type = analyze_type_expr(g, scope, enum_type_node);
if (type_is_invalid(enum_type)) {
union_type->data.unionation.is_invalid = true;
union_type->data.unionation.embedded_in_current = false;
return;
}
if (enum_type->id != TypeTableEntryIdEnum) {
union_type->data.unionation.is_invalid = true;
union_type->data.unionation.embedded_in_current = false;
add_node_error(g, enum_type_node,
buf_sprintf("expected enum tag type, found '%s'", buf_ptr(&enum_type->name)));
return;
Expand Down Expand Up @@ -3319,7 +3317,7 @@ TypeStructField *find_struct_type_field(TypeTableEntry *type_entry, Buf *name) {

TypeUnionField *find_union_type_field(TypeTableEntry *type_entry, Buf *name) {
assert(type_entry->id == TypeTableEntryIdUnion);
assert(type_entry->data.unionation.complete);
assert(type_entry->data.unionation.zero_bits_known);
for (uint32_t i = 0; i < type_entry->data.unionation.src_field_count; i += 1) {
TypeUnionField *field = &type_entry->data.unionation.fields[i];
if (buf_eql_buf(field->enum_field->name, name)) {
Expand All @@ -3331,7 +3329,7 @@ TypeUnionField *find_union_type_field(TypeTableEntry *type_entry, Buf *name) {

TypeUnionField *find_union_field_by_tag(TypeTableEntry *type_entry, const BigInt *tag) {
assert(type_entry->id == TypeTableEntryIdUnion);
assert(type_entry->data.unionation.complete);
assert(type_entry->data.unionation.zero_bits_known);
assert(type_entry->data.unionation.gen_tag_index != SIZE_MAX);
for (uint32_t i = 0; i < type_entry->data.unionation.src_field_count; i += 1) {
TypeUnionField *field = &type_entry->data.unionation.fields[i];
Expand Down Expand Up @@ -3888,14 +3886,21 @@ bool handle_is_ptr(TypeTableEntry *type_entry) {
return false;
case TypeTableEntryIdArray:
case TypeTableEntryIdStruct:
case TypeTableEntryIdUnion:
return type_has_bits(type_entry);
case TypeTableEntryIdErrorUnion:
return type_has_bits(type_entry->data.error.child_type);
case TypeTableEntryIdMaybe:
return type_has_bits(type_entry->data.maybe.child_type) &&
type_entry->data.maybe.child_type->id != TypeTableEntryIdPointer &&
type_entry->data.maybe.child_type->id != TypeTableEntryIdFn;
case TypeTableEntryIdUnion:
assert(type_entry->data.unionation.complete);
if (type_entry->data.unionation.gen_field_count == 0)
return false;
if (!type_has_bits(type_entry))
return false;
return true;

}
zig_unreachable();
}
Expand Down
5 changes: 2 additions & 3 deletions src/codegen.cpp
Expand Up @@ -3946,8 +3946,6 @@ static LLVMValueRef gen_const_val(CodeGen *g, ConstExprValue *const_val) {
case TypeTableEntryIdUnion:
{
LLVMTypeRef union_type_ref = type_entry->data.unionation.union_type_ref;
ConstExprValue *payload_value = const_val->data.x_union.payload;
assert(payload_value != nullptr);

if (type_entry->data.unionation.gen_field_count == 0) {
if (type_entry->data.unionation.gen_tag_index == SIZE_MAX) {
Expand All @@ -3960,7 +3958,8 @@ static LLVMValueRef gen_const_val(CodeGen *g, ConstExprValue *const_val) {

LLVMValueRef union_value_ref;
bool make_unnamed_struct;
if (!type_has_bits(payload_value->type)) {
ConstExprValue *payload_value = const_val->data.x_union.payload;
if (payload_value == nullptr || !type_has_bits(payload_value->type)) {
if (type_entry->data.unionation.gen_tag_index == SIZE_MAX)
return LLVMGetUndef(type_entry->type_ref);

Expand Down
79 changes: 79 additions & 0 deletions src/ir.cpp
Expand Up @@ -7468,6 +7468,17 @@ static ImplicitCastMatchResult ir_types_match_with_implicit_cast(IrAnalyze *ira,
}
}

// implicit enum to union which has the enum as the tag type
if (expected_type->id == TypeTableEntryIdUnion && actual_type->id == TypeTableEntryIdEnum &&
(expected_type->data.unionation.decl_node->data.container_decl.auto_enum ||
expected_type->data.unionation.decl_node->data.container_decl.init_arg_expr != nullptr))
{
type_ensure_zero_bits_known(ira->codegen, expected_type);
if (expected_type->data.unionation.tag_type == actual_type) {
return ImplicitCastMatchResultYes;
}
}

// implicit undefined literal to anything
if (actual_type->id == TypeTableEntryIdUndefLit) {
return ImplicitCastMatchResultYes;
Expand Down Expand Up @@ -8370,6 +8381,63 @@ static IrInstruction *ir_analyze_undefined_to_anything(IrAnalyze *ira, IrInstruc
return result;
}

static IrInstruction *ir_analyze_enum_to_union(IrAnalyze *ira, IrInstruction *source_instr,
IrInstruction *target, TypeTableEntry *wanted_type)
{
assert(wanted_type->id == TypeTableEntryIdUnion);
assert(target->value.type->id == TypeTableEntryIdEnum);

if (instr_is_comptime(target)) {
ConstExprValue *val = ir_resolve_const(ira, target, UndefBad);
if (!val)
return ira->codegen->invalid_instruction;
TypeUnionField *union_field = find_union_field_by_tag(wanted_type, &val->data.x_enum_tag);
assert(union_field != nullptr);
type_ensure_zero_bits_known(ira->codegen, union_field->type_entry);
if (!union_field->type_entry->zero_bits) {
AstNode *field_node = wanted_type->data.unionation.decl_node->data.container_decl.fields.at(
union_field->enum_field->decl_index);
ErrorMsg *msg = ir_add_error(ira, source_instr,
buf_sprintf("cast to union '%s' must initialize '%s' field '%s'",
buf_ptr(&wanted_type->name),
buf_ptr(&union_field->type_entry->name),
buf_ptr(union_field->name)));
add_error_note(ira->codegen, msg, field_node,
buf_sprintf("field '%s' declared here", buf_ptr(union_field->name)));
return ira->codegen->invalid_instruction;
}
IrInstruction *result = ir_create_const(&ira->new_irb, source_instr->scope,
source_instr->source_node, wanted_type);
result->value.special = ConstValSpecialStatic;
result->value.type = wanted_type;
bigint_init_bigint(&result->value.data.x_union.tag, &val->data.x_enum_tag);
return result;
}

// if the union has all fields 0 bits, we can do it
// and in fact it's a noop cast because the union value is just the enum value
if (wanted_type->data.unionation.gen_field_count == 0) {
IrInstruction *result = ir_build_cast(&ira->new_irb, target->scope, target->source_node, wanted_type, target, CastOpNoop);
result->value.type = wanted_type;
return result;
}

ErrorMsg *msg = ir_add_error(ira, source_instr,
buf_sprintf("runtime cast to union '%s' which has non-void fields",
buf_ptr(&wanted_type->name)));
for (uint32_t i = 0; i < wanted_type->data.unionation.src_field_count; i += 1) {
TypeUnionField *union_field = &wanted_type->data.unionation.fields[i];
if (type_has_bits(union_field->type_entry)) {
AstNode *field_node = wanted_type->data.unionation.decl_node->data.container_decl.fields.at(i);
add_error_note(ira->codegen, msg, field_node,
buf_sprintf("field '%s' has type '%s'",
buf_ptr(union_field->name),
buf_ptr(&union_field->type_entry->name)));
}
}
return ira->codegen->invalid_instruction;
}

static IrInstruction *ir_analyze_widen_or_shorten(IrAnalyze *ira, IrInstruction *source_instr,
IrInstruction *target, TypeTableEntry *wanted_type)
{
Expand Down Expand Up @@ -8919,6 +8987,17 @@ static IrInstruction *ir_analyze_cast(IrAnalyze *ira, IrInstruction *source_inst
}
}

// explicit enum to union which has the enum as the tag type
if (wanted_type->id == TypeTableEntryIdUnion && actual_type->id == TypeTableEntryIdEnum &&
(wanted_type->data.unionation.decl_node->data.container_decl.auto_enum ||
wanted_type->data.unionation.decl_node->data.container_decl.init_arg_expr != nullptr))
{
type_ensure_zero_bits_known(ira->codegen, wanted_type);
if (wanted_type->data.unionation.tag_type == actual_type) {
return ir_analyze_enum_to_union(ira, source_instr, value, wanted_type);
}
}

// explicit cast from undefined to anything
if (actual_type->id == TypeTableEntryIdUndefLit) {
return ir_analyze_undefined_to_anything(ira, source_instr, value, wanted_type);
Expand Down
7 changes: 7 additions & 0 deletions test/cases/union.zig
Expand Up @@ -190,3 +190,10 @@ test "cast union to tag type of union" {
fn testCastUnionToTagType(x: &const TheUnion) {
assert(TheTag(*x) == TheTag.B);
}

test "cast tag type of union to union" {
var x: Value2 = Letter2.B;
assert(Letter2(x) == Letter2.B);
}
const Letter2 = enum { A, B, C };
const Value2 = union(Letter2) { A: i32, B, C, };
31 changes: 31 additions & 0 deletions test/compile_errors.zig
Expand Up @@ -2696,4 +2696,35 @@ pub fn addCases(cases: &tests.CompileErrorContext) {
,
".tmp_source.zig:6:16: error: enum 'Foo' has no tag matching integer value 0",
".tmp_source.zig:1:13: note: 'Foo' declared here");

cases.add("comptime cast enum to union but field has payload",
\\const Letter = enum { A, B, C };
\\const Value = union(Letter) {
\\ A: i32,
\\ B,
\\ C,
\\};
\\export fn entry() {
\\ var x: Value = Letter.A;
\\}
,
".tmp_source.zig:8:26: error: cast to union 'Value' must initialize 'i32' field 'A'",
".tmp_source.zig:3:5: note: field 'A' declared here");

cases.add("runtime cast to union which has non-void fields",
\\const Letter = enum { A, B, C };
\\const Value = union(Letter) {
\\ A: i32,
\\ B,
\\ C,
\\};
\\export fn entry() {
\\ foo(Letter.A);
\\}
\\fn foo(l: Letter) {
\\ var x: Value = l;
\\}
,
".tmp_source.zig:11:20: error: runtime cast to union 'Value' which has non-void fields",
".tmp_source.zig:3:5: note: field 'A' has type 'i32'");
}

0 comments on commit 960914a

Please sign in to comment.