diff --git a/extension/src/experiments/workspace.ts b/extension/src/experiments/workspace.ts index 945fce327b..f5d2b836c3 100644 --- a/extension/src/experiments/workspace.ts +++ b/extension/src/experiments/workspace.ts @@ -95,83 +95,43 @@ export class WorkspaceExperiments extends BaseWorkspaceWebviews< ) } - public async addFilter(overrideRoot?: string) { - const dvcRoot = await this.getDvcRoot(overrideRoot) - if (!dvcRoot) { - return - } - return this.getRepository(dvcRoot).addFilter() + public addFilter(overrideRoot?: string) { + return this.getRepositoryThenUpdate('addFilter', overrideRoot) } - public async addStarredFilter(overrideRoot?: string) { - const dvcRoot = await this.getDvcRoot(overrideRoot) - if (!dvcRoot) { - return - } - return this.getRepository(dvcRoot).addStarredFilter() + public addStarredFilter(overrideRoot?: string) { + return this.getRepositoryThenUpdate('addStarredFilter', overrideRoot) } - public async removeFilters() { - const dvcRoot = await this.getFocusedOrOnlyOrPickProject() - if (!dvcRoot) { - return - } - return this.getRepository(dvcRoot).removeFilters() + public removeFilters() { + return this.getRepositoryThenUpdate('removeFilters') } - public async addSort(overrideRoot?: string) { - const dvcRoot = await this.getDvcRoot(overrideRoot) - if (!dvcRoot) { - return - } - return this.getRepository(dvcRoot).addSort() + public addSort(overrideRoot?: string) { + return this.getRepositoryThenUpdate('addSort', overrideRoot) } - public async addStarredSort(overrideRoot?: string) { - const dvcRoot = await this.getDvcRoot(overrideRoot) - if (!dvcRoot) { - return - } - return this.getRepository(dvcRoot).addStarredSort() + public addStarredSort(overrideRoot?: string) { + return this.getRepositoryThenUpdate('addStarredSort', overrideRoot) } - public async removeSorts() { - const dvcRoot = await this.getFocusedOrOnlyOrPickProject() - if (!dvcRoot) { - return - } - - return this.getRepository(dvcRoot).removeSorts() + public removeSorts() { + return this.getRepositoryThenUpdate('removeSorts') } - public async selectExperimentsToPlot(overrideRoot?: string) { - const dvcRoot = await this.getDvcRoot(overrideRoot) - if (!dvcRoot) { - return - } - return this.getRepository(dvcRoot).selectExperimentsToPlot() + public selectExperimentsToPlot(overrideRoot?: string) { + return this.getRepositoryThenUpdate('selectExperimentsToPlot', overrideRoot) } - public async selectColumns(overrideRoot?: string) { - const dvcRoot = await this.getDvcRoot(overrideRoot) - if (!dvcRoot) { - return - } - return this.getRepository(dvcRoot).selectColumns() + public selectColumns(overrideRoot?: string) { + return this.getRepositoryThenUpdate('selectColumns', overrideRoot) } - public async selectQueueTasksToKill() { - const cwd = await this.getFocusedOrOnlyOrPickProject() - if (!cwd) { - return - } - - const taskIds = await this.getRepository(cwd).pickQueueTasksToKill() - - if (!taskIds || isEmpty(taskIds)) { - return - } - return this.runCommand(AvailableCommands.QUEUE_KILL, cwd, ...taskIds) + public selectQueueTasksToKill() { + return this.pickIdsThenRun( + 'pickQueueTasksToKill', + AvailableCommands.QUEUE_KILL + ) } public async selectExperimentsToPush(setup: Setup) { @@ -190,20 +150,11 @@ export class WorkspaceExperiments extends BaseWorkspaceWebviews< return pushCommand({ dvcRoot, ids }) } - public async selectExperimentsToRemove() { - const cwd = await this.getFocusedOrOnlyOrPickProject() - if (!cwd) { - return - } - - const experimentIds = await this.getRepository( - cwd - ).pickExperimentsToRemove() - if (!experimentIds || isEmpty(experimentIds)) { - return - } - - return this.runCommand(AvailableCommands.EXP_REMOVE, cwd, ...experimentIds) + public selectExperimentsToRemove() { + return this.pickIdsThenRun( + 'pickExperimentsToRemove', + AvailableCommands.EXP_REMOVE + ) } public async modifyExperimentParamsAndRun( @@ -453,6 +404,25 @@ export class WorkspaceExperiments extends BaseWorkspaceWebviews< ) } + private async getRepositoryThenUpdate( + method: + | 'addFilter' + | 'addStarredFilter' + | 'removeFilters' + | 'addSort' + | 'addStarredSort' + | 'removeSorts' + | 'selectExperimentsToPlot' + | 'selectColumns', + overrideRoot?: string + ) { + const dvcRoot = await this.getDvcRoot(overrideRoot) + if (!dvcRoot) { + return + } + return this.getRepository(dvcRoot)[method]() + } + private async shouldRun() { const cwd = await this.getFocusedOrOnlyOrPickProject() if (!cwd) { @@ -555,6 +525,25 @@ export class WorkspaceExperiments extends BaseWorkspaceWebviews< return { command, enteredManually, trainingScript } } + private async pickIdsThenRun( + pickMethod: 'pickQueueTasksToKill' | 'pickExperimentsToRemove', + commandId: + | typeof AvailableCommands.QUEUE_KILL + | typeof AvailableCommands.EXP_REMOVE + ) { + const cwd = await this.getFocusedOrOnlyOrPickProject() + if (!cwd) { + return + } + + const ids = await this.getRepository(cwd)[pickMethod]() + + if (!ids || isEmpty(ids)) { + return + } + return this.runCommand(commandId, cwd, ...ids) + } + private async pickExpThenRun( commandId: CommandId, pickFunc: (cwd: string) => Thenable | undefined