Skip to content

Commit

Permalink
Don't send command used to start streamlit to frontend (streamlit#7787)
Browse files Browse the repository at this point in the history
* Display more generic rerun streamlit message when disconnected

* Stop sending command line to frontend
  • Loading branch information
vdonato authored and zyxue committed Apr 16, 2024
1 parent 8ec1c8e commit 271d573
Show file tree
Hide file tree
Showing 22 changed files with 74 additions and 91 deletions.
6 changes: 3 additions & 3 deletions frontend/app/src/App.test.tsx
Expand Up @@ -131,7 +131,7 @@ const NEW_SESSION_JSON: INewSession = {
scriptIsRunning: false,
},
sessionId: "sessionId",
commandLine: "commandLine",
isHello: false,
},
appPages: [
{ pageScriptHash: "page_script_hash", pageName: "streamlit_app" },
Expand Down Expand Up @@ -844,7 +844,7 @@ describe("App.handleNewSession", () => {
},
initialize: {
...NEW_SESSION_JSON.initialize,
commandLine: "streamlit hello",
isHello: true,
},
})
)
Expand Down Expand Up @@ -884,7 +884,7 @@ describe("App.onHistoryChange", () => {
scriptIsRunning: false,
},
sessionId: "sessionId",
commandLine: "commandLine",
isHello: false,
},
appPages: [
{ pageScriptHash: "top_hash", pageName: "streamlit_app" },
Expand Down
1 change: 1 addition & 0 deletions frontend/app/src/StreamlitLib.test.tsx
Expand Up @@ -137,6 +137,7 @@ class StreamlitLibExample extends PureComponent<Props, State> {
installationId: "",
installationIdV3: "",
commandLine: "",
isHello: false,
})

// Initialize React state
Expand Down
50 changes: 16 additions & 34 deletions frontend/app/src/connection/WebsocketConnection.test.tsx
Expand Up @@ -77,7 +77,6 @@ describe("doInitPings", () => {
maxTimeoutMs: 100,
retryCallback: jest.fn(),
setAllowedOrigins: jest.fn(),
userCommandLine: "streamlit run not-a-real-script.py",
}

let originalAxiosGet: any
Expand Down Expand Up @@ -112,8 +111,7 @@ describe("doInitPings", () => {
MOCK_PING_DATA.timeoutMs,
MOCK_PING_DATA.maxTimeoutMs,
MOCK_PING_DATA.retryCallback,
MOCK_PING_DATA.setAllowedOrigins,
MOCK_PING_DATA.userCommandLine
MOCK_PING_DATA.setAllowedOrigins
)
expect(uriIndex).toEqual(0)
expect(MOCK_PING_DATA.setAllowedOrigins).toHaveBeenCalledWith(
Expand All @@ -131,8 +129,7 @@ describe("doInitPings", () => {
MOCK_PING_DATA.timeoutMs,
MOCK_PING_DATA.maxTimeoutMs,
MOCK_PING_DATA.retryCallback,
MOCK_PING_DATA.setAllowedOrigins,
MOCK_PING_DATA.userCommandLine
MOCK_PING_DATA.setAllowedOrigins
)
expect(uriIndex).toEqual(0)
expect(MOCK_PING_DATA.setAllowedOrigins).toHaveBeenCalledWith(
Expand All @@ -151,8 +148,7 @@ describe("doInitPings", () => {
MOCK_PING_DATA.timeoutMs,
MOCK_PING_DATA.maxTimeoutMs,
MOCK_PING_DATA.retryCallback,
MOCK_PING_DATA.setAllowedOrigins,
MOCK_PING_DATA.userCommandLine
MOCK_PING_DATA.setAllowedOrigins
)
expect(uriIndex).toEqual(1)
expect(MOCK_PING_DATA.setAllowedOrigins).toHaveBeenCalledWith(
Expand All @@ -174,8 +170,7 @@ describe("doInitPings", () => {
MOCK_PING_DATA.timeoutMs,
MOCK_PING_DATA.maxTimeoutMs,
MOCK_PING_DATA.retryCallback,
MOCK_PING_DATA.setAllowedOrigins,
MOCK_PING_DATA.userCommandLine
MOCK_PING_DATA.setAllowedOrigins
)

expect(MOCK_PING_DATA.retryCallback).toHaveBeenCalledWith(
Expand All @@ -199,8 +194,7 @@ describe("doInitPings", () => {
MOCK_PING_DATA.timeoutMs,
MOCK_PING_DATA.maxTimeoutMs,
MOCK_PING_DATA.retryCallback,
MOCK_PING_DATA.setAllowedOrigins,
MOCK_PING_DATA.userCommandLine
MOCK_PING_DATA.setAllowedOrigins
)

expect(MOCK_PING_DATA.retryCallback).toHaveBeenCalledWith(
Expand Down Expand Up @@ -228,8 +222,7 @@ describe("doInitPings", () => {
MOCK_PING_DATA.timeoutMs,
MOCK_PING_DATA.maxTimeoutMs,
MOCK_PING_DATA.retryCallback,
MOCK_PING_DATA.setAllowedOrigins,
MOCK_PING_DATA.userCommandLine
MOCK_PING_DATA.setAllowedOrigins
)

expect(MOCK_PING_DATA.retryCallback).toHaveBeenCalledWith(
Expand All @@ -255,8 +248,7 @@ describe("doInitPings", () => {
MOCK_PING_DATA.timeoutMs,
MOCK_PING_DATA.maxTimeoutMs,
MOCK_PING_DATA.retryCallback,
MOCK_PING_DATA.setAllowedOrigins,
MOCK_PING_DATA.userCommandLine
MOCK_PING_DATA.setAllowedOrigins
)

expect(MOCK_PING_DATA.retryCallback).toHaveBeenCalledWith(
Expand Down Expand Up @@ -288,9 +280,7 @@ describe("doInitPings", () => {
just restart it in your terminal:
</p>
<pre>
<StyledBashCode>
{MOCK_PING_DATA_LOCALHOST.userCommandLine}
</StyledBashCode>
<StyledBashCode>streamlit run yourscript.py</StyledBashCode>
</pre>
</Fragment>
)
Expand All @@ -306,8 +296,7 @@ describe("doInitPings", () => {
MOCK_PING_DATA_LOCALHOST.timeoutMs,
MOCK_PING_DATA.maxTimeoutMs,
MOCK_PING_DATA_LOCALHOST.retryCallback,
MOCK_PING_DATA.setAllowedOrigins,
MOCK_PING_DATA_LOCALHOST.userCommandLine
MOCK_PING_DATA.setAllowedOrigins
)

expect(MOCK_PING_DATA_LOCALHOST.retryCallback).toHaveBeenCalledWith(
Expand Down Expand Up @@ -346,8 +335,7 @@ describe("doInitPings", () => {
MOCK_PING_DATA.timeoutMs,
MOCK_PING_DATA.maxTimeoutMs,
MOCK_PING_DATA.retryCallback,
MOCK_PING_DATA.setAllowedOrigins,
MOCK_PING_DATA.userCommandLine
MOCK_PING_DATA.setAllowedOrigins
)

expect(MOCK_PING_DATA.retryCallback).toHaveBeenCalledWith(
Expand Down Expand Up @@ -376,8 +364,7 @@ describe("doInitPings", () => {
MOCK_PING_DATA.timeoutMs,
MOCK_PING_DATA.maxTimeoutMs,
MOCK_PING_DATA.retryCallback,
MOCK_PING_DATA.setAllowedOrigins,
MOCK_PING_DATA.userCommandLine
MOCK_PING_DATA.setAllowedOrigins
)

expect(MOCK_PING_DATA.retryCallback).toHaveBeenCalledWith(
Expand Down Expand Up @@ -405,8 +392,7 @@ describe("doInitPings", () => {
MOCK_PING_DATA.timeoutMs,
MOCK_PING_DATA.maxTimeoutMs,
MOCK_PING_DATA.retryCallback,
MOCK_PING_DATA.setAllowedOrigins,
MOCK_PING_DATA.userCommandLine
MOCK_PING_DATA.setAllowedOrigins
)

expect(MOCK_PING_DATA.retryCallback).toHaveBeenCalledTimes(5)
Expand Down Expand Up @@ -439,8 +425,7 @@ describe("doInitPings", () => {
MOCK_PING_DATA.timeoutMs,
MOCK_PING_DATA.maxTimeoutMs,
callback,
MOCK_PING_DATA.setAllowedOrigins,
MOCK_PING_DATA.userCommandLine
MOCK_PING_DATA.setAllowedOrigins
)

expect(timeouts.length).toEqual(5)
Expand Down Expand Up @@ -482,8 +467,7 @@ describe("doInitPings", () => {
MOCK_PING_DATA.timeoutMs,
MOCK_PING_DATA.maxTimeoutMs,
callback,
MOCK_PING_DATA.setAllowedOrigins,
MOCK_PING_DATA.userCommandLine
MOCK_PING_DATA.setAllowedOrigins
)

expect(timeouts.length).toEqual(5)
Expand Down Expand Up @@ -521,8 +505,7 @@ describe("doInitPings", () => {
MOCK_PING_DATA.timeoutMs,
MOCK_PING_DATA.maxTimeoutMs,
callback,
MOCK_PING_DATA.setAllowedOrigins,
MOCK_PING_DATA.userCommandLine
MOCK_PING_DATA.setAllowedOrigins
)

const timeouts2: number[] = []
Expand All @@ -539,8 +522,7 @@ describe("doInitPings", () => {
MOCK_PING_DATA.timeoutMs,
MOCK_PING_DATA.maxTimeoutMs,
callback2,
MOCK_PING_DATA.setAllowedOrigins,
MOCK_PING_DATA.userCommandLine
MOCK_PING_DATA.setAllowedOrigins
)

expect(timeouts[0]).toEqual(10)
Expand Down
17 changes: 5 additions & 12 deletions frontend/app/src/connection/WebsocketConnection.tsx
Expand Up @@ -256,11 +256,7 @@ export class WebsocketConnection {
// Perform pre-callback actions when entering certain states.
switch (this.state) {
case ConnectionState.PINGING_SERVER:
this.pingServer(
this.args.sessionInfo.isSet
? this.args.sessionInfo.current.commandLine
: undefined
)
this.pingServer()
break

default:
Expand Down Expand Up @@ -367,14 +363,13 @@ export class WebsocketConnection {
)
}

private async pingServer(userCommandLine?: string): Promise<void> {
private async pingServer(): Promise<void> {
this.uriIndex = await doInitPings(
this.args.baseUriPartsList,
PING_MINIMUM_RETRY_PERIOD_MS,
PING_MAXIMUM_RETRY_PERIOD_MS,
this.args.onRetry,
this.args.onHostConfigResp,
userCommandLine
this.args.onHostConfigResp
)

this.stepFsm("SERVER_PING_SUCCEEDED")
Expand Down Expand Up @@ -612,8 +607,7 @@ export function doInitPings(
minimumTimeoutMs: number,
maximumTimeoutMs: number,
retryCallback: OnRetry,
onHostConfigResp: (resp: IHostConfigResponse) => void,
userCommandLine?: string
onHostConfigResp: (resp: IHostConfigResponse) => void
): Promise<number> {
const resolver = new Resolver<number>()
let totalTries = 0
Expand Down Expand Up @@ -652,15 +646,14 @@ export function doInitPings(
const uri = new URL(buildHttpUri(uriParts, ""))

if (uri.hostname === "localhost") {
const commandLine = userCommandLine || "streamlit run yourscript.py"
retry(
<Fragment>
<p>
Is Streamlit still running? If you accidentally stopped Streamlit,
just restart it in your terminal:
</p>
<pre>
<StyledBashCode>{commandLine}</StyledBashCode>
<StyledBashCode>streamlit run yourscript.py</StyledBashCode>
</pre>
</Fragment>
)
Expand Down
19 changes: 8 additions & 11 deletions frontend/lib/src/SessionInfo.test.ts
Expand Up @@ -46,19 +46,15 @@ describe("SessionInfo.setCurrent", () => {
})

describe("SessionInfo.isHello", () => {
test("is true only when initialized with `streamlit hello` commandline", () => {
test("is true only when `isHello` is true in current SessionInfo", () => {
const sessionInfo = new SessionInfo()
expect(sessionInfo.isHello).toBe(false)

sessionInfo.setCurrent(
mockSessionInfoProps({ commandLine: "random command line" })
)
expect(sessionInfo.isHello).toBe(false)

sessionInfo.setCurrent(
mockSessionInfoProps({ commandLine: "streamlit hello" })
)
sessionInfo.setCurrent(mockSessionInfoProps({ isHello: true }))
expect(sessionInfo.isHello).toBe(true)

sessionInfo.setCurrent(mockSessionInfoProps({ isHello: false }))
expect(sessionInfo.isHello).toBe(false)
})
})

Expand All @@ -84,7 +80,7 @@ test("Props can be initialized from a protobuf", () => {
scriptIsRunning: false,
},
sessionId: "sessionId",
commandLine: "commandLine",
isHello: false,
},
})

Expand All @@ -95,5 +91,6 @@ test("Props can be initialized from a protobuf", () => {
expect(props.installationId).toEqual("installationId")
expect(props.installationIdV3).toEqual("installationIdV3")
expect(props.maxCachedMessageAge).toEqual(31)
expect(props.commandLine).toEqual("commandLine")
expect(props.commandLine).toBeUndefined()
expect(props.isHello).toBeFalsy()
})
9 changes: 4 additions & 5 deletions frontend/lib/src/SessionInfo.ts
Expand Up @@ -35,7 +35,8 @@ export interface Props {
readonly installationId: string
readonly installationIdV3: string
readonly maxCachedMessageAge: number
readonly commandLine: string
readonly commandLine?: string // Unused, but kept around for compatibility
readonly isHello: boolean
}

export class SessionInfo {
Expand Down Expand Up @@ -82,9 +83,7 @@ export class SessionInfo {

/** True if `SessionInfo.current` refers to a "streamlit hello" session. */
public get isHello(): boolean {
return (
this._current != null && this._current.commandLine === "streamlit hello"
)
return this._current != null && this._current.isHello
}

/** Create SessionInfo Props from the relevant bits of an initialize message. */
Expand All @@ -101,7 +100,7 @@ export class SessionInfo {
installationId: userInfo.installationId,
installationIdV3: userInfo.installationIdV3,
maxCachedMessageAge: config.maxCachedMessageAge,
commandLine: initialize.commandLine,
isHello: initialize.isHello,
}
}
}
2 changes: 1 addition & 1 deletion frontend/lib/src/mocks/mocks.ts
Expand Up @@ -32,7 +32,7 @@ export function mockSessionInfoProps(
installationId: "mockInstallationId",
installationIdV3: "mockInstallationIdV3",
maxCachedMessageAge: 123,
commandLine: "mockCommandLine",
isHello: false,
...overrides,
}
}
Expand Down
2 changes: 1 addition & 1 deletion lib/streamlit/runtime/app_session.py
Expand Up @@ -653,7 +653,7 @@ def _create_new_session_message(self, page_script_hash: str) -> ForwardMsg:
self._state == AppSessionState.APP_IS_RUNNING
)

imsg.command_line = self._script_data.command_line
imsg.is_hello = self._script_data.is_hello
imsg.session_id = self.id

return msg
Expand Down
13 changes: 8 additions & 5 deletions lib/streamlit/runtime/runtime.py
Expand Up @@ -83,8 +83,8 @@ class RuntimeConfig:
# The filesystem path of the Streamlit script to run.
script_path: str

# The (optional) command line that Streamlit was started with
# (e.g. "streamlit run app.py")
# DEPRECATED: We need to keep this field around for compatibility reasons, but we no
# longer use this anywhere.
command_line: Optional[str]

# The storage backend for Streamlit's MediaFileManager.
Expand All @@ -104,6 +104,9 @@ class RuntimeConfig:
# The SessionStorage instance for the SessionManager to use.
session_storage: SessionStorage = field(default_factory=MemorySessionStorage)

# True if the command used to start Streamlit was `streamlit hello`.
is_hello: bool = False


class RuntimeState(Enum):
INITIAL = "INITIAL"
Expand Down Expand Up @@ -183,7 +186,7 @@ def __init__(self, config: RuntimeConfig):
self._loop_coroutine_task: Optional[asyncio.Task[None]] = None

self._main_script_path = config.script_path
self._command_line = config.command_line or ""
self._is_hello = config.is_hello

self._state = RuntimeState.INITIAL

Expand Down Expand Up @@ -367,7 +370,7 @@ def connect_session(

session_id = self._session_mgr.connect_session(
client=client,
script_data=ScriptData(self._main_script_path, self._command_line or ""),
script_data=ScriptData(self._main_script_path, self._is_hello),
user_info=user_info,
existing_session_id=existing_session_id,
session_id_override=session_id_override,
Expand Down Expand Up @@ -540,7 +543,7 @@ async def does_script_run_without_error(self) -> Tuple[bool, str]:
# SessionManager intentionally. This isn't a "real" session and is only being
# used to test that the script runs without error.
session = AppSession(
script_data=ScriptData(self._main_script_path, self._command_line),
script_data=ScriptData(self._main_script_path, self._is_hello),
uploaded_file_manager=self._uploaded_file_mgr,
script_cache=self._script_cache,
message_enqueued_callback=self._enqueued_some_message,
Expand Down

0 comments on commit 271d573

Please sign in to comment.