In [1]:
# Take a look at the file system
display(dbutils.fs.ls("/databricks-datasets/samples/docs/"))

path,name,size
dbfs:/databricks-datasets/samples/docs/README.md,README.md,3137


In [2]:
# transformation
textFile = spark.read.text("/databricks-datasets/samples/docs/README.md")

In [3]:
# action
textFile.count()

In [4]:
# display columns
textFile.columns

In [5]:
# Output the first line from the text file
textFile.first()

In [6]:
# Filter all of the lines within the DataFrame
linesWithSpark = textFile.filter(textFile.value.contains("Spark"))

# Perform a count (action) 
c = linesWithSpark.count()

In [7]:
# display
linesWithSpark.take(c)

In [8]:
%python
# Use the Spark CSV datasource with options specifying:
# - First line of file is a header
# - Automatically infer the schema of the data
data = spark.read.csv("/databricks-datasets/samples/population-vs-price/data_geo.csv", header="true", inferSchema="true")
data.cache() # Cache data for faster reuse


In [9]:
data.count()

In [10]:
data.drop_duplicates()

In [11]:
data.dtypes

In [12]:
data.summary().take(10)

In [13]:
%python
data.take(10)

In [14]:
%python
display(data)

2014 rank,City,State,State Code,2014 Population estimate,2015 median sales price
101,Birmingham,Alabama,AL,212247.0,162.9
125,Huntsville,Alabama,AL,188226.0,157.7
122,Mobile,Alabama,AL,194675.0,122.5
114,Montgomery,Alabama,AL,200481.0,129.0
64,Anchorage[19],Alaska,AK,301010.0,
78,Chandler,Arizona,AZ,254276.0,
86,Gilbert[20],Arizona,AZ,239277.0,
88,Glendale,Arizona,AZ,237517.0,
38,Mesa,Arizona,AZ,464704.0,
148,Peoria,Arizona,AZ,166934.0,


In [15]:
%python
# Register table so it is accessible via SQL Context
data.createOrReplaceTempView("data_geo")

In [16]:
a = spark.sql("select City from data_geo where `State Code` = 'WA'")
display(a)

City
Bellevue
Everett
Kent
Seattle
Spokane
Tacoma
Vancouver


In [17]:
%sql select City from data_geo where `State Code` = 'WA';

City
Bellevue
Everett
Kent
Seattle
Spokane
Tacoma
Vancouver


In [18]:
import os
import subprocess
import uuid

In [19]:
# Set a unique working directory for this notebook.
work_dir = os.path.join("/ml/tmp/petastorm", str(uuid.uuid4()))
dbutils.fs.mkdirs(work_dir)

def get_local_path(dbfs_path):
  return os.path.join("/dbfs", dbfs_path.lstrip("/"))

In [20]:
data_url = "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/mnist.bz2"
libsvm_path = os.path.join(work_dir, "mnist.bz2")
subprocess.check_output(["wget", data_url, "-O", get_local_path(libsvm_path)])

In [21]:
df = spark.read.format("libsvm").option("numFeatures", "784").load(libsvm_path)

In [22]:
display(df)

label,features
5.0,"List(0, 784, List(152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 260, 261, 262, 263, 264, 265, 266, 268, 269, 289, 290, 291, 292, 293, 319, 320, 321, 322, 347, 348, 349, 350, 376, 377, 378, 379, 380, 381, 405, 406, 407, 408, 409, 410, 434, 435, 436, 437, 438, 439, 463, 464, 465, 466, 467, 493, 494, 495, 496, 518, 519, 520, 521, 522, 523, 524, 544, 545, 546, 547, 548, 549, 550, 551, 570, 571, 572, 573, 574, 575, 576, 577, 578, 596, 597, 598, 599, 600, 601, 602, 603, 604, 605, 622, 623, 624, 625, 626, 627, 628, 629, 630, 631, 648, 649, 650, 651, 652, 653, 654, 655, 656, 657, 676, 677, 678, 679, 680, 681, 682, 683), List(3.0, 18.0, 18.0, 18.0, 126.0, 136.0, 175.0, 26.0, 166.0, 255.0, 247.0, 127.0, 30.0, 36.0, 94.0, 154.0, 170.0, 253.0, 253.0, 253.0, 253.0, 253.0, 225.0, 172.0, 253.0, 242.0, 195.0, 64.0, 49.0, 238.0, 253.0, 253.0, 253.0, 253.0, 253.0, 253.0, 253.0, 253.0, 251.0, 93.0, 82.0, 82.0, 56.0, 39.0, 18.0, 219.0, 253.0, 253.0, 253.0, 253.0, 253.0, 198.0, 182.0, 247.0, 241.0, 80.0, 156.0, 107.0, 253.0, 253.0, 205.0, 11.0, 43.0, 154.0, 14.0, 1.0, 154.0, 253.0, 90.0, 139.0, 253.0, 190.0, 2.0, 11.0, 190.0, 253.0, 70.0, 35.0, 241.0, 225.0, 160.0, 108.0, 1.0, 81.0, 240.0, 253.0, 253.0, 119.0, 25.0, 45.0, 186.0, 253.0, 253.0, 150.0, 27.0, 16.0, 93.0, 252.0, 253.0, 187.0, 249.0, 253.0, 249.0, 64.0, 46.0, 130.0, 183.0, 253.0, 253.0, 207.0, 2.0, 39.0, 148.0, 229.0, 253.0, 253.0, 253.0, 250.0, 182.0, 24.0, 114.0, 221.0, 253.0, 253.0, 253.0, 253.0, 201.0, 78.0, 23.0, 66.0, 213.0, 253.0, 253.0, 253.0, 253.0, 198.0, 81.0, 2.0, 18.0, 171.0, 219.0, 253.0, 253.0, 253.0, 253.0, 195.0, 80.0, 9.0, 55.0, 172.0, 226.0, 253.0, 253.0, 253.0, 253.0, 244.0, 133.0, 11.0, 136.0, 253.0, 253.0, 253.0, 212.0, 135.0, 132.0, 16.0))"
0.0,"List(0, 784, List(127, 128, 129, 130, 131, 154, 155, 156, 157, 158, 159, 181, 182, 183, 184, 185, 186, 187, 188, 189, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 289, 290, 291, 292, 293, 294, 295, 296, 297, 300, 301, 302, 316, 317, 318, 319, 320, 321, 328, 329, 330, 343, 344, 345, 346, 347, 348, 349, 356, 357, 358, 371, 372, 373, 374, 384, 385, 386, 399, 400, 401, 412, 413, 414, 426, 427, 428, 429, 440, 441, 442, 454, 455, 456, 457, 466, 467, 468, 469, 470, 482, 483, 484, 493, 494, 495, 496, 497, 510, 511, 512, 520, 521, 522, 523, 538, 539, 540, 547, 548, 549, 550, 566, 567, 568, 569, 570, 571, 572, 573, 574, 575, 576, 577, 578, 594, 595, 596, 597, 598, 599, 600, 601, 602, 603, 604, 622, 623, 624, 625, 626, 627, 628, 629, 630, 651, 652, 653, 654, 655, 656, 657), List(51.0, 159.0, 253.0, 159.0, 50.0, 48.0, 238.0, 252.0, 252.0, 252.0, 237.0, 54.0, 227.0, 253.0, 252.0, 239.0, 233.0, 252.0, 57.0, 6.0, 10.0, 60.0, 224.0, 252.0, 253.0, 252.0, 202.0, 84.0, 252.0, 253.0, 122.0, 163.0, 252.0, 252.0, 252.0, 253.0, 252.0, 252.0, 96.0, 189.0, 253.0, 167.0, 51.0, 238.0, 253.0, 253.0, 190.0, 114.0, 253.0, 228.0, 47.0, 79.0, 255.0, 168.0, 48.0, 238.0, 252.0, 252.0, 179.0, 12.0, 75.0, 121.0, 21.0, 253.0, 243.0, 50.0, 38.0, 165.0, 253.0, 233.0, 208.0, 84.0, 253.0, 252.0, 165.0, 7.0, 178.0, 252.0, 240.0, 71.0, 19.0, 28.0, 253.0, 252.0, 195.0, 57.0, 252.0, 252.0, 63.0, 253.0, 252.0, 195.0, 198.0, 253.0, 190.0, 255.0, 253.0, 196.0, 76.0, 246.0, 252.0, 112.0, 253.0, 252.0, 148.0, 85.0, 252.0, 230.0, 25.0, 7.0, 135.0, 253.0, 186.0, 12.0, 85.0, 252.0, 223.0, 7.0, 131.0, 252.0, 225.0, 71.0, 85.0, 252.0, 145.0, 48.0, 165.0, 252.0, 173.0, 86.0, 253.0, 225.0, 114.0, 238.0, 253.0, 162.0, 85.0, 252.0, 249.0, 146.0, 48.0, 29.0, 85.0, 178.0, 225.0, 253.0, 223.0, 167.0, 56.0, 85.0, 252.0, 252.0, 252.0, 229.0, 215.0, 252.0, 252.0, 252.0, 196.0, 130.0, 28.0, 199.0, 252.0, 252.0, 253.0, 252.0, 252.0, 233.0, 145.0, 25.0, 128.0, 252.0, 253.0, 252.0, 141.0, 37.0))"
4.0,"List(0, 784, List(160, 161, 162, 172, 173, 188, 189, 190, 200, 201, 215, 216, 217, 218, 228, 229, 243, 244, 245, 256, 257, 271, 272, 273, 283, 284, 285, 299, 300, 301, 311, 312, 313, 326, 327, 328, 329, 339, 340, 341, 354, 355, 356, 357, 367, 368, 369, 379, 380, 381, 382, 383, 384, 395, 396, 397, 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 452, 453, 454, 455, 456, 457, 458, 459, 465, 466, 467, 493, 494, 495, 521, 522, 523, 549, 550, 551, 577, 578, 579, 605, 606, 607, 633, 634, 635, 661, 662, 663, 689, 690, 691), List(67.0, 232.0, 39.0, 62.0, 81.0, 120.0, 180.0, 39.0, 126.0, 163.0, 2.0, 153.0, 210.0, 40.0, 220.0, 163.0, 27.0, 254.0, 162.0, 222.0, 163.0, 183.0, 254.0, 125.0, 46.0, 245.0, 163.0, 198.0, 254.0, 56.0, 120.0, 254.0, 163.0, 23.0, 231.0, 254.0, 29.0, 159.0, 254.0, 120.0, 163.0, 254.0, 216.0, 16.0, 159.0, 254.0, 67.0, 14.0, 86.0, 178.0, 248.0, 254.0, 91.0, 159.0, 254.0, 85.0, 47.0, 49.0, 116.0, 144.0, 150.0, 241.0, 243.0, 234.0, 179.0, 241.0, 252.0, 40.0, 150.0, 253.0, 237.0, 207.0, 207.0, 207.0, 253.0, 254.0, 250.0, 240.0, 198.0, 143.0, 91.0, 28.0, 5.0, 233.0, 250.0, 119.0, 177.0, 177.0, 177.0, 177.0, 177.0, 98.0, 56.0, 102.0, 254.0, 220.0, 169.0, 254.0, 137.0, 169.0, 254.0, 57.0, 169.0, 254.0, 57.0, 169.0, 255.0, 94.0, 169.0, 254.0, 96.0, 169.0, 254.0, 153.0, 169.0, 255.0, 153.0, 96.0, 254.0, 153.0))"
1.0,"List(0, 784, List(158, 159, 160, 161, 185, 186, 187, 188, 189, 213, 214, 215, 216, 217, 240, 241, 242, 243, 244, 245, 267, 268, 269, 270, 271, 295, 296, 297, 298, 322, 323, 324, 325, 326, 349, 350, 351, 352, 353, 377, 378, 379, 380, 381, 404, 405, 406, 407, 408, 431, 432, 433, 434, 435, 459, 460, 461, 462, 463, 486, 487, 488, 489, 490, 514, 515, 516, 517, 518, 542, 543, 544, 545, 569, 570, 571, 572, 573, 596, 597, 598, 599, 600, 601, 624, 625, 626, 627, 652, 653, 654, 655, 680, 681, 682, 683), List(124.0, 253.0, 255.0, 63.0, 96.0, 244.0, 251.0, 253.0, 62.0, 127.0, 251.0, 251.0, 253.0, 62.0, 68.0, 236.0, 251.0, 211.0, 31.0, 8.0, 60.0, 228.0, 251.0, 251.0, 94.0, 155.0, 253.0, 253.0, 189.0, 20.0, 253.0, 251.0, 235.0, 66.0, 32.0, 205.0, 253.0, 251.0, 126.0, 104.0, 251.0, 253.0, 184.0, 15.0, 80.0, 240.0, 251.0, 193.0, 23.0, 32.0, 253.0, 253.0, 253.0, 159.0, 151.0, 251.0, 251.0, 251.0, 39.0, 48.0, 221.0, 251.0, 251.0, 172.0, 234.0, 251.0, 251.0, 196.0, 12.0, 253.0, 251.0, 251.0, 89.0, 159.0, 255.0, 253.0, 253.0, 31.0, 48.0, 228.0, 253.0, 247.0, 140.0, 8.0, 64.0, 251.0, 253.0, 220.0, 64.0, 251.0, 253.0, 220.0, 24.0, 193.0, 253.0, 220.0))"
9.0,"List(0, 784, List(208, 209, 210, 211, 212, 213, 214, 215, 216, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 289, 290, 291, 292, 293, 296, 297, 298, 299, 300, 316, 317, 318, 319, 320, 324, 325, 326, 327, 343, 344, 345, 346, 347, 350, 351, 352, 353, 354, 370, 371, 372, 373, 377, 378, 379, 380, 381, 382, 398, 399, 400, 401, 402, 403, 404, 405, 406, 407, 408, 409, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 455, 456, 457, 458, 459, 460, 461, 462, 463, 464, 489, 490, 491, 492, 517, 518, 519, 520, 546, 547, 548, 573, 574, 575, 576, 601, 602, 603, 604, 629, 630, 631, 632, 658, 659, 660, 686, 687, 688, 689, 714, 715, 716, 717, 718, 743, 744, 745, 746), List(55.0, 148.0, 210.0, 253.0, 253.0, 113.0, 87.0, 148.0, 55.0, 87.0, 232.0, 252.0, 253.0, 189.0, 210.0, 252.0, 252.0, 253.0, 168.0, 4.0, 57.0, 242.0, 252.0, 190.0, 65.0, 5.0, 12.0, 182.0, 252.0, 253.0, 116.0, 96.0, 252.0, 252.0, 183.0, 14.0, 92.0, 252.0, 252.0, 225.0, 21.0, 132.0, 253.0, 252.0, 146.0, 14.0, 215.0, 252.0, 252.0, 79.0, 126.0, 253.0, 247.0, 176.0, 9.0, 8.0, 78.0, 245.0, 253.0, 129.0, 16.0, 232.0, 252.0, 176.0, 36.0, 201.0, 252.0, 252.0, 169.0, 11.0, 22.0, 252.0, 252.0, 30.0, 22.0, 119.0, 197.0, 241.0, 253.0, 252.0, 251.0, 77.0, 16.0, 231.0, 252.0, 253.0, 252.0, 252.0, 252.0, 226.0, 227.0, 252.0, 231.0, 55.0, 235.0, 253.0, 217.0, 138.0, 42.0, 24.0, 192.0, 252.0, 143.0, 62.0, 255.0, 253.0, 109.0, 71.0, 253.0, 252.0, 21.0, 253.0, 252.0, 21.0, 71.0, 253.0, 252.0, 21.0, 106.0, 253.0, 252.0, 21.0, 45.0, 255.0, 253.0, 21.0, 218.0, 252.0, 56.0, 96.0, 252.0, 189.0, 42.0, 14.0, 184.0, 252.0, 170.0, 11.0, 14.0, 147.0, 252.0, 42.0))"
2.0,"List(0, 784, List(155, 156, 157, 158, 159, 181, 182, 183, 184, 185, 186, 187, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 261, 262, 263, 264, 265, 266, 267, 269, 270, 271, 272, 289, 290, 291, 292, 293, 294, 297, 298, 299, 300, 317, 318, 319, 320, 325, 326, 327, 328, 353, 354, 355, 356, 377, 378, 379, 380, 381, 382, 383, 384, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 455, 456, 457, 458, 459, 460, 462, 463, 464, 465, 466, 467, 468, 469, 482, 483, 484, 485, 486, 487, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 509, 510, 511, 512, 513, 514, 516, 517, 518, 519, 520, 522, 523, 524, 525, 526, 527, 528, 537, 538, 539, 540, 541, 542, 543, 544, 545, 546, 547, 553, 554, 555, 556, 565, 566, 567, 568, 569, 570, 571, 572, 573, 574, 593, 594, 595, 596, 597, 598, 599, 600, 601, 621, 622, 623, 624, 625, 626), List(13.0, 25.0, 100.0, 122.0, 7.0, 33.0, 151.0, 208.0, 252.0, 252.0, 252.0, 146.0, 40.0, 152.0, 244.0, 252.0, 253.0, 224.0, 211.0, 252.0, 232.0, 40.0, 15.0, 152.0, 239.0, 252.0, 252.0, 252.0, 216.0, 31.0, 37.0, 252.0, 252.0, 60.0, 96.0, 252.0, 252.0, 252.0, 252.0, 217.0, 29.0, 37.0, 252.0, 252.0, 60.0, 181.0, 252.0, 252.0, 220.0, 167.0, 30.0, 77.0, 252.0, 252.0, 60.0, 26.0, 128.0, 58.0, 22.0, 100.0, 252.0, 252.0, 60.0, 157.0, 252.0, 252.0, 60.0, 110.0, 121.0, 122.0, 121.0, 202.0, 252.0, 194.0, 3.0, 10.0, 53.0, 179.0, 253.0, 253.0, 255.0, 253.0, 253.0, 228.0, 35.0, 5.0, 54.0, 227.0, 252.0, 243.0, 228.0, 170.0, 242.0, 252.0, 252.0, 231.0, 117.0, 6.0, 6.0, 78.0, 252.0, 252.0, 125.0, 59.0, 18.0, 208.0, 252.0, 252.0, 252.0, 252.0, 87.0, 7.0, 5.0, 135.0, 252.0, 252.0, 180.0, 16.0, 21.0, 203.0, 253.0, 247.0, 129.0, 173.0, 252.0, 252.0, 184.0, 66.0, 49.0, 49.0, 3.0, 136.0, 252.0, 241.0, 106.0, 17.0, 53.0, 200.0, 252.0, 216.0, 65.0, 14.0, 72.0, 163.0, 241.0, 252.0, 252.0, 223.0, 105.0, 252.0, 242.0, 88.0, 18.0, 73.0, 170.0, 244.0, 252.0, 126.0, 29.0, 89.0, 180.0, 180.0, 37.0, 231.0, 252.0, 245.0, 205.0, 216.0, 252.0, 252.0, 252.0, 124.0, 3.0, 207.0, 252.0, 252.0, 252.0, 252.0, 178.0, 116.0, 36.0, 4.0, 13.0, 93.0, 143.0, 121.0, 23.0, 6.0))"
1.0,"List(0, 784, List(124, 125, 126, 127, 151, 152, 153, 154, 155, 179, 180, 181, 182, 183, 208, 209, 210, 211, 235, 236, 237, 238, 239, 263, 264, 265, 266, 267, 268, 292, 293, 294, 295, 296, 321, 322, 323, 324, 349, 350, 351, 352, 377, 378, 379, 380, 405, 406, 407, 408, 433, 434, 435, 436, 461, 462, 463, 464, 489, 490, 491, 492, 493, 517, 518, 519, 520, 521, 545, 546, 547, 548, 549, 574, 575, 576, 577, 578, 602, 603, 604, 605, 606, 630, 631, 632, 633, 634, 658, 659, 660, 661, 662), List(145.0, 255.0, 211.0, 31.0, 32.0, 237.0, 253.0, 252.0, 71.0, 11.0, 175.0, 253.0, 252.0, 71.0, 144.0, 253.0, 252.0, 71.0, 16.0, 191.0, 253.0, 252.0, 71.0, 26.0, 221.0, 253.0, 252.0, 124.0, 31.0, 125.0, 253.0, 252.0, 252.0, 108.0, 253.0, 252.0, 252.0, 108.0, 255.0, 253.0, 253.0, 108.0, 253.0, 252.0, 252.0, 108.0, 253.0, 252.0, 252.0, 108.0, 253.0, 252.0, 252.0, 108.0, 255.0, 253.0, 253.0, 170.0, 253.0, 252.0, 252.0, 252.0, 42.0, 149.0, 252.0, 252.0, 252.0, 144.0, 109.0, 252.0, 252.0, 252.0, 144.0, 218.0, 253.0, 253.0, 255.0, 35.0, 175.0, 252.0, 252.0, 253.0, 35.0, 73.0, 252.0, 252.0, 253.0, 35.0, 31.0, 211.0, 252.0, 253.0, 35.0))"
3.0,"List(0, 784, List(151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 261, 262, 263, 264, 269, 270, 271, 272, 273, 274, 297, 298, 299, 300, 301, 324, 325, 326, 327, 328, 329, 350, 351, 352, 353, 354, 355, 356, 357, 373, 374, 375, 376, 377, 378, 379, 380, 381, 382, 383, 384, 400, 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 457, 458, 459, 460, 461, 462, 463, 464, 465, 466, 467, 492, 493, 494, 495, 520, 521, 522, 523, 538, 539, 540, 547, 548, 549, 550, 551, 565, 566, 567, 568, 573, 574, 575, 576, 577, 578, 579, 593, 594, 595, 596, 597, 598, 599, 600, 601, 602, 603, 604, 605, 606, 621, 622, 623, 624, 625, 626, 627, 628, 629, 630, 631, 632, 633, 649, 650, 651, 652, 653, 654, 655, 656, 657, 658, 659, 678, 679, 680, 681, 682, 683, 684), List(38.0, 43.0, 105.0, 255.0, 253.0, 253.0, 253.0, 253.0, 253.0, 174.0, 6.0, 43.0, 139.0, 224.0, 226.0, 252.0, 253.0, 252.0, 252.0, 252.0, 252.0, 252.0, 252.0, 158.0, 14.0, 178.0, 252.0, 252.0, 252.0, 252.0, 253.0, 252.0, 252.0, 252.0, 252.0, 252.0, 252.0, 252.0, 59.0, 109.0, 252.0, 252.0, 230.0, 132.0, 133.0, 132.0, 132.0, 189.0, 252.0, 252.0, 252.0, 252.0, 59.0, 4.0, 29.0, 29.0, 24.0, 14.0, 226.0, 252.0, 252.0, 172.0, 7.0, 85.0, 243.0, 252.0, 252.0, 144.0, 88.0, 189.0, 252.0, 252.0, 252.0, 14.0, 91.0, 212.0, 247.0, 252.0, 252.0, 252.0, 204.0, 9.0, 32.0, 125.0, 193.0, 193.0, 193.0, 253.0, 252.0, 252.0, 252.0, 238.0, 102.0, 28.0, 45.0, 222.0, 252.0, 252.0, 252.0, 252.0, 253.0, 252.0, 252.0, 252.0, 177.0, 45.0, 223.0, 253.0, 253.0, 253.0, 253.0, 255.0, 253.0, 253.0, 253.0, 253.0, 74.0, 31.0, 123.0, 52.0, 44.0, 44.0, 44.0, 44.0, 143.0, 252.0, 252.0, 74.0, 15.0, 252.0, 252.0, 74.0, 86.0, 252.0, 252.0, 74.0, 5.0, 75.0, 9.0, 98.0, 242.0, 252.0, 252.0, 74.0, 61.0, 183.0, 252.0, 29.0, 18.0, 92.0, 239.0, 252.0, 252.0, 243.0, 65.0, 208.0, 252.0, 252.0, 147.0, 134.0, 134.0, 134.0, 134.0, 203.0, 253.0, 252.0, 252.0, 188.0, 83.0, 208.0, 252.0, 252.0, 252.0, 252.0, 252.0, 252.0, 252.0, 252.0, 253.0, 230.0, 153.0, 8.0, 49.0, 157.0, 252.0, 252.0, 252.0, 252.0, 252.0, 217.0, 207.0, 146.0, 45.0, 7.0, 103.0, 235.0, 252.0, 172.0, 103.0, 24.0))"
1.0,"List(0, 784, List(152, 153, 154, 180, 181, 182, 183, 208, 209, 210, 211, 236, 237, 238, 239, 264, 265, 266, 267, 292, 293, 294, 295, 320, 321, 322, 323, 349, 350, 351, 377, 378, 379, 405, 406, 407, 433, 434, 435, 461, 462, 463, 489, 490, 491, 492, 517, 518, 519, 520, 546, 547, 548, 574, 575, 576, 602, 603, 604, 630, 631, 632, 658, 659, 660, 686, 687, 688), List(5.0, 63.0, 197.0, 20.0, 254.0, 230.0, 24.0, 20.0, 254.0, 254.0, 48.0, 20.0, 254.0, 255.0, 48.0, 20.0, 254.0, 254.0, 57.0, 20.0, 254.0, 254.0, 108.0, 16.0, 239.0, 254.0, 143.0, 178.0, 254.0, 143.0, 178.0, 254.0, 143.0, 178.0, 254.0, 162.0, 178.0, 254.0, 240.0, 113.0, 254.0, 240.0, 83.0, 254.0, 245.0, 31.0, 79.0, 254.0, 246.0, 38.0, 214.0, 254.0, 150.0, 144.0, 241.0, 8.0, 144.0, 240.0, 2.0, 144.0, 254.0, 82.0, 230.0, 247.0, 40.0, 168.0, 209.0, 31.0))"
4.0,"List(0, 784, List(134, 135, 161, 162, 163, 188, 189, 190, 191, 216, 217, 218, 236, 237, 238, 243, 244, 245, 246, 264, 265, 266, 271, 272, 273, 292, 293, 294, 298, 299, 300, 301, 319, 320, 321, 322, 325, 326, 327, 328, 329, 346, 347, 348, 349, 353, 354, 355, 373, 374, 375, 376, 380, 381, 382, 383, 399, 400, 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, 464, 465, 466, 467, 482, 483, 484, 488, 489, 490, 491, 492, 493, 494, 510, 511, 516, 517, 518, 519, 520, 521, 522, 543, 544, 545, 546, 571, 572, 573, 574, 598, 599, 600, 601, 626, 627, 628, 654, 655, 656), List(189.0, 190.0, 143.0, 247.0, 153.0, 136.0, 247.0, 242.0, 86.0, 192.0, 252.0, 187.0, 62.0, 185.0, 18.0, 89.0, 236.0, 217.0, 47.0, 216.0, 253.0, 60.0, 212.0, 255.0, 81.0, 206.0, 252.0, 68.0, 48.0, 242.0, 253.0, 89.0, 131.0, 251.0, 212.0, 21.0, 11.0, 167.0, 252.0, 197.0, 5.0, 29.0, 232.0, 247.0, 63.0, 153.0, 252.0, 226.0, 45.0, 219.0, 252.0, 143.0, 116.0, 249.0, 252.0, 103.0, 4.0, 96.0, 253.0, 255.0, 253.0, 200.0, 122.0, 7.0, 25.0, 201.0, 250.0, 158.0, 92.0, 252.0, 252.0, 253.0, 217.0, 252.0, 252.0, 200.0, 227.0, 252.0, 231.0, 87.0, 251.0, 247.0, 231.0, 65.0, 48.0, 189.0, 252.0, 252.0, 253.0, 252.0, 251.0, 227.0, 35.0, 190.0, 221.0, 98.0, 42.0, 196.0, 252.0, 253.0, 252.0, 252.0, 162.0, 111.0, 29.0, 62.0, 239.0, 252.0, 86.0, 42.0, 42.0, 14.0, 15.0, 148.0, 253.0, 218.0, 121.0, 252.0, 231.0, 28.0, 31.0, 221.0, 251.0, 129.0, 218.0, 252.0, 160.0, 122.0, 252.0, 82.0))"


In [23]:
%scala

import org.apache.spark.ml.linalg.Vector

val toArray = udf { v: Vector => v.toArray }
spark.sqlContext.udf.register("toArray", toArray)

In [24]:
# Convert sparse vectors to dense arrays and write the data as Parquet.
# Petastorm will sample Parquet row groups into batches.
# Batch size is important for the utilization of both I/O and compute.
# You can use parquet.block.size to control the size.
parquet_path = os.path.join(work_dir, "parquet")
df.selectExpr("toArray(features) AS features", "int(label) AS label") \
  .repartition(10) \
  .write.mode("overwrite") \
  .option("parquet.block.size", 1024 * 1024) \
  .parquet(parquet_path)

In [25]:
## Load data using Petastorm and feed data into a DL framework

# Use Petastorm to load the Parquet data and create a tf.data.Dataset. Then fit a simple neural network model using tf.Keras.

In [26]:
!pip install petastorm

In [27]:
!conda install -c conda-forge pyarrow

#!pip install pyarrow

In [28]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import models, layers

from petastorm import make_batch_reader
from petastorm.tf_utils import make_petastorm_dataset


In [29]:
def get_model():
  model = models.Sequential()
  model.add(layers.Conv2D(32, kernel_size=(3, 3),
                          activation='relu',
                          input_shape=(28, 28, 1)))
  model.add(layers.Conv2D(64, (3, 3), activation='relu'))
  model.add(layers.MaxPooling2D(pool_size=(2, 2)))
  model.add(layers.Dropout(0.25))
  model.add(layers.Flatten())
  model.add(layers.Dense(128, activation='relu'))
  model.add(layers.Dropout(0.5))
  model.add(layers.Dense(10, activation='softmax'))
  return model

In [30]:
import pyarrow.parquet as pq

underscore_files = [f for f in os.listdir(get_local_path(parquet_path)) if f.startswith("_")]
pq.EXCLUDED_PARQUET_PATHS.update(underscore_files)


In [31]:
# We use make_batch_reader to load Parquet row groups into batches.
# HINT: Use cur_shard and shard_count params to shard data in distributed training.
petastorm_dataset_url = "file://" + get_local_path(parquet_path)
with make_batch_reader(petastorm_dataset_url, num_epochs=100) as reader:
  dataset = make_petastorm_dataset(reader) \
    .map(lambda x: (tf.reshape(x.features, [-1, 28, 28, 1]), tf.one_hot(x.label, 10)))
  model = get_model()
  optimizer = keras.optimizers.Adadelta()
  model.compile(optimizer=optimizer,
                loss='categorical_crossentropy',
                metrics=['accuracy'])
  model.fit(dataset, steps_per_epoch=10, epochs=10)

In [32]:
# Clean up the working directory.
dbutils.fs.rm(work_dir, recurse=True)