Skip to content

Commit

Permalink
test(server): auth tests (#6135)
Browse files Browse the repository at this point in the history
  • Loading branch information
forehalo committed Mar 26, 2024
1 parent 1c9d899 commit 1a1af83
Show file tree
Hide file tree
Showing 19 changed files with 1,058 additions and 96 deletions.
28 changes: 4 additions & 24 deletions packages/backend/server/src/core/auth/controller.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,15 @@ import {
Controller,
Get,
Header,
HttpStatus,
Post,
Query,
Req,
Res,
} from '@nestjs/common';
import type { Request, Response } from 'express';

import {
Config,
PaymentRequiredException,
URLHelper,
} from '../../fundamentals';
import { PaymentRequiredException, URLHelper } from '../../fundamentals';
import { UserService } from '../user';
import { validators } from '../utils/validators';
import { CurrentUser } from './current-user';
Expand All @@ -33,7 +30,6 @@ class SignInCredential {
@Controller('/api/auth')
export class AuthController {
constructor(
private readonly config: Config,
private readonly url: URLHelper,
private readonly auth: AuthService,
private readonly user: UserService,
Expand Down Expand Up @@ -64,7 +60,7 @@ export class AuthController {
);

await this.auth.setCookie(req, res, user);
res.send(user);
res.status(HttpStatus.OK).send(user);
} else {
// send email magic link
const user = await this.user.findUserByEmail(credential.email);
Expand All @@ -77,7 +73,7 @@ export class AuthController {
throw new Error('Failed to send sign-in email.');
}

res.send({
res.status(HttpStatus.OK).send({
email: credential.email,
});
}
Expand Down Expand Up @@ -162,22 +158,6 @@ export class AuthController {
return this.url.safeRedirect(res, redirectUri);
}

@Get('/authorize')
async authorize(
@CurrentUser() user: CurrentUser,
@Query('redirect_uri') redirect_uri?: string
) {
const session = await this.auth.createUserSession(
user,
undefined,
this.config.auth.accessToken.ttl
);

this.url.link(redirect_uri ?? '/open-app/redirect', {
token: session.sessionId,
});
}

@Public()
@Get('/session')
async currentSessionUser(@CurrentUser() user?: CurrentUser) {
Expand Down
4 changes: 2 additions & 2 deletions packages/backend/server/src/core/auth/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import { UserModule } from '../user';
import { AuthController } from './controller';
import { AuthResolver } from './resolver';
import { AuthService } from './service';
import { TokenService } from './token';
import { TokenService, TokenType } from './token';

@Module({
imports: [FeatureModule, UserModule],
Expand All @@ -17,5 +17,5 @@ export class AuthModule {}

export * from './guard';
export { ClientTokenType } from './resolver';
export { AuthService };
export { AuthService, TokenService, TokenType };
export * from './current-user';
6 changes: 4 additions & 2 deletions packages/backend/server/src/core/auth/resolver.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import {
import type { Request, Response } from 'express';

import { CloudThrottlerGuard, Config, Throttle } from '../../fundamentals';
import { UserService } from '../user';
import { UserType } from '../user/types';
import { validators } from '../utils/validators';
import { CurrentUser } from './current-user';
Expand Down Expand Up @@ -48,6 +49,7 @@ export class AuthResolver {
constructor(
private readonly config: Config,
private readonly auth: AuthService,
private readonly user: UserService,
private readonly token: TokenService
) {}

Expand Down Expand Up @@ -165,7 +167,7 @@ export class AuthResolver {
throw new ForbiddenException('Invalid token');
}

await this.auth.changePassword(user.email, newPassword);
await this.auth.changePassword(user.id, newPassword);

return user;
}
Expand Down Expand Up @@ -319,7 +321,7 @@ export class AuthResolver {
throw new ForbiddenException('Invalid token');
}

const hasRegistered = await this.auth.getUserByEmail(email);
const hasRegistered = await this.user.findUserByEmail(email);

if (hasRegistered) {
if (hasRegistered.id !== user.id) {
Expand Down
76 changes: 30 additions & 46 deletions packages/backend/server/src/core/auth/service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,38 +2,39 @@ import {
BadRequestException,
Injectable,
NotAcceptableException,
NotFoundException,
OnApplicationBootstrap,
} from '@nestjs/common';
import type { User } from '@prisma/client';
import { PrismaClient } from '@prisma/client';
import type { CookieOptions, Request, Response } from 'express';
import { assign, omit } from 'lodash-es';

import {
Config,
CryptoHelper,
MailService,
SessionCache,
} from '../../fundamentals';
import { Config, CryptoHelper, MailService } from '../../fundamentals';
import { FeatureManagementService } from '../features/management';
import { UserService } from '../user/service';
import type { CurrentUser } from './current-user';

export function parseAuthUserSeqNum(value: any) {
let seq: number = 0;
switch (typeof value) {
case 'number': {
return value;
seq = value;
break;
}
case 'string': {
value = Number.parseInt(value);
return Number.isNaN(value) ? 0 : value;
const result = value.match(/^([\d{0, 10}])$/);
if (result?.[1]) {
seq = Number(result[1]);
}
break;
}

default: {
return 0;
seq = 0;
}
}

return Math.max(0, seq);
}

export function sessionUser(
Expand All @@ -57,7 +58,6 @@ export class AuthService implements OnApplicationBootstrap {
sameSite: 'lax',
httpOnly: true,
path: '/',
domain: this.config.host,
secure: this.config.https,
};
static readonly sessionCookieName = 'sid';
Expand All @@ -69,8 +69,7 @@ export class AuthService implements OnApplicationBootstrap {
private readonly mailer: MailService,
private readonly feature: FeatureManagementService,
private readonly user: UserService,
private readonly crypto: CryptoHelper,
private readonly cache: SessionCache
private readonly crypto: CryptoHelper
) {}

async onApplicationBootstrap() {
Expand All @@ -90,7 +89,7 @@ export class AuthService implements OnApplicationBootstrap {
email: string,
password: string
): Promise<CurrentUser> {
const user = await this.getUserByEmail(email);
const user = await this.user.findUserByEmail(email);

if (user) {
throw new BadRequestException('Email was taken');
Expand All @@ -111,12 +110,12 @@ export class AuthService implements OnApplicationBootstrap {
const user = await this.user.findUserWithHashedPasswordByEmail(email);

if (!user) {
throw new NotFoundException('User Not Found');
throw new NotAcceptableException('Invalid sign in credentials');
}

if (!user.password) {
throw new NotAcceptableException(
'User Password is not set. Should login throw email link.'
'User Password is not set. Should login through email link.'
);
}

Expand All @@ -126,28 +125,12 @@ export class AuthService implements OnApplicationBootstrap {
);

if (!passwordMatches) {
throw new NotAcceptableException('Incorrect Password');
throw new NotAcceptableException('Invalid sign in credentials');
}

return sessionUser(user);
}

async getUserWithCache(token: string, seq = 0) {
const cacheKey = `session:${token}:${seq}`;
let user = await this.cache.get<CurrentUser | null>(cacheKey);
if (user) {
return user;
}

user = await this.getUser(token, seq);

if (user) {
await this.cache.set(cacheKey, user);
}

return user;
}

async getUser(token: string, seq = 0): Promise<CurrentUser | null> {
const session = await this.getSession(token);

Expand Down Expand Up @@ -198,7 +181,16 @@ export class AuthService implements OnApplicationBootstrap {
// Session
// | { user: LimitedUser { email, avatarUrl }, expired: true }
// | { user: User, expired: false }
return users.map(sessionUser);
return session.userSessions
.map(userSession => {
// keep users in the same order as userSessions
const user = users.find(({ id }) => id === userSession.userId);
if (!user) {
return null;
}
return sessionUser(user);
})
.filter(Boolean) as CurrentUser[];
}

async signOut(token: string, seq = 0) {
Expand Down Expand Up @@ -319,12 +311,8 @@ export class AuthService implements OnApplicationBootstrap {
});
}

async getUserByEmail(email: string) {
return this.user.findUserByEmail(email);
}

async changePassword(email: string, newPassword: string): Promise<User> {
const user = await this.getUserByEmail(email);
async changePassword(id: string, newPassword: string): Promise<User> {
const user = await this.user.findUserById(id);

if (!user) {
throw new BadRequestException('Invalid email');
Expand All @@ -343,11 +331,7 @@ export class AuthService implements OnApplicationBootstrap {
}

async changeEmail(id: string, newEmail: string): Promise<User> {
const user = await this.db.user.findUnique({
where: {
id,
},
});
const user = await this.user.findUserById(id);

if (!user) {
throw new BadRequestException('Invalid email');
Expand Down
12 changes: 12 additions & 0 deletions packages/backend/server/src/core/auth/token.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { randomUUID } from 'node:crypto';

import { Injectable } from '@nestjs/common';
import { Cron, CronExpression } from '@nestjs/schedule';
import { PrismaClient } from '@prisma/client';

import { CryptoHelper } from '../../fundamentals/helpers';
Expand Down Expand Up @@ -81,4 +82,15 @@ export class TokenService {

return valid ? record : null;
}

@Cron(CronExpression.EVERY_DAY_AT_MIDNIGHT)
cleanExpiredTokens() {
return this.db.verificationToken.deleteMany({
where: {
expiresAt: {
lte: new Date(),
},
},
});
}
}
6 changes: 0 additions & 6 deletions packages/backend/server/src/fundamentals/config/def.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,6 @@ declare global {
}
}

export enum ExternalAccount {
github = 'github',
google = 'google',
firebase = 'firebase',
}

export type ServerFlavor = 'allinone' | 'graphql' | 'sync';
export type AFFINE_ENV = 'dev' | 'beta' | 'production';
export type NODE_ENV = 'development' | 'test' | 'production';
Expand Down
8 changes: 7 additions & 1 deletion packages/backend/server/src/fundamentals/prisma/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,13 @@ import { PrismaService } from './service';
// only `PrismaClient` can be injected
const clientProvider: Provider = {
provide: PrismaClient,
useClass: PrismaService,
useFactory: () => {
if (PrismaService.INSTANCE) {
return PrismaService.INSTANCE;
}

return new PrismaService();
},
};

@Global()
Expand Down
5 changes: 4 additions & 1 deletion packages/backend/server/src/fundamentals/prisma/service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ export class PrismaService
}

async onModuleDestroy(): Promise<void> {
await this.$disconnect();
if (!AFFiNE.node.test) {
await this.$disconnect();
PrismaService.INSTANCE = null;
}
}
}
2 changes: 1 addition & 1 deletion packages/backend/server/src/plugins/oauth/controller.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ export class OAuthController {
const provider = this.providerFactory.get(providerName);

if (!provider) {
throw new BadRequestException('Invalid provider');
throw new BadRequestException('Invalid OAuth provider');
}

const state = await this.oauth.saveOAuthState({
Expand Down
2 changes: 1 addition & 1 deletion packages/backend/server/src/plugins/oauth/register.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ export function registerOAuthProvider(
@Injectable()
export class OAuthProviderFactory {
get providers() {
return PROVIDERS.keys();
return Array.from(PROVIDERS.keys());
}

get(name: OAuthProviderName): OAuthProvider | undefined {
Expand Down
Loading

0 comments on commit 1a1af83

Please sign in to comment.