Skip to content

Commit

Permalink
Support MSC3860 download redirection behaviour
Browse files Browse the repository at this point in the history
Fixes #540
  • Loading branch information
turt2live committed Jan 15, 2024
1 parent fdf44d4 commit cbc3677
Show file tree
Hide file tree
Showing 12 changed files with 157 additions and 20 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Expand Up @@ -13,6 +13,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
* Exporting MMR's data to Synapse is now possible with `import_to_synapse`. To use it, first run `gdpr_export` or similar.
* Errors encountered during a background task, such as an API-induced export, are exposed as `error_message` in the admin API.
* MMR will follow redirects on federated downloads up to 5 hops.
* S3-backed datastores can have download requests redirected to a public-facing CDN rather than being proxied through MMR. See `publicBaseUrl` under the S3 datastore config.

### Changed

Expand Down
9 changes: 9 additions & 0 deletions api/_responses/redirect.go
@@ -0,0 +1,9 @@
package _responses

type RedirectResponse struct {
ToUrl string
}

func Redirect(url string) *RedirectResponse {
return &RedirectResponse{ToUrl: url}
}
8 changes: 8 additions & 0 deletions api/_routers/98-use-rcontext.go
Expand Up @@ -59,6 +59,14 @@ func (c *RContextRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) {

headers := w.Header()

// Check for redirection early
if redirect, isRedirect := res.(*_responses.RedirectResponse); isRedirect {
log.Infof("Replying with result: %T <%s>", res, redirect.ToUrl)
headers.Set("Location", redirect.ToUrl)
r = writeStatusCode(w, r, http.StatusTemporaryRedirect)
return // we're done here
}

// Check for HTML response and reply accordingly
if htmlRes, isHtml := res.(*_responses.HtmlResponse); isHtml {
log.Infof("Replying with result: %T <%d chars of html>", res, len(htmlRes.HTML))
Expand Down
24 changes: 20 additions & 4 deletions api/r0/download.go
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/turt2live/matrix-media-repo/api/_apimeta"
"github.com/turt2live/matrix-media-repo/api/_responses"
"github.com/turt2live/matrix-media-repo/api/_routers"
"github.com/turt2live/matrix-media-repo/datastores"
"github.com/turt2live/matrix-media-repo/pipelines/pipeline_download"
"github.com/turt2live/matrix-media-repo/util"

Expand All @@ -22,6 +23,7 @@ func DownloadMedia(r *http.Request, rctx rcontext.RequestContext, user _apimeta.
mediaId := _routers.GetParam("mediaId", r)
filename := _routers.GetParam("filename", r)
allowRemote := r.URL.Query().Get("allow_remote")
allowRedirect := r.URL.Query().Get("allow_redirect")
timeoutMs := r.URL.Query().Get("timeout_ms")

if !_routers.ServerNameRegex.MatchString(server) {
Expand All @@ -37,16 +39,26 @@ func DownloadMedia(r *http.Request, rctx rcontext.RequestContext, user _apimeta.
downloadRemote = parsedFlag
}

canRedirect := false
if allowRedirect != "" {
parsedFlag, err := strconv.ParseBool(allowRedirect)
if err != nil {
return _responses.BadRequest("allow_redirect flag does not appear to be a boolean")
}
canRedirect = parsedFlag
}

blockFor, err := util.CalcBlockForDuration(timeoutMs)
if err != nil {
return _responses.BadRequest("timeout_ms does not appear to be an integer")
}

rctx = rctx.LogWithFields(logrus.Fields{
"mediaId": mediaId,
"server": server,
"filename": filename,
"allowRemote": downloadRemote,
"mediaId": mediaId,
"server": server,
"filename": filename,
"allowRemote": downloadRemote,
"allowRedirect": canRedirect,
})

if !util.IsGlobalAdmin(user.UserId) && util.IsHostIgnored(server) {
Expand All @@ -57,8 +69,10 @@ func DownloadMedia(r *http.Request, rctx rcontext.RequestContext, user _apimeta.
media, stream, err := pipeline_download.Execute(rctx, server, mediaId, pipeline_download.DownloadOpts{
FetchRemoteIfNeeded: downloadRemote,
BlockForReadUntil: blockFor,
CanRedirect: canRedirect,
})
if err != nil {
var redirect datastores.RedirectError
if errors.Is(err, common.ErrMediaNotFound) {
return _responses.NotFoundError()
} else if errors.Is(err, common.ErrMediaTooLarge) {
Expand All @@ -72,6 +86,8 @@ func DownloadMedia(r *http.Request, rctx rcontext.RequestContext, user _apimeta.
}
} else if errors.Is(err, common.ErrMediaNotYetUploaded) {
return _responses.NotYetUploaded()
} else if errors.As(err, &redirect) {
return _responses.Redirect(redirect.RedirectUrl)
}
rctx.Log.Error("Unexpected error locating media: ", err)
sentry.CaptureException(err)
Expand Down
22 changes: 19 additions & 3 deletions api/r0/thumbnail.go
Expand Up @@ -10,6 +10,7 @@ import (
"github.com/turt2live/matrix-media-repo/api/_responses"
"github.com/turt2live/matrix-media-repo/api/_routers"
"github.com/turt2live/matrix-media-repo/database"
"github.com/turt2live/matrix-media-repo/datastores"
"github.com/turt2live/matrix-media-repo/pipelines/pipeline_download"
"github.com/turt2live/matrix-media-repo/pipelines/pipeline_thumbnail"
"github.com/turt2live/matrix-media-repo/util"
Expand All @@ -23,6 +24,7 @@ func ThumbnailMedia(r *http.Request, rctx rcontext.RequestContext, user _apimeta
server := _routers.GetParam("server", r)
mediaId := _routers.GetParam("mediaId", r)
allowRemote := r.URL.Query().Get("allow_remote")
allowRedirect := r.URL.Query().Get("allow_redirect")
timeoutMs := r.URL.Query().Get("timeout_ms")

if !_routers.ServerNameRegex.MatchString(server) {
Expand All @@ -38,15 +40,25 @@ func ThumbnailMedia(r *http.Request, rctx rcontext.RequestContext, user _apimeta
downloadRemote = parsedFlag
}

canRedirect := false
if allowRedirect != "" {
parsedFlag, err := strconv.ParseBool(allowRedirect)
if err != nil {
return _responses.BadRequest("allow_redirect flag does not appear to be a boolean")
}
canRedirect = parsedFlag
}

blockFor, err := util.CalcBlockForDuration(timeoutMs)
if err != nil {
return _responses.BadRequest("timeout_ms does not appear to be an integer")
}

rctx = rctx.LogWithFields(logrus.Fields{
"mediaId": mediaId,
"server": server,
"allowRemote": downloadRemote,
"mediaId": mediaId,
"server": server,
"allowRemote": downloadRemote,
"allowRedirect": canRedirect,
})

if !util.IsGlobalAdmin(user.UserId) && util.IsHostIgnored(server) {
Expand Down Expand Up @@ -111,13 +123,15 @@ func ThumbnailMedia(r *http.Request, rctx rcontext.RequestContext, user _apimeta
FetchRemoteIfNeeded: downloadRemote,
BlockForReadUntil: blockFor,
RecordOnly: false, // overridden
CanRedirect: canRedirect,
},
Width: width,
Height: height,
Method: method,
Animated: animated,
})
if err != nil {
var redirect datastores.RedirectError
if errors.Is(err, common.ErrMediaNotFound) {
return _responses.NotFoundError()
} else if errors.Is(err, common.ErrMediaTooLarge) {
Expand Down Expand Up @@ -152,6 +166,8 @@ func ThumbnailMedia(r *http.Request, rctx rcontext.RequestContext, user _apimeta
TargetDisposition: "infer",
}
}
} else if errors.As(err, &redirect) {
return _responses.Redirect(redirect.RedirectUrl)
}
rctx.Log.Error("Unexpected error locating media: ", err)
sentry.CaptureException(err)
Expand Down
9 changes: 9 additions & 0 deletions config.sample.yaml
Expand Up @@ -199,6 +199,15 @@ datastores:
# An optional storage class for tuning how the media is stored at s3.
# See https://aws.amazon.com/s3/storage-classes/ for details; uncomment to use.
#storageClass: STANDARD
# When set, if the requesting user/server supports being redirected, and MMR is capable
# of performing that redirection, they will be redirected to the given object location.
# The object ID used in S3 is assumed to be the file name, and will simply be appended.
# It is therefore important to include any trailing slashes or path information. For
# example, an object with ID "hello/world" will get converted to "https://mycdn.example.org/hello/world".
# Note that MMR may not redirect in all cases, even if the client/server requests the
# capability. MMR may still be responsible for bandwidth charges incurred from going to
# the bucket directly.
#publicBaseUrl: "https://mycdn.example.org/"

# Options for controlling archives. Archives are exports of a particular user's content for
# the purpose of GDPR or moving media to a different server.
Expand Down
19 changes: 19 additions & 0 deletions datastores/download.go
Expand Up @@ -2,6 +2,7 @@ package datastores

import (
"errors"
"fmt"
"io"
"os"
"path"
Expand Down Expand Up @@ -35,3 +36,21 @@ func Download(ctx rcontext.RequestContext, ds config.DatastoreConfig, dsFileName

return rsc, err
}

func DownloadOrRedirect(ctx rcontext.RequestContext, ds config.DatastoreConfig, dsFileName string) (io.ReadSeekCloser, error) {
if ds.Type != "s3" {
return Download(ctx, ds, dsFileName)
}

s3c, err := getS3(ds)
if err != nil {
return nil, err
}

if s3c.publicBaseUrl != "" {
metrics.S3Operations.With(prometheus.Labels{"operation": "RedirectGetObject"}).Inc()
return nil, redirect(fmt.Sprintf("%s%s", s3c.publicBaseUrl, dsFileName))
}

return Download(ctx, ds, dsFileName)
}
15 changes: 15 additions & 0 deletions datastores/redirect.go
@@ -0,0 +1,15 @@
package datastores

import "errors"

type RedirectError struct {
error
RedirectUrl string
}

func redirect(url string) RedirectError {
return RedirectError{
error: errors.New("redirection"),
RedirectUrl: url,
}
}
15 changes: 9 additions & 6 deletions datastores/s3.go
Expand Up @@ -16,9 +16,10 @@ import (
var s3clients = &sync.Map{}

type s3 struct {
client *minio.Client
storageClass string
bucket string
client *minio.Client
storageClass string
bucket string
publicBaseUrl string
}

func ResetS3Clients() {
Expand All @@ -37,6 +38,7 @@ func getS3(ds config.DatastoreConfig) (*s3, error) {
region := ds.Options["region"]
storageClass, hasStorageClass := ds.Options["storageClass"]
useSslStr, hasSsl := ds.Options["ssl"]
publicBaseUrl := ds.Options["publicBaseUrl"]

if !hasStorageClass {
storageClass = "STANDARD"
Expand All @@ -59,9 +61,10 @@ func getS3(ds config.DatastoreConfig) (*s3, error) {
}

s3c := &s3{
client: client,
storageClass: storageClass,
bucket: bucket,
client: client,
storageClass: storageClass,
bucket: bucket,
publicBaseUrl: publicBaseUrl,
}
s3clients.Store(ds.Id, s3c)
return s3c, nil
Expand Down
33 changes: 30 additions & 3 deletions pipelines/_steps/download/open_stream.go
Expand Up @@ -4,6 +4,7 @@ import (
"errors"
"io"

"github.com/turt2live/matrix-media-repo/common/config"
"github.com/turt2live/matrix-media-repo/common/rcontext"
"github.com/turt2live/matrix-media-repo/database"
"github.com/turt2live/matrix-media-repo/datastores"
Expand All @@ -12,16 +13,42 @@ import (
)

func OpenStream(ctx rcontext.RequestContext, media *database.Locatable) (io.ReadSeekCloser, error) {
reader, ds, err := doOpenStream(ctx, media)
if err != nil {
return nil, err
}
if reader != nil {
ctx.Log.Debugf("Got %s from cache", media.Sha256Hash)
return readers.NopSeekCloser(reader), nil
}

return datastores.Download(ctx, ds, media.Location)
}

func OpenOrRedirect(ctx rcontext.RequestContext, media *database.Locatable) (io.ReadSeekCloser, error) {
reader, ds, err := doOpenStream(ctx, media)
if err != nil {
return nil, err
}
if reader != nil {
ctx.Log.Debugf("Got %s from cache", media.Sha256Hash)
return readers.NopSeekCloser(reader), nil
}

return datastores.DownloadOrRedirect(ctx, ds, media.Location)
}

func doOpenStream(ctx rcontext.RequestContext, media *database.Locatable) (io.ReadSeekCloser, config.DatastoreConfig, error) {
reader, err := redislib.TryGetMedia(ctx, media.Sha256Hash)
if err != nil || reader != nil {
ctx.Log.Debugf("Got %s from cache", media.Sha256Hash)
return readers.NopSeekCloser(reader), err
return readers.NopSeekCloser(reader), config.DatastoreConfig{}, err
}

ds, ok := datastores.Get(ctx, media.DatastoreId)
if !ok {
return nil, errors.New("unable to locate datastore for media")
return nil, ds, errors.New("unable to locate datastore for media")
}

return datastores.Download(ctx, ds, media.Location)
return nil, ds, nil
}
9 changes: 7 additions & 2 deletions pipelines/pipeline_download/pipeline.go
Expand Up @@ -30,10 +30,11 @@ type DownloadOpts struct {
FetchRemoteIfNeeded bool
BlockForReadUntil time.Duration
RecordOnly bool
CanRedirect bool
}

func (o DownloadOpts) String() string {
return fmt.Sprintf("f=%t,b=%s,r=%t", o.FetchRemoteIfNeeded, o.BlockForReadUntil.String(), o.RecordOnly)
return fmt.Sprintf("f=%t,b=%s,r=%t,d=%t", o.FetchRemoteIfNeeded, o.BlockForReadUntil.String(), o.RecordOnly, o.CanRedirect)
}

func Execute(ctx rcontext.RequestContext, origin string, mediaId string, opts DownloadOpts) (*database.DbMedia, io.ReadCloser, error) {
Expand Down Expand Up @@ -71,7 +72,11 @@ func Execute(ctx rcontext.RequestContext, origin string, mediaId string, opts Do
if opts.RecordOnly {
return nil, nil
}
return download.OpenStream(ctx, record.Locatable)
if opts.CanRedirect {
return download.OpenOrRedirect(ctx, record.Locatable)
} else {
return download.OpenStream(ctx, record.Locatable)
}
}

// Step 4: Media record unknown - download it (if possible)
Expand Down
13 changes: 11 additions & 2 deletions pipelines/pipeline_thumbnail/pipeline.go
Expand Up @@ -105,14 +105,23 @@ func Execute(ctx rcontext.RequestContext, origin string, mediaId string, opts Th
if opts.RecordOnly {
return nil, nil
}
return download.OpenStream(ctx, record.Locatable)
if opts.CanRedirect {
return download.OpenOrRedirect(ctx, record.Locatable)
} else {
return download.OpenStream(ctx, record.Locatable)
}
}

// Step 6: Generate the thumbnail and return that
record, r, err := thumbnails.Generate(ctx, mediaRecord, opts.Width, opts.Height, opts.Method, opts.Animated)
if err != nil {
if !opts.RecordOnly && errors.Is(err, common.ErrMediaDimensionsTooSmall) {
d, err := download.OpenStream(ctx, mediaRecord.Locatable)
var d io.ReadSeekCloser
if opts.CanRedirect {
d, err = download.OpenOrRedirect(ctx, mediaRecord.Locatable)
} else {
d, err = download.OpenStream(ctx, mediaRecord.Locatable)
}
if err != nil {
return nil, err
} else {
Expand Down

0 comments on commit cbc3677

Please sign in to comment.