From 97b578a674c5487592aec3184c92c8c5cf54b872 Mon Sep 17 00:00:00 2001 From: Olafur Geirsson Date: Wed, 8 May 2024 16:29:53 +0200 Subject: [PATCH] TSC Graph Context: fix performance Fixes CODY-1434 Previously, the `tsc-mixed` context retriever made the UI freeze because it was doing super inefficient things with context expansion. This PR fixes the problem so that the performance is back to normal. --- .../retrievers/section-history/nextTick.ts | 3 + .../section-history-retriever.test.ts | 5 +- .../retrievers/tsc/SymbolFormatter.test.ts | 4 +- .../context/retrievers/tsc/SymbolFormatter.ts | 47 +++-- .../retrievers/tsc/relevantTypeIdentifiers.ts | 54 ++++-- .../retrievers/tsc/tsc-retriever.test.ts | 28 ++- .../context/retrievers/tsc/tsc-retriever.ts | 183 ++++++++++++------ 7 files changed, 228 insertions(+), 96 deletions(-) create mode 100644 vscode/src/completions/context/retrievers/section-history/nextTick.ts diff --git a/vscode/src/completions/context/retrievers/section-history/nextTick.ts b/vscode/src/completions/context/retrievers/section-history/nextTick.ts new file mode 100644 index 00000000000..5651b1cc54d --- /dev/null +++ b/vscode/src/completions/context/retrievers/section-history/nextTick.ts @@ -0,0 +1,3 @@ +export function nextTick() { + return new Promise(resolve => process.nextTick(resolve)) +} diff --git a/vscode/src/completions/context/retrievers/section-history/section-history-retriever.test.ts b/vscode/src/completions/context/retrievers/section-history/section-history-retriever.test.ts index 161716fc566..88d6bb2a5f8 100644 --- a/vscode/src/completions/context/retrievers/section-history/section-history-retriever.test.ts +++ b/vscode/src/completions/context/retrievers/section-history/section-history-retriever.test.ts @@ -6,6 +6,7 @@ import { testFileUri } from '@sourcegraph/cody-shared' import { range, withPosixPathsInString } from '../../../../testutils/textDocument' import * as docContextGetters from '../../../doc-context-getters' +import { nextTick } from './nextTick' import { SectionHistoryRetriever } from './section-history-retriever' const document1Uri = testFileUri('document1.ts') @@ -274,7 +275,3 @@ describe('GraphSectionObserver', () => { }) }) }) - -function nextTick() { - return new Promise(resolve => process.nextTick(resolve)) -} diff --git a/vscode/src/completions/context/retrievers/tsc/SymbolFormatter.test.ts b/vscode/src/completions/context/retrievers/tsc/SymbolFormatter.test.ts index c7b0ff0379a..45b339ccd51 100644 --- a/vscode/src/completions/context/retrievers/tsc/SymbolFormatter.test.ts +++ b/vscode/src/completions/context/retrievers/tsc/SymbolFormatter.test.ts @@ -19,7 +19,7 @@ describe('SymbolFormatter', () => { // quickly adds up, but it's good enough for now. const program = ts.createProgram(['test.ts'], {}, host) const checker = program.getTypeChecker() - const formatter = new SymbolFormatter(checker) + const formatter = new SymbolFormatter(checker, 10) const sourceFile = program.getSourceFile('test.ts') if (!sourceFile) { return [] @@ -34,7 +34,7 @@ describe('SymbolFormatter', () => { if (!symbol) { continue } - result.push(formatter.formatSymbol(symbol)) + result.push(formatter.formatSymbol(statement, symbol, 0)) } return result } diff --git a/vscode/src/completions/context/retrievers/tsc/SymbolFormatter.ts b/vscode/src/completions/context/retrievers/tsc/SymbolFormatter.ts index ccea350f9b0..2559146cb5e 100644 --- a/vscode/src/completions/context/retrievers/tsc/SymbolFormatter.ts +++ b/vscode/src/completions/context/retrievers/tsc/SymbolFormatter.ts @@ -26,24 +26,24 @@ const path = defaultPathFunctions() * `ts.{Symbol,Type}`). */ export class SymbolFormatter { - private queue = new Set() + public queue = new Map() public isRendered = new Set() - constructor(private checker: ts.TypeChecker) {} + private depth = 0 + constructor( + private checker: ts.TypeChecker, + private maxDepth: number + ) {} - public formatSymbolWithQueue(sym: ts.Symbol): { formatted: string; queue: ts.Symbol[] } { - const formatted = this.formatSymbol(sym) + public formatSymbol( + declaration: ts.Node, + sym: ts.Symbol, + depth: number, + params?: { stripEnclosingInformation?: boolean } + ): string { + const oldDepth = this.depth + this.depth = depth this.queueRelatedSymbols(sym) - - // TODO: return a queue of symbols that got added in this function call, - // not accumulated list - return { formatted, queue: [...this.queue] } - } - - public formatSymbol(sym: ts.Symbol, params?: { stripEnclosingInformation?: boolean }): string { - const declaration = sym.declarations?.[0] - if (!declaration) { - return '' - } + this.depth = oldDepth if (ts.isClassLike(declaration) || ts.isInterfaceDeclaration(declaration)) { return this.formatClassOrInterface(declaration, sym) @@ -65,6 +65,9 @@ export class SymbolFormatter { } private queueRelatedSymbols(sym: ts.Symbol): void { + if (this.depth > this.maxDepth) { + return + } const walkNode = (node: ts.Node | undefined): void => { if (!node) { return @@ -108,7 +111,7 @@ export class SymbolFormatter { if (isStdLibSymbol(s)) { return } - this.queue.add(s) + this.queue.set(s, this.depth + 1) } private registerRenderedNode(node: ts.Node): void { @@ -184,9 +187,17 @@ export class SymbolFormatter { } const name = declarationName(decl) if (name) { - p.line(this.formatSymbol(member, { stripEnclosingInformation: true })) + p.line( + this.formatSymbol(decl, member, this.depth + 1, { + stripEnclosingInformation: true, + }) + ) } else if (memberName === ts.InternalSymbolName.Constructor) { - p.line(this.formatSymbol(member, { stripEnclosingInformation: true })) + p.line( + this.formatSymbol(decl, member, this.depth + 1, { + stripEnclosingInformation: true, + }) + ) } } }) diff --git a/vscode/src/completions/context/retrievers/tsc/relevantTypeIdentifiers.ts b/vscode/src/completions/context/retrievers/tsc/relevantTypeIdentifiers.ts index 003cc68c5ac..b22471db773 100644 --- a/vscode/src/completions/context/retrievers/tsc/relevantTypeIdentifiers.ts +++ b/vscode/src/completions/context/retrievers/tsc/relevantTypeIdentifiers.ts @@ -1,26 +1,43 @@ import ts from 'typescript' import { declarationName } from './SymbolFormatter' +export type NodeMatchKind = + | 'imports' + | 'call-expression' + | 'property-access' + | 'function-declaration' + | 'declaration' + | 'none' + /** * Returns a list of identifier nodes that should be added to the Cody context. * * The logic for this function is going to be evolving as we add support for * more syntax constructs where we want to inject graph context. */ -export function relevantTypeIdentifiers(checker: ts.TypeChecker, node: ts.Node): ts.Node[] { - const result: ts.Node[] = [] - pushTypeIdentifiers(result, checker, node) - return result +export function relevantTypeIdentifiers( + checker: ts.TypeChecker, + node: ts.Node +): { kind: NodeMatchKind; nodes: ts.Node[] } { + const nodes: ts.Node[] = [] + const kind = pushTypeIdentifiers(nodes, checker, node) + return { kind, nodes } } -export function pushTypeIdentifiers(result: ts.Node[], checker: ts.TypeChecker, node: ts.Node): void { +export function pushTypeIdentifiers( + result: ts.Node[], + checker: ts.TypeChecker, + node: ts.Node +): NodeMatchKind { if (ts.isSourceFile(node)) { ts.forEachChild(node, child => { if (ts.isImportDeclaration(child)) { pushDescendentIdentifiers(result, child) } }) - } else if ( + return 'imports' + } + if ( ts.isSetAccessorDeclaration(node) || ts.isGetAccessorDeclaration(node) || ts.isConstructorDeclaration(node) || @@ -36,19 +53,24 @@ export function pushTypeIdentifiers(result: ts.Node[], checker: ts.TypeChecker, if (node.type) { pushDescendentIdentifiers(result, node.type) } - } else if (ts.isCallExpression(node)) { + return 'function-declaration' + } + if (ts.isCallExpression(node)) { result.push(...rightmostIdentifier(node.expression)) - } else if (ts.isPropertyAccessExpression(node)) { + return 'call-expression' + } + if (ts.isPropertyAccessExpression(node)) { result.push(...rightmostIdentifier(node.expression)) - } else { - const name = declarationName(node) - if (name) { - result.push(name) - } else { - // Uncomment below to debug what kind of if (ts.isX) case to handle - // console.log({ text: node.getText(), kindString: ts.SyntaxKind[node.kind] }) - } + return 'property-access' + } + const name = declarationName(node) + if (name) { + result.push(name) + return 'declaration' } + // Uncomment below to debug what kind of if (ts.isX) case to handle + // console.log({ text: node.getText(), kindString: ts.SyntaxKind[node.kind] }) + return 'none' } // A hacky way to get the `ts.Identifier` node furthest to the right. Ideally, diff --git a/vscode/src/completions/context/retrievers/tsc/tsc-retriever.test.ts b/vscode/src/completions/context/retrievers/tsc/tsc-retriever.test.ts index 829e73a0b3a..27a420456bf 100644 --- a/vscode/src/completions/context/retrievers/tsc/tsc-retriever.test.ts +++ b/vscode/src/completions/context/retrievers/tsc/tsc-retriever.test.ts @@ -47,7 +47,7 @@ describe.skipIf(isWindows())('TscRetriever', () => { docContext, document, position, - hints: { maxChars: 1000, maxMs: 100 }, + hints: { maxChars: 10_000, maxMs: 100 }, }) return { snippets: result, moduleName, namespaceName } } @@ -57,6 +57,28 @@ describe.skipIf(isWindows())('TscRetriever', () => { } it('imports', async () => { + expect( + await retrieveText(dedent` + import { execFileSync } from 'child_process' + const a = █ + `) + // TODO: drill into Holder + ).toMatchInlineSnapshot(` + [ + "function execFileSync(file: string): Buffer", + "function execFileSync(file: string, options: ExecFileSyncOptionsWithStringEncoding): string", + "function execFileSync(file: string, options: ExecFileSyncOptionsWithBufferEncoding): Buffer", + "function execFileSync(file: string, options?: ExecFileSyncOptions | undefined): string | Buffer", + "function execFileSync(file: string, args: readonly string[]): Buffer", + "function execFileSync(file: string, args: readonly string[], options: ExecFileSyncOptionsWithStringEncoding): string", + "function execFileSync(file: string, args: readonly string[], options: ExecFileSyncOptionsWithBufferEncoding): Buffer", + "function execFileSync(file: string, args?: readonly string[] | undefined, options?: ExecFileSyncOptions | undefined): string | Buffer", + "const a: any", + ] + `) + }) + + it('imports2', async () => { const { moduleName } = await retrieve( dedent` export interface Holder { bananas: number } @@ -254,7 +276,9 @@ describe.skipIf(isWindows())('TscRetriever', () => { interface A { value: number } interface B { a(): A } const b: B = {} - b.a().█ + function foo() { + b.a().█ + } `) ).toMatchInlineSnapshot(` [ diff --git a/vscode/src/completions/context/retrievers/tsc/tsc-retriever.ts b/vscode/src/completions/context/retrievers/tsc/tsc-retriever.ts index cdcefb45172..430acc57c11 100644 --- a/vscode/src/completions/context/retrievers/tsc/tsc-retriever.ts +++ b/vscode/src/completions/context/retrievers/tsc/tsc-retriever.ts @@ -11,9 +11,10 @@ import { import ts from 'typescript' import * as vscode from 'vscode' import type { ContextRetriever, ContextRetrieverOptions } from '../../../types' +import { nextTick } from '../section-history/nextTick' import { SymbolFormatter, isStdLibNode } from './SymbolFormatter' import { getTSSymbolAtLocation } from './getTSSymbolAtLocation' -import { relevantTypeIdentifiers } from './relevantTypeIdentifiers' +import { type NodeMatchKind, relevantTypeIdentifiers } from './relevantTypeIdentifiers' interface LoadedCompiler { service: ts.LanguageService @@ -46,7 +47,18 @@ interface TscRetrieverOptions { maxNodeMatches: number /** - * The "symbol depth" determines how many nested layers of signatures we + * For each node match, include at most these number of matches. + */ + maxSnippetsPerNodeMatch: Map + + /** For node match kinds that are undefined in maxSnippetsPerNodeMatch, use this value. */ + defaultSnippetsPerNodeMatch: number + + /** Return at most this number of total symbol snippets per request. */ + maxTotalSnippets: number + + /** + * The "symbol depth" determines how many nested ljyers of signatures we * want to emit for a given symbol. For example, * * - Depth 0: does nothing @@ -66,10 +78,15 @@ interface TscRetrieverOptions { export function defaultTscRetrieverOptions(): TscRetrieverOptions { return { - includeSymbolsInCurrentFile: false, + // it's confusing when we skip results from the local file. Also, the + // prefix/suffix are often only a fraction of the open file anyways. + includeSymbolsInCurrentFile: true, maxNodeMatches: vscode.workspace .getConfiguration('sourcegraph') .get('cody.autocomplete.experimental.maxTscResults', 1), + maxSnippetsPerNodeMatch: new Map([['imports', 3]]), + defaultSnippetsPerNodeMatch: 5, + maxTotalSnippets: 10, maxSymbolDepth: 1, } } @@ -78,6 +95,12 @@ interface TscLanguageService { service: ts.LanguageService host: TscLanguageServiceHost } + +interface DocumentSnapshot { + text: string + version: string +} + /** * The tsc retriever uses the TypeScript compiler API to retrieve contextual * information about the autocomplete request location. @@ -85,12 +108,19 @@ interface TscLanguageService { export class TscRetriever implements ContextRetriever { public identifier = 'tsc' - constructor(private options: TscRetrieverOptions = defaultTscRetrieverOptions()) {} + constructor(private options: TscRetrieverOptions = defaultTscRetrieverOptions()) { + this.disposables.push( + vscode.workspace.onDidChangeTextDocument(event => { + this.snapshots.delete(event.document.fileName) + }) + ) + } private servicesByTsconfigPath = new Map() private baseCompilerHost: ts.FormatDiagnosticsHost = ts.createCompilerHost({}) private disposables: vscode.Disposable[] = [] private documentRegistry = ts.createDocumentRegistry(isMacOS() || isWindows(), currentDirectory()) + private snapshots = new Map() private getOrLoadCompiler(file: FileURI): LoadedCompiler | undefined { const fromCache = this.getCompiler(file) @@ -98,16 +128,23 @@ export class TscRetriever implements ContextRetriever { return fromCache } this.loadCompiler(file) + this.documentRegistry.updateDocument return this.getCompiler(file) } private readDocument(fileName: string): { text: string; version: string } { + const fromCache = this.snapshots.get(fileName) + if (fromCache) { + return fromCache + } for (const document of vscode.workspace.textDocuments) { if (isFileURI(document.uri) && document.uri.fsPath === fileName) { return { text: document.getText(), version: document.version.toString() } } } - return { text: ts.sys.readFile(fileName) ?? '', version: '0' } + const result = { text: ts.sys.readFile(fileName) ?? '', version: '0' } + this.snapshots.set(fileName, result) + return result } private loadCompiler(file: FileURI): undefined { @@ -229,7 +266,7 @@ export class TscRetriever implements ContextRetriever { return result ?? this.servicesByTsconfigPath.get(process.cwd()) } - private doBlockingRetrieve(options: ContextRetrieverOptions): AutocompleteContextSnippet[] { + private async doRetrieve(options: ContextRetrieverOptions): Promise { const uri = options.document.uri if (!isFileURI(uri)) { return [] @@ -238,25 +275,26 @@ export class TscRetriever implements ContextRetriever { if (!compiler) { return [] } - try { - return new SymbolCollector(compiler, this.options, options.position).relevantSymbols() - } catch (error) { - logError('tsc-retriever', 'unexpected error', error) - return [] - } + + // Loading the compiler can block the thread for a while, so we hand + // back control to allow other promises to run before running symbol + // collection. + await nextTick() + + return new SymbolCollector(compiler, this.options, options, options.position).relevantSymbols() } - public retrieve(options: ContextRetrieverOptions): Promise { - return new Promise(resolve => { - tracer.startActiveSpan('graph-context.tsc', span => { - span.setAttribute('sampled', true) - try { - resolve(this.doBlockingRetrieve(options)) - } catch (error) { - logError('tsc-retriever', String(error)) - resolve([]) - } - }) + public async retrieve(options: ContextRetrieverOptions): Promise { + return tracer.startActiveSpan('graph-context.tsc', async span => { + span.setAttribute('sampled', true) + try { + const result = await this.doRetrieve(options) + // logDebug('tsc-retriever', JSON.stringify(result, null, 2)) + return result + } catch (error) { + logError('tsc-retriever', String(error)) + return [] + } }) } @@ -268,9 +306,7 @@ export class TscRetriever implements ContextRetriever { languageId === 'javascriptreact' ) } - public dispose() { - vscode.Disposable.from(...this.disposables).dispose() - } + public dispose() {} } // Copy-pasted and adapted code from scip-typescript @@ -340,17 +376,22 @@ type TscLanguageServiceHost = ts.LanguageServiceHost & { class SymbolCollector { private snippets: AutocompleteContextSnippet[] = [] - private toplevelNodes = new Set() - private isDone = () => this.toplevelNodes.size >= this.options.maxNodeMatches + private nodeMatches = new Set() + private hasRemainingNodeMatches = () => this.nodeMatches.size < this.options.maxNodeMatches + private hasRemainingChars = () => this.addedContentChars < this.contextOptions.hints.maxChars + private addedContentChars = 0 private isAdded = new Set() private formatter: SymbolFormatter private offset: number + private searchState: SearchState = SearchState.Continue + private isSearchDone = () => this.searchState === SearchState.Done constructor( private readonly compiler: LoadedCompiler, private options: TscRetrieverOptions, + private contextOptions: ContextRetrieverOptions, position: vscode.Position ) { - this.formatter = new SymbolFormatter(this.compiler.checker) + this.formatter = new SymbolFormatter(this.compiler.checker, this.options.maxSymbolDepth) this.offset = this.compiler.sourceFile.getPositionOfLineAndCharacter( position.line, position.character @@ -358,29 +399,41 @@ class SymbolCollector { } public relevantSymbols(): AutocompleteContextSnippet[] { - this.loop(this.compiler.sourceFile) + this.tryNodeMatch(this.compiler.sourceFile) + for (const [queued, depth] of this.formatter.queue.entries()) { + if (depth > this.options.maxSymbolDepth) { + continue + } + const budget = this.options.maxTotalSnippets - this.snippets.length + this.addSymbol(queued, budget, depth) + } return this.snippets } - private addSymbol(sym: ts.Symbol, depth: number): boolean { + private addSymbol( + sym: ts.Symbol, + remainingNodeMatchKindSnippetBudget: number, + depth: number + ): number { if (depth > this.options.maxSymbolDepth) { - return false + return 0 } if (this.isAdded.has(sym)) { - return false + return 0 } if (this.formatter.isRendered.has(sym)) { // Skip this symbol if it's a child of a symbol that we have already // formatted. For example, if we render `interface A { a: number }` // then we don't need to render `(property) A.a: number` separately // because it's redunant with the interface declaration. - return false + return 0 } this.isAdded.add(sym) // Symbols with multiple declarations are normally overloaded // functions, in which case we want to show all available // signatures. - let isAdded = false + let addedCount = 0 + for (const declaration of sym.declarations ?? []) { if (isStdLibNode(declaration)) { // Skip stdlib types because the LLM most likely knows how @@ -405,13 +458,10 @@ class SymbolCollector { case ts.SyntaxKind.NamespaceImport: continue } - if (this.isDone()) { - continue - } const sourceFile = declaration.getSourceFile() const start = sourceFile.getLineAndCharacterOfPosition(declaration.getStart()) const end = sourceFile.getLineAndCharacterOfPosition(declaration.getEnd()) - const { formatted: content, queue } = this.formatter.formatSymbolWithQueue(sym) + const content = this.formatter.formatSymbol(declaration, sym, depth) if (!ts.isModuleDeclaration(declaration)) { // Skip module declarations because they can be too large. // We still format them to queue the referenced types. @@ -422,44 +472,69 @@ class SymbolCollector { endLine: end.line, uri: vscode.Uri.file(sourceFile.fileName), } + this.addedContentChars += content.length this.snippets.push(snippet) + addedCount++ + if (this.snippets.length >= this.options.maxTotalSnippets) { + this.searchState = SearchState.Done + break + } + if (!this.hasRemainingChars()) { + this.searchState = SearchState.Done + break + } + if (remainingNodeMatchKindSnippetBudget - addedCount <= 0) { + break + } } - for (const queued of queue) { - this.addSymbol(queued, depth + 1) - } - isAdded = true } - return isAdded + return addedCount } - private loop(node: ts.Node): void { - if (this.isDone()) { + private tryNodeMatch(node: ts.Node): void { + if (this.isSearchDone()) { return } // Loop on children first to boost symbol results that are closer to the // cursor location. - ts.forEachChild(node, child => this.loop(child)) + ts.forEachChild(node, child => { + if (this.isSearchDone()) { + return + } + this.tryNodeMatch(child) + }) - if (this.isDone()) { + if (this.isSearchDone()) { return } if (this.offset < node.getStart() || this.offset > node.getEnd()) { + // Subtree does not enclose the request position. return } - let isAdded = false - for (const identifier of relevantTypeIdentifiers(this.compiler.checker, node)) { + let addedCount = 0 + const { kind, nodes } = relevantTypeIdentifiers(this.compiler.checker, node) + const budget = + this.options.maxSnippetsPerNodeMatch.get(kind) ?? this.options.defaultSnippetsPerNodeMatch + for (const identifier of nodes) { const symbol = getTSSymbolAtLocation(this.compiler.checker, identifier) if (symbol) { - const gotAdded = this.addSymbol(symbol, 0) - isAdded ||= gotAdded + addedCount += this.addSymbol(symbol, budget - addedCount, 0) } } - if (isAdded) { - this.toplevelNodes.add(node) + if (addedCount > 0) { + this.nodeMatches.add(node) + if (!this.hasRemainingNodeMatches()) { + this.searchState = SearchState.Done + } } } } + +enum SearchState { + Done = 1, + Continue = 2, +}