-
Notifications
You must be signed in to change notification settings - Fork 2.5k
/
HibernateMultiTenantConnectionProvider.java
127 lines (112 loc) · 5.81 KB
/
HibernateMultiTenantConnectionProvider.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
package io.quarkus.hibernate.orm.runtime.tenant;
import java.lang.annotation.Annotation;
import java.util.Locale;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import javax.enterprise.context.RequestScoped;
import javax.enterprise.context.SessionScoped;
import javax.enterprise.inject.Default;
import org.hibernate.engine.jdbc.connections.spi.AbstractMultiTenantConnectionProvider;
import org.hibernate.engine.jdbc.connections.spi.ConnectionProvider;
import org.jboss.logging.Logger;
import io.quarkus.arc.Arc;
import io.quarkus.arc.InstanceHandle;
import io.quarkus.arc.ManagedContext;
import io.quarkus.hibernate.orm.PersistenceUnit.PersistenceUnitLiteral;
import io.quarkus.hibernate.orm.runtime.PersistenceUnitUtil;
/**
* Maps from the Quarkus {@link TenantConnectionResolver} to the {@link HibernateMultiTenantConnectionProvider} model.
*
* @author Michael Schnell
*/
public final class HibernateMultiTenantConnectionProvider extends AbstractMultiTenantConnectionProvider {
private static final Logger LOG = Logger.getLogger(HibernateMultiTenantConnectionProvider.class);
private final String persistenceUnitName;
private final Map<String, ConnectionProvider> providerMap = new ConcurrentHashMap<>();
public HibernateMultiTenantConnectionProvider(String persistenceUnitName) {
this.persistenceUnitName = persistenceUnitName;
}
@Override
protected ConnectionProvider getAnyConnectionProvider() {
InstanceHandle<TenantResolver> tenantResolver = tenantResolver(persistenceUnitName);
String tenantId;
// Activate RequestScope if the TenantResolver is @RequestScoped or @SessionScoped
ManagedContext requestContext = Arc.container().requestContext();
Class<? extends Annotation> tenantScope = tenantResolver.getBean().getScope();
boolean requiresRequestScope = (tenantScope == RequestScoped.class || tenantScope == SessionScoped.class);
boolean forceRequestActivation = (!requestContext.isActive() && requiresRequestScope);
try {
if (forceRequestActivation) {
requestContext.activate();
}
tenantId = tenantResolver.get().getDefaultTenantId();
} finally {
if (forceRequestActivation) {
requestContext.deactivate();
}
}
if (tenantId == null) {
throw new IllegalStateException("Method 'TenantResolver.getDefaultTenantId()' returned a null value. "
+ "This violates the contract of the interface!");
}
return selectConnectionProvider(tenantId);
}
@Override
protected ConnectionProvider selectConnectionProvider(final String tenantIdentifier) {
LOG.debugv("selectConnectionProvider(persistenceUnitName={0}, tenantIdentifier={1})", persistenceUnitName,
tenantIdentifier);
ConnectionProvider provider = providerMap.get(tenantIdentifier);
if (provider == null) {
final ConnectionProvider connectionProvider = resolveConnectionProvider(persistenceUnitName, tenantIdentifier);
providerMap.put(tenantIdentifier, connectionProvider);
return connectionProvider;
}
return provider;
}
private static ConnectionProvider resolveConnectionProvider(String persistenceUnitName, String tenantIdentifier) {
LOG.debugv("resolveConnectionProvider(persistenceUnitName={0}, tenantIdentifier={1})", persistenceUnitName,
tenantIdentifier);
InstanceHandle<TenantConnectionResolver> instance;
if (PersistenceUnitUtil.isDefaultPersistenceUnit(persistenceUnitName)) {
instance = Arc.container().instance(TenantConnectionResolver.class, Default.Literal.INSTANCE);
} else {
instance = Arc.container().instance(TenantConnectionResolver.class,
new PersistenceUnitLiteral(persistenceUnitName));
}
if (!instance.isAvailable()) {
throw new IllegalStateException(
String.format(
Locale.ROOT, "No instance of %1$s was found for persistence unit %2$s. "
+ "You need to create an implementation for this interface to allow resolving the current tenant connection.",
TenantConnectionResolver.class.getSimpleName(), persistenceUnitName));
}
TenantConnectionResolver resolver = instance.get();
ConnectionProvider cp = resolver.resolve(tenantIdentifier);
if (cp == null) {
throw new IllegalStateException("Method 'TenantConnectionResolver."
+ "resolve(String)' returned a null value. This violates the contract of the interface!");
}
return cp;
}
/**
* Retrieves the tenant resolver or fails if it is not available.
*
* @return Current tenant resolver.
*/
private static InstanceHandle<TenantResolver> tenantResolver(String persistenceUnitName) {
InstanceHandle<TenantResolver> resolverInstance;
if (PersistenceUnitUtil.isDefaultPersistenceUnit(persistenceUnitName)) {
resolverInstance = Arc.container().instance(TenantResolver.class, Default.Literal.INSTANCE);
} else {
resolverInstance = Arc.container().instance(TenantResolver.class,
new PersistenceUnitLiteral(persistenceUnitName));
}
if (!resolverInstance.isAvailable()) {
throw new IllegalStateException(String.format(Locale.ROOT,
"No instance of %1$s was found for persistence unit %2$s. "
+ "You need to create an implementation for this interface to allow resolving the current tenant identifier.",
TenantResolver.class.getSimpleName(), persistenceUnitName));
}
return resolverInstance;
}
}