# Spark LLM Assistant

## Initialization

In [1]:
import sys
print(sys.path)
sys.path.append("/Users/amanda.liu/Documents/Databricks/spark-llm")

['/Users/amanda.liu/Documents/Databricks/spark-llm/examples', '/Users/amanda.liu/anaconda3/envs/llm-spark/lib/python311.zip', '/Users/amanda.liu/anaconda3/envs/llm-spark/lib/python3.11', '/Users/amanda.liu/anaconda3/envs/llm-spark/lib/python3.11/lib-dynload', '', '/Users/amanda.liu/anaconda3/envs/llm-spark/lib/python3.11/site-packages']


In [2]:
from langchain.chat_models import ChatOpenAI
from spark_llm import SparkLLMAssistant

llm = ChatOpenAI(model_name='gpt-4') # using gpt-4 can achieve better results
assistant=SparkLLMAssistant(llm=llm, verbose=True)
assistant.activate() # active partial functions for Spark DataFrame

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
23/06/14 10:58:21 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


## Example 1: Auto sales by brand in US 2022

In [3]:
# Search and ingest web content into a DataFrame
auto_df = assistant.create_df("2022 USA national auto sales by brand")
auto_df.show()

KeyboardInterrupt: 

In [None]:
auto_df.llm_plot()

In [None]:
# Apply transforms to a Dataframe
auto_top_growth_df=auto_df.llm_transform("top brand with the highest growth")
auto_top_growth_df.show()

In [None]:
# Explain what a DataFrame is retrieving.
auto_top_growth_df.llm_explain()

## Example 2: USA Presidents

In [None]:
# You can also specify the expected columns for the ingestion.
df=assistant.create_df("USA presidents", ["president", "vice_president"])
df.show()

In [None]:
presidents_who_were_vp = df.llm_transform("presidents who were also vice presidents")
presidents_who_were_vp.show()

In [None]:
presidents_who_were_vp.llm_explain()

# Example 3: Top 10 tech companies

In [None]:
# Search and ingest web content into a DataFrame
company_df=assistant.create_df("Top 10 tech companies by market cap", ['company', 'cap', 'country'])
company_df.show()

In [None]:
us_company_df=company_df.llm_transform("companies in USA")
us_company_df.show()

In [None]:
us_company_df.llm_explain()

In [None]:
us_company_df.llm_plot()

## Example 4: Ingestion from a URL
Instead of searching for the web page, you can also ask the assistant to ingest from a URL.

In [None]:
assistant.create_df('https://time.com/6235186/best-albums-2022/').show()

## Example 5: Verify dataframe
You can ask the assistant to verify an expected property of a df

In [None]:
tswift_df = assistant.create_df("Taylor Swift Top Songs 2022")

In [None]:
tswift_df.llm_verify('expect no null data')
# tswift_df.llm_verify('expect 2 columns')

## Example 6: Test generation
You can ask the assistant to generate test code for a given dataframe transformation function.

In [4]:
import pyspark.sql.functions as F

def remove_non_word_characters(col):
    return F.regexp_replace(col, "[^\\w\\s]+", "")

assistant.test_llm(remove_non_word_characters)

Generated test code:
from pyspark.sql import SparkSession
from pyspark.sql.types import StringType, StructField, StructType
import unittest

def remove_non_word_characters(df):
    return df.withColumn("text", F.regexp_replace(F.col("text"), "[^\\w\\s]", ""))

class TestRemoveNonWordCharacters(unittest.TestCase):
    def setUp(self):
        self.spark = SparkSession.builder \
            .appName("TestRemoveNonWordCharacters") \
            .getOrCreate()
        self.input_data = [
            ("Hi, my name is John!",),
            ("Today is a great day?!",),
            ("What's the plan for tomorrow?",)
        ]
        self.expected_data = [
            ("Hi my name is John",),
            ("Today is a great day",),
            ("Whats the plan for tomorrow",)
        ]
        self.input_schema = StructType([StructField("text", StringType(), True)])
        self.expected_schema = StructType([StructField("text", StringType(), True)])

    def test_remove_non_word_characters(self

In [None]:
from pyspark.sql import SparkSession
from pyspark.sql.types import StringType
from pyspark.sql.functions import udf
import unittest
import re

def remove_non_word_characters(string: str) -> str:
    return re.sub(r'\W+', '', string)

class TestRemoveNonWordCharacters(unittest.TestCase):

    def setUp(self):
        self.spark = SparkSession \
            .builder \
            .appName("Python Spark SQL test") \
            .getOrCreate()
        self.remove_non_word_characters_udf = udf(remove_non_word_characters, StringType())

    def test_remove_non_word_characters(self):
        data = [("Hello, World!", "HelloWorld"),
                ("$100.00", "10000"),
                ("test_case1", "test_case1"),
                ("!@#$%^&*()", "")]
        df = self.spark.createDataFrame(data, ["input", "expected"])
        df = df.withColumn("result", self.remove_non_word_characters_udf(df["input"]))
        result = df.select("result", "expected").collect()
        for row in result:
            self.assertEqual(row["result"], row["expected"])

    def tearDown(self):
        self.spark.stop()

if __name__ == '__main__':
    # unittest.main()
    unittest.main(argv=['first-arg-is-ignored'], exit=False)