Skip to content
This repository has been archived by the owner on Apr 17, 2024. It is now read-only.

Commit

Permalink
Fixing ciphertext malleability issue in Java caused by storing the ci…
Browse files Browse the repository at this point in the history
…phertext prefix in a hashmap keyed by UTF8 encoded strings, instead of byte arrays, leading to the ability to retrieve keys with IDs that happen to be invalid Unicode strings with a changed ID.

PiperOrigin-RevId: 336763863
  • Loading branch information
sophieschmieg authored and Copybara-Service committed Oct 12, 2020
1 parent ac94479 commit 93d839a
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 20 deletions.
2 changes: 2 additions & 0 deletions java_src/src/main/java/com/google/crypto/tink/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,7 @@ java_library(
deps = [
":crypto_format",
"//proto:tink_java_proto",
"//src/main/java/com/google/crypto/tink/subtle:hex",
],
)

Expand All @@ -487,6 +488,7 @@ android_library(
deps = [
":crypto_format-android",
"//proto:tink_java_proto_lite",
"//src/main/java/com/google/crypto/tink/subtle:hex",
],
)

Expand Down
54 changes: 47 additions & 7 deletions java_src/src/main/java/com/google/crypto/tink/PrimitiveSet.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import com.google.crypto.tink.proto.KeyStatusType;
import com.google.crypto.tink.proto.Keyset;
import com.google.crypto.tink.proto.OutputPrefixType;
import java.nio.charset.Charset;
import com.google.crypto.tink.subtle.Hex;
import java.security.GeneralSecurityException;
import java.util.ArrayList;
import java.util.Arrays;
Expand Down Expand Up @@ -49,7 +49,6 @@
* @since 1.0.0
*/
public final class PrimitiveSet<P> {
private static final Charset UTF_8 = Charset.forName("UTF-8");
/**
* A single entry in the set. In addition to the actual primitive it holds also some extra
* information about the primitive.
Expand Down Expand Up @@ -117,7 +116,7 @@ public List<Entry<P>> getRawPrimitives() {

/** @return the entries with primitive identifed by {@code identifier}. */
public List<Entry<P>> getPrimitive(final byte[] identifier) {
List<Entry<P>> found = primitives.get(new String(identifier, UTF_8));
List<Entry<P>> found = primitives.get(new Prefix(identifier));
return found != null ? found : Collections.<Entry<P>>emptyList();
}

Expand All @@ -136,8 +135,8 @@ public Collection<List<Entry<P>>> getAll() {
* prefix). This allows quickly retrieving the list of primitives sharing some particular prefix.
* Because all RAW keys are using an empty prefix, this also quickly allows retrieving them.
*/
private ConcurrentMap<java.lang.String, List<Entry<P>>> primitives =
new ConcurrentHashMap<java.lang.String, List<Entry<P>>>();
private final ConcurrentMap<Prefix, List<Entry<P>>> primitives =
new ConcurrentHashMap<Prefix, List<Entry<P>>>();

private Entry<P> primary;
private final Class<P> primitiveClass;
Expand Down Expand Up @@ -185,8 +184,8 @@ public Entry<P> addPrimitive(final P primitive, Keyset.Key key)
key.getKeyId());
List<Entry<P>> list = new ArrayList<Entry<P>>();
list.add(entry);
// Cannot use [] as keys in hash map, convert to string.
String identifier = new String(entry.getIdentifier(), UTF_8);
// Cannot use [] as keys in hash map, convert to Prefix wrapper class.
Prefix identifier = new Prefix(entry.getIdentifier());
List<Entry<P>> existing = primitives.put(identifier, Collections.unmodifiableList(list));
if (existing != null) {
List<Entry<P>> newList = new ArrayList<Entry<P>>();
Expand All @@ -200,4 +199,45 @@ public Entry<P> addPrimitive(final P primitive, Keyset.Key key)
public Class<P> getPrimitiveClass() {
return primitiveClass;
}

private static class Prefix implements Comparable<Prefix> {
private final byte[] prefix;

private Prefix(byte[] prefix) {
this.prefix = Arrays.copyOf(prefix, prefix.length);
}

@Override
public int hashCode() {
return Arrays.hashCode(prefix);
}

@Override
public boolean equals(Object o) {
if (!(o instanceof Prefix)) {
return false;
}
Prefix other = (Prefix) o;
return Arrays.equals(prefix, other.prefix);
}

@Override
public int compareTo(Prefix o) {
if (prefix.length != o.prefix.length) {
return prefix.length - o.prefix.length;
}
for (int i = 0; i < prefix.length; i++) {
if (prefix[i] != o.prefix[i]) {
return prefix[i] - o.prefix[i];
}
}
return 0;
}

@Override
public String toString() {
return Hex.encode(prefix);
}
}

}
1 change: 1 addition & 0 deletions java_src/src/test/java/com/google/crypto/tink/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ java_test(
"//src/main/java/com/google/crypto/tink:crypto_format",
"//src/main/java/com/google/crypto/tink:mac",
"//src/main/java/com/google/crypto/tink:primitive_set",
"//src/main/java/com/google/crypto/tink/subtle:hex",
"//src/main/java/com/google/crypto/tink/testing:test_util",
"@maven//:com_google_truth_truth",
"@maven//:junit_junit",
Expand Down
57 changes: 44 additions & 13 deletions java_src/src/test/java/com/google/crypto/tink/PrimitiveSetTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,15 @@

import static com.google.common.truth.Truth.assertThat;
import static com.google.crypto.tink.testing.TestUtil.assertExceptionContains;
import static java.nio.charset.StandardCharsets.UTF_8;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.fail;

import com.google.crypto.tink.proto.KeyStatusType;
import com.google.crypto.tink.proto.Keyset.Key;
import com.google.crypto.tink.proto.OutputPrefixType;
import java.nio.charset.Charset;
import com.google.crypto.tink.subtle.Hex;
import java.security.GeneralSecurityException;
import java.util.ArrayList;
import java.util.List;
Expand All @@ -36,7 +37,6 @@
/** Tests for PrimitiveSet. */
@RunWith(JUnit4.class)
public class PrimitiveSetTest {
private static final Charset UTF_8 = Charset.forName("UTF-8");

private static class DummyMac1 implements Mac {
public DummyMac1() {}
Expand Down Expand Up @@ -91,36 +91,34 @@ public void testBasicFunctionality() throws Exception {
.build();
pset.addPrimitive(new DummyMac1(), key3);

assertEquals(3, pset.getAll().size());
assertThat(pset.getAll()).hasSize(3);

List<PrimitiveSet.Entry<Mac>> entries = pset.getPrimitive(key1);
assertEquals(1, entries.size());
assertThat(entries).hasSize(1);
PrimitiveSet.Entry<Mac> entry = entries.get(0);
assertEquals(
DummyMac1.class.getSimpleName(),
new String(entry.getPrimitive().computeMac(null), "UTF-8"));
DummyMac1.class.getSimpleName(), new String(entry.getPrimitive().computeMac(null), UTF_8));
assertEquals(KeyStatusType.ENABLED, entry.getStatus());
assertEquals(CryptoFormat.TINK_START_BYTE, entry.getIdentifier()[0]);
assertArrayEquals(CryptoFormat.getOutputPrefix(key1), entry.getIdentifier());
assertEquals(entry.getKeyId(), 1);

entries = pset.getPrimitive(key2);
assertEquals(1, entries.size());
assertThat(entries).hasSize(1);
entry = entries.get(0);
assertEquals(
DummyMac2.class.getSimpleName(),
new String(entry.getPrimitive().computeMac(null), "UTF-8"));
new String(entry.getPrimitive().computeMac(null), UTF_8));
assertEquals(KeyStatusType.ENABLED, entry.getStatus());
assertEquals(0, entry.getIdentifier().length);
assertThat(entry.getIdentifier()).isEmpty();
assertArrayEquals(CryptoFormat.getOutputPrefix(key2), entry.getIdentifier());
assertEquals(2, entry.getKeyId());

entries = pset.getPrimitive(key3);
assertEquals(1, entries.size());
assertThat(entries).hasSize(1);
entry = entries.get(0);
assertEquals(
DummyMac1.class.getSimpleName(),
new String(entry.getPrimitive().computeMac(null), "UTF-8"));
DummyMac1.class.getSimpleName(), new String(entry.getPrimitive().computeMac(null), UTF_8));
assertEquals(KeyStatusType.ENABLED, entry.getStatus());
assertEquals(CryptoFormat.LEGACY_START_BYTE, entry.getIdentifier()[0]);
assertArrayEquals(CryptoFormat.getOutputPrefix(key3), entry.getIdentifier());
Expand All @@ -129,7 +127,7 @@ public void testBasicFunctionality() throws Exception {
entry = pset.getPrimary();
assertEquals(
DummyMac2.class.getSimpleName(),
new String(entry.getPrimitive().computeMac(null), "UTF-8"));
new String(entry.getPrimitive().computeMac(null), UTF_8));
assertEquals(KeyStatusType.ENABLED, entry.getStatus());
assertArrayEquals(CryptoFormat.getOutputPrefix(key2), entry.getIdentifier());
assertEquals(2, entry.getKeyId());
Expand Down Expand Up @@ -276,4 +274,37 @@ public void testAddPrimive_WithDisabledKey_shouldFail() throws Exception {
assertExceptionContains(e, "only ENABLED key is allowed");
}
}

@Test
public void testPrefix_isUnique() throws Exception {
PrimitiveSet<Mac> pset = PrimitiveSet.newPrimitiveSet(Mac.class);
Key key1 =
Key.newBuilder()
.setKeyId(0xffffffff)
.setStatus(KeyStatusType.ENABLED)
.setOutputPrefixType(OutputPrefixType.TINK)
.build();
pset.addPrimitive(new DummyMac1(), key1);
Key key2 =
Key.newBuilder()
.setKeyId(0xffffffdf)
.setStatus(KeyStatusType.ENABLED)
.setOutputPrefixType(OutputPrefixType.RAW)
.build();
pset.setPrimary(pset.addPrimitive(new DummyMac2(), key2));
Key key3 =
Key.newBuilder()
.setKeyId(0xffffffef)
.setStatus(KeyStatusType.ENABLED)
.setOutputPrefixType(OutputPrefixType.LEGACY)
.build();
pset.addPrimitive(new DummyMac1(), key3);

assertThat(pset.getAll()).hasSize(3);

assertThat(pset.getPrimitive(Hex.decode("01ffffffff"))).hasSize(1);
assertThat(pset.getPrimitive(Hex.decode("01ffffffef"))).isEmpty();
assertThat(pset.getPrimitive(Hex.decode("00ffffffff"))).isEmpty();
assertThat(pset.getPrimitive(Hex.decode("00ffffffef"))).hasSize(1);
}
}

0 comments on commit 93d839a

Please sign in to comment.