Skip to content

Commit

Permalink
refactor for better error handling
Browse files Browse the repository at this point in the history
  • Loading branch information
salrashid123 committed May 16, 2024
1 parent f823436 commit 897bb87
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 83 deletions.
58 changes: 25 additions & 33 deletions kms/kms.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,67 +61,59 @@ func NewKMSCrypto(conf *KMS) (KMS, error) {
if conf.ProjectId == "" {
return KMS{}, fmt.Errorf("ProjectID cannot be null")
}
return *conf, nil
}

func (t KMS) Public() crypto.PublicKey {
if t.publicKey == nil {
ctx := context.Background()
parentName := fmt.Sprintf("projects/%s/locations/%s/keyRings/%s/cryptoKeys/%s/cryptoKeyVersions/%s", t.ProjectId, t.LocationId, t.KeyRing, t.Key, t.KeyVersion)
ctx := context.Background()
parentName := fmt.Sprintf("projects/%s/locations/%s/keyRings/%s/cryptoKeys/%s/cryptoKeyVersions/%s", conf.ProjectId, conf.LocationId, conf.KeyRing, conf.Key, conf.KeyVersion)

kmsClient, err := cloudkms.NewKeyManagementClient(ctx)
if err != nil {
fmt.Printf("Error getting kms client %v", err)
return nil
}
defer kmsClient.Close()
kmsClient, err := cloudkms.NewKeyManagementClient(ctx)
if err != nil {
return KMS{}, fmt.Errorf("Error getting kms client %v", err)
}
defer kmsClient.Close()

dresp, err := kmsClient.GetPublicKey(ctx, &kmspb.GetPublicKeyRequest{Name: parentName})
if err != nil {
fmt.Printf("Error getting GetPublicKey %v", err)
return nil
}
pubKeyBlock, _ := pem.Decode([]byte(dresp.Pem))
dresp, err := kmsClient.GetPublicKey(ctx, &kmspb.GetPublicKeyRequest{Name: parentName})
if err != nil {
return KMS{}, fmt.Errorf("Error getting GetPublicKey %v", err)
}
pubKeyBlock, _ := pem.Decode([]byte(dresp.Pem))

t.publicKey, err = x509.ParsePKIXPublicKey(pubKeyBlock.Bytes)
if err != nil {
fmt.Printf("Error parsing PublicKey %v", err)
return nil
}
conf.publicKey, err = x509.ParsePKIXPublicKey(pubKeyBlock.Bytes)
if err != nil {
return KMS{}, fmt.Errorf("Error parsing PublicKey %v", err)
}

return *conf, nil
}

func (t KMS) Public() crypto.PublicKey {
return t.publicKey
}

func (t KMS) TLSCertificate() tls.Certificate {
func (t KMS) TLSCertificate() (tls.Certificate, error) {

if t.PublicKeyFile == "" {
fmt.Printf("Public X509 certificate not specified")
return tls.Certificate{}
return tls.Certificate{}, fmt.Errorf("public X509 certificate not specified")
}

pubPEM, err := os.ReadFile(t.PublicKeyFile)
if err != nil {
fmt.Printf("Unable to read keys %v", err)
return tls.Certificate{}
return tls.Certificate{}, fmt.Errorf("unable to read keys %v", err)
}
block, _ := pem.Decode([]byte(pubPEM))
if block == nil {
fmt.Printf("failed to parse PEM block containing the public key")
return tls.Certificate{}
return tls.Certificate{}, fmt.Errorf("failed to parse PEM block containing the public key")
}
pub, err := x509.ParseCertificate(block.Bytes)
if err != nil {
fmt.Printf("failed to parse public key: " + err.Error())
return tls.Certificate{}
return tls.Certificate{}, fmt.Errorf("failed to parse public key: %v ", err)
}
x509Certificate = *pub
var privKey crypto.PrivateKey = t
return tls.Certificate{
PrivateKey: privKey,
Leaf: &x509Certificate,
Certificate: [][]byte{x509Certificate.Raw},
}
}, nil
}

func (t KMS) Sign(_ io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) {
Expand Down
88 changes: 38 additions & 50 deletions tpm/tpm.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,53 +69,45 @@ func NewTPMCrypto(conf *TPM) (TPM, error) {
if conf.TpmPath != "" && conf.KeyHandle == 0 {
return TPM{}, fmt.Errorf("salrashid123/x/oauth2/google: if TPMTokenConfig.TPMPath is specified, a KeyHandle must be set")
}
return *conf, nil
}

func (t TPM) Public() crypto.PublicKey {
if t.publicKey == nil {
t.refreshMutex.Lock()
defer t.refreshMutex.Unlock()

var rwc io.ReadWriteCloser
var k *client.Key
if t.TpmDevice == nil {
var err error
rwc, err = tpm2.OpenTPM(t.TpmPath)
if err != nil {
fmt.Printf("google: Unable to Read Public data from TPM: %v", err)
return nil
}
defer rwc.Close()
pcrsession, err := client.NewPCRSession(rwc, tpm2.PCRSelection{tpm2.AlgSHA256, t.PCRs})
if err != nil {
fmt.Printf("google: Unable to Read Public data from TPM: %v", err)
return nil
}
k, err = client.LoadCachedKey(rwc, tpmutil.Handle(t.KeyHandle), pcrsession)
if err != nil {
fmt.Printf("google: Unable to Read Public data from TPM: %v", err)
return nil
}
defer pcrsession.Close()
defer k.Close()
} else {
rwc = t.TpmDevice
k = t.Key
var rwc io.ReadWriteCloser
var k *client.Key
if conf.TpmDevice == nil {
var err error
rwc, err = tpm2.OpenTPM(conf.TpmPath)
if err != nil {
return TPM{}, fmt.Errorf("google: Unable to Read Public data from TPM: %v", err)
}

pub, _, _, err := tpm2.ReadPublic(rwc, k.Handle())
defer rwc.Close()
pcrsession, err := client.NewPCRSession(rwc, tpm2.PCRSelection{tpm2.AlgSHA256, conf.PCRs})
if err != nil {
fmt.Printf("google: Unable to Read Public data from TPM: %v", err)
return nil
return TPM{}, fmt.Errorf("google: Unable to Read Public data from TPM: %v", err)
}
pubKey, err := pub.Key()
k, err = client.LoadCachedKey(rwc, tpmutil.Handle(conf.KeyHandle), pcrsession)
if err != nil {
fmt.Printf("google: Unable to Read Public data from TPM: %v", err)
return nil
return TPM{}, fmt.Errorf("google: Unable to Read Public data from TPM: %v", err)
}
t.publicKey = pubKey
defer pcrsession.Close()
defer k.Close()
} else {
rwc = conf.TpmDevice
k = conf.Key
}

pub, _, _, err := tpm2.ReadPublic(rwc, k.Handle())
if err != nil {
return TPM{}, fmt.Errorf("google: Unable to Read Public data from TPM: %v", err)
}
pubKey, err := pub.Key()
if err != nil {
return TPM{}, fmt.Errorf("google: Unable to Read Public data from TPM: %v", err)
}
conf.publicKey = pubKey

return *conf, nil
}

func (t TPM) Public() crypto.PublicKey {
return t.publicKey
}

Expand Down Expand Up @@ -196,27 +188,23 @@ func (t TPM) Sign(rr io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte,
}
}

func (t TPM) TLSCertificate() tls.Certificate {
func (t TPM) TLSCertificate() (tls.Certificate, error) {

if t.PublicCertFile == "" {
fmt.Printf("Public X509 certificate not specified")
return tls.Certificate{}
return tls.Certificate{}, fmt.Errorf("Public X509 certificate not specified")
}

pubPEM, err := os.ReadFile(t.PublicCertFile)
if err != nil {
fmt.Printf("Unable to read keys %v", err)
return tls.Certificate{}
return tls.Certificate{}, fmt.Errorf("Unable to read public certificate file %v", err)
}
block, _ := pem.Decode([]byte(pubPEM))
if block == nil {
fmt.Printf("failed to parse PEM block containing the public key")
return tls.Certificate{}
return tls.Certificate{}, fmt.Errorf("failed to parse PEM block containing the public key")
}
pub, err := x509.ParseCertificate(block.Bytes)
if err != nil {
fmt.Printf("failed to parse public key: " + err.Error())
return tls.Certificate{}
return tls.Certificate{}, fmt.Errorf("Unable to read public certificate file %v", err)
}

t.x509Certificate = *pub
Expand All @@ -225,5 +213,5 @@ func (t TPM) TLSCertificate() tls.Certificate {
PrivateKey: privKey,
Leaf: &t.x509Certificate,
Certificate: [][]byte{t.x509Certificate.Raw},
}
}, nil
}

0 comments on commit 897bb87

Please sign in to comment.