In [None]:
:dep tapis-authenticator = { git = "https://github.com/tapis-project/tapis-rust-sdk", package = "tapis-authenticator" }
:dep tapis-pods = { git = "https://github.com/tapis-project/tapis-rust-sdk", package = "tapis-pods" }
:dep serde_json = "1.0.149"
:dep serde = { version = "1.0.188", features = ["derive"] }
:dep sha2 = "0.10.9"
:dep base62 = "2.2.3"

In [None]:
// :dep tapis-pods = { git = "https://github.com/tapis-project/tapis-rust-sdk", package = "tapis-pods" }
// :dep serde_json = "1.0.149"
// :dep serde = { version = "1.0.188", features = ["derive"] }

use std::collections::HashMap;
use tapis_pods;
use serde_json::Value;
use serde::{Serialize, Deserialize};
use sha2::{Sha256, Digest};
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine as _};


/// Supported ML inference backends
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum Backend {
    #[serde(rename = "transformers")]
    Transformers {command: Vec<String>},
    #[serde(rename = "vllm")]
    VLlm {command: Vec<String>},
    #[serde(rename = "sglang")]
    SGLang {command: Vec<String>},
    #[serde(rename = "trtllm")]
    TrtLlm {command: Vec<String>},
}

impl Backend {
    pub fn as_str(&self) -> &str {
        match self {
            Backend::Transformers {..} => "transformers",
            Backend::VLlm {..} => "vllm",
            Backend::SGLang {..} => "sglang",
            Backend::TrtLlm {..} => "trtllm",
        }
    }

    /// Create a Transformers parameter builder
    pub fn transformers(&self) -> TransformersParametersBuilder {
        match self {
            Backend::Transformers {command} => {
                TransformersParametersBuilder::new(command.clone())
            }
            _ => panic!("Backend is not Transformers"),
        }
    }

    /// Create a vLLM parameter builder
    pub fn vllm(&self, command: Vec<String>) -> VLlmParametersBuilder {
        match self {
            Backend::VLlm {command} => {
                VLlmParametersBuilder::new(command.clone())
            },
            _ => panic!("Backend is not vLLM"),
        }
    }

    /// Create an SGLang parameter builder
    pub fn sglang(&self, command: Vec<String>) -> SGLangParametersBuilder {
        match self{
            Backend::SGLang {command} => {
                SGLangParametersBuilder::new(command.clone())
            },
            _ => panic!("Backend is not SGLang"),
        }
    }

    /// Create a TRT-LLM parameter builder
    pub fn trtllm(&self, command: Vec<String>) -> TrtLlmParametersBuilder {
        match self {
            Backend::TrtLlm {command} => {
                TrtLlmParametersBuilder::new(command.clone())
            },
            _ => panic!("Backend is not TRT-LLM"),
        }
    }
}

/// Backend-specific parameters
/// This is a flexible JSON object that varies by backend
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BackendParameters {
    pub command: Vec<String>,
    pub params: HashMap<String, Value>,
    pub env: HashMap<String, String>,
}

impl BackendParameters {
    pub fn new(command: Vec<String>) -> Self {
        Self {
            command,
            params: HashMap::new(),
            env: HashMap::new(),
        }
    }

    pub fn insert_param<T: Serialize>(&mut self, key: impl Into<String>, value: T) -> &mut Self {
        self.params
            .insert(key.into(), serde_json::to_value(value).unwrap());
        self
    }

    pub fn insert_env(&mut self, key: impl Into<String>, value: impl Into<String>) -> &mut Self {
        self.env.insert(key.into(), value.into());
        self
    }

    pub fn get(&self, key: &str) -> Option<&Value> {
        self.params.get(key)
    }
}



/// Builder for transformers parameters
pub struct TransformersParametersBuilder {
    params: BackendParameters,
}

impl TransformersParametersBuilder {
    pub fn new(command: Vec<String>) -> Self {
        Self {
            params: BackendParameters::new(command),
        }
    }

    pub fn default_model(mut self, model: &str) -> Self {
        self.params.insert_param("default-model", model);
        self
    }
    pub fn default_embedding_model(mut self, model: &str) -> Self {
        self.params.insert_param("default-embedding-model", model);
        self
    }
    pub fn host(mut self, host: &str) -> Self {
        self.params.insert_param("host", host);
        self
    }
    pub fn port(mut self, port: u16) -> Self {
        self.params.insert_param("port", port);
        self
    }
    pub fn device(mut self, device: &str) -> Self {
        self.params.insert_param("device", device);
        self
    }
    pub fn dtype(mut self, dtype: &str) -> Self {
        self.params.insert_param("dtype", dtype);
        self
    }
    pub fn continuous_batching(mut self, enabled: bool) -> Self {
        self.params.insert_param("continuous-batching", enabled);
        self
    }
    pub fn flexserv_token(mut self, token: &str) -> Self {
        self.params.insert_param("flexserv-token", token);
        self
    }
    pub fn force_default_model(mut self, force: bool) -> Self {
        self.params.insert_param("force-default-model", force);
        self
    }
    pub fn force_default_embedding_model(mut self, force: bool) -> Self {
        self.params.insert_param("force-default-embedding-model", force);
        self
    }
    pub fn log_level(mut self, level: &str) -> Self {
        self.params.insert_param("log-level", level);
        self
    }
    pub fn quantization(mut self, quant: &str) -> Self {
        self.params.insert_param("quantization", quant);
        self
    }
    pub fn trust_remote_code(mut self, trust: bool) -> Self {
        self.params.insert_param("trust-remote-code", trust);
        self
    }
    pub fn attn_implementation(mut self, implementation: &str) -> Self {
        self.params.insert_param("attn-implementation", implementation);
        self
    }
    pub fn enable_cors(mut self, enable: bool) -> Self {
        self.params.insert_param("enable-cors", enable);
        self
    }
    pub fn non_blocking(mut self, non_blocking: bool) -> Self {
        self.params.insert_param("non-blocking", non_blocking);
        self
    }
    pub fn build(self) -> BackendParameters {
        self.params
    }
}

/// Builder for vLLM parameters
pub struct VLlmParametersBuilder {
    params: BackendParameters,
}

impl VLlmParametersBuilder {
    pub fn new() -> Self {
        Self {
            params: BackendParameters::new(),
        }
    }
    pub fn tensor_parallel_size(mut self, size: u32) -> Self {
        self.params.insert("tensor_parallel_size", size);
        self
    }
    pub fn pipeline_parallel_size(mut self, size: u32) -> Self {
        self.params.insert("pipeline_parallel_size", size);
        self
    }
    pub fn max_model_len(mut self, len: u32) -> Self {
        self.params.insert("max_model_len", len);
        self
    }
    pub fn gpu_memory_utilization(mut self, util: f32) -> Self {
        self.params.insert("gpu_memory_utilization", util);
        self
    }
    pub fn build(self) -> BackendParameters {
        self.params
    }
}

/// Builder for SGLang parameters
pub struct SGLangParametersBuilder {
    params: BackendParameters,
}

impl SGLangParametersBuilder {
    pub fn new() -> Self {
        Self {
            params: BackendParameters::new(),
        }
    }
    pub fn tp_size(mut self, size: u32) -> Self {
        self.params.insert("tp_size", size);
        self
    }
    pub fn mem_fraction_static(mut self, fraction: f32) -> Self {
        self.params.insert("mem_fraction_static", fraction);
        self
    }
    pub fn build(self) -> BackendParameters {
        self.params
    }
}

/// Builder for TRT-LLM parameters
pub struct TrtLlmParametersBuilder {
    params: BackendParameters,
}

impl TrtLlmParametersBuilder {
    pub fn new() -> Self {
        Self {
            params: BackendParameters::new(),
        }
    }
    pub fn max_batch_size(mut self, size: u32) -> Self {
        self.params.insert("max_batch_size", size);
        self
    }
    pub fn max_input_len(mut self, len: u32) -> Self {
        self.params.insert("max_input_len", len);
        self
    }
    pub fn build(self) -> BackendParameters {
        self.params
    }
}


pub struct FlexServ {
    /// tenant url
    pub tenant_url: String,

    /// tapis username
    pub tapis_user: String,

    /// model to deploy (e.g., "meta-llama/Llama-3-70b-hf")
    pub default_model: String,

    /// default embedding model
    pub default_embedding_model: Option<String>,

    /// backend to use
    pub backend: Backend,
    
}
impl FlexServ {
    pub fn new(tenant_url: String, tapis_user: String, default_model: String, default_embedding_model: Option<String>, backend: Backend) -> Self {
        FlexServ { tenant_url, tapis_user, default_model, default_embedding_model, backend }
    }

    pub fn deployment_hash(&self) -> String {
        // Create a unique hash for the deployment configuration
        let config_string = format!("{}@{}-{}-{:?}", self.tapis_user, self.tenant_url, self.default_model, self.backend);
        let digest = Sha256::digest(config_string.as_bytes());
        let hash = 
        format!("{:x}", sha256::compute(config_string))
    }
}

pub struct FlexServPodDeployment {
    new_volume: tapis_pods::models::NewVolume,
    new_pod: tapis_pods::models::NewPod,
    volume_info: tapis_pods::models::VolumeResponseModel,
    pod_info: tapis_pods::models::PodResponseModel,
    server: FlexServ,
}

pub struct FlexServHPCDeployment {
    server: FlexServ,
}
/// Deployment result enum
#[derive(Debug)]
pub enum DeploymentResult {
    PodResult {pod_info: tapis_pods::models::PodResponseModel, volume_info: tapis_pods::models::VolumeResponseModel, tapis_user: String, tapis_tenant: String, model_sha: String},
    HPCResult {job_info: String, tapis_user: String, tapis_tenant: String, model_sha: String},
}

/// Deployment related errors
/// We can bind the message to this enum variant for more detailed error information
/// 1. TapisAuthFailed(String) - Authentication to Tapis failed
/// 2. TapisAPIUnreachable(String) - Tapis API is unreachable
/// 3. TapisBadRequest(String) - Bad request to Tapis API
/// 4. TapisTimeout(String) - Request to Tapis API timed out
/// 5. TapisInternalServerError(String) - Tapis API internal server error
/// 6. UnkownError(String) - Unknown error
/// 7. ModelUploadingFailed(String) - Model uploading failed not because of any of the reasons from 1-6.
/// 8. PodCreationFailed(String) - Pod creation failed not be cause of any of the reasons from 1-6.
/// 9. JobCreationFailed(String) - Job creation failed not be cause of any of the reasons from 1-6.
#[derive(Debug)]
pub enum DeploymentError {
    TapisAuthFailed(String),
    TapisAPIUnreachable(String),
    TapisBadRequest(String),
    TapisTimeout(String),
    TapisInternalServerError(String),
    UnkownError(String),
    ModelUploadingFailed(String),
    PodCreationFailed(String),
    JobCreationFailed(String),
}

pub trait FlexServDeployment{
    fn create(&mut self) -> Result<DeploymentResult, DeploymentError>;
    fn start(&self)-> Result<DeploymentResult, DeploymentError>;
    fn stop(&self)-> Result<DeploymentResult, DeploymentError>;
    fn terminate(&self)-> Result<DeploymentResult, DeploymentError>;
    fn monitor(&self)-> Result<DeploymentResult, DeploymentError>;
}

impl FlexServDeployment for FlexServPodDeployment {
    fn create(&mut self) -> Result<DeploymentResult, DeploymentError> {
        // Create volume and pod
        todo!()
    }

    fn start(&self) -> Result<DeploymentResult, DeploymentError> {
        // Start pod
        todo!()
    }

    fn stop(&self) -> Result<DeploymentResult, DeploymentError> {
        // Stop pod
        todo!()
    }

    fn terminate(&self) -> Result<DeploymentResult, DeploymentError> {
        // Terminate pod and delete volume
        todo!()
    }

    fn monitor(&self) -> Result<DeploymentResult, DeploymentError> {
        // Monitor pod status
        todo!()
    }
}