@@ -21,17 +21,42 @@ public abstract class Aggregate extends SingleInputRel implements HasExtension {
2121
2222 @ Override
2323 protected Type .Struct deriveRecordType () {
24- return TypeCreator .REQUIRED .struct (
25- Stream .concat (
26- // unique grouping expressions
27- getGroupings ().stream ()
28- .flatMap (g -> g .getExpressions ().stream ())
29- .collect (Collectors .toCollection (LinkedHashSet ::new ))
30- .stream ()
31- .map (Expression ::getType ),
32-
33- // measures
34- getMeasures ().stream ().map (t -> t .getFunction ().getType ())));
24+ // If there's only one grouping set (or none), the nullability rule doesn't apply.
25+ if (getGroupings ().size () <= 1 ) {
26+ final Stream <Type > groupingTypes =
27+ getGroupings ().stream ()
28+ .flatMap (g -> g .getExpressions ().stream ())
29+ .map (Expression ::getType );
30+ final Stream <Type > measureTypes = getMeasures ().stream ().map (t -> t .getFunction ().getType ());
31+ return TypeCreator .REQUIRED .struct (Stream .concat (groupingTypes , measureTypes ));
32+ }
33+
34+ final LinkedHashSet <Expression > uniqueGroupingExpressions =
35+ getGroupings ().stream ()
36+ .flatMap (g -> g .getExpressions ().stream ())
37+ .collect (Collectors .toCollection (LinkedHashSet ::new ));
38+
39+ // For each unique grouping expression, determine its final nullability based on the spec.
40+ final Stream <Type > groupingTypes =
41+ uniqueGroupingExpressions .stream ()
42+ .map (
43+ expr -> {
44+ // the code below implements the following statement from the spec
45+ // (https://substrait.io/relations/logical_relations/#aggregate-operation):
46+ // "The values for the grouping expression columns that are not
47+ // part of the grouping set for a particular record will be set to null."
48+ final boolean appearsInAllSets =
49+ getGroupings ().stream ().allMatch (g -> g .getExpressions ().contains (expr ));
50+ if (appearsInAllSets ) {
51+ return expr .getType ();
52+ } else {
53+ return TypeCreator .asNullable (expr .getType ());
54+ }
55+ });
56+
57+ final Stream <Type > measureTypes = getMeasures ().stream ().map (t -> t .getFunction ().getType ());
58+
59+ return TypeCreator .REQUIRED .struct (Stream .concat (groupingTypes , measureTypes ));
3560 }
3661
3762 @ Override
0 commit comments