diff --git a/auxlib.go b/auxlib.go index cb3dd24d..15199390 100644 --- a/auxlib.go +++ b/auxlib.go @@ -3,6 +3,7 @@ package lua import ( "bufio" "fmt" + "io" "os" "strings" ) @@ -357,7 +358,7 @@ func (ls *LState) LoadFile(path string) (*LFunction, error) { reader := bufio.NewReader(file) // get the first character. c, err := reader.ReadByte() - if err != nil { + if err != nil && err != io.EOF { return nil, newApiErrorE(ApiErrorFile, err) } if c == byte('#') { @@ -368,7 +369,15 @@ func (ls *LState) LoadFile(path string) (*LFunction, error) { return nil, newApiErrorE(ApiErrorFile, err) } } - reader.UnreadByte() + + if err != io.EOF { + // if the file is not empty, + // unread the first character of the file or newline character(readBufioLine's last byte). + err = reader.UnreadByte() + if err != nil { + return nil, newApiErrorE(ApiErrorFile, err) + } + } return ls.Load(reader, path) } diff --git a/auxlib_test.go b/auxlib_test.go index 383f1c86..5dff1df1 100644 --- a/auxlib_test.go +++ b/auxlib_test.go @@ -315,3 +315,19 @@ print("hello") _, err = L.LoadFile(tmpFile.Name()) errorIfNotNil(t, err) } + +func TestLoadFileForEmptyFile(t *testing.T) { + tmpFile, err := ioutil.TempFile("", "") + errorIfNotNil(t, err) + + defer func() { + tmpFile.Close() + os.Remove(tmpFile.Name()) + }() + + L := NewState() + defer L.Close() + + _, err = L.LoadFile(tmpFile.Name()) + errorIfNotNil(t, err) +}