<img style="float: left" src="images/spark.png" />
<img style="float: right" src="images/surfsara.png" />
<hr style="clear: both" />

# Counting kmers with Apache Spark

_You can edit the cells below and execute the code by selecting the cell and press Shift-Enter. Code completion is supported by use of the Tab key._

During the exercises you may want to refer to the [PySpark documentation](https://spark.apache.org/docs/1.6.1/api/python/pyspark.html#pyspark.RDD) for more information on possible transformations and actions.

In [None]:
# initialize Spark
from pyspark import SparkContext, SparkConf

if not 'sc' in globals(): # This Python 'trick' makes sure the SparkContext sc is initialized exactly once
    conf = SparkConf().setMaster('local[*]')
    sc = SparkContext(conf=conf)

##  RDD from a FASTA file

Read a FASTA file from disk, convert the line-based stuff 

In [None]:
reads = sc.textFile("file:////home/jovyan/work/data/blast/input/CAM_SMPL_GS108.fa")

Records are not correct. Two option, write a custom InputFormat or try some data munching in Spark. 

In this case we are not interested in the metadata. We select only the sequence data by adding an index number to all the records and, select only the odd-numberd records.

In [None]:
import pprint
pp = pprint.PrettyPrinter(indent=2)

indexedReads = reads.zipWithIndex()
print(indexedReads.take(2))

In [None]:
### BEGIN SOLUTION
sequences = indexedReads.filter(lambda x: x[1] % 2 == 1).keys()
### END SOLUTION
sequences.cache()
sequences.take(2)

## Extending wordcount: basecount

Alright, time to do some programming of our own: extend the wordcount example so that it now counts letters, or bases in a sequence.

Like we have split lines into word, now we split sequences (strings) into bases (characters).

Hint: the easiest way in Python to split a string `s` into characters is: `list(s)`.

In [None]:
### BEGIN SOLUTION
bases = sequences.flatMap(lambda s: list(s))
### END SOLUTION 

pp.pprint(bases.take(5))

Make key-value pairs and sum using reduceByKey

In [None]:
### BEGIN SOLUTION
basecounts = bases.map(lambda b: (b, 1)).reduceByKey(lambda a,b: a + b)
### END SOLUTION

Since we know the number of records is now very small (5), it is safe to call `collect` on the RDD and print the results.

In [None]:
pp.pprint(basecounts.collect())

## Extending basecount: kmercount

Lets make it a bit more interesting.

Select all substrings of length k. We have written a helper function for you.

In [None]:
def sliding(seq, size):
    result = []
    for i in range(0, len(seq) - size + 1):
        result.append(seq[i:i + size])
    return result

sliding("GAGATCTCCTGTGGTGTCCTTGGTCATAGTGATTTGCTCCTACAA", 5)

Create an RDD with all the subsequences of length 21.

In [None]:
### BEGIN SOLUTION
kmers = sequences.flatMap(lambda s: sliding(s, 21))
### END SOLUTION

In [None]:
### BEGIN SOLUTION
kmercounts = kmers.map(lambda x: (x, 1)).reduceByKey(lambda x,y: x + y)
### END SOLUTION
kmercounts.cache()

Use takeOrdered to get the 10 most frequent 21-mers

In [None]:
### BEGIN SOLUTION
top10 = kmercounts.takeOrdered(10, key=lambda x: -x[1])
### END SOLUTION
pp.pprint(top10)

Finally, we look at the distribution of 21-mers. x-axis the number of occurences of a 21-mer, y-axis the number of 21-mers that occur x times.

In [None]:
### BEGIN SOLUTION
kmerdist = kmercounts.map(lambda x: (x[1], 1)).reduceByKey(lambda x, y: x + y)
### END SOLUTION 

In [None]:
kmerdistsorted = kmerdist.map(lambda x: (x[1], x[0])).sortByKey(True).cache()

In [None]:
%matplotlib notebook
import matplotlib.pyplot as plt

x = kmerdistsorted.keys().collect()
y = kmerdistsorted.values().collect()

plt.plot(x, y)
plt.yscale('log')
plt.title("kmer distribution")
plt.xlabel("kmer matches")
plt.ylabel("genome wide frequency")
plt.legend()
plt.show()