Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 65 additions & 21 deletions crates/assistant2/src/active_thread.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
use crate::thread::{MessageId, RequestKind, Thread, ThreadError, ThreadEvent};
use crate::thread::{
LastRestoreCheckpoint, MessageId, RequestKind, Thread, ThreadError, ThreadEvent,
};
use crate::thread_store::ThreadStore;
use crate::tool_use::{ToolUse, ToolUseStatus};
use crate::ui::ContextPill;
use collections::HashMap;
use editor::{Editor, MultiBuffer};
use gpui::{
list, percentage, AbsoluteLength, Animation, AnimationExt, AnyElement, App, ClickEvent,
DefiniteLength, EdgesRefinement, Empty, Entity, Focusable, Length, ListAlignment, ListOffset,
ListState, StyleRefinement, Subscription, Task, TextStyleRefinement, Transformation,
UnderlineStyle, WeakEntity,
list, percentage, pulsating_between, AbsoluteLength, Animation, AnimationExt, AnyElement, App,
ClickEvent, DefiniteLength, EdgesRefinement, Empty, Entity, Focusable, Length, ListAlignment,
ListOffset, ListState, StyleRefinement, Subscription, Task, TextStyleRefinement,
Transformation, UnderlineStyle, WeakEntity,
};
use language::{Buffer, LanguageRegistry};
use language_model::{LanguageModelRegistry, LanguageModelToolUseId, Role};
Expand All @@ -18,7 +20,7 @@ use settings::Settings as _;
use std::sync::Arc;
use std::time::Duration;
use theme::ThemeSettings;
use ui::{prelude::*, Disclosure, KeyBinding};
use ui::{prelude::*, Disclosure, KeyBinding, Tooltip};
use util::ResultExt as _;
use workspace::{OpenOptions, Workspace};

Expand Down Expand Up @@ -401,7 +403,6 @@ impl ActiveThread {
window,
cx,
);

self.render_scripting_tool_use_markdown(
tool_use.id.clone(),
tool_use.name.as_ref(),
Expand Down Expand Up @@ -463,6 +464,7 @@ impl ActiveThread {
}
}
}
ThreadEvent::CheckpointChanged => cx.notify(),
}
}

Expand Down Expand Up @@ -789,20 +791,62 @@ impl ActiveThread {
v_flex()
.when(ix == 0, |parent| parent.child(self.render_rules_item(cx)))
.when_some(checkpoint, |parent, checkpoint| {
parent.child(
h_flex().pl_2().child(
Button::new(("restore-checkpoint", ix), "Restore Checkpoint")
.icon(IconName::Undo)
.size(ButtonSize::Compact)
.on_click(cx.listener(move |this, _, _window, cx| {
this.thread.update(cx, |thread, cx| {
thread
.restore_checkpoint(checkpoint.clone(), cx)
.detach_and_log_err(cx);
});
})),
),
)
let mut is_pending = false;
let mut error = None;
if let Some(last_restore_checkpoint) =
self.thread.read(cx).last_restore_checkpoint()
{
if last_restore_checkpoint.message_id() == message_id {
match last_restore_checkpoint {
LastRestoreCheckpoint::Pending { .. } => is_pending = true,
LastRestoreCheckpoint::Error { error: err, .. } => {
error = Some(err.clone());
}
}
}
}

let restore_checkpoint_button =
Button::new(("restore-checkpoint", ix), "Restore Checkpoint")
.icon(if error.is_some() {
IconName::XCircle
} else {
IconName::Undo
})
.size(ButtonSize::Compact)
.disabled(is_pending)
.icon_color(if error.is_some() {
Some(Color::Error)
} else {
None
})
.on_click(cx.listener(move |this, _, _window, cx| {
this.thread.update(cx, |thread, cx| {
thread
.restore_checkpoint(checkpoint.clone(), cx)
.detach_and_log_err(cx);
});
}));

let restore_checkpoint_button = if is_pending {
restore_checkpoint_button
.with_animation(
("pulsating-restore-checkpoint-button", ix),
Animation::new(Duration::from_secs(2))
.repeat()
.with_easing(pulsating_between(0.6, 1.)),
|label, delta| label.alpha(delta),
)
.into_any_element()
} else if let Some(error) = error {
restore_checkpoint_button
.tooltip(Tooltip::text(error.to_string()))
.into_any_element()
} else {
restore_checkpoint_button.into_any_element()
};

parent.child(h_flex().pl_2().child(restore_checkpoint_button))
})
.child(styled_message)
.into_any()
Expand Down
48 changes: 46 additions & 2 deletions crates/assistant2/src/thread.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,25 @@ pub struct ThreadCheckpoint {
git_checkpoint: GitStoreCheckpoint,
}

pub enum LastRestoreCheckpoint {
Pending {
message_id: MessageId,
},
Error {
message_id: MessageId,
error: String,
},
}

impl LastRestoreCheckpoint {
pub fn message_id(&self) -> MessageId {
match self {
LastRestoreCheckpoint::Pending { message_id } => *message_id,
LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
}
}
}

/// A thread of conversation with the LLM.
pub struct Thread {
id: ThreadId,
Expand All @@ -118,6 +137,7 @@ pub struct Thread {
tools: Arc<ToolWorkingSet>,
tool_use: ToolUseState,
action_log: Entity<ActionLog>,
last_restore_checkpoint: Option<LastRestoreCheckpoint>,
scripting_session: Entity<ScriptingSession>,
scripting_tool_use: ToolUseState,
initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
Expand Down Expand Up @@ -147,6 +167,7 @@ impl Thread {
project: project.clone(),
prompt_builder,
tools: tools.clone(),
last_restore_checkpoint: None,
tool_use: ToolUseState::new(tools.clone()),
scripting_session: cx.new(|cx| ScriptingSession::new(project.clone(), cx)),
scripting_tool_use: ToolUseState::new(tools),
Expand Down Expand Up @@ -207,6 +228,7 @@ impl Thread {
checkpoints_by_message: HashMap::default(),
completion_count: 0,
pending_completions: Vec::new(),
last_restore_checkpoint: None,
project,
prompt_builder,
tools,
Expand Down Expand Up @@ -279,17 +301,38 @@ impl Thread {
checkpoint: ThreadCheckpoint,
cx: &mut Context<Self>,
) -> Task<Result<()>> {
self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
message_id: checkpoint.message_id,
});
cx.emit(ThreadEvent::CheckpointChanged);

let project = self.project.read(cx);
let restore = project
.git_store()
.read(cx)
.restore_checkpoint(checkpoint.git_checkpoint, cx);
cx.spawn(async move |this, cx| {
restore.await?;
this.update(cx, |this, cx| this.truncate(checkpoint.message_id, cx))
let result = restore.await;
this.update(cx, |this, cx| {
if let Err(err) = result.as_ref() {
this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
message_id: checkpoint.message_id,
error: err.to_string(),
});
} else {
this.last_restore_checkpoint = None;
this.truncate(checkpoint.message_id, cx);
}
cx.emit(ThreadEvent::CheckpointChanged);
})?;
result
})
}

pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
self.last_restore_checkpoint.as_ref()
}

pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
let Some(message_ix) = self
.messages
Expand Down Expand Up @@ -1361,6 +1404,7 @@ pub enum ThreadEvent {
/// Whether the tool was canceled by the user.
canceled: bool,
},
CheckpointChanged,
}

impl EventEmitter<ThreadEvent> for Thread {}
Expand Down
3 changes: 2 additions & 1 deletion crates/project/src/git.rs
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,8 @@ impl GitStore {
let mut tasks = Vec::new();
for (dot_git_abs_path, checkpoint) in checkpoint.checkpoints_by_dot_git_abs_path {
if let Some(repository) = repositories_by_dot_git_abs_path.get(&dot_git_abs_path) {
tasks.push(repository.read(cx).restore_checkpoint(checkpoint));
let restore = repository.read(cx).restore_checkpoint(checkpoint);
tasks.push(async move { restore.await? });
}
}
cx.background_spawn(async move {
Expand Down
Loading