In [16]:
from pyspark import SparkConf, SparkContext
from pyspark.sql import SparkSession
from pyspark.sql.functions import col
from pyspark.sql.types import StructField, StructType, StringType, IntegerType, FloatType

In [2]:
conf = SparkConf().set("spark.executor.memory", "6g")

In [3]:
sc = SparkContext().getOrCreate(conf=conf)

In [4]:
spark = SparkSession.builder.appName('Collaborative Filtering').getOrCreate()

In [5]:
schema = StructType([
    StructField('userId', IntegerType(), nullable=True),
    StructField('movieId', IntegerType(), nullable=True),
    StructField('rating', FloatType(), nullable=True),
    StructField('timestamp', IntegerType(), nullable=True),
])

In [6]:
# rdd = sc.textFile('ratings.csv')
# rdd.collect()

In [7]:
# headers = rdd.first()
# rdd = rdd.filter(lambda x: x != headers)
# rdd.collect()

In [8]:
%%time
rdf = spark.read.options(header= True,).schema(schema=schema).csv('ratings.csv')
rdf

CPU times: total: 0 ns
Wall time: 1.4 s


DataFrame[userId: int, movieId: int, rating: float, timestamp: int]

In [9]:
rdf.show()

+------+-------+------+----------+
|userId|movieId|rating| timestamp|
+------+-------+------+----------+
|     1|    307|   3.5|1256677221|
|     1|    481|   3.5|1256677456|
|     1|   1091|   1.5|1256677471|
|     1|   1257|   4.5|1256677460|
|     1|   1449|   4.5|1256677264|
|     1|   1590|   2.5|1256677236|
|     1|   1591|   1.5|1256677475|
|     1|   2134|   4.5|1256677464|
|     1|   2478|   4.0|1256677239|
|     1|   2840|   3.0|1256677500|
|     1|   2986|   2.5|1256677496|
|     1|   3020|   4.0|1256677260|
|     1|   3424|   4.5|1256677444|
|     1|   3698|   3.5|1256677243|
|     1|   3826|   2.0|1256677210|
|     1|   3893|   3.5|1256677486|
|     2|    170|   3.5|1192913581|
|     2|    849|   3.5|1192913537|
|     2|   1186|   3.5|1192913611|
|     2|   1235|   3.0|1192913585|
+------+-------+------+----------+
only showing top 20 rows



In [10]:
rdf.printSchema()

root
 |-- userId: integer (nullable = true)
 |-- movieId: integer (nullable = true)
 |-- rating: float (nullable = true)
 |-- timestamp: integer (nullable = true)



In [11]:
rdf.rdd.getNumPartitions()

12

In [12]:
rdf.count()

27753444

In [13]:
rdf.select('userId', 'movieId', 'rating').show()

+------+-------+------+
|userId|movieId|rating|
+------+-------+------+
|     1|    307|   3.5|
|     1|    481|   3.5|
|     1|   1091|   1.5|
|     1|   1257|   4.5|
|     1|   1449|   4.5|
|     1|   1590|   2.5|
|     1|   1591|   1.5|
|     1|   2134|   4.5|
|     1|   2478|   4.0|
|     1|   2840|   3.0|
|     1|   2986|   2.5|
|     1|   3020|   4.0|
|     1|   3424|   4.5|
|     1|   3698|   3.5|
|     1|   3826|   2.0|
|     1|   3893|   3.5|
|     2|    170|   3.5|
|     2|    849|   3.5|
|     2|   1186|   3.5|
|     2|   1235|   3.0|
+------+-------+------+
only showing top 20 rows



In [15]:
rdf.select(rdf.userId, rdf.movieId, rdf.rating).show()

+------+-------+------+
|userId|movieId|rating|
+------+-------+------+
|     1|    307|   3.5|
|     1|    481|   3.5|
|     1|   1091|   1.5|
|     1|   1257|   4.5|
|     1|   1449|   4.5|
|     1|   1590|   2.5|
|     1|   1591|   1.5|
|     1|   2134|   4.5|
|     1|   2478|   4.0|
|     1|   2840|   3.0|
|     1|   2986|   2.5|
|     1|   3020|   4.0|
|     1|   3424|   4.5|
|     1|   3698|   3.5|
|     1|   3826|   2.0|
|     1|   3893|   3.5|
|     2|    170|   3.5|
|     2|    849|   3.5|
|     2|   1186|   3.5|
|     2|   1235|   3.0|
+------+-------+------+
only showing top 20 rows



In [17]:
rdf.select(col('userId')).show()

+------+
|userId|
+------+
|     1|
|     1|
|     1|
|     1|
|     1|
|     1|
|     1|
|     1|
|     1|
|     1|
|     1|
|     1|
|     1|
|     1|
|     1|
|     1|
|     2|
|     2|
|     2|
|     2|
+------+
only showing top 20 rows

