diff --git a/extension/src/experiments/checkpoints/collect.test.ts b/extension/src/experiments/checkpoints/collect.test.ts index f1ea1c934a..48f02afd32 100644 --- a/extension/src/experiments/checkpoints/collect.test.ts +++ b/extension/src/experiments/checkpoints/collect.test.ts @@ -51,6 +51,34 @@ describe('collectHasCheckpoints', () => { expect(hasCheckpoints).toBe(false) }) + it('should not fail if a train stage is not provided', () => { + const hasCheckpoints = collectHasCheckpoints({ + stages: { + extract: { + cmd: 'tar -xzf data/images.tar.gz --directory data', + deps: ['data/images.tar.gz'], + outs: [{ 'data/images/': { cache: false } }] + } + } + } as PartialDvcYaml) + + expect(hasCheckpoints).toBe(false) + }) + + it('should return true if any stage has checkpoints', () => { + const hasCheckpoints = collectHasCheckpoints({ + stages: { + extract: { + cmd: 'tar -xzf data/images.tar.gz --directory data', + deps: ['data/images.tar.gz'], + outs: [{ 'data/images/': { cache: false, checkpoint: true } }] + } + } + } as PartialDvcYaml) + + expect(hasCheckpoints).toBe(true) + }) + it('should correctly classify a more complex dvc.yaml without checkpoint', () => { const hasCheckpoints = collectHasCheckpoints({ stages: { diff --git a/extension/src/experiments/checkpoints/collect.ts b/extension/src/experiments/checkpoints/collect.ts index 3ea8ddea13..b1bcd850e2 100644 --- a/extension/src/experiments/checkpoints/collect.ts +++ b/extension/src/experiments/checkpoints/collect.ts @@ -1,13 +1,22 @@ -import { PartialDvcYaml } from '../../fileSystem' +import { Out, PartialDvcYaml } from '../../fileSystem' -export const collectHasCheckpoints = (yaml: PartialDvcYaml): boolean => { - return !!yaml.stages.train.outs.some(out => { +const stageHasCheckpoints = (outs: Out[] = []): boolean => { + for (const out of outs) { if (typeof out === 'string') { - return false + continue } - if (Object.values(out).some(file => file?.checkpoint)) { return true } - }) + } + return false +} + +export const collectHasCheckpoints = (yaml: PartialDvcYaml): boolean => { + for (const stage of Object.values(yaml?.stages || {})) { + if (stageHasCheckpoints(stage?.outs)) { + return true + } + } + return false } diff --git a/extension/src/fileSystem/index.ts b/extension/src/fileSystem/index.ts index 5c1bf1d66c..cc3f235cd5 100644 --- a/extension/src/fileSystem/index.ts +++ b/extension/src/fileSystem/index.ts @@ -67,9 +67,15 @@ export const isSameOrChild = (root: string, path: string) => { return !rel.startsWith('..') } +export type Out = + | string + | Record + export type PartialDvcYaml = { stages: { - train: { outs: (string | Record)[] } + [stage: string]: { + outs?: Out[] + } } }