# Scala Spark Connect via Local Python Proxy (Path B)

**Goal:** Run Scala Spark code in a Workspace Notebook by:
1. Starting a local Spark Connect server in Python (consumes SPCS OAuth token)
2. Connecting a Scala Spark client to `sc://localhost:15002` via JPype

This avoids PAT — the Python proxy handles auth using the container's
Snowpark session, and Scala just talks to localhost.

**Architecture:**
```
Python: start_session(is_daemon=False, remote_url="sc://localhost:15002")
   ↓  (consumes SPCS OAuth token via Snowpark session)
Scala (JPype) → sc://localhost:15002 → Python proxy → Snowflake warehouse
```

**Self-contained:** This notebook installs everything it needs. It can run
on a fresh container or one where other notebooks have already been executed.
All installs are idempotent.

## Contents

1. [Setup: Install dependencies](#1)
2. [Start local Spark Connect server](#2)
3. [Verify PySpark via local server](#3)
4. [Setup JVM + Scala Spark Connect client](#4)
5. [Connect Scala to local server](#5)
6. [Scala Spark SQL tests](#6)
7. [Findings](#7)

## 1. Setup: Install dependencies

All dependencies are installed from scratch. Each install is idempotent —
safe to re-run if already installed from a prior notebook or previous run.

We need:
- `snowpark-connect[jdk]` + `pyspark` (PySpark + bundled JDK)
- `JPype1` (Python-JVM bridge)
- `coursier` (JVM dependency resolver, for fetching Spark Connect client JARs)
- `spark-connect-client-jvm` Scala JAR (fetched via coursier)

In [None]:
# Cell 1
import subprocess, sys, os, time, shutil, pathlib, socket

def _pip_install(package, label=None):
    """Install a pip package if not already importable. Idempotent."""
    label = label or package
    result = subprocess.run(
        [sys.executable, "-m", "pip", "install", package, "-q"],
        capture_output=True, text=True, timeout=300
    )
    if result.returncode == 0:
        print(f"  {label}: OK")
    else:
        print(f"  {label}: FAILED — {result.stderr[-200:]}")

print("=== Installing Python packages (idempotent) ===")
_pip_install("snowpark-connect[jdk]", "snowpark-connect")
_pip_install("pyspark==3.5.6", "pyspark")
_pip_install("JPype1", "JPype1")

import snowflake.snowpark_connect
import pyspark
import jpype
print(f"\nsnowpark-connect: imported OK")
print(f"PySpark version:  {pyspark.__version__}")
print(f"JPype1 version:   {jpype.__version__}")

In [None]:
# Cell 2 — Find or install Java (idempotent)
MICROMAMBA_ROOT = pathlib.Path.home() / "micromamba"
MM_ENV = MICROMAMBA_ROOT / "envs" / "jvm_env"

def _set_java_home(java_home_path):
    """Set JAVA_HOME and update PATH."""
    os.environ["JAVA_HOME"] = str(java_home_path)
    bin_dir = pathlib.Path(java_home_path) / "bin"
    if str(bin_dir) not in os.environ.get("PATH", ""):
        os.environ["PATH"] = f"{bin_dir}:" + os.environ.get("PATH", "")

java_search_paths = [
    ("PATH (pre-existing)",          lambda: shutil.which("java")),
    ("snowpark-connect[jdk] bundle", lambda: pathlib.Path(sys.prefix) / "jdk"),
    ("micromamba jvm_env",           lambda: MM_ENV),
]

java_found = False
for label, path_fn in java_search_paths:
    path = path_fn()
    if path and isinstance(path, str):
        _set_java_home(pathlib.Path(path).resolve().parent.parent)
        java_found = True
        print(f"Java source: {label}")
        break
    elif path and isinstance(path, pathlib.Path) and path.exists():
        _set_java_home(path)
        java_found = True
        print(f"Java source: {label}")
        break

if not java_found:
    print("No Java found — installing OpenJDK 17 via micromamba...")
    os.environ["MAMBA_ROOT_PREFIX"] = str(MICROMAMBA_ROOT)
    mm_bin = MICROMAMBA_ROOT / "bin" / "micromamba"

    if not mm_bin.exists():
        print("  Downloading micromamba...")
        (MICROMAMBA_ROOT / "bin").mkdir(parents=True, exist_ok=True)
        subprocess.run(
            "cd /tmp && "
            "curl -Ls --retry 3 --retry-delay 2 "
            "https://micro.mamba.pm/api/micromamba/linux-64/latest "
            "| tar -xvj bin/micromamba 2>/dev/null && "
            f"mv bin/micromamba {mm_bin} && rmdir bin 2>/dev/null || true",
            shell=True, check=True, timeout=60
        )
        print(f"  micromamba installed: {mm_bin}")

    os.environ["PATH"] = f"{MICROMAMBA_ROOT / 'bin'}:" + os.environ.get("PATH", "")

    if not MM_ENV.exists():
        print("  Creating jvm_env with OpenJDK 17...")
        subprocess.run(
            f"{mm_bin} create -y -n jvm_env -c conda-forge openjdk=17 -q",
            shell=True, check=True, timeout=120,
            env={**os.environ, "MAMBA_ROOT_PREFIX": str(MICROMAMBA_ROOT)}
        )

    if MM_ENV.exists():
        _set_java_home(MM_ENV)
        print(f"  JDK ready: {MM_ENV}")
    else:
        raise RuntimeError("FATAL: Could not install Java via micromamba")

java_path = shutil.which("java")
print(f"\nJava binary:  {java_path or 'NOT FOUND'}")
if java_path:
    jv = subprocess.run(["java", "-version"], capture_output=True, text=True)
    print(f"Java version: {jv.stderr.splitlines()[0]}")
print(f"JAVA_HOME:    {os.environ.get('JAVA_HOME', 'NOT SET')}")

# --- Install coursier (idempotent) ---
CS_DIR = os.path.expanduser("~/scala_jars")
os.makedirs(CS_DIR, exist_ok=True)
cs_path = shutil.which("cs") or os.path.join(CS_DIR, "cs")

if not os.path.isfile(cs_path):
    print("\nInstalling coursier...")
    subprocess.run(
        f"curl -fL https://github.com/coursier/coursier/releases/latest/download/cs-x86_64-pc-linux.gz "
        f"| gzip -d > {cs_path} && chmod +x {cs_path}",
        shell=True, check=True, timeout=60
    )
    print(f"Coursier installed: {cs_path}")
else:
    print(f"\nCoursier already at: {cs_path}")

## 2. Start local Spark Connect server

Use the `snowpark_connect` API to start a local Spark Connect gRPC
server that listens on `localhost:15002`. This server consumes the
Workspace's Snowpark session (SPCS OAuth) and proxies requests to
Snowflake.

The server runs in a background thread so the notebook stays interactive.

In [None]:
# Cell 3 — Write config.toml for the Spark Connect server
from snowflake.snowpark.context import get_active_session

session = get_active_session()

def _safe(fn):
    try:
        v = fn()
        return v.strip('"') if v else ""
    except Exception:
        return ""

conn_info = {
    "account": _safe(lambda: session.sql("SELECT CURRENT_ACCOUNT()").collect()[0][0]),
    "user": _safe(lambda: session.sql("SELECT CURRENT_USER()").collect()[0][0]),
    "role": _safe(session.get_current_role),
    "database": _safe(session.get_current_database),
    "schema": _safe(session.get_current_schema),
    "warehouse": _safe(session.get_current_warehouse),
    "host": os.environ.get("SNOWFLAKE_HOST", ""),
}

spcs_token = ""
if os.path.isfile("/snowflake/session/token"):
    with open("/snowflake/session/token") as f:
        spcs_token = f.read().strip()

config_dir = pathlib.Path.home() / ".snowflake"
config_dir.mkdir(parents=True, exist_ok=True)
config_file = config_dir / "config.toml"

toml_content = f'''[connections.spark-connect]
host = "{conn_info['host']}"
account = "{conn_info['account']}"
user = "{conn_info['user']}"
token = "{spcs_token}"
authenticator = "oauth"
warehouse = "{conn_info['warehouse']}"
database = "{conn_info['database']}"
schema = "{conn_info['schema']}"
role = "{conn_info['role']}"
'''

config_file.write_text(toml_content)
config_file.chmod(0o600)
print(f"Config written to {config_file}")
print(f"Account: {conn_info['account']}, User: {conn_info['user']}")

In [None]:
# Cell 4 — Explore snowpark_connect API
import snowflake.snowpark_connect as spc

print("=== snowpark_connect top-level attributes ===")
for attr in sorted(dir(spc)):
    if not attr.startswith("_"):
        obj = getattr(spc, attr)
        print(f"  {attr}: {type(obj).__name__}")

# Check for server/start_session APIs
if hasattr(spc, 'server'):
    print("\n=== snowpark_connect.server attributes ===")
    for attr in sorted(dir(spc.server)):
        if not attr.startswith("_"):
            print(f"  {attr}: {type(getattr(spc.server, attr)).__name__}")

if hasattr(spc, 'client'):
    print("\n=== snowpark_connect.client attributes ===")
    for attr in sorted(dir(spc.client)):
        if not attr.startswith("_"):
            print(f"  {attr}: {type(getattr(spc.client, attr)).__name__}")

# Check for start_session specifically
for mod_name in ['', '.server', '.client']:
    try:
        mod = eval(f'spc{mod_name}')
        if hasattr(mod, 'start_session'):
            import inspect
            sig = inspect.signature(mod.start_session)
            print(f"\n=== {f'spc{mod_name}'}.start_session signature ===")
            print(f"  {sig}")
    except Exception:
        pass

In [None]:
# Cell 5 — Start local Spark Connect server
import threading

SERVER_PORT = 15002

# Check if server is already listening (idempotent — e.g. re-run without restart)
def _port_open(port):
    try:
        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        s.settimeout(1)
        s.connect(("127.0.0.1", port))
        s.close()
        return True
    except (socket.error, OSError):
        return False

if _port_open(SERVER_PORT):
    print(f"Spark Connect server already listening on localhost:{SERVER_PORT}")
else:
    server_error = [None]

    def _start_server():
        try:
            spc.start_session(
                is_daemon=False,
                remote_url=f"sc://localhost:{SERVER_PORT}",
                snowpark_session=session,
            )
        except Exception as e:
            server_error[0] = str(e)

    server_thread = threading.Thread(target=_start_server, daemon=True)
    server_thread.start()

    print(f"Waiting for Spark Connect server on localhost:{SERVER_PORT}...")
    for attempt in range(30):
        time.sleep(1)
        if server_error[0]:
            print(f"Server failed: {server_error[0]}")
            break
        if _port_open(SERVER_PORT):
            print(f"Server LISTENING on localhost:{SERVER_PORT} (took {attempt+1}s)")
            break
        if attempt % 5 == 4:
            print(f"  Still waiting ({attempt+1}s)...")
    else:
        print("Server did not start within 30s")

## 3. Verify PySpark via local server

Before trying Scala, verify that a PySpark client can connect to the
local server (not the remote endpoint).

In [None]:
# Cell 6 — Verify PySpark via local server
from pyspark.sql import SparkSession

spark_local = SparkSession.builder.remote(f"sc://localhost:{SERVER_PORT}").getOrCreate()
print(f"PySpark connected to local server: {type(spark_local).__name__}")
print(f"Spark version: {spark_local.version}")

# Basic test
spark_local.sql("SELECT 1 AS test, CURRENT_USER() AS user").show()
print("PySpark via local server: WORKING")

## 4. Setup JVM + Scala Spark Connect client JARs

Fetch the `spark-connect-client-jvm` JAR and its dependencies using
coursier (installed in step 1). This JAR lets a Scala client speak
the Spark Connect gRPC protocol. Idempotent — skips if already fetched.

In [None]:
# Cell 7 — Fetch Spark Connect client JARs via coursier
import pyspark
SPARK_VERSION = pyspark.__version__  # e.g. "3.5.6"
SCALA_VERSION = "2.12"
JAR_DIR = os.path.expanduser("~/spark_connect_jars")
CP_FILE = os.path.join(JAR_DIR, "spark_connect_classpath.txt")
os.makedirs(JAR_DIR, exist_ok=True)

# cs_path was set in the setup cell above
artifact = f"org.apache.spark:spark-connect-client-jvm_{SCALA_VERSION}:{SPARK_VERSION}"

# Idempotent: skip if classpath already resolved
if os.path.isfile(CP_FILE):
    with open(CP_FILE) as f:
        jars = [j for j in f.read().strip().split(":") if j and os.path.exists(j)]
    if jars:
        print(f"Spark Connect client JARs already resolved: {len(jars)} JARs")
    else:
        os.remove(CP_FILE)
        jars = []

if not os.path.isfile(CP_FILE):
    print(f"Fetching {artifact} via coursier...")
    t0 = time.time()
    result = subprocess.run(
        [cs_path, "fetch", artifact, "--classpath"],
        capture_output=True, text=True, timeout=300
    )
    elapsed = time.time() - t0
    
    if result.returncode == 0:
        classpath = result.stdout.strip()
        jars = [j for j in classpath.split(":") if j]
        with open(CP_FILE, "w") as f:
            f.write(classpath)
        print(f"Resolved {len(jars)} JARs in {elapsed:.1f}s")
    else:
        print(f"Coursier failed: {result.stderr[-300:]}")
        jars = []

## 5. Connect Scala to local Spark Connect server

Start a JVM via JPype with the `spark-connect-client-jvm` JARs on the
classpath, then create a Scala SparkSession connected to `sc://localhost:15002`.

In [None]:
# Cell 8 — Start JVM with Spark Connect JARs

if not jpype.isJVMStarted():
    cp_file = os.path.expanduser("~/spark_connect_jars/spark_connect_classpath.txt")
    if os.path.isfile(cp_file):
        with open(cp_file) as f:
            classpath = [j for j in f.read().strip().split(":") if j]
    else:
        raise RuntimeError("No classpath file found. Run the JAR fetch cell first.")
    
    jvm_options = [
        "--add-opens=java.base/java.nio=ALL-UNNAMED",
        "-Xmx1g",
    ]
    
    print(f"Starting JVM with {len(classpath)} JARs...")
    jpype.startJVM(
        jpype.getDefaultJVMPath(),
        *jvm_options,
        classpath=classpath,
        convertStrings=True,
    )
    import jpype.imports
    print(f"JVM started: {jpype.getDefaultJVMPath()}")
else:
    print(f"JVM already running: {jpype.getDefaultJVMPath()}")

In [None]:
# Cell 9 — Create Scala SparkSession via Spark Connect
try:
    SparkSession = jpype.JClass("org.apache.spark.sql.SparkSession")
    
    scala_spark = (
        SparkSession.builder()
        .remote(f"sc://localhost:{SERVER_PORT}")
        .getOrCreate()
    )
    
    print(f"Scala SparkSession created: {type(scala_spark)}")
    print(f"Spark version: {scala_spark.version()}")
    print("\n*** Scala Spark Connect via local proxy: WORKING ***")
    
except Exception as e:
    print(f"Failed to create Scala SparkSession: {type(e).__name__}: {e}")
    import traceback
    traceback.print_exc()

## 6. Scala Spark SQL tests

If the Scala SparkSession is live, test running SQL on Snowflake.

In [None]:
# Cell 10 — Scala Spark SQL basic test
try:
    result = scala_spark.sql("SELECT 1 AS id, 'hello from scala' AS msg")
    result.show()
    print("Scala Spark SQL: WORKING")
except Exception as e:
    print(f"Scala SQL failed: {e}")

In [None]:
# Cell 11 — Scala queries against Snowflake
try:
    result = scala_spark.sql("SELECT CURRENT_USER() AS user")
    result.show()
except Exception as e:
    print(f"CURRENT_USER query failed: {e}")

try:
    result = scala_spark.sql(
        "SELECT TABLE_NAME, ROW_COUNT FROM INFORMATION_SCHEMA.TABLES LIMIT 3"
    )
    result.show()
except Exception as e:
    print(f"INFORMATION_SCHEMA query failed: {e}")

In [None]:
# Cell 12 — Interop: Snowpark Python writes, Scala Spark reads
try:
    session.sql("""
        CREATE OR REPLACE TRANSIENT TABLE _SCALA_CONNECT_TEST AS
        SELECT 1 AS id, 'from_snowpark_python' AS source
    """).collect()
    print("Snowpark Python: wrote _SCALA_CONNECT_TEST")
    
    scala_df = scala_spark.sql("SELECT * FROM _SCALA_CONNECT_TEST")
    scala_df.show()
    print("Scala Spark read it via local proxy!")
    
    session.sql("DROP TABLE IF EXISTS _SCALA_CONNECT_TEST").collect()
    print("Cleanup done")
except Exception as e:
    print(f"Interop test failed: {e}")

## 7. Findings

| Question | Result | Notes |
|----------|--------|-------|
| Local server starts? | | `start_session` on localhost:15002 |
| PySpark via local server? | | Separate from remote endpoint test |
| JVM + spark-connect JARs? | | coursier fetch |
| Scala SparkSession created? | | `SparkSession.builder().remote(...)` |
| Scala SQL to Snowflake? | | Through local proxy |
| Scala reads Snowflake tables? | | INFORMATION_SCHEMA etc |
| Snowpark ↔ Scala Spark interop? | | Via transient tables |
| Auth: no PAT needed? | | SPCS token via Python proxy |