diff --git a/http/client.go b/http/client.go index 2906ec3..c8d5280 100644 --- a/http/client.go +++ b/http/client.go @@ -27,7 +27,7 @@ func NewClient() *Client { } // Stream returns a snapshot and continuous stream of WAL updates. -func (c *Client) Stream(ctx context.Context, rawurl string, posMap map[uint32]litefs.Pos) (io.ReadCloser, error) { +func (c *Client) Stream(ctx context.Context, rawurl string, nodeID string, posMap map[uint32]litefs.Pos) (io.ReadCloser, error) { u, err := url.Parse(rawurl) if err != nil { return nil, fmt.Errorf("invalid client URL: %w", err) @@ -55,6 +55,8 @@ func (c *Client) Stream(ctx context.Context, rawurl string, posMap map[uint32]li } req = req.WithContext(ctx) + req.Header.Set("Litefs-Id", nodeID) + resp, err := c.HTTPClient.Do(req) if err != nil { return nil, err diff --git a/http/server.go b/http/server.go index 7fa0688..ba0042a 100644 --- a/http/server.go +++ b/http/server.go @@ -145,6 +145,12 @@ func (s *Server) serveHTTP(w http.ResponseWriter, r *http.Request) { } func (s *Server) handlePostStream(w http.ResponseWriter, r *http.Request) { + // Prevent nodes from connecting to themselves. + if id := r.Header.Get("Litefs-Id"); id == s.store.ID() { + Error(w, r, fmt.Errorf("cannot connect to self"), http.StatusBadRequest) + return + } + log.Printf("stream connected") defer log.Printf("stream disconnected") diff --git a/litefs.go b/litefs.go index 9146413..39462f0 100644 --- a/litefs.go +++ b/litefs.go @@ -101,7 +101,7 @@ func (p Pos) IsZero() bool { // Client represents a client for connecting to other LiteFS nodes. type Client interface { // Stream starts a long-running connection to stream changes from another node. - Stream(ctx context.Context, rawurl string, posMap map[uint32]Pos) (io.ReadCloser, error) + Stream(ctx context.Context, rawurl string, id string, posMap map[uint32]Pos) (io.ReadCloser, error) } type StreamFrameType uint32 diff --git a/store.go b/store.go index 4e16add..3a8e724 100644 --- a/store.go +++ b/store.go @@ -1,7 +1,9 @@ package litefs import ( + "bytes" "context" + "crypto/rand" "encoding/json" "expvar" "fmt" @@ -19,6 +21,9 @@ import ( "golang.org/x/sync/errgroup" ) +// IDLength is the length of a node ID, in bytes. +const IDLength = 24 + // Default store settings. const ( DefaultRetentionDuration = 1 * time.Minute @@ -30,6 +35,7 @@ type Store struct { mu sync.Mutex path string + id string // unique node id nextDBID uint32 dbsByID map[uint32]*DB dbsByName map[string]*DB @@ -80,9 +86,20 @@ func NewStore(path string, candidate bool) *Store { // Path returns underlying data directory. func (s *Store) Path() string { return s.path } -// DBDir returns the folder that stores a single database. -func (s *Store) DBDir(id uint32) string { - return filepath.Join(s.path, ltx.FormatDBID(id)) +// DBDir returns the folder that stores all databases. +func (s *Store) DBDir() string { + return filepath.Join(s.path, "dbs") +} + +// DBPath returns the folder that stores a single database. +func (s *Store) DBPath(id uint32) string { + return filepath.Join(s.path, "dbs", ltx.FormatDBID(id)) +} + +// ID returns the unique identifier for this instance. Available after Open(). +// Persistent across restarts if underlying storage is persistent. +func (s *Store) ID() string { + return s.id } // Open initializes the store based on files in the data directory. @@ -91,6 +108,10 @@ func (s *Store) Open() error { return err } + if err := s.initID(); err != nil { + return fmt.Errorf("init node id: %w", err) + } + if err := s.openDatabases(); err != nil { return fmt.Errorf("open databases: %w", err) } @@ -111,10 +132,52 @@ func (s *Store) Open() error { return nil } +// initID initializes an identifier that is unique to this node. +func (s *Store) initID() error { + filename := filepath.Join(s.path, "id") + + // Read existing ID from file, if it exists. + if buf, err := os.ReadFile(filename); err != nil && !os.IsNotExist(err) { + return err + } else if err == nil { + s.id = string(bytes.TrimSpace(buf)) + return nil // existing ID + } + + // Generate a new node ID if file doesn't exist. + b := make([]byte, IDLength/2) + if _, err := io.ReadFull(rand.Reader, b); err != nil { + return fmt.Errorf("generate id: %w", err) + } + id := fmt.Sprintf("%x", b) + + f, err := os.Create(filename) + if err != nil { + return err + } + defer f.Close() + + if _, err := f.Write([]byte(id + "\n")); err != nil { + return err + } else if err := f.Sync(); err != nil { + return err + } else if err := f.Close(); err != nil { + return err + } + + s.id = id + + return nil +} + func (s *Store) openDatabases() error { - f, err := os.Open(s.path) + if err := os.MkdirAll(s.DBDir(), 0777); err != nil { + return err + } + + f, err := os.Open(s.DBDir()) if err != nil { - return fmt.Errorf("open data dir: %w", err) + return fmt.Errorf("open databases dir: %w", err) } defer f.Close() @@ -140,7 +203,7 @@ func (s *Store) openDatabases() error { func (s *Store) openDatabase(id uint32) error { // Instantiate and open database. - db := NewDB(s, id, s.DBDir(id)) + db := NewDB(s, id, s.DBPath(id)) if err := db.Open(); err != nil { return err } @@ -236,20 +299,20 @@ func (s *Store) CreateDB(name string) (*DB, *os.File, error) { s.nextDBID++ // Generate database directory with name file & empty database file. - dbDir := s.DBDir(id) - if err := os.MkdirAll(dbDir, 0777); err != nil { + dbPath := s.DBPath(id) + if err := os.MkdirAll(dbPath, 0777); err != nil { return nil, nil, err - } else if err := os.WriteFile(filepath.Join(dbDir, "name"), []byte(name), 0666); err != nil { + } else if err := os.WriteFile(filepath.Join(dbPath, "name"), []byte(name), 0666); err != nil { return nil, nil, err } - f, err := os.OpenFile(filepath.Join(dbDir, "database"), os.O_RDWR|os.O_CREATE|os.O_EXCL|os.O_TRUNC, 0666) + f, err := os.OpenFile(filepath.Join(dbPath, "database"), os.O_RDWR|os.O_CREATE|os.O_EXCL|os.O_TRUNC, 0666) if err != nil { return nil, nil, err } // Create new database instance and add to maps. - db := NewDB(s, id, dbDir) + db := NewDB(s, id, dbPath) if err := db.Open(); err != nil { f.Close() return nil, nil, err @@ -280,19 +343,19 @@ func (s *Store) ForceCreateDB(id uint32, name string) (*DB, error) { // TODO: Handle conflict if another database exists with the same name. // Generate database directory with name file & empty database file. - dbDir := s.DBDir(id) - if err := os.MkdirAll(dbDir, 0777); err != nil { + dbPath := s.DBPath(id) + if err := os.MkdirAll(dbPath, 0777); err != nil { return nil, err - } else if err := os.WriteFile(filepath.Join(dbDir, "name"), []byte(name), 0666); err != nil { + } else if err := os.WriteFile(filepath.Join(dbPath, "name"), []byte(name), 0666); err != nil { return nil, err } - if err := os.WriteFile(filepath.Join(dbDir, "database"), nil, 0666); err != nil { + if err := os.WriteFile(filepath.Join(dbPath, "database"), nil, 0666); err != nil { return nil, err } // Create new database instance and add to maps. - db := NewDB(s, id, dbDir) + db := NewDB(s, id, dbPath) if err := db.Open(); err != nil { return nil, err } @@ -493,7 +556,7 @@ func (s *Store) monitorLeaseAsReplica(ctx context.Context, info *PrimaryInfo) er }() posMap := s.PosMap() - st, err := s.Client.Stream(ctx, info.AdvertiseURL, posMap) + st, err := s.Client.Stream(ctx, info.AdvertiseURL, s.id, posMap) if err != nil { return fmt.Errorf("connect to primary: %s ('%s')", err, info.AdvertiseURL) } diff --git a/store_test.go b/store_test.go index 72f1310..fef2173 100644 --- a/store_test.go +++ b/store_test.go @@ -33,13 +33,13 @@ func TestStore_CreateDB(t *testing.T) { if got, want := db.TXID(), uint64(0); !reflect.DeepEqual(got, want) { t.Fatalf("TXID=%#v, want %#v", got, want) } - if got, want := db.Path(), filepath.Join(store.Path(), "00000001"); got != want { + if got, want := db.Path(), filepath.Join(store.Path(), "dbs", "00000001"); got != want { t.Fatalf("Path=%s, want %s", got, want) } - if got, want := db.LTXDir(), filepath.Join(store.Path(), "00000001", "ltx"); got != want { + if got, want := db.LTXDir(), filepath.Join(store.Path(), "dbs", "00000001", "ltx"); got != want { t.Fatalf("LTXDir=%s, want %s", got, want) } - if got, want := db.LTXPath(1, 2), filepath.Join(store.Path(), "00000001", "ltx", "0000000000000001-0000000000000002.ltx"); got != want { + if got, want := db.LTXPath(1, 2), filepath.Join(store.Path(), "dbs", "00000001", "ltx", "0000000000000001-0000000000000002.ltx"); got != want { t.Fatalf("LTXPath=%s, want %s", got, want) } @@ -104,6 +104,35 @@ func TestPrimaryInfo_Clone(t *testing.T) { }) } +// Ensure store generates a unique ID that is persistent across restarts. +func TestStore_ID(t *testing.T) { + store := newStore(t) + if err := store.Open(); err != nil { + t.Fatal(err) + } else if err := store.Close(); err != nil { + t.Fatal(err) + } + + id := store.ID() + if id == "" { + t.Fatal("expected id") + } else if got, want := len(id), litefs.IDLength; got != want { + t.Fatalf("len(id)=%d, want %d", got, want) + } + + // Reopen as a new instance. + store = litefs.NewStore(store.Path(), true) + if err := store.Open(); err != nil { + t.Fatal(err) + } + defer store.Close() + + // Ensure ID is the same. + if got, want := store.ID(), id; got != want { + t.Fatalf("id=%q, want %q", got, want) + } +} + // newStore returns a new instance of a Store on a temporary directory. // This store will automatically close when the test ends. func newStore(tb testing.TB) *litefs.Store { diff --git a/testdata/store/open-name-only/00000001/database b/testdata/store/open-name-only/dbs/00000001/database similarity index 100% rename from testdata/store/open-name-only/00000001/database rename to testdata/store/open-name-only/dbs/00000001/database diff --git a/testdata/store/open-name-only/00000001/name b/testdata/store/open-name-only/dbs/00000001/name similarity index 100% rename from testdata/store/open-name-only/00000001/name rename to testdata/store/open-name-only/dbs/00000001/name