diff --git a/package.json b/package.json index 1450f114..fd8eed12 100644 --- a/package.json +++ b/package.json @@ -13,7 +13,7 @@ "/types/**/*.d.ts" ], "scripts": { - "test": "tsc -noEmit -p tsconfig-test.json && jest --useStderr --runInBand --detectOpenHandles --forceExit", + "test": "tsc -noEmit -p tsconfig-test.json && jest --useStderr --runInBand --detectOpenHandles", "build": "npm run lint && tsc --emitDeclarationOnly && ./build.js", "prepack": "npm run build", "lint": "eslint --ext .ts,.js .", diff --git a/src/classifications/scheduler.ts b/src/classifications/scheduler.ts index 194a6f1d..27298bef 100644 --- a/src/classifications/scheduler.ts +++ b/src/classifications/scheduler.ts @@ -95,23 +95,27 @@ export default class ClassificationsScheduler extends CommandBase { pollForCompletion = (id: any): Promise => { return new Promise((resolve, reject) => { - setTimeout( - () => - reject( - new Error( - "classification didn't finish within configured timeout, " + - 'set larger timeout with .withWaitTimeout(timeout)' - ) - ), - this.waitTimeout - ); - - setInterval(() => { + const timeout = setTimeout(() => { + clearInterval(interval); + clearTimeout(timeout); + reject( + new Error( + "classification didn't finish within configured timeout, " + + 'set larger timeout with .withWaitTimeout(timeout)' + ) + ); + }, this.waitTimeout); + + const interval = setInterval(() => { new ClassificationsGetter(this.client) .withId(id) .do() .then((res: Classification) => { - if (res.status === 'completed') resolve(res); + if (res.status === 'completed') { + clearInterval(interval); + clearTimeout(timeout); + resolve(res); + } }); }, 500); }); diff --git a/src/connection/auth.ts b/src/connection/auth.ts index 71bcf2a5..afba8755 100644 --- a/src/connection/auth.ts +++ b/src/connection/auth.ts @@ -6,17 +6,22 @@ interface AuthenticatorResult { refreshToken: string; } +interface OidcCredentials { + silentRefresh: boolean; +} + export interface OidcAuthFlow { refresh: () => Promise; } export class OidcAuthenticator { private readonly http: HttpClient; - private readonly creds: any; + private readonly creds: OidcCredentials; private accessToken: string; private refreshToken?: string; private expiresAt: number; private refreshRunning: boolean; + private refreshInterval!: NodeJS.Timeout; constructor(http: HttpClient, creds: any) { this.http = http; @@ -57,10 +62,7 @@ export class OidcAuthenticator { this.accessToken = resp.accessToken; this.expiresAt = resp.expiresAt; this.refreshToken = resp.refreshToken; - if (!this.refreshRunning && this.refreshTokenProvided()) { - this.runBackgroundTokenRefresh(authenticator); - this.refreshRunning = true; - } + this.startTokenRefresh(authenticator); }); }; @@ -75,17 +77,25 @@ export class OidcAuthenticator { }); }; - runBackgroundTokenRefresh = (authenticator: { refresh: () => any }) => { - setInterval(async () => { - // check every 30s if the token will expire in <= 1m, - // if so, refresh - if (this.expiresAt - Date.now() <= 60_000) { - const resp = await authenticator.refresh(); - this.accessToken = resp.accessToken; - this.expiresAt = resp.expiresAt; - this.refreshToken = resp.refreshToken; - } - }, 30_000); + startTokenRefresh = (authenticator: { refresh: () => any }) => { + if (this.creds.silentRefresh && !this.refreshRunning && this.refreshTokenProvided()) { + this.refreshInterval = setInterval(async () => { + // check every 30s if the token will expire in <= 1m, + // if so, refresh + if (this.expiresAt - Date.now() <= 60_000) { + const resp = await authenticator.refresh(); + this.accessToken = resp.accessToken; + this.expiresAt = resp.expiresAt; + this.refreshToken = resp.refreshToken; + } + }, 30_000); + this.refreshRunning = true; + } + }; + + stopTokenRefresh = () => { + clearInterval(this.refreshInterval); + this.refreshRunning = false; }; refreshTokenProvided = () => { @@ -109,16 +119,19 @@ export interface UserPasswordCredentialsInput { username: string; password?: string; scopes?: any[]; + silentRefresh?: boolean; } -export class AuthUserPasswordCredentials { +export class AuthUserPasswordCredentials implements OidcCredentials { private username: string; private password?: string; private scopes?: any[]; + public readonly silentRefresh: boolean; constructor(creds: UserPasswordCredentialsInput) { this.username = creds.username; this.password = creds.password; this.scopes = creds.scopes; + this.silentRefresh = parseSilentRefresh(creds.silentRefresh); } } @@ -190,18 +203,21 @@ export interface AccessTokenCredentialsInput { accessToken: string; expiresIn: number; refreshToken?: string; + silentRefresh?: boolean; } -export class AuthAccessTokenCredentials { +export class AuthAccessTokenCredentials implements OidcCredentials { public readonly accessToken: string; public readonly expiresAt: number; public readonly refreshToken?: string; + public readonly silentRefresh: boolean; constructor(creds: AccessTokenCredentialsInput) { this.validate(creds); this.accessToken = creds.accessToken; this.expiresAt = calcExpirationEpoch(creds.expiresIn); this.refreshToken = creds.refreshToken; + this.silentRefresh = parseSilentRefresh(creds.silentRefresh); } validate = (creds: AccessTokenCredentialsInput) => { @@ -270,15 +286,18 @@ class AccessTokenAuthenticator implements OidcAuthFlow { export interface ClientCredentialsInput { clientSecret: string; scopes?: any[]; + silentRefresh?: boolean; } -export class AuthClientCredentials { +export class AuthClientCredentials implements OidcCredentials { private clientSecret: any; private scopes?: any[]; + public readonly silentRefresh: boolean; constructor(creds: ClientCredentialsInput) { this.clientSecret = creds.clientSecret; this.scopes = creds.scopes; + this.silentRefresh = parseSilentRefresh(creds.silentRefresh); } } @@ -345,3 +364,12 @@ export class ApiKey { function calcExpirationEpoch(expiresIn: number): number { return Date.now() + (expiresIn - 2) * 1000; // -2 for some lag } + +function parseSilentRefresh(silentRefresh: boolean | undefined): boolean { + // Silent token refresh by default + if (silentRefresh === undefined) { + return true; + } else { + return silentRefresh; + } +} diff --git a/src/connection/index.ts b/src/connection/index.ts index 27ba2adc..6a1edffd 100644 --- a/src/connection/index.ts +++ b/src/connection/index.ts @@ -8,10 +8,10 @@ import { Variables } from 'graphql-request'; export default class Connection { private apiKey?: string; - private oidcAuth?: OidcAuthenticator; private authEnabled: boolean; private gql: GraphQLClient; public readonly http: HttpClient; + public oidcAuth?: OidcAuthenticator; constructor(params: ConnectionParams) { params = this.sanitizeParams(params); diff --git a/src/connection/journey.test.ts b/src/connection/journey.test.ts index 5974ee3f..e25080ff 100644 --- a/src/connection/journey.test.ts +++ b/src/connection/journey.test.ts @@ -21,6 +21,7 @@ describe('connection', () => { authClientSecret: new AuthUserPasswordCredentials({ username: 'ms_2d0e007e7136de11d5f29fce7a53dae219a51458@existiert.net', password: process.env.WCS_DUMMY_CI_PW, + silentRefresh: false, }), }); @@ -46,6 +47,7 @@ describe('connection', () => { host: 'localhost:8081', authClientSecret: new AuthClientCredentials({ clientSecret: process.env.AZURE_CLIENT_SECRET, + silentRefresh: false, }), }); @@ -72,6 +74,7 @@ describe('connection', () => { authClientSecret: new AuthClientCredentials({ clientSecret: process.env.OKTA_CLIENT_SECRET, scopes: ['some_scope'], + silentRefresh: false, }), }); @@ -98,6 +101,7 @@ describe('connection', () => { authClientSecret: new AuthUserPasswordCredentials({ username: 'test@test.de', password: process.env.OKTA_DUMMY_CI_PW, + silentRefresh: false, }), }); @@ -124,6 +128,7 @@ describe('connection', () => { authClientSecret: new AuthUserPasswordCredentials({ username: 'ms_2d0e007e7136de11d5f29fce7a53dae219a51458@existiert.net', password: process.env.WCS_DUMMY_CI_PW, + silentRefresh: false, }), }); @@ -168,6 +173,7 @@ describe('connection', () => { authClientSecret: new AuthUserPasswordCredentials({ username: 'ms_2d0e007e7136de11d5f29fce7a53dae219a51458@existiert.net', password: process.env.WCS_DUMMY_CI_PW, + silentRefresh: false, }), }); // obtain access token with user/pass so we can @@ -189,6 +195,7 @@ describe('connection', () => { .do() .then((res: any) => { expect(res.version).toBeDefined(); + client.oidcAuth?.stopTokenRefresh(); }) .catch((e: any) => { throw new Error('it should not have errord: ' + e); @@ -207,6 +214,7 @@ describe('connection', () => { authClientSecret: new AuthUserPasswordCredentials({ username: 'ms_2d0e007e7136de11d5f29fce7a53dae219a51458@existiert.net', password: process.env.WCS_DUMMY_CI_PW, + silentRefresh: false, }), }); // obtain access token with user/pass so we can @@ -231,6 +239,7 @@ describe('connection', () => { .then((resp) => { expect(resp).toBeDefined(); expect(resp != '').toBeTruthy(); + conn.oidcAuth?.stopTokenRefresh(); }) .catch((e: any) => { throw new Error('it should not have errord: ' + e); diff --git a/src/connection/unit.test.ts b/src/connection/unit.test.ts index 09268eac..c061f764 100644 --- a/src/connection/unit.test.ts +++ b/src/connection/unit.test.ts @@ -17,6 +17,7 @@ describe('mock server auth tests', () => { authClientSecret: new AuthClientCredentials({ clientSecret: 'supersecret', scopes: ['some_scope'], + silentRefresh: false, }), }); @@ -58,6 +59,7 @@ describe('mock server auth tests', () => { expect(token).toEqual('access_token_000'); expect((conn as any).oidcAuth?.refreshToken).toEqual('refresh_token_000'); expect((conn as any).oidcAuth?.expiresAt).toBeGreaterThan(Date.now()); + conn.oidcAuth?.stopTokenRefresh(); }) .catch((e) => { throw new Error('it should not have failed: ' + e); @@ -94,6 +96,7 @@ describe('mock server auth tests', () => { expect(token).toEqual('access_token_000'); expect((conn as any).oidcAuth?.refreshToken).toEqual('refresh_token_000'); expect((conn as any).oidcAuth?.expiresAt).toBeGreaterThan(Date.now()); + conn.oidcAuth?.stopTokenRefresh(); }) .catch((e) => { throw new Error('it should not have failed: ' + e); diff --git a/src/index.ts b/src/index.ts index c91d53a6..c2f369a9 100644 --- a/src/index.ts +++ b/src/index.ts @@ -14,6 +14,7 @@ import { AuthAccessTokenCredentials, AuthClientCredentials, AuthUserPasswordCredentials, + OidcAuthenticator, } from './connection/auth'; import MetaGetter from './misc/metaGetter'; import { EmbeddedDB, EmbeddedOptions } from './embedded'; @@ -38,6 +39,7 @@ export interface WeaviateClient { backup: Backup; cluster: Cluster; embedded?: EmbeddedDB; + oidcAuth?: OidcAuthenticator; } const app = { @@ -67,9 +69,8 @@ const app = { cluster: cluster(conn), }; - if (params.embedded) { - ifc.embedded = new EmbeddedDB(params.embedded); - } + if (params.embedded) ifc.embedded = new EmbeddedDB(params.embedded); + if (conn.oidcAuth) ifc.oidcAuth = conn.oidcAuth; return ifc; },