diff --git a/__test__/common/github/RepoRestrictedGitHubClient.test.ts b/__test__/common/github/RepoRestrictedGitHubClient.test.ts new file mode 100644 index 00000000..852e6754 --- /dev/null +++ b/__test__/common/github/RepoRestrictedGitHubClient.test.ts @@ -0,0 +1,134 @@ +import { RepoRestrictedGitHubClient } from '@/common/github/RepoRestrictedGitHubClient'; +import { + IGitHubClient, + AddCommentToPullRequestRequest, + GetPullRequestCommentsRequest, + GetPullRequestFilesRequest, + GetRepositoryContentRequest, + GraphQLQueryRequest, + UpdatePullRequestCommentRequest, + GitHubClient +} from "@/common"; +import { jest } from '@jest/globals'; + +describe('RepoRestrictedGitHubClient', () => { + let client: RepoRestrictedGitHubClient; + const repositoryNameSuffix = '-suffix'; + + const gitHubClient: jest.Mocked = { + graphql: jest.fn(), + getRepositoryContent: jest.fn(), + getPullRequestFiles: jest.fn(), + getPullRequestComments: jest.fn(), + addCommentToPullRequest: jest.fn(), + updatePullRequestComment: jest.fn(), + }; + + beforeEach(() => { + client = new RepoRestrictedGitHubClient({ + repositoryNameSuffix, + gitHubClient + }); + }); + + it('should delegate graphql request to the underlying client', async () => { + const request: GraphQLQueryRequest = { query: '' }; + await client.graphql(request); + expect(gitHubClient.graphql).toHaveBeenCalledWith(request); + }); + + it('should check suffix for getRepositoryContent', async () => { + const request: GetRepositoryContentRequest = { + repositoryName: 'repo-suffix', path: '', + repositoryOwner: '', + ref: undefined + }; + await client.getRepositoryContent(request); + expect(gitHubClient.getRepositoryContent).toHaveBeenCalledWith(request); + }); + + it('should throw error if suffix is invalid for getRepositoryContent', async () => { + const request: GetRepositoryContentRequest = { + repositoryName: 'repo', path: '', + repositoryOwner: '', + ref: undefined + }; + await expect(client.getRepositoryContent(request)).rejects.toThrow("Invalid repository name"); + }); + + it('should check suffix for getPullRequestFiles', async () => { + const request: GetPullRequestFilesRequest = { + repositoryName: 'repo-suffix', pullRequestNumber: 1, + appInstallationId: 0, + repositoryOwner: '' + }; + await client.getPullRequestFiles(request); + expect(gitHubClient.getPullRequestFiles).toHaveBeenCalledWith(request); + }); + + it('should throw error if suffix is invalid for getPullRequestFiles', async () => { + const request: GetPullRequestFilesRequest = { + repositoryName: 'repo', pullRequestNumber: 1, + appInstallationId: 0, + repositoryOwner: '' + }; + await expect(client.getPullRequestFiles(request)).rejects.toThrow("Invalid repository name"); + }); + + it('should check suffix for getPullRequestComments', async () => { + const request: GetPullRequestCommentsRequest = { + repositoryName: 'repo-suffix', pullRequestNumber: 1, + appInstallationId: 0, + repositoryOwner: '' + }; + await client.getPullRequestComments(request); + expect(gitHubClient.getPullRequestComments).toHaveBeenCalledWith(request); + }); + + it('should throw error if suffix is invalid for getPullRequestComments', async () => { + const request: GetPullRequestCommentsRequest = { + repositoryName: 'repo', pullRequestNumber: 1, + appInstallationId: 0, + repositoryOwner: '' + }; + await expect(client.getPullRequestComments(request)).rejects.toThrow("Invalid repository name"); + }); + + it('should check suffix for addCommentToPullRequest', async () => { + const request: AddCommentToPullRequestRequest = { + repositoryName: 'repo-suffix', pullRequestNumber: 1, body: '', + appInstallationId: 0, + repositoryOwner: '' + }; + await client.addCommentToPullRequest(request); + expect(gitHubClient.addCommentToPullRequest).toHaveBeenCalledWith(request); + }); + + it('should throw error if suffix is invalid for addCommentToPullRequest', async () => { + const request: AddCommentToPullRequestRequest = { + repositoryName: 'repo', pullRequestNumber: 1, body: '', + appInstallationId: 0, + repositoryOwner: '' + }; + await expect(client.addCommentToPullRequest(request)).rejects.toThrow("Invalid repository name"); + }); + + it('should check suffix for updatePullRequestComment', async () => { + const request: UpdatePullRequestCommentRequest = { + repositoryName: 'repo-suffix', commentId: 1, body: '', + appInstallationId: 0, + repositoryOwner: '' + }; + await client.updatePullRequestComment(request); + expect(gitHubClient.updatePullRequestComment).toHaveBeenCalledWith(request); + }); + + it('should throw error if suffix is invalid for updatePullRequestComment', async () => { + const request: UpdatePullRequestCommentRequest = { + repositoryName: 'repo', commentId: 1, body: '', + appInstallationId: 0, + repositoryOwner: '' + }; + await expect(client.updatePullRequestComment(request)).rejects.toThrow("Invalid repository name"); + }); +}); diff --git a/src/common/github/RepoRestrictedGitHubClient.ts b/src/common/github/RepoRestrictedGitHubClient.ts new file mode 100644 index 00000000..558492f3 --- /dev/null +++ b/src/common/github/RepoRestrictedGitHubClient.ts @@ -0,0 +1,60 @@ +import { + IGitHubClient, + AddCommentToPullRequestRequest, + GetPullRequestCommentsRequest, + GetPullRequestFilesRequest, + GetRepositoryContentRequest, + GraphQLQueryRequest, + GraphQlQueryResponse, + PullRequestComment, + PullRequestFile, + RepositoryContent, + UpdatePullRequestCommentRequest +} from "@/common"; + +export class RepoRestrictedGitHubClient implements IGitHubClient { + + private gitHubClient: IGitHubClient; + private repositoryNameSuffix: string; + + constructor(config: { + repositoryNameSuffix: string; + gitHubClient: IGitHubClient + }) { + this.gitHubClient = config.gitHubClient; + this.repositoryNameSuffix = config.repositoryNameSuffix; + } + + graphql(request: GraphQLQueryRequest): Promise { + return this.gitHubClient.graphql(request); + } + + getRepositoryContent(request: GetRepositoryContentRequest): Promise { + if (!this.isRepositoryNameValid(request.repositoryName)) return Promise.reject(new Error("Invalid repository name")); + return this.gitHubClient.getRepositoryContent(request); + } + + getPullRequestFiles(request: GetPullRequestFilesRequest): Promise { + if (!this.isRepositoryNameValid(request.repositoryName)) return Promise.reject(new Error("Invalid repository name")); + return this.gitHubClient.getPullRequestFiles(request); + } + + getPullRequestComments(request: GetPullRequestCommentsRequest): Promise { + if (!this.isRepositoryNameValid(request.repositoryName)) return Promise.reject(new Error("Invalid repository name")); + return this.gitHubClient.getPullRequestComments(request); + } + + addCommentToPullRequest(request: AddCommentToPullRequestRequest): Promise { + if (!this.isRepositoryNameValid(request.repositoryName)) return Promise.reject(new Error("Invalid repository name")); + return this.gitHubClient.addCommentToPullRequest(request); + } + + updatePullRequestComment(request: UpdatePullRequestCommentRequest): Promise { + if (!this.isRepositoryNameValid(request.repositoryName)) return Promise.reject(new Error("Invalid repository name")); + return this.gitHubClient.updatePullRequestComment(request); + } + + private isRepositoryNameValid(repositoryName: string): boolean { + return repositoryName.endsWith(this.repositoryNameSuffix); + } +} diff --git a/src/composition.ts b/src/composition.ts index 7ab23475..5da65221 100644 --- a/src/composition.ts +++ b/src/composition.ts @@ -50,6 +50,7 @@ import { RepositoryNameEventFilter, PullRequestCommenter } from "@/features/hooks/domain" +import { RepoRestrictedGitHubClient } from "./common/github/RepoRestrictedGitHubClient" const gitHubAppCredentials = { appId: env.getOrThrow("GITHUB_APP_ID"), @@ -135,13 +136,18 @@ const oauthTokenRefresher = new LockingOAuthTokenRefresher({ }) }) -export const gitHubClient = new GitHubClient({ +const gitHubClient = new GitHubClient({ ...gitHubAppCredentials, oauthTokenDataSource }) +const repoRestrictedGitHubClient = new RepoRestrictedGitHubClient({ + repositoryNameSuffix: env.getOrThrow("REPOSITORY_NAME_SUFFIX"), + gitHubClient +}) + export const userGitHubClient = new OAuthTokenRefreshingGitHubClient({ - gitHubClient, + gitHubClient: repoRestrictedGitHubClient, oauthTokenDataSource, oauthTokenRefresher })