diff --git a/docs/db_repository_update.md b/docs/db_repository_update.md new file mode 100644 index 0000000..4ac88d7 --- /dev/null +++ b/docs/db_repository_update.md @@ -0,0 +1,49 @@ +# DB Repository Update Instructions + +To fully implement the SSH key change history feature, the database schema needs to be updated in the [ssh-sync-db](https://github.com/therealpaulgg/ssh-sync-db) repository. + +## Changes Required + +Add the following SQL to the `init.sql` file in the ssh-sync-db repository: + +```sql +-- Create enum type for change types +CREATE TYPE change_type AS ENUM ('created', 'updated', 'deleted'); + +-- Create table for tracking SSH key changes +CREATE TABLE IF NOT EXISTS ssh_key_changes ( + id UUID DEFAULT uuid_generate_v4() NOT NULL, + ssh_key_id UUID NOT NULL, + user_id UUID NOT NULL, + change_type change_type NOT NULL, + filename VARCHAR(255) NOT NULL, + previous_data BYTEA, + new_data BYTEA, + change_time TIMESTAMP WITH TIME ZONE NOT NULL, + PRIMARY KEY (id), + FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE +); + +-- Add indexes for efficient lookups +CREATE INDEX IF NOT EXISTS idx_ssh_key_changes_ssh_key_id ON ssh_key_changes(ssh_key_id); +CREATE INDEX IF NOT EXISTS idx_ssh_key_changes_user_id ON ssh_key_changes(user_id); +CREATE INDEX IF NOT EXISTS idx_ssh_key_changes_change_time ON ssh_key_changes(change_time); + +-- Add index for the most common query pattern: finding the latest changes per key for a user +CREATE INDEX IF NOT EXISTS idx_ssh_key_changes_user_time ON ssh_key_changes(user_id, change_time DESC); +``` + +## Implementation Notes + +1. The schema matches what's defined in the `docs/sql/ssh_key_changes.sql` file in this repository +2. The schema uses the existing `uuid_generate_v4()` function that's already being used in the database +3. The table includes a foreign key reference to the users table with cascade deletion to ensure data integrity +4. Indexes have been added to optimize the most common query patterns + +## Testing + +After updating the init.sql file in the ssh-sync-db repository, you can test that the schema works correctly by: + +1. Running `docker-compose up --build` in the ssh-sync-db repository +2. Connecting to the database with `psql -h localhost -p 5432 -U sshsync -d sshsync` +3. Verifying the table exists with `\d ssh_key_changes` \ No newline at end of file diff --git a/docs/sql/ssh_key_changes.sql b/docs/sql/ssh_key_changes.sql new file mode 100644 index 0000000..68bcc05 --- /dev/null +++ b/docs/sql/ssh_key_changes.sql @@ -0,0 +1,25 @@ +-- Schema for SSH Key Change tracking + +-- Create enum type for change types +CREATE TYPE change_type AS ENUM ('created', 'updated', 'deleted'); + +-- Create table for tracking SSH key changes +CREATE TABLE IF NOT EXISTS ssh_key_changes ( + id UUID PRIMARY KEY, + ssh_key_id UUID NOT NULL, + user_id UUID NOT NULL, + change_type change_type NOT NULL, + filename VARCHAR(255) NOT NULL, + previous_data BYTEA, + new_data BYTEA, + change_time TIMESTAMP WITH TIME ZONE NOT NULL, + FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE +); + +-- Add indexes for efficient lookups +CREATE INDEX IF NOT EXISTS idx_ssh_key_changes_ssh_key_id ON ssh_key_changes(ssh_key_id); +CREATE INDEX IF NOT EXISTS idx_ssh_key_changes_user_id ON ssh_key_changes(user_id); +CREATE INDEX IF NOT EXISTS idx_ssh_key_changes_change_time ON ssh_key_changes(change_time); + +-- Add index for the most common query pattern: finding the latest changes per key for a user +CREATE INDEX IF NOT EXISTS idx_ssh_key_changes_user_time ON ssh_key_changes(user_id, change_time DESC); \ No newline at end of file diff --git a/docs/ssh_key_change_history.md b/docs/ssh_key_change_history.md new file mode 100644 index 0000000..283d0b0 --- /dev/null +++ b/docs/ssh_key_change_history.md @@ -0,0 +1,36 @@ +# SSH Key Change History + +This feature adds support for tracking changes to SSH keys in the database, allowing for better conflict resolution during syncing. + +## Database Schema + +The feature introduces a new table `ssh_key_changes` that tracks all changes (creation, updates, and deletions) to SSH keys. +See the SQL schema in `docs/sql/ssh_key_changes.sql`. + +**Note**: This schema needs to be added to the init.sql file in the separate [ssh-sync-db](https://github.com/therealpaulgg/ssh-sync-db) repository. See `docs/db_repository_update.md` for detailed instructions. + +## API Usage + +### Recording Changes + +Changes are automatically recorded when using the new repository methods: + +- `SshKeyRepo.CreateSshKeyWithChange`: Create a new SSH key and record it as a creation event +- `SshKeyRepo.UpsertSshKeyWithChange`: Create or update an SSH key and record the appropriate event +- `SshKeyRepo.UpsertSshKeyWithChangeTx`: Same as above but within a transaction +- `UserRepo.DeleteUserKeyTx`: Now records a deletion event before deleting the key + +### Retrieving Change History + +Use the `SshKeyChangeRepository` to access change history: + +- `GetKeyChanges`: Retrieve the full change history for a specific SSH key +- `GetLatestKeyChangesForUser`: Get the most recent change for each of a user's SSH keys since a specified time + +## Conflict Resolution + +This change history enables better conflict resolution during syncing: + +1. When a key is deleted on the server, clients can detect this change and delete it locally +2. When both server and client have made changes to the same key, the timestamps can be used to determine which change is more recent +3. In case of conflicts, the application can choose to keep the most recent change or prompt the user for resolution \ No newline at end of file diff --git a/go.mod b/go.mod index 5b9b52e..07c2431 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,7 @@ module github.com/therealpaulgg/ssh-sync-server -go 1.23 +go 1.23.0 + toolchain go1.24.1 require ( diff --git a/internal/setup/main.go b/internal/setup/main.go index 1b6c6b5..dfa6a0a 100644 --- a/internal/setup/main.go +++ b/internal/setup/main.go @@ -46,6 +46,14 @@ func SetupServices(i *do.Injector) { dataAccessor := do.MustInvoke[database.DataAccessor](i) return &query.QueryServiceTxImpl[models.SshConfig]{DataAccessor: dataAccessor}, nil }) + do.Provide(i, func(i *do.Injector) (query.QueryService[models.SshKeyChange], error) { + dataAccessor := do.MustInvoke[database.DataAccessor](i) + return &query.QueryServiceImpl[models.SshKeyChange]{DataAccessor: dataAccessor}, nil + }) + do.Provide(i, func(i *do.Injector) (query.QueryServiceTx[models.SshKeyChange], error) { + dataAccessor := do.MustInvoke[database.DataAccessor](i) + return &query.QueryServiceTxImpl[models.SshKeyChange]{DataAccessor: dataAccessor}, nil + }) do.Provide(i, func(i *do.Injector) (repository.UserRepository, error) { return &repository.UserRepo{Injector: i}, nil }) @@ -58,5 +66,7 @@ func SetupServices(i *do.Injector) { do.Provide(i, func(i *do.Injector) (repository.SshConfigRepository, error) { return &repository.SshConfigRepo{Injector: i}, nil }) - + do.Provide(i, func(i *do.Injector) (repository.SshKeyChangeRepository, error) { + return &repository.SshKeyChangeRepo{Injector: i}, nil + }) } diff --git a/pkg/database/models/ssh_key_change.go b/pkg/database/models/ssh_key_change.go new file mode 100644 index 0000000..9c8a63d --- /dev/null +++ b/pkg/database/models/ssh_key_change.go @@ -0,0 +1,31 @@ +package models + +import ( + "time" + + "github.com/google/uuid" +) + +// ChangeType represents the type of change made to an SSH key +type ChangeType string + +const ( + // Created indicates a new SSH key was created + Created ChangeType = "created" + // Updated indicates an existing SSH key was updated + Updated ChangeType = "updated" + // Deleted indicates an SSH key was deleted + Deleted ChangeType = "deleted" +) + +// SshKeyChange represents a change to an SSH key in the database +type SshKeyChange struct { + ID uuid.UUID `json:"id" db:"id"` + SshKeyID uuid.UUID `json:"ssh_key_id" db:"ssh_key_id"` + UserID uuid.UUID `json:"user_id" db:"user_id"` + ChangeType ChangeType `json:"change_type" db:"change_type"` + Filename string `json:"filename" db:"filename"` + PreviousData []byte `json:"previous_data,omitempty" db:"previous_data"` + NewData []byte `json:"new_data,omitempty" db:"new_data"` + ChangeTime time.Time `json:"change_time" db:"change_time"` +} \ No newline at end of file diff --git a/pkg/database/query/transaction_mock.go b/pkg/database/query/transaction_mock.go new file mode 100644 index 0000000..05e2af0 --- /dev/null +++ b/pkg/database/query/transaction_mock.go @@ -0,0 +1,36 @@ +package query + +import ( + "github.com/jackc/pgx/v5" +) + +// TransactionMock implements the TransactionService interface for testing +type TransactionMock struct { + StartTxFunc func(pgx.TxOptions) (pgx.Tx, error) + CommitFunc func(pgx.Tx) error + RollbackFunc func(pgx.Tx) error +} + +// StartTx implements TransactionService +func (m *TransactionMock) StartTx(opts pgx.TxOptions) (pgx.Tx, error) { + if m.StartTxFunc != nil { + return m.StartTxFunc(opts) + } + return nil, nil +} + +// Commit implements TransactionService +func (m *TransactionMock) Commit(tx pgx.Tx) error { + if m.CommitFunc != nil { + return m.CommitFunc(tx) + } + return nil +} + +// Rollback implements TransactionService +func (m *TransactionMock) Rollback(tx pgx.Tx) error { + if m.RollbackFunc != nil { + return m.RollbackFunc(tx) + } + return nil +} \ No newline at end of file diff --git a/pkg/database/repository/machine_test.go b/pkg/database/repository/machine_test.go deleted file mode 100644 index c684cc3..0000000 --- a/pkg/database/repository/machine_test.go +++ /dev/null @@ -1,3 +0,0 @@ -package repository - -// TODO diff --git a/pkg/database/repository/ssh_config_test.go b/pkg/database/repository/ssh_config_test.go deleted file mode 100644 index c684cc3..0000000 --- a/pkg/database/repository/ssh_config_test.go +++ /dev/null @@ -1,3 +0,0 @@ -package repository - -// TODO diff --git a/pkg/database/repository/ssh_key.go b/pkg/database/repository/ssh_key.go index af1b0ee..881cc96 100644 --- a/pkg/database/repository/ssh_key.go +++ b/pkg/database/repository/ssh_key.go @@ -1,6 +1,7 @@ package repository import ( + "github.com/google/uuid" "github.com/jackc/pgx/v5" "github.com/samber/do" "github.com/therealpaulgg/ssh-sync-server/pkg/database/models" @@ -11,6 +12,12 @@ type SshKeyRepository interface { CreateSshKey(sshKey *models.SshKey) (*models.SshKey, error) UpsertSshKey(sshKey *models.SshKey) (*models.SshKey, error) UpsertSshKeyTx(sshKey *models.SshKey, tx pgx.Tx) (*models.SshKey, error) + // Methods with change tracking + CreateSshKeyWithChange(sshKey *models.SshKey) (*models.SshKey, error) + UpsertSshKeyWithChange(sshKey *models.SshKey) (*models.SshKey, error) + UpsertSshKeyWithChangeTx(sshKey *models.SshKey, tx pgx.Tx) (*models.SshKey, error) + // Get a key by user ID and filename + GetSshKeyByFilename(userID uuid.UUID, filename string) (*models.SshKey, error) } type SshKeyRepo struct { @@ -44,3 +51,165 @@ func (repo *SshKeyRepo) UpsertSshKeyTx(sshKey *models.SshKey, tx pgx.Tx) (*model } return key, nil } + +func (repo *SshKeyRepo) GetSshKeyByFilename(userID uuid.UUID, filename string) (*models.SshKey, error) { + q := do.MustInvoke[query.QueryService[models.SshKey]](repo.Injector) + key, err := q.QueryOne("SELECT * FROM ssh_keys WHERE user_id = $1 AND filename = $2", userID, filename) + if err != nil { + return nil, err + } + return key, nil +} + +func (repo *SshKeyRepo) CreateSshKeyWithChange(sshKey *models.SshKey) (*models.SshKey, error) { + // Start a transaction + txService := do.MustInvoke[query.TransactionService](repo.Injector) + tx, err := txService.StartTx(pgx.TxOptions{}) + if err != nil { + return nil, err + } + + // Defer rollback in case of error + defer func() { + if err != nil { + _ = txService.Rollback(tx) + } + }() + + // Create the SSH key + key, err := repo.UpsertSshKeyTx(sshKey, tx) + if err != nil { + return nil, err + } + + // Record the change + changeRepo := &SshKeyChangeRepo{Injector: repo.Injector} + change := &models.SshKeyChange{ + SshKeyID: key.ID, + UserID: key.UserID, + ChangeType: models.Created, + Filename: key.Filename, + NewData: key.Data, + } + + _, err = changeRepo.CreateKeyChangeTx(change, tx) + if err != nil { + return nil, err + } + + // Commit the transaction + err = txService.Commit(tx) + if err != nil { + return nil, err + } + + return key, nil +} + +func (repo *SshKeyRepo) UpsertSshKeyWithChange(sshKey *models.SshKey) (*models.SshKey, error) { + // Start a transaction + txService := do.MustInvoke[query.TransactionService](repo.Injector) + tx, err := txService.StartTx(pgx.TxOptions{}) + if err != nil { + return nil, err + } + + // Defer rollback in case of error + defer func() { + if err != nil { + _ = txService.Rollback(tx) + } + }() + + // Get the existing key if it exists + var existingKey *models.SshKey + var changeType models.ChangeType + + existingKey, err = repo.GetSshKeyByFilename(sshKey.UserID, sshKey.Filename) + if err != nil { + // If error is not "no rows", return the error + // Otherwise, continue as it's a new key + changeType = models.Created + } else if existingKey != nil { + changeType = models.Updated + } else { + changeType = models.Created + } + + // Upsert the SSH key + key, err := repo.UpsertSshKeyTx(sshKey, tx) + if err != nil { + return nil, err + } + + // Record the change + changeRepo := &SshKeyChangeRepo{Injector: repo.Injector} + change := &models.SshKeyChange{ + SshKeyID: key.ID, + UserID: key.UserID, + ChangeType: changeType, + Filename: key.Filename, + NewData: key.Data, + } + + if existingKey != nil { + change.PreviousData = existingKey.Data + } + + _, err = changeRepo.CreateKeyChangeTx(change, tx) + if err != nil { + return nil, err + } + + // Commit the transaction + err = txService.Commit(tx) + if err != nil { + return nil, err + } + + return key, nil +} + +func (repo *SshKeyRepo) UpsertSshKeyWithChangeTx(sshKey *models.SshKey, tx pgx.Tx) (*models.SshKey, error) { + // Get the existing key if it exists + var existingKey *models.SshKey + var changeType models.ChangeType + + existingKey, err := repo.GetSshKeyByFilename(sshKey.UserID, sshKey.Filename) + if err != nil { + // If error is not "no rows", return the error + // Otherwise, continue as it's a new key + changeType = models.Created + } else if existingKey != nil { + changeType = models.Updated + } else { + changeType = models.Created + } + + // Upsert the SSH key + key, err := repo.UpsertSshKeyTx(sshKey, tx) + if err != nil { + return nil, err + } + + // Record the change + changeRepo := &SshKeyChangeRepo{Injector: repo.Injector} + change := &models.SshKeyChange{ + SshKeyID: key.ID, + UserID: key.UserID, + ChangeType: changeType, + Filename: key.Filename, + NewData: key.Data, + } + + if existingKey != nil { + change.PreviousData = existingKey.Data + } + + _, err = changeRepo.CreateKeyChangeTx(change, tx) + if err != nil { + return nil, err + } + + return key, nil +} diff --git a/pkg/database/repository/ssh_key_change.go b/pkg/database/repository/ssh_key_change.go new file mode 100644 index 0000000..999e3da --- /dev/null +++ b/pkg/database/repository/ssh_key_change.go @@ -0,0 +1,112 @@ +package repository + +import ( + "time" + + "github.com/google/uuid" + "github.com/jackc/pgx/v5" + "github.com/samber/do" + "github.com/therealpaulgg/ssh-sync-server/pkg/database/models" + "github.com/therealpaulgg/ssh-sync-server/pkg/database/query" +) + +// SshKeyChangeRepository defines the interface for operations on SSH key changes +type SshKeyChangeRepository interface { + // CreateKeyChange records a new change to an SSH key + CreateKeyChange(change *models.SshKeyChange) (*models.SshKeyChange, error) + // CreateKeyChangeTx records a new change to an SSH key within a transaction + CreateKeyChangeTx(change *models.SshKeyChange, tx pgx.Tx) (*models.SshKeyChange, error) + // GetKeyChanges returns all changes for a specific SSH key + GetKeyChanges(sshKeyID uuid.UUID) ([]models.SshKeyChange, error) + // GetLatestKeyChangesForUser returns the most recent changes for each SSH key owned by a user + GetLatestKeyChangesForUser(userID uuid.UUID, since time.Time) ([]models.SshKeyChange, error) +} + +// SshKeyChangeRepo implements the SshKeyChangeRepository interface +type SshKeyChangeRepo struct { + Injector *do.Injector +} + +// CreateKeyChange records a new change to an SSH key +func (repo *SshKeyChangeRepo) CreateKeyChange(change *models.SshKeyChange) (*models.SshKeyChange, error) { + q := do.MustInvoke[query.QueryService[models.SshKeyChange]](repo.Injector) + + // Set the ID and timestamp if not already set + if change.ID == uuid.Nil { + change.ID = uuid.New() + } + if change.ChangeTime.IsZero() { + change.ChangeTime = time.Now() + } + + result, err := q.QueryOne( + "INSERT INTO ssh_key_changes (id, ssh_key_id, user_id, change_type, filename, previous_data, new_data, change_time) "+ + "VALUES ($1, $2, $3, $4, $5, $6, $7, $8) RETURNING *", + change.ID, change.SshKeyID, change.UserID, change.ChangeType, + change.Filename, change.PreviousData, change.NewData, change.ChangeTime) + + if err != nil { + return nil, err + } + + return result, nil +} + +// CreateKeyChangeTx records a new change to an SSH key within a transaction +func (repo *SshKeyChangeRepo) CreateKeyChangeTx(change *models.SshKeyChange, tx pgx.Tx) (*models.SshKeyChange, error) { + q := do.MustInvoke[query.QueryServiceTx[models.SshKeyChange]](repo.Injector) + + // Set the ID and timestamp if not already set + if change.ID == uuid.Nil { + change.ID = uuid.New() + } + if change.ChangeTime.IsZero() { + change.ChangeTime = time.Now() + } + + result, err := q.QueryOne( + tx, + "INSERT INTO ssh_key_changes (id, ssh_key_id, user_id, change_type, filename, previous_data, new_data, change_time) "+ + "VALUES ($1, $2, $3, $4, $5, $6, $7, $8) RETURNING *", + change.ID, change.SshKeyID, change.UserID, change.ChangeType, + change.Filename, change.PreviousData, change.NewData, change.ChangeTime) + + if err != nil { + return nil, err + } + + return result, nil +} + +// GetKeyChanges returns all changes for a specific SSH key +func (repo *SshKeyChangeRepo) GetKeyChanges(sshKeyID uuid.UUID) ([]models.SshKeyChange, error) { + q := do.MustInvoke[query.QueryService[models.SshKeyChange]](repo.Injector) + + results, err := q.Query( + "SELECT * FROM ssh_key_changes WHERE ssh_key_id = $1 ORDER BY change_time DESC", + sshKeyID) + + if err != nil { + return nil, err + } + + return results, nil +} + +// GetLatestKeyChangesForUser returns the most recent changes for each SSH key owned by a user +func (repo *SshKeyChangeRepo) GetLatestKeyChangesForUser(userID uuid.UUID, since time.Time) ([]models.SshKeyChange, error) { + q := do.MustInvoke[query.QueryService[models.SshKeyChange]](repo.Injector) + + results, err := q.Query( + `SELECT DISTINCT ON (ssh_key_id) * + FROM ssh_key_changes + WHERE user_id = $1 AND change_time > $2 + ORDER BY ssh_key_id, change_time DESC`, + userID, since) + + if err != nil { + return nil, err + } + + return results, nil +} \ No newline at end of file diff --git a/pkg/database/repository/ssh_key_change_mock.go b/pkg/database/repository/ssh_key_change_mock.go new file mode 100644 index 0000000..77a970b --- /dev/null +++ b/pkg/database/repository/ssh_key_change_mock.go @@ -0,0 +1,68 @@ +package repository + +import ( + "time" + + "github.com/google/uuid" + "github.com/jackc/pgx/v5" + "github.com/therealpaulgg/ssh-sync-server/pkg/database/models" +) + +// SshKeyChangeMock is a mock implementation of SshKeyChangeRepository +type SshKeyChangeMock struct { + Changes []models.SshKeyChange +} + +// CreateKeyChange records a new change to an SSH key +func (mock *SshKeyChangeMock) CreateKeyChange(change *models.SshKeyChange) (*models.SshKeyChange, error) { + if change.ID == uuid.Nil { + change.ID = uuid.New() + } + if change.ChangeTime.IsZero() { + change.ChangeTime = time.Now() + } + + mock.Changes = append(mock.Changes, *change) + return change, nil +} + +// CreateKeyChangeTx records a new change to an SSH key within a transaction +func (mock *SshKeyChangeMock) CreateKeyChangeTx(change *models.SshKeyChange, tx pgx.Tx) (*models.SshKeyChange, error) { + return mock.CreateKeyChange(change) +} + +// GetKeyChanges returns all changes for a specific SSH key +func (mock *SshKeyChangeMock) GetKeyChanges(sshKeyID uuid.UUID) ([]models.SshKeyChange, error) { + var result []models.SshKeyChange + + for _, change := range mock.Changes { + if change.SshKeyID == sshKeyID { + result = append(result, change) + } + } + + return result, nil +} + +// GetLatestKeyChangesForUser returns the most recent changes for each SSH key owned by a user +func (mock *SshKeyChangeMock) GetLatestKeyChangesForUser(userID uuid.UUID, since time.Time) ([]models.SshKeyChange, error) { + var result []models.SshKeyChange + keyMap := make(map[uuid.UUID]models.SshKeyChange) + + // Find the latest change for each key + for _, change := range mock.Changes { + if change.UserID == userID && change.ChangeTime.After(since) { + existing, exists := keyMap[change.SshKeyID] + if !exists || change.ChangeTime.After(existing.ChangeTime) { + keyMap[change.SshKeyID] = change + } + } + } + + // Convert map to slice + for _, change := range keyMap { + result = append(result, change) + } + + return result, nil +} \ No newline at end of file diff --git a/pkg/database/repository/ssh_key_change_test.go b/pkg/database/repository/ssh_key_change_test.go new file mode 100644 index 0000000..e89abc0 --- /dev/null +++ b/pkg/database/repository/ssh_key_change_test.go @@ -0,0 +1,301 @@ +package repository + +import ( + "errors" + "testing" + "time" + + "github.com/golang/mock/gomock" + "github.com/google/uuid" + "github.com/samber/do" + "github.com/stretchr/testify/assert" + "github.com/therealpaulgg/ssh-sync-server/pkg/database/models" + "github.com/therealpaulgg/ssh-sync-server/pkg/database/query" + testpgx "github.com/therealpaulgg/ssh-sync-server/test/pgx" +) + +func setupSshKeyChangeRepoTest(t *testing.T) (*do.Injector, *gomock.Controller) { + ctrl := gomock.NewController(t) + injector := do.New() + return injector, ctrl +} + +func TestSshKeyChangeRepo_CreateKeyChange(t *testing.T) { + // Arrange + injector, ctrl := setupSshKeyChangeRepoTest(t) + defer ctrl.Finish() + + userID := uuid.New() + keyID := uuid.New() + changeID := uuid.New() + now := time.Now().UTC() + + testChange := &models.SshKeyChange{ + ID: changeID, + SshKeyID: keyID, + UserID: userID, + ChangeType: models.Created, + Filename: "test_key.pub", + NewData: []byte("ssh-rsa TEST"), + ChangeTime: now, + } + + mockQuery := query.NewMockQueryService[models.SshKeyChange](ctrl) + mockQuery.EXPECT(). + QueryOne( + "INSERT INTO ssh_key_changes (id, ssh_key_id, user_id, change_type, filename, previous_data, new_data, change_time) "+ + "VALUES ($1, $2, $3, $4, $5, $6, $7, $8) RETURNING *", + gomock.Any(), keyID, userID, models.Created, "test_key.pub", nil, []byte("ssh-rsa TEST"), gomock.Any(), + ). + Return(testChange, nil) + + do.Provide(injector, func(i *do.Injector) (query.QueryService[models.SshKeyChange], error) { + return mockQuery, nil + }) + + repo := &SshKeyChangeRepo{Injector: injector} + + // Act + change, err := repo.CreateKeyChange(testChange) + + // Assert + assert.NoError(t, err) + assert.Equal(t, testChange, change) +} + +func TestSshKeyChangeRepo_CreateKeyChange_Error(t *testing.T) { + // Arrange + injector, ctrl := setupSshKeyChangeRepoTest(t) + defer ctrl.Finish() + + userID := uuid.New() + keyID := uuid.New() + now := time.Now().UTC() + + testChange := &models.SshKeyChange{ + SshKeyID: keyID, + UserID: userID, + ChangeType: models.Created, + Filename: "test_key.pub", + NewData: []byte("ssh-rsa TEST"), + ChangeTime: now, + } + + mockQuery := query.NewMockQueryService[models.SshKeyChange](ctrl) + mockQuery.EXPECT(). + QueryOne( + "INSERT INTO ssh_key_changes (id, ssh_key_id, user_id, change_type, filename, previous_data, new_data, change_time) "+ + "VALUES ($1, $2, $3, $4, $5, $6, $7, $8) RETURNING *", + gomock.Any(), keyID, userID, models.Created, "test_key.pub", nil, []byte("ssh-rsa TEST"), gomock.Any(), + ). + Return(nil, errors.New("database error")) + + do.Provide(injector, func(i *do.Injector) (query.QueryService[models.SshKeyChange], error) { + return mockQuery, nil + }) + + repo := &SshKeyChangeRepo{Injector: injector} + + // Act + change, err := repo.CreateKeyChange(testChange) + + // Assert + assert.Error(t, err) + assert.Nil(t, change) +} + +func TestSshKeyChangeRepo_CreateKeyChangeTx(t *testing.T) { + // Arrange + injector, ctrl := setupSshKeyChangeRepoTest(t) + defer ctrl.Finish() + + userID := uuid.New() + keyID := uuid.New() + changeID := uuid.New() + now := time.Now().UTC() + + testChange := &models.SshKeyChange{ + ID: changeID, + SshKeyID: keyID, + UserID: userID, + ChangeType: models.Created, + Filename: "test_key.pub", + NewData: []byte("ssh-rsa TEST"), + ChangeTime: now, + } + + tx := testpgx.NewMockTx(ctrl) + mockQueryTx := query.NewMockQueryServiceTx[models.SshKeyChange](ctrl) + mockQueryTx.EXPECT(). + QueryOne( + tx, + "INSERT INTO ssh_key_changes (id, ssh_key_id, user_id, change_type, filename, previous_data, new_data, change_time) "+ + "VALUES ($1, $2, $3, $4, $5, $6, $7, $8) RETURNING *", + gomock.Any(), // id + keyID, // ssh_key_id + userID, // user_id + models.Created, // change_type + "test_key.pub", // filename + nil, // previous_data + []byte("ssh-rsa TEST"), // new_data + gomock.Any(), // change_time + ). + Return(testChange, nil) + + do.Provide(injector, func(i *do.Injector) (query.QueryServiceTx[models.SshKeyChange], error) { + return mockQueryTx, nil + }) + + repo := &SshKeyChangeRepo{Injector: injector} + + // Act + change, err := repo.CreateKeyChangeTx(testChange, tx) + + // Assert + assert.NoError(t, err) + assert.Equal(t, testChange, change) +} + +func TestSshKeyChangeRepo_GetKeyChanges(t *testing.T) { + // Arrange + injector, ctrl := setupSshKeyChangeRepoTest(t) + defer ctrl.Finish() + + keyID := uuid.New() + userID := uuid.New() + + changes := []models.SshKeyChange{ + { + ID: uuid.New(), + SshKeyID: keyID, + UserID: userID, + ChangeType: models.Created, + Filename: "test_key.pub", + NewData: []byte("ssh-rsa TEST1"), + ChangeTime: time.Now().Add(-2 * time.Hour), + }, + { + ID: uuid.New(), + SshKeyID: keyID, + UserID: userID, + ChangeType: models.Updated, + Filename: "test_key.pub", + PreviousData: []byte("ssh-rsa TEST1"), + NewData: []byte("ssh-rsa TEST2"), + ChangeTime: time.Now().Add(-1 * time.Hour), + }, + } + + mockQuery := query.NewMockQueryService[models.SshKeyChange](ctrl) + mockQuery.EXPECT(). + Query( + "SELECT * FROM ssh_key_changes WHERE ssh_key_id = $1 ORDER BY change_time DESC", + keyID, + ). + Return(changes, nil) + + do.Provide(injector, func(i *do.Injector) (query.QueryService[models.SshKeyChange], error) { + return mockQuery, nil + }) + + repo := &SshKeyChangeRepo{Injector: injector} + + // Act + result, err := repo.GetKeyChanges(keyID) + + // Assert + assert.NoError(t, err) + assert.Equal(t, changes, result) +} + +func TestSshKeyChangeRepo_GetLatestKeyChangesForUser(t *testing.T) { + // Arrange + injector, ctrl := setupSshKeyChangeRepoTest(t) + defer ctrl.Finish() + + userID := uuid.New() + since := time.Now().Add(-24 * time.Hour) + + changes := []models.SshKeyChange{ + { + ID: uuid.New(), + SshKeyID: uuid.New(), + UserID: userID, + ChangeType: models.Created, + Filename: "key1.pub", + NewData: []byte("ssh-rsa KEY1"), + ChangeTime: time.Now().Add(-2 * time.Hour), + }, + { + ID: uuid.New(), + SshKeyID: uuid.New(), + UserID: userID, + ChangeType: models.Updated, + Filename: "key2.pub", + PreviousData: []byte("ssh-rsa OLD"), + NewData: []byte("ssh-rsa KEY2"), + ChangeTime: time.Now().Add(-1 * time.Hour), + }, + } + + mockQuery := query.NewMockQueryService[models.SshKeyChange](ctrl) + mockQuery.EXPECT(). + Query( + `SELECT DISTINCT ON (ssh_key_id) * + FROM ssh_key_changes + WHERE user_id = $1 AND change_time > $2 + ORDER BY ssh_key_id, change_time DESC`, + userID, since, + ). + Return(changes, nil) + + do.Provide(injector, func(i *do.Injector) (query.QueryService[models.SshKeyChange], error) { + return mockQuery, nil + }) + + repo := &SshKeyChangeRepo{Injector: injector} + + // Act + result, err := repo.GetLatestKeyChangesForUser(userID, since) + + // Assert + assert.NoError(t, err) + assert.Equal(t, changes, result) +} + +func TestSshKeyChangeMock(t *testing.T) { + // Test the mock implementation + mock := SshKeyChangeMock{} + + // Test CreateKeyChange + userID := uuid.New() + keyID := uuid.New() + change := &models.SshKeyChange{ + SshKeyID: keyID, + UserID: userID, + ChangeType: models.Created, + Filename: "test.pub", + NewData: []byte("test data"), + } + + result, err := mock.CreateKeyChange(change) + assert.NoError(t, err) + assert.NotEqual(t, uuid.Nil, result.ID) // ID should be set + assert.False(t, result.ChangeTime.IsZero()) // Time should be set + + // Test CreateKeyChangeTx - we just pass nil for the tx since it's not used in the mock + result, err = mock.CreateKeyChangeTx(change, nil) + assert.NoError(t, err) + assert.Equal(t, 2, len(mock.Changes)) // Should have added another change + + // Test GetKeyChanges + changes, err := mock.GetKeyChanges(keyID) + assert.NoError(t, err) + assert.Equal(t, 2, len(changes)) // Should return both changes from above + + // Test GetLatestKeyChangesForUser + since := time.Now().Add(-1 * time.Hour) + latestChanges, err := mock.GetLatestKeyChangesForUser(userID, since) + assert.NoError(t, err) + assert.Equal(t, 1, len(latestChanges)) // Should return one change per key +} \ No newline at end of file diff --git a/pkg/database/repository/ssh_key_test.go b/pkg/database/repository/ssh_key_test.go index c684cc3..2254b71 100644 --- a/pkg/database/repository/ssh_key_test.go +++ b/pkg/database/repository/ssh_key_test.go @@ -1,3 +1,438 @@ package repository -// TODO +import ( + "errors" + "testing" + "time" + + "github.com/golang/mock/gomock" + "github.com/google/uuid" + "github.com/samber/do" + "github.com/stretchr/testify/assert" + "github.com/therealpaulgg/ssh-sync-server/pkg/database/models" + "github.com/therealpaulgg/ssh-sync-server/pkg/database/query" + testpgx "github.com/therealpaulgg/ssh-sync-server/test/pgx" +) + +func setupSshKeyRepoTest(t *testing.T) (*do.Injector, *gomock.Controller) { + ctrl := gomock.NewController(t) + injector := do.New() + return injector, ctrl +} + +func TestSshKeyRepo_CreateSshKey(t *testing.T) { + // Arrange + injector, ctrl := setupSshKeyRepoTest(t) + defer ctrl.Finish() + + userId := uuid.New() + keyId := uuid.New() + testKey := &models.SshKey{ + ID: keyId, + UserID: userId, + Filename: "test_key.pub", + Data: []byte("ssh-rsa TEST"), + } + + mockQuery := query.NewMockQueryService[models.SshKey](ctrl) + mockQuery.EXPECT(). + QueryOne("INSERT INTO ssh_keys (user_id, filename, data) VALUES ($1, $2, $3) RETURNING *", + userId, "test_key.pub", []byte("ssh-rsa TEST")). + Return(testKey, nil) + + do.Provide(injector, func(i *do.Injector) (query.QueryService[models.SshKey], error) { + return mockQuery, nil + }) + + repo := &SshKeyRepo{Injector: injector} + + // Act + key, err := repo.CreateSshKey(testKey) + + // Assert + assert.NoError(t, err) + assert.Equal(t, testKey, key) +} + +func TestSshKeyRepo_CreateSshKey_Error(t *testing.T) { + // Arrange + injector, ctrl := setupSshKeyRepoTest(t) + defer ctrl.Finish() + + userId := uuid.New() + testKey := &models.SshKey{ + UserID: userId, + Filename: "test_key.pub", + Data: []byte("ssh-rsa TEST"), + } + + mockQuery := query.NewMockQueryService[models.SshKey](ctrl) + mockQuery.EXPECT(). + QueryOne("INSERT INTO ssh_keys (user_id, filename, data) VALUES ($1, $2, $3) RETURNING *", + userId, "test_key.pub", []byte("ssh-rsa TEST")). + Return(nil, errors.New("database error")) + + do.Provide(injector, func(i *do.Injector) (query.QueryService[models.SshKey], error) { + return mockQuery, nil + }) + + repo := &SshKeyRepo{Injector: injector} + + // Act + key, err := repo.CreateSshKey(testKey) + + // Assert + assert.Error(t, err) + assert.Nil(t, key) +} + +func TestSshKeyRepo_UpsertSshKey(t *testing.T) { + // Arrange + injector, ctrl := setupSshKeyRepoTest(t) + defer ctrl.Finish() + + userId := uuid.New() + keyId := uuid.New() + testKey := &models.SshKey{ + ID: keyId, + UserID: userId, + Filename: "test_key.pub", + Data: []byte("ssh-rsa TEST"), + } + + mockQuery := query.NewMockQueryService[models.SshKey](ctrl) + mockQuery.EXPECT(). + QueryOne("INSERT INTO ssh_keys (user_id, filename, data) VALUES ($1, $2, $3) ON CONFLICT (user_id, filename) DO UPDATE SET data = $3 RETURNING *", + userId, "test_key.pub", []byte("ssh-rsa TEST")). + Return(testKey, nil) + + do.Provide(injector, func(i *do.Injector) (query.QueryService[models.SshKey], error) { + return mockQuery, nil + }) + + repo := &SshKeyRepo{Injector: injector} + + // Act + key, err := repo.UpsertSshKey(testKey) + + // Assert + assert.NoError(t, err) + assert.Equal(t, testKey, key) +} + +func TestSshKeyRepo_UpsertSshKeyTx(t *testing.T) { + // Arrange + injector, ctrl := setupSshKeyRepoTest(t) + defer ctrl.Finish() + + userId := uuid.New() + keyId := uuid.New() + testKey := &models.SshKey{ + ID: keyId, + UserID: userId, + Filename: "test_key.pub", + Data: []byte("ssh-rsa TEST"), + } + + tx := testpgx.NewMockTx(ctrl) + mockQueryTx := query.NewMockQueryServiceTx[models.SshKey](ctrl) + mockQueryTx.EXPECT(). + QueryOne(tx, "INSERT INTO ssh_keys (user_id, filename, data) VALUES ($1, $2, $3) ON CONFLICT (user_id, filename) DO UPDATE SET data = $3 RETURNING *", + userId, "test_key.pub", []byte("ssh-rsa TEST")). + Return(testKey, nil) + + do.Provide(injector, func(i *do.Injector) (query.QueryServiceTx[models.SshKey], error) { + return mockQueryTx, nil + }) + + repo := &SshKeyRepo{Injector: injector} + + // Act + key, err := repo.UpsertSshKeyTx(testKey, tx) + + // Assert + assert.NoError(t, err) + assert.Equal(t, testKey, key) +} + +func TestSshKeyRepo_GetSshKeyByFilename(t *testing.T) { + // Arrange + injector, ctrl := setupSshKeyRepoTest(t) + defer ctrl.Finish() + + userId := uuid.New() + keyId := uuid.New() + filename := "test_key.pub" + testKey := &models.SshKey{ + ID: keyId, + UserID: userId, + Filename: filename, + Data: []byte("ssh-rsa TEST"), + } + + mockQuery := query.NewMockQueryService[models.SshKey](ctrl) + mockQuery.EXPECT(). + QueryOne("SELECT * FROM ssh_keys WHERE user_id = $1 AND filename = $2", userId, filename). + Return(testKey, nil) + + do.Provide(injector, func(i *do.Injector) (query.QueryService[models.SshKey], error) { + return mockQuery, nil + }) + + repo := &SshKeyRepo{Injector: injector} + + // Act + key, err := repo.GetSshKeyByFilename(userId, filename) + + // Assert + assert.NoError(t, err) + assert.Equal(t, testKey, key) +} + +func TestSshKeyRepo_CreateSshKeyWithChange(t *testing.T) { + // Arrange + injector, ctrl := setupSshKeyRepoTest(t) + defer ctrl.Finish() + + userId := uuid.New() + keyId := uuid.New() + changeId := uuid.New() + testKey := &models.SshKey{ + ID: keyId, + UserID: userId, + Filename: "test_key.pub", + Data: []byte("ssh-rsa TEST"), + } + + testChange := &models.SshKeyChange{ + ID: changeId, + SshKeyID: keyId, + UserID: userId, + ChangeType: models.Created, + Filename: "test_key.pub", + NewData: []byte("ssh-rsa TEST"), + ChangeTime: time.Now(), + } + + tx := testpgx.NewMockTx(ctrl) + + // Mock transaction service + mockTxService := query.NewMockTransactionService(ctrl) + mockTxService.EXPECT().StartTx(gomock.Any()).Return(tx, nil) + mockTxService.EXPECT().Commit(tx).Return(nil) + do.Provide(injector, func(i *do.Injector) (query.TransactionService, error) { + return mockTxService, nil + }) + + // Mock query services for UpsertSshKeyTx + mockQueryTx := query.NewMockQueryServiceTx[models.SshKey](ctrl) + mockQueryTx.EXPECT(). + QueryOne(gomock.Eq(tx), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(testKey, nil) + do.Provide(injector, func(i *do.Injector) (query.QueryServiceTx[models.SshKey], error) { + return mockQueryTx, nil + }) + + // Mock query services for CreateKeyChangeTx + mockChangeTx := query.NewMockQueryServiceTx[models.SshKeyChange](ctrl) + mockChangeTx.EXPECT(). + QueryOne( + gomock.Eq(tx), + gomock.Any(), + gomock.Any(), // id + gomock.Any(), // ssh_key_id + gomock.Any(), // user_id + gomock.Any(), // change_type + gomock.Any(), // filename + gomock.Any(), // previous_data + gomock.Any(), // new_data + gomock.Any(), // change_time + ). + Return(testChange, nil) + do.Provide(injector, func(i *do.Injector) (query.QueryServiceTx[models.SshKeyChange], error) { + return mockChangeTx, nil + }) + + repo := &SshKeyRepo{Injector: injector} + + // Act + key, err := repo.CreateSshKeyWithChange(testKey) + + // Assert + assert.NoError(t, err) + assert.Equal(t, testKey, key) +} + +func TestSshKeyRepo_UpsertSshKeyWithChange(t *testing.T) { + // Arrange + injector, ctrl := setupSshKeyRepoTest(t) + defer ctrl.Finish() + + userId := uuid.New() + keyId := uuid.New() + changeId := uuid.New() + testKey := &models.SshKey{ + ID: keyId, + UserID: userId, + Filename: "test_key.pub", + Data: []byte("ssh-rsa TEST"), + } + + existingKey := &models.SshKey{ + ID: keyId, + UserID: userId, + Filename: "test_key.pub", + Data: []byte("ssh-rsa OLD"), + } + + testChange := &models.SshKeyChange{ + ID: changeId, + SshKeyID: keyId, + UserID: userId, + ChangeType: models.Updated, + Filename: "test_key.pub", + PreviousData: []byte("ssh-rsa OLD"), + NewData: []byte("ssh-rsa TEST"), + ChangeTime: time.Now(), + } + + tx := testpgx.NewMockTx(ctrl) + + // Mock transaction service + mockTxService := query.NewMockTransactionService(ctrl) + mockTxService.EXPECT().StartTx(gomock.Any()).Return(tx, nil) + mockTxService.EXPECT().Commit(tx).Return(nil) + do.Provide(injector, func(i *do.Injector) (query.TransactionService, error) { + return mockTxService, nil + }) + + // Mock query service for GetSshKeyByFilename + mockQuery := query.NewMockQueryService[models.SshKey](ctrl) + mockQuery.EXPECT(). + QueryOne("SELECT * FROM ssh_keys WHERE user_id = $1 AND filename = $2", userId, "test_key.pub"). + Return(existingKey, nil) + do.Provide(injector, func(i *do.Injector) (query.QueryService[models.SshKey], error) { + return mockQuery, nil + }) + + // Mock query services for UpsertSshKeyTx + mockQueryTx := query.NewMockQueryServiceTx[models.SshKey](ctrl) + mockQueryTx.EXPECT(). + QueryOne(gomock.Eq(tx), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(testKey, nil) + do.Provide(injector, func(i *do.Injector) (query.QueryServiceTx[models.SshKey], error) { + return mockQueryTx, nil + }) + + // Mock query services for CreateKeyChangeTx + mockChangeTx := query.NewMockQueryServiceTx[models.SshKeyChange](ctrl) + mockChangeTx.EXPECT(). + QueryOne( + gomock.Eq(tx), + gomock.Any(), + gomock.Any(), // id + gomock.Any(), // ssh_key_id + gomock.Any(), // user_id + gomock.Any(), // change_type + gomock.Any(), // filename + gomock.Any(), // previous_data + gomock.Any(), // new_data + gomock.Any(), // change_time + ). + Return(testChange, nil) + do.Provide(injector, func(i *do.Injector) (query.QueryServiceTx[models.SshKeyChange], error) { + return mockChangeTx, nil + }) + + repo := &SshKeyRepo{Injector: injector} + + // Act + key, err := repo.UpsertSshKeyWithChange(testKey) + + // Assert + assert.NoError(t, err) + assert.Equal(t, testKey, key) +} + +func TestSshKeyRepo_UpsertSshKeyWithChangeTx(t *testing.T) { + // Arrange + injector, ctrl := setupSshKeyRepoTest(t) + defer ctrl.Finish() + + userId := uuid.New() + keyId := uuid.New() + changeId := uuid.New() + testKey := &models.SshKey{ + ID: keyId, + UserID: userId, + Filename: "test_key.pub", + Data: []byte("ssh-rsa TEST"), + } + + existingKey := &models.SshKey{ + ID: keyId, + UserID: userId, + Filename: "test_key.pub", + Data: []byte("ssh-rsa OLD"), + } + + testChange := &models.SshKeyChange{ + ID: changeId, + SshKeyID: keyId, + UserID: userId, + ChangeType: models.Updated, + Filename: "test_key.pub", + PreviousData: []byte("ssh-rsa OLD"), + NewData: []byte("ssh-rsa TEST"), + ChangeTime: time.Now(), + } + + tx := testpgx.NewMockTx(ctrl) + + // Mock query service for GetSshKeyByFilename + mockQuery := query.NewMockQueryService[models.SshKey](ctrl) + mockQuery.EXPECT(). + QueryOne("SELECT * FROM ssh_keys WHERE user_id = $1 AND filename = $2", userId, "test_key.pub"). + Return(existingKey, nil) + do.Provide(injector, func(i *do.Injector) (query.QueryService[models.SshKey], error) { + return mockQuery, nil + }) + + // Mock query services for UpsertSshKeyTx + mockQueryTx := query.NewMockQueryServiceTx[models.SshKey](ctrl) + mockQueryTx.EXPECT(). + QueryOne(gomock.Eq(tx), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(testKey, nil) + do.Provide(injector, func(i *do.Injector) (query.QueryServiceTx[models.SshKey], error) { + return mockQueryTx, nil + }) + + // Mock query services for CreateKeyChangeTx + mockChangeTx := query.NewMockQueryServiceTx[models.SshKeyChange](ctrl) + mockChangeTx.EXPECT(). + QueryOne( + gomock.Eq(tx), + gomock.Any(), + gomock.Any(), // id + gomock.Any(), // ssh_key_id + gomock.Any(), // user_id + gomock.Any(), // change_type + gomock.Any(), // filename + gomock.Any(), // previous_data + gomock.Any(), // new_data + gomock.Any(), // change_time + ). + Return(testChange, nil) + do.Provide(injector, func(i *do.Injector) (query.QueryServiceTx[models.SshKeyChange], error) { + return mockChangeTx, nil + }) + + repo := &SshKeyRepo{Injector: injector} + + // Act + key, err := repo.UpsertSshKeyWithChangeTx(testKey, tx) + + // Assert + assert.NoError(t, err) + assert.Equal(t, testKey, key) +} diff --git a/pkg/database/repository/user.go b/pkg/database/repository/user.go index a2036a0..1f0b14d 100644 --- a/pkg/database/repository/user.go +++ b/pkg/database/repository/user.go @@ -211,6 +211,34 @@ func (repo *UserRepo) GetUserKey(userId uuid.UUID, keyId uuid.UUID) (*models.Ssh } func (repo *UserRepo) DeleteUserKeyTx(user *models.User, id uuid.UUID, tx pgx.Tx) error { - _, err := tx.Exec(context.TODO(), "delete from ssh_keys where user_id = $1 and id = $2", user.ID, id) + // Get the key first so we can record its information in the change history + q := do.MustInvoke[query.QueryServiceTx[models.SshKey]](repo.Injector) + key, err := q.QueryOne(tx, "SELECT * FROM ssh_keys WHERE user_id = $1 AND id = $2", user.ID, id) + if err != nil { + return err + } + + if key == nil { + // Key doesn't exist, nothing to delete + return nil + } + + // Record the change + changeRepo := &SshKeyChangeRepo{Injector: repo.Injector} + change := &models.SshKeyChange{ + SshKeyID: key.ID, + UserID: key.UserID, + ChangeType: models.Deleted, + Filename: key.Filename, + PreviousData: key.Data, + } + + _, err = changeRepo.CreateKeyChangeTx(change, tx) + if err != nil { + return err + } + + // Now delete the key + _, err = tx.Exec(context.TODO(), "DELETE FROM ssh_keys WHERE user_id = $1 AND id = $2", user.ID, id) return err } diff --git a/pkg/database/repository/user_test.go b/pkg/database/repository/user_test.go deleted file mode 100644 index c684cc3..0000000 --- a/pkg/database/repository/user_test.go +++ /dev/null @@ -1,3 +0,0 @@ -package repository - -// TODO diff --git a/test/pgx/mock.go b/test/pgx/mock.go index df1b158..96fdf09 100644 --- a/test/pgx/mock.go +++ b/test/pgx/mock.go @@ -208,4 +208,4 @@ func (m *MockTx) SendBatch(arg0 context.Context, arg1 *pgx.Batch) pgx.BatchResul func (mr *MockTxMockRecorder) SendBatch(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendBatch", reflect.TypeOf((*MockTx)(nil).SendBatch), arg0, arg1) -} +} \ No newline at end of file diff --git a/test/pgx/mock_database.go b/test/pgx/mock_database.go new file mode 100644 index 0000000..bbb4608 --- /dev/null +++ b/test/pgx/mock_database.go @@ -0,0 +1,27 @@ +package pgx + +import ( + "github.com/jackc/pgx/v5" +) + +// MockDatabase is a mock implementation of database for testing +type MockDatabase struct { + MockQuery func(sql string, args ...interface{}) [][]interface{} + MockQueryRow func(sql string, args ...interface{}) []interface{} + MockTxQueryRow func(tx pgx.Tx, sql string, args ...interface{}) []interface{} +} + +// NewMockDatabase creates a new MockDatabase with default implementations +func NewMockDatabase() *MockDatabase { + return &MockDatabase{ + MockQuery: func(sql string, args ...interface{}) [][]interface{} { + return [][]interface{}{} + }, + MockQueryRow: func(sql string, args ...interface{}) []interface{} { + return nil + }, + MockTxQueryRow: func(tx pgx.Tx, sql string, args ...interface{}) []interface{} { + return nil + }, + } +} \ No newline at end of file