Skip to content

Commit

Permalink
chore: module-sig-verify cleanup
Browse files Browse the repository at this point in the history
Make module-sig-verify code more idiomatic.

Co-authored-by: Andrey Smirnov <andrey.smirnov@talos-systems.com>
Signed-off-by: Noel Georgi <git@frezbo.dev>
(cherry picked from commit 07bb61e)
  • Loading branch information
frezbo authored and smira committed Apr 11, 2023
1 parent be87b65 commit 69045b7
Showing 1 changed file with 28 additions and 39 deletions.
67 changes: 28 additions & 39 deletions hack/module-sig-verify/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,11 @@ import (
"crypto/rsa"
"crypto/sha512"
"crypto/x509"
"encoding/hex"
"encoding/binary"
"flag"
"fmt"
"io"
"os"
"strconv"

"go.mozilla.org/pkcs7"
)
Expand All @@ -36,8 +35,10 @@ import (
const (
// SignedModuleMagic is the magic string appended to the end of a signed module.
SignedModuleMagic = "~Module signature appended~\n"
// SignedModuleMagicLength is the length of the magic string.
SignedModuleMagicLength = int64(len(SignedModuleMagic))
// ModuleSignatureInfoLength is the length of the signature info.
ModuleSignatureInfoLength = 12
ModuleSignatureInfoLength int64 = 12
)

var (
Expand Down Expand Up @@ -72,56 +73,58 @@ func main() {
fmt.Println(err)
os.Exit(1)
}
defer moduleData.Close() //nolint:errcheck

if err := verifyModule(crt, moduleData); err != nil {
fmt.Println(err)
os.Exit(1)
os.Exit(1) //nolint:gocritic
}
}

func parseModuleInput(module string) (*bytes.Reader, error) {
type noOPCloser struct {
io.ReadSeeker
}

func (noOPCloser) Close() error { return nil }

func parseModuleInput(module string) (io.ReadSeekCloser, error) {
if module == "-" {
moduleData, err := io.ReadAll(os.Stdin)
if err != nil {
return nil, fmt.Errorf("failed to read module from stdin: %w", err)
}

return bytes.NewReader(moduleData), nil
return noOPCloser{bytes.NewReader(moduleData)}, nil
}

moduleData, err := os.ReadFile(module)
moduleData, err := os.Open(module)
if err != nil {
return nil, fmt.Errorf("failed to open file %s: %w", module, err)
}

return bytes.NewReader(moduleData), nil
return moduleData, nil
}

func verifyModule(crt *x509.Certificate, moduleData *bytes.Reader) error {
fileLen := moduleData.Size()
signedModuleMagicStart := fileLen - int64(len(SignedModuleMagic))

_, err := moduleData.Seek(signedModuleMagicStart, 0)
func verifyModule(crt *x509.Certificate, moduleData io.ReadSeeker) error {
_, err := moduleData.Seek(-SignedModuleMagicLength, io.SeekEnd)
if err != nil {
return fmt.Errorf("failed to seek to %d in file %s: %w", signedModuleMagicStart, module, err)
return fmt.Errorf("failed to seek to %d in file %s: %w", -SignedModuleMagicLength, module, err)
}

magicBytes := make([]byte, len(SignedModuleMagic))

_, err = moduleData.Read(magicBytes)
if err != nil {
return fmt.Errorf("failed to read %d bytes from file %s: %w", len(SignedModuleMagic), module, err)
return fmt.Errorf("failed to read %d bytes from file %s: %w", SignedModuleMagicLength, module, err)
}

if string(magicBytes) != SignedModuleMagic {
return fmt.Errorf("file %s is not a signed module", module)
}

signatureInfoStart := signedModuleMagicStart - ModuleSignatureInfoLength

_, err = moduleData.Seek(signatureInfoStart, 0)
_, err = moduleData.Seek(-SignedModuleMagicLength-ModuleSignatureInfoLength, io.SeekCurrent)
if err != nil {
return fmt.Errorf("failed to seek to %d in file %s: %w", signatureInfoStart, module, err)
return fmt.Errorf("failed to seek to %d in file %s: %w", -SignedModuleMagicLength-ModuleSignatureInfoLength, module, err)
}

signatureBytes := make([]byte, ModuleSignatureInfoLength)
Expand All @@ -131,17 +134,13 @@ func verifyModule(crt *x509.Certificate, moduleData *bytes.Reader) error {
return fmt.Errorf("failed to read %d bytes from file %s: %w", ModuleSignatureInfoLength, module, err)
}

// The signature length is encoded in the last 2 bytes of the signature info.
signatureLength, err := strconv.ParseInt(hex.EncodeToString(signatureBytes[(len(signatureBytes)-2):]), 16, 64)
if err != nil {
return fmt.Errorf("failed to parse signature length %w", err)
}
// The signature length is encoded in the last 4 bytes of the signature info.
// https://github.com/torvalds/linux/blob/master/scripts/sign-file.c#L62-L70
signatureLength := int64(binary.BigEndian.Uint32(signatureBytes[(len(signatureBytes) - 4):]))

signatureStart := signatureInfoStart - signatureLength

_, err = moduleData.Seek(signatureStart, 0)
_, err = moduleData.Seek(-ModuleSignatureInfoLength-signatureLength, io.SeekCurrent)
if err != nil {
return fmt.Errorf("failed to seek to %d in file %s: %w", signatureStart, module, err)
return fmt.Errorf("failed to seek to %d in file %s: %w", -ModuleSignatureInfoLength-signatureLength, module, err)
}

signature := make([]byte, signatureLength)
Expand All @@ -151,26 +150,16 @@ func verifyModule(crt *x509.Certificate, moduleData *bytes.Reader) error {
return fmt.Errorf("failed to read %d bytes from file %s: %w", signatureLength, module, err)
}

moduleWithSignatureLength := fileLen - int64(len(SignedModuleMagic))

_, err = moduleData.Seek(0, 0)
unsignedModuleLength, err := moduleData.Seek(-signatureLength, io.SeekCurrent)
if err != nil {
return fmt.Errorf("failed to seek to %d in file %s: %w", 0, module, err)
}

moduleWithSignature := make([]byte, moduleWithSignatureLength)

_, err = moduleData.Read(moduleWithSignature)
if err != nil {
return fmt.Errorf("failed to read %d bytes from file %s: %w", moduleWithSignatureLength, module, err)
}

_, err = moduleData.Seek(0, 0)
if err != nil {
return fmt.Errorf("failed to seek to %d in file %s: %w", 0, module, err)
}

unsignedModuleLength := fileLen - signatureLength - ModuleSignatureInfoLength - int64(len(SignedModuleMagic))
unsignedModuleData := make([]byte, unsignedModuleLength)

_, err = moduleData.Read(unsignedModuleData)
Expand Down

0 comments on commit 69045b7

Please sign in to comment.