Skip to content

Commit

Permalink
Change blob datatype to text when saving Oauth2Authorization to the d…
Browse files Browse the repository at this point in the history
…atabase

Closes spring-projectsgh-480
  • Loading branch information
ovidiupopa07 committed Jan 24, 2022
1 parent a1e513b commit da6a10e
Show file tree
Hide file tree
Showing 4 changed files with 205 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package org.springframework.security.oauth2.server.authorization;

import java.nio.charset.StandardCharsets;
import java.sql.DatabaseMetaData;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
Expand All @@ -35,6 +36,7 @@

import org.springframework.dao.DataRetrievalFailureException;
import org.springframework.jdbc.core.ArgumentPreparedStatementSetter;
import org.springframework.jdbc.core.ConnectionCallback;
import org.springframework.jdbc.core.JdbcOperations;
import org.springframework.jdbc.core.PreparedStatementSetter;
import org.springframework.jdbc.core.RowMapper;
Expand Down Expand Up @@ -141,6 +143,7 @@ public class JdbcOAuth2AuthorizationService implements OAuth2AuthorizationServic

private final JdbcOperations jdbcOperations;
private final LobHandler lobHandler;
private static int tokenColumnType;
private RowMapper<OAuth2Authorization> authorizationRowMapper;
private Function<OAuth2Authorization, List<SqlParameterValue>> authorizationParametersMapper;

Expand Down Expand Up @@ -169,12 +172,15 @@ public JdbcOAuth2AuthorizationService(JdbcOperations jdbcOperations,
Assert.notNull(lobHandler, "lobHandler cannot be null");
this.jdbcOperations = jdbcOperations;
this.lobHandler = lobHandler;
tokenColumnType = getColumnDataType(jdbcOperations, "access_token_value");
OAuth2AuthorizationRowMapper authorizationRowMapper = new OAuth2AuthorizationRowMapper(registeredClientRepository);
authorizationRowMapper.setLobHandler(lobHandler);
this.authorizationRowMapper = authorizationRowMapper;
this.authorizationParametersMapper = new OAuth2AuthorizationParametersMapper();
OAuth2AuthorizationParametersMapper authorizationParametersMapper = new OAuth2AuthorizationParametersMapper();
this.authorizationParametersMapper = authorizationParametersMapper;
}


@Override
public void save(OAuth2Authorization authorization) {
Assert.notNull(authorization, "authorization cannot be null");
Expand Down Expand Up @@ -232,26 +238,33 @@ public OAuth2Authorization findByToken(String token, @Nullable OAuth2TokenType t
List<SqlParameterValue> parameters = new ArrayList<>();
if (tokenType == null) {
parameters.add(new SqlParameterValue(Types.VARCHAR, token));
parameters.add(new SqlParameterValue(Types.BLOB, token.getBytes(StandardCharsets.UTF_8)));
parameters.add(new SqlParameterValue(Types.BLOB, token.getBytes(StandardCharsets.UTF_8)));
parameters.add(new SqlParameterValue(Types.BLOB, token.getBytes(StandardCharsets.UTF_8)));
parameters.add(mapTokenToSqlParameter(token));
parameters.add(mapTokenToSqlParameter(token));
parameters.add(mapTokenToSqlParameter(token));
return findBy(UNKNOWN_TOKEN_TYPE_FILTER, parameters);
} else if (OAuth2ParameterNames.STATE.equals(tokenType.getValue())) {
parameters.add(new SqlParameterValue(Types.VARCHAR, token));
return findBy(STATE_FILTER, parameters);
} else if (OAuth2ParameterNames.CODE.equals(tokenType.getValue())) {
parameters.add(new SqlParameterValue(Types.BLOB, token.getBytes(StandardCharsets.UTF_8)));
parameters.add(mapTokenToSqlParameter(token));
return findBy(AUTHORIZATION_CODE_FILTER, parameters);
} else if (OAuth2TokenType.ACCESS_TOKEN.equals(tokenType)) {
parameters.add(new SqlParameterValue(Types.BLOB, token.getBytes(StandardCharsets.UTF_8)));
parameters.add(mapTokenToSqlParameter(token));
return findBy(ACCESS_TOKEN_FILTER, parameters);
} else if (OAuth2TokenType.REFRESH_TOKEN.equals(tokenType)) {
parameters.add(new SqlParameterValue(Types.BLOB, token.getBytes(StandardCharsets.UTF_8)));
parameters.add(mapTokenToSqlParameter(token));
return findBy(REFRESH_TOKEN_FILTER, parameters);
}
return null;
}

private SqlParameterValue mapTokenToSqlParameter(String token) {
if (Types.BLOB == tokenColumnType) {
return new SqlParameterValue(Types.BLOB, token.getBytes(StandardCharsets.UTF_8));
}
return new SqlParameterValue(tokenColumnType, token);
}

private OAuth2Authorization findBy(String filter, List<SqlParameterValue> parameters) {
try (LobCreator lobCreator = getLobHandler().getLobCreator()) {
PreparedStatementSetter pss = new LobCreatorArgumentPreparedStatementSetter(lobCreator,
Expand Down Expand Up @@ -349,25 +362,22 @@ public OAuth2Authorization mapRow(ResultSet rs, int rowNum) throws SQLException
builder.attribute(OAuth2ParameterNames.STATE, state);
}

String tokenValue;
Instant tokenIssuedAt;
Instant tokenExpiresAt;
byte[] authorizationCodeValue = this.lobHandler.getBlobAsBytes(rs, "authorization_code_value");
String authorizationCodeValue = getTokenValue(rs, "authorization_code_value");

if (authorizationCodeValue != null) {
tokenValue = new String(authorizationCodeValue, StandardCharsets.UTF_8);
if (StringUtils.hasText(authorizationCodeValue)) {
tokenIssuedAt = rs.getTimestamp("authorization_code_issued_at").toInstant();
tokenExpiresAt = rs.getTimestamp("authorization_code_expires_at").toInstant();
Map<String, Object> authorizationCodeMetadata = parseMap(rs.getString("authorization_code_metadata"));

OAuth2AuthorizationCode authorizationCode = new OAuth2AuthorizationCode(
tokenValue, tokenIssuedAt, tokenExpiresAt);
authorizationCodeValue, tokenIssuedAt, tokenExpiresAt);
builder.token(authorizationCode, (metadata) -> metadata.putAll(authorizationCodeMetadata));
}

byte[] accessTokenValue = this.lobHandler.getBlobAsBytes(rs, "access_token_value");
if (accessTokenValue != null) {
tokenValue = new String(accessTokenValue, StandardCharsets.UTF_8);
String accessTokenValue = getTokenValue(rs, "access_token_value");
if (StringUtils.hasText(accessTokenValue)) {
tokenIssuedAt = rs.getTimestamp("access_token_issued_at").toInstant();
tokenExpiresAt = rs.getTimestamp("access_token_expires_at").toInstant();
Map<String, Object> accessTokenMetadata = parseMap(rs.getString("access_token_metadata"));
Expand All @@ -381,25 +391,23 @@ public OAuth2Authorization mapRow(ResultSet rs, int rowNum) throws SQLException
if (accessTokenScopes != null) {
scopes = StringUtils.commaDelimitedListToSet(accessTokenScopes);
}
OAuth2AccessToken accessToken = new OAuth2AccessToken(tokenType, tokenValue, tokenIssuedAt, tokenExpiresAt, scopes);
OAuth2AccessToken accessToken = new OAuth2AccessToken(tokenType, accessTokenValue, tokenIssuedAt, tokenExpiresAt, scopes);
builder.token(accessToken, (metadata) -> metadata.putAll(accessTokenMetadata));
}

byte[] oidcIdTokenValue = this.lobHandler.getBlobAsBytes(rs, "oidc_id_token_value");
if (oidcIdTokenValue != null) {
tokenValue = new String(oidcIdTokenValue, StandardCharsets.UTF_8);
String oidcIdTokenValue = getTokenValue(rs, "oidc_id_token_value");
if (StringUtils.hasText(oidcIdTokenValue)) {
tokenIssuedAt = rs.getTimestamp("oidc_id_token_issued_at").toInstant();
tokenExpiresAt = rs.getTimestamp("oidc_id_token_expires_at").toInstant();
Map<String, Object> oidcTokenMetadata = parseMap(rs.getString("oidc_id_token_metadata"));

OidcIdToken oidcToken = new OidcIdToken(
tokenValue, tokenIssuedAt, tokenExpiresAt, (Map<String, Object>) oidcTokenMetadata.get(OAuth2Authorization.Token.CLAIMS_METADATA_NAME));
oidcIdTokenValue, tokenIssuedAt, tokenExpiresAt, (Map<String, Object>) oidcTokenMetadata.get(OAuth2Authorization.Token.CLAIMS_METADATA_NAME));
builder.token(oidcToken, (metadata) -> metadata.putAll(oidcTokenMetadata));
}

byte[] refreshTokenValue = this.lobHandler.getBlobAsBytes(rs, "refresh_token_value");
if (refreshTokenValue != null) {
tokenValue = new String(refreshTokenValue, StandardCharsets.UTF_8);
String refreshTokenValue = getTokenValue(rs, "refresh_token_value");
if (StringUtils.hasText(refreshTokenValue)) {
tokenIssuedAt = rs.getTimestamp("refresh_token_issued_at").toInstant();
tokenExpiresAt = null;
Timestamp refreshTokenExpiresAt = rs.getTimestamp("refresh_token_expires_at");
Expand All @@ -409,12 +417,29 @@ public OAuth2Authorization mapRow(ResultSet rs, int rowNum) throws SQLException
Map<String, Object> refreshTokenMetadata = parseMap(rs.getString("refresh_token_metadata"));

OAuth2RefreshToken refreshToken = new OAuth2RefreshToken(
tokenValue, tokenIssuedAt, tokenExpiresAt);
refreshTokenValue, tokenIssuedAt, tokenExpiresAt);
builder.token(refreshToken, (metadata) -> metadata.putAll(refreshTokenMetadata));
}
return builder.build();
}

private String getTokenValue(ResultSet rs, String tokenColumn) throws SQLException {
String tokenValue = null;
if (Types.CLOB == tokenColumnType) {
tokenValue = this.lobHandler.getClobAsString(rs, tokenColumn);
}
if (Types.VARCHAR == tokenColumnType) {
tokenValue = rs.getString(tokenColumn);
}
if (Types.BLOB == tokenColumnType) {
byte[] tokenValueByte = this.lobHandler.getBlobAsBytes(rs, tokenColumn);
if (tokenValueByte != null) {
tokenValue = new String(tokenValueByte, StandardCharsets.UTF_8);
}
}
return tokenValue;
}

public final void setLobHandler(LobHandler lobHandler) {
Assert.notNull(lobHandler, "lobHandler cannot be null");
this.lobHandler = lobHandler;
Expand Down Expand Up @@ -520,12 +545,12 @@ protected final ObjectMapper getObjectMapper() {

private <T extends AbstractOAuth2Token> List<SqlParameterValue> toSqlParameterList(OAuth2Authorization.Token<T> token) {
List<SqlParameterValue> parameters = new ArrayList<>();
byte[] tokenValue = null;
String tokenValue = null;
Timestamp tokenIssuedAt = null;
Timestamp tokenExpiresAt = null;
String metadata = null;
if (token != null) {
tokenValue = token.getToken().getTokenValue().getBytes(StandardCharsets.UTF_8);
tokenValue = token.getToken().getTokenValue();
if (token.getToken().getIssuedAt() != null) {
tokenIssuedAt = Timestamp.from(token.getToken().getIssuedAt());
}
Expand All @@ -534,7 +559,13 @@ private <T extends AbstractOAuth2Token> List<SqlParameterValue> toSqlParameterLi
}
metadata = writeMap(token.getMetadata());
}
parameters.add(new SqlParameterValue(Types.BLOB, tokenValue));
if (Types.BLOB == tokenColumnType && StringUtils.hasText(tokenValue)) {
byte[] tokenValueAsBytes = tokenValue.getBytes(StandardCharsets.UTF_8);
parameters.add(new SqlParameterValue(tokenColumnType, tokenValueAsBytes));
} else {
parameters.add(new SqlParameterValue(tokenColumnType, tokenValue));
}

parameters.add(new SqlParameterValue(Types.TIMESTAMP, tokenIssuedAt));
parameters.add(new SqlParameterValue(Types.TIMESTAMP, tokenExpiresAt));
parameters.add(new SqlParameterValue(Types.VARCHAR, metadata));
Expand All @@ -551,6 +582,23 @@ private String writeMap(Map<String, Object> data) {

}

private static int getColumnDataType(JdbcOperations jdbcOperations, String columnName){
return jdbcOperations.execute((ConnectionCallback<Integer>) con -> {
DatabaseMetaData databaseMetaData = con.getMetaData();
ResultSet rs = databaseMetaData.getColumns(null, null, TABLE_NAME, columnName);
if (rs.next()) {
return rs.getInt("DATA_TYPE");
}
// NOTE: When using HSQL: When a database object is created with one of the CREATE statements if the name is enclosed in double quotes, the exact name is used as the case-normal form.
// But if it is not enclosed in double quotes, the name is converted to uppercase and this uppercase version is stored in the database as the case-normal form
rs = databaseMetaData.getColumns(null, null, TABLE_NAME.toUpperCase(), columnName.toUpperCase());
if (rs.next()) {
return rs.getInt("DATA_TYPE");
}
return Types.NULL;
});
}

private static final class LobCreatorArgumentPreparedStatementSetter extends ArgumentPreparedStatementSetter {
private final LobCreator lobCreator;

Expand All @@ -572,6 +620,15 @@ protected void doSetValue(PreparedStatement ps, int parameterPosition, Object ar
this.lobCreator.setBlobAsBytes(ps, parameterPosition, valueBytes);
return;
}
if (paramValue.getSqlType() == Types.CLOB) {
if (paramValue.getValue() != null) {
Assert.isInstanceOf(String.class, paramValue.getValue(),
"Value of clob parameter must be String");
}
String valueString = (String) paramValue.getValue();
this.lobCreator.setClobAsString(ps, parameterPosition, valueString);
return;
}
}
super.doSetValue(ps, parameterPosition, argValue);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Copyright 2020-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE oauth2_authorization (
id varchar(100) NOT NULL,
registered_client_id varchar(100) NOT NULL,
principal_name varchar(200) NOT NULL,
authorization_grant_type varchar(100) NOT NULL,
attributes varchar(15000) DEFAULT NULL,
state varchar(500) DEFAULT NULL,
authorization_code_value text DEFAULT NULL,
authorization_code_issued_at timestamp DEFAULT NULL,
authorization_code_expires_at timestamp DEFAULT NULL,
authorization_code_metadata varchar(2000) DEFAULT NULL,
access_token_value text DEFAULT NULL,
access_token_issued_at timestamp DEFAULT NULL,
access_token_expires_at timestamp DEFAULT NULL,
access_token_metadata varchar(2000) DEFAULT NULL,
access_token_type varchar(100) DEFAULT NULL,
access_token_scopes varchar(1000) DEFAULT NULL,
oidc_id_token_value text DEFAULT NULL,
oidc_id_token_issued_at timestamp DEFAULT NULL,
oidc_id_token_expires_at timestamp DEFAULT NULL,
oidc_id_token_metadata varchar(2000) DEFAULT NULL,
refresh_token_value text DEFAULT NULL,
refresh_token_issued_at timestamp DEFAULT NULL,
refresh_token_expires_at timestamp DEFAULT NULL,
refresh_token_metadata varchar(2000) DEFAULT NULL,
PRIMARY KEY (id)
);
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabase;
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseBuilder;
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseType;
import org.springframework.jdbc.support.lob.DefaultLobHandler;
import org.springframework.security.oauth2.core.AbstractOAuth2Token;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
Expand Down Expand Up @@ -75,6 +76,7 @@
public class JdbcOAuth2AuthorizationServiceTests {
private static final String OAUTH2_AUTHORIZATION_SCHEMA_SQL_RESOURCE = "org/springframework/security/oauth2/server/authorization/oauth2-authorization-schema.sql";
private static final String CUSTOM_OAUTH2_AUTHORIZATION_SCHEMA_SQL_RESOURCE = "org/springframework/security/oauth2/server/authorization/custom-oauth2-authorization-schema.sql";
private static final String OAUTH2_AUTHORIZATION_SCHEMA_CLOB_COLUMN_TYPE_SQL_RESOURCE = "org/springframework/security/oauth2/server/authorization/custom-oauth2-authorization-schema-clob-data-type.sql";
private static final OAuth2TokenType AUTHORIZATION_CODE_TOKEN_TYPE = new OAuth2TokenType(OAuth2ParameterNames.CODE);
private static final OAuth2TokenType STATE_TOKEN_TYPE = new OAuth2TokenType(OAuth2ParameterNames.STATE);
private static final String ID = "id";
Expand Down Expand Up @@ -414,6 +416,37 @@ public void tableDefinitionWhenCustomThenAbleToOverride() {
db.shutdown();
}

@Test
public void tableDefinitionWhenClobSqlTypeThenUpdateAuthorization() {
EmbeddedDatabase db = createDb(OAUTH2_AUTHORIZATION_SCHEMA_CLOB_COLUMN_TYPE_SQL_RESOURCE);
OAuth2AuthorizationService authorizationService =
new JdbcOAuth2AuthorizationService(new JdbcTemplate(db), this.registeredClientRepository);
when(this.registeredClientRepository.findById(eq(REGISTERED_CLIENT.getId())))
.thenReturn(REGISTERED_CLIENT);
OAuth2Authorization originalAuthorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
.id(ID)
.principalName(PRINCIPAL_NAME)
.authorizationGrantType(AUTHORIZATION_GRANT_TYPE)
.token(AUTHORIZATION_CODE)
.build();
authorizationService.save(originalAuthorization);

OAuth2Authorization authorization = authorizationService.findById(
originalAuthorization.getId());
assertThat(authorization).isEqualTo(originalAuthorization);

OAuth2Authorization updatedAuthorization = OAuth2Authorization.from(authorization)
.attribute("custom-name-1", "custom-value-1")
.build();
authorizationService.save(updatedAuthorization);

authorization = authorizationService.findById(
updatedAuthorization.getId());
assertThat(authorization).isEqualTo(updatedAuthorization);
assertThat(authorization).isNotEqualTo(originalAuthorization);
db.shutdown();
}

private static EmbeddedDatabase createDb() {
return createDb(OAUTH2_AUTHORIZATION_SCHEMA_SQL_RESOURCE);
}
Expand Down Expand Up @@ -479,11 +512,14 @@ private static final class CustomJdbcOAuth2AuthorizationService extends JdbcOAut

private CustomJdbcOAuth2AuthorizationService(JdbcOperations jdbcOperations,
RegisteredClientRepository registeredClientRepository) {
super(jdbcOperations, registeredClientRepository);
super(jdbcOperations, registeredClientRepository, new DefaultLobHandler());
setAuthorizationRowMapper(new CustomOAuth2AuthorizationRowMapper(registeredClientRepository));
setAuthorizationParametersMapper(new CustomOAuth2AuthorizationParametersMapper());

}



@Override
public void save(OAuth2Authorization authorization) {
List<SqlParameterValue> parameters = getAuthorizationParametersMapper().apply(authorization);
Expand Down

0 comments on commit da6a10e

Please sign in to comment.