Skip to content

Commit

Permalink
Call Cipher APIs with non-direct ByteBuffers and perform copies in th…
Browse files Browse the repository at this point in the history
…e ALTS code. (cl/308901367)
  • Loading branch information
veblush authored and ejona86 committed May 27, 2020
1 parent d711614 commit a7bca23
Showing 1 changed file with 55 additions and 44 deletions.
99 changes: 55 additions & 44 deletions alts/src/main/java/io/grpc/alts/internal/AltsChannelCrypter.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
package io.grpc.alts.internal;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Verify.verify;

import com.google.common.annotations.VisibleForTesting;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import java.nio.ByteBuffer;
import java.security.GeneralSecurityException;
import java.util.List;
Expand Down Expand Up @@ -56,61 +56,72 @@ static int getCounterLength() {

@Override
public void encrypt(ByteBuf outBuf, List<ByteBuf> plainBufs) throws GeneralSecurityException {
checkArgument(outBuf.nioBufferCount() == 1);
// Copy plaintext buffers into outBuf for in-place encryption on single direct buffer.
ByteBuf plainBuf = outBuf.slice(outBuf.writerIndex(), outBuf.writableBytes());
plainBuf.writerIndex(0);
for (ByteBuf inBuf : plainBufs) {
plainBuf.writeBytes(inBuf);
byte[] tempArr = new byte[outBuf.writableBytes()];

// Copy plaintext into tempArr.
{
ByteBuf tempBuf = Unpooled.wrappedBuffer(tempArr, 0, tempArr.length - TAG_LENGTH);
tempBuf.resetWriterIndex();
for (ByteBuf plainBuf : plainBufs) {
tempBuf.writeBytes(plainBuf);
}
}

verify(outBuf.writableBytes() == plainBuf.readableBytes() + TAG_LENGTH);
ByteBuffer out = outBuf.internalNioBuffer(outBuf.writerIndex(), outBuf.writableBytes());
ByteBuffer plain = out.duplicate();
plain.limit(out.limit() - TAG_LENGTH);

byte[] counter = incrementOutCounter();
int outPosition = out.position();
aeadCrypter.encrypt(out, plain, counter);
int bytesWritten = out.position() - outPosition;
outBuf.writerIndex(outBuf.writerIndex() + bytesWritten);
verify(!outBuf.isWritable());
// Encrypt into tempArr.
{
ByteBuffer out = ByteBuffer.wrap(tempArr);
ByteBuffer plain = ByteBuffer.wrap(tempArr, 0, tempArr.length - TAG_LENGTH);

byte[] counter = incrementOutCounter();
aeadCrypter.encrypt(out, plain, counter);
}
outBuf.writeBytes(tempArr);
}

@Override
public void decrypt(ByteBuf out, ByteBuf tag, List<ByteBuf> ciphertextBufs)
public void decrypt(ByteBuf outBuf, ByteBuf tagBuf, List<ByteBuf> ciphertextBufs)
throws GeneralSecurityException {
// There is enough space for the ciphertext including the tag in outBuf.
byte[] tempArr = new byte[outBuf.writableBytes()];

// Copy ciphertext and tag into tempArr.
{
ByteBuf tempBuf = Unpooled.wrappedBuffer(tempArr);
tempBuf.resetWriterIndex();
for (ByteBuf ciphertextBuf : ciphertextBufs) {
tempBuf.writeBytes(ciphertextBuf);
}
tempBuf.writeBytes(tagBuf);
}

ByteBuf cipherTextAndTag = out.slice(out.writerIndex(), out.writableBytes());
cipherTextAndTag.writerIndex(0);
decryptInternal(outBuf, tempArr);
}

for (ByteBuf inBuf : ciphertextBufs) {
cipherTextAndTag.writeBytes(inBuf);
@Override
public void decrypt(
ByteBuf outBuf, ByteBuf ciphertextAndTagDirect) throws GeneralSecurityException {
byte[] tempArr = new byte[ciphertextAndTagDirect.readableBytes()];

// Copy ciphertext and tag into tempArr.
{
ByteBuf tempBuf = Unpooled.wrappedBuffer(tempArr);
tempBuf.resetWriterIndex();
tempBuf.writeBytes(ciphertextAndTagDirect);
}
cipherTextAndTag.writeBytes(tag);

decrypt(out, cipherTextAndTag);
decryptInternal(outBuf, tempArr);
}

@Override
public void decrypt(ByteBuf out, ByteBuf ciphertextAndTag) throws GeneralSecurityException {
int bytesRead = ciphertextAndTag.readableBytes();
checkArgument(bytesRead == out.writableBytes());

checkArgument(out.nioBufferCount() == 1);
ByteBuffer outBuffer = out.internalNioBuffer(out.writerIndex(), out.writableBytes());

checkArgument(ciphertextAndTag.nioBufferCount() == 1);
ByteBuffer ciphertextAndTagBuffer =
ciphertextAndTag.nioBuffer(ciphertextAndTag.readerIndex(), bytesRead);

byte[] counter = incrementInCounter();
int outPosition = outBuffer.position();
aeadCrypter.decrypt(outBuffer, ciphertextAndTagBuffer, counter);
int bytesWritten = outBuffer.position() - outPosition;
out.writerIndex(out.writerIndex() + bytesWritten);
ciphertextAndTag.readerIndex(out.readerIndex() + bytesRead);
verify(out.writableBytes() == TAG_LENGTH);
private void decryptInternal(ByteBuf outBuf, byte[] tempArr) throws GeneralSecurityException {
// Perform in-place decryption on tempArr.
{
ByteBuffer ciphertextAndTag = ByteBuffer.wrap(tempArr);
ByteBuffer out = ByteBuffer.wrap(tempArr);
byte[] counter = incrementInCounter();
aeadCrypter.decrypt(out, ciphertextAndTag, counter);
}

outBuf.writeBytes(tempArr, 0, tempArr.length - TAG_LENGTH);
}

@Override
Expand Down

0 comments on commit a7bca23

Please sign in to comment.