Skip to content

Commit

Permalink
feat(types): allow generics to aid in CryptoKey or KeyObject narrowin…
Browse files Browse the repository at this point in the history
…g of KeyLike
  • Loading branch information
panva committed Feb 27, 2023
1 parent 20610a9 commit 6effa4d
Show file tree
Hide file tree
Showing 15 changed files with 79 additions and 53 deletions.
4 changes: 2 additions & 2 deletions src/jwe/compact/decrypt.ts
Expand Up @@ -47,11 +47,11 @@ export async function compactDecrypt(
* @param getKey Function resolving Private Key or Secret to decrypt the JWE with.
* @param options JWE Decryption options.
*/
export async function compactDecrypt(
export async function compactDecrypt<T extends KeyLike = KeyLike>(
jwe: string | Uint8Array,
getKey: CompactDecryptGetKey,
options?: DecryptOptions,
): Promise<CompactDecryptResult & ResolvedKey>
): Promise<CompactDecryptResult & ResolvedKey<T>>
export async function compactDecrypt(
jwe: string | Uint8Array,
key: KeyLike | Uint8Array | CompactDecryptGetKey,
Expand Down
4 changes: 2 additions & 2 deletions src/jwe/flattened/decrypt.ts
Expand Up @@ -66,11 +66,11 @@ export function flattenedDecrypt(
* @param getKey Function resolving Private Key or Secret to decrypt the JWE with.
* @param options JWE Decryption options.
*/
export function flattenedDecrypt(
export function flattenedDecrypt<T extends KeyLike = KeyLike>(
jwe: FlattenedJWE,
getKey: FlattenedDecryptGetKey,
options?: DecryptOptions,
): Promise<FlattenedDecryptResult & ResolvedKey>
): Promise<FlattenedDecryptResult & ResolvedKey<T>>
export async function flattenedDecrypt(
jwe: FlattenedJWE,
key: KeyLike | Uint8Array | FlattenedDecryptGetKey,
Expand Down
4 changes: 2 additions & 2 deletions src/jwe/general/decrypt.ts
Expand Up @@ -61,11 +61,11 @@ export function generalDecrypt(
* @param getKey Function resolving Private Key or Secret to decrypt the JWE with.
* @param options JWE Decryption options.
*/
export function generalDecrypt(
export function generalDecrypt<T extends KeyLike = KeyLike>(
jwe: GeneralJWE,
getKey: GeneralDecryptGetKey,
options?: DecryptOptions,
): Promise<GeneralDecryptResult & ResolvedKey>
): Promise<GeneralDecryptResult & ResolvedKey<T>>
export async function generalDecrypt(
jwe: GeneralJWE,
key: KeyLike | Uint8Array | GeneralDecryptGetKey,
Expand Down
6 changes: 3 additions & 3 deletions src/jwk/embedded.ts
Expand Up @@ -24,10 +24,10 @@ import { JWSInvalid } from '../util/errors.js'
* console.log(payload)
* ```
*/
export async function EmbeddedJWK(
): Promise<KeyLike> {
export async function EmbeddedJWK<T extends KeyLike = KeyLike>(
protectedHeader?: JWSHeaderParameters,
token?: FlattenedJWSInput,
): Promise<T> {
const joseHeader = {
...protectedHeader,
...token?.header,
Expand All @@ -36,7 +36,7 @@ export async function EmbeddedJWK(
throw new JWSInvalid('"jwk" (JSON Web Key) Header Parameter must be a JSON object')
}

const key = await importJWK({ ...joseHeader.jwk, ext: true }, joseHeader.alg!, true)
const key = await importJWK<T>({ ...joseHeader.jwk, ext: true }, joseHeader.alg!, true)

if (key instanceof Uint8Array || key.type !== 'public') {
throw new JWSInvalid('"jwk" (JSON Web Key) Header Parameter must be a public key')
Expand Down
36 changes: 23 additions & 13 deletions src/jwks/local.ts
Expand Up @@ -28,8 +28,8 @@ function getKtyFromAlg(alg: unknown) {
}
}

interface Cache {
[alg: string]: KeyLike
interface Cache<T extends KeyLike = KeyLike> {
[alg: string]: T
}

/** @private */
Expand Down Expand Up @@ -59,10 +59,10 @@ function clone<T>(obj: T): T {
}

/** @private */
export class LocalJWKSet {
export class LocalJWKSet<T extends KeyLike = KeyLike> {
protected _jwks?: JSONWebKeySet

private _cached: WeakMap<JWK, Cache> = new WeakMap()
private _cached: WeakMap<JWK, Cache<T>> = new WeakMap()

constructor(jwks: unknown) {
if (!isJWKSLike(jwks)) {
Expand All @@ -72,7 +72,7 @@ export class LocalJWKSet {
this._jwks = clone<JSONWebKeySet>(jwks)
}

async getKey(protectedHeader?: JWSHeaderParameters, token?: FlattenedJWSInput): Promise<KeyLike> {
async getKey(protectedHeader?: JWSHeaderParameters, token?: FlattenedJWSInput): Promise<T> {
const { alg, kid } = { ...protectedHeader, ...token?.header }
const kty = getKtyFromAlg(alg)

Expand Down Expand Up @@ -137,7 +137,7 @@ export class LocalJWKSet {
error[Symbol.asyncIterator] = async function* () {
for (const jwk of candidates) {
try {
yield await importWithAlgCache(_cached, jwk, alg!)
yield await importWithAlgCache<T>(_cached, jwk, alg!)
} catch {
continue
}
Expand All @@ -147,20 +147,24 @@ export class LocalJWKSet {
throw error
}

return importWithAlgCache(this._cached, jwk, alg!)
return importWithAlgCache<T>(this._cached, jwk, alg!)
}
}

async function importWithAlgCache(cache: WeakMap<JWK, Cache>, jwk: JWK, alg: string) {
async function importWithAlgCache<T extends KeyLike = KeyLike>(
cache: WeakMap<JWK, Cache<T>>,
jwk: JWK,
alg: string,
) {
const cached = cache.get(jwk) || cache.set(jwk, {}).get(jwk)!
if (cached[alg] === undefined) {
const keyObject = <KeyLike>await importJWK({ ...jwk, ext: true }, alg)
const key = await importJWK<T>({ ...jwk, ext: true }, alg)

if (keyObject.type !== 'public') {
if (key instanceof Uint8Array || key.type !== 'public') {
throw new JWKSInvalid('JSON Web Key Set members must be public keys')
}

cached[alg] = keyObject
cached[alg] = key
}

return cached[alg]
Expand Down Expand Up @@ -240,6 +244,12 @@ async function importWithAlgCache(cache: WeakMap<JWK, Cache>, jwk: JWK, alg: str
*
* @param jwks JSON Web Key Set formatted object.
*/
export function createLocalJWKSet(jwks: JSONWebKeySet) {
return LocalJWKSet.prototype.getKey.bind(new LocalJWKSet(jwks))
export function createLocalJWKSet<T extends KeyLike = KeyLike>(jwks: JSONWebKeySet) {
const set = new LocalJWKSet<T>(jwks)
return async function (
protectedHeader?: JWSHeaderParameters,
token?: FlattenedJWSInput,
): Promise<T> {
return set.getKey(protectedHeader, token)
}
}
17 changes: 13 additions & 4 deletions src/jwks/remote.ts
Expand Up @@ -40,7 +40,7 @@ export interface RemoteJWKSetOptions {
headers?: Record<string, string>
}

class RemoteJWKSet extends LocalJWKSet {
class RemoteJWKSet<T extends KeyLike = KeyLike> extends LocalJWKSet<T> {
private _url: URL

private _timeoutDuration: number
Expand Down Expand Up @@ -84,7 +84,7 @@ class RemoteJWKSet extends LocalJWKSet {
: false
}

async getKey(protectedHeader?: JWSHeaderParameters, token?: FlattenedJWSInput): Promise<KeyLike> {
async getKey(protectedHeader?: JWSHeaderParameters, token?: FlattenedJWSInput): Promise<T> {
if (!this._jwks || !this.fresh()) {
await this.reload()
}
Expand Down Expand Up @@ -199,6 +199,15 @@ class RemoteJWKSet extends LocalJWKSet {
* @param url URL to fetch the JSON Web Key Set from.
* @param options Options for the remote JSON Web Key Set.
*/
export function createRemoteJWKSet(url: URL, options?: RemoteJWKSetOptions) {
return RemoteJWKSet.prototype.getKey.bind(new RemoteJWKSet(url, options))
export function createRemoteJWKSet<T extends KeyLike = KeyLike>(
url: URL,
options?: RemoteJWKSetOptions,
) {
const set = new RemoteJWKSet<T>(url, options)
return async function (
protectedHeader?: JWSHeaderParameters,
token?: FlattenedJWSInput,
): Promise<T> {
return set.getKey(protectedHeader, token)
}
}
4 changes: 2 additions & 2 deletions src/jws/compact/verify.ts
Expand Up @@ -51,11 +51,11 @@ export function compactVerify(
* @param getKey Function resolving a key to verify the JWS with.
* @param options JWS Verify options.
*/
export function compactVerify(
export function compactVerify<T extends KeyLike = KeyLike>(
jws: string | Uint8Array,
getKey: CompactVerifyGetKey,
options?: VerifyOptions,
): Promise<CompactVerifyResult & ResolvedKey>
): Promise<CompactVerifyResult & ResolvedKey<T>>
export async function compactVerify(
jws: string | Uint8Array,
key: KeyLike | Uint8Array | CompactVerifyGetKey,
Expand Down
4 changes: 2 additions & 2 deletions src/jws/flattened/verify.ts
Expand Up @@ -64,11 +64,11 @@ export function flattenedVerify(
* @param getKey Function resolving a key to verify the JWS with.
* @param options JWS Verify options.
*/
export function flattenedVerify(
export function flattenedVerify<T extends KeyLike = KeyLike>(
jws: FlattenedJWSInput,
getKey: FlattenedVerifyGetKey,
options?: VerifyOptions,
): Promise<FlattenedVerifyResult & ResolvedKey>
): Promise<FlattenedVerifyResult & ResolvedKey<T>>
export async function flattenedVerify(
jws: FlattenedJWSInput,
key: KeyLike | Uint8Array | FlattenedVerifyGetKey,
Expand Down
4 changes: 2 additions & 2 deletions src/jws/general/verify.ts
Expand Up @@ -60,11 +60,11 @@ export function generalVerify(
* @param getKey Function resolving a key to verify the JWS with.
* @param options JWS Verify options.
*/
export function generalVerify(
export function generalVerify<T extends KeyLike = KeyLike>(
jws: GeneralJWSInput,
getKey: GeneralVerifyGetKey,
options?: VerifyOptions,
): Promise<GeneralVerifyResult & ResolvedKey>
): Promise<GeneralVerifyResult & ResolvedKey<T>>
export async function generalVerify(
jws: GeneralJWSInput,
key: KeyLike | Uint8Array | GeneralVerifyGetKey,
Expand Down
4 changes: 2 additions & 2 deletions src/jwt/decrypt.ts
Expand Up @@ -56,11 +56,11 @@ export async function jwtDecrypt(
* @param getKey Function resolving Private Key or Secret to decrypt and verify the JWT with.
* @param options JWT Decryption and JWT Claims Set validation options.
*/
export async function jwtDecrypt(
export async function jwtDecrypt<T extends KeyLike = KeyLike>(
jwt: string | Uint8Array,
getKey: JWTDecryptGetKey,
options?: JWTDecryptOptions,
): Promise<JWTDecryptResult & ResolvedKey>
): Promise<JWTDecryptResult & ResolvedKey<T>>
export async function jwtDecrypt(
jwt: string | Uint8Array,
key: KeyLike | Uint8Array | JWTDecryptGetKey,
Expand Down
4 changes: 2 additions & 2 deletions src/jwt/verify.ts
Expand Up @@ -123,11 +123,11 @@ export async function jwtVerify(
* @param getKey Function resolving a key to verify the JWT with.
* @param options JWT Decryption and JWT Claims Set validation options.
*/
export async function jwtVerify(
export async function jwtVerify<T extends KeyLike = KeyLike>(
jwt: string | Uint8Array,
getKey: JWTVerifyGetKey,
options?: JWTVerifyOptions,
): Promise<JWTVerifyResult & ResolvedKey>
): Promise<JWTVerifyResult & ResolvedKey<T>>

export async function jwtVerify(
jwt: string | Uint8Array,
Expand Down
11 changes: 6 additions & 5 deletions src/key/generate_key_pair.ts
Expand Up @@ -2,12 +2,12 @@ import { generateKeyPair as generate } from '../runtime/generate.js'

import type { KeyLike } from '../types.d'

export interface GenerateKeyPairResult {
export interface GenerateKeyPairResult<T extends KeyLike = KeyLike> {
/** The generated Private Key. */
privateKey: KeyLike
privateKey: T

/** Public Key corresponding to the generated Private Key. */
publicKey: KeyLike
publicKey: T
}

export interface GenerateKeyPairOptions {
Expand Down Expand Up @@ -49,9 +49,10 @@ export interface GenerateKeyPairOptions {
* @param alg JWA Algorithm Identifier to be used with the generated key pair.
* @param options Additional options passed down to the key pair generation.
*/
export async function generateKeyPair(
export async function generateKeyPair<T extends KeyLike = KeyLike>(
alg: string,
options?: GenerateKeyPairOptions,
): Promise<GenerateKeyPairResult> {
): Promise<GenerateKeyPairResult<T>> {
// @ts-ignore
return generate(alg, options)
}
5 changes: 3 additions & 2 deletions src/key/generate_secret.ts
Expand Up @@ -27,9 +27,10 @@ export interface GenerateSecretOptions {
* @param alg JWA Algorithm Identifier to be used with the generated secret.
* @param options Additional options passed down to the secret generation.
*/
export async function generateSecret(
export async function generateSecret<T extends KeyLike = KeyLike>(
alg: string,
options?: GenerateSecretOptions,
): Promise<KeyLike | Uint8Array> {
): Promise<T | Uint8Array> {
// @ts-ignore
return generate(alg, options)
}
21 changes: 13 additions & 8 deletions src/key/import.ts
Expand Up @@ -35,14 +35,15 @@ export interface PEMImportOptions {
* @param alg (Only effective in Web Crypto API runtimes) JSON Web Algorithm identifier to be used
* with the imported key, its presence is only enforced in Web Crypto API runtimes.
*/
export async function importSPKI(
export async function importSPKI<T extends KeyLike = KeyLike>(
spki: string,
alg: string,
options?: PEMImportOptions,
): Promise<KeyLike> {
): Promise<T> {
if (typeof spki !== 'string' || spki.indexOf('-----BEGIN PUBLIC KEY-----') !== 0) {
throw new TypeError('"spki" must be SPKI formatted string')
}
// @ts-ignore
return fromSPKI(spki, alg, options)
}

Expand Down Expand Up @@ -73,14 +74,15 @@ export async function importSPKI(
* @param alg (Only effective in Web Crypto API runtimes) JSON Web Algorithm identifier to be used
* with the imported key, its presence is only enforced in Web Crypto API runtimes.
*/
export async function importX509(
export async function importX509<T extends KeyLike = KeyLike>(
x509: string,
alg: string,
options?: PEMImportOptions,
): Promise<KeyLike> {
): Promise<T> {
if (typeof x509 !== 'string' || x509.indexOf('-----BEGIN CERTIFICATE-----') !== 0) {
throw new TypeError('"x509" must be X.509 formatted string')
}
// @ts-ignore
return fromX509(x509, alg, options)
}

Expand All @@ -105,14 +107,15 @@ export async function importX509(
* @param alg (Only effective in Web Crypto API runtimes) JSON Web Algorithm identifier to be used
* with the imported key, its presence is only enforced in Web Crypto API runtimes.
*/
export async function importPKCS8(
export async function importPKCS8<T extends KeyLike = KeyLike>(
pkcs8: string,
alg: string,
options?: PEMImportOptions,
): Promise<KeyLike> {
): Promise<T> {
if (typeof pkcs8 !== 'string' || pkcs8.indexOf('-----BEGIN PRIVATE KEY-----') !== 0) {
throw new TypeError('"pkcs8" must be PKCS#8 formatted string')
}
// @ts-ignore
return fromPKCS8(pkcs8, alg, options)
}

Expand Down Expand Up @@ -154,11 +157,11 @@ export async function importPKCS8(
* @param octAsKeyObject Forces a symmetric key to be imported to a KeyObject or CryptoKey. Default
* is true unless JWK "ext" (Extractable) is true.
*/
export async function importJWK(
export async function importJWK<T extends KeyLike = KeyLike>(
jwk: JWK,
alg?: string,
octAsKeyObject?: boolean,
): Promise<KeyLike | Uint8Array> {
): Promise<T | Uint8Array> {
if (!isObject(jwk)) {
throw new TypeError('JWK must be an object')
}
Expand All @@ -174,6 +177,7 @@ export async function importJWK(
octAsKeyObject ??= jwk.ext !== true

if (octAsKeyObject) {
// @ts-ignore
return asKeyObject({ ...jwk, alg, ext: jwk.ext ?? false })
}

Expand All @@ -186,6 +190,7 @@ export async function importJWK(
}
case 'EC':
case 'OKP':
// @ts-ignore
return asKeyObject({ ...jwk, alg })
default:
throw new JOSENotSupported('Unsupported "kty" (Key Type) Parameter value')
Expand Down
4 changes: 2 additions & 2 deletions src/types.d.ts
Expand Up @@ -601,9 +601,9 @@ export interface JWTDecryptResult {
protectedHeader: CompactJWEHeaderParameters
}

export interface ResolvedKey {
export interface ResolvedKey<T extends KeyLike = KeyLike> {
/** Key resolved from the key resolver function. */
key: KeyLike | Uint8Array
key: T | Uint8Array
}

/** Recognized Compact JWS Header Parameters, any other Header Members may also be present. */
Expand Down

0 comments on commit 6effa4d

Please sign in to comment.