Skip to content

Commit

Permalink
Update tool
Browse files Browse the repository at this point in the history
  • Loading branch information
fasmat committed Mar 12, 2024
1 parent 71f64a7 commit a666b7e
Show file tree
Hide file tree
Showing 4 changed files with 301 additions and 73 deletions.
17 changes: 17 additions & 0 deletions cmd/merge-nodes/internal/errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package internal

import (
"errors"
"fmt"
)

var ErrSupervisedNode = errors.New("merging of supervised smeshing nodes is not supported")

type ErrInvalidSchemaVersion struct {
Expected int
Actual int
}

func (e ErrInvalidSchemaVersion) Error() string {
return fmt.Sprintf("invalid schema version: expected %d got %d", e.Expected, e.Actual)
}
137 changes: 108 additions & 29 deletions cmd/merge-nodes/internal/merge_action.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,17 @@ package internal

import (
"context"
"encoding/hex"
"errors"
"fmt"
"io/fs"
"os"
"path/filepath"
"slices"

"github.com/natefinch/atomic"
"go.uber.org/zap"

"github.com/spacemeshos/go-spacemesh/signing"
"github.com/spacemeshos/go-spacemesh/sql"
"github.com/spacemeshos/go-spacemesh/sql/localsql"
)
Expand All @@ -22,18 +24,11 @@ const (
supervisedIDKeyFileName = "local.key"
)

type ErrInvalidSchemaVersion struct {
Expected int
Actual int
}

func (e ErrInvalidSchemaVersion) Error() string {
return fmt.Sprintf("invalid schema version: expected %d got %d", e.Expected, e.Actual)
}

func MergeDBs(ctx context.Context, dbLog *zap.Logger, name, from, to string) error {
func MergeDBs(ctx context.Context, dbLog *zap.Logger, from, to string) error {
// Open the target database
dstDB, err := openDB(dbLog, to)
var dstDB *localsql.Database
var err error
dstDB, err = openDB(dbLog, to)
var schemaErr ErrInvalidSchemaVersion
switch {
case errors.As(err, &schemaErr):
Expand All @@ -42,8 +37,44 @@ func MergeDBs(ctx context.Context, dbLog *zap.Logger, name, from, to string) err
zap.Int("expected version", schemaErr.Expected),
)
return err
case errors.Is(err, fs.ErrNotExist):
// target database does not exist, create it
dbLog.Info("target database does not exist, creating it", zap.String("path", to))
if err := os.MkdirAll(to, 0o700); err != nil {
return fmt.Errorf("create target directory: %w", err)
}
if err := os.MkdirAll(filepath.Join(to, keyDir), 0o700); err != nil {
return fmt.Errorf("create target key directory: %w", err)
}

dstDB, err = localsql.Open("file:"+filepath.Join(to, localDbFile),
sql.WithLogger(dbLog),
)
if err != nil {
return err
}
// -------------
// this code should not be needed but without it, Test_MergeDBs_Successful_Empty_Dir fails on INSERT initial_post
// with sqlite.SQLITE_CONSTRAINT_PRIMARYKEY error
if err := dstDB.Close(); err != nil {
return fmt.Errorf("close target database: %w", err)
}
dstDB, err = localsql.Open("file:"+filepath.Join(to, localDbFile),
sql.WithLogger(dbLog),
)
if err != nil {
return err
}
// -------------
case err != nil:
return err
default:
// target database exists, check if there is at least one key in the target key directory
// not named supervisedIDKeyFileName
if err := checkIdentities(dbLog, to); err != nil {
dbLog.Error("target appears to be a supervised node - only merging of remote smeshers is supported")
return err
}
}
defer dstDB.Close()

Expand All @@ -63,17 +94,55 @@ func MergeDBs(ctx context.Context, dbLog *zap.Logger, name, from, to string) err
return fmt.Errorf("close source database: %w", err)
}

dstKey := filepath.Join(to, keyDir, name+".key")
if _, err := os.Stat(dstKey); err == nil {
return fmt.Errorf("destination key file %s: %w", dstKey, fs.ErrExist)
if err := checkIdentities(dbLog, from); err != nil {
dbLog.Error("source appears to be a supervised node - only merging of remote smeshers is supported")
return err
}

srcKey := filepath.Join(from, keyDir, supervisedIDKeyFileName)
f, err := os.Open(srcKey)
// copy files from `from` to `to`
dir := filepath.Join(from, keyDir)
err = filepath.WalkDir(dir, func(path string, d fs.DirEntry, err error) error {
if err != nil {
return fmt.Errorf("failed to walk directory at %s: %w", path, err)
}

// skip subdirectories and files in them
if d.IsDir() && path != dir {
return fs.SkipDir
}

// skip files that are not identity files
if filepath.Ext(path) != ".key" {
return nil
}

signer, err := signing.NewEdSigner(
signing.FromFile(path),
)
if err != nil {
return fmt.Errorf("not a valid key file %s: %w", d.Name(), err)
}

dstPath := filepath.Join(to, keyDir, d.Name())
if _, err := os.Stat(dstPath); err == nil {
return fmt.Errorf("identity file %s already exists: %w", d.Name(), fs.ErrExist)
}

dst := make([]byte, hex.EncodedLen(len(signer.PrivateKey())))
hex.Encode(dst, signer.PrivateKey())
err = os.WriteFile(dstPath, dst, 0o600)
if err != nil {
return fmt.Errorf("failed to write identity file: %w", err)
}

dbLog.Info("copied identity",
zap.String("name", d.Name()),
)
return nil
})
if err != nil {
return fmt.Errorf("open source key file %s: %w", srcKey, err)
return err
}
defer f.Close()

dbLog.Info("merging databases", zap.String("from", from), zap.String("to", to))
err = dstDB.WithTx(ctx, func(tx *sql.Tx) error {
Expand All @@ -95,18 +164,12 @@ func MergeDBs(ctx context.Context, dbLog *zap.Logger, name, from, to string) err
if _, err := tx.Exec("INSERT INTO nipost SELECT * FROM srcDB.nipost;", nil, nil); err != nil {
return fmt.Errorf("merge nipost: %w", err)
}
if err := atomic.WriteFile(dstKey, f); err != nil {
return fmt.Errorf("write destination key file %s: %w", dstKey, err)
}
return nil
})
if err != nil {
return fmt.Errorf("start transaction: %w", err)
}

if err := f.Close(); err != nil {
return fmt.Errorf("close source key file %s: %w", srcKey, err)
}
if err := dstDB.Close(); err != nil {
return fmt.Errorf("close target database: %w", err)
}
Expand All @@ -116,15 +179,15 @@ func MergeDBs(ctx context.Context, dbLog *zap.Logger, name, from, to string) err
func openDB(dbLog *zap.Logger, path string) (*localsql.Database, error) {
dbPath := filepath.Join(path, localDbFile)
if _, err := os.Stat(dbPath); err != nil {
return nil, fmt.Errorf("open source database %s: %w", dbPath, err)
return nil, fmt.Errorf("open database %s: %w", dbPath, err)
}

migrations, err := sql.LocalMigrations()
if err != nil {
return nil, fmt.Errorf("get local migrations: %w", err)
}

srcDB, err := localsql.Open("file:"+dbPath,
db, err := localsql.Open("file:"+dbPath,
sql.WithLogger(dbLog),
sql.WithMigrations([]sql.Migration{}), // do not migrate database when opening
)
Expand All @@ -134,7 +197,7 @@ func openDB(dbLog *zap.Logger, path string) (*localsql.Database, error) {

// check if the source database has the right schema
var version int
_, err = srcDB.Exec("PRAGMA user_version;", nil, func(stmt *sql.Statement) bool {
_, err = db.Exec("PRAGMA user_version;", nil, func(stmt *sql.Statement) bool {
version = stmt.ColumnInt(0)
return true
})
Expand All @@ -147,5 +210,21 @@ func openDB(dbLog *zap.Logger, path string) (*localsql.Database, error) {
Actual: version,
}
}
return srcDB, nil
return db, nil
}

func checkIdentities(dbLog *zap.Logger, path string) error {
dir := filepath.Join(path, keyDir)
if err := os.MkdirAll(dir, 0o700); err != nil {
return err
}

files, err := os.ReadDir(dir)
if err != nil {
return fmt.Errorf("read target key directory: %w", err)
}
if slices.ContainsFunc(files, func(e fs.DirEntry) bool { return e.Name() == supervisedIDKeyFileName }) {
return ErrSupervisedNode
}
return nil
}
Loading

0 comments on commit a666b7e

Please sign in to comment.