Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WFLY-18066 ByteBufferMarshalledValue generates duplicate buffers during a single marshalling operation #16886

Merged
merged 6 commits into from
May 31, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ public Object objectFromByteBuffer(byte[] buf, int offset, int length) throws IO
}

@Override
public byte[] objectToByteBuffer(Object object) throws IOException, InterruptedException {
public byte[] objectToByteBuffer(Object object) throws IOException {
if (object == null) {
return this.objectToByteBuffer(null, 1);
}
Expand Down Expand Up @@ -97,7 +97,7 @@ public ByteBuffer objectToBuffer(Object object) throws IOException {
}

@Override
public byte[] objectToByteBuffer(Object obj, int estimatedSize) throws IOException, InterruptedException {
public byte[] objectToByteBuffer(Object obj, int estimatedSize) throws IOException {
ByteBuffer b = this.objectToBuffer(obj, estimatedSize);
return trim(b);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,16 +67,17 @@ public Any readFrom(ProtoStreamReader reader) throws IOException {
public void writeTo(ProtoStreamWriter writer, Any value) throws IOException {
Object object = value.get();
if (object != null) {
ProtoStreamWriterContext context = ProtoStreamWriterContext.FACTORY.get().apply(writer);
Integer referenceId = context.getReference(object);
try (ProtoStreamWriterContext context = ProtoStreamWriterContext.FACTORY.get().apply(writer)) {
Integer referenceId = context.getReference(object);

// If we already wrote this object to the stream, write the object reference intead
AnyField field = (referenceId == null) ? getField(writer, object) : AnyField.REFERENCE;
writer.writeTag(field.getIndex(), field.getMarshaller().getWireType());
field.getMarshaller().writeTo(writer, (referenceId == null) ? object : referenceId);
// If we already wrote this object to the stream, write the object reference instead
AnyField field = (referenceId == null) ? getField(writer, object) : AnyField.REFERENCE;
writer.writeTag(field.getIndex(), field.getMarshaller().getWireType());
field.getMarshaller().writeTo(writer, (referenceId == null) ? object : referenceId);

if (referenceId == null) {
context.addReference(object);
if (referenceId == null) {
context.addReference(object);
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.OptionalInt;

import org.infinispan.protostream.descriptors.WireType;
import org.wildfly.clustering.marshalling.spi.ByteBufferMarshalledKey;
Expand Down Expand Up @@ -61,15 +62,22 @@ public ByteBufferMarshalledKey<Object> readFrom(ProtoStreamReader reader) throws
public void writeTo(ProtoStreamWriter writer, ByteBufferMarshalledKey<Object> key) throws IOException {
ByteBuffer buffer = key.getBuffer();
if (buffer != null) {
writer.writeBytes(BUFFER_INDEX, buffer.mark());
buffer.reset();
writer.writeBytes(BUFFER_INDEX, buffer);
}
int hashCode = key.hashCode();
if (hashCode != 0) {
writer.writeSFixed32(HASH_CODE_INDEX, hashCode);
}
}

@Override
public OptionalInt size(ProtoStreamOperation context, ByteBufferMarshalledKey<Object> key) {
if (key.isEmpty()) return OptionalInt.of(0);
int hashCodeSize = WireType.FIXED_32_SIZE + context.tagSize(HASH_CODE_INDEX, WireType.FIXED32);
OptionalInt size = key.size();
return size.isPresent() ? OptionalInt.of(context.tagSize(BUFFER_INDEX, WireType.LENGTH_DELIMITED) + context.varIntSize(size.getAsInt()) + size.getAsInt() + hashCodeSize) : OptionalInt.empty();
}

@SuppressWarnings("unchecked")
@Override
public Class<? extends ByteBufferMarshalledKey<Object>> getJavaClass() {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*
* JBoss, Home of Professional Open Source.
* Copyright 2023, Red Hat, Inc., and individual contributors
* as indicated by the @author tags. See the copyright.txt file in the
* distribution for a full listing of individual contributors.
*
* This is free software; you can redistribute it and/or modify it
* under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2.1 of
* the License, or (at your option) any later version.
*
* This software is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this software; if not, write to the Free
* Software Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA
* 02110-1301 USA, or see the FSF site: http://www.fsf.org.
*/

package org.wildfly.clustering.marshalling.protostream;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.OptionalInt;

import org.infinispan.protostream.descriptors.WireType;
import org.wildfly.clustering.marshalling.spi.ByteBufferMarshalledValue;

/**
* {@link ProtoStreamMarshaller} for a {@link ByteBufferMarshalledValue}.
* @author Paul Ferraro
*/
public class ByteBufferMarshalledValueMarshaller implements ProtoStreamMarshaller<ByteBufferMarshalledValue<Object>> {

private static final int BUFFER_INDEX = 1;

@Override
public ByteBufferMarshalledValue<Object> readFrom(ProtoStreamReader reader) throws IOException {
ByteBuffer buffer = null;
while (!reader.isAtEnd()) {
int tag = reader.readTag();
switch (WireType.getTagFieldNumber(tag)) {
case BUFFER_INDEX:
buffer = reader.readByteBuffer();
break;
default:
reader.skipField(tag);
}
}
return new ByteBufferMarshalledValue<>(buffer);
}

@Override
public void writeTo(ProtoStreamWriter writer, ByteBufferMarshalledValue<Object> key) throws IOException {
ByteBuffer buffer = key.getBuffer();
if (buffer != null) {
writer.writeBytes(BUFFER_INDEX, buffer);
}
}

@Override
public OptionalInt size(ProtoStreamOperation context, ByteBufferMarshalledValue<Object> value) {
if (value.isEmpty()) return OptionalInt.of(0);
OptionalInt size = value.size();
return size.isPresent() ? OptionalInt.of(context.tagSize(BUFFER_INDEX, WireType.LENGTH_DELIMITED) + context.varIntSize(size.getAsInt()) + size.getAsInt()) : OptionalInt.empty();
}

@SuppressWarnings("unchecked")
@Override
public Class<? extends ByteBufferMarshalledValue<Object>> getJavaClass() {
return (Class<ByteBufferMarshalledValue<Object>>) (Class<?>) ByteBufferMarshalledValue.class;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

import org.infinispan.protostream.ImmutableSerializationContext;
import org.infinispan.protostream.ProtobufTagMarshaller.OperationContext;
import org.infinispan.protostream.impl.TagWriterImpl;

/**
* @author Paul Ferraro
Expand All @@ -32,6 +33,10 @@ public class DefaultProtoStreamOperation implements ProtoStreamOperation, Operat

private final OperationContext context;

public DefaultProtoStreamOperation(ImmutableSerializationContext context) {
this(TagWriterImpl.newInstance(context));
}

public DefaultProtoStreamOperation(OperationContext context) {
this.context = context;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,26 +45,16 @@ public DefaultProtoStreamWriter(WriteContext context) {
public void writeObjectNoTag(Object value) throws IOException {
ImmutableSerializationContext context = this.getSerializationContext();
ProtoStreamMarshaller<Object> marshaller = this.findMarshaller(value.getClass());
OptionalInt size = this.size(marshaller, value);
OptionalInt size = marshaller.size(this, value);
try (ByteBufferOutputStream output = new ByteBufferOutputStream(size)) {
TagWriterImpl writer = size.isPresent() ? TagWriterImpl.newInstance(context, output, size.getAsInt()) : TagWriterImpl.newInstance(context, output);
TagWriterImpl writer = size.isPresent() ? TagWriterImpl.newInstance(context, output, size.getAsInt()) : TagWriterImpl.newInstance(context, output);
marshaller.writeTo(new DefaultProtoStreamWriter(writer), value);
writer.flush();
ByteBuffer buffer = output.getBuffer();
ByteBuffer buffer = output.getBuffer(); // Buffer is array backed
int offset = buffer.arrayOffset();
int length = buffer.limit() - offset;
this.writeVarint32(length);
this.writeRawBytes(buffer.array(), offset, length);
}
}

private OptionalInt size(ProtoStreamMarshaller<Object> marshaller, Object value) {
SizeComputingProtoStreamWriter writer = new SizeComputingProtoStreamWriter(this.getSerializationContext());
try (ProtoStreamWriterContext context = ProtoStreamWriterContext.FACTORY.get().apply(writer)) {
marshaller.writeTo(writer, value);
return writer.get();
} catch (IOException e) {
return OptionalInt.empty();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
package org.wildfly.clustering.marshalling.protostream;

import java.io.IOException;
import java.util.OptionalInt;

/**
* Interface inherited by marshallable components.
Expand All @@ -47,6 +48,22 @@ public interface Marshallable<T> {
*/
void writeTo(ProtoStreamWriter writer, T value) throws IOException;

/**
* Computes the size of the specified object.
* @param context the marshalling operation
* @param value the value whose size is to be calculated
* @return an optional buffer size, only present if the buffer size could be computed
*/
default OptionalInt size(ProtoStreamOperation operation, T value) {
SizeComputingProtoStreamWriter writer = new SizeComputingProtoStreamWriter(operation.getSerializationContext());
try (ProtoStreamWriterContext ctx = ProtoStreamWriterContext.FACTORY.get().apply(writer)) {
this.writeTo(writer, value);
return writer.get();
} catch (IOException e) {
return OptionalInt.empty();
}
}

/**
* Returns the type of object handled by this marshallable instance.
* @return the type of object handled by this marshallable instance.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,12 @@

package org.wildfly.clustering.marshalling.protostream;

import java.nio.ByteBuffer;

import org.wildfly.clustering.marshalling.spi.ByteBufferMarshalledValue;

/**
* @author Paul Ferraro
*/
public enum MarshallingMarshallerProvider implements ProtoStreamMarshallerProvider {
BYTE_BUFFER_MARSHALLED_KEY(new ByteBufferMarshalledKeyMarshaller()),
BYTE_BUFFER_MARSHALLED_VALUE(new FunctionalScalarMarshaller<>(Scalar.BYTE_BUFFER.cast(ByteBuffer.class), () -> new ByteBufferMarshalledValue<>(null), ByteBufferMarshalledValue::isEmpty, ByteBufferMarshalledValue::getBuffer, ByteBufferMarshalledValue::new)),
BYTE_BUFFER_MARSHALLED_VALUE(new ByteBufferMarshalledValueMarshaller()),
;
private final ProtoStreamMarshaller<?> marshaller;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,8 @@ public ProtoStreamByteBufferMarshaller(ImmutableSerializationContext context) {
public OptionalInt size(Object value) {
try (ProtoStreamWriterContext.Factory factory = ProtoStreamWriterContext.FACTORY.get()) {
ProtoStreamMarshaller<Any> marshaller = (ProtoStreamMarshaller<Any>) this.context.getMarshaller(Any.class);
SizeComputingProtoStreamWriter writer = new SizeComputingProtoStreamWriter(this.context);
try {
marshaller.writeTo(writer, new Any(value));
return writer.get();
} catch (IOException e) {
return OptionalInt.empty();
}
ProtoStreamOperation operation = new DefaultProtoStreamOperation(this.context);
return marshaller.size(operation, new Any(value));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
package org.wildfly.clustering.marshalling.protostream;

import java.io.IOException;
import java.util.OptionalInt;

/**
* Provides a {@link ProtoStreamMarshaller}.
Expand Down Expand Up @@ -51,6 +52,11 @@ default void write(WriteContext context, Object value) throws IOException {
this.cast(Object.class).write(context, value);
}

@Override
default OptionalInt size(ProtoStreamOperation operation, Object value) {
return this.cast(Object.class).size(operation, value);
}

@Override
default Class<? extends Object> getJavaClass() {
return this.getMarshaller().getJavaClass();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@

package org.wildfly.clustering.marshalling.protostream;

import java.io.IOException;

import org.infinispan.protostream.ImmutableSerializationContext;
import org.infinispan.protostream.descriptors.WireType;
import org.infinispan.protostream.impl.TagWriterImpl;

/**
* Common interface of {@link ProtoStreamReader} and {@link ProtoStreamWriter}.
Expand Down Expand Up @@ -71,4 +75,29 @@ default <T, V extends T> ProtoStreamMarshaller<T> findMarshaller(Class<V> javaCl
}
throw exception;
}

/**
* Returns the marshalled size of the protobuf tag containing the specified field index and wire type.
* @param index a field index
* @param type a wire type
* @return the marshalled size of the protobuf tag
*/
default int tagSize(int index, WireType type) {
return this.varIntSize(WireType.makeTag(index, type));
}

/**
* Returns the marshalled size of the specified variable-width integer.
* @param index a variable-width integer
* @return the marshalled size of the specified variable-width integer.
*/
default int varIntSize(int value) {
TagWriterImpl writer = TagWriterImpl.newInstance(this.getSerializationContext());
try {
writer.writeVarint32(value);
return writer.getWrittenBytes();
} catch (IOException e) {
return WireType.MAX_VARINT_SIZE;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@
public interface ProtoStreamWriter extends ProtoStreamOperation, TagWriter {

default Context getContext() {
return ProtoStreamWriterContext.FACTORY.get().apply(this);
try (ProtoStreamWriterContext context = ProtoStreamWriterContext.FACTORY.get().apply(this)) {
return context;
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,17 +45,18 @@ protected Factory initialValue() {
}

class DefaultFactory implements Factory {
final Map<Class<?>, ProtoStreamWriterContext> contexts = new IdentityHashMap<>(2);
final Map<Class<?>, DefaultProtoStreamWriterContext> contexts = new IdentityHashMap<>(2);

@Override
public ProtoStreamWriterContext apply(ProtoStreamWriter writer) {
return this.contexts.computeIfAbsent(writer.getClass(), DefaultProtoStreamWriterContext::new);
return this.contexts.computeIfAbsent(writer.getClass(), DefaultProtoStreamWriterContext::new).open();
}

class DefaultProtoStreamWriterContext implements ProtoStreamWriterContext, Function<Object, Integer> {
private final Class<?> writerClass;
private final Map<Object, Integer> references = new IdentityHashMap<>(64);
private int index = 0;
private int reference = 0; // Enumerates object references
private int index = 0; // Tracks context lifecycle

DefaultProtoStreamWriterContext(Class<?> targetClass) {
this.writerClass = targetClass;
Expand All @@ -73,12 +74,19 @@ public void addReference(Object object) {

@Override
public Integer apply(Object key) {
return this.index++;
return this.reference++;
}

ProtoStreamWriterContext open() {
this.index += 1;
return this;
}

@Override
public void close() {
DefaultFactory.this.contexts.remove(this.writerClass);
if (--this.index <= 0) {
DefaultFactory.this.contexts.remove(this.writerClass);
}
}
}
}
Expand Down