Skip to content

Commit

Permalink
Address feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
marcusdacoregio committed Jun 12, 2023
1 parent ba26597 commit 0950932
Show file tree
Hide file tree
Showing 35 changed files with 695 additions and 258 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
import java.util.Set;
import java.util.UUID;

import org.springframework.util.Assert;

/**
* <p>
* A {@link Session} implementation that is backed by a {@link java.util.Map}. The
Expand Down Expand Up @@ -74,13 +76,27 @@ public final class MapSession implements Session, Serializable {
*/
private Duration maxInactiveInterval = DEFAULT_MAX_INACTIVE_INTERVAL;

private transient SessionIdGenerationStrategy sessionIdGenerationStrategy = UuidSessionIdGenerationStrategy
.getInstance();

/**
* Creates a new instance with a secure randomly generated identifier.
*/
public MapSession() {
this(generateId());
}

/**
* Creates a new instance using the specified {@link SessionIdGenerationStrategy} to
* generate the session id.
* @param sessionIdGenerationStrategy the {@link SessionIdGenerationStrategy} to use.
* @since 3.1
*/
public MapSession(SessionIdGenerationStrategy sessionIdGenerationStrategy) {
this(sessionIdGenerationStrategy.generate());
this.sessionIdGenerationStrategy = sessionIdGenerationStrategy;
}

/**
* Creates a new instance with the specified id. This is preferred to the default
* constructor when the id is known to prevent unnecessary consumption on entropy
Expand Down Expand Up @@ -141,7 +157,7 @@ public String getOriginalId() {

@Override
public String changeSessionId() {
String changedId = generateId();
String changedId = this.sessionIdGenerationStrategy.generate();
setId(changedId);
return changedId;
}
Expand Down Expand Up @@ -232,6 +248,17 @@ private static String generateId() {
return UUID.randomUUID().toString();
}

/**
* Sets the {@link SessionIdGenerationStrategy} to use when generating a new session
* id.
* @param sessionIdGenerationStrategy the {@link SessionIdGenerationStrategy} to use.
* @since 3.1
*/
public void setSessionIdGenerationStrategy(SessionIdGenerationStrategy sessionIdGenerationStrategy) {
Assert.notNull(sessionIdGenerationStrategy, "sessionIdGenerationStrategy cannot be null");
this.sessionIdGenerationStrategy = sessionIdGenerationStrategy;
}

private static final long serialVersionUID = 7160779239673823561L;

}
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,11 @@
*/
public class MapSessionRepository implements SessionRepository<MapSession> {

private static final SessionIdGenerationStrategy DEFAULT_STRATEGY = UuidSessionIdGenerationStrategy.getInstance();

private Duration defaultMaxInactiveInterval = Duration.ofSeconds(MapSession.DEFAULT_MAX_INACTIVE_INTERVAL_SECONDS);

private final Map<String, Session> sessions;

private SessionIdGenerationStrategy sessionIdGenerationStrategy = DEFAULT_STRATEGY;
private SessionIdGenerationStrategy sessionIdGenerationStrategy = UuidSessionIdGenerationStrategy.getInstance();

/**
* Creates a new instance backed by the provided {@link java.util.Map}. This allows
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,11 @@
*/
public class ReactiveMapSessionRepository implements ReactiveSessionRepository<MapSession> {

private static final SessionIdGenerationStrategy DEFAULT_STRATEGY = UuidSessionIdGenerationStrategy.getInstance();

private Duration defaultMaxInactiveInterval = Duration.ofSeconds(MapSession.DEFAULT_MAX_INACTIVE_INTERVAL_SECONDS);

private final Map<String, Session> sessions;

private SessionIdGenerationStrategy sessionIdGenerationStrategy = DEFAULT_STRATEGY;
private SessionIdGenerationStrategy sessionIdGenerationStrategy = UuidSessionIdGenerationStrategy.getInstance();

/**
* Creates a new instance backed by the provided {@link Map}. This allows injecting a
Expand Down Expand Up @@ -88,6 +86,7 @@ public Mono<MapSession> findById(String id) {
return Mono.defer(() -> Mono.justOrEmpty(this.sessions.get(id))
.filter((session) -> !session.isExpired())
.map(MapSession::new)
.doOnNext((session) -> session.setSessionIdGenerationStrategy(this.sessionIdGenerationStrategy))
.switchIfEmpty(deleteById(id).then(Mono.empty())));
// @formatter:on
}
Expand All @@ -100,7 +99,7 @@ public Mono<Void> deleteById(String id) {
@Override
public Mono<MapSession> createSession() {
return Mono.defer(() -> {
MapSession result = new MapSession(this.sessionIdGenerationStrategy.generate());
MapSession result = new MapSession(this.sessionIdGenerationStrategy);
result.setMaxInactiveInterval(this.defaultMaxInactiveInterval);
return Mono.just(result);
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@

import org.springframework.lang.NonNull;

/**
* An interface for specifying a strategy for generating session identifiers.
*
* @author Marcus da Coregio
* @since 3.1
*/
public interface SessionIdGenerationStrategy {

@NonNull
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@

import org.springframework.lang.NonNull;

/**
* A {@link SessionIdGenerationStrategy} that generates a random UUID to be used as the
* session id.
*
* @author Marcus da Coregio
* @since 3.1
*/
public final class UuidSessionIdGenerationStrategy implements SessionIdGenerationStrategy {

private static final UuidSessionIdGenerationStrategy INSTANCE = new UuidSessionIdGenerationStrategy();
Expand All @@ -33,6 +40,10 @@ public String generate() {
return UUID.randomUUID().toString();
}

/**
* Returns the singleton instance of {@link UuidSessionIdGenerationStrategy}.
* @return the singleton instance of {@link UuidSessionIdGenerationStrategy}
*/
public static UuidSessionIdGenerationStrategy getInstance() {
return INSTANCE;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import java.time.Duration;
import java.time.Instant;
import java.util.Set;
import java.util.UUID;

import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
Expand All @@ -42,6 +43,19 @@ void constructorNullSession() {
.withMessage("session cannot be null");
}

@Test
void constructorWhenSessionIdGenerationStrategyThenUsesStrategy() {
MapSession session = new MapSession(new FixedSessionIdGenerationStrategy("my-id"));
assertThat(session.getId()).isEqualTo("my-id");
}

@Test
void constructorWhenDefaultThenUuid() {
String id = this.session.getId();
UUID uuid = UUID.fromString(id);
assertThat(uuid).isNotNull();
}

@Test
void getAttributeWhenNullThenNull() {
String result = this.session.getAttribute("attrName");
Expand Down Expand Up @@ -143,6 +157,41 @@ void getAttributeNamesAndRemove() {
assertThat(this.session.getAttributeNames()).isEmpty();
}

@Test
void changeSessionIdWhenSessionIdStrategyThenUsesStrategy() {
MapSession session = new MapSession(new IncrementalSessionIdGenerationStrategy());
String idBeforeChange = session.getId();
String idAfterChange = session.changeSessionId();
assertThat(idBeforeChange).isEqualTo("1");
assertThat(idAfterChange).isEqualTo("2");
}

static class FixedSessionIdGenerationStrategy implements SessionIdGenerationStrategy {

private final String id;

FixedSessionIdGenerationStrategy(String id) {
this.id = id;
}

@Override
public String generate() {
return this.id;
}

}

static class IncrementalSessionIdGenerationStrategy implements SessionIdGenerationStrategy {

private int counter = 1;

@Override
public String generate() {
return String.valueOf(this.counter++);
}

}

static class CustomSession implements Session {

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,4 +152,31 @@ void getAttributeNamesAndRemove() {
assertThat(session.getAttributeNames()).isEmpty();
}

@Test
void createSessionWhenSessionIdGenerationStrategyThenUses() {
this.repository.setSessionIdGenerationStrategy(() -> "test");
MapSession session = this.repository.createSession().block();
assertThat(session.getId()).isEqualTo("test");
assertThat(session.changeSessionId()).isEqualTo("test");
}

@Test
void setSessionIdGenerationStrategyWhenNullThenThrowsException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.repository.setSessionIdGenerationStrategy(null))
.withMessage("sessionIdGenerationStrategy cannot be null");
}

@Test
void findByIdWhenChangeSessionIdThenUsesSessionIdGenerationStrategy() {
this.repository.setSessionIdGenerationStrategy(() -> "test");

MapSession session = this.repository.createSession().block();
this.repository.save(session).block();

MapSession savedSession = this.repository.findById("test").block();

assertThat(savedSession.getId()).isEqualTo("test");
assertThat(savedSession.changeSessionId()).isEqualTo("test");
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,6 @@ public class MongoIndexedSessionRepository

private static final Log logger = LogFactory.getLog(MongoIndexedSessionRepository.class);

private static final SessionIdGenerationStrategy DEFAULT_STRATEGY = UuidSessionIdGenerationStrategy.getInstance();

private final MongoOperations mongoOperations;

private Duration defaultMaxInactiveInterval = Duration.ofSeconds(MapSession.DEFAULT_MAX_INACTIVE_INTERVAL_SECONDS);
Expand All @@ -85,7 +83,7 @@ public class MongoIndexedSessionRepository

private ApplicationEventPublisher eventPublisher;

private SessionIdGenerationStrategy sessionIdGenerationStrategy = DEFAULT_STRATEGY;
private SessionIdGenerationStrategy sessionIdGenerationStrategy = UuidSessionIdGenerationStrategy.getInstance();

public MongoIndexedSessionRepository(MongoOperations mongoOperations) {
this.mongoOperations = mongoOperations;
Expand All @@ -94,27 +92,18 @@ public MongoIndexedSessionRepository(MongoOperations mongoOperations) {
@Override
public MongoSession createSession() {

String sessionId = this.sessionIdGenerationStrategy.generate();
MongoSession session = new MongoSession(sessionId);
MongoSession session = new MongoSession(this.sessionIdGenerationStrategy);

session.setMaxInactiveInterval(this.defaultMaxInactiveInterval);

publishEvent(new SessionCreatedEvent(this, session));

return wrapSession(session);
}

private SessionIdGenerationStrategyAwareMongoSession wrapSession(MongoSession session) {
return new SessionIdGenerationStrategyAwareMongoSession(session, this.sessionIdGenerationStrategy);
return session;
}

@Override
public void save(MongoSession session) {
MongoSession mongoSession = session;
if (session instanceof SessionIdGenerationStrategyAwareMongoSession awareMongoSession) {
mongoSession = awareMongoSession.getDelegate();
}
DBObject dbObject = MongoSessionUtils.convertToDBObject(this.mongoSessionConverter, mongoSession);
DBObject dbObject = MongoSessionUtils.convertToDBObject(this.mongoSessionConverter, session);
Assert.notNull(dbObject, "dbObject must not be null");
this.mongoOperations.save(dbObject, this.collectionName);
}
Expand All @@ -137,7 +126,7 @@ public MongoSession findById(String id) {
deleteById(id);
return null;
}
return wrapSession(session);
session.setSessionIdGenerationStrategy(this.sessionIdGenerationStrategy);
}

return session;
Expand All @@ -158,7 +147,8 @@ public Map<String, MongoSession> findByIndexNameAndIndexValue(String indexName,
.map((query) -> this.mongoOperations.find(query, Document.class, this.collectionName))
.orElse(Collections.emptyList()).stream()
.map((dbSession) -> MongoSessionUtils.convertToSession(this.mongoSessionConverter, dbSession))
.map(this::wrapSession).collect(Collectors.toMap(MongoSession::getId, (mapSession) -> mapSession));
.peek((session) -> session.setSessionIdGenerationStrategy(this.sessionIdGenerationStrategy))
.collect(Collectors.toMap(MongoSession::getId, (mapSession) -> mapSession));
}

@Override
Expand Down Expand Up @@ -234,6 +224,11 @@ public void setMongoSessionConverter(final AbstractMongoSessionConverter mongoSe
this.mongoSessionConverter = mongoSessionConverter;
}

/**
* Set the {@link SessionIdGenerationStrategy} to use to generate session ids.
* @param sessionIdGenerationStrategy the {@link SessionIdGenerationStrategy} to use
* @since 3.1
*/
public void setSessionIdGenerationStrategy(SessionIdGenerationStrategy sessionIdGenerationStrategy) {
Assert.notNull(sessionIdGenerationStrategy, "sessionIdGenerationStrategy cannot be null");
this.sessionIdGenerationStrategy = sessionIdGenerationStrategy;
Expand Down

0 comments on commit 0950932

Please sign in to comment.