Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Streamlit frontend to query models #257

Draft
wants to merge 16 commits into
base: main
Choose a base branch
from
5 changes: 5 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,11 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
g++ \
&& rm -rf /var/lib/apt/lists/*

# Install gui and client
COPY clients clients

RUN cd clients/gui && \
make install

# Final image
FROM base
Expand Down
3 changes: 3 additions & 0 deletions clients/gui/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
install:
pip install pip --upgrade
pip install -r ./requirements.txt
Empty file.
123 changes: 123 additions & 0 deletions clients/gui/lorax-gui/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import streamlit as st
import lorax
from lorax.types import Parameters, MERGE_STRATEGIES, ADAPTER_SOURCES, MAJORITY_SIGN_METHODS, MergedAdapters

LORAX_PORT = 8080
HOST = f"http://localhost:{LORAX_PORT}"

def add_merge_adapter(i):
merge_adapter_id = st.text_input(f"Merge Adapter ID {i+1}", key=f"merge_adapter_id_{i}", value=None, placeholder="Merge Adapter Id")
merge_adapter_weight = st.number_input(f"Merge Adapter Weight {i+1}", key=f"merge_adapter_weight_{i}", value=None, placeholder="Merge Adapter weight")
st.divider()
return merge_adapter_id, merge_adapter_weight

def render_parameters(params: Parameters):
with st.expander("Request parameters"):
max_new_tokens = st.slider("Max new tokens", 0, 256, 20)
repetition_penalty = st.number_input(
"Repetition penalty",
help="The parameter for repetition penalty. 1.0 means no penalty. See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.",
min_value=0.0,
value=None,
placeholder="No repition penalty by default",
)
return_full_text = st.checkbox("Return full text", help="Whether to return the full text or just the generated part")
stop_tokens = st.text_input("Stop tokens", help="A comma seperated list of tokens where generation is stopped if the model encounters any of them")
stop_sequences = stop_tokens.split(",") if stop_tokens else []
seed = st.number_input("Seed", help="Random seed for generation", min_value=0, value=None, placeholder="Random seed for generation")
temperature = st.number_input("Temperature", help="The value used to module the next token probabilities", min_value=0.0, value=None, placeholder="Temperature for generation")
best_of = st.number_input("Best of", help="The number of independently computed samples to generate and then pick the best from", value=None, placeholder="Best of for generation")
best_of_int = int(best_of) if best_of else None
watermark = st.checkbox("Watermark", help="Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)")
decoder_input_details = st.checkbox("Decoder input details", help="Whether to return the decoder input details")

do_sample_val = st.checkbox("Do Sample", help="Whether to use sampling or greedy decoding for text generation")
top_k_int, top_p, typical_p = None, None, None
if do_sample_val:
top_k = st.number_input(
"Top K",
help="The number of highest probability vocabulary tokens to keep for top-k-filtering",
value=None,
placeholder="Top K for generation",
format="%d"
)
top_k_int = int(top_k) if top_k else None
top_p = st.number_input("Top P", help="The cumulative probability of parameter for top-p-filtering", value=None, placeholder="Top P for generation")
typical_p = st.number_input(
"Typical P",
help="The typical decoding mass. See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information",
value=None,
placeholder="Typical P for generation"
)
params.max_new_tokens = max_new_tokens
params.repetition_penalty = repetition_penalty
params.return_full_text = return_full_text
params.stop = stop_sequences
params.seed = seed
params.temperature = temperature
params.best_of = best_of_int
params.watermark = watermark
params.decoder_input_details = decoder_input_details
params.do_sample = do_sample_val
params.top_k = top_k_int
params.top_p = top_p
params.typical_p = typical_p
st.write(params)


with st.expander("Adapter Configuration"):
adapter_id = st.text_input("Adapter ID", value=None, placeholder="Adapter id")
adapter_source = st.selectbox("Adapter Source", options=ADAPTER_SOURCES, index=None)
api_token = st.text_input("API Token", value=None, placeholder="API token")
if st.checkbox("Merged Adapters"):
num_adapters = st.slider("Number of Merge Adapters", value=1, min_value=1, max_value=10)
merge_strategy = st.selectbox("Merge Strategy", options=MERGE_STRATEGIES, index=None)
majority_sign_method = st.selectbox("Majority Sign Method", options=MAJORITY_SIGN_METHODS, index=None)
density = st.number_input("Density", value=0.0, placeholder="Density")
st.divider()
merge_adapters_list = [add_merge_adapter(i) for i in range(num_adapters)]
merge_adapter_ids, merge_adapter_weights = zip(*merge_adapters_list)
merge_adapters = MergedAdapters(
ids=merge_adapter_ids,
weights=merge_adapter_weights,
density=density,
merge_strategy=merge_strategy,
majority_sign_method=majority_sign_method,
)
params.merged_adapters = merge_adapters
params.adapter_id = adapter_id
params.adapter_source = adapter_source
params.api_token = api_token
st.write(params)


def main():
st.markdown(
r"""
<style>
.stDeployButton {
visibility: hidden;
}
</style>
""", unsafe_allow_html=True
)

st.title("Lorax GUI")
params = Parameters()
render_parameters(params)

txt = st.text_area("Enter prompt", "Type Here ...")
client = lorax.Client(HOST)
params_dict = params.dict()
params_dict['stop_sequences'] = params_dict['stop']
params_dict.pop('stop')
params_dict.pop('details')
if st.button("Generate"):
resp = client.generate(prompt=txt, **params_dict)
st.write(resp.generated_text)
with st.expander("Response details"):
st.write(resp.details.dict())


if __name__ == "__main__":
main()
2 changes: 2 additions & 0 deletions clients/gui/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
streamlit
lorax-client
33 changes: 33 additions & 0 deletions launcher/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,15 @@ struct Args {
/// Download model weights only
#[clap(long, env)]
download_only: bool,

/// Launch testing GUI
/// This will launch a GUI that will allow you to test the model
#[clap(long, env)]
launch_gui: bool,

/// Gui port
#[clap(long, env)]
gui_port: Option<u16>,
}

#[derive(Debug)]
Expand Down Expand Up @@ -958,6 +967,30 @@ fn spawn_shards(
Ok(())
}

// fn spawn_gui(port: u16) -> Result<Child, LauncherError> {
// tracing::info!("Starting GUI");

// let mut webserver = match Command::new("lorax-router")
// .stdout(Stdio::piped())
// .stderr(Stdio::piped())
// .process_group(0)
// .spawn()
// {
// Ok(p) => p,
// Err(err) => {
// tracing::error!("Failed to start webserver: {}", err);
// if err.kind() == io::ErrorKind::NotFound {
// tracing::error!("lorax-router not found in PATH");
// tracing::error!("Please install it with `make install-router`")
// } else {
// tracing::error!("{}", err);
// }
// return Err(LauncherError::WebserverCannotStart);
// }
// };
// Ok(webserver)
// }

fn spawn_webserver(
args: Args,
shutdown: Arc<AtomicBool>,
Expand Down