Skip to content
This repository has been archived by the owner on May 31, 2022. It is now read-only.

Commit

Permalink
RedisTokenStore expiration fix #1657
Browse files Browse the repository at this point in the history
  • Loading branch information
LukasVyhlidka committed Jun 21, 2019
1 parent bbae002 commit 09eceaa
Show file tree
Hide file tree
Showing 3 changed files with 161 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import org.springframework.data.redis.connection.RedisConnection;
import org.springframework.data.redis.connection.RedisConnectionFactory;
import org.springframework.data.redis.connection.RedisZSetCommands;
import org.springframework.data.redis.core.Cursor;
import org.springframework.data.redis.core.ScanOptions;
import org.springframework.security.oauth2.common.ExpiringOAuth2RefreshToken;
Expand Down Expand Up @@ -30,11 +31,10 @@ public class RedisTokenStore implements TokenStore {
private static final String AUTH_TO_ACCESS = "auth_to_access:";
private static final String AUTH = "auth:";
private static final String REFRESH_AUTH = "refresh_auth:";
private static final String ACCESS_TO_REFRESH = "access_to_refresh:";
private static final String REFRESH = "refresh:";
private static final String REFRESH_TO_ACCESS = "refresh_to_access:";
private static final String CLIENT_ID_TO_ACCESS = "client_id_to_access:";
private static final String UNAME_TO_ACCESS = "uname_to_access:";
private static final String CLIENT_ID_TO_ACCESS = "client_id_to_access_z:";
private static final String UNAME_TO_ACCESS = "uname_to_access_z:";

private static final boolean springDataRedis_2_0 = ClassUtils.isPresent(
"org.springframework.data.redis.connection.RedisStandaloneConfiguration",
Expand Down Expand Up @@ -189,34 +189,40 @@ public void storeAccessToken(OAuth2AccessToken token, OAuth2Authentication authe
conn.set(authKey, serializedAuth);
conn.set(authToAccessKey, serializedAccessToken);
}
if (!authentication.isClientOnly()) {
conn.sAdd(approvalKey, serializedAccessToken);
}
conn.sAdd(clientId, serializedAccessToken);

if (token.getExpiration() != null) {
int seconds = token.getExpiresIn();
long expirationTime = token.getExpiration().getTime();

if (!authentication.isClientOnly()) {
conn.zAdd(approvalKey, expirationTime, serializedAccessToken);
}
conn.zAdd(clientId, expirationTime, serializedAccessToken);

conn.expire(accessKey, seconds);
conn.expire(authKey, seconds);
conn.expire(authToAccessKey, seconds);
conn.expire(clientId, seconds);
conn.expire(approvalKey, seconds);
} else {
conn.zAdd(clientId, -1, serializedAccessToken); // -1 don't expire
if (!authentication.isClientOnly()) {
conn.zAdd(approvalKey, -1, serializedAccessToken);
}
}
OAuth2RefreshToken refreshToken = token.getRefreshToken();
if (refreshToken != null && refreshToken.getValue() != null) {
byte[] refresh = serialize(token.getRefreshToken().getValue());
byte[] auth = serialize(token.getValue());
byte[] refreshToAccessKey = serializeKey(REFRESH_TO_ACCESS + token.getRefreshToken().getValue());
byte[] accessToRefreshKey = serializeKey(ACCESS_TO_REFRESH + token.getValue());
if (springDataRedis_2_0) {
try {
this.redisConnectionSet_2_0.invoke(conn, refreshToAccessKey, auth);
this.redisConnectionSet_2_0.invoke(conn, accessToRefreshKey, refresh);
} catch (Exception ex) {
throw new RuntimeException(ex);
}
} else {
conn.set(refreshToAccessKey, auth);
conn.set(accessToRefreshKey, refresh);
}
if (refreshToken instanceof ExpiringOAuth2RefreshToken) {
ExpiringOAuth2RefreshToken expiringRefreshToken = (ExpiringOAuth2RefreshToken) refreshToken;
Expand All @@ -225,7 +231,6 @@ public void storeAccessToken(OAuth2AccessToken token, OAuth2Authentication authe
int seconds = Long.valueOf((expiration.getTime() - System.currentTimeMillis()) / 1000L)
.intValue();
conn.expire(refreshToAccessKey, seconds);
conn.expire(accessToRefreshKey, seconds);
}
}
}
Expand Down Expand Up @@ -267,14 +272,12 @@ public OAuth2AccessToken readAccessToken(String tokenValue) {
public void removeAccessToken(String tokenValue) {
byte[] accessKey = serializeKey(ACCESS + tokenValue);
byte[] authKey = serializeKey(AUTH + tokenValue);
byte[] accessToRefreshKey = serializeKey(ACCESS_TO_REFRESH + tokenValue);
RedisConnection conn = getConnection();
try {
conn.openPipeline();
conn.get(accessKey);
conn.get(authKey);
conn.del(accessKey);
conn.del(accessToRefreshKey);
// Don't remove the refresh token - it's up to the caller to do that
conn.del(authKey);
List<Object> results = conn.closePipeline();
Expand All @@ -289,8 +292,8 @@ public void removeAccessToken(String tokenValue) {
byte[] clientId = serializeKey(CLIENT_ID_TO_ACCESS + authentication.getOAuth2Request().getClientId());
conn.openPipeline();
conn.del(authToAccessKey);
conn.sRem(unameKey, access);
conn.sRem(clientId, access);
conn.zRem(unameKey, access);
conn.zRem(clientId, access);
conn.del(serialize(ACCESS + key));
conn.closePipeline();
}
Expand Down Expand Up @@ -357,14 +360,12 @@ public void removeRefreshToken(String tokenValue) {
byte[] refreshKey = serializeKey(REFRESH + tokenValue);
byte[] refreshAuthKey = serializeKey(REFRESH_AUTH + tokenValue);
byte[] refresh2AccessKey = serializeKey(REFRESH_TO_ACCESS + tokenValue);
byte[] access2RefreshKey = serializeKey(ACCESS_TO_REFRESH + tokenValue);
RedisConnection conn = getConnection();
try {
conn.openPipeline();
conn.del(refreshKey);
conn.del(refreshAuthKey);
conn.del(refresh2AccessKey);
conn.del(access2RefreshKey);
conn.closePipeline();
} finally {
conn.close();
Expand Down Expand Up @@ -398,24 +399,68 @@ private void removeAccessTokenUsingRefreshToken(String refreshToken) {
}
}

private List<byte[]> getByteLists(byte[] approvalKey, RedisConnection conn) {
private List<byte[]> getZByteLists(byte[] key, RedisConnection conn) {
// Sorted Set expiration maintenance
long currentTime = System.currentTimeMillis();
conn.zRemRangeByScore(key, 0, currentTime);

List<byte[]> byteList;
Long size = conn.sCard(approvalKey);
Long size = conn.zCard(key);
byteList = new ArrayList<byte[]>(size.intValue());
Cursor<byte[]> cursor = conn.sScan(approvalKey, ScanOptions.NONE);
Cursor<RedisZSetCommands.Tuple> cursor = conn.zScan(key, ScanOptions.NONE);

while(cursor.hasNext()) {
byteList.add(cursor.next());
RedisZSetCommands.Tuple t = cursor.next();

// Probably not necessary because of the maintenance at the beginning but why not...
if (t.getScore() == -1 || t.getScore() > currentTime) {
byteList.add(t.getValue());
}
}
return byteList;
}

/**
* Runs a maintenance of the RedisTokenStore.
*
* SortedSets UNAME_TO_ACCESS and CLIENT_ID_TO_ACCESS contains access tokens that can expire.
* This expiration is set as a score of the Redis SortedSet data structure. Redis does not support expiration of items in a container data structure.
* It supports only expiration of whole key. In case there is still new access tokens being stored into the RedisTokenStore before whole key gets expired,
* the expiration is prolonged and the key is not effectively deleted. To do "garbage collection" this method should be called once upon a time.
* @return how many items were removed
*/
public long doMaintenance() {
long removed = 0;
RedisConnection conn = getConnection();
try {
//client_id_to_acccess maintenance
Cursor<byte[]> clientToAccessKeys = conn.scan(ScanOptions.scanOptions().match(prefix + CLIENT_ID_TO_ACCESS + "*").build());
while (clientToAccessKeys.hasNext()) {
byte[] clientIdToAccessKey = clientToAccessKeys.next();

removed += conn.zRemRangeByScore(clientIdToAccessKey, 0, System.currentTimeMillis());
}

//uname_to_access maintenance
Cursor<byte[]> unameToAccessKeys = conn.scan(ScanOptions.scanOptions().match(prefix + UNAME_TO_ACCESS + "*").build());
while (unameToAccessKeys.hasNext()) {
byte[] unameToAccessKey = unameToAccessKeys.next();

removed += conn.zRemRangeByScore(unameToAccessKey, 0, System.currentTimeMillis());
}
} finally {
conn.close();
}
return removed;
}

@Override
public Collection<OAuth2AccessToken> findTokensByClientIdAndUserName(String clientId, String userName) {
byte[] approvalKey = serializeKey(UNAME_TO_ACCESS + getApprovalKey(clientId, userName));
List<byte[]> byteList = null;
RedisConnection conn = getConnection();
try {
byteList = getByteLists(approvalKey, conn);
byteList = getZByteLists(approvalKey, conn);
} finally {
conn.close();
}
Expand All @@ -436,7 +481,7 @@ public Collection<OAuth2AccessToken> findTokensByClientId(String clientId) {
List<byte[]> byteList = null;
RedisConnection conn = getConnection();
try {
byteList = getByteLists(key, conn);
byteList = getZByteLists(key, conn);
} finally {
conn.close();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,16 +93,16 @@ public void storeAccessTokenWithoutRefreshTokenRemoveAccessTokenVerifyKeysRemove
ArgumentCaptor<byte[]> setKeyArgs = ArgumentCaptor.forClass(byte[].class);
verify(connection, times(3)).set(setKeyArgs.capture(), any(byte[].class));

ArgumentCaptor<byte[]> sAddKeyArgs = ArgumentCaptor.forClass(byte[].class);
verify(connection, times(2)).sAdd(sAddKeyArgs.capture(), any(byte[].class));
ArgumentCaptor<byte[]> zAddKeyArgs = ArgumentCaptor.forClass(byte[].class);
verify(connection, times(2)).zAdd(zAddKeyArgs.capture(), anyDouble(), any(byte[].class));

tokenStore.removeAccessToken(oauth2AccessToken);

for (byte[] key : setKeyArgs.getAllValues()) {
verify(connection).del(key);
}
for (byte[] key : sAddKeyArgs.getAllValues()) {
verify(connection).sRem(eq(key), any(byte[].class));
for (byte[] key : zAddKeyArgs.getAllValues()) {
verify(connection).zRem(eq(key), any(byte[].class));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@
import java.util.Date;
import java.util.UUID;

import static org.junit.Assert.*;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;

/**
* @author efenderbosch
Expand Down Expand Up @@ -109,4 +113,88 @@ public void storeAccessTokenWithoutRefreshTokenRemoveAccessTokenVerifyTokenRemov
assertTrue(oauth2AccessTokens.isEmpty());
}

@Test
public void tokenExpirationWithParallelTokenStoring() throws InterruptedException {
final Crate<Boolean> clientOn = new Crate<Boolean>(true);
Runnable clientsSimulation = new Runnable() {
@Override
public void run() {
while (clientOn.getObj()) {
OAuth2Authentication auth = new OAuth2Authentication(RequestTokenFactory.createOAuth2Request(
"client_X", false), new TestAuthentication("user42", false));
DefaultOAuth2AccessToken accessToken = new DefaultOAuth2AccessToken(UUID.randomUUID().toString());
accessToken.setExpiration(new Date(System.currentTimeMillis() + 1500));
getTokenStore().storeAccessToken(accessToken, auth);

// There is new token stored every half a second - our app is very used
try {
Thread.sleep(500);
} catch (InterruptedException e) {
e.printStackTrace();
fail();
}
}
}
};

// Start client simulation
Thread clientsThread = new Thread(clientsSimulation);
clientsThread.start();

// Create and store our own token
String accessToken = UUID.randomUUID().toString();
OAuth2Authentication expectedAuthentication = new OAuth2Authentication(RequestTokenFactory.createOAuth2Request(
"client_X", false), new TestAuthentication("user1", false));
DefaultOAuth2AccessToken expectedOAuth2AccessToken = new DefaultOAuth2AccessToken(accessToken);
expectedOAuth2AccessToken.setExpiration(new Date(System.currentTimeMillis() + 1500));
getTokenStore().storeAccessToken(expectedOAuth2AccessToken, expectedAuthentication);

// Should be possible to read the token
OAuth2AccessToken actualOAuth2AccessToken = getTokenStore().readAccessToken(accessToken);
assertEquals(expectedOAuth2AccessToken, actualOAuth2AccessToken);
assertEquals(expectedAuthentication, getTokenStore().readAuthentication(expectedOAuth2AccessToken));
assertNotNull(findAccessToken(getTokenStore().findTokensByClientId("client_X"), accessToken));
assertNotNull(findAccessToken(getTokenStore().findTokensByClientIdAndUserName("client_X", "user1"), accessToken));

// let the token expire
Thread.sleep(1500);

// now it should be gone
assertNull(getTokenStore().readAccessToken(accessToken));
assertNull(getTokenStore().readAuthentication(expectedOAuth2AccessToken));
assertNull(findAccessToken(getTokenStore().findTokensByClientId("client_X"), accessToken));
assertNull(findAccessToken(getTokenStore().findTokensByClientIdAndUserName("client_X", "user1"), accessToken));

// Stop the client
clientOn.setObj(false);
clientsThread.join();

// let clients token expire
Thread.sleep(2000);
}

private OAuth2AccessToken findAccessToken(Collection<OAuth2AccessToken> tokens, String tokenValue) {
for (OAuth2AccessToken token : tokens) {
if (tokenValue.equals(token.getValue())) {
return token;
}
}
return null;
}

private static class Crate<E> {
E obj;

public Crate(final E obj) {
this.obj = obj;
}

public E getObj() {
return obj;
}

public void setObj(final E obj) {
this.obj = obj;
}
}
}

0 comments on commit 09eceaa

Please sign in to comment.