Skip to content
Permalink
Browse files Browse the repository at this point in the history
fix some security bug (#103)
* fix: use hard-code secret

* feat: add driver class validate

* feat: optimize drvier resource code

* fix:ut failed
  • Loading branch information
vran-dev committed Apr 18, 2022
1 parent 6b6a7f4 commit ca22a8f
Show file tree
Hide file tree
Showing 9 changed files with 127 additions and 61 deletions.
3 changes: 2 additions & 1 deletion api/src/main/resources/application-local.properties
Expand Up @@ -9,4 +9,5 @@ spring.flyway.locations=classpath:db/migration
databasir.db.url=localhost:3306
databasir.db.username=root
databasir.db.password=123456
databasir.db.driver-directory=drivers
databasir.db.driver-directory=drivers
databasir.jwt.secret=DatabasirJwtSecret
3 changes: 2 additions & 1 deletion api/src/main/resources/application.properties
Expand Up @@ -11,4 +11,5 @@ spring.flyway.enabled=true
spring.flyway.baseline-on-migrate=true
spring.flyway.locations=classpath:db/migration
# driver directory
databasir.db.driver-directory=drivers
databasir.db.driver-directory=drivers
databasir.jwt.secret=${random.uuid}
Expand Up @@ -44,7 +44,7 @@ public enum DomainErrors implements DatabasirErrors {
DUPLICATE_COLUMN("A_10028", "重复的列"),
INVALID_MOCK_DATA_SCRIPT("A_10029", "不合法的表达式"),
CANNOT_DELETE_SELF("A_10030", "无法对自己执行删除账号操作"),
DRIVER_CLASS_NAME_OBTAIN_ERROR("A_10031", "获取驱动类名失败"),
DRIVER_CLASS_NOT_FOUND("A_10031", "获取驱动类名失败"),
;

private final String errCode;
Expand Down
Expand Up @@ -36,6 +36,7 @@ public class DatabaseTypeService {
private final DatabaseTypePojoConverter databaseTypePojoConverter;

public Integer create(DatabaseTypeCreateRequest request) {
driverResources.validateJar(request.getJdbcDriverFileUrl(), request.getJdbcDriverClassName());
DatabaseTypePojo pojo = databaseTypePojoConverter.of(request);
try {
return databaseTypeDao.insertAndReturnId(pojo);
Expand All @@ -50,7 +51,7 @@ public void update(DatabaseTypeUpdateRequest request) {
if (DatabaseTypes.has(data.getDatabaseType())) {
throw DomainErrors.MUST_NOT_MODIFY_SYSTEM_DEFAULT_DATABASE_TYPE.exception();
}

driverResources.validateJar(request.getJdbcDriverFileUrl(), request.getJdbcDriverClassName());
DatabaseTypePojo pojo = databaseTypePojoConverter.of(request);
try {
databaseTypeDao.updateById(pojo);
Expand All @@ -61,7 +62,7 @@ public void update(DatabaseTypeUpdateRequest request) {
// 名称修改,下载地址修改需要删除原有的 driver
if (!Objects.equals(request.getDatabaseType(), data.getDatabaseType())
|| !Objects.equals(request.getJdbcDriverFileUrl(), data.getJdbcDriverFileUrl())) {
driverResources.delete(data.getDatabaseType());
driverResources.deleteByDatabaseType(data.getDatabaseType());
}
});

Expand All @@ -73,7 +74,7 @@ public void deleteById(Integer id) {
throw DomainErrors.MUST_NOT_MODIFY_SYSTEM_DEFAULT_DATABASE_TYPE.exception();
}
databaseTypeDao.deleteById(id);
driverResources.delete(data.getDatabaseType());
driverResources.deleteByDatabaseType(data.getDatabaseType());
});
}

Expand Down Expand Up @@ -109,7 +110,7 @@ public Optional<DatabaseTypeDetailResponse> selectOne(Integer id) {
}

public String resolveDriverClassName(DriverClassNameResolveRequest request) {
return driverResources.resolveSqlDriverNameFromJar(request.getJdbcDriverFileUrl());
return driverResources.resolveDriverClassName(request.getJdbcDriverFileUrl());
}

}
Expand Up @@ -36,8 +36,10 @@ public boolean support(String databaseType) {

@Override
public Connection getConnection(Context context) throws SQLException {
DatabaseTypePojo type = databaseTypeDao.selectByDatabaseType(context.getDatabaseType());
File driverFile = driverResources.loadOrDownload(context.getDatabaseType(), type.getJdbcDriverFileUrl());
String databaseType = context.getDatabaseType();
DatabaseTypePojo type = databaseTypeDao.selectByDatabaseType(databaseType);
File driverFile = driverResources.loadOrDownloadByDatabaseType(databaseType, type.getJdbcDriverFileUrl());

URLClassLoader loader = null;
try {
loader = new URLClassLoader(
Expand All @@ -55,11 +57,11 @@ public Connection getConnection(Context context) throws SQLException {
Class<?> clazz = null;
Driver driver = null;
try {
clazz = Class.forName(type.getJdbcDriverClassName(), true, loader);
clazz = Class.forName(type.getJdbcDriverClassName(), false, loader);
driver = (Driver) clazz.getConstructor().newInstance();
} catch (ClassNotFoundException e) {
log.error("init driver error", e);
throw DomainErrors.CONNECT_DATABASE_FAILED.exception("驱动初始化异常, 请检查 Driver name:" + e.getMessage());
throw DomainErrors.CONNECT_DATABASE_FAILED.exception("驱动初始化异常, 请检查驱动类名:" + e.getMessage());
} catch (InvocationTargetException
| InstantiationException
| IllegalAccessException
Expand Down
Expand Up @@ -4,16 +4,22 @@
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.IOUtils;
import org.apache.commons.lang3.ClassUtils;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.http.HttpMethod;
import org.springframework.stereotype.Component;
import org.springframework.util.StreamUtils;
import org.springframework.web.client.RestClientException;
import org.springframework.web.client.RestTemplate;

import java.io.*;
import java.net.MalformedURLException;
import java.net.URL;
import java.net.URLClassLoader;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Optional;
import java.util.UUID;
import java.util.jar.JarFile;

Expand All @@ -27,7 +33,7 @@ public class DriverResources {

private final RestTemplate restTemplate;

public void delete(String databaseType) {
public void deleteByDatabaseType(String databaseType) {
Path path = Paths.get(driverFilePath(driverBaseDirectory, databaseType));
try {
Files.deleteIfExists(path);
Expand All @@ -36,10 +42,24 @@ public void delete(String databaseType) {
}
}

public String resolveSqlDriverNameFromJar(String driverFileUrl) {
public Optional<File> loadByDatabaseType(String databaseType) {
Path path = Paths.get(driverFilePath(driverBaseDirectory, databaseType));
if (Files.exists(path)) {
return Optional.of(path.toFile());
} else {
return Optional.empty();
}
}

public File loadOrDownloadByDatabaseType(String databaseType, String driverFileUrl) {
return loadByDatabaseType(databaseType)
.orElseGet(() -> download(driverFileUrl, driverFilePath(driverBaseDirectory, databaseType)));
}

public String resolveDriverClassName(String driverFileUrl) {
String tempFilePath = "temp/" + UUID.randomUUID() + ".jar";
File driverFile = doDownload(driverFileUrl, tempFilePath);
String className = doResolveSqlDriverNameFromJar(driverFile);
File driverFile = download(driverFileUrl, tempFilePath);
String className = resolveDriverClassName(driverFile);
try {
Files.deleteIfExists(driverFile.toPath());
} catch (IOException e) {
Expand All @@ -48,27 +68,48 @@ public String resolveSqlDriverNameFromJar(String driverFileUrl) {
return className;
}

public File loadOrDownload(String databaseType, String driverFileUrl) {
String filePath = driverFilePath(driverBaseDirectory, databaseType);
Path path = Path.of(filePath);
if (Files.exists(path)) {
// ignore
log.debug("{} already exists, ignore download from {}", filePath, driverFileUrl);
return path.toFile();
public String resolveDriverClassName(File driverFile) {
JarFile jarFile = null;
try {
jarFile = new JarFile(driverFile);
} catch (IOException e) {
log.error("resolve driver class name error", e);
throw DomainErrors.DRIVER_CLASS_NOT_FOUND.exception(e.getMessage());
}
return this.doDownload(driverFileUrl, filePath);

final JarFile driverJar = jarFile;
String driverClassName = jarFile.stream()
.filter(entry -> entry.getName().contains("META-INF/services/java.sql.Driver"))
.findFirst()
.map(entry -> {
InputStream stream = null;
BufferedReader reader = null;
try {
stream = driverJar.getInputStream(entry);
reader = new BufferedReader(new InputStreamReader(stream));
return reader.readLine();
} catch (IOException e) {
log.error("resolve driver class name error", e);
throw DomainErrors.DRIVER_CLASS_NOT_FOUND.exception(e.getMessage());
} finally {
IOUtils.closeQuietly(reader, ex -> log.error("close reader error", ex));
}
})
.orElseThrow(DomainErrors.DRIVER_CLASS_NOT_FOUND::exception);
IOUtils.closeQuietly(jarFile, ex -> log.error("close jar file error", ex));
return driverClassName;
}

private File doDownload(String driverFileUrl, String filePath) {
Path path = Path.of(filePath);
private File download(String driverFileUrl, String targetFile) {
Path path = Path.of(targetFile);

// create parent directory
if (Files.notExists(path)) {
path.getParent().toFile().mkdirs();
try {
Files.createFile(path);
} catch (IOException e) {
log.error("create file error " + filePath, e);
log.error("create file error " + targetFile, e);
throw DomainErrors.DOWNLOAD_DRIVER_ERROR.exception(e.getMessage());
}
}
Expand All @@ -81,52 +122,55 @@ private File doDownload(String driverFileUrl, String filePath) {
FileOutputStream out = new FileOutputStream(file);
StreamUtils.copy(response.getBody(), out);
IOUtils.closeQuietly(out, ex -> log.error("close file error", ex));
log.info("{} download success ", filePath);
log.info("{} download success ", targetFile);
return file;
} else {
log.error("{} download error from {}: {} ", filePath, driverFileUrl, response);
log.error("{} download error from {}: {} ", targetFile, driverFileUrl, response);
throw DomainErrors.DOWNLOAD_DRIVER_ERROR.exception("驱动下载失败:"
+ response.getStatusCode()
+ ", "
+ response.getStatusText());
}
});
} catch (IllegalArgumentException e) {
log.error(filePath + " download driver error", e);
} catch (RestClientException e) {
log.error(targetFile + " download driver error", e);
throw DomainErrors.DOWNLOAD_DRIVER_ERROR.exception(e.getMessage());
}
}

private String doResolveSqlDriverNameFromJar(File driverFile) {
JarFile jarFile = null;
public void validateJar(String driverFileUrl, String className) {
String tempFilePath = "temp/" + UUID.randomUUID() + ".jar";
File driverFile = download(driverFileUrl, tempFilePath);
URLClassLoader loader = null;
try {
jarFile = new JarFile(driverFile);
} catch (IOException e) {
log.error("resolve driver class name error", e);
throw DomainErrors.DRIVER_CLASS_NAME_OBTAIN_ERROR.exception(e.getMessage());
loader = new URLClassLoader(
new URL[]{driverFile.toURI().toURL()},
this.getClass().getClassLoader()
);
} catch (MalformedURLException e) {
log.error("load driver jar error ", e);
throw DomainErrors.DOWNLOAD_DRIVER_ERROR.exception(e.getMessage());
}

final JarFile driverJar = jarFile;
String driverClassName = jarFile.stream()
.filter(entry -> entry.getName().contains("META-INF/services/java.sql.Driver"))
.findFirst()
.map(entry -> {
InputStream stream = null;
BufferedReader reader = null;
try {
stream = driverJar.getInputStream(entry);
reader = new BufferedReader(new InputStreamReader(stream));
return reader.readLine();
} catch (IOException e) {
log.error("resolve driver class name error", e);
throw DomainErrors.DRIVER_CLASS_NAME_OBTAIN_ERROR.exception(e.getMessage());
} finally {
IOUtils.closeQuietly(reader, ex -> log.error("close reader error", ex));
}
})
.orElseThrow(DomainErrors.DRIVER_CLASS_NAME_OBTAIN_ERROR::exception);
IOUtils.closeQuietly(jarFile, ex -> log.error("close jar file error", ex));
return driverClassName;
try {
Class clazz = Class.forName(className, false, loader);
boolean isValid = ClassUtils.getAllInterfaces(clazz)
.stream()
.anyMatch(cls -> cls.getName().equals("java.sql.Driver"));
if (!isValid) {
throw DomainErrors.DRIVER_CLASS_NOT_FOUND.exception("不合法的驱动类,请重新指定");
}
} catch (ClassNotFoundException e) {
log.error("init driver error", e);
throw DomainErrors.DRIVER_CLASS_NOT_FOUND.exception("驱动初始化异常, 请检查驱动类名:" + e.getMessage());
} finally {
IOUtils.closeQuietly(loader);
try {
Files.deleteIfExists(driverFile.toPath());
} catch (IOException e) {
log.error("delete driver error " + tempFilePath, e);
}
}
}

private String driverFilePath(String baseDir, String databaseType) {
Expand Down
Expand Up @@ -5,6 +5,7 @@
import com.auth0.jwt.exceptions.JWTVerificationException;
import com.auth0.jwt.interfaces.JWTVerifier;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;

import java.time.Instant;
Expand All @@ -23,10 +24,11 @@ public class JwtTokens {

private static final String ISSUER = "Databasir";

private static final String SECRET = "Databasir2022";
@Value("${databasir.jwt.secret}")
private String tokenSecret;

public String accessToken(String username) {
Algorithm algorithm = Algorithm.HMAC256(SECRET);
Algorithm algorithm = Algorithm.HMAC256(tokenSecret);

return JWT.create()
.withExpiresAt(new Date(new Date().getTime() + ACCESS_EXPIRE_TIME))
Expand All @@ -36,7 +38,7 @@ public String accessToken(String username) {
}

public boolean verify(String token) {
JWTVerifier verifier = JWT.require(Algorithm.HMAC256(SECRET))
JWTVerifier verifier = JWT.require(Algorithm.HMAC256(tokenSecret))
.withIssuer(ISSUER)
.build();
try {
Expand Down
Expand Up @@ -5,14 +5,20 @@
import com.databasir.core.domain.DomainErrors;
import com.databasir.core.domain.database.data.DatabaseTypeCreateRequest;
import com.databasir.core.domain.database.data.DatabaseTypeUpdateRequest;
import com.databasir.core.infrastructure.driver.DriverResources;
import com.databasir.dao.impl.DatabaseTypeDao;
import com.databasir.dao.tables.pojos.DatabaseTypePojo;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.mock.mockito.MockBean;
import org.springframework.test.context.jdbc.Sql;
import org.springframework.transaction.annotation.Transactional;

import static org.mockito.ArgumentMatchers.anyString;

@Transactional
class DatabaseTypeServiceTest extends BaseTest {

Expand All @@ -22,6 +28,14 @@ class DatabaseTypeServiceTest extends BaseTest {
@Autowired
private DatabaseTypeDao databaseTypeDao;

@MockBean
private DriverResources driverResources;

@BeforeEach
public void setUp() {
Mockito.doNothing().when(driverResources).validateJar(anyString(), anyString());
}

@Test
void create() {
DatabaseTypeCreateRequest request = new DatabaseTypeCreateRequest();
Expand Down
3 changes: 2 additions & 1 deletion core/src/test/resources/application-ut.properties
Expand Up @@ -15,4 +15,5 @@ spring.flyway.locations=classpath:db/migration
databasir.db.url=localhost:3306
databasir.db.username=root
databasir.db.password=123456
databasir.db.driver-directory=drivers
databasir.db.driver-directory=drivers
databasir.jwt.secret=DatabasirJwtSecret

0 comments on commit ca22a8f

Please sign in to comment.