<img style="float: left" src="images/spark.png" />
<img style="float: right" src="images/surfsara.png" />
<hr style="clear: both" />

# Counting kmers with Apache Spark

In our second notebook we will look at genomics data and count the numer of occurences of specific subsequences in the reads.

_You can edit the cells below and execute the code by selecting the cell and press Shift-Enter. Code completion is supported by use of the Tab key._

During the exercises you may want to refer to the [PySpark documentation](https://spark.apache.org/docs/1.6.1/api/python/pyspark.html#pyspark.RDD) for more information on possible transformations and actions.

In [None]:
# initialize Spark
from pyspark import SparkContext, SparkConf

if not 'sc' in globals(): # This Python 'trick' makes sure the SparkContext sc is initialized exactly once
    conf = SparkConf().setMaster('local[*]')
    sc = SparkContext(conf=conf)

##  RDD from a FASTA file

We will start by creating an RDD from a FASTA file. We will read this using the `sc.textFile` method we also used in the first notebook.

In [None]:
import pprint
pp = pprint.PrettyPrinter(indent=2)

reads = sc.textFile("file:////home/jovyan/work/data/blast/input/CAM_SMPL_GS108.fa")

pp.pprint(reads.take(3))

We see that we are able to read the FASTA file, but the division in records is incorrect. The metadata and sequence data of a single read are now split over two separate records. The clean solution to this would be to write a custom InputFormat in Java or Scala. For simplicity in this tutorial we will do some text munching in Spark to get the same effect.

## Remove metadata

In this case we are not interested in the metadata. We select only the sequence data by adding an index number to all the records and use a [`filter`](https://spark.apache.org/docs/1.6.1/api/python/pyspark.html#pyspark.RDD.filter) to select only the odd-numberd records.

In [None]:
indexed_reads = reads.zipWithIndex()

pp.pprint(indexed_reads.take(2))

In [None]:
# Create an RDD 'indexed_sequences' from' indexed_reads' that only contains the records with an odd index number.

### BEGIN SOLUTION
indexed_sequences = indexed_reads.filter(lambda x: x[1] % 2 == 1)
### END SOLUTION

sequences = indexed_sequences.keys().cache()

pp.pprint(sequences.take(2))

## Examining the data

Before we get started, let's have a quick look at the data. How many sequences are there in total in this file?

In [None]:
### BEGIN SOLUTION
num_seq = sequences.count()
### END SOLUTION

print(num_seq)

What is the length of the shorted and longest sequence in this file? [`RDD`](https://spark.apache.org/docs/1.6.1/api/python/pyspark.html#pyspark.RDD) has two useful methods for this.

Extra: Also determine the average length.

In [None]:
seq_lengths = sequences.map(len).cache()

### BEGIN SOLUTION
shortest = seq_lengths.min()
longest = seq_lengths.max()
### END SOLUTION

print('The shortest sequence has length: ' + str(shortest))
print('The longest sequence has length: ' + str(longest))

## Extending WordCount: BaseCount

Now that we have our input data in the format we want we can start to do something with it. Our first exercise is a variation on the WordCount we did in the first notebook. This time we will not count words, but the individual bases in the sequences.

Hint: the easiest way in Python to split a string `s` into characters is: `list(s)`.

In [None]:
# Create an RDD 'bases' from 'sequences' containing all the individual bases (letters).

### BEGIN SOLUTION
bases = sequences.flatMap(lambda s: list(s))
### END SOLUTION 

pp.pprint(bases.take(5))

In [None]:
from test_helper import Test

Test.assertEquals(bases.take(5), [u'A', u'T', u'T', u'T', u'A'], 'incorrect value for bases')

Now that we have all the bases, all we need to do is count them by creating (key, 1) tuples and sum these per key. We can chain these two operations on a single line.

In [None]:
# Create an RDD 'basecount' from 'bases' with (base, #occurences) pairs.

### BEGIN SOLUTION
basecounts = bases.map(lambda b: (b, 1)).reduceByKey(lambda a,b: a + b)
### END SOLUTION

Since we know the number of records in this RDD is very small (5), it is safe to call `collect` on the RDD and print the results.

In [None]:
pp.pprint(basecounts.collect())

In [None]:
Test.assertEquals(sorted(basecounts.collect()),
                  [(u'A', 189420125), (u'C', 89225753), (u'G', 90183181), (u'N', 134734), (u'T', 187685193)],
                  'incorrect value for basecounts')

## Extending BaseCount: KmerCount

We have counted all the 1-mers. Let's extend this to the general case of k-mers. For this we need to generate all overlapping substrings of length k. We can no longer just split the sequence, but need to use a sliding window. We have already written a helper function for you that does this. 

In [None]:
def sliding(seq, size):
    result = []
    for i in range(0, len(seq) - size + 1):
        result.append(seq[i:i + size])
    return result

print(sliding("GAGATCTCCTGTGGTGTCCTTGGTCATAGTGATTTGCTCCTACAA", 5))

Create an RDD with all the subsequences of length 10.

In [None]:
# Create an RDD 'kmers' from 'sequences' with all the 10-mers

### BEGIN SOLUTION
kmers = sequences.flatMap(lambda s: sliding(s, 10))
### END SOLUTION

print(kmers.take(5))

In [None]:
Test.assertEquals(kmers.take(5),
                  [u'ATTTACAATA', u'TTTACAATAA', u'TTACAATAAT', u'TACAATAATT', u'ACAATAATTT'],
                  'incorrect value for kmers')

Count all the unique values in the RDD. Looks familiar?

In [None]:
# Create an RDD 'kmercounts' from 'kmers' with (10-mer, #occurences) pairs.

### BEGIN SOLUTION
kmercounts = kmers.map(lambda x: (x, 1)).reduceByKey(lambda x,y: x + y)
### END SOLUTION

kmercounts.cache()

Use `takeOrdered` to get the 10 most frequent 10-mers. This may take some time (~ 2 minutes).

In [None]:
### BEGIN SOLUTION
top10_kmers = kmercounts.takeOrdered(10, key=lambda x: -x[1])
### END SOLUTION

print(top10_kmers)

In [None]:
Test.assertEquals(top10_kmers,
                  [(u'AAAAATAAAA', 53752), (u'AAAAAAATAA', 50949), (u'AAAAAATAAA', 50909), (u'AAAAAATTAA', 48357),
                   (u'AAAAAAATTA', 48239), (u'TTTTATTTTT', 47295), (u'AAAAATTAAA', 46774), (u'AAAATAAAAA', 45347),
                   (u'AAAATTAAAA', 44951), (u'TTTTTAAAAA', 43321)], 
                  'incorrect value for top10_kmers')

Finally, we look at the distribution of 10-mers. We want to know how many unique 10-mers occur only once, how many twice, etc. We will plot the results using Python's matplotlib.

In [None]:
# Create an RDD 'kmerdist' from 'kmercounts' with (#occurences, #unique kmers)

### BEGIN SOLUTION
kmerdist = kmercounts.map(lambda x: (x[1], 1)).reduceByKey(lambda x, y: x + y)
### END SOLUTION

In [None]:
# For the plot we sort the RDD by numer of unique kmers
kmerdistsorted = kmerdist.map(lambda x: (x[1], x[0])).sortByKey(True).cache()

In [None]:
%matplotlib notebook
import matplotlib.pyplot as plt

x = kmerdistsorted.keys().collect()
y = kmerdistsorted.values().collect()

plt.plot(x, y)
plt.yscale('log')
plt.title("kmer distribution")
plt.xlabel("kmer matches")
plt.ylabel("genome-wide frequency")
plt.legend()
plt.show()

## Optional 1: [1,k]-mers in a single go

Create a version of the KmerCount example that counts all subsequenecs of length 1,2,3,...,k. Don't use any `for` loops over RDDs (on a single record is OK). And don't use `collect` to combine results.

## Optional 2: Runtime as a function of k

You might have noticed that the runtime increases when you increase the number k. Look at the runtime of the KmerCount for different values of k and plot these in a graph. You can get an estimate of the runtime of a notebook cell by putting `%time` on the first line.

## End of the second notebook

Congratulations! You finished the second notebook. You can continue to the next notebook "03 - Running Blast on Spark". Close this notebook via 'File' -> 'Close and Halt' to stop the underlying kernel and release computing resources.