Skip to content
This repository has been archived by the owner on May 31, 2022. It is now read-only.

Commit

Permalink
SECOAUTH-349: add explicit catch blocks for deserialization problems
Browse files Browse the repository at this point in the history
  • Loading branch information
dsyer committed Oct 30, 2012
1 parent c0c1b39 commit 2112c4e
Showing 1 changed file with 49 additions and 10 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
package org.springframework.security.oauth2.provider.token;

import java.io.UnsupportedEncodingException;
import java.math.BigInteger;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Types;
Expand Down Expand Up @@ -107,6 +111,8 @@ public OAuth2AccessToken mapRow(ResultSet rs, int rowNum) throws SQLException {
if (LOG.isInfoEnabled()) {
LOG.debug("Failed to find access token for authentication " + authentication);
}
} catch (IllegalArgumentException e) {
LOG.error("Could not extract access token for authentication " + authentication);
}

if (accessToken != null && !authentication.equals(readAuthentication(accessToken.getValue()))) {
Expand All @@ -124,11 +130,11 @@ public void storeAccessToken(OAuth2AccessToken token, OAuth2Authentication authe
refreshToken = token.getRefreshToken().getValue();
}

jdbcTemplate.update(insertAccessTokenSql, new Object[] { token.getValue(),
jdbcTemplate.update(insertAccessTokenSql, new Object[] { extractTokenKey(token.getValue()),
new SqlLobValue(serializeAccessToken(token)), authenticationKeyGenerator.extractKey(authentication),
authentication.isClientOnly() ? null : authentication.getName(),
authentication.getAuthorizationRequest().getClientId(),
new SqlLobValue(serializeAuthentication(authentication)), refreshToken }, new int[] { Types.VARCHAR,
new SqlLobValue(serializeAuthentication(authentication)), extractTokenKey(refreshToken) }, new int[] { Types.VARCHAR,
Types.BLOB, Types.VARCHAR, Types.VARCHAR, Types.VARCHAR, Types.BLOB, Types.VARCHAR });
}

Expand All @@ -140,12 +146,15 @@ public OAuth2AccessToken readAccessToken(String tokenValue) {
public OAuth2AccessToken mapRow(ResultSet rs, int rowNum) throws SQLException {
return deserializeAccessToken(rs.getBytes(2));
}
}, tokenValue);
}, extractTokenKey(tokenValue));
}
catch (EmptyResultDataAccessException e) {
if (LOG.isInfoEnabled()) {
LOG.info("Failed to find access token for token " + tokenValue);
}
} catch (IllegalArgumentException e) {
LOG.warn("Failed to deserialize access token for " + tokenValue);
removeAccessToken(tokenValue);
}

return accessToken;
Expand All @@ -156,7 +165,7 @@ public void removeAccessToken(OAuth2AccessToken token) {
}

public void removeAccessToken(String tokenValue) {
jdbcTemplate.update(deleteAccessTokenSql, tokenValue);
jdbcTemplate.update(deleteAccessTokenSql, extractTokenKey(tokenValue));
}

public OAuth2Authentication readAuthentication(OAuth2AccessToken token) {
Expand All @@ -172,19 +181,22 @@ public OAuth2Authentication readAuthentication(String token) {
public OAuth2Authentication mapRow(ResultSet rs, int rowNum) throws SQLException {
return deserializeAuthentication(rs.getBytes(2));
}
}, token);
}, extractTokenKey(token));
}
catch (EmptyResultDataAccessException e) {
if (LOG.isInfoEnabled()) {
LOG.info("Failed to find access token for token " + token);
}
} catch (IllegalArgumentException e) {
LOG.warn("Failed to deserialize authentication for " + token);
removeAccessToken(token);
}

return authentication;
}

public void storeRefreshToken(OAuth2RefreshToken refreshToken, OAuth2Authentication authentication) {
jdbcTemplate.update(insertRefreshTokenSql, new Object[] { refreshToken.getValue(),
jdbcTemplate.update(insertRefreshTokenSql, new Object[] { extractTokenKey(refreshToken.getValue()),
new SqlLobValue(serializeRefreshToken(refreshToken)),
new SqlLobValue(serializeAuthentication(authentication)) }, new int[] { Types.VARCHAR, Types.BLOB,
Types.BLOB });
Expand All @@ -198,12 +210,15 @@ public OAuth2RefreshToken readRefreshToken(String token) {
public OAuth2RefreshToken mapRow(ResultSet rs, int rowNum) throws SQLException {
return deserializeRefreshToken(rs.getBytes(2));
}
}, token);
}, extractTokenKey(token));
}
catch (EmptyResultDataAccessException e) {
if (LOG.isInfoEnabled()) {
LOG.info("Failed to find refresh token for token " + token);
}
} catch (IllegalArgumentException e) {
LOG.warn("Failed to deserialize refresh token for token " + token);
removeRefreshToken(token);
}

return refreshToken;
Expand All @@ -214,7 +229,7 @@ public void removeRefreshToken(OAuth2RefreshToken token) {
}

public void removeRefreshToken(String token) {
jdbcTemplate.update(deleteRefreshTokenSql, token);
jdbcTemplate.update(deleteRefreshTokenSql, extractTokenKey(token));
}

public OAuth2Authentication readAuthenticationForRefreshToken(OAuth2RefreshToken token) {
Expand All @@ -230,12 +245,15 @@ public OAuth2Authentication readAuthenticationForRefreshToken(String value) {
public OAuth2Authentication mapRow(ResultSet rs, int rowNum) throws SQLException {
return deserializeAuthentication(rs.getBytes(2));
}
}, value);
}, extractTokenKey(value));
}
catch (EmptyResultDataAccessException e) {
if (LOG.isInfoEnabled()) {
LOG.info("Failed to find access token for token " + value);
}
} catch (IllegalArgumentException e) {
LOG.warn("Failed to deserialize access token for " + value);
removeRefreshToken(value);
}

return authentication;
Expand All @@ -246,7 +264,7 @@ public void removeAccessTokenUsingRefreshToken(OAuth2RefreshToken refreshToken)
}

public void removeAccessTokenUsingRefreshToken(String refreshToken) {
jdbcTemplate.update(deleteAccessTokenFromRefreshTokenSql, new Object[] { refreshToken },
jdbcTemplate.update(deleteAccessTokenFromRefreshTokenSql, new Object[] { extractTokenKey(refreshToken) },
new int[] { Types.VARCHAR });
}

Expand Down Expand Up @@ -288,6 +306,27 @@ public OAuth2AccessToken mapRow(ResultSet rs, int rowNum) throws SQLException {
return accessTokens;
}

protected String extractTokenKey(String value) {
if (value==null) {
return null;
}
MessageDigest digest;
try {
digest = MessageDigest.getInstance("MD5");
}
catch (NoSuchAlgorithmException e) {
throw new IllegalStateException("MD5 algorithm not available. Fatal (should be in the JDK).");
}

try {
byte[] bytes = digest.digest(value.getBytes("UTF-8"));
return String.format("%032x", new BigInteger(1, bytes));
}
catch (UnsupportedEncodingException e) {
throw new IllegalStateException("UTF-8 encoding not available. Fatal (should be in the JDK).");
}
}

protected byte[] serializeAccessToken(OAuth2AccessToken token) {
return SerializationUtils.serialize(token);
}
Expand Down

0 comments on commit 2112c4e

Please sign in to comment.