Skip to content

Commit 40a45fd

Browse files
feat: implement the Schema (#6)
* Implement the Schema. * Address feedback * chore: remove unused pyarrow_fields.mojo --------- Co-authored-by: Krisztian Szucs <szucs.krisztian@gmail.com>
1 parent a44aa46 commit 40a45fd

File tree

9 files changed

+549
-226
lines changed

9 files changed

+549
-226
lines changed

firebolt/arrays/tests/test_nested.mojo

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ from testing import assert_equal, assert_true, assert_false
33

44
from firebolt.arrays import *
55
from firebolt.dtypes import *
6-
from firebolt.arrays.tests.utils import as_bool_array_scalar
6+
from firebolt.test_fixtures.bool_array import as_bool_array_scalar
77

88

99
def test_list_int_array():

firebolt/arrays/tests/test_primitive.mojo

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ from testing import assert_equal, assert_true, assert_false
22

33

44
from firebolt.arrays import *
5-
from firebolt.arrays.tests.utils import as_bool_array_scalar
5+
from firebolt.test_fixtures.bool_array import as_bool_array_scalar
66

77

88
def test_boolean_array():

firebolt/schema.mojo

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
"""Define the Mojo representation of the Arrow Schema.
2+
3+
[Reference](https://arrow.apache.org/docs/python/generated/pyarrow.Schema.html#pyarrow.Schema)
4+
"""
5+
from .dtypes import Field
6+
from .c_data import CArrowSchema
7+
from collections import Dict
8+
from collections.string import StringSlice
9+
10+
11+
@value
12+
struct Schema(Movable):
13+
var fields: List[Field]
14+
var metadata: Dict[String, String]
15+
16+
fn __init__(
17+
out self,
18+
*,
19+
fields: List[Field] = List[Field](),
20+
metadata: Dict[String, String] = Dict[String, String](),
21+
):
22+
"""Initializes a schema with the given fields, if provided."""
23+
self.fields = fields
24+
self.metadata = metadata
25+
26+
@staticmethod
27+
fn from_c(c_arrow_schema: CArrowSchema) raises -> Schema:
28+
"""Initializes a schema from a CArrowSchema."""
29+
var fields = List[Field]()
30+
for i in range(c_arrow_schema.n_children):
31+
var child = c_arrow_schema.children[i]
32+
var field = child[].to_field()
33+
fields.append(field)
34+
35+
return Schema(fields=fields)
36+
37+
fn append(mut self, owned field: Field):
38+
"""Appends a field to the schema."""
39+
self.fields.append(field^)
40+
41+
fn names(self) -> List[String]:
42+
"""Returns the names of the fields in the schema."""
43+
var names = List[String]()
44+
for field in self.fields:
45+
names.append(field[].name)
46+
return names
47+
48+
fn field(
49+
self,
50+
*,
51+
index: Optional[Int] = None,
52+
name: Optional[
53+
StringSlice[mut=False, origin=ImmutableAnyOrigin]
54+
] = None,
55+
) raises -> ref [self.fields] Field:
56+
"""Returns the field at the given index or with the given name."""
57+
if index and name:
58+
raise Error("Either an index or a name must be provided, not both.")
59+
if index:
60+
return self.fields[index.value()]
61+
if not name:
62+
raise Error("Either an index or a name must be provided.")
63+
for field in self.fields:
64+
if field[].name.as_string_slice() == name.value():
65+
return field[]
66+
raise Error(
67+
StringSlice("Field with name `{}` not found.").format(name.value())
68+
)

firebolt/test_fixtures/__init__.mojo

Whitespace-only changes.

firebolt/tests/test_schema.mojo

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
"""Test the schema.mojo file."""
2+
from testing import assert_equal, assert_true
3+
from python import Python, PythonObject
4+
from firebolt.schema import Schema
5+
from firebolt.dtypes import (
6+
int8,
7+
int16,
8+
int32,
9+
int64,
10+
uint8,
11+
uint16,
12+
uint32,
13+
uint64,
14+
)
15+
from firebolt.dtypes import float16, float32, float64, binary, string, list_
16+
from firebolt.c_data import Field, CArrowSchema
17+
from firebolt.test_fixtures.pyarrow_fields import (
18+
build_list_of_list_of_ints,
19+
build_struct,
20+
)
21+
22+
23+
def test_schema_primitive_fields():
24+
"""Test the schema with primitive fields."""
25+
26+
# Create a schema with different data types
27+
fields = List[Field](
28+
Field("field1", int8),
29+
Field("field2", int16),
30+
Field("field3", int32),
31+
Field("field4", int64),
32+
Field("field5", uint8),
33+
Field("field6", uint16),
34+
Field("field7", uint32),
35+
Field("field8", uint64),
36+
Field("field9", float16),
37+
Field("field10", float32),
38+
Field("field11", float64),
39+
Field("field12", binary),
40+
Field("field13", string),
41+
)
42+
43+
var schema = Schema(fields=fields)
44+
45+
# Check the number of fields in the schema
46+
assert_equal(len(schema.fields), len(fields))
47+
48+
# Check the names of the fields in the schema
49+
for i in range(len(fields)):
50+
assert_equal(schema.field(index=i).name, "field" + String(i + 1))
51+
52+
53+
def test_from_c_schema() -> None:
54+
var pa = Python.import_module("pyarrow")
55+
var pa_schema = pa.schema(
56+
[
57+
pa.field("field1", pa.list_(pa.int32())),
58+
pa.field(
59+
"field2",
60+
pa.`struct`(
61+
[
62+
pa.field("field_a", pa.int32()),
63+
pa.field("field_b", pa.float64()),
64+
]
65+
),
66+
),
67+
]
68+
)
69+
70+
var c_schema = CArrowSchema.from_pyarrow(pa_schema)
71+
var schema = Schema.from_c(c_schema)
72+
73+
assert_equal(len(schema.fields), 2)
74+
75+
# Test first field.
76+
var field_0 = schema.field(index=0)
77+
assert_true(field_0.dtype.is_list())
78+
assert_true(field_0.dtype.fields[0].dtype.is_integer())
79+
80+
# Test second field.
81+
var field_1 = schema.field(index=1)
82+
assert_true(field_1.dtype.is_struct())
83+
assert_equal(field_1.dtype.fields[0].name, "field_a")
84+
assert_equal(field_1.dtype.fields[1].name, "field_b")

firebolt/tests/test_utils.mojo

Lines changed: 0 additions & 2 deletions
This file was deleted.

0 commit comments

Comments
 (0)