|
16 | 16 |
|
17 | 17 | package com.mongodb.internal.connection;
|
18 | 18 |
|
| 19 | +import com.mongodb.internal.connection.netty.NettyByteBuf; |
| 20 | +import org.bson.BsonSerializationException; |
19 | 21 | import org.bson.ByteBuf;
|
20 | 22 | import org.bson.io.OutputBuffer;
|
21 | 23 |
|
|
27 | 29 |
|
28 | 30 | import static com.mongodb.assertions.Assertions.assertTrue;
|
29 | 31 | import static com.mongodb.assertions.Assertions.notNull;
|
| 32 | +import static java.lang.String.format; |
30 | 33 |
|
31 | 34 | /**
|
32 | 35 | * <p>This class is not part of the public API and may be removed or changed at any time</p>
|
@@ -273,6 +276,161 @@ public void close() {
|
273 | 276 | }
|
274 | 277 | }
|
275 | 278 |
|
| 279 | + @Override |
| 280 | + protected int writeCharacters(final String str, final boolean checkForNullCharacters) { |
| 281 | + ensureOpen(); |
| 282 | + ByteBuf buf = getCurrentByteBuffer(); |
| 283 | + if ((buf.remaining() >= str.length() + 1)) { |
| 284 | + if (buf.hasArray()) { |
| 285 | + return writeCharactersOnArray(str, checkForNullCharacters, buf); |
| 286 | + } else if (buf instanceof NettyByteBuf) { |
| 287 | + return writeCharactersOnNettyByteBuf(str, checkForNullCharacters, buf); |
| 288 | + } |
| 289 | + } |
| 290 | + return super.writeCharacters(str, 0, checkForNullCharacters); |
| 291 | + } |
| 292 | + |
| 293 | + private static void validateNoNullSingleByteChars(String str, long chars, int i) { |
| 294 | + long tmp = (chars & 0x7F7F7F7F7F7F7F7FL) + 0x7F7F7F7F7F7F7F7FL; |
| 295 | + tmp = ~(tmp | chars | 0x7F7F7F7F7F7F7F7FL); |
| 296 | + if (tmp != 0) { |
| 297 | + int firstZero = Long.numberOfTrailingZeros(tmp) >>> 3; |
| 298 | + throw new BsonSerializationException(format("BSON cstring '%s' is not valid because it contains a null character " |
| 299 | + + "at index %d", str, i + firstZero)); |
| 300 | + } |
| 301 | + } |
| 302 | + |
| 303 | + private static void validateNoNullAsciiCharacters(String str, long asciiChars, int i) { |
| 304 | + // simplified Hacker's delight search for zero with ASCII chars i.e. which doesn't use the MSB |
| 305 | + long tmp = asciiChars + 0x7F7F7F7F7F7F7F7FL; |
| 306 | + // MSB is 0 iff the byte is 0x00, 1 otherwise |
| 307 | + tmp = ~tmp & 0x8080808080808080L; |
| 308 | + // MSB is 1 iff the byte is 0x00, 0 otherwise |
| 309 | + if (tmp != 0) { |
| 310 | + // there's some 0x00 in the word |
| 311 | + int firstZero = Long.numberOfTrailingZeros(tmp) >> 3; |
| 312 | + throw new BsonSerializationException(format("BSON cstring '%s' is not valid because it contains a null character " |
| 313 | + + "at index %d", str, i + firstZero)); |
| 314 | + } |
| 315 | + } |
| 316 | + |
| 317 | + private int writeCharactersOnNettyByteBuf(String str, boolean checkForNullCharacters, ByteBuf buf) { |
| 318 | + int i = 0; |
| 319 | + io.netty.buffer.ByteBuf nettyBuffer = ((NettyByteBuf) buf).asByteBuf(); |
| 320 | + // readonly buffers, netty buffers and off-heap NIO ByteBuffer |
| 321 | + boolean slowPath = false; |
| 322 | + int batches = str.length() / 8; |
| 323 | + final int writerIndex = nettyBuffer.writerIndex(); |
| 324 | + // this would avoid resizing the buffer while appending: ASCII length + delimiter required space |
| 325 | + nettyBuffer.ensureWritable(str.length() + 1); |
| 326 | + for (int b = 0; b < batches; b++) { |
| 327 | + i = b * 8; |
| 328 | + // read 4 chars at time to preserve the 0x0100 cases |
| 329 | + long evenChars = str.charAt(i) | |
| 330 | + str.charAt(i + 2) << 16 | |
| 331 | + (long) str.charAt(i + 4) << 32 | |
| 332 | + (long) str.charAt(i + 6) << 48; |
| 333 | + long oddChars = str.charAt(i + 1) | |
| 334 | + str.charAt(i + 3) << 16 | |
| 335 | + (long) str.charAt(i + 5) << 32 | |
| 336 | + (long) str.charAt(i + 7) << 48; |
| 337 | + // check that both the second byte and the MSB of the first byte of each pair is 0 |
| 338 | + // needed for cases like \u0100 and \u0080 |
| 339 | + long mergedChars = evenChars | oddChars; |
| 340 | + if ((mergedChars & 0xFF80FF80FF80FF80L) != 0) { |
| 341 | + if (allSingleByteChars(mergedChars)) { |
| 342 | + i = tryWriteAsciiChars(str, checkForNullCharacters, oddChars, evenChars, nettyBuffer, writerIndex, i); |
| 343 | + } |
| 344 | + slowPath = true; |
| 345 | + break; |
| 346 | + } |
| 347 | + // all ASCII - compose them into a single long |
| 348 | + long asciiChars = oddChars << 8 | evenChars; |
| 349 | + if (checkForNullCharacters) { |
| 350 | + validateNoNullAsciiCharacters(str, asciiChars, i); |
| 351 | + } |
| 352 | + nettyBuffer.setLongLE(writerIndex + i, asciiChars); |
| 353 | + } |
| 354 | + if (!slowPath) { |
| 355 | + i = batches * 8; |
| 356 | + // do the rest, if any |
| 357 | + for (; i < str.length(); i++) { |
| 358 | + char c = str.charAt(i); |
| 359 | + if (checkForNullCharacters && c == 0x0) { |
| 360 | + throw new BsonSerializationException(format("BSON cstring '%s' is not valid because it contains a null character " |
| 361 | + + "at index %d", str, i)); |
| 362 | + } |
| 363 | + if (c >= 0x80) { |
| 364 | + slowPath = true; |
| 365 | + break; |
| 366 | + } |
| 367 | + nettyBuffer.setByte(writerIndex + i, c); |
| 368 | + } |
| 369 | + } |
| 370 | + if (slowPath) { |
| 371 | + // ith char is not ASCII: |
| 372 | + position += i; |
| 373 | + buf.position(writerIndex + i); |
| 374 | + return i + super.writeCharacters(str, i, checkForNullCharacters); |
| 375 | + } else { |
| 376 | + nettyBuffer.setByte(writerIndex + str.length(), 0); |
| 377 | + int totalWritten = str.length() + 1; |
| 378 | + position += totalWritten; |
| 379 | + buf.position(writerIndex + totalWritten); |
| 380 | + return totalWritten; |
| 381 | + } |
| 382 | + } |
| 383 | + |
| 384 | + private static boolean allSingleByteChars(long fourChars) { |
| 385 | + return (fourChars & 0xFF00FF00FF00FF00L) == 0; |
| 386 | + } |
| 387 | + |
| 388 | + private static int tryWriteAsciiChars(String str, boolean checkForNullCharacters, |
| 389 | + long oddChars, long evenChars, io.netty.buffer.ByteBuf nettyByteBuf, int writerIndex, int i) { |
| 390 | + // all single byte chars |
| 391 | + long latinChars = oddChars << 8 | evenChars; |
| 392 | + if (checkForNullCharacters) { |
| 393 | + validateNoNullSingleByteChars(str, latinChars, i); |
| 394 | + } |
| 395 | + long msbSetForNonAscii = latinChars & 0x8080808080808080L; |
| 396 | + int firstNonAsciiOffset = Long.numberOfTrailingZeros(msbSetForNonAscii) >> 3; |
| 397 | + // that's a bit cheating :P but later phases will patch the wrongly encoded ones |
| 398 | + nettyByteBuf.setLongLE(writerIndex + i, latinChars); |
| 399 | + i += firstNonAsciiOffset; |
| 400 | + return i; |
| 401 | + } |
| 402 | + |
| 403 | + private int writeCharactersOnArray(String str, boolean checkForNullCharacters, ByteBuf buf) { |
| 404 | + int i = 0; |
| 405 | + byte[] array = buf.array(); |
| 406 | + int pos = buf.position(); |
| 407 | + int len = str.length(); |
| 408 | + for (; i < len; i++) { |
| 409 | + char c = str.charAt(i); |
| 410 | + if (checkForNullCharacters && c == 0x0) { |
| 411 | + throw new BsonSerializationException(format("BSON cstring '%s' is not valid because it contains a null character " |
| 412 | + + "at index %d", str, i)); |
| 413 | + } |
| 414 | + if (c >= 0x80) { |
| 415 | + break; |
| 416 | + } |
| 417 | + array[pos + i] = (byte) c; |
| 418 | + } |
| 419 | + if (i == len) { |
| 420 | + int total = len + 1; |
| 421 | + array[pos + len] = 0; |
| 422 | + position += total; |
| 423 | + buf.position(pos + total); |
| 424 | + return len + 1; |
| 425 | + } |
| 426 | + // ith character is not ASCII |
| 427 | + if (i > 0) { |
| 428 | + position += i; |
| 429 | + buf.position(pos + i); |
| 430 | + } |
| 431 | + return i + super.writeCharacters(str, i, checkForNullCharacters); |
| 432 | + } |
| 433 | + |
276 | 434 | private static final class BufferPositionPair {
|
277 | 435 | private final int bufferIndex;
|
278 | 436 | private int position;
|
|
0 commit comments