Skip to content

Commit

Permalink
Add support for custom context in plan execution (microsoft#826)
Browse files Browse the repository at this point in the history
This commit fixes a bug that allows users to pass a custom SKContext
object to a plan invocation, which can override or augment the default
context variables. This enables more flexibility and control over the
plan execution and the output. The commit also adds a unit test to
verify the functionality of the new feature.

Co-authored-by: Lee Miller <lemillermicrosoft@users.noreply.github.com>
Co-authored-by: Shawn Callegari <36091529+shawncal@users.noreply.github.com>
  • Loading branch information
3 people authored and codebrain committed May 16, 2023
1 parent 9cb0a60 commit 183b29f
Show file tree
Hide file tree
Showing 3 changed files with 235 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,9 @@ public ContextVariables Clone()
return clone;
}

#region private ================================================================================
internal const string MainKey = "INPUT";

private const string MainKey = "INPUT";
#region private ================================================================================

// Important: names are case insensitive
private readonly ConcurrentDictionary<string, string> _variables = new(StringComparer.OrdinalIgnoreCase);
Expand Down
179 changes: 179 additions & 0 deletions dotnet/src/SemanticKernel.UnitTests/Planning/PlanTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,185 @@ public async Task CanExecutePlanWithOneStepAndStateAsync()
mockFunction.Verify(x => x.InvokeAsync(It.IsAny<SKContext>(), null, null, default), Times.Once);
}

[Fact]
public async Task CanExecutePlanWithStateAsync()
{
// Arrange
var kernel = new Mock<IKernel>();
var log = new Mock<ILogger>();
var memory = new Mock<ISemanticTextMemory>();
var skills = new Mock<ISkillCollection>();

var returnContext = new SKContext(
new ContextVariables(),
memory.Object,
skills.Object,
log.Object
);

var mockFunction = new Mock<ISKFunction>();
mockFunction.Setup(x => x.InvokeAsync(It.IsAny<SKContext>(), null, null, default))
.Callback<SKContext, CompleteRequestSettings, ILogger, CancellationToken?>((c, s, l, ct) =>
{
c.Variables.Get("type", out var t);
returnContext.Variables.Update($"Here is a {t} about " + c.Variables.Input);
})
.Returns(() => Task.FromResult(returnContext));

var planStep = new Plan(mockFunction.Object);
planStep.Parameters.Set("type", string.Empty);
var plan = new Plan(string.Empty);
plan.AddSteps(planStep);
plan.State.Set("input", "Cleopatra");
plan.State.Set("type", "poem");

// Act
var result = await plan.InvokeAsync();

// Assert
Assert.NotNull(result);
Assert.Equal($"Here is a poem about Cleopatra", result.Result);
mockFunction.Verify(x => x.InvokeAsync(It.IsAny<SKContext>(), null, null, default), Times.Once);
}

[Fact]
public async Task CanExecutePlanWithCustomContextAsync()
{
// Arrange
var kernel = new Mock<IKernel>();
var log = new Mock<ILogger>();
var memory = new Mock<ISemanticTextMemory>();
var skills = new Mock<ISkillCollection>();

var returnContext = new SKContext(
new ContextVariables(),
memory.Object,
skills.Object,
log.Object
);

var mockFunction = new Mock<ISKFunction>();
mockFunction.Setup(x => x.InvokeAsync(It.IsAny<SKContext>(), null, null, default))
.Callback<SKContext, CompleteRequestSettings, ILogger, CancellationToken?>((c, s, l, ct) =>
{
c.Variables.Get("type", out var t);
returnContext.Variables.Update($"Here is a {t} about " + c.Variables.Input);
})
.Returns(() => Task.FromResult(returnContext));

var plan = new Plan(mockFunction.Object);
plan.State.Set("input", "Cleopatra");
plan.State.Set("type", "poem");

// Act
var result = await plan.InvokeAsync();

// Assert
Assert.NotNull(result);
Assert.Equal($"Here is a poem about Cleopatra", result.Result);
mockFunction.Verify(x => x.InvokeAsync(It.IsAny<SKContext>(), null, null, default), Times.Once);

plan = new Plan(mockFunction.Object);
plan.State.Set("input", "Cleopatra");
plan.State.Set("type", "poem");

var contextOverride = new SKContext(
new ContextVariables(),
memory.Object,
skills.Object,
log.Object
);
contextOverride.Variables.Set("type", "joke");
contextOverride.Variables.Update("Medusa");

// Act
result = await plan.InvokeAsync(contextOverride);

// Assert
Assert.NotNull(result);
Assert.Equal($"Here is a joke about Medusa", result.Result);
mockFunction.Verify(x => x.InvokeAsync(It.IsAny<SKContext>(), null, null, default), Times.Exactly(2));
}

[Fact]
public async Task CanExecutePlanWithCustomStateAsync()
{
// Arrange
var kernel = new Mock<IKernel>();
var log = new Mock<ILogger>();
var memory = new Mock<ISemanticTextMemory>();
var skills = new Mock<ISkillCollection>();

var returnContext = new SKContext(
new ContextVariables(),
memory.Object,
skills.Object,
log.Object
);

var mockFunction = new Mock<ISKFunction>();
mockFunction.Setup(x => x.InvokeAsync(It.IsAny<SKContext>(), null, null, default))
.Callback<SKContext, CompleteRequestSettings, ILogger, CancellationToken?>((c, s, l, ct) =>
{
c.Variables.Get("type", out var t);
returnContext.Variables.Update($"Here is a {t} about " + c.Variables.Input);
})
.Returns(() => Task.FromResult(returnContext));

var planStep = new Plan(mockFunction.Object);
planStep.Parameters.Set("type", string.Empty);
var plan = new Plan("A plan");
plan.State.Set("input", "Medusa");
plan.State.Set("type", "joke");
plan.AddSteps(planStep);

// Act
var result = await plan.InvokeAsync();

// Assert
Assert.NotNull(result);
Assert.Equal($"Here is a joke about Medusa", result.Result);
mockFunction.Verify(x => x.InvokeAsync(It.IsAny<SKContext>(), null, null, default), Times.Once);

planStep = new Plan(mockFunction.Object);
plan = new Plan("A plan");
planStep.Parameters.Set("input", "Medusa");
planStep.Parameters.Set("type", "joke");
plan.State.Set("input", "Cleopatra"); // state input will not override parameter
plan.State.Set("type", "poem");
plan.AddSteps(planStep);

// Act
result = await plan.InvokeAsync();

// Assert
Assert.NotNull(result);
Assert.Equal($"Here is a poem about Medusa", result.Result);
mockFunction.Verify(x => x.InvokeAsync(It.IsAny<SKContext>(), null, null, default), Times.Exactly(2));

planStep = new Plan(mockFunction.Object);
plan = new Plan("A plan");
planStep.Parameters.Set("input", "Cleopatra");
planStep.Parameters.Set("type", "poem");
plan.AddSteps(planStep);
var contextOverride = new SKContext(
new ContextVariables(),
memory.Object,
skills.Object,
log.Object
);
contextOverride.Variables.Set("type", "joke");
contextOverride.Variables.Update("Medusa"); // context input will not override parameters

// Act
result = await plan.InvokeAsync(contextOverride);

// Assert
Assert.NotNull(result);
Assert.Equal($"Here is a joke about Cleopatra", result.Result);
mockFunction.Verify(x => x.InvokeAsync(It.IsAny<SKContext>(), null, null, default), Times.Exactly(3));
}

[Fact]
public async Task CanExecutePlanWithJoinedResultAsync()
{
Expand Down
71 changes: 54 additions & 17 deletions dotnet/src/SemanticKernel/Planning/Plan.cs
Original file line number Diff line number Diff line change
Expand Up @@ -470,46 +470,83 @@ private SKContext UpdateContextWithOutputs(SKContext context)
/// <returns>The context variables for the next step in the plan.</returns>
private ContextVariables GetNextStepVariables(ContextVariables variables, Plan step)
{
// If the current step is passing to another plan, we set the default input to an empty string.
// Otherwise, we use the description from the current plan as the default input.
// We then set the input to the value from the SKContext, or the input from the Plan.State, or the default input.
var defaultInput = step.Steps.Count > 0 ? string.Empty : this.Description ?? string.Empty;
var planInput = string.IsNullOrEmpty(variables.Input) ? this.State.Input : variables.Input;
var stepInput = string.IsNullOrEmpty(planInput) ? defaultInput : planInput;
var stepVariables = new ContextVariables(stepInput);
// Priority for Input
// - Parameters (expand from variables if needed)
// - SKContext.Variables
// - Plan.State
// - Empty if sending to another plan
// - Plan.Description

var input = string.Empty;
if (!string.IsNullOrEmpty(step.Parameters.Input))
{
input = this.ExpandFromVariables(variables, step.Parameters.Input);
}
else if (!string.IsNullOrEmpty(variables.Input))
{
input = variables.Input;
}
else if (!string.IsNullOrEmpty(this.State.Input))
{
input = this.State.Input;
}
else if (step.Steps.Count > 0)
{
input = string.Empty;
}
else if (!string.IsNullOrEmpty(this.Description))
{
input = this.Description;
}

var stepVariables = new ContextVariables(input);

// Priority for remaining stepVariables is:
// - Parameters (pull from State by a key value)
// - Parameters (from context)
// - Parameters (from State)
// - Function Parameters (pull from variables or state by a key value)
// - Step Parameters (pull from variables or state by a key value)
var functionParameters = step.Describe();
foreach (var param in functionParameters.Parameters)
{
if (variables.Get(param.Name, out var value) && !string.IsNullOrEmpty(value))
if (param.Name.Equals(ContextVariables.MainKey, StringComparison.OrdinalIgnoreCase))
{
continue;
}

if (variables.Get(param.Name, out var value))
{
stepVariables.Set(param.Name, value);
}
else if (this.State.Get(param.Name, out value) && !string.IsNullOrEmpty(value))
else if (this.State.Get(param.Name, out value))
{
stepVariables.Set(param.Name, value);
}
}

foreach (var item in step.Parameters)
{
if (!string.IsNullOrEmpty(item.Value))
// Don't overwrite variable values that are already set
if (stepVariables.Get(item.Key, out _))
{
var value = this.ExpandFromVariables(variables, item.Value);
stepVariables.Set(item.Key, value);
continue;
}

var expandedValue = this.ExpandFromVariables(variables, item.Value);
if (!expandedValue.Equals(item.Value, StringComparison.OrdinalIgnoreCase))
{
stepVariables.Set(item.Key, expandedValue);
}
else if (variables.Get(item.Key, out var value) && !string.IsNullOrEmpty(value))
else if (variables.Get(item.Key, out var value))
{
stepVariables.Set(item.Key, value);
}
else if (this.State.Get(item.Key, out value) && !string.IsNullOrEmpty(value))
else if (this.State.Get(item.Key, out value))
{
stepVariables.Set(item.Key, value);
}
else
{
stepVariables.Set(item.Key, expandedValue);
}
}

return stepVariables;
Expand Down

0 comments on commit 183b29f

Please sign in to comment.