diff --git a/main.go b/main.go index 3eb05a8..7908b33 100644 --- a/main.go +++ b/main.go @@ -7,6 +7,7 @@ package main import ( "flag" "fmt" + "io" "log" "os" @@ -19,17 +20,17 @@ import ( var ( Version = "dev" - inline = flag.Bool("inline", false, "parse rule inlining") - _switch = flag.Bool("switch", false, "replace if-else if-else like blocks with switch blocks") - // Avoid redefinition of built-in function print. + inline = flag.Bool("inline", false, "parse rule inlining") + switchFlag = flag.Bool("switch", false, "replace if-else if-else like blocks with switch blocks") printFlag = flag.Bool("print", false, "directly dump the syntax tree") syntax = flag.Bool("syntax", false, "print out the syntax tree") noast = flag.Bool("noast", false, "disable AST") strict = flag.Bool("strict", false, "treat compiler warnings as errors") - filename = flag.String("output", "", "specify name of output file") + outputFile = flag.String("output", "", "output to `FILE` (\"-\" for stdout)") showVersion = flag.Bool("version", false, "print the version and exit") ) +// main is the entry point for the PEG compiler. func main() { flag.Parse() @@ -38,49 +39,84 @@ func main() { return } - if flag.NArg() != 1 { - flag.Usage() - log.Fatalf("FILE: the peg file to compile") - } - file := flag.Arg(0) + err := parse( + func(p *Peg[uint32], out io.Writer) error { + if *printFlag { + p.Print() + } + if *syntax { + p.PrintSyntaxTree() + } - buffer, err := os.ReadFile(file) + p.Strict = *strict + if err := p.Compile(*outputFile, os.Args, out); err != nil { + return err + } + return nil + }, + ) if err != nil { - log.Fatal(err) - } - - p := &Peg[uint32]{Tree: tree.New(*inline, *_switch, *noast), Buffer: string(buffer)} - _ = p.Init(Pretty[uint32](true), Size[uint32](1<<15)) - if err := p.Parse(); err != nil { - log.Fatal(err) + if *strict { + log.Fatal(err) + } + fmt.Fprintln(os.Stderr, "warning:", err) } +} - p.Execute() +// getIO returns input and output streams based on command-line flags. +func getIO() (in io.ReadCloser, out io.WriteCloser, err error) { + in, out = os.Stdin, os.Stdout - if *printFlag { - p.Print() - } - if *syntax { - p.PrintSyntaxTree() + if flag.NArg() > 0 && flag.Arg(0) != "-" { + in, err = os.Open(flag.Arg(0)) + if err != nil { + return nil, nil, err + } + if *outputFile == "" { + *outputFile = flag.Arg(0) + ".go" + } } - if *filename == "" { - *filename = file + ".go" + if *outputFile != "" && *outputFile != "-" { + out, err = os.OpenFile(*outputFile, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0o644) + if err != nil { + if in != nil && in != os.Stdin { + in.Close() + } + return nil, nil, err + } } - out, err := os.OpenFile(*filename, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0o644) + + return in, out, nil +} + +// parse reads input, parses, executes, and compiles the PEG grammar. +func parse(compile func(*Peg[uint32], io.Writer) error) error { + in, out, err := getIO() if err != nil { - fmt.Printf("%v: %v\n", *filename, err) - return + return err } defer func() { - err := out.Close() - if err != nil { - log.Fatal(err) + if in != nil && in != os.Stdin { + in.Close() + } + if out != nil && out != os.Stdout { + out.Close() } }() - p.Strict = *strict - if err = p.Compile(*filename, os.Args, out); err != nil { - log.Fatal(err) + buffer, err := io.ReadAll(in) + if err != nil { + return err + } + + p := &Peg[uint32]{Tree: tree.New(*inline, *switchFlag, *noast), Buffer: string(buffer)} + _ = p.Init(Pretty[uint32](true), Size[uint32](1<<15)) + if err = p.Parse(); err != nil { + return err } + + p.Execute() + + return compile(p, out) }