Skip to content

Commit

Permalink
agent status managing & paused job can be canceled (#285)
Browse files Browse the repository at this point in the history
* allow paused job be canceled

* agent status online & offline managing
  • Loading branch information
anda-ren committed May 10, 2022
1 parent 59f5ab6 commit 3c41091
Show file tree
Hide file tree
Showing 22 changed files with 221 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import javax.validation.Valid;
import org.springframework.http.ResponseEntity;
import org.springframework.validation.annotation.Validated;
import org.springframework.web.bind.annotation.DeleteMapping;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PathVariable;
import org.springframework.web.bind.annotation.PostMapping;
Expand Down Expand Up @@ -67,6 +68,24 @@ ResponseEntity<ResponseMessage<PageInfo<AgentVO>>> listAgent(
@RequestParam(value = "pageSize", required = false, defaultValue = "10")
Integer pageSize);

@Operation(summary = "remove offline agent")
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "ok",
content =
@Content(
mediaType = "application/json",
schema = @Schema(implementation = String.class)))
})
@DeleteMapping(value = "/system/agent")
ResponseEntity<ResponseMessage<String>> deleteAgent(
@Parameter(in = ParameterIn.QUERY, description = "the serialNumber of the agent to be deleted", schema = @Schema())
@Valid
@RequestParam(value = "serialNumber", required = true)
String serialNumber);

@Operation(summary = "Upgrade system version or cancel upgrade")
@ApiResponses(value = {@ApiResponse(responseCode = "200", description = "ok")})
@PostMapping(value = "/system/version/{action}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ public ResponseEntity<ResponseMessage<PageInfo<AgentVO>>> listAgent(String ip, I
return ResponseEntity.ok(Code.success.asResponse(pageInfo));
}

@Override
public ResponseEntity<ResponseMessage<String>> deleteAgent(String serialNumber) {
return ResponseEntity.ok(Code.success.asResponse(systemService.deleteOfflineAgent(serialNumber)));
}

@Override
public ResponseEntity<ResponseMessage<String>> systemVersionAction(String action) {
return ResponseEntity.ok(Code.success.asResponse("Unknown action"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,11 @@

package ai.starwhale.mlops.api.protocol.agent;

import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonValue;
import ai.starwhale.mlops.domain.system.agent.AgentStatus;
import io.swagger.v3.oas.annotations.media.Schema;
import java.io.Serializable;
import lombok.Builder;
import lombok.Data;
import org.aspectj.weaver.loadtime.Agent;
import org.checkerframework.checker.units.qual.A;
import org.springframework.validation.annotation.Validated;

@Data
Expand All @@ -36,45 +33,17 @@ public class AgentVO implements Serializable {

private String ip;

private String serialNumber;

private Long connectedTime;

private StatusEnum status;
private AgentStatus status;

private String version;

public static AgentVO empty() {
return new AgentVO("", "", -1L, StatusEnum.OFFLINE, "");
return new AgentVO("", "", "",-1L, AgentStatus.OFFLINE, "");
}

/**
* Gets or Sets status
*/
public enum StatusEnum {
ACTIVE("active"),

OFFLINE("offline");

private final String value;

StatusEnum(String value) {
this.value = value;
}

@Override
@JsonValue
public String toString() {
return String.valueOf(value);
}

@JsonCreator
public static StatusEnum fromValue(String text) {
for (StatusEnum b : StatusEnum.values()) {
if (String.valueOf(b.value).equals(text)) {
return b;
}
}
return null;
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -194,11 +194,14 @@ Stream<Job> findAllNewJobs(){
public void cancelJob(Long jobId){
Collection<Task> tasks = livingTaskCache.ofJob(jobId);
if(null == tasks || tasks.isEmpty()){
throw new StarWhaleApiException(new SWValidationException(ValidSubject.JOB).tip("freezing job can't be canceled "),
throw new StarWhaleApiException(new SWValidationException(ValidSubject.JOB).tip("freeze job can't be canceled "),
HttpStatus.BAD_REQUEST);
}
JobStatus desiredJobStatus = taskJobStatusHelper.desiredJobStatus(tasks);
if(desiredJobStatus != JobStatus.RUNNING){
if(desiredJobStatus != JobStatus.RUNNING
&& desiredJobStatus != JobStatus.PAUSED
&& desiredJobStatus != JobStatus.TO_COLLECT_RESULT
&& desiredJobStatus != JobStatus.COLLECTING_RESULT){
throw new SWValidationException(ValidSubject.JOB).tip("not running job can't be canceled ");
}
jobMapper.updateJobStatus(List.of(jobId), JobStatus.TO_CANCEL);
Expand All @@ -208,8 +211,9 @@ public void cancelJob(Long jobId){
|| task.getStatus() == TaskStatus.ASSIGNING).collect(Collectors.toList()), TaskStatus.TO_CANCEL);
List<Task> tobeCanceledTasks = tasks.parallelStream()
.filter(task -> task.getStatus() == TaskStatus.CREATED
|| task.getStatus() == TaskStatus.PAUSED
|| task.getStatus() == TaskStatus.UNKNOWN).collect(Collectors.toList());
swTaskScheduler.stopSchedule(tobeCanceledTasks.parallelStream().map(Task::getId).collect(
swTaskScheduler.stopSchedule(tobeCanceledTasks.parallelStream().filter(task -> task.getStatus() == TaskStatus.CREATED).map(Task::getId).collect(
Collectors.toList()));
updateTaskStatus(tobeCanceledTasks, TaskStatus.CANCELED);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ public AgentVO convert(AgentEntity agentEntity) throws ConvertException {
.id(idConvertor.convert(agentEntity.getId()))
.ip(agentEntity.getAgentIp())
.connectedTime(localDateTimeConvertor.convert(agentEntity.getCreatedTime()))
.status(agentEntity.getStatus())
.version(agentEntity.getAgentVersion())
.serialNumber(agentEntity.getSerialNumber())
.build();
}

Expand All @@ -54,6 +56,7 @@ public AgentEntity revert(AgentVO agentVO) throws ConvertException {
return AgentEntity.builder()
.id(idConvertor.revert(agentVO.getId()))
.agentIp(agentVO.getIp())
.serialNumber(agentVO.getSerialNumber())
.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package ai.starwhale.mlops.domain.system;

import ai.starwhale.mlops.common.BaseEntity;
import ai.starwhale.mlops.domain.system.agent.AgentStatus;
import java.time.LocalDateTime;
import lombok.AllArgsConstructor;
import lombok.Builder;
Expand All @@ -43,4 +44,6 @@ public class AgentEntity extends BaseEntity {

private String deviceInfo;

private AgentStatus status;

}
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,9 @@ public PageInfo<AgentVO> listAgents(String ipPrefix, PageParams pageParams) {
public String controllerVersion(){
return controllerVersion;
}

public String deleteOfflineAgent(String agentSerialNumber) {
agentCache.removeOfflineAgent(agentSerialNumber);
return "";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ public class Agent {

Long connectTime;

AgentStatus status;

@Override
public boolean equals(Object o) {
if (this == o) {
Expand Down Expand Up @@ -109,6 +111,11 @@ public Long getConnectTime() {
return agent.getConnectTime();
}

@Override
public AgentStatus getStatus(){
return agent.getStatus();
}

@Override
public void setId(Long id) {
throw new UnsupportedOperationException();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import ai.starwhale.mlops.domain.system.AgentEntity;
import ai.starwhale.mlops.domain.system.agent.Agent.AgentUnModifiable;
import ai.starwhale.mlops.domain.system.mapper.AgentMapper;
import ai.starwhale.mlops.exception.SWValidationException;
import ai.starwhale.mlops.exception.SWValidationException.ValidSubject;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
Expand All @@ -43,10 +45,13 @@ public class AgentCache implements CommandLineRunner {

final AgentConverter agentConverter;

public AgentCache(AgentMapper agentMapper,
AgentConverter agentConverter) {
final List<AgentStatusWatcher> agentStatusWatchers;

public AgentCache(AgentMapper agentMapper, AgentConverter agentConverter,
List<AgentStatusWatcher> agentStatusWatchers) {
this.agentMapper = agentMapper;
this.agentConverter = agentConverter;
this.agentStatusWatchers = agentStatusWatchers;
agents = new ConcurrentHashMap<>();
}

Expand All @@ -55,6 +60,18 @@ public List<Agent> agents(){
Collectors.toList());
}

public void removeOfflineAgent(String agentSerialNumber){
Agent tobeDeleteAgent = agents.get(agentSerialNumber);
if(null == tobeDeleteAgent){
return;
}
if(tobeDeleteAgent.getStatus() != AgentStatus.OFFLINE){
throw new SWValidationException(ValidSubject.NODE).tip("you can't remove online agent manually!");
}
agentMapper.deleteById(tobeDeleteAgent.getId());
agents.remove(agentSerialNumber);
}

public Agent nodeReport(Node node){
log.debug("node reported {}",node.getSerialNumber());
Agent agentReported = agentConverter.fromNode(node);
Expand All @@ -65,6 +82,7 @@ public Agent nodeReport(Node node){
return new AgentUnModifiable(agentReported);
}else {
residentAgent.setAgentVersion(agentReported.getAgentVersion());
residentAgent.setStatus(AgentStatus.ONLINE);
residentAgent.setNodeInfo(agentReported.getNodeInfo());
residentAgent.setConnectTime(agentReported.getConnectTime());
return new AgentUnModifiable(residentAgent);
Expand All @@ -73,7 +91,15 @@ public Agent nodeReport(Node node){

@Scheduled(initialDelay = 10000,fixedDelay = 30000)
public void flushDb(){
long now = System.currentTimeMillis();
int bareTimeMilli = 30000;
List<AgentEntity> agentEntities = agents.values().stream()
.peek(agent -> {
if( (now - agent.getConnectTime())> bareTimeMilli && AgentStatus.ONLINE == agent.getStatus()){
agent.setStatus(AgentStatus.OFFLINE);
agentStatusWatchers.parallelStream().forEach(watcher->watcher.agentStatusChange(agent,AgentStatus.OFFLINE));
}
})
.map(agent -> agentConverter.toEntity(agent))
.collect(Collectors.toList());
if(null == agentEntities || agentEntities.isEmpty()){
Expand All @@ -99,7 +125,7 @@ public void run(String... args) throws Exception {
private void initCache() {
List<AgentEntity> agentEntities = agentMapper.listAgents();
agentEntities.parallelStream().forEach(entity -> {
agents.put(entity.getAgentIp(),agentConverter.fromEntity(entity));
agents.put(entity.getSerialNumber(),agentConverter.fromEntity(entity));
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ public Agent fromNode(Node node){
.serialNumber(node.getSerialNumber())
.nodeInfo(new NodeInfo(node.getMemorySizeGB(),node.getDevices()))
.agentVersion(node.getAgentVersion())
.status(AgentStatus.ONLINE)
.connectTime(Instant.now().toEpochMilli())
.build();
}
Expand All @@ -63,6 +64,7 @@ public Agent fromEntity(AgentEntity entity){
.ip(entity.getAgentIp())
.serialNumber(entity.getSerialNumber())
.agentVersion(entity.getAgentVersion())
.status(entity.getStatus())
.nodeInfo(nodeInfo)
.connectTime(entity.getConnectTime().atZone(ZoneId.systemDefault()).toInstant().toEpochMilli())
.build();
Expand All @@ -80,6 +82,7 @@ public AgentEntity toEntity(Agent agent){
.serialNumber(agent.getSerialNumber())
.agentIp(agent.getIp())
.agentVersion(agent.getAgentVersion())
.status(agent.getStatus())
.connectTime(Instant.ofEpochMilli(agent.getConnectTime()).atZone(ZoneId.systemDefault()).toLocalDateTime())
.deviceInfo(deviceInfo)
.build();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
/*
* Copyright 2022 Starwhale, Inc. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package ai.starwhale.mlops.domain.system.agent;

public enum AgentStatus {
ONLINE,OFFLINE
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
/*
* Copyright 2022 Starwhale, Inc. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package ai.starwhale.mlops.domain.system.agent;

public interface AgentStatusWatcher {
void agentStatusChange(Agent agent,AgentStatus newStatus);
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ public interface AgentMapper {

Long addAgent(@Param("agent")AgentEntity agent);

void deleteById(@Param("agentId")Long agentId);

void updateAgents(@Param("agents")List<AgentEntity> agents);

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package ai.starwhale.mlops.domain.task;

import ai.starwhale.mlops.domain.job.status.JobStatusMachine;
import ai.starwhale.mlops.domain.task.bo.Task.StatusUnModifiableTask;
import ai.starwhale.mlops.domain.task.status.TaskStatus;
import static ai.starwhale.mlops.domain.task.status.TaskStatus.*;

Expand Down Expand Up @@ -126,7 +127,13 @@ public LivingTaskCacheImpl(TaskMapper taskMapper, JobMapper jobMapper,

@Override
public void adopt(Collection<Task> tasks, final TaskStatus status) {
tasks.parallelStream().forEach(task -> {
tasks.parallelStream().map(task -> {
if(task instanceof StatusUnModifiableTask){
StatusUnModifiableTask statusUnModifiableTask = (StatusUnModifiableTask) task;
return statusUnModifiableTask.getOTask();
}
return task;
}).forEach(task -> {
updateCache(status, task);
});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,13 @@ public static class StatusUnModifiableTask extends Task{
public StatusUnModifiableTask(Task task){
this.oTask = task;
}

public Task getOTask(){
if(oTask instanceof StatusUnModifiableTask){
StatusUnModifiableTask statusUnModifiableTask = (StatusUnModifiableTask) oTask;
return statusUnModifiableTask.getOTask();
}
return oTask;
}
@Override
public Long getId() {
return oTask.id;
Expand Down
Loading

0 comments on commit 3c41091

Please sign in to comment.