Skip to content

Commit

Permalink
feat: improve the ReadSchema and FindSQLDir functions to not require …
Browse files Browse the repository at this point in the history
…any input parameters and look for the directory that contains a `schema.sql` file
  • Loading branch information
bradub committed Nov 13, 2023
1 parent 121e985 commit bf107cd
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 59 deletions.
56 changes: 35 additions & 21 deletions psqlutil/find_sql_directory.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,37 +2,51 @@ package psqlutil

import (
"fmt"
"io/fs"
"os"
"path/filepath"
)

// FindSQLDir attempts to compose the path to the sql directory in the project.
func FindSQLDir(projectDirectoryName string) (string, error) {
p, err := getDirectoryPath(projectDirectoryName)
// nolint: gocognit // allow high cog complexity.
func FindSQLDir() (string, error) {
projectRoot, err := getProjectRoot()
if err != nil {
return "", fmt.Errorf("get project directory: %w", err)
return "", fmt.Errorf("get project root: %w", err)
}

sqlPath := filepath.Clean(
filepath.Join(
string(os.PathSeparator),
filepath.Join(
p,
"sql",
),
),
)

_, err = os.Stat(
sqlPath,
)
if err != nil {
if os.IsNotExist(err) {
return "", err
var sqlDirPath string

if err := filepath.WalkDir(projectRoot, func(path string, d os.DirEntry, err error) error {
if err != nil {
return err // Propagate any error encountered.
}

if d.IsDir() {
// Check for 'schema.sql' inside this directory.
schemaFilePath := filepath.Join(path, "schema.sql")

if _, err := os.Stat(schemaFilePath); err != nil {
if os.IsNotExist(err) {
return nil // File not found, continue walking.
}

return fmt.Errorf("stat %q: %w", schemaFilePath, err)
}

sqlDirPath = path // Found the directory containing 'schema.sql'.

return fs.SkipAll // Throw an error to stop the walk early
}

return "", fmt.Errorf("err while checking if sql dir exists: %w", err)
return nil
}); err != nil {
return "", fmt.Errorf("walk project root: %w", err)
}

if sqlDirPath == "" {
return "", fs.ErrNotExist
}

return sqlPath, nil
return sqlDirPath, nil
}
68 changes: 30 additions & 38 deletions psqlutil/read_schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,67 +6,59 @@ import (
"os"
"path/filepath"
"runtime"
"strings"
)

// ErrUnableToResolveCaller is returned when the caller CWD cannot be retrieved.
var ErrUnableToResolveCaller = errors.New("unable to resolve caller")

// ReadSchema reads schema dynamically based on the CWD of the caller.
func ReadSchema(projectDirectoryName string) (string, error) {
path, err := getDirectoryPath(projectDirectoryName)
func ReadSchema() (string, error) {
sqlDir, err := FindSQLDir()
if err != nil {
return "", fmt.Errorf("get project directory: %w", err)
return "", fmt.Errorf("find sql dir: %w", err)
}

schemaPath := filepath.Clean(
filepath.Join(
string(os.PathSeparator),
filepath.Join(
path,
"sql",
"schema.sql",
),
),
)
const schemaFile = "schema.sql"

schemaB, err := os.ReadFile(schemaPath)
schemaB, err := os.ReadFile(filepath.Join(filepath.Clean(sqlDir), schemaFile))
if err != nil {
return "", fmt.Errorf("err while reading schema: %w", err)
return "", fmt.Errorf("read schema: %w", err)
}

return string(schemaB), nil
}

// ErrDirectoryNotFound is returned when the
// project directory is not found.
var ErrDirectoryNotFound = errors.New("directory not found")
var (
// ErrUnableToResolveCaller is returned when the caller CWD cannot be retrieved.
ErrUnableToResolveCaller = errors.New("unable to resolve caller")

// ErrProjectRootNotFound is returned when the
// project root is not found.
ErrProjectRootNotFound = errors.New("project root not found")
)

func getProjectRoot() (string, error) {
rootIndicators := []string{"go.mod"}

func getDirectoryPath(directoryName string) (string, error) {
_, filename, _, ok := runtime.Caller(0)
if !ok {
return "", ErrUnableToResolveCaller
}

pathParts := strings.Split(filename, string(os.PathSeparator))
dir := filepath.Dir(filename)

var directoryPath string

// reverse range over path parts to find the directory
// absolute path
for directoryPath == "" && len(pathParts) > 0 {
if pathParts[len(pathParts)-1] != directoryName {
pathParts = pathParts[:len(pathParts)-1]

continue
// Walk up the directory tree until we find the project root.
for {
for _, indicator := range rootIndicators {
if _, err := os.Stat(filepath.Join(dir, indicator)); err == nil {
return dir, nil
}
}

directoryPath = filepath.Join(pathParts...)
}
parentDir := filepath.Dir(dir)
if parentDir == dir {
break // we've reached the root of the filesystem and didn't find the project root
}

if directoryPath == "" {
return "", ErrDirectoryNotFound
dir = parentDir
}

return directoryPath, nil
return "", ErrProjectRootNotFound
}

0 comments on commit bf107cd

Please sign in to comment.