Skip to content

Commit

Permalink
Update to latest JWT apis
Browse files Browse the repository at this point in the history
  • Loading branch information
dain committed Nov 2, 2021
1 parent 835695a commit 37c7d22
Show file tree
Hide file tree
Showing 18 changed files with 147 additions and 78 deletions.
14 changes: 13 additions & 1 deletion client/trino-jdbc/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,19 @@

<dependency>
<groupId>io.jsonwebtoken</groupId>
<artifactId>jjwt</artifactId>
<artifactId>jjwt-api</artifactId>
<scope>test</scope>
</dependency>

<dependency>
<groupId>io.jsonwebtoken</groupId>
<artifactId>jjwt-impl</artifactId>
<scope>test</scope>
</dependency>

<dependency>
<groupId>io.jsonwebtoken</groupId>
<artifactId>jjwt-jackson</artifactId>
<scope>test</scope>
</dependency>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import io.airlift.log.Logging;
import io.airlift.security.pem.PemReader;
import io.jsonwebtoken.Jwts;
import io.jsonwebtoken.SignatureAlgorithm;
import io.jsonwebtoken.security.Keys;
import io.trino.plugin.tpch.TpchPlugin;
import io.trino.server.testing.TestingTrinoServer;
import org.testng.annotations.AfterClass;
Expand All @@ -26,21 +26,23 @@

import java.io.File;
import java.net.URL;
import java.security.Key;
import java.security.PrivateKey;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.time.Duration;
import java.util.Base64;
import java.util.Map;
import java.util.Optional;
import java.util.Properties;

import static com.google.common.io.Files.asCharSource;
import static com.google.common.io.Resources.getResource;
import static io.jsonwebtoken.JwsHeader.KEY_ID;
import static io.jsonwebtoken.SignatureAlgorithm.HS512;
import static io.jsonwebtoken.security.Keys.hmacShaKeyFor;
import static java.lang.String.format;
import static java.nio.charset.StandardCharsets.US_ASCII;
import static java.util.Base64.getMimeDecoder;
Expand All @@ -54,8 +56,8 @@ public class TestTrinoDriverAuth
{
private static final String TEST_CATALOG = "test_catalog";
private TestingTrinoServer server;
private byte[] defaultKey;
private byte[] hmac222;
private Key defaultKey;
private Key hmac222;
private PrivateKey privateKey33;

@BeforeClass
Expand All @@ -68,8 +70,8 @@ public void setup()
assertNotNull(resource, "key directory not found");
File keyDir = new File(resource.getFile()).getAbsoluteFile().getParentFile();

defaultKey = getMimeDecoder().decode(asCharSource(new File(keyDir, "default-key.key"), US_ASCII).read().getBytes(US_ASCII));
hmac222 = getMimeDecoder().decode(asCharSource(new File(keyDir, "222.key"), US_ASCII).read().getBytes(US_ASCII));
defaultKey = hmacShaKeyFor(getMimeDecoder().decode(asCharSource(new File(keyDir, "default-key.key"), US_ASCII).read().getBytes(US_ASCII)));
hmac222 = hmacShaKeyFor(getMimeDecoder().decode(asCharSource(new File(keyDir, "222.key"), US_ASCII).read().getBytes(US_ASCII)));
privateKey33 = PemReader.loadPrivateKey(new File(keyDir, "33.privateKey"), Optional.empty());

server = TestingTrinoServer.builder()
Expand Down Expand Up @@ -99,7 +101,7 @@ public void testSuccessDefaultKey()
{
String accessToken = Jwts.builder()
.setSubject("test")
.signWith(SignatureAlgorithm.HS512, defaultKey)
.signWith(defaultKey)
.compact();

try (Connection connection = createConnection(ImmutableMap.of("accessToken", accessToken))) {
Expand All @@ -120,7 +122,7 @@ public void testSuccessHmac()
String accessToken = Jwts.builder()
.setSubject("test")
.setHeaderParam(KEY_ID, "222")
.signWith(SignatureAlgorithm.HS512, hmac222)
.signWith(hmac222)
.compact();

try (Connection connection = createConnection(ImmutableMap.of("accessToken", accessToken))) {
Expand All @@ -141,7 +143,7 @@ public void testSuccessPublicKey()
String accessToken = Jwts.builder()
.setSubject("test")
.setHeaderParam(KEY_ID, "33")
.signWith(SignatureAlgorithm.RS256, privateKey33)
.signWith(privateKey33)
.compact();

try (Connection connection = createConnection(ImmutableMap.of("accessToken", accessToken))) {
Expand Down Expand Up @@ -185,9 +187,10 @@ public void testFailedUnsigned()
public void testFailedBadHmacSignature()
throws Exception
{
Key badKey = Keys.secretKeyFor(HS512);
String accessToken = Jwts.builder()
.setSubject("test")
.signWith(SignatureAlgorithm.HS512, Base64.getEncoder().encodeToString("bad-key".getBytes(US_ASCII)))
.signWith(badKey)
.compact();

try (Connection connection = createConnection(ImmutableMap.of("accessToken", accessToken))) {
Expand All @@ -204,7 +207,7 @@ public void testFailedWrongPublicKey()
String accessToken = Jwts.builder()
.setSubject("test")
.setHeaderParam(KEY_ID, "42")
.signWith(SignatureAlgorithm.RS256, privateKey33)
.signWith(privateKey33)
.compact();

try (Connection connection = createConnection(ImmutableMap.of("accessToken", accessToken))) {
Expand All @@ -221,7 +224,7 @@ public void testFailedUnknownPublicKey()
String accessToken = Jwts.builder()
.setSubject("test")
.setHeaderParam(KEY_ID, "unknown")
.signWith(SignatureAlgorithm.RS256, privateKey33)
.signWith(privateKey33)
.compact();

try (Connection connection = createConnection(ImmutableMap.of("accessToken", accessToken))) {
Expand All @@ -238,7 +241,7 @@ public void testSuccessFullSslVerification()
String accessToken = Jwts.builder()
.setSubject("test")
.setHeaderParam(KEY_ID, "33")
.signWith(SignatureAlgorithm.RS256, privateKey33)
.signWith(privateKey33)
.compact();

try (Connection connection = createConnection(ImmutableMap.of("accessToken", accessToken, "SSLVerification", "FULL"))) {
Expand All @@ -259,7 +262,7 @@ public void testSuccessCaSslVerification()
String accessToken = Jwts.builder()
.setSubject("test")
.setHeaderParam(KEY_ID, "33")
.signWith(SignatureAlgorithm.RS256, privateKey33)
.signWith(privateKey33)
.compact();

try (Connection connection = createConnection(ImmutableMap.of("accessToken", accessToken, "SSLVerification", "CA"))) {
Expand Down
13 changes: 12 additions & 1 deletion core/trino-main/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,12 @@

<dependency>
<groupId>io.jsonwebtoken</groupId>
<artifactId>jjwt</artifactId>
<artifactId>jjwt-api</artifactId>
</dependency>

<dependency>
<groupId>io.jsonwebtoken</groupId>
<artifactId>jjwt-impl</artifactId>
</dependency>

<dependency>
Expand Down Expand Up @@ -332,6 +337,12 @@
<scope>runtime</scope>
</dependency>

<dependency>
<groupId>io.jsonwebtoken</groupId>
<artifactId>jjwt-jackson</artifactId>
<scope>runtime</scope>
</dependency>

<dependency>
<groupId>org.assertj</groupId>
<artifactId>assertj-core</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,19 @@
import io.airlift.node.NodeInfo;
import io.jsonwebtoken.JwtException;
import io.jsonwebtoken.Jwts;
import io.jsonwebtoken.SignatureAlgorithm;
import io.trino.server.security.InternalPrincipal;
import io.trino.spi.security.Identity;

import javax.inject.Inject;
import javax.ws.rs.container.ContainerRequestContext;
import javax.ws.rs.core.Response;

import java.security.Key;
import java.time.ZonedDateTime;
import java.util.Date;

import static io.airlift.http.client.Request.Builder.fromRequest;
import static io.jsonwebtoken.security.Keys.hmacShaKeyFor;
import static io.trino.server.ServletSecurityUtils.setAuthenticatedIdentity;
import static java.nio.charset.StandardCharsets.UTF_8;
import static java.util.Objects.requireNonNull;
Expand All @@ -45,7 +46,7 @@ public class InternalAuthenticationManager

private static final String TRINO_INTERNAL_BEARER = "X-Trino-Internal-Bearer";

private final byte[] hmac;
private final Key hmac;
private final String nodeId;

@Inject
Expand All @@ -72,7 +73,7 @@ public InternalAuthenticationManager(String sharedSecret, String nodeId)
{
requireNonNull(sharedSecret, "sharedSecret is null");
requireNonNull(nodeId, "nodeId is null");
this.hmac = Hashing.sha256().hashString(sharedSecret, UTF_8).asBytes();
this.hmac = hmacShaKeyFor(Hashing.sha256().hashString(sharedSecret, UTF_8).asBytes());
this.nodeId = nodeId;
}

Expand Down Expand Up @@ -115,16 +116,17 @@ public Request filterRequest(Request request)
private String generateJwt()
{
return Jwts.builder()
.signWith(SignatureAlgorithm.HS256, hmac)
.signWith(hmac)
.setSubject(nodeId)
.setExpiration(Date.from(ZonedDateTime.now().plusMinutes(5).toInstant()))
.compact();
}

private String parseJwt(String jwt)
{
return Jwts.parser()
return Jwts.parserBuilder()
.setSigningKey(hmac)
.build()
.parseClaimsJws(jwt)
.getBody()
.getSubject();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
import io.jsonwebtoken.Claims;
import io.jsonwebtoken.JwsHeader;
import io.jsonwebtoken.SignatureAlgorithm;
import io.jsonwebtoken.SignatureException;
import io.jsonwebtoken.SigningKeyResolver;
import io.jsonwebtoken.UnsupportedJwtException;
import io.jsonwebtoken.security.SecurityException;

import javax.crypto.spec.SecretKeySpec;
import javax.inject.Inject;
Expand Down Expand Up @@ -112,15 +112,15 @@ private LoadedKey loadKey(String keyId)
public static LoadedKey loadKeyFile(File file)
{
if (!file.canRead()) {
throw new SignatureException("Unknown signing key ID");
throw new SecurityException("Unknown signing key ID");
}

String data;
try {
data = asCharSource(file, US_ASCII).read();
}
catch (IOException e) {
throw new SignatureException("Unable to read signing key", e);
throw new SecurityException("Unable to read signing key", e);
}

// try to load the key as a PEM encoded public key
Expand All @@ -129,7 +129,7 @@ public static LoadedKey loadKeyFile(File file)
return new LoadedKey(PemReader.loadPublicKey(data));
}
catch (RuntimeException | GeneralSecurityException e) {
throw new SignatureException("Unable to decode PEM signing key id", e);
throw new SecurityException("Unable to decode PEM signing key id", e);
}
}

Expand All @@ -139,7 +139,7 @@ public static LoadedKey loadKeyFile(File file)
return new LoadedKey(rawKey);
}
catch (RuntimeException e) {
throw new SignatureException("Unable to decode HMAC signing key", e);
throw new SecurityException("Unable to decode HMAC signing key", e);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@

import io.jsonwebtoken.Claims;
import io.jsonwebtoken.JwsHeader;
import io.jsonwebtoken.SignatureException;
import io.jsonwebtoken.SigningKeyResolver;
import io.jsonwebtoken.security.SecurityException;

import javax.inject.Inject;

Expand Down Expand Up @@ -51,9 +51,9 @@ private Key getKey(JwsHeader<?> header)
{
String keyId = header.getKeyId();
if (keyId == null) {
throw new SignatureException("Key ID is required");
throw new SecurityException("Key ID is required");
}
return keys.getKey(keyId)
.orElseThrow(() -> new SignatureException("Unknown signing key ID"));
.orElseThrow(() -> new SecurityException("Unknown signing key ID"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
package io.trino.server.security.jwt;

import io.jsonwebtoken.JwtParser;
import io.jsonwebtoken.JwtParserBuilder;
import io.jsonwebtoken.Jwts;
import io.jsonwebtoken.SigningKeyResolver;
import io.trino.server.security.AbstractBearerAuthenticator;
Expand All @@ -40,7 +41,7 @@ public JwtAuthenticator(JwtAuthenticatorConfig config, SigningKeyResolver signin
super(createUserMapping(config.getUserMappingPattern(), config.getUserMappingFile()));
principalField = config.getPrincipalField();

JwtParser jwtParser = Jwts.parser()
JwtParserBuilder jwtParser = Jwts.parserBuilder()
.setSigningKeyResolver(signingKeyResolver);

if (config.getRequiredIssuer() != null) {
Expand All @@ -49,7 +50,7 @@ public JwtAuthenticator(JwtAuthenticatorConfig config, SigningKeyResolver signin
if (config.getRequiredAudience() != null) {
jwtParser.requireAudience(config.getRequiredAudience());
}
this.jwtParser = jwtParser;
this.jwtParser = jwtParser.build();
}

@Override
Expand Down
Loading

0 comments on commit 37c7d22

Please sign in to comment.