From 8800ab42828ef11a11c75d362e6ad4b6857f104d Mon Sep 17 00:00:00 2001 From: Ben Morrison Date: Sun, 2 Jun 2019 02:42:35 -0400 Subject: [PATCH] concurrency protection / precautions --- handlers.go | 20 ++++++++++++++++---- handlers_test.go | 4 ++++ init.go | 11 ++++++++--- main.go | 18 ++++++++++++------ md.go | 7 ++++++- pages.go | 36 ++++++++++++++++++++++++++++++++++++ types.go | 1 + 7 files changed, 83 insertions(+), 14 deletions(-) diff --git a/handlers.go b/handlers.go index c7a349b..bf8779c 100644 --- a/handlers.go +++ b/handlers.go @@ -67,9 +67,13 @@ func indexHandler(w http.ResponseWriter, r *http.Request) { // This is due to the default behavior of // not serving naked paths but virtual ones. func iconHandler(w http.ResponseWriter, r *http.Request) { + confVars.mu.RLock() + assetsDir := confVars.assetsDir + iconPath := confVars.iconPath + confVars.mu.RUnlock() // read the raw bytes of the image - longname := confVars.assetsDir + "/" + confVars.iconPath + longname := assetsDir + "/" + iconPath icon, err := ioutil.ReadFile(longname) if err != nil { if os.IsNotExist(err) { @@ -108,16 +112,20 @@ func iconHandler(w http.ResponseWriter, r *http.Request) { // not serving naked paths but virtual ones. func cssHandler(w http.ResponseWriter, r *http.Request) { + confVars.mu.RLock() + cssPath := confVars.cssPath + confVars.mu.RUnlock() + // check if using local or remote CSS. // if remote, don't bother doing anything // and redirect requests to / - if !cssLocal([]byte(confVars.cssPath)) { + if !cssLocal([]byte(cssPath)) { http.Redirect(w, r, "/", http.StatusFound) return } // read the raw bytes of the stylesheet - css, err := ioutil.ReadFile(confVars.cssPath) + css, err := ioutil.ReadFile(cssPath) if err != nil { if os.IsNotExist(err) { log.Printf("CSS file specified in config does not exist: /css request 404\n") @@ -130,7 +138,7 @@ func cssHandler(w http.ResponseWriter, r *http.Request) { } // stat to get the mod time for the etag header - stat, err := os.Stat(confVars.cssPath) + stat, err := os.Stat(cssPath) if err != nil { log.Printf("Couldn't stat CSS file to send ETag header: %v\n", err) } @@ -168,7 +176,9 @@ func validatePath(fn func(http.ResponseWriter, *http.Request, string)) http.Hand // if the markdown doc can't be read, default to // net/http's error handling func error500(w http.ResponseWriter, _ *http.Request) { + confVars.mu.RLock() e500 := confVars.assetsDir + "/500.md" + confVars.mu.RUnlock() file, err := ioutil.ReadFile(e500) if err != nil { @@ -190,7 +200,9 @@ func error500(w http.ResponseWriter, _ *http.Request) { // if the markdown doc can't be read, default to // net/http's error handling func error404(w http.ResponseWriter, r *http.Request) { + confVars.mu.RLock() e404 := confVars.assetsDir + "/404.md" + confVars.mu.RUnlock() file, err := ioutil.ReadFile(e404) if err != nil { diff --git a/handlers_test.go b/handlers_test.go index 4a19fa7..e686bff 100644 --- a/handlers_test.go +++ b/handlers_test.go @@ -68,7 +68,11 @@ func Test_indexHandler(t *testing.T) { func Test_iconHandler(t *testing.T) { name := "Icon Handler Test" initConfigParams() + + confVars.mu.RLock() icon, _ := ioutil.ReadFile(confVars.assetsDir + "/" + confVars.iconPath) + confVars.mu.RUnlock() + w := httptest.NewRecorder() r := httptest.NewRequest("GET", "localhost:8080/icon", nil) t.Run(name, func(t *testing.T) { diff --git a/init.go b/init.go index 47efe61..9fb145c 100644 --- a/init.go +++ b/init.go @@ -15,8 +15,13 @@ func init() { // set up logging if the config file params // are set - if confVars.fileLogging && !confVars.quietLogging { - if llogfile, err := os.OpenFile(confVars.logFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0600); err == nil { + confVars.mu.RLock() + filog := confVars.fileLogging + qlog := confVars.quietLogging + logfi := confVars.logFile + confVars.mu.RUnlock() + if filog && !qlog { + if llogfile, err := os.OpenFile(logfi, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0600); err == nil { log.SetOutput(llogfile) go func() { @@ -34,7 +39,7 @@ func init() { } // Tell TildeWiki to be quiet, - if confVars.quietLogging { + if qlog { if llogfile, err := os.Open("/dev/null"); err == nil { log.SetOutput(llogfile) diff --git a/main.go b/main.go index 829d8dc..2613309 100644 --- a/main.go +++ b/main.go @@ -19,6 +19,13 @@ var closelog = make(chan bool, 1) func main() { + confVars.mu.RLock() + filog := confVars.fileLogging + qlog := confVars.quietLogging + reversed := confVars.reverseTally + portnum := confVars.port + confVars.mu.RUnlock() + // watch for SIGINT aka ^C // close the log file then exit c := make(chan os.Signal, 1) @@ -28,7 +35,7 @@ func main() { for sigint := range c { log.Printf("\n\nCaught %v. Cleaning up ...\n", sigint) - if confVars.fileLogging { + if filog { // signal to close the log file closelog <- true time.Sleep(50 * time.Millisecond) @@ -52,14 +59,13 @@ func main() { serv.HandleFunc("/500", error500) serv.HandleFunc("/404", error404) - log.Println("**NOTICE** Binding to " + confVars.port) - // let the user know if using reversed page listings - if confVars.reverseTally { + if reversed { log.Printf("**NOTICE** Using reversed page listings on index ... \n") } - portnum := confVars.port + log.Println("**NOTICE** Binding to " + portnum) + server := &http.Server{ Handler: handlers.CompressHandler(ipMiddleware(serv)), Addr: portnum, @@ -73,7 +79,7 @@ func main() { } // signal to close the log file - if confVars.fileLogging || confVars.quietLogging { + if filog || qlog { closelog <- true close(closelog) } diff --git a/md.go b/md.go index 1a3786f..74d6a9b 100644 --- a/md.go +++ b/md.go @@ -8,9 +8,11 @@ import ( func setupMarkdown(css, title string) *bf.HTMLRenderer { // if using local CSS file, use the virtually-served css // path rather than the actual file name + confVars.mu.RLock() if cssLocal([]byte(confVars.cssPath)) { css = "/css" } + confVars.mu.RUnlock() // return the parameters used for the rendering // of markdown to html. @@ -30,5 +32,8 @@ func setupMarkdown(css, title string) *bf.HTMLRenderer { // Wrapper function to generate the parameters above and // pass them to the blackfriday library's parsing function func render(data []byte, title string) []byte { - return bf.Run(data, bf.WithRenderer(setupMarkdown(confVars.cssPath, title))) + confVars.mu.RLock() + cssPath := confVars.cssPath + confVars.mu.RUnlock() + return bf.Run(data, bf.WithRenderer(setupMarkdown(cssPath, title))) } diff --git a/pages.go b/pages.go index 2c22637..2b6439e 100644 --- a/pages.go +++ b/pages.go @@ -56,14 +56,18 @@ func buildPage(filename string) (*Page, error) { title = shortname } if desc != "" { + confVars.mu.RLock() desc = confVars.descSep + " " + desc + confVars.mu.RUnlock() } if author != "" { author = "`by " + author + "`" } // longtitle is used in the tags of the output html + confVars.mu.RLock() longtitle := title + " " + confVars.titleSep + " " + confVars.wikiName + confVars.mu.RUnlock() // store the raw bytes of the document after parsing // from markdown to HTML. @@ -118,28 +122,40 @@ func (indexCache *indexPage) checkCache() bool { // if the last tally time is past the // interval in the config file, re-cache if interval, err := time.ParseDuration(viper.GetString("IndexRefreshInterval")); err == nil { + imutex.RLock() if time.Since(indexCache.LastTally) > interval { + imutex.RUnlock() return true } + imutex.RUnlock() } else { log.Printf("Couldn't parse index refresh interval: %v\n", err) } // if the stored mod time is different // from the file's modtime, re-cache + confVars.mu.RLock() if stat, err := os.Stat(confVars.assetsDir + "/" + confVars.indexFile); err == nil { + imutex.RLock() if stat.ModTime() != indexCache.Modtime { + imutex.RUnlock() + confVars.mu.RUnlock() return true } + imutex.RUnlock() } else { log.Printf("Couldn't stat index page: %v\n", err) } + confVars.mu.RUnlock() // if the last tally time or stored mod time is zero, signal // to re-cache the index + imutex.RLock() if indexCache.LastTally.IsZero() || indexCache.Modtime.IsZero() { + imutex.RUnlock() return true } + imutex.RUnlock() return false } @@ -147,7 +163,9 @@ func (indexCache *indexPage) checkCache() bool { // Re-caches the index page. // This method helps satisfy the cacher interface. func (indexCache *indexPage) cache() error { + confVars.mu.RLock() body := render(genIndex(), confVars.wikiName+" "+confVars.titleSep+" "+confVars.wikiDesc) + confVars.mu.RUnlock() if body == nil { return errors.New("indexPage.cache(): getting nil bytes") } @@ -161,7 +179,9 @@ func (indexCache *indexPage) cache() error { func genIndex() []byte { var err error + confVars.mu.RLock() indexpath := confVars.assetsDir + "/" + confVars.indexFile + confVars.mu.RUnlock() // stat to check mod time stat, err := os.Stat(indexpath) @@ -171,7 +191,10 @@ func genIndex() []byte { // if the index file has been modified, // vaccuum up those bytes into the cache + + imutex.RLock() if indexCache.Modtime != stat.ModTime() { + imutex.RUnlock() imutex.Lock() indexCache.Raw, err = ioutil.ReadFile(indexpath) imutex.Unlock() @@ -179,6 +202,8 @@ func genIndex() []byte { return []byte("Could not open \"" + indexpath + "\"") } + } else { + imutex.RUnlock() } // body holds the bytes of the generated index page being sent to the client. @@ -222,6 +247,7 @@ func tallyPages(buf *bytes.Buffer) { // get a list of files in the directory specified // in the config file parameter "PageDir" + confVars.mu.RLock() if files, err := ioutil.ReadDir(confVars.pageDir); err == nil { // entry is used in the loop to construct the markdown @@ -231,6 +257,7 @@ func tallyPages(buf *bytes.Buffer) { if err != nil || n == 0 { log.Printf("Error writing to buffer: %v\n", err) } + confVars.mu.RUnlock() return } @@ -255,6 +282,7 @@ func tallyPages(buf *bytes.Buffer) { if err != nil { log.Printf("Error writing to buffer: %v\n", err) } + confVars.mu.RUnlock() } // Takes in a file and outputs a markdown link to it. @@ -272,7 +300,9 @@ func writeIndexLinks(f os.FileInfo, buf *bytes.Buffer) { } else { // if it hasn't been cached, cache it. // usually means the page is new. + confVars.mu.RLock() newpage := newBarePage(confVars.pageDir+"/"+f.Name(), f.Name()) + confVars.mu.RUnlock() if err := newpage.cache(); err != nil { log.Printf("While caching page %v during the index generation, caught an error: %v\n", f.Name(), err) } @@ -285,7 +315,9 @@ func writeIndexLinks(f os.FileInfo, buf *bytes.Buffer) { // and write the formatted link to the // bytes.Buffer linkname := bytes.TrimSuffix([]byte(page.Shortname), []byte(".md")) + confVars.mu.RLock() n, err := buf.WriteString("* [" + page.Title + "](" + confVars.viewPath + string(linkname) + ") " + page.Desc + " " + page.Author + "\n") + confVars.mu.RUnlock() if err != nil || n == 0 { log.Printf("Error writing to buffer: %v\n", err) } @@ -334,6 +366,7 @@ func genPageCache() { // spawn a new goroutine for each entry, to cache // everything as quickly as possible + confVars.mu.RLock() if wikipages, err := ioutil.ReadDir(confVars.pageDir); err == nil { var wg sync.WaitGroup for _, f := range wikipages { @@ -341,7 +374,9 @@ func genPageCache() { wg.Add(1) go func(f os.FileInfo) { + confVars.mu.RLock() page := newBarePage(confVars.pageDir+"/"+f.Name(), f.Name()) + confVars.mu.RLock() if err := page.cache(); err != nil { log.Printf("While generating initial cache, caught error for %v: %v\n", f.Name(), err) } @@ -358,6 +393,7 @@ func genPageCache() { log.Printf("**NOTICE** TildeWiki's cache may not function correctly until this is resolved.\n") log.Printf("\tPlease verify the directory in tildewiki.yml is correct and restart TildeWiki\n") } + confVars.mu.RLock() } // Wrapper function to check the cache diff --git a/types.go b/types.go index 55e9737..8278af1 100644 --- a/types.go +++ b/types.go @@ -30,6 +30,7 @@ type ipCtxKey int const ctxKey ipCtxKey = iota type confParams struct { + mu sync.RWMutex port string pageDir string assetsDir string