Skip to content

Commit

Permalink
fix: properly update the pulled images count if some images are skipped
Browse files Browse the repository at this point in the history
The code had an issue: if the loop pulled at least one image and one of
the images in the end of the loop already exist on the node, it was
skipping the update, so only the outdated event was sent to the channel.

Signed-off-by: Artem Chernyshev <artem.chernyshev@talos-systems.com>
  • Loading branch information
Unix4ever committed Jun 7, 2024
1 parent 5a4251c commit 2fcd0fd
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func (m *mockImageClient) ListImagesOnNode(_ context.Context, cluster, node stri
node: node,
})

return []string{node + "-image-1"}, nil // mimic <node>-image-1 being already on the node
return []string{node + "-image-1", node + "-image-4"}, nil // mimic <node>-image-1 and image-4 being already on the node
}

func (m *mockImageClient) PullImageToNode(_ context.Context, cluster, node, image string) error {
Expand Down Expand Up @@ -105,7 +105,7 @@ func (suite *ImagePullStatusControllerSuite) TestImagePullStatus() {
},
{
Node: "node-2",
Images: []string{"node-2-image-1", "node-2-image-2", "node-2-image-3"},
Images: []string{"node-2-image-1", "node-2-image-2", "node-2-image-3", "node-2-image-4"},
},
}

Expand Down Expand Up @@ -162,14 +162,14 @@ func (suite *ImagePullStatusControllerSuite) TestImagePullStatus() {
sts1Cluster, _ := sts1.Metadata().Labels().Get(omni.LabelCluster)
assert.Equal(collect, "pr-1-cluster", sts1Cluster)

// pr-1 will pull three images, so the version should be 3
assert.Equal(collect, "3", sts1.Metadata().Version().String())
// pr-1 will pull four images, so the version should be 4
assert.Equal(collect, "4", sts1.Metadata().Version().String())

assert.Equal(collect, sts1.TypedSpec().Value.GetRequestVersion(), pr1.Metadata().Version().String())
assert.Equal(collect, sts1.TypedSpec().Value.GetLastProcessedNode(), "node-2")
assert.Equal(collect, sts1.TypedSpec().Value.GetLastProcessedImage(), "node-2-image-3")
assert.Equal(collect, sts1.TypedSpec().Value.GetProcessedCount(), uint32(5)) // the processed count also includes images already on the node
assert.Equal(collect, sts1.TypedSpec().Value.GetTotalCount(), uint32(5)) // the total count also includes images already on the node
assert.Equal(collect, sts1.TypedSpec().Value.GetLastProcessedImage(), "node-2-image-4")
assert.Equal(collect, sts1.TypedSpec().Value.GetProcessedCount(), uint32(6)) // the processed count also includes images already on the node
assert.Equal(collect, sts1.TypedSpec().Value.GetTotalCount(), uint32(6)) // the total count also includes images already on the node
assert.Equal(collect, sts1.TypedSpec().Value.GetLastProcessedError(), "")

sts2, err := safe.StateGet[*omni.ImagePullStatus](suite.ctx, suite.state, omni.NewImagePullStatus(resources.DefaultNamespace, pr2.Metadata().ID()).Metadata())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,24 +55,8 @@ func (p PullTaskSpec) RunTask(ctx context.Context, _ *zap.Logger, pullStatusCh P
currentNum, totalNum int
currentNode, currentImage string
currentError error
skipFinalEvent bool
)

// make sure to call the callback at least once - even if there is
// no image to be pulled or if there was an error early on
defer func() {
if !skipFinalEvent {
channel.SendWithContext(ctx, pullStatusCh, PullStatus{
Request: p.request,
Node: currentNode,
Image: currentImage,
CurrentNum: currentNum,
TotalNum: totalNum,
Error: currentError,
})
}
}()

clusterID, ok := p.request.Metadata().Labels().Get(omni.LabelCluster)
if !ok {
return fmt.Errorf("missing cluster label on %q", p.request.Metadata())
Expand Down Expand Up @@ -102,8 +86,6 @@ func (p PullTaskSpec) RunTask(ctx context.Context, _ *zap.Logger, pullStatusCh P
errs = multierror.Append(errs, currentError)
}

skipFinalEvent = true // we are already sending an event, so no need to send one at the end

if !channel.SendWithContext(ctx, pullStatusCh, PullStatus{
Request: p.request,
Node: currentNode,
Expand All @@ -118,6 +100,15 @@ func (p PullTaskSpec) RunTask(ctx context.Context, _ *zap.Logger, pullStatusCh P
}
}

channel.SendWithContext(ctx, pullStatusCh, PullStatus{
Request: p.request,
Node: currentNode,
Image: currentImage,
CurrentNum: currentNum,
TotalNum: totalNum,
Error: currentError,
})

return errs
}

Expand Down

0 comments on commit 2fcd0fd

Please sign in to comment.