diff --git a/caddy/pmtiles_proxy.go b/caddy/pmtiles_proxy.go index 0c2cb6d..fc5ed01 100644 --- a/caddy/pmtiles_proxy.go +++ b/caddy/pmtiles_proxy.go @@ -2,6 +2,12 @@ package caddy import ( "fmt" + "io" + "log" + "net/http" + "strconv" + "time" + "github.com/caddyserver/caddy/v2" "github.com/caddyserver/caddy/v2/caddyconfig/caddyfile" "github.com/caddyserver/caddy/v2/caddyconfig/httpcaddyfile" @@ -12,11 +18,6 @@ import ( _ "gocloud.dev/blob/fileblob" _ "gocloud.dev/blob/gcsblob" _ "gocloud.dev/blob/s3blob" - "io" - "log" - "net/http" - "strconv" - "time" ) func init() { @@ -66,12 +67,7 @@ func (m *Middleware) Validate() error { func (m Middleware) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyhttp.Handler) error { start := time.Now() - statusCode, headers, body := m.server.Get(r.Context(), r.URL.Path) - for k, v := range headers { - w.Header().Set(k, v) - } - w.WriteHeader(statusCode) - w.Write(body) + statusCode := m.server.ServeHTTP(w, r) m.logger.Info("response", zap.Int("status", statusCode), zap.String("path", r.URL.Path), zap.Duration("duration", time.Since(start))) return next.ServeHTTP(w, r) diff --git a/go.mod b/go.mod index 379eef5..e5b256a 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ require ( github.com/alecthomas/kong v0.8.0 github.com/aws/aws-sdk-go v1.45.12 github.com/caddyserver/caddy/v2 v2.7.5 + github.com/cespare/xxhash/v2 v2.2.0 github.com/dustin/go-humanize v1.0.1 github.com/paulmach/orb v0.10.0 github.com/prometheus/client_golang v1.18.0 @@ -62,7 +63,6 @@ require ( github.com/bits-and-blooms/bitset v1.2.0 // indirect github.com/caddyserver/certmagic v0.19.2 // indirect github.com/cespare/xxhash v1.1.0 // indirect - github.com/cespare/xxhash/v2 v2.2.0 // indirect github.com/chzyer/readline v1.5.1 // indirect github.com/cpuguy83/go-md2man/v2 v2.0.2 // indirect github.com/davecgh/go-spew v1.1.1 // indirect diff --git a/main.go b/main.go index 524b398..e6e1a80 100644 --- a/main.go +++ b/main.go @@ -140,12 +140,7 @@ func main() { http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { start := time.Now() - statusCode, headers, body := server.Get(r.Context(), r.URL.Path) - for k, v := range headers { - w.Header().Set(k, v) - } - w.WriteHeader(statusCode) - w.Write(body) + statusCode := server.ServeHTTP(w, r) logger.Printf("served %d %s in %s", statusCode, r.URL.Path, time.Since(start)) }) diff --git a/pmtiles/bucket.go b/pmtiles/bucket.go index a8cc269..e83c15d 100644 --- a/pmtiles/bucket.go +++ b/pmtiles/bucket.go @@ -3,7 +3,7 @@ package pmtiles import ( "bytes" "context" - "crypto/md5" + "encoding/binary" "encoding/hex" "errors" "fmt" @@ -17,6 +17,7 @@ import ( "github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/service/s3" + "github.com/cespare/xxhash/v2" "gocloud.dev/blob" ) @@ -55,8 +56,7 @@ func (m mockBucket) NewRangeReaderEtag(_ context.Context, key string, offset int return nil, "", fmt.Errorf("Not found %s", key) } - hash := md5.Sum(bs) - resultEtag := hex.EncodeToString(hash[:]) + resultEtag := generateEtag(bs) if len(etag) > 0 && resultEtag != etag { return nil, "", &RefreshRequiredError{} } @@ -77,6 +77,31 @@ func (b FileBucket) NewRangeReader(ctx context.Context, key string, offset, leng return body, err } +func uintToBytes(n uint64) []byte { + bs := make([]byte, 8) + binary.LittleEndian.PutUint64(bs, n) + return bs +} + +func hasherToEtag(hasher *xxhash.Digest) string { + sum := uintToBytes(hasher.Sum64()) + return fmt.Sprintf(`"%s"`, hex.EncodeToString(sum)) +} + +func generateEtag(data []byte) string { + hasher := xxhash.New() + hasher.Write(data) + return hasherToEtag(hasher) +} + +func generateEtagFromInts(ns ...int64) string { + hasher := xxhash.New() + for _, n := range ns { + hasher.Write(uintToBytes(uint64(n))) + } + return hasherToEtag(hasher) +} + func (b FileBucket) NewRangeReaderEtag(_ context.Context, key string, offset, length int64, etag string) (io.ReadCloser, string, error) { name := filepath.Join(b.path, key) file, err := os.Open(name) @@ -88,9 +113,7 @@ func (b FileBucket) NewRangeReaderEtag(_ context.Context, key string, offset, le if err != nil { return nil, "", err } - modInfo := fmt.Sprintf("%d %d", info.ModTime().UnixNano(), info.Size()) - hash := md5.Sum([]byte(modInfo)) - newEtag := fmt.Sprintf(`"%s"`, hex.EncodeToString(hash[:])) + newEtag := generateEtagFromInts(info.ModTime().UnixNano(), info.Size()) if len(etag) > 0 && etag != newEtag { return nil, "", &RefreshRequiredError{} } diff --git a/pmtiles/server.go b/pmtiles/server.go index 74df12f..309df28 100644 --- a/pmtiles/server.go +++ b/pmtiles/server.go @@ -9,8 +9,10 @@ import ( "errors" "io" "log" + "net/http" "regexp" "strconv" + "time" "github.com/prometheus/client_golang/prometheus" ) @@ -294,6 +296,7 @@ func (server *Server) getTileJSON(ctx context.Context, httpHeaders map[string]st } httpHeaders["Content-Type"] = "application/json" + httpHeaders["Etag"] = generateEtag(tilejsonBytes) return 200, httpHeaders, tilejsonBytes } @@ -310,6 +313,7 @@ func (server *Server) getMetadata(ctx context.Context, httpHeaders map[string]st } httpHeaders["Content-Type"] = "application/json" + httpHeaders["Etag"] = generateEtag(metadataBytes) return 200, httpHeaders, metadataBytes } func (server *Server) getTile(ctx context.Context, httpHeaders map[string]string, name string, z uint8, x uint32, y uint32, ext string) (int, map[string]string, []byte) { @@ -320,6 +324,7 @@ func (server *Server) getTile(ctx context.Context, httpHeaders map[string]string } return status, headers, data } + func (server *Server) getTileAttempt(ctx context.Context, httpHeaders map[string]string, name string, z uint8, x uint32, y uint32, ext string, purgeEtag string) (int, map[string]string, []byte, string) { rootReq := request{key: cacheKey{name: name, offset: 0, length: 0}, value: make(chan cachedValue, 1), purgeEtag: purgeEtag} server.reqs <- rootReq @@ -390,6 +395,8 @@ func (server *Server) getTileAttempt(ctx context.Context, httpHeaders map[string if err != nil { return 500, httpHeaders, []byte("I/O error"), "" } + + httpHeaders["Etag"] = generateEtag(b) if headerVal, ok := headerContentType(header); ok { httpHeaders["Content-Type"] = headerVal } @@ -465,3 +472,24 @@ func (server *Server) Get(ctx context.Context, path string) (int, map[string]str return 404, httpHeaders, []byte("Path not found") } + +// Serve an HTTP response from the archive +func (server *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) int { + statusCode, headers, body := server.Get(r.Context(), r.URL.Path) + for k, v := range headers { + w.Header().Set(k, v) + } + if statusCode == 200 { + // handle if-match, if-none-match request headers based on response etag + http.ServeContent( + w, r, + "", // name used to infer content-type, but we've already set that + time.UnixMilli(0), // ignore setting last-modified time and handling if-modified-since headers + bytes.NewReader(body), + ) + } else { + w.WriteHeader(statusCode) + w.Write(body) + } + return statusCode +} diff --git a/pmtiles/server_test.go b/pmtiles/server_test.go index b7063a7..191d362 100644 --- a/pmtiles/server_test.go +++ b/pmtiles/server_test.go @@ -370,3 +370,47 @@ func TestInvalidateCacheOnMetadataRequest(t *testing.T) { "meta": "data2" }`, string(data)) } + +func TestEtagResponsesFromTile(t *testing.T) { + mockBucket, server := newServer(t) + header := HeaderV3{ + TileType: Mvt, + } + mockBucket.items["archive.pmtiles"] = fakeArchive(t, header, map[string]interface{}{}, map[Zxy][]byte{ + {0, 0, 0}: {0, 1, 2, 3}, + {4, 1, 2}: {1, 2, 3}, + }, false) + + statusCode, headers000v1, _ := server.Get(context.Background(), "/archive/0/0/0.mvt") + assert.Equal(t, 200, statusCode) + statusCode, headers412v1, _ := server.Get(context.Background(), "/archive/4/1/2.mvt") + assert.Equal(t, 200, statusCode) + statusCode, headers311v1, _ := server.Get(context.Background(), "/archive/3/1/1.mvt") + assert.Equal(t, 204, statusCode) + + mockBucket.items["archive.pmtiles"] = fakeArchive(t, header, map[string]interface{}{}, map[Zxy][]byte{ + {0, 0, 0}: {0, 1, 2, 3}, + {4, 1, 2}: {1, 2, 3, 4}, // different + }, false) + + statusCode, headers000v2, _ := server.Get(context.Background(), "/archive/0/0/0.mvt") + assert.Equal(t, 200, statusCode) + statusCode, headers412v2, _ := server.Get(context.Background(), "/archive/4/1/2.mvt") + assert.Equal(t, 200, statusCode) + statusCode, headers311v2, _ := server.Get(context.Background(), "/archive/3/1/1.mvt") + assert.Equal(t, 204, statusCode) + + // 204's have no etag + assert.Equal(t, "", headers311v1["Etag"]) + assert.Equal(t, "", headers311v2["Etag"]) + + // 000 and 311 didn't change + assert.Equal(t, headers000v1["Etag"], headers000v2["Etag"]) + + // 412 did change + assert.NotEqual(t, headers412v1["Etag"], headers412v2["Etag"]) + + // all are different + assert.NotEqual(t, headers000v1["Etag"], headers311v1["Etag"]) + assert.NotEqual(t, headers000v1["Etag"], headers412v1["Etag"]) +}