# Snowpark Scala in Workspace Notebooks (Prototype)

This notebook demonstrates running **Scala** and **Snowpark Scala** within a
Snowflake Workspace Notebook using a `%%scala` cell magic powered by JPype.

**Architecture:** Python kernel → JPype (JNI) → JVM (in-process) → Scala REPL → Snowpark

---

## Contents

1. [Installation & Configuration](#1)
2. [Basic Scala Execution](#2)
3. [Python ↔ Scala Interop](#3)
4. [Snowpark Scala Session](#4)
5. [Diagnostics](#5)

---
## 1. Installation & Configuration

### 1.1 Install JDK, Scala, and Snowpark JAR

Run the setup script. This takes ~2-4 minutes on first run (installs
OpenJDK 17, Scala 2.12, Ammonite, Snowpark JAR via micromamba + coursier).

On subsequent runs it detects what is already installed and skips those steps.

In [None]:
!bash setup_scala_environment.sh

### 1.2 Configure Python Environment & Register Scala Magics

This cell:
1. Sets `JAVA_HOME` and `PATH`
2. Installs JPype1 into the kernel venv (if needed)
3. Starts the JVM in-process with the Scala + Snowpark classpath
4. Initialises the Scala REPL (Ammonite-lite or IMain)
5. Registers `%%scala` (cell) and `%scala` (line) magics

In [None]:
from scala_helpers import setup_scala_environment

result = setup_scala_environment()

print(f"Success:          {result['success']}")
print(f"Java version:     {result['java_version']}")
print(f"Scala version:    {result['scala_version']}")
print(f"Interpreter type: {result['interpreter_type']}")
print(f"JVM started:      {result['jvm_started']}")
print(f"Magic registered: {result['magic_registered']}")
if result.get('jvm_options'):
    print(f"JVM options:      {result['jvm_options']}")

if result['errors']:
    print(f"\nErrors:")
    for err in result['errors']:
        print(f"  - {err}")

### 1.3 Verify Scala Execution

In [None]:
%%scala
println(s"Hello from Scala ${util.Properties.versionString}")
println(s"Java: ${System.getProperty("java.version")}")
println(s"OS: ${System.getProperty("os.name")}")

### 1.4 Single-line Scala (`%scala`)

The `%scala` line magic runs a single Scala expression inline — handy
for quick checks without a full `%%scala` cell.

In [None]:
%scala println(s"Quick check: 2 + 2 = ${2 + 2}")
%scala println(s"Scala version: ${util.Properties.versionString}")

---
## 2. Basic Scala Execution

State persists across `%%scala` cells — vals, defs, imports, and classes
defined in one cell are available in the next.

In [None]:
%%scala
// Define a value
val greeting = "Hello from Snowflake Workspace Notebook!"
println(greeting)

In [None]:
%%scala
// Previous cell's 'greeting' is still in scope
println(s"Greeting length: ${greeting.length}")

// Define a function
def factorial(n: Int): BigInt = if (n <= 1) 1 else n * factorial(n - 1)

println(s"10! = ${factorial(10)}")
println(s"20! = ${factorial(20)}")

In [None]:
%%scala
// Collections and functional programming
val numbers = (1 to 10).toList
val squares = numbers.map(n => n * n)
val evenSquares = squares.filter(_ % 2 == 0)

println(s"Numbers:      $numbers")
println(s"Squares:      $squares")
println(s"Even squares: $evenSquares")
println(s"Sum:          ${evenSquares.sum}")

In [None]:
%%scala
// Case classes and pattern matching
case class Employee(name: String, department: String, salary: Double)

val employees = List(
  Employee("Alice", "Engineering", 120000),
  Employee("Bob", "Engineering", 115000),
  Employee("Carol", "Data Science", 130000),
  Employee("Dave", "Data Science", 125000),
  Employee("Eve", "Product", 110000)
)

val byDept = employees.groupBy(_.department).map {
  case (dept, emps) => (dept, emps.map(_.salary).sum / emps.size)
}

byDept.toList.sortBy(-_._2).foreach {
  case (dept, avgSalary) =>
    println(f"  $dept%-20s $$${avgSalary}%,.0f")
}

---
## 3. Python ↔ Scala Interoperability

### 3.1 Push values from Python to Scala

In [None]:
from scala_helpers import push_to_scala

# Push a string and number from Python into the Scala interpreter
push_to_scala("pythonMessage", "Hello from Python!")
push_to_scala("pythonNumber", 42)

In [None]:
%%scala
// Access the variables pushed from Python
println(s"From Python: $pythonMessage")
println(s"Number: $pythonNumber")

### 3.2 Pull values from Scala to Python

In [None]:
%%scala
val scalaResult = (1 to 100).sum
println(s"Sum 1..100 = $scalaResult")

In [None]:
from scala_helpers import pull_from_scala

value = pull_from_scala("scalaResult")
print(f"Pulled from Scala: {value} (type: {type(value).__name__})")

### 3.3 Magic flags: `-i` and `-o` (like rpy2's `%%R`)

Instead of calling `push_to_scala()` / `pull_from_scala()` explicitly,
you can use **`-i`** and **`-o`** flags directly on the `%%scala` line —
the same pattern as rpy2's `%%R -i` / `%%R -o`.

In [None]:
# Define Python variables to push into Scala
py_limit = 50
py_label = "first N numbers"

In [None]:
%%scala -i py_limit,py_label -o scala_sum --time
// py_limit and py_label were pushed from Python automatically
val n = py_limit.asInstanceOf[Int]
val scala_sum = (1 to n).sum
println(s"Sum of $py_label (1 to $n) = $scala_sum")

In [None]:
# scala_sum was pulled back into Python automatically via -o
print(f"Back in Python: scala_sum = {scala_sum} (type: {type(scala_sum).__name__})")

---
## 4. Snowpark Scala Session

### 4.1 Inject credentials

Extract credentials from the Python session and the SPCS container token,
then set them as Java System properties for Scala.

Inside a Workspace Notebook, the SPCS OAuth token at `/snowflake/session/token`
is used automatically. No PAT is needed.

In [None]:
from snowflake.snowpark.context import get_active_session
from scala_helpers import inject_session_credentials

session = get_active_session()
creds = inject_session_credentials(session)

print("Credentials injected as Java System properties:")
for k, v in creds.items():
    if 'TOKEN' in k:
        print(f"  {k}: {'SET (' + str(len(v)) + ' chars)' if v else 'NOT SET'}")
    else:
        print(f"  {k}: {v}")

### 4.2 Preview Session Code

In [None]:
from scala_helpers import create_snowpark_scala_session_code

code = create_snowpark_scala_session_code()
print(code)

### 4.3 Create Snowpark Scala Session

In [None]:
%%scala
import com.snowflake.snowpark._
import com.snowflake.snowpark.functions._

def prop(k: String): String = {
  val v = System.getProperty(k)
  require(v != null, s"System property '$k' not set. Run inject_session_credentials() first.")
  v
}

val sfSession = Session.builder.configs(Map(
  "URL"           -> prop("SNOWFLAKE_URL"),
  "USER"          -> prop("SNOWFLAKE_USER"),
  "ROLE"          -> prop("SNOWFLAKE_ROLE"),
  "DB"            -> prop("SNOWFLAKE_DATABASE"),
  "SCHEMA"        -> prop("SNOWFLAKE_SCHEMA"),
  "WAREHOUSE"     -> prop("SNOWFLAKE_WAREHOUSE"),
  "TOKEN"         -> prop("SNOWFLAKE_TOKEN"),
  "AUTHENTICATOR" -> prop("SNOWFLAKE_AUTH_TYPE")
)).create

println("Snowpark Scala session created!")
val _user = sfSession.sql("SELECT CURRENT_USER()").collect()(0).getString(0)
val _role = sfSession.sql("SELECT CURRENT_ROLE()").collect()(0).getString(0)
val _db = sfSession.sql("SELECT CURRENT_DATABASE()").collect()(0).getString(0)
println(s"  User:      ${_user}")
println(s"  Role:      ${_role}")
println(s"  Database:  ${_db}")

### 4.4 Query Snowflake from Scala

In [None]:
%%scala
// Basic query
sfSession.sql("SELECT CURRENT_USER() AS user, CURRENT_ROLE() AS role, CURRENT_WAREHOUSE() AS warehouse").show()

In [None]:
%%scala
// DataFrame operations
val df = sfSession.sql("SELECT 'Scala' AS language, 'Snowpark' AS framework, CURRENT_TIMESTAMP() AS ts")
df.show()

In [None]:
%%scala
// Show available tables
sfSession.sql("SHOW TABLES LIMIT 5").show()

### 4.5 Cross-language Data Sharing

The Python and Scala Snowpark sessions are **separate connections**, so
`TEMPORARY TABLE`s (which are session-scoped) are not visible across them.
Use a `TRANSIENT TABLE` instead, and drop it when done.

In [None]:
# Python: create a transient table (visible across sessions, unlike TEMPORARY)
session.sql("""
    CREATE OR REPLACE TRANSIENT TABLE scala_demo (
        id INT, name STRING, value DOUBLE
    ) AS
    SELECT column1, column2, column3 FROM VALUES
        (1, 'alpha', 10.5),
        (2, 'beta', 20.3),
        (3, 'gamma', 30.7)
""").collect()
print("Transient table 'scala_demo' created from Python")

In [None]:
%%scala
// Scala: read the temp table created by Python
val demo = sfSession.table("scala_demo")
demo.show()

// Compute something
val total = demo.select(sum(col("VALUE"))).collect()(0).getDouble(0)
println(s"Total value: $total")

In [None]:
# Cleanup: drop the transient demo table
session.sql("DROP TABLE IF EXISTS scala_demo").collect()
print("Table 'scala_demo' dropped")

---
## 5. Diagnostics

Run the diagnostics check to verify the JVM, Scala interpreter,
Snowpark classpath, credentials, and disk space are all healthy.

In [None]:
from scala_helpers import print_diagnostics
print_diagnostics()