Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Asynchronous uploads (MSC2246) #364

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
19 changes: 18 additions & 1 deletion api/r0/download.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,21 +48,38 @@ func DownloadMedia(r *http.Request, rctx rcontext.RequestContext, user api.UserI
downloadRemote = parsedFlag
}

var asyncWaitMs *int = nil
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ideally this is a non-pointer int throughout, for code clarity more than anything. A wait time of zero is still semantically possible.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having it as a nillable pointer was to make it distinct it's not set vs it being set to zero. Not setting it from the request makes the code pick up the default from the configuration file.

if rctx.Config.Features.MSC2246Async.Enabled {
// request default wait time if feature enabled
var parsedInt int = -1
maxStallMs := r.URL.Query().Get("fi.mau.msc2246.max_stall_ms")
if maxStallMs != "" {
var err error
parsedInt, err = strconv.Atoi(maxStallMs)
if err != nil {
return api.InternalServerError("fi.mau.msc2246.max_stall_ms does not appear to be a number")
}
}
asyncWaitMs = &parsedInt
}

rctx = rctx.LogWithFields(logrus.Fields{
"mediaId": mediaId,
"server": server,
"filename": filename,
"allowRemote": downloadRemote,
})

streamedMedia, err := download_controller.GetMedia(server, mediaId, downloadRemote, false, rctx)
streamedMedia, err := download_controller.GetMedia(server, mediaId, downloadRemote, false, asyncWaitMs, rctx)
if err != nil {
if err == common.ErrMediaNotFound {
return api.NotFoundError()
} else if err == common.ErrMediaTooLarge {
return api.RequestTooLarge()
} else if err == common.ErrMediaQuarantined {
return api.NotFoundError() // We lie for security
} else if err == common.ErrNotYetUploaded {
return api.NotYetUploaded()
}
rctx.Log.Error("Unexpected error locating media: " + err.Error())
sentry.CaptureException(err)
Expand Down
19 changes: 18 additions & 1 deletion api/r0/thumbnail.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,21 @@ func ThumbnailMedia(r *http.Request, rctx rcontext.RequestContext, user api.User
downloadRemote = parsedFlag
}

var asyncWaitMs *int = nil
if rctx.Config.Features.MSC2246Async.Enabled {
// request default wait time if feature enabled
var parsedInt int = -1
maxStallMs := r.URL.Query().Get("fi.mau.msc2246.max_stall_ms")
if maxStallMs != "" {
var err error
parsedInt, err = strconv.Atoi(maxStallMs)
if err != nil {
return api.InternalServerError("fi.mau.msc2246.max_stall_ms does not appear to be a number")
}
}
asyncWaitMs = &parsedInt
}

rctx = rctx.LogWithFields(logrus.Fields{
"mediaId": mediaId,
"server": server,
Expand Down Expand Up @@ -87,12 +102,14 @@ func ThumbnailMedia(r *http.Request, rctx rcontext.RequestContext, user api.User
return api.BadRequest("Width and height must be greater than zero")
}

streamedThumbnail, err := thumbnail_controller.GetThumbnail(server, mediaId, width, height, animated, method, downloadRemote, rctx)
streamedThumbnail, err := thumbnail_controller.GetThumbnail(server, mediaId, width, height, animated, method, downloadRemote, asyncWaitMs, rctx)
if err != nil {
if err == common.ErrMediaNotFound {
return api.NotFoundError()
} else if err == common.ErrMediaTooLarge {
return api.RequestTooLarge()
} else if err == common.ErrNotYetUploaded {
return api.NotYetUploaded()
}
rctx.Log.Error("Unexpected error locating media: " + err.Error())
sentry.CaptureException(err)
Expand Down
47 changes: 46 additions & 1 deletion api/r0/upload.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,17 @@ import (
"io/ioutil"
"net/http"
"path/filepath"
"time"

"github.com/gorilla/mux"
"github.com/sirupsen/logrus"
"github.com/turt2live/matrix-media-repo/api"
"github.com/turt2live/matrix-media-repo/common"
"github.com/turt2live/matrix-media-repo/common/rcontext"
"github.com/turt2live/matrix-media-repo/controllers/info_controller"
"github.com/turt2live/matrix-media-repo/controllers/upload_controller"
"github.com/turt2live/matrix-media-repo/quota"
"github.com/turt2live/matrix-media-repo/util"
"github.com/turt2live/matrix-media-repo/util/cleanup"
)

Expand All @@ -22,14 +25,52 @@ type MediaUploadedResponse struct {
Blurhash string `json:"xyz.amorgan.blurhash,omitempty"`
}

type MediaCreatedResponse struct {
ContentUri string `json:"content_uri"`
UnusedExpiresAt int64 `json:"unused_expires_at"`
}

func CreateMedia(r *http.Request, rctx rcontext.RequestContext, user api.UserInfo) interface{} {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't r0, so should be in the unstable directory instead (alongside the media info and local copy endpoints).

media, _, err := upload_controller.CreateMedia(r.Host, rctx)
if err != nil {
rctx.Log.Error("Unexpected error creating media reference: " + err.Error())
return api.InternalServerError("Unexpected Error")
}

if err = upload_controller.PersistMedia(media, user.UserId, rctx); err != nil {
rctx.Log.Error("Unexpected error persisting media reference: " + err.Error())
return api.InternalServerError("Unexpected Error")
}

return &MediaCreatedResponse{
ContentUri: media.MxcUri(),
UnusedExpiresAt: time.Now().Unix() + int64(rctx.Config.Features.MSC2246Async.AsyncUploadExpirySecs),
}
}

func UploadMedia(r *http.Request, rctx rcontext.RequestContext, user api.UserInfo) interface{} {
var server = ""
var mediaId = ""

filename := filepath.Base(r.URL.Query().Get("filename"))
defer cleanup.DumpAndCloseStream(r.Body)

if rctx.Config.Features.MSC2246Async.Enabled {
params := mux.Vars(r)
server = params["server"]
mediaId = params["mediaId"]
}

rctx = rctx.LogWithFields(logrus.Fields{
"server": server,
"mediaId": mediaId,
"filename": filename,
})

if server != "" && (!util.IsServerOurs(server) || server != r.Host) {
return api.NotFoundError()
}

contentType := r.Header.Get("Content-Type")
if contentType == "" {
contentType = "application/octet-stream" // binary
Expand Down Expand Up @@ -59,12 +100,16 @@ func UploadMedia(r *http.Request, rctx rcontext.RequestContext, user api.UserInf

contentLength := upload_controller.EstimateContentLength(r.ContentLength, r.Header.Get("Content-Length"))

media, err := upload_controller.UploadMedia(r.Body, contentLength, contentType, filename, user.UserId, r.Host, rctx)
media, err := upload_controller.UploadMedia(r.Body, contentLength, contentType, filename, user.UserId, r.Host, mediaId, rctx)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

StoreDirect should probably be used instead of UploadMedia when the media ID is known, similar to an import. Though, the upload size check might need to be moved up to somewhere. Possibly a new function which calls StoreDirect but is named OverwriteExisting or something?

if err != nil {
io.Copy(ioutil.Discard, r.Body) // Ditch the entire request

if err == common.ErrMediaQuarantined {
return api.BadRequest("This file is not permitted on this server")
} else if err == common.ErrCannotOverwriteMedia {
return api.CannotOverwriteMedia()
} else if err == common.ErrMediaNotFound {
return api.NotFoundError()
}

rctx.Log.Error("Unexpected error storing media: " + err.Error())
Expand Down
8 changes: 8 additions & 0 deletions api/responses.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,11 @@ func BadRequest(message string) *ErrorResponse {
func QuotaExceeded() *ErrorResponse {
return &ErrorResponse{common.ErrCodeForbidden, "Quota Exceeded", common.ErrCodeQuotaExceeded}
}

func CannotOverwriteMedia() *ErrorResponse {
return &ErrorResponse{common.ErrCodeCannotOverwriteMedia, "Cannot overwrite media", common.ErrCodeCannotOverwriteMedia}
}

func NotYetUploaded() *ErrorResponse {
return &ErrorResponse{common.ErrCodeNotYetUploaded, "Media not yet uploaded", common.ErrCodeNotYetUploaded}
}
2 changes: 1 addition & 1 deletion api/unstable/info.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func MediaInfo(r *http.Request, rctx rcontext.RequestContext, user api.UserInfo)
"allowRemote": downloadRemote,
})

streamedMedia, err := download_controller.GetMedia(server, mediaId, downloadRemote, true, rctx)
streamedMedia, err := download_controller.GetMedia(server, mediaId, downloadRemote, true, nil, rctx)
if err != nil {
if err == common.ErrMediaNotFound {
return api.NotFoundError()
Expand Down
4 changes: 2 additions & 2 deletions api/unstable/local_copy.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func LocalCopy(r *http.Request, rctx rcontext.RequestContext, user api.UserInfo)

// TODO: There's a lot of room for improvement here. Instead of re-uploading media, we should just update the DB.

streamedMedia, err := download_controller.GetMedia(server, mediaId, downloadRemote, true, rctx)
streamedMedia, err := download_controller.GetMedia(server, mediaId, downloadRemote, true, nil, rctx)
if err != nil {
if err == common.ErrMediaNotFound {
return api.NotFoundError()
Expand All @@ -60,7 +60,7 @@ func LocalCopy(r *http.Request, rctx rcontext.RequestContext, user api.UserInfo)
return &r0.MediaUploadedResponse{ContentUri: streamedMedia.KnownMedia.MxcUri()}
}

newMedia, err := upload_controller.UploadMedia(streamedMedia.Stream, streamedMedia.KnownMedia.SizeBytes, streamedMedia.KnownMedia.ContentType, streamedMedia.KnownMedia.UploadName, user.UserId, r.Host, rctx)
newMedia, err := upload_controller.UploadMedia(streamedMedia.Stream, streamedMedia.KnownMedia.SizeBytes, streamedMedia.KnownMedia.ContentType, streamedMedia.KnownMedia.UploadName, user.UserId, r.Host, "", rctx)
if err != nil {
rctx.Log.Error("Unexpected error storing media: " + err.Error())
sentry.CaptureException(err)
Expand Down
5 changes: 4 additions & 1 deletion api/webserver/route_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ func (h handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
case common.ErrCodeUnknownToken:
statusCode = http.StatusUnauthorized
break
case common.ErrCodeNotFound:
case common.ErrCodeNotYetUploaded, common.ErrCodeNotFound:
statusCode = http.StatusNotFound
break
case common.ErrCodeMediaTooLarge:
Expand All @@ -161,6 +161,9 @@ func (h handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
case common.ErrCodeForbidden:
statusCode = http.StatusForbidden
break
case common.ErrCodeCannotOverwriteMedia:
statusCode = http.StatusConflict
break
default: // Treat as unknown (a generic server error)
statusCode = http.StatusInternalServerError
break
Expand Down
7 changes: 7 additions & 0 deletions api/webserver/webserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ func Init() *sync.WaitGroup {
counter := &requestCounter{}

optionsHandler := handler{api.EmptyResponseHandler, "options_request", counter, false}
createHandler := handler{api.AccessTokenRequiredRoute(r0.CreateMedia), "create", counter, false}
uploadHandler := handler{api.AccessTokenRequiredRoute(r0.UploadMedia), "upload", counter, false}
downloadHandler := handler{api.AccessTokenOptionalRoute(r0.DownloadMedia), "download", counter, false}
thumbnailHandler := handler{api.AccessTokenOptionalRoute(r0.ThumbnailMedia), "thumbnail", counter, false}
Expand Down Expand Up @@ -158,6 +159,12 @@ func Init() *sync.WaitGroup {
}
}

if config.Get().Features.MSC2246Async.Enabled {
logrus.Info("Asynchronous uploads (MSC2246) enabled")
routes = append(routes, definedRoute{"/_matrix/media/unstable/fi.mau.msc2246/create", route{"POST", createHandler}})
routes = append(routes, definedRoute{"/_matrix/media/unstable/fi.mau.msc2246/upload/{server:[a-zA-Z0-9.:\\-_]+}/{mediaId:[^/]+}", route{"PUT", uploadHandler}})
}

if config.Get().Features.IPFS.Enabled {
routes = append(routes, definedRoute{features.IPFSDownloadRoute, route{"GET", ipfsDownloadHandler}})
routes = append(routes, definedRoute{features.IPFSLiveDownloadRouteR0, route{"GET", ipfsDownloadHandler}})
Expand Down
5 changes: 5 additions & 0 deletions common/config/conf_min_shared.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ func NewDefaultMinimumRepoConfig() MinimumRepoConfig {
YComponents: 3,
Punch: 1,
},
MSC2246Async: MSC2246Config{
Enabled: false,
AsyncUploadExpirySecs: 60,
AsyncDownloadDefaultWaitSecs: 20,
},
IPFS: IPFSConfig{
Enabled: false,
Daemon: IPFSDaemonConfig{
Expand Down
7 changes: 7 additions & 0 deletions common/config/models_domain.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ type TimeoutsConfig struct {

type FeatureConfig struct {
MSC2448Blurhash MSC2448Config `yaml:"MSC2448"`
MSC2246Async MSC2246Config `yaml:"MSC2246"`
IPFS IPFSConfig `yaml:"IPFS"`
Redis RedisConfig `yaml:"redis"`
}
Expand All @@ -103,6 +104,12 @@ type MSC2448Config struct {
Punch int `yaml:"punch"`
}

type MSC2246Config struct {
Enabled bool `yaml:"enabled"`
AsyncUploadExpirySecs int `yaml:"asyncUploadExpirySecs"`
AsyncDownloadDefaultWaitSecs int `yaml:"asyncDownloadDefaultWaitSecs"`
}

type IPFSConfig struct {
Enabled bool `yaml:"enabled"`
Daemon IPFSDaemonConfig `yaml:"builtInDaemon"`
Expand Down
2 changes: 2 additions & 0 deletions common/errorcodes.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,5 @@ const ErrCodeRateLimitExceeded = "M_LIMIT_EXCEEDED"
const ErrCodeUnknown = "M_UNKNOWN"
const ErrCodeForbidden = "M_FORBIDDEN"
const ErrCodeQuotaExceeded = "M_QUOTA_EXCEEDED"
const ErrCodeCannotOverwriteMedia = "FI.MAU.MSC2246_CANNOT_OVEWRITE_MEDIA"
const ErrCodeNotYetUploaded = "FI.MAU.MSC2246_NOT_YET_UPLOADED"
2 changes: 2 additions & 0 deletions common/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,5 @@ var ErrInvalidHost = errors.New("invalid host")
var ErrHostNotFound = errors.New("host not found")
var ErrHostBlacklisted = errors.New("host not allowed")
var ErrMediaQuarantined = errors.New("media quarantined")
var ErrCannotOverwriteMedia = errors.New("cannot overwrite media")
var ErrNotYetUploaded = errors.New("not yet uploaded")
14 changes: 14 additions & 0 deletions config.sample.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,20 @@ featureSupport:
# make the effect more subtle, larger values make it stronger.
punch: 1

# MSC2246 - Asynchronous uploads
MSC2246:
# Whether or not this MSC is enabled for use in the media repo
enabled: false

# The number of seconds an asynchronous upload is valid to be started after requesting a media
# id. After expiring the upload endpoint will return an error for the client.
asyncUploadExpirySecs: 60

# The number of seconds a download request for an asynchronous upload will stall before
# returning an error. This affects clients that do not support async uploads by making them
# wait by default. Setting to zero will disable this behavior unless the client requests it.
asyncDownloadDefaultWaitSecs: 20

# IPFS Support
# This is currently experimental and might not work at all.
IPFS:
Expand Down