diff --git a/main.go b/main.go index 9c76db0..0e97a9a 100644 --- a/main.go +++ b/main.go @@ -6,6 +6,7 @@ import ( "net/http" "os" "runtime/debug" + "strconv" "time" "github.com/getsentry/sentry-go" @@ -19,11 +20,14 @@ import ( ) var ( - InitialRunFinished atomic.Bool - FlagRunOnce bool - FlagStatusAddr = ":8087" - FlagExclude []string - FlagScratch bool + InitialRunFinished atomic.Bool + FlagRunOnce bool + FlagStatusAddr = ":8087" + FlagExclude []string + FlagScratch bool + FlagDefaultFileMode = "0664" + FlagS3Endpoint = "" + FlagDisableSSL = false metricsSyncTime = prometheus.NewGauge(prometheus.GaugeOpts{ Namespace: "objinsync", @@ -100,12 +104,21 @@ func main() { if err != nil { log.Fatal(err) } + puller.DisableSSL = FlagDisableSSL + puller.S3Endpoint = FlagS3Endpoint if FlagExclude != nil { puller.AddExcludePatterns(FlagExclude) } if !FlagScratch { puller.PopulateChecksum() } + if FlagDefaultFileMode != "" { + mode, err := strconv.ParseInt(FlagDefaultFileMode, 8, 64) + if err != nil { + log.Fatal("invalid default file mode", err) + } + puller.SetDefaultFileMode(os.FileMode(mode)) + } pull := func() { start := time.Now() @@ -147,6 +160,8 @@ func main() { pullCmd.PersistentFlags().BoolVarP( &FlagRunOnce, "once", "o", false, "run action once and then exit") + pullCmd.PersistentFlags().BoolVarP( + &FlagDisableSSL, "disable-ssl", "", false, "disable SSL for object storage connection") pullCmd.PersistentFlags().StringVarP( &FlagStatusAddr, "status-addr", "s", ":8087", "binding address for status endpoint") pullCmd.PersistentFlags().StringSliceVarP( @@ -158,6 +173,10 @@ func main() { false, "skip checksums calculation and override all files during the initial sync", ) + pullCmd.PersistentFlags().StringVarP( + &FlagDefaultFileMode, "default-file-mode", "m", "0664", "default mode to use for creating local file") + pullCmd.PersistentFlags().StringVarP( + &FlagS3Endpoint, "s3-endpoint", "", "", "override endpoint to use for remote object store (e.g. minio)") rootCmd.AddCommand(pullCmd) rootCmd.Execute() diff --git a/pkg/sync/pull.go b/pkg/sync/pull.go index 771251b..c58b439 100644 --- a/pkg/sync/pull.go +++ b/pkg/sync/pull.go @@ -5,7 +5,6 @@ import ( "encoding/hex" "fmt" "io" - "io/ioutil" "os" "path/filepath" "strings" @@ -100,10 +99,13 @@ func uidFromLocalPath(localPath string) (string, error) { } type Puller struct { - RemoteUri string - LocalDir string + RemoteUri string + LocalDir string + DisableSSL bool + S3Endpoint string workingDir string + defaultMode os.FileMode exclude []string workerCnt int uidCache map[string]string @@ -151,9 +153,11 @@ func (self *Puller) downloadHandler(task DownloadTask, downloader GenericDownloa } // create file - tmpfile, err := ioutil.TempFile(self.workingDir, filepath.Base(task.LocalPath)) + tmpfileName := fmt.Sprintf("%x", md5.Sum([]byte(task.LocalPath))) + tmpfilePath := filepath.Join(self.workingDir, tmpfileName) + tmpfile, err := os.OpenFile(tmpfilePath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, self.defaultMode) if err != nil { - self.errMsgQueue <- fmt.Sprintf("Failed to create file for download: %v", err) + self.errMsgQueue <- fmt.Sprintf("Failed to create temp file for download: %v", err) return } defer tmpfile.Close() @@ -164,7 +168,7 @@ func (self *Puller) downloadHandler(task DownloadTask, downloader GenericDownloa }) // use rename to make file update atomic - err = os.Rename(tmpfile.Name(), task.LocalPath) + err = os.Rename(tmpfilePath, task.LocalPath) if err != nil { self.errMsgQueue <- fmt.Sprintf("Failed to replace file %s for download: %v", task.LocalPath, err) return @@ -306,7 +310,16 @@ func (self *Puller) Pull() string { } } - svc := s3.New(sess, aws.NewConfig().WithRegion(region)) + s3Config := &aws.Config{Region: aws.String(region)} + if self.DisableSSL { + s3Config.DisableSSL = aws.Bool(true) + } + if self.S3Endpoint != "" { + s3Config.Endpoint = aws.String(self.S3Endpoint) + s3Config.S3ForcePathStyle = aws.Bool(true) + } + svc := s3.New(sess, s3Config) + downloader := s3manager.NewDownloaderWithClient(svc) if err := self.SetupWorkingDir(); err != nil { @@ -446,17 +459,23 @@ func (self *Puller) PopulateChecksum() { } } +func (self *Puller) SetDefaultFileMode(mode os.FileMode) { + self.defaultMode = mode +} + func NewPuller(remoteUri string, localDir string) (*Puller, error) { if _, err := os.Stat(localDir); os.IsNotExist(err) { return nil, fmt.Errorf("local directory `%s` does not exist: %v", localDir, err) } return &Puller{ - RemoteUri: remoteUri, - LocalDir: localDir, - workingDir: filepath.Join(localDir, ".objinsync"), - workerCnt: 5, - uidCache: map[string]string{}, - uidLock: &sync.Mutex{}, + RemoteUri: remoteUri, + LocalDir: localDir, + DisableSSL: false, + workingDir: filepath.Join(localDir, ".objinsync"), + defaultMode: 0664, + workerCnt: 5, + uidCache: map[string]string{}, + uidLock: &sync.Mutex{}, }, nil }