diff --git a/Dockerfile b/Dockerfile index 9ae3dae1..756eaea2 100644 --- a/Dockerfile +++ b/Dockerfile @@ -230,6 +230,11 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins g++ \ && rm -rf /var/lib/apt/lists/* +# Install gui and client +COPY clients clients + +RUN cd clients/gui && \ + make install # Final image FROM base diff --git a/clients/gui/Makefile b/clients/gui/Makefile new file mode 100644 index 00000000..55b7aed7 --- /dev/null +++ b/clients/gui/Makefile @@ -0,0 +1,3 @@ +install: + pip install pip --upgrade + pip install -r ./requirements.txt diff --git a/clients/gui/lorax-gui/__init__.py b/clients/gui/lorax-gui/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/clients/gui/lorax-gui/app.py b/clients/gui/lorax-gui/app.py new file mode 100644 index 00000000..8bbaf7b4 --- /dev/null +++ b/clients/gui/lorax-gui/app.py @@ -0,0 +1,123 @@ +import streamlit as st +import lorax +from lorax.types import Parameters, MERGE_STRATEGIES, ADAPTER_SOURCES, MAJORITY_SIGN_METHODS, MergedAdapters + +LORAX_PORT = 8080 +HOST = f"http://localhost:{LORAX_PORT}" + +def add_merge_adapter(i): + merge_adapter_id = st.text_input(f"Merge Adapter ID {i+1}", key=f"merge_adapter_id_{i}", value=None, placeholder="Merge Adapter Id") + merge_adapter_weight = st.number_input(f"Merge Adapter Weight {i+1}", key=f"merge_adapter_weight_{i}", value=None, placeholder="Merge Adapter weight") + st.divider() + return merge_adapter_id, merge_adapter_weight + +def render_parameters(params: Parameters): + with st.expander("Request parameters"): + max_new_tokens = st.slider("Max new tokens", 0, 256, 20) + repetition_penalty = st.number_input( + "Repetition penalty", + help="The parameter for repetition penalty. 1.0 means no penalty. See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.", + min_value=0.0, + value=None, + placeholder="No repition penalty by default", + ) + return_full_text = st.checkbox("Return full text", help="Whether to return the full text or just the generated part") + stop_tokens = st.text_input("Stop tokens", help="A comma seperated list of tokens where generation is stopped if the model encounters any of them") + stop_sequences = stop_tokens.split(",") if stop_tokens else [] + seed = st.number_input("Seed", help="Random seed for generation", min_value=0, value=None, placeholder="Random seed for generation") + temperature = st.number_input("Temperature", help="The value used to module the next token probabilities", min_value=0.0, value=None, placeholder="Temperature for generation") + best_of = st.number_input("Best of", help="The number of independently computed samples to generate and then pick the best from", value=None, placeholder="Best of for generation") + best_of_int = int(best_of) if best_of else None + watermark = st.checkbox("Watermark", help="Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)") + decoder_input_details = st.checkbox("Decoder input details", help="Whether to return the decoder input details") + + do_sample_val = st.checkbox("Do Sample", help="Whether to use sampling or greedy decoding for text generation") + top_k_int, top_p, typical_p = None, None, None + if do_sample_val: + top_k = st.number_input( + "Top K", + help="The number of highest probability vocabulary tokens to keep for top-k-filtering", + value=None, + placeholder="Top K for generation", + format="%d" + ) + top_k_int = int(top_k) if top_k else None + top_p = st.number_input("Top P", help="The cumulative probability of parameter for top-p-filtering", value=None, placeholder="Top P for generation") + typical_p = st.number_input( + "Typical P", + help="The typical decoding mass. See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information", + value=None, + placeholder="Typical P for generation" + ) + params.max_new_tokens = max_new_tokens + params.repetition_penalty = repetition_penalty + params.return_full_text = return_full_text + params.stop = stop_sequences + params.seed = seed + params.temperature = temperature + params.best_of = best_of_int + params.watermark = watermark + params.decoder_input_details = decoder_input_details + params.do_sample = do_sample_val + params.top_k = top_k_int + params.top_p = top_p + params.typical_p = typical_p + st.write(params) + + + with st.expander("Adapter Configuration"): + adapter_id = st.text_input("Adapter ID", value=None, placeholder="Adapter id") + adapter_source = st.selectbox("Adapter Source", options=ADAPTER_SOURCES, index=None) + api_token = st.text_input("API Token", value=None, placeholder="API token") + if st.checkbox("Merged Adapters"): + num_adapters = st.slider("Number of Merge Adapters", value=1, min_value=1, max_value=10) + merge_strategy = st.selectbox("Merge Strategy", options=MERGE_STRATEGIES, index=None) + majority_sign_method = st.selectbox("Majority Sign Method", options=MAJORITY_SIGN_METHODS, index=None) + density = st.number_input("Density", value=0.0, placeholder="Density") + st.divider() + merge_adapters_list = [add_merge_adapter(i) for i in range(num_adapters)] + merge_adapter_ids, merge_adapter_weights = zip(*merge_adapters_list) + merge_adapters = MergedAdapters( + ids=merge_adapter_ids, + weights=merge_adapter_weights, + density=density, + merge_strategy=merge_strategy, + majority_sign_method=majority_sign_method, + ) + params.merged_adapters = merge_adapters + params.adapter_id = adapter_id + params.adapter_source = adapter_source + params.api_token = api_token + st.write(params) + + +def main(): + st.markdown( + r""" + + """, unsafe_allow_html=True + ) + + st.title("Lorax GUI") + params = Parameters() + render_parameters(params) + + txt = st.text_area("Enter prompt", "Type Here ...") + client = lorax.Client(HOST) + params_dict = params.dict() + params_dict['stop_sequences'] = params_dict['stop'] + params_dict.pop('stop') + params_dict.pop('details') + if st.button("Generate"): + resp = client.generate(prompt=txt, **params_dict) + st.write(resp.generated_text) + with st.expander("Response details"): + st.write(resp.details.dict()) + + +if __name__ == "__main__": + main() diff --git a/clients/gui/requirements.txt b/clients/gui/requirements.txt new file mode 100644 index 00000000..6b0df2d3 --- /dev/null +++ b/clients/gui/requirements.txt @@ -0,0 +1,2 @@ +streamlit +lorax-client \ No newline at end of file diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 0f3c673f..9a8e536c 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -357,6 +357,15 @@ struct Args { /// Download model weights only #[clap(long, env)] download_only: bool, + + /// Launch testing GUI + /// This will launch a GUI that will allow you to test the model + #[clap(long, env)] + launch_gui: bool, + + /// Gui port + #[clap(long, env)] + gui_port: Option, } #[derive(Debug)] @@ -958,6 +967,30 @@ fn spawn_shards( Ok(()) } +// fn spawn_gui(port: u16) -> Result { +// tracing::info!("Starting GUI"); + +// let mut webserver = match Command::new("lorax-router") +// .stdout(Stdio::piped()) +// .stderr(Stdio::piped()) +// .process_group(0) +// .spawn() +// { +// Ok(p) => p, +// Err(err) => { +// tracing::error!("Failed to start webserver: {}", err); +// if err.kind() == io::ErrorKind::NotFound { +// tracing::error!("lorax-router not found in PATH"); +// tracing::error!("Please install it with `make install-router`") +// } else { +// tracing::error!("{}", err); +// } +// return Err(LauncherError::WebserverCannotStart); +// } +// }; +// Ok(webserver) +// } + fn spawn_webserver( args: Args, shutdown: Arc,