diff --git a/.changeset/tender-falcons-design.md b/.changeset/tender-falcons-design.md new file mode 100644 index 00000000..502bef86 --- /dev/null +++ b/.changeset/tender-falcons-design.md @@ -0,0 +1,5 @@ +--- +"@chat-adapter/discord": minor +--- + +Add support for slash command interactions — `onSlashCommand` handlers are now invoked when a Discord user triggers an application command, and the deferred "thinking" response is resolved automatically via the interaction token. diff --git a/packages/adapter-discord/src/index.test.ts b/packages/adapter-discord/src/index.test.ts index c6b114bc..5ccaed29 100644 --- a/packages/adapter-discord/src/index.test.ts +++ b/packages/adapter-discord/src/index.test.ts @@ -4,9 +4,9 @@ import { generateKeyPairSync, sign } from "node:crypto"; import { ValidationError } from "@chat-adapter/shared"; -import type { Logger } from "chat"; +import type { ChatInstance, Logger, StateAdapter } from "chat"; import { InteractionType } from "discord-api-types/v10"; -import { describe, expect, it, vi } from "vitest"; +import { afterEach, beforeEach, describe, expect, it, vi } from "vitest"; import { createDiscordAdapter, DiscordAdapter } from "./index"; import { DiscordFormatConverter } from "./markdown"; @@ -67,6 +67,52 @@ function createWebhookRequest( }); } +function createMockState(): StateAdapter & { cache: Map } { + const cache = new Map(); + return { + cache, + connect: vi.fn().mockResolvedValue(undefined), + disconnect: vi.fn().mockResolvedValue(undefined), + subscribe: vi.fn().mockResolvedValue(undefined), + unsubscribe: vi.fn().mockResolvedValue(undefined), + isSubscribed: vi.fn().mockResolvedValue(false), + acquireLock: vi.fn().mockResolvedValue(null), + releaseLock: vi.fn().mockResolvedValue(undefined), + extendLock: vi.fn().mockResolvedValue(true), + get: vi + .fn() + .mockImplementation((key: string) => + Promise.resolve(cache.get(key) ?? null) + ), + set: vi.fn().mockImplementation((key: string, value: unknown) => { + cache.set(key, value); + return Promise.resolve(); + }), + delete: vi.fn().mockImplementation((key: string) => { + cache.delete(key); + return Promise.resolve(); + }), + }; +} + +function createMockChatInstance(state: StateAdapter): ChatInstance { + return { + processMessage: vi.fn(), + handleIncomingMessage: vi.fn().mockResolvedValue(undefined), + processReaction: vi.fn(), + processAction: vi.fn(), + processAppHomeOpened: vi.fn(), + processAssistantContextChanged: vi.fn(), + processAssistantThreadStarted: vi.fn(), + processModalSubmit: vi.fn().mockResolvedValue(undefined), + processModalClose: vi.fn(), + processSlashCommand: vi.fn(), + getState: () => state, + getUserName: () => "test-bot", + getLogger: () => mockLogger, + }; +} + // ============================================================================ // Factory Function Tests // ============================================================================ @@ -374,38 +420,146 @@ describe("handleWebhook - APPLICATION_COMMAND", () => { logger: mockLogger, }); - it("handles slash command interaction", async () => { - const body = JSON.stringify({ - type: InteractionType.ApplicationCommand, - id: "interaction123", - application_id: "test-app-id", - token: "interaction-token", - version: 1, - guild_id: "guild123", - channel_id: "channel456", - member: { - user: { - id: "user789", - username: "testuser", - discriminator: "0001", - }, - roles: [], - joined_at: "2021-01-01T00:00:00.000Z", - }, - data: { - id: "cmd123", - name: "test", - type: 1, + const mockState = createMockState(); + const mockChat = createMockChatInstance(mockState); + + adapter.initialize(mockChat); + + const slashCommandBody = JSON.stringify({ + type: InteractionType.ApplicationCommand, + id: "interaction123", + application_id: "test-app-id", + token: "interaction-token", + version: 1, + guild_id: "guild123", + channel_id: "channel456", + member: { + user: { + id: "user789", + username: "testuser", + discriminator: "0001", }, - }); - const request = createWebhookRequest(body); + roles: [], + joined_at: "2021-01-01T00:00:00.000Z", + }, + data: { + id: "cmd123", + name: "test", + type: 1, + }, + }); + it("ACKs with DeferredChannelMessageWithSource", async () => { + const request = createWebhookRequest(slashCommandBody); const response = await adapter.handleWebhook(request); expect(response.status).toBe(200); - const responseBody = await response.json(); expect(responseBody).toEqual({ type: 5 }); // DeferredChannelMessageWithSource }); + + it("invokes processSlashCommand with correct event", async () => { + const processSlashCommand = mockChat.processSlashCommand as ReturnType< + typeof vi.fn + >; + processSlashCommand.mockClear(); + const request = createWebhookRequest(slashCommandBody); + await adapter.handleWebhook(request); + + expect(processSlashCommand).toHaveBeenCalledOnce(); + const [event] = processSlashCommand.mock.calls[0] as [ + Record, + ]; + expect(event.command).toBe("/test"); + expect(event.channelId).toBe("discord:guild123:channel456"); + expect(event.triggerId).toBe("interaction-token"); + expect((event.user as { userId: string }).userId).toBe("user789"); + expect(event.raw).toMatchObject({ id: "interaction123" }); + }); +}); + +// ============================================================================ +// postChannelMessage - Interaction Token Resolution Tests +// ============================================================================ + +describe("postChannelMessage - interaction token resolution", () => { + const channelId = "discord:guild123:channel456"; + const stateKey = `discord:interaction-token:${channelId}`; + const msgResponse = { id: "msg-1", channel_id: "channel456" }; + + let adapter: InstanceType; + let mockState: ReturnType; + + beforeEach(() => { + adapter = createDiscordAdapter({ + botToken: "test-bot-token", + publicKey: testPublicKey, + applicationId: "test-app-id", + logger: mockLogger, + }); + mockState = createMockState(); + adapter.initialize(createMockChatInstance(mockState)); + }); + + afterEach(() => { + vi.restoreAllMocks(); + }); + + it("PATCHes @original with JSON when token is in state", async () => { + await mockState.set(stateKey, "my-token"); + vi.spyOn(global, "fetch").mockResolvedValue( + new Response(JSON.stringify(msgResponse), { status: 200 }) + ); + + const result = await adapter.postChannelMessage(channelId, "Hello!"); + + const [url, init] = vi.mocked(fetch).mock.calls[0] as [string, RequestInit]; + expect(url).toContain("/webhooks/test-app-id/my-token/messages/@original"); + expect(init.method).toBe("PATCH"); + expect(init.body).not.toBeInstanceOf(FormData); + expect(result.id).toBe("msg-1"); + expect(await mockState.get(stateKey)).toBeNull(); + }); + + it("PATCHes @original with multipart when token is in state and files are present", async () => { + await mockState.set(stateKey, "my-token"); + vi.spyOn(global, "fetch").mockResolvedValue( + new Response(JSON.stringify(msgResponse), { status: 200 }) + ); + + await adapter.postChannelMessage(channelId, { + raw: "Here is a file", + files: [ + { + filename: "test.txt", + data: Buffer.from("hello"), + mimeType: "text/plain", + }, + ], + }); + + const [url, init] = vi.mocked(fetch).mock.calls[0] as [string, RequestInit]; + expect(url).toContain("/webhooks/test-app-id/my-token/messages/@original"); + expect(init.method).toBe("PATCH"); + expect(init.body).toBeInstanceOf(FormData); + expect(await mockState.get(stateKey)).toBeNull(); + }); + + it("deletes token and falls back to channel POST when PATCH fails", async () => { + await mockState.set(stateKey, "bad-token"); + vi.spyOn(global, "fetch") + .mockResolvedValueOnce(new Response("Server Error", { status: 500 })) + .mockResolvedValueOnce( + new Response(JSON.stringify(msgResponse), { status: 200 }) + ); + + const result = await adapter.postChannelMessage(channelId, "Hello!"); + + expect(vi.mocked(fetch)).toHaveBeenCalledTimes(2); + const [secondUrl] = vi.mocked(fetch).mock.calls[1] as [string, RequestInit]; + expect(secondUrl).toContain("/channels/channel456/messages"); + expect(result.id).toBe("msg-1"); + expect(await mockState.get(stateKey)).toBeNull(); + }); }); // ============================================================================ diff --git a/packages/adapter-discord/src/index.ts b/packages/adapter-discord/src/index.ts index f328d8c5..0a29bbe8 100644 --- a/packages/adapter-discord/src/index.ts +++ b/packages/adapter-discord/src/index.ts @@ -74,6 +74,11 @@ const DISCORD_API_BASE = "https://discord.com/api/v10"; const DISCORD_MAX_CONTENT_LENGTH = 2000; const HEX_64_PATTERN = /^[0-9a-f]{64}$/; const HEX_PATTERN = /^[0-9a-f]+$/; +/** Discord interaction tokens are valid for 15 minutes. + * @see https://docs.discord.com/developers/interactions/receiving-and-responding#interaction-callback + */ +const INTERACTION_TOKEN_TTL_MS = 15 * 60 * 1000; +const INTERACTION_TOKEN_KEY_PREFIX = "discord:interaction-token:"; export class DiscordAdapter implements Adapter { readonly name = "discord"; @@ -206,9 +211,9 @@ export class DiscordAdapter implements Adapter { }); } - // Handle APPLICATION_COMMAND (slash commands - not implemented yet) + // Handle APPLICATION_COMMAND (slash commands) if (interaction.type === InteractionType.ApplicationCommand) { - // For now, just ACK + await this.handleSlashCommandInteraction(interaction, options); return this.respondToInteraction({ type: InteractionResponseType.DeferredChannelMessageWithSource, }); @@ -366,6 +371,102 @@ export class DiscordAdapter implements Adapter { this.chat.processAction(actionEvent, options); } + /** + * Handle APPLICATION_COMMAND interactions (slash commands). + */ + private async handleSlashCommandInteraction( + interaction: DiscordInteraction, + options?: WebhookOptions + ): Promise { + if (!this.chat) { + this.logger.warn("Chat instance not initialized, ignoring slash command"); + return; + } + + const commandName = interaction.data?.name; + if (!commandName) { + this.logger.warn("No command name in slash command interaction"); + return; + } + + const user = interaction.member?.user || interaction.user; + if (!user) { + this.logger.warn("No user in slash command interaction"); + return; + } + + const interactionChannelId = interaction.channel_id; + if (!interactionChannelId) { + this.logger.warn("No channel_id in slash command interaction"); + return; + } + + const guildId = interaction.guild_id || "@me"; + + // Detect if the command was invoked inside a thread channel + // Discord channel types: 11 = public thread, 12 = private thread + const channel = interaction.channel; + const isThread = channel?.type === 11 || channel?.type === 12; + const parentChannelId = + isThread && channel?.parent_id ? channel.parent_id : interactionChannelId; + + const channelId = isThread + ? this.encodeThreadId({ + guildId, + channelId: parentChannelId, + threadId: interactionChannelId, + }) + : this.encodeThreadId({ guildId, channelId: interactionChannelId }); + + // Join top-level option values into a text string (simple v1 approach) + const text = + interaction.data?.options + ?.map((opt) => String(opt.value ?? "")) + .join(" ") ?? ""; + + this.logger.debug("Processing Discord slash command", { + command: `/${commandName}`, + channelId, + triggerId: interaction.token, + }); + + // Store the interaction token in central state so postChannelMessage can resolve + // the deferred "thinking" response, even across serverless invocations. + try { + await this.chat + .getState() + .set( + `${INTERACTION_TOKEN_KEY_PREFIX}${channelId}`, + interaction.token, + INTERACTION_TOKEN_TTL_MS + ); + } catch (error) { + this.logger.warn("Failed to store interaction token", { + error: String(error), + channelId, + }); + // Continue processing — postChannelMessage will fall back to a regular channel message + } + + const event = { + command: `/${commandName}`, + text, + user: { + userId: user.id, + userName: user.username, + fullName: user.global_name || user.username, + isBot: user.bot ?? false, + isMe: false, + }, + adapter: this, + raw: interaction, + triggerId: interaction.token, + channelId, + }; + + this.chat.processSlashCommand(event, options); + } + /** * Handle a forwarded Gateway event received via webhook. */ @@ -734,7 +835,7 @@ export class DiscordAdapter implements Adapter { continue; } const buffer = await toBuffer(file.data, { - platform: "discord" as "slack", + platform: "discord", }); if (!buffer) { continue; @@ -2011,6 +2112,82 @@ export class DiscordAdapter implements Adapter { } const files = extractFiles(message); + + // If there is a pending slash command interaction for this channel, resolve the + // deferred "thinking" message by PATCHing the original response via the interaction + // webhook instead of posting a new channel message. This check must come before the + // files branch — otherwise file responses exit early and the token is never consumed. + const { chat } = this; + if (chat) { + const stateKey = `${INTERACTION_TOKEN_KEY_PREFIX}${channelId}`; + const interactionToken = await chat.getState().get(stateKey); + if (interactionToken) { + const patchPath = `/webhooks/${this.applicationId}/${interactionToken}/messages/@original`; + this.logger.debug("Discord API: PATCH interaction original message", { + channelId: discordChannelId, + contentLength: payload.content?.length || 0, + hasFiles: files.length > 0, + }); + try { + let result: APIMessage; + if (files.length > 0) { + const formData = new FormData(); + formData.append("payload_json", JSON.stringify(payload)); + for (let i = 0; i < files.length; i++) { + const file = files[i]; + if (!file) { + continue; + } + const buffer = await toBuffer(file.data, { + platform: "discord", + }); + if (!buffer) { + continue; + } + formData.append( + `files[${i}]`, + new Blob([new Uint8Array(buffer)], { + type: file.mimeType || "application/octet-stream", + }), + file.filename + ); + } + const response = await fetch(`${DISCORD_API_BASE}${patchPath}`, { + method: "PATCH", + headers: { Authorization: `Bot ${this.botToken}` }, + body: formData, + }); + if (!response.ok) { + const errorText = await response.text(); + throw new NetworkError( + "discord", + `Failed to PATCH interaction: ${response.status} ${errorText}` + ); + } + result = (await response.json()) as APIMessage; + } else { + const response = await this.discordFetch( + patchPath, + "PATCH", + payload + ); + result = (await response.json()) as APIMessage; + } + await chat.getState().delete(stateKey); + return { id: result.id, threadId: channelId, raw: result }; + } catch (error) { + this.logger.warn( + "Failed to PATCH interaction response, falling back to channel message", + { error: String(error), channelId: discordChannelId } + ); + await chat.getState().delete(stateKey); + // Fall through to post a regular channel message + } + } + } + + // No pending interaction token (or PATCH failed above) — post directly to the + // channel. Files require multipart; plain messages use JSON. if (files.length > 0) { return this.postMessageWithFiles( discordChannelId,