Skip to content

Commit

Permalink
feat(core,common,testing): support overriding middleware for testing
Browse files Browse the repository at this point in the history
closes: nestjs#8777
  • Loading branch information
schiemon committed Oct 9, 2023
1 parent 88c8cf8 commit 334dff7
Show file tree
Hide file tree
Showing 11 changed files with 480 additions and 14 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
import {
Injectable,
MiddlewareConsumer,
Module,
NestMiddleware,
} from '@nestjs/common';
import { Test } from '@nestjs/testing';
import * as request from 'supertest';
import { expect } from 'chai';

describe('Middleware overriding', () => {
@Injectable()
class MiddlewareA implements NestMiddleware {
use(req, res, next) {
middlewareAApplied = true;
next();
}
}

function MiddlewareAOverride(req, res, next) {
middlewareAOverrideApplied = true;
next();
}

function MiddlewareB(req, res, next) {
middlewareBApplied = true;
next();
}

@Injectable()
class MiddlewareBOverride implements NestMiddleware {
use(req, res, next) {
middlewareBOverrideApplied = true;
next();
}
}

@Injectable()
class MiddlewareC implements NestMiddleware {
use(req, res, next) {
middlewareCApplied = true;
next();
}
}

@Injectable()
class MiddlewareC1Override implements NestMiddleware {
use(req, res, next) {
middlewareC1OverrideApplied = true;
next();
}
}

function MiddlewareC2Override(req, res, next) {
middlewareC2OverrideApplied = true;
next();
}

@Module({})
class AppModule {
configure(consumer: MiddlewareConsumer) {
return consumer
.apply(MiddlewareA)
.forRoutes('a')
.apply(MiddlewareB)
.forRoutes('b')
.apply(MiddlewareC)
.forRoutes('c');
}
}

let middlewareAApplied: boolean;
let middlewareAOverrideApplied: boolean;

let middlewareBApplied: boolean;
let middlewareBOverrideApplied: boolean;

let middlewareCApplied: boolean;
let middlewareC1OverrideApplied: boolean;
let middlewareC2OverrideApplied: boolean;

const resetMiddlewareApplicationFlags = () => {
middlewareAApplied =
middlewareAOverrideApplied =
middlewareBApplied =
middlewareBOverrideApplied =
middlewareCApplied =
middlewareC1OverrideApplied =
middlewareC2OverrideApplied =
false;
};

beforeEach(() => {
resetMiddlewareApplicationFlags();
});
it('should override class middleware', async () => {
const testingModule = await Test.createTestingModule({
imports: [AppModule],
})
.overrideMiddleware(MiddlewareA)
.useMiddleware(MiddlewareAOverride)
.overrideMiddleware(MiddlewareC)
.useMiddleware(MiddlewareC1Override, MiddlewareC2Override)
.compile();

const app = testingModule.createNestApplication();
await app.init();

await request(app.getHttpServer()).get('/a');

expect(middlewareAApplied).to.be.false;
expect(middlewareAOverrideApplied).to.be.true;
expect(middlewareBApplied).to.be.false;
expect(middlewareBOverrideApplied).to.be.false;
expect(middlewareCApplied).to.be.false;
expect(middlewareC1OverrideApplied).to.be.false;
expect(middlewareC2OverrideApplied).to.be.false;
resetMiddlewareApplicationFlags();

await request(app.getHttpServer()).get('/b');

expect(middlewareAApplied).to.be.false;
expect(middlewareAOverrideApplied).to.be.false;
expect(middlewareBApplied).to.be.true;
expect(middlewareBOverrideApplied).to.be.false;
expect(middlewareCApplied).to.be.false;
expect(middlewareC1OverrideApplied).to.be.false;
expect(middlewareC2OverrideApplied).to.be.false;
resetMiddlewareApplicationFlags();

await request(app.getHttpServer()).get('/c');

expect(middlewareAApplied).to.be.false;
expect(middlewareAOverrideApplied).to.be.false;
expect(middlewareBApplied).to.be.false;
expect(middlewareBOverrideApplied).to.be.false;
expect(middlewareCApplied).to.be.false;
expect(middlewareC1OverrideApplied).to.be.true;
expect(middlewareC2OverrideApplied).to.be.true;
resetMiddlewareApplicationFlags();

await app.close();
});

it('should override functional middleware', async () => {
const testingModule = await Test.createTestingModule({
imports: [AppModule],
})
.overrideMiddleware(MiddlewareB)
.useMiddleware(MiddlewareBOverride)
.compile();

const app = testingModule.createNestApplication();
await app.init();

await request(app.getHttpServer()).get('/a');

expect(middlewareAApplied).to.be.true;
expect(middlewareAOverrideApplied).to.be.false;
expect(middlewareBApplied).to.be.false;
expect(middlewareBOverrideApplied).to.be.false;
expect(middlewareCApplied).to.be.false;
expect(middlewareC1OverrideApplied).to.be.false;
expect(middlewareC2OverrideApplied).to.be.false;
resetMiddlewareApplicationFlags();

await request(app.getHttpServer()).get('/b');

expect(middlewareAApplied).to.be.false;
expect(middlewareAOverrideApplied).to.be.false;
expect(middlewareBApplied).to.be.false;
expect(middlewareBOverrideApplied).to.be.true;
expect(middlewareCApplied).to.be.false;
expect(middlewareC1OverrideApplied).to.be.false;
expect(middlewareC2OverrideApplied).to.be.false;
resetMiddlewareApplicationFlags();

await request(app.getHttpServer()).get('/c');

expect(middlewareAApplied).to.be.false;
expect(middlewareAOverrideApplied).to.be.false;
expect(middlewareBApplied).to.be.false;
expect(middlewareBOverrideApplied).to.be.false;
expect(middlewareCApplied).to.be.true;
expect(middlewareC1OverrideApplied).to.be.false;
expect(middlewareC2OverrideApplied).to.be.false;
resetMiddlewareApplicationFlags();

await app.close();
});
});
17 changes: 17 additions & 0 deletions integration/testing/module-override/tsconfig.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
{
"compilerOptions": {
"module": "commonjs",
"declaration": true,
"removeComments": true,
"emitDecoratorMetadata": true,
"experimentalDecorators": true,
"allowSyntheticDefaultImports": true,
"target": "ES2021",
"sourceMap": true,
"outDir": "./dist",
"baseUrl": "./",
"incremental": true,
"skipLibCheck": true
},
"include": ["src/**/*"]
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,17 @@ export interface MiddlewareConsumer {
* @returns {MiddlewareConfigProxy}
*/
apply(...middleware: (Type<any> | Function)[]): MiddlewareConfigProxy;

/**
* Replaces the currently applied middleware with a new (set of) middleware.
*
* @param {Type | Function} middlewareToReplace middleware class/function to be replaced.
* @param {(Type | Function)[]} middlewareReplacement middleware class/function(s) that serve as a replacement for {@link middlewareToReplace}.
*
* @returns {MiddlewareConsumer}
*/
replace(
middlewareToReplace: Type<any> | Function,
...middlewareReplacement: (Type<any> | Function)[]
): MiddlewareConsumer;
}
60 changes: 48 additions & 12 deletions packages/core/middleware/builder.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,15 @@ import { RouteInfoPathExtractor } from './route-info-path-extractor';
import { RoutesMapper } from './routes-mapper';
import { filterMiddleware } from './utils';

type MiddlewareConfigurationContext = {
middleware: (Type<any> | Function)[];
routes: RouteInfo[];
excludedRoutes: RouteInfo[];
};

export class MiddlewareBuilder implements MiddlewareConsumer {
private readonly middlewareCollection = new Set<MiddlewareConfiguration>();
private readonly middlewareConfigurationContexts: MiddlewareConfigurationContext[] =
[];

constructor(
private readonly routesMapper: RoutesMapper,
Expand All @@ -34,8 +41,39 @@ export class MiddlewareBuilder implements MiddlewareConsumer {
);
}

public replace(
middlewareToReplace: Type<any> | Function,
...middlewareReplacements: Array<Type<any> | Function>
): MiddlewareBuilder {
for (const currentConfigurationContext of this
.middlewareConfigurationContexts) {
currentConfigurationContext.middleware = flatten(
currentConfigurationContext.middleware.map(middleware =>
middleware === middlewareToReplace
? middlewareReplacements
: middleware,
),
) as (Type<any> | Function)[];
}

return this;
}

public getMiddlewareConfigurationContexts(): MiddlewareConfigurationContext[] {
return this.middlewareConfigurationContexts;
}

public build(): MiddlewareConfiguration[] {
return [...this.middlewareCollection];
return this.middlewareConfigurationContexts.map(
({ middleware, routes, excludedRoutes }) => ({
middleware: filterMiddleware(
middleware,
excludedRoutes,
this.getHttpAdapter(),
),
forRoutes: routes,
}),
);
}

public getHttpAdapter(): HttpServer {
Expand Down Expand Up @@ -68,19 +106,17 @@ export class MiddlewareBuilder implements MiddlewareConsumer {
public forRoutes(
...routes: Array<string | Type<any> | RouteInfo>
): MiddlewareConsumer {
const { middlewareCollection } = this.builder;
const { middlewareConfigurationContexts } = this.builder;

const flattedRoutes = this.getRoutesFlatList(routes);
const forRoutes = this.removeOverlappedRoutes(flattedRoutes);
const configuration = {
middleware: filterMiddleware(
this.middleware,
this.excludedRoutes,
this.builder.getHttpAdapter(),
),
forRoutes,
};
middlewareCollection.add(configuration);

middlewareConfigurationContexts.push({
middleware: this.middleware,
routes: forRoutes,
excludedRoutes: this.excludedRoutes,
});

return this.builder;
}

Expand Down
Loading

0 comments on commit 334dff7

Please sign in to comment.