Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 42 additions & 28 deletions apps/sim/app/api/auth/oauth/utils.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,10 @@
* @vitest-environment node
*/

import { loggerMock } from '@sim/testing'
import { databaseMock, loggerMock } from '@sim/testing'
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'

vi.mock('@sim/db', () => ({
db: {
select: vi.fn().mockReturnThis(),
from: vi.fn().mockReturnThis(),
where: vi.fn().mockReturnThis(),
limit: vi.fn().mockReturnValue([]),
update: vi.fn().mockReturnThis(),
set: vi.fn().mockReturnThis(),
orderBy: vi.fn().mockReturnThis(),
},
}))
vi.mock('@sim/db', () => databaseMock)

vi.mock('@/lib/oauth/oauth', () => ({
refreshOAuthToken: vi.fn(),
Expand All @@ -34,13 +24,36 @@ import {
refreshTokenIfNeeded,
} from '@/app/api/auth/oauth/utils'

const mockDbTyped = db as any
const mockDb = db as any
const mockRefreshOAuthToken = refreshOAuthToken as any

/**
* Creates a chainable mock for db.select() calls.
* Returns a nested chain: select() -> from() -> where() -> limit() / orderBy()
*/
function mockSelectChain(limitResult: unknown[]) {
const mockLimit = vi.fn().mockReturnValue(limitResult)
const mockOrderBy = vi.fn().mockReturnValue(limitResult)
const mockWhere = vi.fn().mockReturnValue({ limit: mockLimit, orderBy: mockOrderBy })
const mockFrom = vi.fn().mockReturnValue({ where: mockWhere })
mockDb.select.mockReturnValueOnce({ from: mockFrom })
return { mockFrom, mockWhere, mockLimit }
}

/**
* Creates a chainable mock for db.update() calls.
* Returns a nested chain: update() -> set() -> where()
*/
function mockUpdateChain() {
const mockWhere = vi.fn().mockResolvedValue({})
const mockSet = vi.fn().mockReturnValue({ where: mockWhere })
mockDb.update.mockReturnValueOnce({ set: mockSet })
return { mockSet, mockWhere }
}

describe('OAuth Utils', () => {
beforeEach(() => {
vi.clearAllMocks()
mockDbTyped.limit.mockReturnValue([])
})

afterEach(() => {
Expand All @@ -50,20 +63,20 @@ describe('OAuth Utils', () => {
describe('getCredential', () => {
it('should return credential when found', async () => {
const mockCredential = { id: 'credential-id', userId: 'test-user-id' }
mockDbTyped.limit.mockReturnValueOnce([mockCredential])
const { mockFrom, mockWhere, mockLimit } = mockSelectChain([mockCredential])

const credential = await getCredential('request-id', 'credential-id', 'test-user-id')

expect(mockDbTyped.select).toHaveBeenCalled()
expect(mockDbTyped.from).toHaveBeenCalled()
expect(mockDbTyped.where).toHaveBeenCalled()
expect(mockDbTyped.limit).toHaveBeenCalledWith(1)
expect(mockDb.select).toHaveBeenCalled()
expect(mockFrom).toHaveBeenCalled()
expect(mockWhere).toHaveBeenCalled()
expect(mockLimit).toHaveBeenCalledWith(1)

expect(credential).toEqual(mockCredential)
})

it('should return undefined when credential is not found', async () => {
mockDbTyped.limit.mockReturnValueOnce([])
mockSelectChain([])

const credential = await getCredential('request-id', 'nonexistent-id', 'test-user-id')

Expand Down Expand Up @@ -102,11 +115,12 @@ describe('OAuth Utils', () => {
refreshToken: 'new-refresh-token',
})

mockUpdateChain()

const result = await refreshTokenIfNeeded('request-id', mockCredential, 'credential-id')

expect(mockRefreshOAuthToken).toHaveBeenCalledWith('google', 'refresh-token')
expect(mockDbTyped.update).toHaveBeenCalled()
expect(mockDbTyped.set).toHaveBeenCalled()
expect(mockDb.update).toHaveBeenCalled()
expect(result).toEqual({ accessToken: 'new-token', refreshed: true })
})

Expand Down Expand Up @@ -152,7 +166,7 @@ describe('OAuth Utils', () => {
providerId: 'google',
userId: 'test-user-id',
}
mockDbTyped.limit.mockReturnValueOnce([mockCredential])
mockSelectChain([mockCredential])

const token = await refreshAccessTokenIfNeeded('credential-id', 'test-user-id', 'request-id')

Expand All @@ -169,7 +183,8 @@ describe('OAuth Utils', () => {
providerId: 'google',
userId: 'test-user-id',
}
mockDbTyped.limit.mockReturnValueOnce([mockCredential])
mockSelectChain([mockCredential])
mockUpdateChain()

mockRefreshOAuthToken.mockResolvedValueOnce({
accessToken: 'new-token',
Expand All @@ -180,13 +195,12 @@ describe('OAuth Utils', () => {
const token = await refreshAccessTokenIfNeeded('credential-id', 'test-user-id', 'request-id')

expect(mockRefreshOAuthToken).toHaveBeenCalledWith('google', 'refresh-token')
expect(mockDbTyped.update).toHaveBeenCalled()
expect(mockDbTyped.set).toHaveBeenCalled()
expect(mockDb.update).toHaveBeenCalled()
expect(token).toBe('new-token')
})

it('should return null if credential not found', async () => {
mockDbTyped.limit.mockReturnValueOnce([])
mockSelectChain([])

const token = await refreshAccessTokenIfNeeded('nonexistent-id', 'test-user-id', 'request-id')

Expand All @@ -202,7 +216,7 @@ describe('OAuth Utils', () => {
providerId: 'google',
userId: 'test-user-id',
}
mockDbTyped.limit.mockReturnValueOnce([mockCredential])
mockSelectChain([mockCredential])

mockRefreshOAuthToken.mockResolvedValueOnce(null)

Expand Down
8 changes: 2 additions & 6 deletions apps/sim/app/api/knowledge/search/utils.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,12 @@
*
* @vitest-environment node
*/
import { createEnvMock, createMockLogger } from '@sim/testing'
import { createEnvMock, databaseMock, loggerMock } from '@sim/testing'
import { beforeEach, describe, expect, it, vi } from 'vitest'

const loggerMock = vi.hoisted(() => ({
createLogger: () => createMockLogger(),
}))

vi.mock('drizzle-orm')
vi.mock('@sim/logger', () => loggerMock)
vi.mock('@sim/db')
vi.mock('@sim/db', () => databaseMock)
vi.mock('@/lib/knowledge/documents/utils', () => ({
retryWithExponentialBackoff: (fn: any) => fn(),
}))
Expand Down
23 changes: 9 additions & 14 deletions apps/sim/app/api/schedules/[id]/route.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,14 @@
*
* @vitest-environment node
*/
import { loggerMock } from '@sim/testing'
import { databaseMock, loggerMock } from '@sim/testing'
import { NextRequest } from 'next/server'
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'

const { mockGetSession, mockAuthorizeWorkflowByWorkspacePermission, mockDbSelect, mockDbUpdate } =
vi.hoisted(() => ({
mockGetSession: vi.fn(),
mockAuthorizeWorkflowByWorkspacePermission: vi.fn(),
mockDbSelect: vi.fn(),
mockDbUpdate: vi.fn(),
}))
const { mockGetSession, mockAuthorizeWorkflowByWorkspacePermission } = vi.hoisted(() => ({
mockGetSession: vi.fn(),
mockAuthorizeWorkflowByWorkspacePermission: vi.fn(),
}))

vi.mock('@/lib/auth', () => ({
getSession: mockGetSession,
Expand All @@ -23,12 +20,7 @@ vi.mock('@/lib/workflows/utils', () => ({
authorizeWorkflowByWorkspacePermission: mockAuthorizeWorkflowByWorkspacePermission,
}))

vi.mock('@sim/db', () => ({
db: {
select: mockDbSelect,
update: mockDbUpdate,
},
}))
vi.mock('@sim/db', () => databaseMock)

vi.mock('@sim/db/schema', () => ({
workflow: { id: 'id', userId: 'userId', workspaceId: 'workspaceId' },
Expand Down Expand Up @@ -59,6 +51,9 @@ function createParams(id: string): { params: Promise<{ id: string }> } {
return { params: Promise.resolve({ id }) }
}

const mockDbSelect = databaseMock.db.select as ReturnType<typeof vi.fn>
const mockDbUpdate = databaseMock.db.update as ReturnType<typeof vi.fn>

function mockDbChain(selectResults: unknown[][]) {
let selectCallIndex = 0
mockDbSelect.mockImplementation(() => ({
Expand Down
21 changes: 8 additions & 13 deletions apps/sim/app/api/schedules/route.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,14 @@
*
* @vitest-environment node
*/
import { loggerMock } from '@sim/testing'
import { databaseMock, loggerMock } from '@sim/testing'
import { NextRequest } from 'next/server'
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'

const { mockGetSession, mockAuthorizeWorkflowByWorkspacePermission, mockDbSelect } = vi.hoisted(
() => ({
mockGetSession: vi.fn(),
mockAuthorizeWorkflowByWorkspacePermission: vi.fn(),
mockDbSelect: vi.fn(),
})
)
const { mockGetSession, mockAuthorizeWorkflowByWorkspacePermission } = vi.hoisted(() => ({
mockGetSession: vi.fn(),
mockAuthorizeWorkflowByWorkspacePermission: vi.fn(),
}))

vi.mock('@/lib/auth', () => ({
getSession: mockGetSession,
Expand All @@ -23,11 +20,7 @@ vi.mock('@/lib/workflows/utils', () => ({
authorizeWorkflowByWorkspacePermission: mockAuthorizeWorkflowByWorkspacePermission,
}))

vi.mock('@sim/db', () => ({
db: {
select: mockDbSelect,
},
}))
vi.mock('@sim/db', () => databaseMock)

vi.mock('@sim/db/schema', () => ({
workflow: { id: 'id', userId: 'userId', workspaceId: 'workspaceId' },
Expand Down Expand Up @@ -62,6 +55,8 @@ function createRequest(url: string): NextRequest {
return new NextRequest(new URL(url), { method: 'GET' })
}

const mockDbSelect = databaseMock.db.select as ReturnType<typeof vi.fn>

function mockDbChain(results: any[]) {
let callIndex = 0
mockDbSelect.mockImplementation(() => ({
Expand Down
10 changes: 3 additions & 7 deletions apps/sim/app/api/workflows/[id]/route.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
* @vitest-environment node
*/

import { loggerMock } from '@sim/testing'
import { loggerMock, setupGlobalFetchMock } from '@sim/testing'
import { NextRequest } from 'next/server'
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'

Expand Down Expand Up @@ -284,9 +284,7 @@ describe('Workflow By ID API Route', () => {
where: vi.fn().mockResolvedValue([{ id: 'workflow-123' }]),
})

global.fetch = vi.fn().mockResolvedValue({
ok: true,
})
setupGlobalFetchMock({ ok: true })

const req = new NextRequest('http://localhost:3000/api/workflows/workflow-123', {
method: 'DELETE',
Expand Down Expand Up @@ -331,9 +329,7 @@ describe('Workflow By ID API Route', () => {
where: vi.fn().mockResolvedValue([{ id: 'workflow-123' }]),
})

global.fetch = vi.fn().mockResolvedValue({
ok: true,
})
setupGlobalFetchMock({ ok: true })

const req = new NextRequest('http://localhost:3000/api/workflows/workflow-123', {
method: 'DELETE',
Expand Down
Loading