Skip to content

Commit

Permalink
Use least privilege service account for the runtime pod (LangStream#403)
Browse files Browse the repository at this point in the history
  • Loading branch information
nicoloboschi committed Sep 13, 2023
1 parent 821ec65 commit b603580
Show file tree
Hide file tree
Showing 19 changed files with 433 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,9 @@ public void initialize(Map<String, Object> configuration) {
tokenProperties.publicAlg(),
tokenProperties.audienceClaim(),
tokenProperties.audience(),
tokenProperties.jwksHostsAllowlist());
tokenProperties.jwksHostsAllowlist(),
false,
null);
this.authenticationProviderToken = new AuthenticationProviderToken(jwtProperties);
}

Expand Down
9 changes: 5 additions & 4 deletions langstream-auth-jwt/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,10 @@
<groupId>io.jsonwebtoken</groupId>
<artifactId>jjwt-jackson</artifactId>
</dependency>
<dependency>
<groupId>com.github.tomakehurst</groupId>
<artifactId>wiremock</artifactId>
<scope>test</scope>
</dependency>
</dependencies>




</project>
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import io.jsonwebtoken.Jwts;
import io.jsonwebtoken.RequiredTypeException;
import io.jsonwebtoken.SignatureAlgorithm;
import io.jsonwebtoken.SigningKeyResolver;
import io.jsonwebtoken.io.Decoders;
import io.jsonwebtoken.io.DecodingException;
import io.jsonwebtoken.security.Keys;
Expand Down Expand Up @@ -54,21 +55,28 @@ public AuthenticationException(String message) {

private final JwtParser parser;
private final String roleClaim;
private final SignatureAlgorithm publicKeyAlg;
private final String audienceClaim;
private final String audience;
private final boolean allowK8sServiceAccountAuth;
private final String k8sNamespacePrefix;

public AuthenticationProviderToken(JwtProperties tokenProperties)
throws IOException, IllegalArgumentException {
this.publicKeyAlg = getPublicKeyAlgType(tokenProperties);
parser =
Jwts.parserBuilder()
.setSigningKeyResolver(
new JwksUriSigningKeyResolver(
publicKeyAlg.getValue(),
tokenProperties.jwksHostsAllowlist(),
getValidationKeyFromConfig(tokenProperties)))
.build();
this(tokenProperties, null);
}

public AuthenticationProviderToken(
JwtProperties tokenProperties, SigningKeyResolver signingKeyResolver)
throws IOException, IllegalArgumentException {
if (signingKeyResolver == null) {
final SignatureAlgorithm publicKeyAlgType = getPublicKeyAlgType(tokenProperties);
signingKeyResolver =
new JwksUriSigningKeyResolver(
publicKeyAlgType.getValue(),
tokenProperties.jwksHostsAllowlist(),
getValidationKeyFromConfig(tokenProperties, publicKeyAlgType));
}
parser = Jwts.parserBuilder().setSigningKeyResolver(signingKeyResolver).build();
this.roleClaim = getTokenRoleClaim(tokenProperties);
this.audienceClaim = getTokenAudienceClaim(tokenProperties);
this.audience = getTokenAudience(tokenProperties);
Expand All @@ -79,11 +87,17 @@ public AuthenticationProviderToken(JwtProperties tokenProperties)
+ this.audienceClaim
+ "] configured, but Audience stands for this broker not.");
}
this.allowK8sServiceAccountAuth = tokenProperties.allowKubernetesServiceAccounts();
this.k8sNamespacePrefix = tokenProperties.kubernetesNamespacePrefix();
}

public String authenticate(String token) throws AuthenticationException {
final Jwt<?, Claims> jwt = authenticateToken(token);
return getPrincipal(jwt);
final String principal = getPrincipal(jwt);
if (principal == null) {
throw new AuthenticationException("Token was valid, however no principal found.");
}
return principal;
}

private Jwt<?, Claims> authenticateToken(final String token) throws AuthenticationException {
Expand Down Expand Up @@ -132,14 +146,9 @@ private String getPrincipal(Jwt<?, Claims> jwt) {
final Claims body = jwt.getBody();
try {
log.debug("Token body: {}", body);
if (body.containsKey("kubernetes.io")) {
final Map map = body.get("kubernetes.io", Map.class);
if (map.containsKey("serviceaccount")) {
final Map serviceAccount = (Map) map.get("serviceaccount");
if (serviceAccount.containsKey("name")) {
return (String) serviceAccount.get("name");
}
}
String principal = getPrincipalFromKubernetesClaim(body);
if (principal != null) {
return principal;
}
return body.get(this.roleClaim, String.class);
} catch (RequiredTypeException var4) {
Expand All @@ -150,7 +159,23 @@ private String getPrincipal(Jwt<?, Claims> jwt) {
}
}

private Key getValidationKeyFromConfig(JwtProperties tokenProperties) throws IOException {
private String getPrincipalFromKubernetesClaim(Claims body) {
if (allowK8sServiceAccountAuth && body.containsKey("kubernetes.io")) {
final Map map = body.get("kubernetes.io", Map.class);
if (map.containsKey("namespace")) {
final String namespace = (String) map.get("namespace");
if (namespace != null) {
if (namespace.startsWith(k8sNamespacePrefix)) {
return namespace.substring(k8sNamespacePrefix.length());
}
}
}
}
return null;
}

private static Key getValidationKeyFromConfig(
JwtProperties tokenProperties, SignatureAlgorithm algType) throws IOException {
String tokenSecretKey = tokenProperties.secretKey();
String tokenPublicKey = tokenProperties.publicKey();
byte[] validationKey;
Expand All @@ -159,7 +184,7 @@ private Key getValidationKeyFromConfig(JwtProperties tokenProperties) throws IOE
return decodeSecretKey(validationKey);
} else if (StringUtils.isNotBlank(tokenPublicKey)) {
validationKey = readKeyFromUrl(tokenPublicKey);
return decodePublicKey(validationKey, this.publicKeyAlg);
return decodePublicKey(validationKey, algType);
}
return null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/
package ai.langstream.auth.jwt;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.jsonwebtoken.Claims;
import io.jsonwebtoken.JwsHeader;
Expand Down Expand Up @@ -59,6 +60,14 @@ public record JwksUriCacheKey(JwksUri uri, String keyId) {}
private Map<JwksUriCacheKey, Key> keyMap = new ConcurrentHashMap<>();

public JwksUriSigningKeyResolver(String algorithm, String hostsAllowlist, Key fallbackKey) {
this(algorithm, hostsAllowlist, fallbackKey, null);
}

public JwksUriSigningKeyResolver(
String algorithm,
String hostsAllowlist,
Key fallbackKey,
LocalKubernetesJwksUriSigningKeyResolver localKubernetesJwksUriSigningKeyResolver) {
this.algorithm = algorithm;
if (StringUtils.isBlank(hostsAllowlist)) {
this.hostsAllowlist = null;
Expand All @@ -71,8 +80,13 @@ public JwksUriSigningKeyResolver(String algorithm, String hostsAllowlist, Key fa
.connectTimeout(Duration.ofSeconds(30))
.followRedirects(HttpClient.Redirect.ALWAYS)
.build();
this.localKubernetesJwksUriSigningKeyResolver =
new LocalKubernetesJwksUriSigningKeyResolver(httpClient);
if (localKubernetesJwksUriSigningKeyResolver == null) {
this.localKubernetesJwksUriSigningKeyResolver =
new LocalKubernetesJwksUriSigningKeyResolver(httpClient);
} else {
this.localKubernetesJwksUriSigningKeyResolver =
localKubernetesJwksUriSigningKeyResolver;
}
}

@Override
Expand Down Expand Up @@ -135,7 +149,8 @@ private Key fetchKey(JwksUriCacheKey jwksKey) {
KeyFactory keyFactory = KeyFactory.getInstance("RSA");
return keyFactory.generatePublic(rsaPublicKeySpec);
} catch (NoSuchAlgorithmException | InvalidKeySpecException e) {
throw new IllegalStateException("Failed to parse public key");
throw new JwtException(
"Failed to parse public key '" + key.kid() + "' from " + jwksUri.uri());
}
}
throw new JwtException(
Expand Down Expand Up @@ -166,7 +181,7 @@ private JwkKeys getKeys(JwksUri jwksUri) throws IOException {
+ " "
+ responseWithToken.body());
} else {
body = httpResponse.body();
body = responseWithToken.body();
}
} else {
log.warn(
Expand All @@ -185,7 +200,12 @@ private JwkKeys getKeys(JwksUri jwksUri) throws IOException {
} else {
body = httpResponse.body();
}
return MAPPER.readValue(body, JwkKeys.class);
try {
return MAPPER.readValue(body, JwkKeys.class);
} catch (JsonProcessingException ex) {
log.warn("Failed to parse keys from URL: {}, got {}", jwksUri.uri(), body, ex);
throw ex;
}
}

private HttpResponse<String> sendGetKeysRequest(JwksUri jwksUri, boolean useToken)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,6 @@ public record JwtProperties(
String publicAlg,
String audienceClaim,
String audience,
String jwksHostsAllowlist) {}
String jwksHostsAllowlist,
boolean allowKubernetesServiceAccounts,
String kubernetesNamespacePrefix) {}
Original file line number Diff line number Diff line change
Expand Up @@ -31,22 +31,36 @@
@Slf4j
public class LocalKubernetesJwksUriSigningKeyResolver {

public static final String DEFAULT_TOKEN_PATH =
"/var/run/secrets/kubernetes.io/serviceaccount/token";
public static final String DEFAULT_K8S_BASE_URL =
"https://kubernetes.default.svc.cluster.local";
private final HttpClient httpClient;
private final String token;
private final String localK8sIssuer;
private final Map<String, JwksUriSigningKeyResolver.JwksUri> cache = new ConcurrentHashMap<>();
private static final ObjectMapper MAPPER = new ObjectMapper();

public LocalKubernetesJwksUriSigningKeyResolver(HttpClient httpClient) {
public LocalKubernetesJwksUriSigningKeyResolver(
HttpClient httpClient, String tokenPath, String localIssuerBaseUrl) {
this.httpClient = httpClient;
token = loadToken();
localK8sIssuer = loadLocalIssuer();
token = loadToken(tokenPath);
localK8sIssuer = loadLocalIssuer(localIssuerBaseUrl);
log.info("Loaded local Kubernetes issuer: {}", localK8sIssuer);
}

public LocalKubernetesJwksUriSigningKeyResolver(HttpClient httpClient) {
this(httpClient, DEFAULT_TOKEN_PATH, DEFAULT_K8S_BASE_URL);
}

@SneakyThrows
private String loadToken() {
final Path defaultPath = Path.of("/var/run/secrets/kubernetes.io/serviceaccount/token");
private static String loadToken(String path) {
if (path == null) {
log.info(
"No token path specified. Kubernetes Service account authentication might not work.");
return null;
}
final Path defaultPath = Path.of(path);
if (Files.isRegularFile(defaultPath)) {
log.info("Loading token from {}", defaultPath);
return Files.readString(defaultPath);
Expand All @@ -59,9 +73,13 @@ private String loadToken() {
}

@SneakyThrows
private String loadLocalIssuer() {
final String endpoint =
composeWellKnownEndpoint("https://kubernetes.default.svc.cluster.local");
private String loadLocalIssuer(String baseUrl) {
if (baseUrl == null) {
log.info(
"Base url not configured for local Kubernetes API. It's ok if not running in a kubernetes pod.");
return null;
}
final String endpoint = composeWellKnownEndpoint(baseUrl);
final Map<String, ?> response;
try {
response = getResponseFromWellKnownOpenIdConfiguration(endpoint);
Expand Down
Loading

0 comments on commit b603580

Please sign in to comment.