Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 97 additions & 14 deletions connector/src/main/java/tech/ydb/spark/connector/YdbContext.java
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
package tech.ydb.spark.connector;

import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.Serializable;
import java.net.URI;
import java.net.URISyntaxException;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;

import com.google.common.io.ByteStreams;
import org.apache.spark.sql.util.CaseInsensitiveStringMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand All @@ -21,6 +26,7 @@
import tech.ydb.core.grpc.GrpcTransport;
import tech.ydb.core.grpc.GrpcTransportBuilder;
import tech.ydb.core.impl.auth.GrpcAuthRpc;
import tech.ydb.core.utils.URITools;
import tech.ydb.spark.connector.common.ConnectionOption;
import tech.ydb.spark.connector.impl.YdbExecutor;
import tech.ydb.table.TableClient;
Expand All @@ -34,6 +40,13 @@ public class YdbContext implements Serializable {

private static final long serialVersionUID = 6522842483896983993L;

// copy URL paramter names from JDBC
private static final String JDBC_TOKEN_FILE = "tokenFile";
private static final String JDBC_SECURE_CONNECTION_CERTIFICATE = "secureConnectionCertificate";
private static final String JDBC_SA_KEY_FILE = "saKeyFile";
private static final String JDBC_USE_METADATA = "useMetadata";
private static final String HOME_REF = "~/";

private final String connectionString;

private final byte[] caCertBytes;
Expand All @@ -55,17 +68,21 @@ public YdbContext(CaseInsensitiveStringMap options) {
throw new IllegalArgumentException("Incorrect value for property " + ConnectionOption.URL);
}

this.caCertBytes = readCaCertificate(options);
this.useMetadata = ConnectionOption.AUTH_METADATA.readBoolean(options, false);
this.useEnv = ConnectionOption.AUTH_ENV.readBoolean(options, false);
Map<String, String> parameters = new HashMap<>();
parameters.putAll(parseJdbcParams(connectionString));
parameters.putAll(options);

this.caCertBytes = readCaCertificate(parameters);
this.useMetadata = ConnectionOption.AUTH_METADATA.readBoolean(parameters, false);
this.useEnv = ConnectionOption.AUTH_ENV.readBoolean(parameters, false);

this.token = ConnectionOption.AUTH_TOKEN.read(options);
this.saKey = readSaKey(options);
this.token = readToken(parameters);
this.saKey = readSaKey(parameters);

this.username = ConnectionOption.AUTH_LOGIN.read(options);
this.password = ConnectionOption.AUTH_PASSWORD.read(options);
this.username = ConnectionOption.AUTH_LOGIN.read(parameters);
this.password = ConnectionOption.AUTH_PASSWORD.read(parameters);

this.sessionPoolSize = ConnectionOption.POOL_SIZE.readInt(options, 0);
this.sessionPoolSize = ConnectionOption.POOL_SIZE.readInt(parameters, 0);
}

@Override
Expand Down Expand Up @@ -159,11 +176,11 @@ private AuthRpcProvider<? super GrpcAuthRpc> createAuthProvider() {
return NopAuthProvider.INSTANCE;
}

private static byte[] readCaCertificate(CaseInsensitiveStringMap options) {
private static byte[] readCaCertificate(Map<String, String> options) {
String caFile = ConnectionOption.CA_FILE.read(options);
if (caFile != null) {
try {
return Files.readAllBytes(Paths.get(caFile));
return readFileAsBytes(caFile);
} catch (IOException ix) {
throw new IllegalArgumentException("Failed to read CA certificate file " + caFile, ix);
}
Expand All @@ -178,11 +195,12 @@ private static byte[] readCaCertificate(CaseInsensitiveStringMap options) {
return null;
}

private static String readSaKey(CaseInsensitiveStringMap options) {
private static String readSaKey(Map<String, String> options) {
String saKeyPath = ConnectionOption.AUTH_SAKEY_FILE.read(options);
if (saKeyPath != null && !saKeyPath.trim().isEmpty()) {
try {
return new String(Files.readAllBytes(Paths.get(saKeyPath)), StandardCharsets.UTF_8);
byte[] content = readFileAsBytes(saKeyPath);
return new String(content, StandardCharsets.UTF_8);
} catch (IOException ix) {
throw new IllegalArgumentException("Failed to read service account key file " + saKeyPath, ix);
}
Expand All @@ -195,4 +213,69 @@ private static String readSaKey(CaseInsensitiveStringMap options) {

return null;
}

private static String readToken(Map<String, String> options) {
String tokenFile = ConnectionOption.AUTH_TOKEN_FILE.read(options);
if (tokenFile != null && !tokenFile.trim().isEmpty()) {
try {
byte[] content = readFileAsBytes(tokenFile);
return firstLine(content);
} catch (IOException ix) {
throw new IllegalArgumentException("Failed to read token file " + tokenFile, ix);
}
}

String tokenValue = ConnectionOption.AUTH_TOKEN.read(options);
if (tokenValue != null && !tokenValue.trim().isEmpty()) {
return tokenValue.trim();
}

return null;
}

private static Map<String, String> parseJdbcParams(String url) {
Map<String, String> params = new HashMap<>();
try {
URI uri = new URI(url.contains("://") ? url : "grpc://" + url);
URITools.splitQuery(uri).forEach((key, values) -> {
if (key.equalsIgnoreCase(JDBC_TOKEN_FILE)) {
params.put(ConnectionOption.AUTH_TOKEN_FILE.getCode(), values.get(0));
}
if (key.equalsIgnoreCase(JDBC_SECURE_CONNECTION_CERTIFICATE)) {
params.put(ConnectionOption.CA_FILE.getCode(), values.get(0));
}
if (key.equalsIgnoreCase(JDBC_SA_KEY_FILE)) {
params.put(ConnectionOption.AUTH_SAKEY_FILE.getCode(), values.get(0));
}
if (key.equalsIgnoreCase(JDBC_USE_METADATA)) {
params.put(ConnectionOption.AUTH_METADATA.getCode(), values.get(0));
}
});
} catch (URISyntaxException ex) {
// nothing
}
return params;
}

public static byte[] readFileAsBytes(String filePath) throws IOException {
String path = filePath.trim();

if (path.startsWith(HOME_REF)) {
String home = System.getProperty("user.home");
path = home + path.substring(HOME_REF.length() - 1);
}

try (InputStream is = new FileInputStream(path)) {
return ByteStreams.toByteArray(is);
}
}

private static String firstLine(byte[] bytes) throws IOException {
for (int idx = 0; idx < bytes.length; idx++) {
if (bytes[idx] == '\n' || bytes[idx] == '\r') {
return new String(bytes, 0, idx, StandardCharsets.UTF_8);
}
}
return new String(bytes, StandardCharsets.UTF_8);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ public enum ConnectionOption implements SparkOption {
*/
AUTH_TOKEN("auth.token"),

/**
* Token value for the TOKEN authentication mode.
*/
AUTH_TOKEN_FILE("auth.token.file"),

/**
* Session pool size limit. Default is 4x number of cores available.
*/
Expand Down