Skip to content

Commit

Permalink
SNOW-845282: Allow configuring tmpdir in DSN
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-pfus committed Aug 4, 2023
1 parent 5e1158c commit 0ab83ed
Show file tree
Hide file tree
Showing 6 changed files with 173 additions and 9 deletions.
7 changes: 7 additions & 0 deletions dsn.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ type Config struct {

Tracing string // sets logging level

TmpDirPath string // sets temporary directory used by a driver for operations like encrypting, compressing etc

MfaToken string // Internally used to cache the MFA token
IDToken string // Internally used to cache the Id Token for external browser
ClientRequestMfaToken ConfigBool // When true the MFA token is cached in the credential manager. True by default in Windows/OSX. False for Linux.
Expand Down Expand Up @@ -214,6 +216,9 @@ func DSN(cfg *Config) (dsn string, err error) {
if cfg.Tracing != "" {
params.Add("tracing", cfg.Tracing)
}
if cfg.TmpDirPath != "" {
params.Add("tmpDirPath", cfg.TmpDirPath)
}

params.Add("ocspFailOpen", strconv.FormatBool(cfg.OCSPFailOpen != OCSPFailOpenFalse))

Expand Down Expand Up @@ -684,6 +689,8 @@ func parseDSNParams(cfg *Config, params string) (err error) {
}
case "tracing":
cfg.Tracing = value
case "tmpDirPath":
cfg.TmpDirPath = value
default:
if cfg.Params == nil {
cfg.Params = make(map[string]*string)
Expand Down
26 changes: 26 additions & 0 deletions dsn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,20 @@ func TestParseDSN(t *testing.T) {
ocspMode: ocspModeFailOpen,
err: nil,
},
{
dsn: "u:p@a.r.c.snowflakecomputing.com/db/s?account=a.r.c&tmpDirPath=%2Ftmp",
config: &Config{
Account: "a", User: "u", Password: "p",
Protocol: "https", Host: "a.r.c.snowflakecomputing.com", Port: 443,
Database: "db", Schema: "s", ValidateDefaultParameters: ConfigBoolTrue, OCSPFailOpen: OCSPFailOpenTrue,
ClientTimeout: defaultClientTimeout,
JWTClientTimeout: defaultJWTClientTimeout,
ExternalBrowserTimeout: defaultExternalBrowserTimeout,
TmpDirPath: "/tmp",
},
ocspMode: ocspModeFailOpen,
err: nil,
},
}

for _, at := range []AuthType{AuthTypeExternalBrowser, AuthTypeOAuth} {
Expand Down Expand Up @@ -727,6 +741,9 @@ func TestParseDSN(t *testing.T) {
t.Fatalf("%d: Failed to match ExternalBrowserTimeout. expected: %v, got: %v",
i, test.config.ExternalBrowserTimeout, cfg.ExternalBrowserTimeout)
}
if test.config.TmpDirPath != cfg.TmpDirPath {
t.Fatalf("%v: Failed to match TmpDirPatch. expected: %v, got: %v", i, test.config.TmpDirPath, cfg.TmpDirPath)
}
case test.err != nil:
driverErrE, okE := test.err.(*SnowflakeError)
driverErrG, okG := err.(*SnowflakeError)
Expand Down Expand Up @@ -1100,6 +1117,15 @@ func TestDSN(t *testing.T) {
},
dsn: "u:p@a.b.c.snowflakecomputing.com:443?ocspFailOpen=true&region=b.c&token=t&validateDefaultParameters=true",
},
{
cfg: &Config{
User: "u",
Password: "p",
Account: "a.b.c",
TmpDirPath: "/tmp",
},
dsn: "u:p@a.b.c.snowflakecomputing.com:443?ocspFailOpen=true&region=b.c&tmpDirPath=%2Ftmp&validateDefaultParameters=true",
},
}
for _, test := range testcases {
dsn, err := DSN(test.cfg)
Expand Down
2 changes: 1 addition & 1 deletion errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ func (se *SnowflakeError) generateTelemetryExceptionData() *telemetryData {
}

func (se *SnowflakeError) sendExceptionTelemetry(sc *snowflakeConn, data *telemetryData) error {
if sc != nil {
if sc != nil && sc.telemetry != nil {
return sc.telemetry.addLog(data)
}
return nil // TODO oob telemetry
Expand Down
4 changes: 2 additions & 2 deletions file_transfer_agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -848,7 +848,7 @@ func (sfa *snowflakeFileTransferAgent) uploadFilesSequential(fileMetas []*fileMe

func (sfa *snowflakeFileTransferAgent) uploadOneFile(meta *fileMetadata) (*fileMetadata, error) {
meta.realSrcFileName = meta.srcFileName
tmpDir, err := os.MkdirTemp("", "")
tmpDir, err := os.MkdirTemp(sfa.sc.cfg.TmpDirPath, "")
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -995,7 +995,7 @@ func (sfa *snowflakeFileTransferAgent) downloadFilesSequential(fileMetas []*file
}

func (sfa *snowflakeFileTransferAgent) downloadOneFile(meta *fileMetadata) (*fileMetadata, error) {
tmpDir, err := os.MkdirTemp("", "")
tmpDir, err := os.MkdirTemp(sfa.sc.cfg.TmpDirPath, "")
if err != nil {
return nil, err
}
Expand Down
137 changes: 131 additions & 6 deletions file_transfer_agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -534,9 +534,8 @@ func TestUpdateMetadataWithPresignedUrlError(t *testing.T) {
}

func TestUploadWhenFilesystemReadOnlyError(t *testing.T) {
// Disable the test on Windows
if isWindows {
return
t.Skip("permission model is different")
}

var err error
Expand Down Expand Up @@ -579,16 +578,16 @@ func TestUploadWhenFilesystemReadOnlyError(t *testing.T) {
}

sfa := &snowflakeFileTransferAgent{
sc: nil,
sc: &snowflakeConn{
cfg: &Config{},
},
commandType: uploadCommand,
command: "put file:///tmp/test_data/data1.txt @~",
stageLocationType: gcsClient,
fileMetadata: []*fileMetadata{&uploadMeta},
parallel: 1,
}

// Set max parallel uploads to 1
sfa.parallel = 1

err = sfa.uploadFilesParallel([]*fileMetadata{&uploadMeta})
if err == nil {
t.Fatal("should error when the filesystem is read only")
Expand All @@ -597,3 +596,129 @@ func TestUploadWhenFilesystemReadOnlyError(t *testing.T) {
t.Fatalf("should error when creating the temporary directory. Instead errored with: %v", err)
}
}

func TestCustomTmpDirPath(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "")
if err != nil {
t.Fatalf("cannot create temp directory: %v", err)
}
defer os.RemoveAll(tmpDir)
uploadFile := filepath.Join(tmpDir, "data.txt")
f, err := os.Create(uploadFile)
if err != nil {
t.Error(err)
}
f.WriteString("test1,test2\ntest3,test4\n")
f.Close()

uploadMeta := &fileMetadata{
name: "data.txt.gz",
stageLocationType: "local",
noSleepingTime: true,
client: local,
sha256Digest: "123456789abcdef",
stageInfo: &execResponseStageInfo{
Location: tmpDir,
LocationType: "local",
},
dstFileName: "data.txt.gz",
srcFileName: uploadFile,
overwrite: true,
options: &SnowflakeFileTransferOptions{
MultiPartThreshold: dataSizeThreshold,
},
}

downloadFile := filepath.Join(tmpDir, "download.txt")
downloadMeta := &fileMetadata{
name: "data.txt.gz",
stageLocationType: "local",
noSleepingTime: true,
client: local,
sha256Digest: "123456789abcdef",
stageInfo: &execResponseStageInfo{
Location: tmpDir,
LocationType: "local",
},
srcFileName: "data.txt.gz",
dstFileName: downloadFile,
overwrite: true,
options: &SnowflakeFileTransferOptions{
MultiPartThreshold: dataSizeThreshold,
},
}

sfa := snowflakeFileTransferAgent{
sc: &snowflakeConn{
cfg: &Config{
TmpDirPath: tmpDir,
},
},
stageLocationType: local,
}
_, err = sfa.uploadOneFile(uploadMeta)
if err != nil {
t.Fatal(err)
}
_, err = sfa.downloadOneFile(downloadMeta)
if err != nil {
t.Fatal(err)
}
defer os.Remove("download.txt")
}

func TestReadonlyTmpDirPathShouldFail(t *testing.T) {
if isWindows {
t.Skip("permission model is different")
}
tmpDir, err := os.MkdirTemp("", "")
if err != nil {
t.Fatalf("cannot create temp directory: %v", err)
}
defer os.RemoveAll(tmpDir)

uploadFile := filepath.Join(tmpDir, "data.txt")
f, err := os.Create(uploadFile)
if err != nil {
t.Error(err)
}
f.WriteString("test1,test2\ntest3,test4\n")
f.Close()

err = os.Chmod(tmpDir, 0400)
if err != nil {
t.Fatalf("cannot mark directory as readonly: %v", err)
}
defer os.Chmod(tmpDir, 0600)

uploadMeta := &fileMetadata{
name: "data.txt.gz",
stageLocationType: "local",
noSleepingTime: true,
client: local,
sha256Digest: "123456789abcdef",
stageInfo: &execResponseStageInfo{
Location: tmpDir,
LocationType: "local",
},
dstFileName: "data.txt.gz",
srcFileName: uploadFile,
overwrite: true,
options: &SnowflakeFileTransferOptions{
MultiPartThreshold: dataSizeThreshold,
},
}

sfa := snowflakeFileTransferAgent{
sc: &snowflakeConn{
cfg: &Config{
TmpDirPath: tmpDir,
},
},
stageLocationType: local,
}
_, err = sfa.uploadOneFile(uploadMeta)
if err == nil {
t.Fatalf("should not upload file as temporary directory is not readable")
}
}
6 changes: 6 additions & 0 deletions put_get_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ func TestPutError(t *testing.T) {
options: &SnowflakeFileTransferOptions{
RaisePutGetError: false,
},
sc: &snowflakeConn{
cfg: &Config{},
},
}
if err = fta.execute(); err != nil {
t.Fatal(err)
Expand All @@ -78,6 +81,9 @@ func TestPutError(t *testing.T) {
options: &SnowflakeFileTransferOptions{
RaisePutGetError: true,
},
sc: &snowflakeConn{
cfg: &Config{},
},
}
if err = fta.execute(); err != nil {
t.Fatal(err)
Expand Down

0 comments on commit 0ab83ed

Please sign in to comment.