# Spark LLM Assistant

## Initialization

In [2]:
import sys
print(sys.path)

['/Users/amanda.liu/Documents/Databricks/spark-llm/examples', '/Users/amanda.liu/anaconda3/envs/llm-spark/lib/python311.zip', '/Users/amanda.liu/anaconda3/envs/llm-spark/lib/python3.11', '/Users/amanda.liu/anaconda3/envs/llm-spark/lib/python3.11/lib-dynload', '', '/Users/amanda.liu/anaconda3/envs/llm-spark/lib/python3.11/site-packages']


In [None]:
from langchain.chat_models import ChatOpenAI
from spark_llm import SparkLLMAssistant

llm = ChatOpenAI(model_name='gpt-4') # using gpt-4 can achieve better results
assistant=SparkLLMAssistant(llm=llm, verbose=True)
assistant.activate() # active partial functions for Spark DataFrame

ModuleNotFoundError: No module named 'spark_llm'

## Example 1: Auto sales by brand in US 2022

In [None]:
# Search and ingest web content into a DataFrame
auto_df = assistant.create_df("2022 USA national auto sales by brand")
auto_df.show()

In [None]:
auto_df.llm_plot()

In [None]:
# Apply transforms to a Dataframe
auto_top_growth_df=auto_df.llm_transform("top brand with the highest growth")
auto_top_growth_df.show()

In [None]:
# Explain what a DataFrame is retrieving.
auto_top_growth_df.llm_explain()

## Example 2: USA Presidents

In [None]:
# You can also specify the expected columns for the ingestion.
df=assistant.create_df("USA presidents", ["president", "vice_president"])
df.show()

In [None]:
presidents_who_were_vp = df.llm_transform("presidents who were also vice presidents")
presidents_who_were_vp.show()

In [None]:
presidents_who_were_vp.llm_explain()

# Example 3: Top 10 tech companies

In [None]:
# Search and ingest web content into a DataFrame
company_df=assistant.create_df("Top 10 tech companies by market cap", ['company', 'cap', 'country'])
company_df.show()

In [None]:
us_company_df=company_df.llm_transform("companies in USA")
us_company_df.show()

In [None]:
us_company_df.llm_explain()

In [None]:
us_company_df.llm_plot()

## Example 4: Ingestion from a URL
Instead of searching for the web page, you can also ask the assistant to ingest from a URL.

In [None]:
assistant.create_df('https://time.com/6235186/best-albums-2022/').show()

## Example 6: Test generation
You can ask the assistant to generate test code for a given dataframe transformation function.

In [None]:
import pyspark.sql.functions as F

def remove_non_word_characters(col):
    return F.regexp_replace(col, "[^\\w\\s]+", "")

assistant.test_llm(remove_non_word_characters)

In [None]:
code = """
import unittest
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, StringType
from pyspark.sql.functions import col

def remove_non_word_characters(df, column_name):
    \"""
    This function takes a dataframe and a column name as input, and returns a dataframe
    with non-word characters removed from the specified column.
    \"""
    return df.withColumn(column_name, regexp_replace(col(column_name), r'\W+', ''))


class TestRemoveNonWordCharacters(unittest.TestCase):

    def setUp(self):
        self.spark = SparkSession.builder \
            .master("local[1]") \
            .appName("TestRemoveNonWordCharacters") \
            .getOrCreate()
        schema = StructType([
            StructField("text", StringType(), True)
        ])
        self.input_data = [
            ("Hello, World!",),
            ("I'm a test-case.",),
            ("123,456,789.00",),
        ]
        self.expected_data = [
            ("HelloWorld",),
            ("Imatestcase",),
            ("12345678900",),
        ]

    def test_remove_non_word_characters(self):
        input_df = self.spark.createDataFrame(self.input_data, schema=schema)
        expected_df = self.spark.createDataFrame(self.expected_data, schema=schema)

        result_df = remove_non_word_characters(input_df, "text")

        self.assertTrue(result_df.subtract(expected_df).count() == 0 and
                        expected_df.subtract(result_df).count() == 0,
                        msg="The function did not remove non-word characters correctly.")

    def tearDown(self):
        self.spark.stop()


if __name__ == '__main__':
    try:
        unittest.main(argv=['first-arg-is-ignored'], exit=False)
        result = "OK"
    except Exception as e: 
        result = f"Error"
"""
    
locals_ = {}
exec(code, {}, locals_)

print(f"\nResult: {locals_['__name__']}")