@@ -69,6 +69,10 @@ struct TWideUnboxedHasher
6969 const TKeyTypes& Types;
7070};
7171
72+ bool HasMemoryForProcessing () {
73+ return !TlsAllocState->IsMemoryYellowZoneEnabled ();
74+ }
75+
7276using TEqualsPtr = bool (*)(const NUdf::TUnboxedValuePod*, const NUdf::TUnboxedValuePod*);
7377using THashPtr = NUdf::THashType(*)(const NUdf::TUnboxedValuePod*);
7478
@@ -99,21 +103,38 @@ struct TStorageWrapper
99103 }
100104};
101105
102- std::optional<size_t > EstimateUvPackSize (const TUnboxedValuePod* items, size_t width ) {
106+ std::optional<size_t > EstimateUvPackSize (const TArrayRef< const TUnboxedValuePod> items, const TArrayRef<TType* const > types ) {
103107 constexpr const size_t uvSize = sizeof (TUnboxedValuePod);
104108
105109 size_t sizeSum = 0 ;
106110
107- const TUnboxedValuePod* itemPtr = items;
108- for (size_t i = 0 ; i < width; ++i, ++itemPtr) {
109- const TUnboxedValuePod& item = *itemPtr;
111+ auto currType = types.begin ();
112+ for (const auto & item : items) {
110113 if (!item.HasValue () || item.IsEmbedded () || item.IsInvalid ()) {
111114 sizeSum += uvSize;
112115 } else if (item.IsString ()) {
113116 sizeSum += uvSize + item.AsStringRef ().Size ();
114- } else {
117+ } else if (!item. IsBoxed ()) {
115118 return {};
119+ } else {
120+ auto ty = *currType;
121+ while (ty->IsOptional ()) {
122+ ty = AS_TYPE (TOptionalType, ty)->GetItemType ();
123+ }
124+ if (ty->IsTuple ()) {
125+ auto tupleType = AS_TYPE (TTupleType, ty);
126+ auto elements = tupleType->GetElements ();
127+ auto tupleSize = EstimateUvPackSize (TArrayRef (item.GetElements (), elements.size ()), elements);
128+ if (!tupleSize.has_value ()) {
129+ return {};
130+ }
131+ // Tuple contents are generally boxed into a TDirectArrayHolderInplace instance
132+ sizeSum += uvSize + sizeof (TDirectArrayHolderInplace) + tupleSize.value ();
133+ } else {
134+ return {};
135+ }
116136 }
137+ ++currType;
117138 }
118139
119140 return sizeSum;
@@ -124,25 +145,40 @@ class TMemoryEstimationHelper
124145{
125146private:
126147 static std::optional<size_t > GetUVSizeBound (TType* type) {
127- using NYql::NUdf::EDataSlot;
148+ if (type->IsData ()) {
149+ using NYql::NUdf::EDataSlot;
128150
129- bool optional = false ;
130- auto dataSlot = UnpackOptionalData (type, optional)->GetDataSlot ();
151+ bool optional = false ;
152+ auto dataSlot = UnpackOptionalData (type, optional)->GetDataSlot ();
131153
132- if (dataSlot.Empty ()) {
133- return {};
134- }
154+ if (dataSlot.Empty ()) {
155+ return {};
156+ }
135157
136- switch (dataSlot.GetRef ()) {
137- case EDataSlot::DyNumber:
138- case EDataSlot::Json:
139- case EDataSlot::JsonDocument:
140- case EDataSlot::Yson:
141- case EDataSlot::Utf8:
142- case EDataSlot::String:
158+ switch (dataSlot.GetRef ()) {
159+ case EDataSlot::DyNumber:
160+ case EDataSlot::Json:
161+ case EDataSlot::JsonDocument:
162+ case EDataSlot::Yson:
163+ case EDataSlot::Utf8:
164+ case EDataSlot::String:
165+ return {};
166+ default :
167+ return sizeof (TUnboxedValuePod);
168+ }
169+ } else if (type->IsTuple ()) {
170+ size_t result = 0 ;
171+ const auto tupleElements = AS_TYPE (TTupleType, type)->GetElements ();
172+ for (auto * element : tupleElements) {
173+ auto sz = GetUVSizeBound (element);
174+ if (!sz.has_value ()) {
175+ return {};
176+ }
177+ result += sz.value ();
178+ }
179+ return result + sizeof (TUnboxedValuePod);
180+ } else {
143181 return {};
144- default :
145- return sizeof (TUnboxedValuePod);
146182 }
147183 }
148184
@@ -163,17 +199,21 @@ class TMemoryEstimationHelper
163199 std::optional<size_t > StateSizeBound;
164200 std::optional<size_t > KeySizeBound;
165201 const size_t KeyWidth;
202+ const std::vector<TType*> KeyItemTypes;
166203
167204 TMemoryEstimationHelper (std::vector<TType*> keyItemTypes, std::vector<TType*> stateItemTypes)
168205 : KeyWidth(keyItemTypes.size())
206+ , KeyItemTypes(keyItemTypes)
169207 {
170208 KeySizeBound = GetMultiUVSizeBound (keyItemTypes);
171209 StateSizeBound = GetMultiUVSizeBound (stateItemTypes);
172210 }
173211
174212 std::optional<size_t > EstimateKeySize (const TUnboxedValuePod* keyPack) const
175213 {
176- return EstimateUvPackSize (keyPack, KeyWidth);
214+ return EstimateUvPackSize (
215+ TArrayRef<const TUnboxedValuePod>(keyPack, KeyWidth),
216+ TArrayRef<TType* const >(KeyItemTypes.begin (), KeyItemTypes.end ()));
177217 }
178218};
179219
@@ -215,16 +255,19 @@ class TGenericAggregation: public IAggregation
215255 const NDqHashOperatorCommon::TCombinerNodes& Nodes;
216256 size_t StateWidth;
217257 size_t StateSize;
258+ const std::vector<TType*>& StateItemTypes;
218259
219260public:
220261 TGenericAggregation (
221262 TComputationContext& ctx,
222- const NDqHashOperatorCommon::TCombinerNodes& nodes
263+ const NDqHashOperatorCommon::TCombinerNodes& nodes,
264+ const std::vector<TType*>& stateItemTypes
223265 )
224266 : Ctx(ctx)
225267 , Nodes(nodes)
226268 , StateWidth(Nodes.StateNodes.size())
227269 , StateSize(StateWidth * sizeof (TUnboxedValue))
270+ , StateItemTypes(stateItemTypes)
228271 {
229272 }
230273
@@ -233,7 +276,10 @@ class TGenericAggregation: public IAggregation
233276 }
234277
235278 std::optional<size_t > GetStateMemoryUsage (void * rawState) const override {
236- return EstimateUvPackSize (static_cast <const TUnboxedValuePod*>(rawState), StateWidth);
279+ return EstimateUvPackSize (
280+ TArrayRef<const TUnboxedValuePod>(static_cast <const TUnboxedValuePod*>(rawState), StateWidth),
281+ TArrayRef<TType* const >(StateItemTypes)
282+ );
237283 }
238284
239285 // Assumes the input row and extracted keys have already been copied into the input nodes, so row isn't even used here
@@ -412,7 +458,7 @@ class TBaseAggregationState: public TComputationValue<TBaseAggregationState>
412458 }
413459
414460 if (isNew) {
415- if (Map->GetSize () >= MaxRowCount) {
461+ if (Map->GetSize () >= MaxRowCount || (! HasMemoryForProcessing () && Map-> GetSize () >= LowerFixedRowCount) ) {
416462 OpenDrain ();
417463 return EFillState::Drain;
418464 }
@@ -433,7 +479,7 @@ class TBaseAggregationState: public TComputationValue<TBaseAggregationState>
433479
434480 TBaseAggregationState (
435481 TMemoryUsageInfo* memInfo, TComputationContext& ctx, const TMemoryEstimationHelper& memoryHelper, size_t memoryLimit, size_t inputWidth,
436- const NDqHashOperatorCommon::TCombinerNodes& nodes, ui32 wideFieldsIndex, const TKeyTypes& keyTypes
482+ const NDqHashOperatorCommon::TCombinerNodes& nodes, ui32 wideFieldsIndex, const TKeyTypes& keyTypes, const std::vector<TType*>& stateItemTypes
437483 )
438484 : TBase(memInfo)
439485 , Ctx(ctx)
@@ -445,7 +491,6 @@ class TBaseAggregationState: public TComputationValue<TBaseAggregationState>
445491 , KeyTypes(keyTypes)
446492 , Hasher(TWideUnboxedHasher(KeyTypes))
447493 , Equals(TWideUnboxedEqual(KeyTypes))
448- , HasGenericAggregation(nodes.StateNodes.size() > 0 )
449494 , KeyStateBuffer(nullptr )
450495 , Draining(false )
451496 , SourceEmpty(false )
@@ -459,7 +504,7 @@ class TBaseAggregationState: public TComputationValue<TBaseAggregationState>
459504 MaxRowCount = TryAllocMapForRowCount (MaxRowCount);
460505
461506 if (HasGenericAggregation) {
462- Aggs.push_back (std::make_unique<TGenericAggregation>(Ctx, Nodes));
507+ Aggs.push_back (std::make_unique<TGenericAggregation>(Ctx, Nodes, stateItemTypes ));
463508 }
464509
465510 MKQL_ENSURE (Aggs.size (), " No aggregations defined" );
@@ -489,6 +534,7 @@ class TBaseAggregationState: public TComputationValue<TBaseAggregationState>
489534 size_t TryAllocMapForRowCount (size_t rowCount)
490535 {
491536 // Avoid reallocating the map
537+ // TODO: although Clear()-ing might be actually more expensive than reallocation
492538 if (Map) {
493539 const size_t oldCapacity = Map->GetCapacity ();
494540 size_t newCapacity = GetMapCapacity (rowCount);
@@ -503,6 +549,10 @@ class TBaseAggregationState: public TComputationValue<TBaseAggregationState>
503549 size_t newCapacity = GetMapCapacity (rows);
504550 try {
505551 Map.Reset (new TMap (Hasher, Equals, newCapacity));
552+ if (!HasMemoryForProcessing ()) {
553+ Map.Reset (nullptr );
554+ return false ;
555+ }
506556 return true ;
507557 }
508558 catch (TMemoryLimitExceededException) {
@@ -517,6 +567,7 @@ class TBaseAggregationState: public TComputationValue<TBaseAggregationState>
517567 rowCount = rowCount / 2 ;
518568 }
519569
570+ // This can emit uncaught TMemoryLimitExceededException if we can't afford even a tiny map
520571 size_t smallCapacity = GetMapCapacity (LowerFixedRowCount);
521572 Map.Reset (new TMap (Hasher, Equals, smallCapacity));
522573 return LowerFixedRowCount;
@@ -596,7 +647,7 @@ class TBaseAggregationState: public TComputationValue<TBaseAggregationState>
596647 const TKeyTypes& KeyTypes;
597648 THashFunc const Hasher;
598649 TEqualsFunc const Equals;
599- const bool HasGenericAggregation;
650+ constexpr static const bool HasGenericAggregation = true ;
600651
601652 using TStore = TStorageWrapper<char >;
602653 std::unique_ptr<TStore> Store;
@@ -631,9 +682,11 @@ class TWideAggregationState: public TBaseAggregationState
631682 size_t outputWidth,
632683 const NDqHashOperatorCommon::TCombinerNodes& nodes,
633684 ui32 wideFieldsIndex,
634- const TKeyTypes& keyTypes
685+ const TKeyTypes& keyTypes,
686+ const std::vector<TType*>& stateItemTypes
687+
635688 )
636- : TBaseAggregationState(memInfo, ctx, memoryHelper, memoryLimit, inputWidth, nodes, wideFieldsIndex, keyTypes)
689+ : TBaseAggregationState(memInfo, ctx, memoryHelper, memoryLimit, inputWidth, nodes, wideFieldsIndex, keyTypes, stateItemTypes )
637690 , StartMoment(TInstant::Now()) // Temporary. Helps correlate debug outputs with SVGs
638691 , OutputWidth(outputWidth)
639692 , DrainMapIterator(nullptr )
@@ -830,9 +883,10 @@ class TBlockAggregationState: public TBaseAggregationState
830883 size_t inputWidth,
831884 const NDqHashOperatorCommon::TCombinerNodes& nodes,
832885 ui32 wideFieldsIndex,
833- const TKeyTypes& keyTypes
886+ const TKeyTypes& keyTypes,
887+ const std::vector<TType*>& stateItemTypes
834888 )
835- : TBaseAggregationState(memInfo, ctx, memoryHelper, memoryLimit, inputWidth, nodes, wideFieldsIndex, keyTypes)
889+ : TBaseAggregationState(memInfo, ctx, memoryHelper, memoryLimit, inputWidth, nodes, wideFieldsIndex, keyTypes, stateItemTypes )
836890 , InputTypes(inputTypes)
837891 , OutputTypes(outputTypes)
838892 , InputColumns(inputTypes.size() - 1 )
@@ -1156,6 +1210,7 @@ class TDqHashCombineFlowWrapper: public TStatefulWideFlowCodegeneratorNode<TDqHa
11561210 , Source(source)
11571211 , InputTypes(inputTypes)
11581212 , OutputTypes(outputTypes)
1213+ , StateItemTypes(stateItemTypes)
11591214 , InputWidth(inputWidth)
11601215 , Nodes(std::move(nodes))
11611216 , KeyTypes(std::move(keyTypes))
@@ -1437,16 +1492,17 @@ class TDqHashCombineFlowWrapper: public TStatefulWideFlowCodegeneratorNode<TDqHa
14371492 UDF_LOG (logger, logComponent, NUdf::ELogLevel::Debug, TStringBuilder () << " State initialized" );
14381493
14391494 if (!BlockMode) {
1440- state = ctx.HolderFactory .Create <TWideAggregationState>(ctx, MemoryHelper, MemoryLimit, InputWidth, OutputTypes.size (), Nodes, WideFieldsIndex, KeyTypes);
1495+ state = ctx.HolderFactory .Create <TWideAggregationState>(ctx, MemoryHelper, MemoryLimit, InputWidth, OutputTypes.size (), Nodes, WideFieldsIndex, KeyTypes, StateItemTypes );
14411496 } else {
1442- state = ctx.HolderFactory .Create <TBlockAggregationState>(ctx, MemoryHelper, MemoryLimit, InputTypes, OutputTypes, InputWidth, Nodes, WideFieldsIndex, KeyTypes);
1497+ state = ctx.HolderFactory .Create <TBlockAggregationState>(ctx, MemoryHelper, MemoryLimit, InputTypes, OutputTypes, InputWidth, Nodes, WideFieldsIndex, KeyTypes, StateItemTypes );
14431498 }
14441499 }
14451500
14461501 const bool BlockMode;
14471502 IComputationWideFlowNode *const Source;
14481503 std::vector<TType*> InputTypes;
14491504 std::vector<TType*> OutputTypes;
1505+ const std::vector<TType*> StateItemTypes;
14501506 size_t InputWidth;
14511507 const NDqHashOperatorCommon::TCombinerNodes Nodes;
14521508 const TKeyTypes KeyTypes;
@@ -1473,6 +1529,7 @@ class TDqHashCombineStreamWrapper: public TMutableComputationNode<TDqHashCombine
14731529 , StreamSource(streamSource)
14741530 , InputTypes(inputTypes)
14751531 , OutputTypes(outputTypes)
1532+ , StateItemTypes(stateItemTypes)
14761533 , InputWidth(inputWidth)
14771534 , Nodes(std::move(nodes))
14781535 , KeyTypes(std::move(keyTypes))
@@ -1499,9 +1556,9 @@ class TDqHashCombineStreamWrapper: public TMutableComputationNode<TDqHashCombine
14991556 UDF_LOG (logger, logComponent, NUdf::ELogLevel::Debug, TStringBuilder () << " State initialized" );
15001557
15011558 if (!BlockMode) {
1502- state = ctx.HolderFactory .Create <TWideAggregationState>(ctx, MemoryHelper, MemoryLimit, InputWidth, OutputTypes.size (), Nodes, WideFieldsIndex, KeyTypes);
1559+ state = ctx.HolderFactory .Create <TWideAggregationState>(ctx, MemoryHelper, MemoryLimit, InputWidth, OutputTypes.size (), Nodes, WideFieldsIndex, KeyTypes, StateItemTypes );
15031560 } else {
1504- state = ctx.HolderFactory .Create <TBlockAggregationState>(ctx, MemoryHelper, MemoryLimit, InputTypes, OutputTypes, InputWidth, Nodes, WideFieldsIndex, KeyTypes);
1561+ state = ctx.HolderFactory .Create <TBlockAggregationState>(ctx, MemoryHelper, MemoryLimit, InputTypes, OutputTypes, InputWidth, Nodes, WideFieldsIndex, KeyTypes, StateItemTypes );
15051562 }
15061563 }
15071564
@@ -1517,6 +1574,7 @@ class TDqHashCombineStreamWrapper: public TMutableComputationNode<TDqHashCombine
15171574 IComputationNode *const StreamSource;
15181575 std::vector<TType*> InputTypes;
15191576 std::vector<TType*> OutputTypes;
1577+ const std::vector<TType*> StateItemTypes;
15201578 size_t InputWidth;
15211579 const NDqHashOperatorCommon::TCombinerNodes Nodes;
15221580 const TKeyTypes KeyTypes;
0 commit comments