Skip to content

Commit

Permalink
Introduce SessionIdGenerationStrategy
Browse files Browse the repository at this point in the history
Closes gh-11
  • Loading branch information
marcusdacoregio committed Jul 12, 2023
1 parent 2e12753 commit a4e393e
Show file tree
Hide file tree
Showing 39 changed files with 1,074 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,27 @@ public final class MapSession implements Session, Serializable {
*/
private Duration maxInactiveInterval = DEFAULT_MAX_INACTIVE_INTERVAL;

private transient SessionIdGenerationStrategy sessionIdGenerationStrategy = UuidSessionIdGenerationStrategy
.getInstance();

/**
* Creates a new instance with a secure randomly generated identifier.
*/
public MapSession() {
this(generateId());
}

/**
* Creates a new instance using the specified {@link SessionIdGenerationStrategy} to
* generate the session id.
* @param sessionIdGenerationStrategy the {@link SessionIdGenerationStrategy} to use.
* @since 3.2
*/
public MapSession(SessionIdGenerationStrategy sessionIdGenerationStrategy) {
this(sessionIdGenerationStrategy.generate());
this.sessionIdGenerationStrategy = sessionIdGenerationStrategy;
}

/**
* Creates a new instance with the specified id. This is preferred to the default
* constructor when the id is known to prevent unnecessary consumption on entropy
Expand Down Expand Up @@ -141,7 +155,7 @@ public String getOriginalId() {

@Override
public String changeSessionId() {
String changedId = generateId();
String changedId = this.sessionIdGenerationStrategy.generate();
setId(changedId);
return changedId;
}
Expand Down Expand Up @@ -232,6 +246,16 @@ private static String generateId() {
return UUID.randomUUID().toString();
}

/**
* Sets the {@link SessionIdGenerationStrategy} to use when generating a new session
* id.
* @param sessionIdGenerationStrategy the {@link SessionIdGenerationStrategy} to use.
* @since 3.2
*/
public void setSessionIdGenerationStrategy(SessionIdGenerationStrategy sessionIdGenerationStrategy) {
this.sessionIdGenerationStrategy = sessionIdGenerationStrategy;
}

private static final long serialVersionUID = 7160779239673823561L;

}
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ public class MapSessionRepository implements SessionRepository<MapSession> {

private final Map<String, Session> sessions;

private SessionIdGenerationStrategy sessionIdGenerationStrategy = UuidSessionIdGenerationStrategy.getInstance();

/**
* Creates a new instance backed by the provided {@link java.util.Map}. This allows
* injecting a distributed {@link java.util.Map}.
Expand Down Expand Up @@ -71,7 +73,9 @@ public void save(MapSession session) {
if (!session.getId().equals(session.getOriginalId())) {
this.sessions.remove(session.getOriginalId());
}
this.sessions.put(session.getId(), new MapSession(session));
MapSession saved = new MapSession(session);
saved.setSessionIdGenerationStrategy(this.sessionIdGenerationStrategy);
this.sessions.put(session.getId(), saved);
}

@Override
Expand All @@ -84,7 +88,9 @@ public MapSession findById(String id) {
deleteById(saved.getId());
return null;
}
return new MapSession(saved);
MapSession result = new MapSession(saved);
result.setSessionIdGenerationStrategy(this.sessionIdGenerationStrategy);
return result;
}

@Override
Expand All @@ -94,9 +100,14 @@ public void deleteById(String id) {

@Override
public MapSession createSession() {
MapSession result = new MapSession();
MapSession result = new MapSession(this.sessionIdGenerationStrategy);
result.setMaxInactiveInterval(this.defaultMaxInactiveInterval);
return result;
}

public void setSessionIdGenerationStrategy(SessionIdGenerationStrategy sessionIdGenerationStrategy) {
Assert.notNull(sessionIdGenerationStrategy, "sessionIdGenerationStrategy cannot be null");
this.sessionIdGenerationStrategy = sessionIdGenerationStrategy;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ public class ReactiveMapSessionRepository implements ReactiveSessionRepository<M

private final Map<String, Session> sessions;

private SessionIdGenerationStrategy sessionIdGenerationStrategy = UuidSessionIdGenerationStrategy.getInstance();

/**
* Creates a new instance backed by the provided {@link Map}. This allows injecting a
* distributed {@link Map}.
Expand Down Expand Up @@ -84,6 +86,7 @@ public Mono<MapSession> findById(String id) {
return Mono.defer(() -> Mono.justOrEmpty(this.sessions.get(id))
.filter((session) -> !session.isExpired())
.map(MapSession::new)
.doOnNext((session) -> session.setSessionIdGenerationStrategy(this.sessionIdGenerationStrategy))
.switchIfEmpty(deleteById(id).then(Mono.empty())));
// @formatter:on
}
Expand All @@ -96,10 +99,21 @@ public Mono<Void> deleteById(String id) {
@Override
public Mono<MapSession> createSession() {
return Mono.defer(() -> {
MapSession result = new MapSession();
MapSession result = new MapSession(this.sessionIdGenerationStrategy);
result.setMaxInactiveInterval(this.defaultMaxInactiveInterval);
return Mono.just(result);
});
}

/**
* Sets the {@link SessionIdGenerationStrategy} to use.
* @param sessionIdGenerationStrategy the non-null {@link SessionIdGenerationStrategy}
* to use
* @since 3.2
*/
public void setSessionIdGenerationStrategy(SessionIdGenerationStrategy sessionIdGenerationStrategy) {
Assert.notNull(sessionIdGenerationStrategy, "sessionIdGenerationStrategy cannot be null");
this.sessionIdGenerationStrategy = sessionIdGenerationStrategy;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*
* Copyright 2014-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.session;

import org.springframework.lang.NonNull;

/**
* An interface for specifying a strategy for generating session identifiers.
*
* @author Marcus da Coregio
* @since 3.2
*/
public interface SessionIdGenerationStrategy {

@NonNull
String generate();

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Copyright 2014-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.session;

import java.util.UUID;

import org.springframework.lang.NonNull;

/**
* A {@link SessionIdGenerationStrategy} that generates a random UUID to be used as the
* session id.
*
* @author Marcus da Coregio
* @since 3.2
*/
public final class UuidSessionIdGenerationStrategy implements SessionIdGenerationStrategy {

private static final UuidSessionIdGenerationStrategy INSTANCE = new UuidSessionIdGenerationStrategy();

private UuidSessionIdGenerationStrategy() {
}

@Override
@NonNull
public String generate() {
return UUID.randomUUID().toString();
}

/**
* Returns the singleton instance of {@link UuidSessionIdGenerationStrategy}.
* @return the singleton instance of {@link UuidSessionIdGenerationStrategy}
*/
public static UuidSessionIdGenerationStrategy getInstance() {
return INSTANCE;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import java.time.Duration;
import java.time.Instant;
import java.util.Set;
import java.util.UUID;

import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
Expand All @@ -42,6 +43,19 @@ void constructorNullSession() {
.withMessage("session cannot be null");
}

@Test
void constructorWhenSessionIdGenerationStrategyThenUsesStrategy() {
MapSession session = new MapSession(new FixedSessionIdGenerationStrategy("my-id"));
assertThat(session.getId()).isEqualTo("my-id");
}

@Test
void constructorWhenDefaultThenUuid() {
String id = this.session.getId();
UUID uuid = UUID.fromString(id);
assertThat(uuid).isNotNull();
}

@Test
void getAttributeWhenNullThenNull() {
String result = this.session.getAttribute("attrName");
Expand Down Expand Up @@ -143,6 +157,41 @@ void getAttributeNamesAndRemove() {
assertThat(this.session.getAttributeNames()).isEmpty();
}

@Test
void changeSessionIdWhenSessionIdStrategyThenUsesStrategy() {
MapSession session = new MapSession(new IncrementalSessionIdGenerationStrategy());
String idBeforeChange = session.getId();
String idAfterChange = session.changeSessionId();
assertThat(idBeforeChange).isEqualTo("1");
assertThat(idAfterChange).isEqualTo("2");
}

static class FixedSessionIdGenerationStrategy implements SessionIdGenerationStrategy {

private final String id;

FixedSessionIdGenerationStrategy(String id) {
this.id = id;
}

@Override
public String generate() {
return this.id;
}

}

static class IncrementalSessionIdGenerationStrategy implements SessionIdGenerationStrategy {

private int counter = 1;

@Override
public String generate() {
return String.valueOf(this.counter++);
}

}

static class CustomSession implements Session {

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,4 +152,31 @@ void getAttributeNamesAndRemove() {
assertThat(session.getAttributeNames()).isEmpty();
}

@Test
void createSessionWhenSessionIdGenerationStrategyThenUses() {
this.repository.setSessionIdGenerationStrategy(() -> "test");
MapSession session = this.repository.createSession().block();
assertThat(session.getId()).isEqualTo("test");
assertThat(session.changeSessionId()).isEqualTo("test");
}

@Test
void setSessionIdGenerationStrategyWhenNullThenThrowsException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.repository.setSessionIdGenerationStrategy(null))
.withMessage("sessionIdGenerationStrategy cannot be null");
}

@Test
void findByIdWhenChangeSessionIdThenUsesSessionIdGenerationStrategy() {
this.repository.setSessionIdGenerationStrategy(() -> "test");

MapSession session = this.repository.createSession().block();
this.repository.save(session).block();

MapSession savedSession = this.repository.findById("test").block();

assertThat(savedSession.getId()).isEqualTo("test");
assertThat(savedSession.changeSessionId()).isEqualTo("test");
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
import org.springframework.lang.Nullable;
import org.springframework.session.FindByIndexNameSessionRepository;
import org.springframework.session.MapSession;
import org.springframework.session.SessionIdGenerationStrategy;
import org.springframework.session.UuidSessionIdGenerationStrategy;
import org.springframework.session.events.SessionCreatedEvent;
import org.springframework.session.events.SessionDeletedEvent;
import org.springframework.session.events.SessionExpiredEvent;
Expand Down Expand Up @@ -81,14 +83,16 @@ public class MongoIndexedSessionRepository

private ApplicationEventPublisher eventPublisher;

private SessionIdGenerationStrategy sessionIdGenerationStrategy = UuidSessionIdGenerationStrategy.getInstance();

public MongoIndexedSessionRepository(MongoOperations mongoOperations) {
this.mongoOperations = mongoOperations;
}

@Override
public MongoSession createSession() {

MongoSession session = new MongoSession();
MongoSession session = new MongoSession(this.sessionIdGenerationStrategy);

session.setMaxInactiveInterval(this.defaultMaxInactiveInterval);

Expand Down Expand Up @@ -116,10 +120,13 @@ public MongoSession findById(String id) {

MongoSession session = MongoSessionUtils.convertToSession(this.mongoSessionConverter, sessionWrapper);

if (session != null && session.isExpired()) {
publishEvent(new SessionExpiredEvent(this, session));
deleteById(id);
return null;
if (session != null) {
if (session.isExpired()) {
publishEvent(new SessionExpiredEvent(this, session));
deleteById(id);
return null;
}
session.setSessionIdGenerationStrategy(this.sessionIdGenerationStrategy);
}

return session;
Expand All @@ -140,6 +147,7 @@ public Map<String, MongoSession> findByIndexNameAndIndexValue(String indexName,
.map((query) -> this.mongoOperations.find(query, Document.class, this.collectionName))
.orElse(Collections.emptyList()).stream()
.map((dbSession) -> MongoSessionUtils.convertToSession(this.mongoSessionConverter, dbSession))
.peek((session) -> session.setSessionIdGenerationStrategy(this.sessionIdGenerationStrategy))
.collect(Collectors.toMap(MongoSession::getId, (mapSession) -> mapSession));
}

Expand Down Expand Up @@ -216,4 +224,14 @@ public void setMongoSessionConverter(final AbstractMongoSessionConverter mongoSe
this.mongoSessionConverter = mongoSessionConverter;
}

/**
* Set the {@link SessionIdGenerationStrategy} to use to generate session ids.
* @param sessionIdGenerationStrategy the {@link SessionIdGenerationStrategy} to use
* @since 3.2
*/
public void setSessionIdGenerationStrategy(SessionIdGenerationStrategy sessionIdGenerationStrategy) {
Assert.notNull(sessionIdGenerationStrategy, "sessionIdGenerationStrategy cannot be null");
this.sessionIdGenerationStrategy = sessionIdGenerationStrategy;
}

}

0 comments on commit a4e393e

Please sign in to comment.