Skip to content

Commit

Permalink
hold state in a specific context
Browse files Browse the repository at this point in the history
  • Loading branch information
zawn committed Apr 30, 2016
1 parent aaa8861 commit 8bc79a2
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 21 deletions.
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ protected OAuth1Token getAccessToken(final OAuthCredentials credentials) throws
final String message = "Token received: " + token + " is different from saved token: " + savedToken; final String message = "Token received: " + token + " is different from saved token: " + savedToken;
throw new OAuthCredentialsException(message); throw new OAuthCredentialsException(message);
} }
final OAuth1Token accessToken = ((OAuth10aService) this.service).getAccessToken(tokenRequest, verifier); final OAuth1Token accessToken = this.service.getAccessToken(tokenRequest, verifier);
logger.debug("accessToken: {}", accessToken); logger.debug("accessToken: {}", accessToken);
return accessToken; return accessToken;
} }
Expand Down
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ protected OAuth2AccessToken getAccessToken(final OAuthCredentials credentials) t
// no request token saved in context and no token (OAuth v2.0) // no request token saved in context and no token (OAuth v2.0)
final String verifier = credentials.getVerifier(); final String verifier = credentials.getVerifier();
logger.debug("verifier: {}", verifier); logger.debug("verifier: {}", verifier);
final OAuth2AccessToken accessToken = ((OAuth20Service) this.service).getAccessToken(verifier); final OAuth2AccessToken accessToken = this.service.getAccessToken(verifier);
logger.debug("accessToken: {}", accessToken); logger.debug("accessToken: {}", accessToken);
return accessToken; return accessToken;
} }
Expand Down
Original file line number Original file line Diff line number Diff line change
@@ -1,5 +1,8 @@
package org.pac4j.oauth.client; package org.pac4j.oauth.client;


import com.github.scribejava.core.model.OAuthConfig;
import com.github.scribejava.core.model.SignatureType;
import com.github.scribejava.core.oauth.OAuth20Service;
import org.pac4j.core.context.WebContext; import org.pac4j.core.context.WebContext;
import org.pac4j.core.exception.RequiresHttpAction; import org.pac4j.core.exception.RequiresHttpAction;
import org.pac4j.core.util.CommonHelper; import org.pac4j.core.util.CommonHelper;
Expand All @@ -17,10 +20,11 @@ public abstract class BaseOAuth20StateClient<U extends OAuth20Profile> extends B


private static final String STATE_PARAMETER = "#oauth20StateParameter"; private static final String STATE_PARAMETER = "#oauth20StateParameter";


private String stateData;

@Override @Override
protected String getStateParameter(final WebContext context) { protected String getStateParameter(final WebContext context) {
final String stateParameter; final String stateParameter;
final String stateData = getState();
if (CommonHelper.isNotBlank(stateData)) { if (CommonHelper.isNotBlank(stateData)) {
stateParameter = stateData; stateParameter = stateData;
} else { } else {
Expand All @@ -30,13 +34,24 @@ protected String getStateParameter(final WebContext context) {
} }


@Override @Override
protected void internalInit(WebContext context) { protected OAuthConfig buildOAuthConfig(WebContext context) {
final String state = getStateParameter(context);
// the state is held in a specific context.
context.setSessionAttribute(getName() + STATE_PARAMETER, state);
return new OAuthConfig(this.getKey(), this.getSecret(), computeFinalCallbackUrl(context),
SignatureType.Header, getOAuthScope(), null, this.getConnectTimeout(), this.getReadTimeout(), hasOAuthGrantType() ? "authorization_code" : null, state, this.getResponseType());
}

@Override
protected String retrieveAuthorizationUrl(final WebContext context) throws RequiresHttpAction {
// create a specific configuration with state // create a specific configuration with state
this.setState(getStateParameter(context)); final OAuthConfig config = buildOAuthConfig(context);
CommonHelper.assertNotNull("state", this.getState());
// save state // create a specific service
context.setSessionAttribute(getName() + STATE_PARAMETER, this.getState()); final OAuth20Service newService = getApi().createService(config);
super.internalInit(context); final String authorizationUrl = newService.getAuthorizationUrl();
logger.debug("authorizationUrl: {}", authorizationUrl);
return authorizationUrl;
} }


@Override @Override
Expand All @@ -61,4 +76,11 @@ protected OAuthCredentials getOAuthCredentials(final WebContext context) throws
return super.getOAuthCredentials(context); return super.getOAuthCredentials(context);
} }


public String getStateData() {
return stateData;
}

public void setStateData(String stateData) {
this.stateData = stateData;
}
} }
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -46,8 +46,6 @@ public abstract class BaseOAuthClient<U extends OAuth20Profile, S extends OAuthS


private int readTimeout = HttpConstants.DEFAULT_READ_TIMEOUT; private int readTimeout = HttpConstants.DEFAULT_READ_TIMEOUT;


private String state = null;

private String responseType = null; private String responseType = null;


@Override @Override
Expand All @@ -67,7 +65,7 @@ protected void internalInit(final WebContext context) {
*/ */
protected OAuthConfig buildOAuthConfig(final WebContext context) { protected OAuthConfig buildOAuthConfig(final WebContext context) {
return new OAuthConfig(this.key, this.secret, computeFinalCallbackUrl(context), return new OAuthConfig(this.key, this.secret, computeFinalCallbackUrl(context),
SignatureType.Header, getOAuthScope(), null, this.connectTimeout, this.readTimeout, hasOAuthGrantType() ? "authorization_code" : null, this.state, this.responseType); SignatureType.Header, getOAuthScope(), null, this.connectTimeout, this.readTimeout, hasOAuthGrantType() ? "authorization_code" : null, null, this.responseType);
} }


/** /**
Expand Down Expand Up @@ -298,14 +296,6 @@ public void setTokenAsHeader(boolean tokenAsHeader) {
this.tokenAsHeader = tokenAsHeader; this.tokenAsHeader = tokenAsHeader;
} }


public String getState() {
return state;
}

public void setState(String state) {
this.state = state;
}

public String getResponseType() { public String getResponseType() {
return responseType; return responseType;
} }
Expand Down
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -7,8 +7,12 @@
import org.pac4j.core.util.TestsConstants; import org.pac4j.core.util.TestsConstants;
import org.pac4j.oauth.credentials.OAuthCredentials; import org.pac4j.oauth.credentials.OAuthCredentials;


import java.io.UnsupportedEncodingException;
import java.net.MalformedURLException; import java.net.MalformedURLException;
import java.net.URL; import java.net.URL;
import java.net.URLDecoder;
import java.util.LinkedHashMap;
import java.util.Map;


import static org.junit.Assert.*; import static org.junit.Assert.*;


Expand Down Expand Up @@ -50,11 +54,68 @@ public void testOk() throws RequiresHttpAction {
public void testState() throws MalformedURLException, RequiresHttpAction { public void testState() throws MalformedURLException, RequiresHttpAction {
BaseOAuth20StateClient client = new FacebookClient(KEY, SECRET); BaseOAuth20StateClient client = new FacebookClient(KEY, SECRET);
client.setCallbackUrl(CALLBACK_URL); client.setCallbackUrl(CALLBACK_URL);
client.setState("OK"); client.setStateData("OK");
URL url = new URL(client.getRedirectAction(MockWebContext.create()).getLocation()); URL url = new URL(client.getRedirectAction(MockWebContext.create()).getLocation());
assertTrue(url.getQuery().contains("state=OK")); assertTrue(url.getQuery().contains("state=OK"));
} }


@Test
public void testStateMatch() throws MalformedURLException, RequiresHttpAction, UnsupportedEncodingException {
BaseOAuth20StateClient client = new FacebookClient(KEY, SECRET);
client.setCallbackUrl(CALLBACK_URL);
final MockWebContext mockWebContext = MockWebContext.create();
URL url = new URL(client.getRedirectAction(mockWebContext).getLocation());
final Map<String, String> stringMap = splitQuery(url);
assertNotNull(stringMap.get("state"));
try {
client.getCredentials(MockWebContext.create());
} catch (Exception e) {
assertTrue(e.getMessage().contains("Missing state parameter"));
}
mockWebContext.addRequestParameter("state", stringMap.get("state"));
mockWebContext.addRequestParameter("code", "mockcode");
client.getCredentials(mockWebContext);
}

@Test
public void testSetState() throws MalformedURLException, RequiresHttpAction, UnsupportedEncodingException {
BaseOAuth20StateClient client = new FacebookClient(KEY, SECRET);
client.setCallbackUrl(CALLBACK_URL);
client.setStateData("oldstate");
final MockWebContext mockWebContext = MockWebContext.create();
URL url = new URL(client.getRedirectAction(mockWebContext).getLocation());
final Map<String, String> stringMap = splitQuery(url);
assertEquals(stringMap.get("state"), "oldstate");
URL url2 = new URL(client.getRedirectAction(mockWebContext).getLocation());
final Map<String, String> stringMap2 = splitQuery(url2);
assertEquals(stringMap2.get("state"), "oldstate");
}

@Test
public void testStateRandom() throws MalformedURLException, RequiresHttpAction, UnsupportedEncodingException {
BaseOAuth20StateClient client = new FacebookClient(KEY, SECRET);
client.setCallbackUrl(CALLBACK_URL);
URL url = new URL(client.getRedirectAction(MockWebContext.create()).getLocation());
final Map<String, String> stringMap = splitQuery(url);
assertNotNull(stringMap.get("state"));

URL url2 = new URL(client.getRedirectAction(MockWebContext.create()).getLocation());
final Map<String, String> stringMap2 = splitQuery(url2);
assertNotNull(stringMap2.get("state"));
assertNotEquals(stringMap.get("state"), stringMap2.get("state"));
}

public static Map<String, String> splitQuery(URL url) throws UnsupportedEncodingException {
Map<String, String> query_pairs = new LinkedHashMap<String, String>();
String query = url.getQuery();
String[] pairs = query.split("&");
for (String pair : pairs) {
int idx = pair.indexOf("=");
query_pairs.put(URLDecoder.decode(pair.substring(0, idx), "UTF-8"), URLDecoder.decode(pair.substring(idx + 1), "UTF-8"));
}
return query_pairs;
}

@Test @Test
public void testGetRedirectionGithub() throws RequiresHttpAction { public void testGetRedirectionGithub() throws RequiresHttpAction {
String url = getClient().getRedirectAction(MockWebContext.create()).getLocation(); String url = getClient().getRedirectAction(MockWebContext.create()).getLocation();
Expand Down

0 comments on commit 8bc79a2

Please sign in to comment.