In [38]:
from pyspark import SparkConf, SparkContext, RDD
from pyspark.sql import SparkSession, DataFrame
from pyspark.sql.types import StructType, StructField, StringType, ArrayType, IntegerType
from pyspark.sql.types import *
from pyspark.sql.functions import split, col, size
from pyspark.sql.functions import udf
from pyspark.sql.functions import *
from pyspark.sql.types import DoubleType
from typing import List, Tuple
import numpy as np
from statistics import pvariance

def get_spark_context(on_server) -> SparkContext:
    spark_conf = SparkConf().setAppName("2AMD15")
    if not on_server:
        spark_conf = spark_conf.setMaster("local[*]")
    spark_context = SparkContext.getOrCreate(spark_conf)

    if on_server:
        # TODO: You may want to change ERROR to WARN to receive more info. For larger data sets, to not set the
        # log level to anything below WARN, Spark will print too much information.
        spark_context.setLogLevel("ERROR")

    return spark_context

on_server = False  # TODO: Set this to true if and only if deploying to the server

spark_context = get_spark_context(on_server)

# q1

In [124]:
def q1a(spark_context: SparkContext, on_server: bool) -> DataFrame:
    vectors_file_path = "/vectors.csv" if on_server else "./vectors.csv"

    spark = SparkSession(spark_context) 

    # TODO: Implement Q1a here by creating a Dataset of DataFrame out of the file at {@code vectors_file_path}.

    # Read CSV file into DataFrame
    df = spark.read.option("header", "false") \
        .csv(vectors_file_path) \
        .withColumnRenamed("_c0", "key").withColumnRenamed("_c1", "value") \
        .select('key', split(col("value"),";").cast("array<int>").alias("value")) \
        .sort('value')


    # df.sort('value').explain() # explain(): 显示当前任务的lineage

    df.show(n=5)
    # df.take(6) # take 拿出来的是array
    

    # df_size = df.select(size("value").alias("vector length")) # size(): return the vector length
    # df_size.show()

    return df

df = q1a(spark_context, on_server)

+----+--------------------+
| key|               value|
+----+--------------------+
|QH9W| [6, 17, 66, 66, 39]|
|YRBC|[16, 60, 34, 35, 32]|
|CDL3|[23, 87, 61, 62, 43]|
|X1E8|[25, 50, 29, 58, 34]|
|P1WH|[26, 39, 34, 38, 41]|
+----+--------------------+
only showing top 5 rows



In [103]:
df.printSchema()

root
 |-- key: string (nullable = true)
 |-- value: array (nullable = true)
 |    |-- element: integer (containsNull = true)



In [106]:
type(df.collect()[0][1])

list

In [35]:
def str2vector(cell:str) -> list:
    ''' 
    Turn a dataframe cell that is a str form of a vector into a numpy vector / or a list
    '''
    return [int(num) for num in cell.split(';')]

cell = df.collect()[0][1]
str2vector(cell=cell)



[44, 33, 29, 27, 37]

In [34]:
# q1b
def q1b(spark_context: SparkContext, on_server: bool) -> RDD:
    vectors_file_path = "/vectors.csv" if on_server else "./vectors.csv"

    # TODO: Implement Q1b here by creating an RDD out of the file at {@code vectors_file_path}.

    vectors_rdd01 = spark_context.textFile(vectors_file_path)

    return rdd 

    



In [47]:
vectors_file_path = "/vectors.csv" if on_server else "./vectors.csv"
rdd = spark_context.textFile(vectors_file_path)


./vectors.csv MapPartitionsRDD[372] at textFile at NativeMethodAccessorImpl.java:0

# Q2

In [118]:
# Define UDF

compute_var = udf(lambda vector: pvariance(vector), FloatType())
aggregate = udf(lambda x, y, z: pvariance(
    [x[i] + y[i] + z[i]
    for i in range( len(x) ) ]
))

# df.select(
#     compute_var(col('value')).alias('variance')
# ).show(10)

sqlWay.select(
    aggregate(col('v1'),col('v2'),col('v3')).alias('var')
).count()
    

64000

In [128]:
# Use UDF in SQL

spark.udf.register("compute_var", compute_var) # register in SQL
spark.udf.register('aggregate_var', aggregate)

23/03/08 15:39:11 WARN SimpleFunctionRegistry: The function compute_var replaced a previously registered function.
23/03/08 15:39:11 WARN SimpleFunctionRegistry: The function aggregate_var replaced a previously registered function.


<function __main__.<lambda>(x, y, z)>

[('key', 'string'), ('value', 'string')]

In [134]:
df.createOrReplaceTempView("vectors")
spark = SparkSession(spark_context) 

# aggregate_var(c1,c2,c3) as var

tau = 410

sqlWay = spark.sql('''
SELECT id1, id2, id3, var 
FROM (
    SELECT id1, id2, id3, aggregate_var(v1,v2,v3) as var
    FROM(
        SELECT  vectors1.key as id1, vectors2.key as id2, vectors3.key as id3, vectors1.value as v1, vectors2.value as v2, vectors3.value as v3
        FROM vectors as vectors1, vectors as vectors2, vectors as vectors3  
        WHERE vectors1.value < vectors2.value and vectors2.value < vectors3.value
    )
)
WHERE var < 410
''')

# sqlWay = spark.sql('''
# SELECT vectors1.value as v1, vectors2.value as v2, vectors3.value as v3
# FROM vectors as vectors1, vectors as vectors2, vectors as vectors3  
# WHERE vectors1.value < vectors2.value and vectors2.value < vectors3.value
# ''')


print('number of rows ' + str(sqlWay.count()))
sqlWay.show()

number of rows 2689
+----+----+----+------------------+
| id1| id2| id3|               var|
+----+----+----+------------------+
|K573|QEL9|FH4I|            331.76|
|K573|ZQ5O|UJXN|            256.56|
|K573|ZQ5O|PW9U|            283.76|
|K573|EEO5|ZQ5O|213.35999999999999|
|K573|EEO5|RJJ7|            395.36|
|K573|OE4G|ZQ5O|             212.4|
|K573|OE4G|FH4I|            265.04|
|K573|OE4G|RJJ7|              73.6|
|K573|OE4G|BX50|387.84000000000003|
|K573|OE4G|WAPJ|             88.64|
|K573|OE4G|JU8G|228.64000000000001|
|K573|FH4I|ZQ5O|            400.24|
|K573|XLCC|ZQ5O|328.96000000000004|
|K573|XLCC|RJJ7|            298.16|
|K573|RJJ7|ZQ5O|             335.6|
|K573|RJJ7|FH4I|            373.84|
|K573|RJJ7|UJXN|            216.96|
|K573|RJJ7|BX50|            255.84|
|K573|RJJ7|PW9U|229.76000000000002|
|K573|BX50|ZQ5O|            263.44|
+----+----+----+------------------+
only showing top 20 rows



In [None]:
def q2(spark_context: SparkContext, data_frame: DataFrame):
    # TODO: Implement Q2 here
    

In [None]:
q2(spark_context, data_frame)

# main

In [31]:
on_server = False  # TODO: Set this to true if and only if deploying to the server

spark_context = get_spark_context(on_server)

data_frame = q1a(spark_context, on_server)

data_frame.show()

+----+--------------+
| key|         value|
+----+--------------+
|K573|44;33;29;27;37|
|ZNSR|37;39;15;37;42|
|CDL3|23;87;61;62;43|
|PQ9E|29;40;46;13;33|
|QWOS|37;24;51;33;43|
|HJ0O| 36;2;68;36;44|
|M46E|41;41;29;78;53|
|X1E8|25;50;29;58;34|
|LNZN|27;17;38;26;39|
|V1N7|39;58;63;57;23|
|QEL9|61;84;10;75;52|
|ZQ5O|62;45;47;47;79|
|ZD9B|  77;97;7;20;1|
|ZAXW|86;80;22;74;27|
|EEO5| 53;42;59;50;5|
|XCXR|27;65;16;50;34|
|OE4G|47;54;37;54;33|
|FH4I|62;37;80;51;37|
|YRBC|16;60;34;35;32|
|LU0Q|91;81;41;40;99|
+----+--------------+
only showing top 20 rows

