Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with
or
.
Download ZIP
Browse files

SECOAUTH-129: Add flushing to in memory token store

  • Loading branch information...
commit 55528610be5c887c049e9a85c76c7c798e4a7a87 1 parent 0648802
@dsyer dsyer authored
View
98 .../src/main/java/org/springframework/security/oauth2/provider/token/InMemoryTokenStore.java
@@ -1,6 +1,11 @@
package org.springframework.security.oauth2.provider.token;
+import java.util.Date;
import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.DelayQueue;
+import java.util.concurrent.Delayed;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
import org.springframework.security.oauth2.common.ExpiringOAuth2RefreshToken;
import org.springframework.security.oauth2.common.OAuth2AccessToken;
@@ -16,6 +21,8 @@
*/
public class InMemoryTokenStore implements TokenStore {
+ private static final int DEFAULT_FLUSH_INTERVAL = 1000;
+
private final ConcurrentHashMap<String, OAuth2AccessToken> accessTokenStore = new ConcurrentHashMap<String, OAuth2AccessToken>();
private final ConcurrentHashMap<String, OAuth2AccessToken> authenticationToAccessTokenStore = new ConcurrentHashMap<String, OAuth2AccessToken>();
@@ -28,24 +35,50 @@
private final ConcurrentHashMap<String, String> refreshTokenToAcessTokenStore = new ConcurrentHashMap<String, String>();
+ private final DelayQueue<TokenExpiry> expiryQueue = new DelayQueue<TokenExpiry>();
+
+ private int flushInterval = DEFAULT_FLUSH_INTERVAL;
+
private AuthenticationKeyGenerator authenticationKeyGenerator = new DefaultAuthenticationKeyGenerator();
-
+
+ private AtomicInteger flushCounter = new AtomicInteger(0);
+
+ /**
+ * The number of tokens to store before flushing expired tokens. Defaults to 1000.
+ *
+ * @param flushInterval the interval to set
+ */
+ public void setFlushInterval(int flushInterval) {
+ this.flushInterval = flushInterval;
+ }
+
+ /**
+ * The interval (count of token inserts) between flushing expired tokens.
+ *
+ * @return the flushInterval the flush interval
+ */
+ public int getFlushInterval() {
+ return flushInterval;
+ }
+
public void setAuthenticationKeyGenerator(AuthenticationKeyGenerator authenticationKeyGenerator) {
- this.authenticationKeyGenerator = authenticationKeyGenerator;
+ this.authenticationKeyGenerator = authenticationKeyGenerator;
}
public int getAccessTokenCount() {
- Assert.state(accessTokenStore.size()>=accessTokenToRefreshTokenStore.size(), "Too many refresh tokens");
- Assert.state(accessTokenStore.size()==authenticationToAccessTokenStore.size(), "Inconsistent token store state");
- Assert.state(accessTokenStore.size()<=authenticationStore.size(), "Inconsistent authentication store state");
+ Assert.state(accessTokenStore.size() >= accessTokenToRefreshTokenStore.size(), "Too many refresh tokens");
+ Assert.state(accessTokenStore.size() == authenticationToAccessTokenStore.size(),
+ "Inconsistent token store state");
+ Assert.state(accessTokenStore.size() <= authenticationStore.size(), "Inconsistent authentication store state");
return accessTokenStore.size();
}
-
+
public int getRefreshTokenCount() {
- Assert.state(refreshTokenStore.size()==refreshTokenToAcessTokenStore.size(), "Inconsistent refresh token store state");
+ Assert.state(refreshTokenStore.size() == refreshTokenToAcessTokenStore.size(),
+ "Inconsistent refresh token store state");
return accessTokenStore.size();
}
-
+
public OAuth2AccessToken getAccessToken(OAuth2Authentication authentication) {
return authenticationToAccessTokenStore.get(authenticationKeyGenerator.extractKey(authentication));
}
@@ -55,9 +88,16 @@ public OAuth2Authentication readAuthentication(OAuth2AccessToken token) {
}
public void storeAccessToken(OAuth2AccessToken token, OAuth2Authentication authentication) {
+ if (this.flushCounter.incrementAndGet() >= this.flushInterval) {
+ flush();
+ this.flushCounter.set(0);
+ }
this.accessTokenStore.put(token.getValue(), token);
this.authenticationStore.put(token.getValue(), authentication);
this.authenticationToAccessTokenStore.put(authenticationKeyGenerator.extractKey(authentication), token);
+ if (token.getExpiration() != null) {
+ this.expiryQueue.put(new TokenExpiry(token.getValue(), token.getExpiration()));
+ }
if (token.getRefreshToken() != null && token.getRefreshToken().getValue() != null) {
this.refreshTokenToAcessTokenStore.put(token.getRefreshToken().getValue(), token.getValue());
this.accessTokenToRefreshTokenStore.put(token.getValue(), token.getRefreshToken().getValue());
@@ -71,12 +111,12 @@ public OAuth2AccessToken readAccessToken(String tokenValue) {
public void removeAccessToken(String tokenValue) {
this.accessTokenStore.remove(tokenValue);
String refresh = this.accessTokenToRefreshTokenStore.remove(tokenValue);
- if (refresh!=null) {
+ if (refresh != null) {
this.refreshTokenStore.remove(tokenValue);
this.refreshTokenToAcessTokenStore.remove(tokenValue);
}
OAuth2Authentication authentication = this.authenticationStore.remove(tokenValue);
- if (authentication!=null) {
+ if (authentication != null) {
this.authenticationToAccessTokenStore.remove(authenticationKeyGenerator.extractKey(authentication));
}
}
@@ -106,4 +146,42 @@ public void removeAccessTokenUsingRefreshToken(String refreshToken) {
this.authenticationStore.remove(accessToken);
}
}
+
+ private void flush() {
+ TokenExpiry expiry = expiryQueue.poll();
+ while (expiry != null) {
+ removeAccessToken(expiry.getValue());
+ expiry = expiryQueue.poll();
+ }
+ }
+
+ private static class TokenExpiry implements Delayed {
+
+ private final long expiry;
+
+ private final String value;
+
+ public TokenExpiry(String value, Date date) {
+ this.value = value;
+ this.expiry = date.getTime();
+ }
+
+ public int compareTo(Delayed other) {
+ if (this == other) {
+ return 0;
+ }
+ long diff = getDelay(TimeUnit.MILLISECONDS) - other.getDelay(TimeUnit.MILLISECONDS);
+ return (diff == 0 ? 0 : ((diff < 0) ? -1 : 1));
+ }
+
+ public long getDelay(TimeUnit unit) {
+ return expiry - System.currentTimeMillis();
+ }
+
+ public String getValue() {
+ return value;
+ }
+
+ }
+
}
View
1  ...uth2/src/main/java/org/springframework/security/oauth2/provider/token/JdbcTokenStore.java
@@ -253,4 +253,5 @@ public void setSelectAccessTokenFromAuthenticationSql(String selectAccessTokenFr
public void setDeleteAccessTokenFromRefreshTokenSql(String deleteAccessTokenFromRefreshTokenSql) {
this.deleteAccessTokenFromRefreshTokenSql = deleteAccessTokenFromRefreshTokenSql;
}
+
}
View
1  ...y-oauth2/src/main/java/org/springframework/security/oauth2/provider/token/TokenStore.java
@@ -86,4 +86,5 @@
* @return the access token or null if there was none
*/
OAuth2AccessToken getAccessToken(OAuth2Authentication authentication);
+
}
View
39 .../test/java/org/springframework/security/oauth2/provider/token/TestInMemoryTokenStore.java
@@ -1,10 +1,18 @@
package org.springframework.security.oauth2.provider.token;
+import static org.junit.Assert.assertEquals;
+
+import java.util.Date;
+
import org.junit.Before;
+import org.junit.Test;
+import org.springframework.security.oauth2.common.OAuth2AccessToken;
+import org.springframework.security.oauth2.provider.AuthorizationRequest;
+import org.springframework.security.oauth2.provider.OAuth2Authentication;
/**
* @author Dave Syer
- *
+ *
*/
public class TestInMemoryTokenStore extends TestTokenStoreBase {
@@ -20,4 +28,33 @@ public void createStore() {
tokenStore = new InMemoryTokenStore();
}
+ @Test
+ public void testTokenCountConsistency() throws Exception {
+ for (int i = 0; i <= 10; i++) {
+ OAuth2Authentication expectedAuthentication = new OAuth2Authentication(new AuthorizationRequest("id"+i, null,
+ null, null), new TestAuthentication("test", false));
+ OAuth2AccessToken expectedOAuth2AccessToken = new OAuth2AccessToken("testToken"+i);
+ expectedOAuth2AccessToken.setExpiration(new Date(System.currentTimeMillis() - 1000));
+ if (i>1) {
+ assertEquals(i, getTokenStore().getAccessTokenCount());
+ }
+ getTokenStore().storeAccessToken(expectedOAuth2AccessToken, expectedAuthentication);
+ }
+ }
+
+ @Test
+ public void testAutoFlush() throws Exception {
+ getTokenStore().setFlushInterval(3);
+ for (int i = 0; i <= 10; i++) {
+ OAuth2Authentication expectedAuthentication = new OAuth2Authentication(new AuthorizationRequest("id"+i, null,
+ null, null), new TestAuthentication("test", false));
+ OAuth2AccessToken expectedOAuth2AccessToken = new OAuth2AccessToken("testToken"+i);
+ expectedOAuth2AccessToken.setExpiration(new Date(System.currentTimeMillis() - 1000));
+ if (i>2) {
+ assertEquals((i%3+1), getTokenStore().getAccessTokenCount());
+ }
+ getTokenStore().storeAccessToken(expectedOAuth2AccessToken, expectedAuthentication);
+ }
+ }
+
}
Please sign in to comment.
Something went wrong with that request. Please try again.