Skip to content
This repository has been archived by the owner on May 31, 2022. It is now read-only.

Commit

Permalink
SECOAUTH-129: Add flushing to in memory token store
Browse files Browse the repository at this point in the history
  • Loading branch information
dsyer committed Feb 2, 2012
1 parent 0648802 commit 5552861
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 11 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
package org.springframework.security.oauth2.provider.token;

import java.util.Date;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.DelayQueue;
import java.util.concurrent.Delayed;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;

import org.springframework.security.oauth2.common.ExpiringOAuth2RefreshToken;
import org.springframework.security.oauth2.common.OAuth2AccessToken;
Expand All @@ -16,6 +21,8 @@
*/
public class InMemoryTokenStore implements TokenStore {

private static final int DEFAULT_FLUSH_INTERVAL = 1000;

private final ConcurrentHashMap<String, OAuth2AccessToken> accessTokenStore = new ConcurrentHashMap<String, OAuth2AccessToken>();

private final ConcurrentHashMap<String, OAuth2AccessToken> authenticationToAccessTokenStore = new ConcurrentHashMap<String, OAuth2AccessToken>();
Expand All @@ -28,24 +35,50 @@ public class InMemoryTokenStore implements TokenStore {

private final ConcurrentHashMap<String, String> refreshTokenToAcessTokenStore = new ConcurrentHashMap<String, String>();

private final DelayQueue<TokenExpiry> expiryQueue = new DelayQueue<TokenExpiry>();

private int flushInterval = DEFAULT_FLUSH_INTERVAL;

private AuthenticationKeyGenerator authenticationKeyGenerator = new DefaultAuthenticationKeyGenerator();


private AtomicInteger flushCounter = new AtomicInteger(0);

/**
* The number of tokens to store before flushing expired tokens. Defaults to 1000.
*
* @param flushInterval the interval to set
*/
public void setFlushInterval(int flushInterval) {
this.flushInterval = flushInterval;
}

/**
* The interval (count of token inserts) between flushing expired tokens.
*
* @return the flushInterval the flush interval
*/
public int getFlushInterval() {
return flushInterval;
}

public void setAuthenticationKeyGenerator(AuthenticationKeyGenerator authenticationKeyGenerator) {
this.authenticationKeyGenerator = authenticationKeyGenerator;
this.authenticationKeyGenerator = authenticationKeyGenerator;
}

public int getAccessTokenCount() {
Assert.state(accessTokenStore.size()>=accessTokenToRefreshTokenStore.size(), "Too many refresh tokens");
Assert.state(accessTokenStore.size()==authenticationToAccessTokenStore.size(), "Inconsistent token store state");
Assert.state(accessTokenStore.size()<=authenticationStore.size(), "Inconsistent authentication store state");
Assert.state(accessTokenStore.size() >= accessTokenToRefreshTokenStore.size(), "Too many refresh tokens");
Assert.state(accessTokenStore.size() == authenticationToAccessTokenStore.size(),
"Inconsistent token store state");
Assert.state(accessTokenStore.size() <= authenticationStore.size(), "Inconsistent authentication store state");
return accessTokenStore.size();
}

public int getRefreshTokenCount() {
Assert.state(refreshTokenStore.size()==refreshTokenToAcessTokenStore.size(), "Inconsistent refresh token store state");
Assert.state(refreshTokenStore.size() == refreshTokenToAcessTokenStore.size(),
"Inconsistent refresh token store state");
return accessTokenStore.size();
}

public OAuth2AccessToken getAccessToken(OAuth2Authentication authentication) {
return authenticationToAccessTokenStore.get(authenticationKeyGenerator.extractKey(authentication));
}
Expand All @@ -55,9 +88,16 @@ public OAuth2Authentication readAuthentication(OAuth2AccessToken token) {
}

public void storeAccessToken(OAuth2AccessToken token, OAuth2Authentication authentication) {
if (this.flushCounter.incrementAndGet() >= this.flushInterval) {
flush();
this.flushCounter.set(0);
}
this.accessTokenStore.put(token.getValue(), token);
this.authenticationStore.put(token.getValue(), authentication);
this.authenticationToAccessTokenStore.put(authenticationKeyGenerator.extractKey(authentication), token);
if (token.getExpiration() != null) {
this.expiryQueue.put(new TokenExpiry(token.getValue(), token.getExpiration()));
}
if (token.getRefreshToken() != null && token.getRefreshToken().getValue() != null) {
this.refreshTokenToAcessTokenStore.put(token.getRefreshToken().getValue(), token.getValue());
this.accessTokenToRefreshTokenStore.put(token.getValue(), token.getRefreshToken().getValue());
Expand All @@ -71,12 +111,12 @@ public OAuth2AccessToken readAccessToken(String tokenValue) {
public void removeAccessToken(String tokenValue) {
this.accessTokenStore.remove(tokenValue);
String refresh = this.accessTokenToRefreshTokenStore.remove(tokenValue);
if (refresh!=null) {
if (refresh != null) {
this.refreshTokenStore.remove(tokenValue);
this.refreshTokenToAcessTokenStore.remove(tokenValue);
}
OAuth2Authentication authentication = this.authenticationStore.remove(tokenValue);
if (authentication!=null) {
if (authentication != null) {
this.authenticationToAccessTokenStore.remove(authenticationKeyGenerator.extractKey(authentication));
}
}
Expand Down Expand Up @@ -106,4 +146,42 @@ public void removeAccessTokenUsingRefreshToken(String refreshToken) {
this.authenticationStore.remove(accessToken);
}
}

private void flush() {
TokenExpiry expiry = expiryQueue.poll();
while (expiry != null) {
removeAccessToken(expiry.getValue());
expiry = expiryQueue.poll();
}
}

private static class TokenExpiry implements Delayed {

private final long expiry;

private final String value;

public TokenExpiry(String value, Date date) {
this.value = value;
this.expiry = date.getTime();
}

public int compareTo(Delayed other) {
if (this == other) {
return 0;
}
long diff = getDelay(TimeUnit.MILLISECONDS) - other.getDelay(TimeUnit.MILLISECONDS);
return (diff == 0 ? 0 : ((diff < 0) ? -1 : 1));
}

public long getDelay(TimeUnit unit) {
return expiry - System.currentTimeMillis();
}

public String getValue() {
return value;
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -253,4 +253,5 @@ public void setSelectAccessTokenFromAuthenticationSql(String selectAccessTokenFr
public void setDeleteAccessTokenFromRefreshTokenSql(String deleteAccessTokenFromRefreshTokenSql) {
this.deleteAccessTokenFromRefreshTokenSql = deleteAccessTokenFromRefreshTokenSql;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -86,4 +86,5 @@ public interface TokenStore {
* @return the access token or null if there was none
*/
OAuth2AccessToken getAccessToken(OAuth2Authentication authentication);

}
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
package org.springframework.security.oauth2.provider.token;

import static org.junit.Assert.assertEquals;

import java.util.Date;

import org.junit.Before;
import org.junit.Test;
import org.springframework.security.oauth2.common.OAuth2AccessToken;
import org.springframework.security.oauth2.provider.AuthorizationRequest;
import org.springframework.security.oauth2.provider.OAuth2Authentication;

/**
* @author Dave Syer
*
*
*/
public class TestInMemoryTokenStore extends TestTokenStoreBase {

Expand All @@ -20,4 +28,33 @@ public void createStore() {
tokenStore = new InMemoryTokenStore();
}

@Test
public void testTokenCountConsistency() throws Exception {
for (int i = 0; i <= 10; i++) {
OAuth2Authentication expectedAuthentication = new OAuth2Authentication(new AuthorizationRequest("id"+i, null,
null, null), new TestAuthentication("test", false));
OAuth2AccessToken expectedOAuth2AccessToken = new OAuth2AccessToken("testToken"+i);
expectedOAuth2AccessToken.setExpiration(new Date(System.currentTimeMillis() - 1000));
if (i>1) {
assertEquals(i, getTokenStore().getAccessTokenCount());
}
getTokenStore().storeAccessToken(expectedOAuth2AccessToken, expectedAuthentication);
}
}

@Test
public void testAutoFlush() throws Exception {
getTokenStore().setFlushInterval(3);
for (int i = 0; i <= 10; i++) {
OAuth2Authentication expectedAuthentication = new OAuth2Authentication(new AuthorizationRequest("id"+i, null,
null, null), new TestAuthentication("test", false));
OAuth2AccessToken expectedOAuth2AccessToken = new OAuth2AccessToken("testToken"+i);
expectedOAuth2AccessToken.setExpiration(new Date(System.currentTimeMillis() - 1000));
if (i>2) {
assertEquals((i%3+1), getTokenStore().getAccessTokenCount());
}
getTokenStore().storeAccessToken(expectedOAuth2AccessToken, expectedAuthentication);
}
}

}

0 comments on commit 5552861

Please sign in to comment.