Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-845282: Allow configuring tmpdir in DSN #874

Merged
merged 1 commit into from
Aug 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ func (d SnowflakeDriver) Open(dsn string) (driver.Conn, error) {
}

// OpenWithConfig creates a new connection with the given Config.
func (d SnowflakeDriver) OpenWithConfig(
ctx context.Context,
config Config) (
driver.Conn, error) {
func (d SnowflakeDriver) OpenWithConfig(ctx context.Context, config Config) (driver.Conn, error) {
if err := config.Validate(); err != nil {
return nil, err
}
if config.Tracing != "" {
logger.SetLogLevel(config.Tracing)
}
Expand Down
12 changes: 12 additions & 0 deletions driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1575,6 +1575,18 @@ func TestOpenWithConfig(t *testing.T) {
db.Close()
}

func TestOpenWithInvalidConfig(t *testing.T) {
config, err := ParseDSN("u:p@h?tmpDirPath=%2Fnon-existing")
if err != nil {
t.Fatalf("failed to parse dsn. err: %v", err)
}
driver := SnowflakeDriver{}
_, err = driver.OpenWithConfig(context.Background(), *config)
if err == nil || !strings.Contains(err.Error(), "/non-existing") {
t.Fatalf("should fail on missing directory")
}
}

type CountingTransport struct {
requests int
}
Expand Down
18 changes: 18 additions & 0 deletions dsn.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,25 @@ type Config struct {

Tracing string // sets logging level

TmpDirPath string // sets temporary directory used by a driver for operations like encrypting, compressing etc

MfaToken string // Internally used to cache the MFA token
IDToken string // Internally used to cache the Id Token for external browser
ClientRequestMfaToken ConfigBool // When true the MFA token is cached in the credential manager. True by default in Windows/OSX. False for Linux.
ClientStoreTemporaryCredential ConfigBool // When true the ID token is cached in the credential manager. True by default in Windows/OSX. False for Linux.
}

// Validate enables testing if config is correct.
// A driver client may call it manually, but it is also called during opening first connection.
func (c *Config) Validate() error {
if c.TmpDirPath != "" {
if _, err := os.Stat(c.TmpDirPath); err != nil {
return err
}
}
return nil
}

// ocspMode returns the OCSP mode in string INSECURE, FAIL_OPEN, FAIL_CLOSED
func (c *Config) ocspMode() string {
if c.InsecureMode {
Expand Down Expand Up @@ -214,6 +227,9 @@ func DSN(cfg *Config) (dsn string, err error) {
if cfg.Tracing != "" {
params.Add("tracing", cfg.Tracing)
}
if cfg.TmpDirPath != "" {
params.Add("tmpDirPath", cfg.TmpDirPath)
}

params.Add("ocspFailOpen", strconv.FormatBool(cfg.OCSPFailOpen != OCSPFailOpenFalse))

Expand Down Expand Up @@ -684,6 +700,8 @@ func parseDSNParams(cfg *Config, params string) (err error) {
}
case "tracing":
cfg.Tracing = value
case "tmpDirPath":
cfg.TmpDirPath = value
sfc-gh-pfus marked this conversation as resolved.
Show resolved Hide resolved
default:
if cfg.Params == nil {
cfg.Params = make(map[string]*string)
Expand Down
35 changes: 35 additions & 0 deletions dsn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,20 @@ func TestParseDSN(t *testing.T) {
ocspMode: ocspModeFailOpen,
err: nil,
},
{
dsn: "u:p@a.r.c.snowflakecomputing.com/db/s?account=a.r.c&tmpDirPath=%2Ftmp",
config: &Config{
Account: "a", User: "u", Password: "p",
Protocol: "https", Host: "a.r.c.snowflakecomputing.com", Port: 443,
Database: "db", Schema: "s", ValidateDefaultParameters: ConfigBoolTrue, OCSPFailOpen: OCSPFailOpenTrue,
ClientTimeout: defaultClientTimeout,
JWTClientTimeout: defaultJWTClientTimeout,
ExternalBrowserTimeout: defaultExternalBrowserTimeout,
TmpDirPath: "/tmp",
sfc-gh-dprzybysz marked this conversation as resolved.
Show resolved Hide resolved
},
ocspMode: ocspModeFailOpen,
err: nil,
},
}

for _, at := range []AuthType{AuthTypeExternalBrowser, AuthTypeOAuth} {
Expand Down Expand Up @@ -727,6 +741,9 @@ func TestParseDSN(t *testing.T) {
t.Fatalf("%d: Failed to match ExternalBrowserTimeout. expected: %v, got: %v",
i, test.config.ExternalBrowserTimeout, cfg.ExternalBrowserTimeout)
}
if test.config.TmpDirPath != cfg.TmpDirPath {
t.Fatalf("%v: Failed to match TmpDirPatch. expected: %v, got: %v", i, test.config.TmpDirPath, cfg.TmpDirPath)
}
case test.err != nil:
driverErrE, okE := test.err.(*SnowflakeError)
driverErrG, okG := err.(*SnowflakeError)
Expand Down Expand Up @@ -1109,6 +1126,15 @@ func TestDSN(t *testing.T) {
},
dsn: "u:p@a.b.c.snowflakecomputing.com:443?authenticator=tokenaccessor&ocspFailOpen=true&region=b.c&validateDefaultParameters=true",
},
{
cfg: &Config{
User: "u",
Password: "p",
Account: "a.b.c",
TmpDirPath: "/tmp",
},
dsn: "u:p@a.b.c.snowflakecomputing.com:443?ocspFailOpen=true&region=b.c&tmpDirPath=%2Ftmp&validateDefaultParameters=true",
},
}
for _, test := range testcases {
dsn, err := DSN(test.cfg)
Expand Down Expand Up @@ -1254,3 +1280,12 @@ func checkConfig(cfg Config, envMap map[string]configParamToValue) error {

return nil
}

func TestConfigValidateTmpDirPath(t *testing.T) {
cfg := &Config{
TmpDirPath: "/not/existing",
}
if err := cfg.Validate(); err == nil {
t.Fatalf("Should fail on not existing TmpDirPath")
}
}
2 changes: 1 addition & 1 deletion errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ func (se *SnowflakeError) generateTelemetryExceptionData() *telemetryData {
}

func (se *SnowflakeError) sendExceptionTelemetry(sc *snowflakeConn, data *telemetryData) error {
if sc != nil {
if sc != nil && sc.telemetry != nil {
return sc.telemetry.addLog(data)
}
return nil // TODO oob telemetry
Expand Down
4 changes: 2 additions & 2 deletions file_transfer_agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -836,7 +836,7 @@ func (sfa *snowflakeFileTransferAgent) uploadFilesSequential(fileMetas []*fileMe

func (sfa *snowflakeFileTransferAgent) uploadOneFile(meta *fileMetadata) (*fileMetadata, error) {
meta.realSrcFileName = meta.srcFileName
tmpDir, err := os.MkdirTemp("", "")
tmpDir, err := os.MkdirTemp(sfa.sc.cfg.TmpDirPath, "")
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -951,7 +951,7 @@ func (sfa *snowflakeFileTransferAgent) downloadFilesParallel(fileMetas []*fileMe
}

func (sfa *snowflakeFileTransferAgent) downloadOneFile(meta *fileMetadata) (*fileMetadata, error) {
tmpDir, err := os.MkdirTemp("", "")
tmpDir, err := os.MkdirTemp(sfa.sc.cfg.TmpDirPath, "")
if err != nil {
return nil, err
}
Expand Down
137 changes: 131 additions & 6 deletions file_transfer_agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -536,9 +536,8 @@ func TestUpdateMetadataWithPresignedUrlError(t *testing.T) {
}

func TestUploadWhenFilesystemReadOnlyError(t *testing.T) {
// Disable the test on Windows
if isWindows {
return
t.Skip("permission model is different")
}

var err error
Expand Down Expand Up @@ -581,16 +580,16 @@ func TestUploadWhenFilesystemReadOnlyError(t *testing.T) {
}

sfa := &snowflakeFileTransferAgent{
sc: nil,
sc: &snowflakeConn{
cfg: &Config{},
},
commandType: uploadCommand,
command: "put file:///tmp/test_data/data1.txt @~",
stageLocationType: gcsClient,
fileMetadata: []*fileMetadata{&uploadMeta},
parallel: 1,
}

// Set max parallel uploads to 1
sfa.parallel = 1

err = sfa.uploadFilesParallel([]*fileMetadata{&uploadMeta})
if err == nil {
t.Fatal("should error when the filesystem is read only")
Expand Down Expand Up @@ -627,3 +626,129 @@ func TestUnitUpdateProgess(t *testing.T) {
t.Fatal("should be done after updating progess")
}
}

func TestCustomTmpDirPath(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "")
if err != nil {
t.Fatalf("cannot create temp directory: %v", err)
}
defer os.RemoveAll(tmpDir)
uploadFile := filepath.Join(tmpDir, "data.txt")
f, err := os.Create(uploadFile)
if err != nil {
t.Error(err)
}
f.WriteString("test1,test2\ntest3,test4\n")
f.Close()

uploadMeta := &fileMetadata{
name: "data.txt.gz",
stageLocationType: "local",
noSleepingTime: true,
client: local,
sha256Digest: "123456789abcdef",
stageInfo: &execResponseStageInfo{
Location: tmpDir,
LocationType: "local",
},
dstFileName: "data.txt.gz",
srcFileName: uploadFile,
overwrite: true,
options: &SnowflakeFileTransferOptions{
MultiPartThreshold: dataSizeThreshold,
},
}

downloadFile := filepath.Join(tmpDir, "download.txt")
downloadMeta := &fileMetadata{
name: "data.txt.gz",
stageLocationType: "local",
noSleepingTime: true,
client: local,
sha256Digest: "123456789abcdef",
stageInfo: &execResponseStageInfo{
Location: tmpDir,
LocationType: "local",
},
srcFileName: "data.txt.gz",
dstFileName: downloadFile,
overwrite: true,
options: &SnowflakeFileTransferOptions{
MultiPartThreshold: dataSizeThreshold,
},
}

sfa := snowflakeFileTransferAgent{
sc: &snowflakeConn{
cfg: &Config{
TmpDirPath: tmpDir,
},
},
stageLocationType: local,
}
_, err = sfa.uploadOneFile(uploadMeta)
if err != nil {
t.Fatal(err)
}
_, err = sfa.downloadOneFile(downloadMeta)
sfc-gh-dprzybysz marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
t.Fatal(err)
}
defer os.Remove("download.txt")
sfc-gh-pfus marked this conversation as resolved.
Show resolved Hide resolved
}

func TestReadonlyTmpDirPathShouldFail(t *testing.T) {
if isWindows {
t.Skip("permission model is different")
}
tmpDir, err := os.MkdirTemp("", "")
if err != nil {
t.Fatalf("cannot create temp directory: %v", err)
}
defer os.RemoveAll(tmpDir)

uploadFile := filepath.Join(tmpDir, "data.txt")
f, err := os.Create(uploadFile)
if err != nil {
t.Error(err)
}
f.WriteString("test1,test2\ntest3,test4\n")
f.Close()

err = os.Chmod(tmpDir, 0400)
if err != nil {
t.Fatalf("cannot mark directory as readonly: %v", err)
}
defer os.Chmod(tmpDir, 0600)
sfc-gh-dprzybysz marked this conversation as resolved.
Show resolved Hide resolved

uploadMeta := &fileMetadata{
name: "data.txt.gz",
stageLocationType: "local",
noSleepingTime: true,
client: local,
sha256Digest: "123456789abcdef",
stageInfo: &execResponseStageInfo{
Location: tmpDir,
LocationType: "local",
},
dstFileName: "data.txt.gz",
srcFileName: uploadFile,
overwrite: true,
options: &SnowflakeFileTransferOptions{
MultiPartThreshold: dataSizeThreshold,
},
}

sfa := snowflakeFileTransferAgent{
sc: &snowflakeConn{
cfg: &Config{
TmpDirPath: tmpDir,
},
},
stageLocationType: local,
}
_, err = sfa.uploadOneFile(uploadMeta)
if err == nil {
t.Fatalf("should not upload file as temporary directory is not readable")
}
}
6 changes: 6 additions & 0 deletions put_get_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ func TestPutError(t *testing.T) {
options: &SnowflakeFileTransferOptions{
RaisePutGetError: false,
},
sc: &snowflakeConn{
cfg: &Config{},
},
}
if err = fta.execute(); err != nil {
t.Fatal(err)
Expand All @@ -78,6 +81,9 @@ func TestPutError(t *testing.T) {
options: &SnowflakeFileTransferOptions{
RaisePutGetError: true,
},
sc: &snowflakeConn{
cfg: &Config{},
},
}
if err = fta.execute(); err != nil {
t.Fatal(err)
Expand Down
Loading