Skip to content

Commit f998fc7

Browse files
clean up notifications PR a little
1 parent d8d9500 commit f998fc7

File tree

3 files changed

+77
-114
lines changed

3 files changed

+77
-114
lines changed

pkg/github/notifications.go

Lines changed: 53 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -15,28 +15,25 @@ import (
1515
"github.com/mark3labs/mcp-go/server"
1616
)
1717

18-
// getNotifications creates a tool to list notifications for the current user.
19-
func GetNotifications(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
20-
return mcp.NewTool("get_notifications",
21-
mcp.WithDescription(t("TOOL_GET_NOTIFICATIONS_DESCRIPTION", "Get notifications for the authenticated GitHub user")),
22-
mcp.WithBoolean("all",
23-
mcp.Description("If true, show notifications marked as read. Default: false"),
24-
),
25-
mcp.WithBoolean("participating",
26-
mcp.Description("If true, only shows notifications in which the user is directly participating or mentioned. Default: false"),
18+
// ListNotifications creates a tool to list notifications for the current user.
19+
func ListNotifications(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
20+
return mcp.NewTool("list_notifications",
21+
mcp.WithDescription(t("TOOL_LIST_NOTIFICATIONS_DESCRIPTION", "List current notifications for the authenticated GitHub user")),
22+
mcp.WithToolAnnotation(mcp.ToolAnnotation{
23+
Title: t("TOOL_LIST_NOTIFICATIONS_USER_TITLE", "List notifications"),
24+
ReadOnlyHint: toBoolPtr(true),
25+
}),
26+
mcp.WithString("filter",
27+
mcp.Description("Filter notifications to, use default unless specified. Read notifications are ones that have already been acknowledged by the user. Participating notifications are those that the user is directly involved in, such as issues or pull requests they have commented on or created."),
28+
mcp.Enum("default", "include_read_notifications", "only_participating"),
2729
),
2830
mcp.WithString("since",
2931
mcp.Description("Only show notifications updated after the given time (ISO 8601 format)"),
3032
),
3133
mcp.WithString("before",
3234
mcp.Description("Only show notifications updated before the given time (ISO 8601 format)"),
3335
),
34-
mcp.WithNumber("per_page",
35-
mcp.Description("Results per page (max 100). Default: 30"),
36-
),
37-
mcp.WithNumber("page",
38-
mcp.Description("Page number of the results to fetch. Default: 1"),
39-
),
36+
WithPagination(),
4037
),
4138
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
4239
client, err := getClient(ctx)
@@ -65,6 +62,7 @@ func GetNotifications(getClient GetClientFn, t translations.TranslationHelperFun
6562
return mcp.NewToolResultError(err.Error()), nil
6663
}
6764

65+
// TODO pagination params from tool
6866
perPage, err := OptionalIntParamWithDefault(request, "per_page", 30)
6967
if err != nil {
7068
return mcp.NewToolResultError(err.Error()), nil
@@ -127,14 +125,19 @@ func GetNotifications(getClient GetClientFn, t translations.TranslationHelperFun
127125
}
128126
}
129127

130-
// markNotificationRead creates a tool to mark a notification as read.
131-
func MarkNotificationRead(getclient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
132-
return mcp.NewTool("mark_notification_read",
133-
mcp.WithDescription(t("TOOL_MARK_NOTIFICATION_READ_DESCRIPTION", "Mark a notification as read")),
128+
// dismiss notification creates a tool to mark a notification as read/done.
129+
func DismissNotification(getclient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
130+
return mcp.NewTool("dismiss_notification",
131+
mcp.WithDescription(t("TOOL_DISMISS_NOTIFICATION_DESCRIPTION", "Dismiss a notification by marking it as read or done")),
132+
mcp.WithToolAnnotation(mcp.ToolAnnotation{
133+
Title: t("TOOL_DISMISS_NOTIFICATION_USER_TITLE", "Dismiss notification"),
134+
ReadOnlyHint: toBoolPtr(false),
135+
}),
134136
mcp.WithString("threadID",
135137
mcp.Required(),
136138
mcp.Description("The ID of the notification thread"),
137139
),
140+
mcp.WithString("state", mcp.Description("The new state of the notification (read/done)"), mcp.Enum("read", "done")),
138141
),
139142
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
140143
client, err := getclient(ctx)
@@ -147,9 +150,27 @@ func MarkNotificationRead(getclient GetClientFn, t translations.TranslationHelpe
147150
return mcp.NewToolResultError(err.Error()), nil
148151
}
149152

150-
resp, err := client.Activity.MarkThreadRead(ctx, threadID)
153+
state, err := requiredParam[string](request, "state")
154+
if err != nil {
155+
return mcp.NewToolResultError(err.Error()), nil
156+
}
157+
158+
var resp *github.Response
159+
160+
if state == "done" {
161+
// for some inexplicable reason, the API seems to have threadID as int64 and string depending on the endpoint
162+
var threadIDInt int64
163+
threadIDInt, err = strconv.ParseInt(threadID, 10, 64)
164+
if err != nil {
165+
return mcp.NewToolResultError(fmt.Sprintf("invalid threadID format: %v", err)), nil
166+
}
167+
resp, err = client.Activity.MarkThreadDone(ctx, threadIDInt)
168+
} else {
169+
resp, err = client.Activity.MarkThreadRead(ctx, threadID)
170+
}
171+
151172
if err != nil {
152-
return nil, fmt.Errorf("failed to mark notification as read: %w", err)
173+
return nil, fmt.Errorf("failed to mark notification as %s: %w", state, err)
153174
}
154175
defer func() { _ = resp.Body.Close() }()
155176

@@ -158,17 +179,21 @@ func MarkNotificationRead(getclient GetClientFn, t translations.TranslationHelpe
158179
if err != nil {
159180
return nil, fmt.Errorf("failed to read response body: %w", err)
160181
}
161-
return mcp.NewToolResultError(fmt.Sprintf("failed to mark notification as read: %s", string(body))), nil
182+
return mcp.NewToolResultError(fmt.Sprintf("failed to mark notification as %s: %s", state, string(body))), nil
162183
}
163184

164-
return mcp.NewToolResultText("Notification marked as read"), nil
185+
return mcp.NewToolResultText(fmt.Sprintf("Notification marked as %s", state)), nil
165186
}
166187
}
167188

168189
// MarkAllNotificationsRead creates a tool to mark all notifications as read.
169190
func MarkAllNotificationsRead(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
170191
return mcp.NewTool("mark_all_notifications_read",
171192
mcp.WithDescription(t("TOOL_MARK_ALL_NOTIFICATIONS_READ_DESCRIPTION", "Mark all notifications as read")),
193+
mcp.WithToolAnnotation(mcp.ToolAnnotation{
194+
Title: t("TOOL_MARK_ALL_NOTIFICATIONS_READ_USER_TITLE", "Mark all notifications as read"),
195+
ReadOnlyHint: toBoolPtr(false),
196+
}),
172197
mcp.WithString("lastReadAt",
173198
mcp.Description("Describes the last point that notifications were checked (optional). Default: Now"),
174199
),
@@ -179,7 +204,7 @@ func MarkAllNotificationsRead(getClient GetClientFn, t translations.TranslationH
179204
return nil, fmt.Errorf("failed to get GitHub client: %w", err)
180205
}
181206

182-
lastReadAt, err := OptionalParam(request, "lastReadAt")
207+
lastReadAt, err := OptionalParam[string](request, "lastReadAt")
183208
if err != nil {
184209
return mcp.NewToolResultError(err.Error()), nil
185210
}
@@ -217,6 +242,10 @@ func MarkAllNotificationsRead(getClient GetClientFn, t translations.TranslationH
217242
func GetNotificationThread(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
218243
return mcp.NewTool("get_notification_thread",
219244
mcp.WithDescription(t("TOOL_GET_NOTIFICATION_THREAD_DESCRIPTION", "Get a specific notification thread")),
245+
mcp.WithToolAnnotation(mcp.ToolAnnotation{
246+
Title: t("TOOL_GET_NOTIFICATION_THREAD_USER_TITLE", "Get notification thread"),
247+
ReadOnlyHint: toBoolPtr(true),
248+
}),
220249
mcp.WithString("threadID",
221250
mcp.Required(),
222251
mcp.Description("The ID of the notification thread"),
@@ -255,46 +284,3 @@ func GetNotificationThread(getClient GetClientFn, t translations.TranslationHelp
255284
return mcp.NewToolResultText(string(r)), nil
256285
}
257286
}
258-
259-
// markNotificationDone creates a tool to mark a notification as done.
260-
func MarkNotificationDone(getclient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
261-
return mcp.NewTool("mark_notification_done",
262-
mcp.WithDescription(t("TOOL_MARK_NOTIFICATION_DONE_DESCRIPTION", "Mark a notification as done")),
263-
mcp.WithString("threadID",
264-
mcp.Required(),
265-
mcp.Description("The ID of the notification thread"),
266-
),
267-
),
268-
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
269-
client, err := getclient(ctx)
270-
if err != nil {
271-
return nil, fmt.Errorf("failed to get GitHub client: %w", err)
272-
}
273-
274-
threadIDStr, err := requiredParam[string](request, "threadID")
275-
if err != nil {
276-
return mcp.NewToolResultError(err.Error()), nil
277-
}
278-
279-
threadID, err := strconv.ParseInt(threadIDStr, 10, 64)
280-
if err != nil {
281-
return mcp.NewToolResultError("Invalid threadID: must be a numeric value"), nil
282-
}
283-
284-
resp, err := client.Activity.MarkThreadDone(ctx, threadID)
285-
if err != nil {
286-
return nil, fmt.Errorf("failed to mark notification as done: %w", err)
287-
}
288-
defer func() { _ = resp.Body.Close() }()
289-
290-
if resp.StatusCode != http.StatusResetContent && resp.StatusCode != http.StatusOK {
291-
body, err := io.ReadAll(resp.Body)
292-
if err != nil {
293-
return nil, fmt.Errorf("failed to read response body: %w", err)
294-
}
295-
return mcp.NewToolResultError(fmt.Sprintf("failed to mark notification as done: %s", string(body))), nil
296-
}
297-
298-
return mcp.NewToolResultText("Notification marked as done"), nil
299-
}
300-
}

pkg/github/server.go

Lines changed: 20 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,26 @@ func OptionalParam[T any](r mcp.CallToolRequest, p string) (T, error) {
119119
return r.Params.Arguments[p].(T), nil
120120
}
121121

122+
// OptionalParam is a helper function that can be used to fetch a requested parameter from the request.
123+
// It does the following checks:
124+
// 1. Checks if the parameter is present in the request, if not, it returns its zero-value
125+
// 2. If it is present, it checks if the parameter is of the expected type and returns it
126+
func OptionalParamWithDefault[T any](r mcp.CallToolRequest, p string, d T) (T, error) {
127+
var zero T
128+
129+
// Check if the parameter is present in the request
130+
if _, ok := r.Params.Arguments[p]; !ok {
131+
return d, nil
132+
}
133+
134+
// Check if the parameter is of the expected type
135+
if _, ok := r.Params.Arguments[p].(T); !ok {
136+
return zero, fmt.Errorf("parameter %s is not of type %T, is %T", p, zero, r.Params.Arguments[p])
137+
}
138+
139+
return r.Params.Arguments[p].(T), nil
140+
}
141+
122142
// OptionalIntParam is a helper function that can be used to fetch a requested parameter from the request.
123143
// It does the following checks:
124144
// 1. Checks if the parameter is present in the request, if not, it returns its zero-value
@@ -144,47 +164,6 @@ func OptionalIntParamWithDefault(r mcp.CallToolRequest, p string, d int) (int, e
144164
return v, nil
145165
}
146166

147-
// OptionalBoolParamWithDefault is a helper function that can be used to fetch a requested parameter from the request
148-
// similar to optionalParam, but it also takes a default value.
149-
func OptionalBoolParamWithDefault(r mcp.CallToolRequest, p string, d bool) (bool, error) {
150-
v, err := OptionalParam[bool](r, p)
151-
if err != nil {
152-
return false, err
153-
}
154-
if !v {
155-
return d, nil
156-
}
157-
return v, nil
158-
}
159-
160-
// OptionalStringParam is a helper function that can be used to fetch a requested parameter from the request.
161-
// It does the following checks:
162-
// 1. Checks if the parameter is present in the request, if not, it returns its zero-value
163-
// 2. If it is present, it checks if the parameter is of the expected type and returns it
164-
func OptionalStringParam(r mcp.CallToolRequest, p string) (string, error) {
165-
v, err := OptionalParam[string](r, p)
166-
if err != nil {
167-
return "", err
168-
}
169-
if v == "" {
170-
return "", nil
171-
}
172-
return v, nil
173-
}
174-
175-
// OptionalStringParamWithDefault is a helper function that can be used to fetch a requested parameter from the request
176-
// similar to optionalParam, but it also takes a default value.
177-
func OptionalStringParamWithDefault(r mcp.CallToolRequest, p string, d string) (string, error) {
178-
v, err := OptionalParam[string](r, p)
179-
if err != nil {
180-
return "", err
181-
}
182-
if v == "" {
183-
return d, nil
184-
}
185-
return v, nil
186-
}
187-
188167
// OptionalStringArrayParam is a helper function that can be used to fetch a requested parameter from the request.
189168
// It does the following checks:
190169
// 1. Checks if the parameter is present in the request, if not, it returns its zero-value

pkg/github/tools.go

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -93,14 +93,12 @@ func InitToolsets(passedToolsets []string, readOnly bool, getClient GetClientFn,
9393

9494
notifications := toolsets.NewToolset("notifications", "GitHub Notifications related tools").
9595
AddReadTools(
96-
97-
toolsets.NewServerTool(MarkNotificationRead(getClient, t)),
98-
toolsets.NewServerTool(MarkAllNotificationsRead(getClient, t)),
99-
toolsets.NewServerTool(MarkNotificationDone(getClient, t)),
96+
toolsets.NewServerTool(ListNotifications(getClient, t)),
97+
toolsets.NewServerTool(GetNotificationThread(getClient, t)),
10098
).
10199
AddWriteTools(
102-
toolsets.NewServerTool(GetNotifications(getClient, t)),
103-
toolsets.NewServerTool(GetNotificationThread(getClient, t)),
100+
toolsets.NewServerTool(DismissNotification(getClient, t)),
101+
toolsets.NewServerTool(MarkAllNotificationsRead(getClient, t)),
104102
)
105103

106104
// Keep experiments alive so the system doesn't error out when it's always enabled

0 commit comments

Comments
 (0)