Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion kmir/src/kmir/alloc.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def from_dict(dct: dict[str, Any]) -> GlobalAlloc:
case {'Memory': _}:
return Memory.from_dict(dct)
case _:
raise ValueError('Unsupported or invalid GlobalAlloc data: {dct}')
raise ValueError(f'Unsupported or invalid GlobalAlloc data: {dct}')


@dataclass
Expand Down
34 changes: 16 additions & 18 deletions kmir/src/kmir/decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
ArbitraryFields,
ArrayT,
BoolT,
Direct,
EnumT,
Initialized,
IntT,
Expand Down Expand Up @@ -42,7 +41,7 @@

from pyk.kast import KInner

from .ty import FieldsShape, LayoutShape, MachineSize, Scalar, TagEncoding, Ty, TypeMetadata, UintTy
from .ty import FieldsShape, IntegerLength, LayoutShape, MachineSize, Scalar, TagEncoding, Ty, TypeMetadata, UintTy
from .value import Metadata


Expand Down Expand Up @@ -241,7 +240,7 @@ def _decode_enum(
fields=fields,
offsets=offsets,
# ---
tag_index=index,
index=index,
# ---
types=types,
)
Expand Down Expand Up @@ -282,15 +281,15 @@ def _decode_enum_single(
discriminants: list[int],
fields: list[list[Ty]],
offsets: list[MachineSize],
tag_index: int,
index: int,
types: Mapping[Ty, TypeMetadata],
) -> Value:
assert index == 0, 'Assumed index to always be 0 for Single(index)'

assert len(fields) == 1, 'Expected a single list of field types for single-variant enum'
tys = fields[0]

assert len(discriminants) == 1, 'Expected a single discriminant for single-variant enum'
discriminant = discriminants[0]
assert tag_index == discriminant, 'Assumed tag_index to be the same as the discriminant'

field_values = _decode_fields(data=data, tys=tys, offsets=offsets, types=types)
return AggregateValue(0, field_values)
Expand All @@ -310,18 +309,16 @@ def _decode_enum_multiple(
# ---
types: Mapping[Ty, TypeMetadata],
) -> Value:
if not isinstance(tag_encoding, Direct):
raise ValueError(f'Unsupported encoding: {tag_encoding}')

assert tag_field == 0, 'Assumed tag field to be zero'
assert len(offsets) == 1, 'Assumed offsets to only contain the tag offset'
tag_offset = offsets[0]
tag_value = _extract_tag_value(data=data, tag_offset=tag_offset, tag=tag)
assert tag_field == 0, 'Assumed tag field to be zero accordingly'
tag_offset = offsets[tag_field]
tag_value, width = _extract_tag(data=data, tag_offset=tag_offset, tag=tag)
discriminant = tag_encoding.decode(tag_value, width=width)

try:
variant_idx = discriminants.index(tag_value)
variant_idx = discriminants.index(discriminant)
except ValueError as err:
raise ValueError(f'Tag not found: {tag_value}') from err
raise ValueError(f'Discriminant not found: {discriminant}') from err

tys = fields[variant_idx]

Expand Down Expand Up @@ -350,16 +347,17 @@ def _decode_fields(
return res


def _extract_tag_value(*, data: bytes, tag_offset: MachineSize, tag: Scalar) -> int:
def _extract_tag(*, data: bytes, tag_offset: MachineSize, tag: Scalar) -> tuple[int, IntegerLength]:
match tag:
case Initialized(
value=PrimitiveInt(
length=length,
signed=signed,
signed=False,
),
valid_range=_,
):
tag_data = data[tag_offset.in_bytes : tag_offset.in_bytes + length.value]
return int.from_bytes(tag_data, byteorder='little', signed=signed)
tag_value = int.from_bytes(tag_data, byteorder='little', signed=False)
return tag_value, length
case _:
raise ValueError('Unsupported tag: {tag}')
raise ValueError(f'Unsupported tag: {tag}')
75 changes: 70 additions & 5 deletions kmir/src/kmir/ty.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from abc import ABC
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from functools import cached_property
Expand Down Expand Up @@ -151,6 +151,13 @@ def from_raw(data: Any) -> EnumT:
case _:
raise _cannot_parse_as('EnumT', data)

def nbytes(self, types: Mapping[Ty, TypeMetadata]) -> int:
match self.layout:
case None:
raise ValueError(f'Cannot determine size, layout is missing for: {self}')
case LayoutShape(size=size):
return size.in_bytes


@dataclass
class LayoutShape:
Expand Down Expand Up @@ -349,6 +356,11 @@ class IntegerLength(Enum):
I64 = 8
I128 = 16

def wrapping_sub(self, x: int, y: int) -> int:
bit_width = 8 * self.value
mask = (1 << bit_width) - 1
return (x - y) & mask


@dataclass
class Float(Primitive): ...
Expand All @@ -364,18 +376,71 @@ def from_raw(data: Any) -> TagEncoding:
match data:
case 'Direct':
return Direct()
case {'Niche': _}:
return Niche()
case {
'Niche': {
'untagged_variant': untagged_variant,
'niche_variants': niche_variants,
'niche_start': niche_start,
},
}:
return Niche(
untagged_variant=int(untagged_variant),
niche_variants=RangeInclusive.from_raw(niche_variants),
niche_start=int(niche_start),
)
case _:
raise _cannot_parse_as('TagEncoding', data)

@abstractmethod
def decode(self, tag: int, *, width: IntegerLength) -> int: ...


@dataclass
class Direct(TagEncoding): ...
class Direct(TagEncoding):
def decode(self, tag: int, *, width: IntegerLength) -> int:
# The tag directly stores the discriminant.
return tag


@dataclass
class Niche(TagEncoding): ...
class Niche(TagEncoding):
untagged_variant: int
niche_variants: RangeInclusive
niche_start: int

def decode(self, tag: int, *, width: IntegerLength) -> int:
# For this encoding, the discriminant and variant index of each variant coincide.
# To recover the variant index i from tag:
# i = tag.wrapping_sub(niche_start) + niche_variants.start
# If i ends up outside niche_variants, the tag must have encoded the untagged_variant.
i = width.wrapping_sub(tag, self.niche_start) + self.niche_variants.start
if not i in self.niche_variants:
return self.untagged_variant
return i


class RangeInclusive(NamedTuple):
start: int
end: int

@staticmethod
def from_raw(data: Any) -> RangeInclusive:
match data:
case {
'start': start,
'end': end,
}:
return RangeInclusive(
start=int(start),
end=int(end),
)
case _:
raise _cannot_parse_as('RangeInclusive', data)

def __contains__(self, x: object) -> bool:
if isinstance(x, int):
return self.start <= x <= self.end
raise TypeError('Method RangeInclusive.__contains__ is only supported for int, got: {x}')


class WrappingRange(NamedTuple):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Aggregate ( variantIdx ( 0 ) , .List )
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
{
"bytes": [
0
],
"types": [
[
0,
{
"PrimitiveType": {
"Uint": "U8"
}
}
]
],
"typeInfo": {
"EnumType": {
"name": "core::option::Option<core::num::NonZero<u8>>",
"adt_def": 100,
"discriminants": [
0,
1
],
"fields": [
[],
[
0
]
],
"layout": {
"fields": {
"Arbitrary": {
"offsets": [
{
"num_bits": 0
}
]
}
},
"variants": {
"Multiple": {
"tag": {
"Initialized": {
"value": {
"Int": {
"length": "I8",
"signed": false
}
},
"valid_range": {
"start": 1,
"end": 0
}
}
},
"tag_encoding": {
"Niche": {
"untagged_variant": 1,
"niche_variants": {
"start": 0,
"end": 0
},
"niche_start": 0
}
},
"tag_field": 0,
"variants": [
{
"fields": {
"Arbitrary": {
"offsets": []
}
},
"variants": {
"Single": {
"index": 0
}
},
"abi": {
"Aggregate": {
"sized": true
}
},
"abi_align": 1,
"size": {
"num_bits": 0
}
},
{
"fields": {
"Arbitrary": {
"offsets": [
{
"num_bits": 0
}
]
}
},
"variants": {
"Single": {
"index": 1
}
},
"abi": {
"Scalar": {
"Initialized": {
"value": {
"Int": {
"length": "I8",
"signed": false
}
},
"valid_range": {
"start": 1,
"end": 255
}
}
}
},
"abi_align": 1,
"size": {
"num_bits": 8
}
}
]
}
},
"abi": {
"Scalar": {
"Initialized": {
"value": {
"Int": {
"length": "I8",
"signed": false
}
},
"valid_range": {
"start": 1,
"end": 0
}
}
}
},
"abi_align": 1,
"size": {
"num_bits": 8
}
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Aggregate ( variantIdx ( 1 ) , ListItem ( Integer ( 123 , 8 , false ) ) )
Loading