diff --git a/dsn.go b/dsn.go index f5a34298f..d95893c5f 100644 --- a/dsn.go +++ b/dsn.go @@ -91,6 +91,8 @@ type Config struct { Tracing string // sets logging level + TmpDirPath string // sets temporary directory used by a driver for operations like encrypting, compressing etc + MfaToken string // Internally used to cache the MFA token IDToken string // Internally used to cache the Id Token for external browser ClientRequestMfaToken ConfigBool // When true the MFA token is cached in the credential manager. True by default in Windows/OSX. False for Linux. @@ -214,6 +216,9 @@ func DSN(cfg *Config) (dsn string, err error) { if cfg.Tracing != "" { params.Add("tracing", cfg.Tracing) } + if cfg.TmpDirPath != "" { + params.Add("tmpDirPath", cfg.TmpDirPath) + } params.Add("ocspFailOpen", strconv.FormatBool(cfg.OCSPFailOpen != OCSPFailOpenFalse)) @@ -684,6 +689,8 @@ func parseDSNParams(cfg *Config, params string) (err error) { } case "tracing": cfg.Tracing = value + case "tmpDirPath": + cfg.TmpDirPath = value default: if cfg.Params == nil { cfg.Params = make(map[string]*string) diff --git a/dsn_test.go b/dsn_test.go index 437297f0a..d764285be 100644 --- a/dsn_test.go +++ b/dsn_test.go @@ -580,6 +580,20 @@ func TestParseDSN(t *testing.T) { ocspMode: ocspModeFailOpen, err: nil, }, + { + dsn: "u:p@a.r.c.snowflakecomputing.com/db/s?account=a.r.c&tmpDirPath=%2Ftmp", + config: &Config{ + Account: "a", User: "u", Password: "p", + Protocol: "https", Host: "a.r.c.snowflakecomputing.com", Port: 443, + Database: "db", Schema: "s", ValidateDefaultParameters: ConfigBoolTrue, OCSPFailOpen: OCSPFailOpenTrue, + ClientTimeout: defaultClientTimeout, + JWTClientTimeout: defaultJWTClientTimeout, + ExternalBrowserTimeout: defaultExternalBrowserTimeout, + TmpDirPath: "/tmp", + }, + ocspMode: ocspModeFailOpen, + err: nil, + }, } for _, at := range []AuthType{AuthTypeExternalBrowser, AuthTypeOAuth} { @@ -727,6 +741,9 @@ func TestParseDSN(t *testing.T) { t.Fatalf("%d: Failed to match ExternalBrowserTimeout. expected: %v, got: %v", i, test.config.ExternalBrowserTimeout, cfg.ExternalBrowserTimeout) } + if test.config.TmpDirPath != cfg.TmpDirPath { + t.Fatalf("%v: Failed to match TmpDirPatch. expected: %v, got: %v", i, test.config.TmpDirPath, cfg.TmpDirPath) + } case test.err != nil: driverErrE, okE := test.err.(*SnowflakeError) driverErrG, okG := err.(*SnowflakeError) @@ -1100,6 +1117,15 @@ func TestDSN(t *testing.T) { }, dsn: "u:p@a.b.c.snowflakecomputing.com:443?ocspFailOpen=true®ion=b.c&token=t&validateDefaultParameters=true", }, + { + cfg: &Config{ + User: "u", + Password: "p", + Account: "a.b.c", + TmpDirPath: "/tmp", + }, + dsn: "u:p@a.b.c.snowflakecomputing.com:443?ocspFailOpen=true®ion=b.c&tmpDirPath=%2Ftmp&validateDefaultParameters=true", + }, } for _, test := range testcases { dsn, err := DSN(test.cfg) diff --git a/errors.go b/errors.go index c49fc921c..459850751 100644 --- a/errors.go +++ b/errors.go @@ -66,7 +66,7 @@ func (se *SnowflakeError) generateTelemetryExceptionData() *telemetryData { } func (se *SnowflakeError) sendExceptionTelemetry(sc *snowflakeConn, data *telemetryData) error { - if sc != nil { + if sc != nil && sc.telemetry != nil { return sc.telemetry.addLog(data) } return nil // TODO oob telemetry diff --git a/file_transfer_agent.go b/file_transfer_agent.go index 99bec9b23..243f9403e 100644 --- a/file_transfer_agent.go +++ b/file_transfer_agent.go @@ -848,7 +848,7 @@ func (sfa *snowflakeFileTransferAgent) uploadFilesSequential(fileMetas []*fileMe func (sfa *snowflakeFileTransferAgent) uploadOneFile(meta *fileMetadata) (*fileMetadata, error) { meta.realSrcFileName = meta.srcFileName - tmpDir, err := os.MkdirTemp("", "") + tmpDir, err := os.MkdirTemp(sfa.sc.cfg.TmpDirPath, "") if err != nil { return nil, err } @@ -995,7 +995,7 @@ func (sfa *snowflakeFileTransferAgent) downloadFilesSequential(fileMetas []*file } func (sfa *snowflakeFileTransferAgent) downloadOneFile(meta *fileMetadata) (*fileMetadata, error) { - tmpDir, err := os.MkdirTemp("", "") + tmpDir, err := os.MkdirTemp(sfa.sc.cfg.TmpDirPath, "") if err != nil { return nil, err } diff --git a/file_transfer_agent_test.go b/file_transfer_agent_test.go index d4ad624b8..617f0a33c 100644 --- a/file_transfer_agent_test.go +++ b/file_transfer_agent_test.go @@ -534,9 +534,8 @@ func TestUpdateMetadataWithPresignedUrlError(t *testing.T) { } func TestUploadWhenFilesystemReadOnlyError(t *testing.T) { - // Disable the test on Windows if isWindows { - return + t.Skip("permission model is different") } var err error @@ -579,16 +578,16 @@ func TestUploadWhenFilesystemReadOnlyError(t *testing.T) { } sfa := &snowflakeFileTransferAgent{ - sc: nil, + sc: &snowflakeConn{ + cfg: &Config{}, + }, commandType: uploadCommand, command: "put file:///tmp/test_data/data1.txt @~", stageLocationType: gcsClient, fileMetadata: []*fileMetadata{&uploadMeta}, + parallel: 1, } - // Set max parallel uploads to 1 - sfa.parallel = 1 - err = sfa.uploadFilesParallel([]*fileMetadata{&uploadMeta}) if err == nil { t.Fatal("should error when the filesystem is read only") @@ -597,3 +596,129 @@ func TestUploadWhenFilesystemReadOnlyError(t *testing.T) { t.Fatalf("should error when creating the temporary directory. Instead errored with: %v", err) } } + +func TestCustomTmpDirPath(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "") + if err != nil { + t.Fatalf("cannot create temp directory: %v", err) + } + defer os.RemoveAll(tmpDir) + uploadFile := filepath.Join(tmpDir, "data.txt") + f, err := os.Create(uploadFile) + if err != nil { + t.Error(err) + } + f.WriteString("test1,test2\ntest3,test4\n") + f.Close() + + uploadMeta := &fileMetadata{ + name: "data.txt.gz", + stageLocationType: "local", + noSleepingTime: true, + client: local, + sha256Digest: "123456789abcdef", + stageInfo: &execResponseStageInfo{ + Location: tmpDir, + LocationType: "local", + }, + dstFileName: "data.txt.gz", + srcFileName: uploadFile, + overwrite: true, + options: &SnowflakeFileTransferOptions{ + MultiPartThreshold: dataSizeThreshold, + }, + } + + downloadFile := filepath.Join(tmpDir, "download.txt") + downloadMeta := &fileMetadata{ + name: "data.txt.gz", + stageLocationType: "local", + noSleepingTime: true, + client: local, + sha256Digest: "123456789abcdef", + stageInfo: &execResponseStageInfo{ + Location: tmpDir, + LocationType: "local", + }, + srcFileName: "data.txt.gz", + dstFileName: downloadFile, + overwrite: true, + options: &SnowflakeFileTransferOptions{ + MultiPartThreshold: dataSizeThreshold, + }, + } + + sfa := snowflakeFileTransferAgent{ + sc: &snowflakeConn{ + cfg: &Config{ + TmpDirPath: tmpDir, + }, + }, + stageLocationType: local, + } + _, err = sfa.uploadOneFile(uploadMeta) + if err != nil { + t.Fatal(err) + } + _, err = sfa.downloadOneFile(downloadMeta) + if err != nil { + t.Fatal(err) + } + defer os.Remove("download.txt") +} + +func TestReadonlyTmpDirPathShouldFail(t *testing.T) { + if isWindows { + t.Skip("permission model is different") + } + tmpDir, err := os.MkdirTemp("", "") + if err != nil { + t.Fatalf("cannot create temp directory: %v", err) + } + defer os.RemoveAll(tmpDir) + + uploadFile := filepath.Join(tmpDir, "data.txt") + f, err := os.Create(uploadFile) + if err != nil { + t.Error(err) + } + f.WriteString("test1,test2\ntest3,test4\n") + f.Close() + + err = os.Chmod(tmpDir, 0400) + if err != nil { + t.Fatalf("cannot mark directory as readonly: %v", err) + } + defer os.Chmod(tmpDir, 0600) + + uploadMeta := &fileMetadata{ + name: "data.txt.gz", + stageLocationType: "local", + noSleepingTime: true, + client: local, + sha256Digest: "123456789abcdef", + stageInfo: &execResponseStageInfo{ + Location: tmpDir, + LocationType: "local", + }, + dstFileName: "data.txt.gz", + srcFileName: uploadFile, + overwrite: true, + options: &SnowflakeFileTransferOptions{ + MultiPartThreshold: dataSizeThreshold, + }, + } + + sfa := snowflakeFileTransferAgent{ + sc: &snowflakeConn{ + cfg: &Config{ + TmpDirPath: tmpDir, + }, + }, + stageLocationType: local, + } + _, err = sfa.uploadOneFile(uploadMeta) + if err == nil { + t.Fatalf("should not upload file as temporary directory is not readable") + } +} diff --git a/put_get_test.go b/put_get_test.go index e6ec5e0db..6ea2762f3 100644 --- a/put_get_test.go +++ b/put_get_test.go @@ -65,6 +65,9 @@ func TestPutError(t *testing.T) { options: &SnowflakeFileTransferOptions{ RaisePutGetError: false, }, + sc: &snowflakeConn{ + cfg: &Config{}, + }, } if err = fta.execute(); err != nil { t.Fatal(err) @@ -78,6 +81,9 @@ func TestPutError(t *testing.T) { options: &SnowflakeFileTransferOptions{ RaisePutGetError: true, }, + sc: &snowflakeConn{ + cfg: &Config{}, + }, } if err = fta.execute(); err != nil { t.Fatal(err)