In [None]:
import org.apache.spark.graphx._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.DataFrame

    //val spark = SparkSession.builder.appName("Simple Application").getOrCreate()  
      
val station = spark.read.format("csv")
         .option("header", "true") //first line in file has headers
         .option("mode", "DROPMALFORMED")
         .load("./station.csv")
station.printSchema()
station.registerTempTable("station")

val trip = spark.read.format("csv")
         .option("header", "true") //first line in file has headers
         .option("mode", "DROPMALFORMED")
         .load("./trip.csv")
trip.printSchema()
trip.registerTempTable("trip")

val bikeStations = spark.sqlContext.sql("SELECT * FROM station") 
bikeStations.printSchema()

val tripData = spark.sqlContext.sql("SELECT * FROM trip")

val justStations = bikeStations
  .selectExpr("float(id) as station_id", "name")
  .distinct()

val stations = tripData
  .select("start_station_id").withColumnRenamed("start_station_id", "station_id")
  .union(tripData.select("end_station_id").withColumnRenamed("end_station_id", "station_id"))
  .distinct()
  .select(col("station_id").cast("long").alias("value"))

stations.take(1) // this is just a station_id at this point

//create set of vertices with properties
val stationVertices: RDD[(VertexId, String)] = stations
  .join(justStations, stations("value") === justStations("station_id"))
  .select(col("station_id").cast("long"), col("name"))
  .rdd
  .map(row => (row.getLong(0), row.getString(1))) // maintain type information

stationVertices.take(1)

//create trip edges
val stationEdges:RDD[Edge[Long]] = tripData
  .select(col("start_station_id").cast("long"), col("end_station_id").cast("long"))
  .rdd
  .map(row => Edge(row.getLong(0), row.getLong(1), 1))
//add dummy value of 1
stationEdges.take(1)

//build a graph
val defaultStation = ("Missing Station") 
val stationGraph = Graph(stationVertices, stationEdges, defaultStation)
stationGraph.cache()
      
println("Total Number of Stations: " + stationGraph.numVertices)
println("Total Number of Trips: " + stationGraph.numEdges)
// sanity check
println("Total Number of Trips in Original Data: " + tripData.count)    

In [None]:
TEST RESULT
Expected output result:

root
 |-- id: string (nullable = true)
 |-- name: string (nullable = true)
 |-- lat: string (nullable = true)
 |-- long: string (nullable = true)
 |-- dock_count: string (nullable = true)
 |-- city: string (nullable = true)
 |-- installation_date: string (nullable = true)

root
 |-- id: string (nullable = true)
 |-- duration: string (nullable = true)
 |-- start_date: string (nullable = true)
 |-- start_station_name: string (nullable = true)
 |-- start_station_id: string (nullable = true)
 |-- end_date: string (nullable = true)
 |-- end_station_name: string (nullable = true)
 |-- end_station_id: string (nullable = true)
 |-- bike_id: string (nullable = true)
 |-- subscription_type: string (nullable = true)
 |-- zip_code: string (nullable = true)

root
 |-- id: string (nullable = true)
 |-- name: string (nullable = true)
 |-- lat: string (nullable = true)
 |-- long: string (nullable = true)
 |-- dock_count: string (nullable = true)
 |-- city: string (nullable = true)
 |-- installation_date: string (nullable = true)

Total Number of Stations: 70
Total Number of Trips: 669959
Total Number of Trips in Original Data: 669959
import org.apache.spark.graphx._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.DataFrame
station: org.apache.spark.sql.DataFrame = [id: string, name: string ... 5 more fields]
trip: org.apache.spark.sql.DataFrame = [id: string, duration: string ... 9 more fields]
bikeStations: org.apache.spark.sql.DataFrame = [id: string, name: string ... 5 more fields]
tripData: org.apache.spark.sql.DataFrame = [id: string, duration: string ... 9 more fields]
justStations: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [station_id: float, name: string]
stations: org.apache.spark.sql.DataFrame = [value: bigint]
stationVertices: org.apache.spark.rdd.RDD[(org.apache.spark.graphx.VertexId, String)] = ...
​