@@ -61,16 +61,28 @@ class StringNGramsOp : public tensorflow::OpKernel {
6161 OP_REQUIRES_OK (context, context->input (" data_splits" , &splits));
6262 const auto & splits_vec = splits->flat <SPLITS_TYPE>();
6363
64- // Validate that the splits are valid indices into data
64+ // Validate that the splits are valid indices into data, only if there are
65+ // splits specified.
6566 const int input_data_size = data->flat <tstring>().size ();
6667 const int splits_vec_size = splits_vec.size ();
67- for (int i = 0 ; i < splits_vec_size; ++i) {
68- bool valid_splits = splits_vec (i) >= 0 ;
69- valid_splits = valid_splits && (splits_vec (i) <= input_data_size);
70- OP_REQUIRES (
71- context, valid_splits,
72- errors::InvalidArgument (" Invalid split value " , splits_vec (i),
73- " , must be in [0," , input_data_size, " ]" ));
68+ if (splits_vec_size > 0 ) {
69+ int prev_split = splits_vec (0 );
70+ OP_REQUIRES (context, prev_split == 0 ,
71+ errors::InvalidArgument (" First split value must be 0, got " ,
72+ prev_split));
73+ for (int i = 1 ; i < splits_vec_size; ++i) {
74+ bool valid_splits = splits_vec (i) >= prev_split;
75+ valid_splits = valid_splits && (splits_vec (i) <= input_data_size);
76+ OP_REQUIRES (context, valid_splits,
77+ errors::InvalidArgument (
78+ " Invalid split value " , splits_vec (i), " , must be in [" ,
79+ prev_split, " , " , input_data_size, " ]" ));
80+ prev_split = splits_vec (i);
81+ }
82+ OP_REQUIRES (context, prev_split == input_data_size,
83+ errors::InvalidArgument (
84+ " Last split value must be data size. Expected " ,
85+ input_data_size, " , got " , prev_split));
7486 }
7587
7688 int num_batch_items = splits_vec.size () - 1 ;
@@ -174,13 +186,31 @@ class StringNGramsOp : public tensorflow::OpKernel {
174186 ngram->append (left_pad_);
175187 ngram->append (separator_);
176188 }
189+ // Only output first num_tokens - 1 pairs of data and separator
177190 for (int n = 0 ; n < num_tokens - 1 ; ++n) {
178191 ngram->append (data[data_start_index + n]);
179192 ngram->append (separator_);
180193 }
181- ngram->append (data[data_start_index + num_tokens - 1 ]);
182- for (int n = 0 ; n < right_padding; ++n) {
183- ngram->append (separator_);
194+ // Handle case when there are no tokens or no right padding as these can
195+ // result in consecutive separators.
196+ if (num_tokens > 0 ) {
197+ // If we have tokens, then output last and then pair each separator with
198+ // the right padding that follows, to ensure ngram ends either with the
199+ // token or with the right pad.
200+ ngram->append (data[data_start_index + num_tokens - 1 ]);
201+ for (int n = 0 ; n < right_padding; ++n) {
202+ ngram->append (separator_);
203+ ngram->append (right_pad_);
204+ }
205+ } else {
206+ // If we don't have tokens, then the last item inserted into the ngram
207+ // has been the separator from the left padding loop above. Hence,
208+ // output right pad and separator and make sure to finish with a
209+ // padding, not a separator.
210+ for (int n = 0 ; n < right_padding - 1 ; ++n) {
211+ ngram->append (right_pad_);
212+ ngram->append (separator_);
213+ }
184214 ngram->append (right_pad_);
185215 }
186216
0 commit comments