diff --git a/pkg/bundle/bundle.go b/pkg/bundle/bundle.go index 376895c..a5294cf 100644 --- a/pkg/bundle/bundle.go +++ b/pkg/bundle/bundle.go @@ -42,6 +42,8 @@ var ErrMissingEnvelope = fmt.Errorf("%w: missing envelope", ErrInvalidAttestatio var ErrDecodingJSON = fmt.Errorf("%w: decoding json", ErrInvalidAttestation) var ErrDecodingB64 = fmt.Errorf("%w: decoding base64", ErrInvalidAttestation) +const mediaTypeBase = "application/vnd.dev.sigstore.bundle" + func ErrValidationError(err error) error { return fmt.Errorf("%w: %w", ErrValidation, err) } @@ -113,17 +115,40 @@ func (b *ProtobufBundle) validate() error { return nil } +// MediaTypeString returns a mediatype string for the specified bundle version. +// The function returns an error if the resulting string does validate. +func MediaTypeString(version string) (string, error) { + if version == "" { + return "", fmt.Errorf("unable to build media type string, no version defined") + } + + var mtString string + + version = strings.TrimPrefix(version, "v") + mtString = fmt.Sprintf("%s.v%s+json", mediaTypeBase, strings.TrimPrefix(version, "v")) + + if version == "0.1" || version == "0.2" { + mtString = fmt.Sprintf("%s+json;version=%s", mediaTypeBase, strings.TrimPrefix(version, "v")) + } + + if _, err := getBundleVersion(mtString); err != nil { + return "", fmt.Errorf("unable to build mediatype: %w", err) + } + + return mtString, nil +} + func getBundleVersion(mediaType string) (string, error) { switch mediaType { - case "application/vnd.dev.sigstore.bundle+json;version=0.1": + case mediaTypeBase + "+json;version=0.1": return "v0.1", nil - case "application/vnd.dev.sigstore.bundle+json;version=0.2": + case mediaTypeBase + "+json;version=0.2": return "v0.2", nil - case "application/vnd.dev.sigstore.bundle+json;version=0.3": + case mediaTypeBase + "+json;version=0.3": return "v0.3", nil } - if strings.HasPrefix(mediaType, "application/vnd.dev.sigstore.bundle.v") && strings.HasSuffix(mediaType, "+json") { - version := strings.TrimPrefix(mediaType, "application/vnd.dev.sigstore.bundle.") + if strings.HasPrefix(mediaType, mediaTypeBase+".v") && strings.HasSuffix(mediaType, "+json") { + version := strings.TrimPrefix(mediaType, mediaTypeBase+".") version = strings.TrimSuffix(version, "+json") if semver.IsValid(version) { return version, nil diff --git a/pkg/bundle/bundle_test.go b/pkg/bundle/bundle_test.go index f1d9419..9704041 100644 --- a/pkg/bundle/bundle_test.go +++ b/pkg/bundle/bundle_test.go @@ -17,6 +17,8 @@ package bundle import ( "fmt" "testing" + + "github.com/stretchr/testify/require" ) func Test_getBundleVersion(t *testing.T) { @@ -94,3 +96,31 @@ func Test_getBundleVersion(t *testing.T) { }) } } + +func TestMediaTypeString(t *testing.T) { + t.Parallel() + for _, tc := range []struct { + name string + ver string + expected string + mustErr bool + }{ + {"normal-semver", "v0.3", "application/vnd.dev.sigstore.bundle.v0.3+json", false}, + {"old-semver1", "v0.1", "application/vnd.dev.sigstore.bundle+json;version=0.1", false}, + {"old-semver2", "v0.2", "application/vnd.dev.sigstore.bundle+json;version=0.2", false}, + {"blank", "", "", true}, + {"invalid", "garbage", "", true}, + } { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + res, err := MediaTypeString(tc.ver) + if tc.mustErr { + require.Error(t, err) + return + } + require.NoError(t, err) + require.Equal(t, tc.expected, res) + }) + } +}