# Reformat 1kgenome vcf format to vsf

In [1]:
from pyspark.sql import SparkSession
from pyspark import SparkContext, SparkConf
import pyspark

# ask Zhong for the keys
aws_access_key_id = ''
aws_secret_access_key = ''

spark = (SparkSession 
            .builder 
            .appName("1kgenome") 
            .master('local[72]')
            .config('spark.executor.heartbeatInterval', '1000s')
            .config('spark.network.timeout', '10000s')
            .config('spark.executor.memory', '4000G')
            .config('spark.executor.extraJavaOptions','-Dcom.amazonaws.services.s3.enableV4=true')
            .config('spark.driver.extraJavaOptions', '-Dcom.amazonaws.services.s3.enableV4=true')
            .config('spark.kryoserializer.buffer.max', '1G') # higher causes oom （GC limit）
            .config('spark.local.dir','/mnt/data/spark') # /tmp might be full
            .config('spark.sql.autoBroadcastJoinThreshold', '-1')
            .getOrCreate()
        )
sc = spark.sparkContext
sc.setSystemProperty('com.amazonaws.services.s3.enableV4', 'true')

hadoopConf = sc._jsc.hadoopConfiguration()
hadoopConf.set('fs.s3a.access.key', aws_access_key_id)
hadoopConf.set('fs.s3a.secret.key', aws_secret_access_key)


Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


23/01/09 09:42:52 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
23/01/09 09:42:52 WARN SparkConf: Note that spark.local.dir will be overridden by the value set by the cluster manager (via SPARK_LOCAL_DIRS in mesos/standalone/kubernetes and LOCAL_DIRS in YARN).


## 0. Formatting DBSNP

In [None]:
import pyspark.sql.functions as F

# assuming dbsnp database has been downloaded here
path = 's3a://zhong-active/PGS/dbsnp328/'

dbsnps = [
    {'vcf': path + 'GRCh37.gz', 'out': path + 'dbsnp328_grc37_p13.pq'},
    {'vcf': path + 'GRCh39.gz', 'out': path + 'dbsnp328_grc38_p13.pq'},
]

def dbsnp_vcf_to_parquet(dbsnp, parquet_file, partitions=200):
    vcf = (spark
         .read
         .csv(dbsnp, sep='\t', header=None, comment='#')
         .toDF('chr', 'pos', 'rsID', 'ref', 'alt', 'q1', 'q2', 'annot')
         .select(
           F.regexp_extract('chr', r'NC_0+(\d+).', 1).astype('int').alias('chr'),
           F.col('pos').astype('long'),
           F.regexp_extract('rsID', r'rs(\d+)', 1).astype('long').alias('rsID'),
           F.posexplode(F.split(F.concat_ws(',', 'ref', 'alt'), ',')).alias('code', 'allele')                         
         )
         .drop_duplicates(['chr', 'pos', 'allele', 'rsID'])
         .where(F.col('chr')>0)         
        )
    vcf.repartition(partitions).write.mode('overwrite').parquet(parquet_file)
    
for dbsnp in dbsnps:
    dbsnp_vcf_to_parquet(dbsnp['vcf'], dbsnp['out'])

## 1. Combine chromosomal VCFs into a spark dataframe

In [2]:
import pyspark.sql.functions as F

In [3]:
# 1000 genomes on S3
data_path ='s3a://1000genomes/release/20130502/'
prefix = 'ALL.chr'
suffix = '.phase3_shapeit2_mvncall_integrated_v5a.20130502.genotypes.vcf.gz'
chroms = [data_path + prefix + str(c) + suffix for c in range(1, 23)] 
chrY = data_path + 'ALL.chrY.phase3_integrated_v1b.20130502.genotypes.vcf.gz'
chrX = data_path + 'ALL.chrX.phase3_shapeit2_mvncall_integrated_v1b.20130502.genotypes.vcf.gz'
chroms = chroms + [chrX, chrY]

In [4]:
# change chromosome# to numerical ID, using GRC37 specification
mapping = {
  'X': '23',
  'Y': '24',
  'M': '25'
}
mapping.update({str(i+1):str(i+1) for i in range(22)})
# other chromosomes will throw out errors later
apply_mapping_udf = F.udf(lambda x: mapping.get(x, x))


def get_vcf_headers(vcf, numHeaderLines=1000):
    """
    column headers of vcf files
    """
    return (spark
            .read
            .csv(vcf, sep='\n', header=None)
            .limit(numHeaderLines)
            .where(F.substring('_c0', 1, 6) == '#CHROM')
            .collect()[0]['_c0'][1:]
            .split('\t')
    )
def chr_vcf_to_spark_df(vcfs, limit=0):
    """
    combine a list of vcf files (one per chromosome) into one spark df
    set limit to a small number to debug
    """
    all_vcf = None
    for chrom in vcfs:
 #       print("Adding " + chrom)
        headers = get_vcf_headers(chrom)
        if limit>0:
            vcf_c = spark.read.csv(chrom, comment='#', sep='\t', header=None).limit(limit).toDF(*headers)
        else:
            vcf_c = spark.read.csv(chrom, comment='#', sep='\t', header=None).toDF(*headers)
        # do a count to force read the file, otherwise headers will be changed
 #       print("Adding %d records." % vcf_c.count())
        if all_vcf is None:
            all_vcf = vcf_c
        else:
            all_vcf = all_vcf.unionByName(vcf_c, allowMissingColumns=True)  
    # change letter chromosome IDs to numerical
    print("Totally we added %d records." % all_vcf.count())
    all_vcf = all_vcf.withColumn('CHROM', apply_mapping_udf('CHROM').astype('int'))
    return all_vcf

In [5]:
all_vcf = chr_vcf_to_spark_df(chroms)

22/12/15 16:11:06 WARN MetricsConfig: Cannot locate configuration: tried hadoop-metrics2-s3a-file-system.properties,hadoop-metrics2.properties


                                                                                

22/12/15 16:14:15 WARN DAGScheduler: Broadcasting large task binary with size 2.3 MiB


                                                                                

Totally we added 84801880 records.


## 2. Assign SNPid from dbSNP

In [6]:
dbsnp = 's3a://zhong-active/PGS/dbsnp328/dbsnp328_grc37_p13.pq'

dbsnp = (spark
       .read
       .parquet(dbsnp)
       .select(F.col('chr').alias('CHROM'), F.col('pos').alias('POS'), F.col('rsID').alias('ID'))
       .drop_duplicates(['CHROM', 'POS'])
      )    

In [7]:
all_vcf = (all_vcf
        .withColumnRenamed('ID', 'oldID')
        .join(dbsnp, on=['CHROM', 'POS'], how='left')
        )

In [None]:
# takes a long time ~3 days
all_vcf.repartition(200).write.mode('overwrite').parquet('1kgenome_dbSNP328.pq')  

22/12/15 16:19:46 WARN package: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.


[Stage 124:>                                                       (0 + 0) / 24]

22/12/15 16:20:00 WARN DAGScheduler: Broadcasting large task binary with size 22.7 MiB


python3: /home/zhongw/.conda/envs/pyspark/bin/../lib/libuuid.so.1: no version information available (required by /lib64/libndctl.so.6)
python3: /home/zhongw/.conda/envs/pyspark/bin/../lib/libuuid.so.1: no version information available (required by /lib64/libdaxctl.so.1)
python3: /home/zhongw/.conda/envs/pyspark/bin/../lib/libuuid.so.1: no version information available (required by /lib64/libblkid.so.1)

## 3.Covert genotypes to sparse format

In [2]:
from pyspark.ml.linalg import Vectors
from pyspark.ml.feature import VectorAssembler
import pyspark.sql.functions as F
genomes = spark.read.parquet('1kgenome_dbSNP328.pq')


                                                                                

In [3]:
genomes.where(F.col('ID') != F.col('oldID')).select('CHROM', 'POS', 'oldID', 'ID','REF', 'ALT').show(5)

[Stage 1:>                                                          (0 + 1) / 1]

+-----+---------+-----+---------+---+---+
|CHROM|      POS|oldID|       ID|REF|ALT|
+-----+---------+-----+---------+---+---+
|   23| 58310816|    .|767511503|  G|  C|
|   23|100682653|    .|782358386|  G|  A|
|   23| 85977482|    .|749600777|  G|  T|
|   23| 79606493|    .|776730835|  G|  A|
|   23| 74469398|    .|758530346|  G|  A|
+-----+---------+-----+---------+---+---+
only showing top 5 rows



                                                                                

In [5]:
# sanity check for ID assignment
genomes.where((F.col('oldID')!= '.') & (F.col('ID') != F.col('oldID'))).select('CHROM', 'POS', 'oldID', 'ID','REF', 'ALT').show(20)



+-----+---+-----+---+---+---+
|CHROM|POS|oldID| ID|REF|ALT|
+-----+---+-----+---+---+---+
+-----+---+-----+---+---+---+



                                                                                

In [4]:
samples = genomes.columns[9:-1]
len(samples)

2504

In [5]:
samples[-1]

'NA21144'

In [6]:
# split alleles

a1 = (genomes
      .select(
        'ID',
        *((F.split(c, '\|').getItem(0).astype('int').alias(c) for c in samples))
      )
      .fillna(0)
      .withColumn('gts', F.array(*samples))
      .select('ID', 'gts')
     )

a2 = (genomes
      .select(
        'ID',
        *((F.split(c, '\|').getItem(1).astype('int').alias(c) for c in samples))
      )
      .fillna(0)
      .withColumn('gts', F.array(*samples))
      .select('ID', 'gts')
     )

In [7]:
from pyspark.mllib.linalg.distributed import IndexedRow, IndexedRowMatrix
# this takes 64.6 hrs + 50.3 hrs
a1 = IndexedRowMatrix(a1.rdd.map(lambda row: IndexedRow(*row))).toCoordinateMatrix()
a2 = IndexedRowMatrix(a2.rdd.map(lambda row: IndexedRow(*row))).toCoordinateMatrix()

python3: /home/zhongw/.conda/envs/pyspark/bin/../lib/libuuid.so.1: no version information available (required by /lib64/libndctl.so.6)
python3: /home/zhongw/.conda/envs/pyspark/bin/../lib/libuuid.so.1: no version information available (required by /lib64/libdaxctl.so.1)
python3: /home/zhongw/.conda/envs/pyspark/bin/../lib/libuuid.so.1: no version information available (required by /lib64/libblkid.so.1)
                                                                                

23/01/09 09:44:13 WARN DAGScheduler: Broadcasting large task binary with size 1003.9 KiB


[Stage 3:>                                                       (0 + 72) / 400]

23/01/09 09:44:56 ERROR Executor: Exception in task 53.0 in stage 3.0 (TID 56)
java.lang.OutOfMemoryError: GC overhead limit exceeded
	at java.lang.AbstractStringBuilder.<init>(AbstractStringBuilder.java:68)
	at java.lang.StringBuilder.<init>(StringBuilder.java:106)
	at scala.collection.mutable.StringBuilder.<init>(StringBuilder.scala:52)
	at scala.collection.mutable.StringBuilder.<init>(StringBuilder.scala:65)
	at org.apache.spark.sql.catalyst.expressions.codegen.Block$.org$apache$spark$sql$catalyst$expressions$codegen$Block$$foldLiteralArgs(javaCode.scala:256)
	at org.apache.spark.sql.catalyst.expressions.codegen.Block$BlockHelper$.code$extension(javaCode.scala:243)
	at org.apache.spark.sql.catalyst.expressions.BoundReference.doGenCode(BoundAttribute.scala:58)
	at org.apache.spark.sql.catalyst.expressions.Expression.$anonfun$genCode$3(Expression.scala:151)
	at org.apache.spark.sql.catalyst.expressions.Expression$$Lambda$2594/1968334508.apply(Unknown Source)
	at scala.Option.getOrElse

[Stage 3:>                                                       (0 + 72) / 400]

23/01/09 09:45:17 ERROR Executor: Exception in task 26.0 in stage 3.0 (TID 29)
java.lang.OutOfMemoryError: GC overhead limit exceeded
23/01/09 09:45:22 ERROR Executor: Exception in task 69.0 in stage 3.0 (TID 72)
java.lang.OutOfMemoryError: GC overhead limit exceeded
23/01/09 09:45:28 ERROR Executor: Exception in task 1.0 in stage 3.0 (TID 4)
java.lang.OutOfMemoryError: GC overhead limit exceeded
23/01/09 09:45:33 ERROR Executor: Exception in task 5.0 in stage 3.0 (TID 8)
java.lang.OutOfMemoryError: GC overhead limit exceeded
23/01/09 09:45:41 ERROR Executor: Exception in task 19.0 in stage 3.0 (TID 22)
java.lang.OutOfMemoryError: GC overhead limit exceeded
23/01/09 09:45:45 ERROR Executor: Exception in task 33.0 in stage 3.0 (TID 36)
java.lang.OutOfMemoryError: GC overhead limit exceeded
23/01/09 09:45:50 ERROR Executor: Exception in task 67.0 in stage 3.0 (TID 70)
java.lang.OutOfMemoryError: GC overhead limit exceeded
23/01/09 09:45:56 ERROR Executor: Exception in task 11.0 in stage 

[Stage 3:>                                                       (0 + 72) / 400]

23/01/09 09:46:14 ERROR Executor: Exception in task 36.0 in stage 3.0 (TID 39)
java.lang.OutOfMemoryError: GC overhead limit exceeded
23/01/09 09:46:14 ERROR SparkUncaughtExceptionHandler: [Container in shutdown] Uncaught exception in thread Thread[Executor task launch worker for task 36.0 in stage 3.0 (TID 39),5,main]
java.lang.OutOfMemoryError: GC overhead limit exceeded
23/01/09 09:46:20 ERROR Executor: Exception in task 34.0 in stage 3.0 (TID 37)
java.lang.OutOfMemoryError: GC overhead limit exceeded
23/01/09 09:46:20 ERROR SparkUncaughtExceptionHandler: [Container in shutdown] Uncaught exception in thread Thread[Executor task launch worker for task 34.0 in stage 3.0 (TID 37),5,main]
java.lang.OutOfMemoryError: GC overhead limit exceeded
23/01/09 09:46:27 ERROR Executor: Exception in task 39.0 in stage 3.0 (TID 42)
java.lang.OutOfMemoryError: GC overhead limit exceeded
23/01/09 09:46:28 ERROR SparkUncaughtExceptionHandler: [Container in shutdown] Uncaught exception in thread Thread

[Stage 3:>                                                       (0 + 73) / 400]

23/01/09 09:46:42 WARN TaskSetManager: Lost task 19.0 in stage 3.0 (TID 22) (aep12.eng.memverge.com executor driver): java.lang.OutOfMemoryError: GC overhead limit exceeded

23/01/09 09:46:42 ERROR Executor: Exception in task 57.0 in stage 3.0 (TID 60)
java.lang.OutOfMemoryError: GC overhead limit exceeded
23/01/09 09:46:42 ERROR SparkUncaughtExceptionHandler: [Container in shutdown] Uncaught exception in thread Thread[Executor task launch worker for task 57.0 in stage 3.0 (TID 60),5,main]
java.lang.OutOfMemoryError: GC overhead limit exceeded
23/01/09 09:46:42 ERROR Inbox: Ignoring error
java.util.concurrent.RejectedExecutionException: Task org.apache.spark.executor.Executor$TaskRunner@5035c51c rejected from java.util.concurrent.ThreadPoolExecutor@380b80a7[Shutting down, pool size = 64, active threads = 64, queued tasks = 0, completed tasks = 11]
	at java.util.concurrent.ThreadPoolExecutor$AbortPolicy.rejectedExecution(ThreadPoolExecutor.java:2063)
	at java.util.concurrent.ThreadPoolE

[Stage 3:>                                                       (0 + 72) / 400]

23/01/09 09:46:53 ERROR Executor: Exception in task 63.0 in stage 3.0 (TID 66)
java.lang.OutOfMemoryError: GC overhead limit exceeded
23/01/09 09:46:53 ERROR SparkUncaughtExceptionHandler: [Container in shutdown] Uncaught exception in thread Thread[Executor task launch worker for task 63.0 in stage 3.0 (TID 66),5,main]
java.lang.OutOfMemoryError: GC overhead limit exceeded
23/01/09 09:47:01 ERROR Executor: Exception in task 55.0 in stage 3.0 (TID 58)
java.lang.OutOfMemoryError: GC overhead limit exceeded
23/01/09 09:47:01 ERROR SparkUncaughtExceptionHandler: [Container in shutdown] Uncaught exception in thread Thread[Executor task launch worker for task 55.0 in stage 3.0 (TID 58),5,main]
java.lang.OutOfMemoryError: GC overhead limit exceeded
23/01/09 09:47:07 ERROR Executor: Exception in task 4.0 in stage 3.0 (TID 7)
java.lang.OutOfMemoryError: GC overhead limit exceeded
23/01/09 09:47:07 WARN ShutdownHookManager: ShutdownHook 'ClientFinalizer' timeout, java.util.concurrent.TimeoutExce

ERROR:root:KeyboardInterrupt while sending command.
Traceback (most recent call last):
  File "/home/zhongw/.conda/envs/pyspark/lib/python3.10/site-packages/py4j/java_gateway.py", line 1038, in send_command
    response = connection.send_command(command)
  File "/home/zhongw/.conda/envs/pyspark/lib/python3.10/site-packages/py4j/clientserver.py", line 511, in send_command
    answer = smart_decode(self.stream.readline()[:-1])
  File "/home/zhongw/.conda/envs/pyspark/lib/python3.10/socket.py", line 705, in readinto
    return self._sock.recv_into(b)
KeyboardInterrupt


KeyboardInterrupt: 

In [None]:
a1 =(spark
 .createDataFrame(a1.transpose().entries)
 .where(F.col('value')>0)
 .toDF('sample', 'rsID', 'code')
)
a2 =(spark
 .createDataFrame(a2.transpose().entries)
 .where(F.col('value')>0)
 .toDF('sample', 'rsID', 'code')
)

In [None]:
# create the genotype sparse matrix
# rows are samples
# columns are genotype (rsID:gt)
# values are doses (how many alleles)
gt = a1.union(a2)
gt = (gt
      .withColumn('index', F.concat_ws(':', *['rsID', 'code']))
      .withColumn('dose', F.lit(1.0))
      .select('sample', 'index', 'dose')
      .groupby('sample', 'index')
      .agg(F.sum('dose').alias('dose'))
     )

gt.write.mode('overwrite').parquet('1kgenome_gt_coo.pq')

In [16]:
# create indexer
from pyspark.ml.feature import StringIndexer
indexx = gt.select('index')
indexer = StringIndexer(inputCol='index', outputCol='i').fit(indexx)
indexx = indexer.transform(indexx)
gt = gt.join(indexx, on='index', how='left').select('index', 'sample', 'i', 'dose')
gt.write.mode('overwrite').parquet('1kgenome_gt_indexed_coo.pq')


22/12/23 14:12:38 WARN DAGScheduler: Broadcasting large task binary with size 1787.1 KiB




23/01/08 02:01:33 WARN DAGScheduler: Broadcasting large task binary with size 1794.7 KiB




23/01/08 04:01:07 WARN DAGScheduler: Broadcasting large task binary with size 1792.1 KiB


[Stage 21:>                                                         (0 + 1) / 1]

23/01/08 04:30:14 ERROR Executor: Exception in task 0.0 in stage 21.0 (TID 2209)
org.apache.spark.SparkException: Kryo serialization failed: Buffer overflow. Available: 0, required: 6
Serialization trace:
_data (org.apache.spark.util.collection.OpenHashSet)
org$apache$spark$util$collection$OpenHashMap$$_keySet (org.apache.spark.util.collection.OpenHashMap$mcJ$sp). To avoid this, increase spark.kryoserializer.buffer.max value.
	at org.apache.spark.serializer.KryoSerializerInstance.serialize(KryoSerializer.scala:391)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection.apply(Unknown Source)
	at org.apache.spark.sql.execution.aggregate.ComplexTypedAggregateExpression.eval(TypedAggregateExpression.scala:260)
	at org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate.eval(interfaces.scala:594)
	at org.apache.spark.sql.execution.aggregate.AggregationIterator.$anonfun$generateResultProjection$5(AggregationIterator.scala:257)
	at org.apach

Py4JJavaError: An error occurred while calling o30348.fit.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 21.0 failed 1 times, most recent failure: Lost task 0.0 in stage 21.0 (TID 2209) (aep12.eng.memverge.com executor driver): org.apache.spark.SparkException: Kryo serialization failed: Buffer overflow. Available: 0, required: 6
Serialization trace:
_data (org.apache.spark.util.collection.OpenHashSet)
org$apache$spark$util$collection$OpenHashMap$$_keySet (org.apache.spark.util.collection.OpenHashMap$mcJ$sp). To avoid this, increase spark.kryoserializer.buffer.max value.
	at org.apache.spark.serializer.KryoSerializerInstance.serialize(KryoSerializer.scala:391)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection.apply(Unknown Source)
	at org.apache.spark.sql.execution.aggregate.ComplexTypedAggregateExpression.eval(TypedAggregateExpression.scala:260)
	at org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate.eval(interfaces.scala:594)
	at org.apache.spark.sql.execution.aggregate.AggregationIterator.$anonfun$generateResultProjection$5(AggregationIterator.scala:257)
	at org.apache.spark.sql.execution.aggregate.ObjectAggregationIterator.next(ObjectAggregationIterator.scala:96)
	at org.apache.spark.sql.execution.aggregate.ObjectAggregationIterator.next(ObjectAggregationIterator.scala:32)
	at org.apache.spark.sql.execution.SparkPlan.$anonfun$getByteArrayRdd$1(SparkPlan.scala:365)
	at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2(RDD.scala:890)
	at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2$adapted(RDD.scala:890)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:365)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:329)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
	at org.apache.spark.scheduler.Task.run(Task.scala:136)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:548)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1504)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:551)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
	at java.lang.Thread.run(Thread.java:750)
Caused by: com.esotericsoftware.kryo.KryoException: Buffer overflow. Available: 0, required: 6
Serialization trace:
_data (org.apache.spark.util.collection.OpenHashSet)
org$apache$spark$util$collection$OpenHashMap$$_keySet (org.apache.spark.util.collection.OpenHashMap$mcJ$sp)
	at com.esotericsoftware.kryo.io.Output.require(Output.java:167)
	at com.esotericsoftware.kryo.io.Output.writeAscii_slow(Output.java:499)
	at com.esotericsoftware.kryo.io.Output.writeString(Output.java:348)
	at com.esotericsoftware.kryo.serializers.DefaultSerializers$StringSerializer.write(DefaultSerializers.java:195)
	at com.esotericsoftware.kryo.serializers.DefaultSerializers$StringSerializer.write(DefaultSerializers.java:188)
	at com.esotericsoftware.kryo.Kryo.writeObjectOrNull(Kryo.java:629)
	at com.esotericsoftware.kryo.serializers.DefaultArraySerializers$StringArraySerializer.write(DefaultArraySerializers.java:272)
	at com.esotericsoftware.kryo.serializers.DefaultArraySerializers$StringArraySerializer.write(DefaultArraySerializers.java:258)
	at com.esotericsoftware.kryo.Kryo.writeObject(Kryo.java:575)
	at com.esotericsoftware.kryo.serializers.ObjectField.write(ObjectField.java:79)
	at com.esotericsoftware.kryo.serializers.FieldSerializer.write(FieldSerializer.java:508)
	at com.esotericsoftware.kryo.Kryo.writeObject(Kryo.java:575)
	at com.esotericsoftware.kryo.serializers.ObjectField.write(ObjectField.java:79)
	at com.esotericsoftware.kryo.serializers.FieldSerializer.write(FieldSerializer.java:508)
	at com.esotericsoftware.kryo.Kryo.writeClassAndObject(Kryo.java:651)
	at com.esotericsoftware.kryo.serializers.DefaultArraySerializers$ObjectArraySerializer.write(DefaultArraySerializers.java:361)
	at com.esotericsoftware.kryo.serializers.DefaultArraySerializers$ObjectArraySerializer.write(DefaultArraySerializers.java:302)
	at com.esotericsoftware.kryo.Kryo.writeClassAndObject(Kryo.java:651)
	at org.apache.spark.serializer.KryoSerializerInstance.serialize(KryoSerializer.scala:387)
	... 20 more

Driver stacktrace:
	at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:2672)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:2608)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:2607)
	at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
	at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2607)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1(DAGScheduler.scala:1182)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1$adapted(DAGScheduler.scala:1182)
	at scala.Option.foreach(Option.scala:407)
	at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:1182)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2860)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2802)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2791)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:952)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2228)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2249)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2268)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2293)
	at org.apache.spark.rdd.RDD.$anonfun$collect$1(RDD.scala:1021)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
	at org.apache.spark.rdd.RDD.withScope(RDD.scala:406)
	at org.apache.spark.rdd.RDD.collect(RDD.scala:1020)
	at org.apache.spark.sql.execution.SparkPlan.executeCollect(SparkPlan.scala:424)
	at org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec.$anonfun$executeCollect$1(AdaptiveSparkPlanExec.scala:340)
	at org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec.withFinalPlanUpdate(AdaptiveSparkPlanExec.scala:368)
	at org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec.executeCollect(AdaptiveSparkPlanExec.scala:340)
	at org.apache.spark.sql.Dataset.collectFromPlan(Dataset.scala:3868)
	at org.apache.spark.sql.Dataset.$anonfun$collect$1(Dataset.scala:3120)
	at org.apache.spark.sql.Dataset.$anonfun$withAction$2(Dataset.scala:3858)
	at org.apache.spark.sql.execution.QueryExecution$.withInternalError(QueryExecution.scala:510)
	at org.apache.spark.sql.Dataset.$anonfun$withAction$1(Dataset.scala:3856)
	at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$6(SQLExecution.scala:109)
	at org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:169)
	at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$1(SQLExecution.scala:95)
	at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:779)
	at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:64)
	at org.apache.spark.sql.Dataset.withAction(Dataset.scala:3856)
	at org.apache.spark.sql.Dataset.collect(Dataset.scala:3120)
	at org.apache.spark.ml.feature.StringIndexer.countByValue(StringIndexer.scala:204)
	at org.apache.spark.ml.feature.StringIndexer.sortByFreq(StringIndexer.scala:212)
	at org.apache.spark.ml.feature.StringIndexer.fit(StringIndexer.scala:242)
	at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.lang.reflect.Method.invoke(Method.java:498)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.ClientServerConnection.waitForCommands(ClientServerConnection.java:182)
	at py4j.ClientServerConnection.run(ClientServerConnection.java:106)
	at java.lang.Thread.run(Thread.java:750)
Caused by: org.apache.spark.SparkException: Kryo serialization failed: Buffer overflow. Available: 0, required: 6
Serialization trace:
_data (org.apache.spark.util.collection.OpenHashSet)
org$apache$spark$util$collection$OpenHashMap$$_keySet (org.apache.spark.util.collection.OpenHashMap$mcJ$sp). To avoid this, increase spark.kryoserializer.buffer.max value.
	at org.apache.spark.serializer.KryoSerializerInstance.serialize(KryoSerializer.scala:391)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection.apply(Unknown Source)
	at org.apache.spark.sql.execution.aggregate.ComplexTypedAggregateExpression.eval(TypedAggregateExpression.scala:260)
	at org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate.eval(interfaces.scala:594)
	at org.apache.spark.sql.execution.aggregate.AggregationIterator.$anonfun$generateResultProjection$5(AggregationIterator.scala:257)
	at org.apache.spark.sql.execution.aggregate.ObjectAggregationIterator.next(ObjectAggregationIterator.scala:96)
	at org.apache.spark.sql.execution.aggregate.ObjectAggregationIterator.next(ObjectAggregationIterator.scala:32)
	at org.apache.spark.sql.execution.SparkPlan.$anonfun$getByteArrayRdd$1(SparkPlan.scala:365)
	at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2(RDD.scala:890)
	at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2$adapted(RDD.scala:890)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:365)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:329)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
	at org.apache.spark.scheduler.Task.run(Task.scala:136)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:548)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1504)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:551)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
	... 1 more
Caused by: com.esotericsoftware.kryo.KryoException: Buffer overflow. Available: 0, required: 6
Serialization trace:
_data (org.apache.spark.util.collection.OpenHashSet)
org$apache$spark$util$collection$OpenHashMap$$_keySet (org.apache.spark.util.collection.OpenHashMap$mcJ$sp)
	at com.esotericsoftware.kryo.io.Output.require(Output.java:167)
	at com.esotericsoftware.kryo.io.Output.writeAscii_slow(Output.java:499)
	at com.esotericsoftware.kryo.io.Output.writeString(Output.java:348)
	at com.esotericsoftware.kryo.serializers.DefaultSerializers$StringSerializer.write(DefaultSerializers.java:195)
	at com.esotericsoftware.kryo.serializers.DefaultSerializers$StringSerializer.write(DefaultSerializers.java:188)
	at com.esotericsoftware.kryo.Kryo.writeObjectOrNull(Kryo.java:629)
	at com.esotericsoftware.kryo.serializers.DefaultArraySerializers$StringArraySerializer.write(DefaultArraySerializers.java:272)
	at com.esotericsoftware.kryo.serializers.DefaultArraySerializers$StringArraySerializer.write(DefaultArraySerializers.java:258)
	at com.esotericsoftware.kryo.Kryo.writeObject(Kryo.java:575)
	at com.esotericsoftware.kryo.serializers.ObjectField.write(ObjectField.java:79)
	at com.esotericsoftware.kryo.serializers.FieldSerializer.write(FieldSerializer.java:508)
	at com.esotericsoftware.kryo.Kryo.writeObject(Kryo.java:575)
	at com.esotericsoftware.kryo.serializers.ObjectField.write(ObjectField.java:79)
	at com.esotericsoftware.kryo.serializers.FieldSerializer.write(FieldSerializer.java:508)
	at com.esotericsoftware.kryo.Kryo.writeClassAndObject(Kryo.java:651)
	at com.esotericsoftware.kryo.serializers.DefaultArraySerializers$ObjectArraySerializer.write(DefaultArraySerializers.java:361)
	at com.esotericsoftware.kryo.serializers.DefaultArraySerializers$ObjectArraySerializer.write(DefaultArraySerializers.java:302)
	at com.esotericsoftware.kryo.Kryo.writeClassAndObject(Kryo.java:651)
	at org.apache.spark.serializer.KryoSerializerInstance.serialize(KryoSerializer.scala:387)
	... 20 more


In [17]:
gt.show(2)

23/01/08 10:35:49 WARN DAGScheduler: Broadcasting large task binary with size 1795.1 KiB


ERROR:root:KeyboardInterrupt while sending command.              (0 + 72) / 800]
Traceback (most recent call last):
  File "/home/zhongw/.conda/envs/pyspark/lib/python3.10/site-packages/py4j/java_gateway.py", line 1038, in send_command
    response = connection.send_command(command)
  File "/home/zhongw/.conda/envs/pyspark/lib/python3.10/site-packages/py4j/clientserver.py", line 511, in send_command
    answer = smart_decode(self.stream.readline()[:-1])
  File "/home/zhongw/.conda/envs/pyspark/lib/python3.10/socket.py", line 705, in readinto
    return self._sock.recv_into(b)
KeyboardInterrupt


KeyboardInterrupt: 

## 4. Formating GWAS results

In [None]:
# convert alleles to code
import pyspark.sql.functions as F
genomes = spark.read.parquet('s3a://zhong-active/PGS/1kgenome_dbSNP328.pq')
grc37_p13 = spark.read.parquet('s3a://zhong-active/PGS/dbsnp328/dbsnp328_grc37_p13.pq')

code = (genomes
 .select(
   F.col('ID').astype('long'),
   F.posexplode(F.split(F.concat_ws(',', 'REF', 'ALT'), ',')).alias('code1k', 'allele') 
 )
 .select('ID', F.col('code1k').astype('int'), 'allele')
 .join(grc37_p13.select(F.col('rsID').alias('ID'), 'code', 'allele'), on=['ID', 'allele'])
 .drop('allele')
 )

code.write.mode('overwrite').parquet('s3a://zhong-active/PGS/1kgenome_dbSNP328_code_mapping.pq')

In [None]:
# only need to run this once
gwas = 's3a://zhong-active/PGS/shared-dna/notebooks/gwas_2022-11-04.csv.gz'
grc37_p13 = 's3a://zhong-active/PGS/dbsnp328/dbsnp328_grc37_p13.pq'
grc38_p13 = 's3a://zhong-active/PGS/dbsnp328/dbsnp328_grc38_p13.pq'
dbsnp1 = spark.read.parquet(grc37_p13).select('chr', 'pos', 'rsID')
dbsnp2 = spark.read.parquet(grc38_p13).select('chr', 'pos', 'rsID')
ref = (spark
       .read
       .csv(gwas, sep='\t', header=True)
       .withColumn('CHR_ID', F.when(F.isnull('CHR_ID'), F.regexp_extract('SNPs', r'chr([\dXYM]+):(\d+)', 1)).otherwise(F.col('CHR_ID')))
       .withColumn('CHR_POS', F.when(F.isnull('CHR_POS'), F.regexp_extract('SNPs', r'chr([\dXYM]+):(\d+)', 2)).otherwise(F.col('CHR_POS')))       
       .select(F.col('SNP_ID_CURRENT').astype('long').alias('ID'), 
               F.col('rallele').alias('alt'),
               F.col('OR or BETA').astype('float').alias('OR'),
               F.col('CHR_ID').alias('chr'),
               F.col('CHR_POS').astype('long').alias('pos'),
               F.col('DISEASE/TRAIT').alias('trait'),
               F.col('DATE ADDED TO CATALOG').astype('date').alias('date'),
               F.col('PUBMEDID').astype('long').alias('pubmedID')
       ).where(
         (F.col('OR')>0.0)
       )
)

# replace chromosome numbering
mapping = {
  'X': '23',
  'Y': '24',
  'M': '25'
}
mapping.update({str(i+1):str(i+1) for i in range(22)})
print(mapping)
apply_mapping_udf = F.udf(lambda x: mapping.get(x, x))
ref = ref.withColumn('chr', apply_mapping_udf('chr').astype('int'))

ref = (ref
       .join(dbsnp1, on=['chr', 'pos'], how='left')
       .withColumn('ID', F.when(F.isnull('ID'), F.col('rsID')).otherwise(F.col('ID')))
       .drop('rsID')
       .join(dbsnp2, on=['chr', 'pos'], how='left')
       .withColumn('rsID', F.when(F.isnull('ID'), F.col('rsID')).otherwise(F.col('ID')))
       .select('trait', 'rsID', 'alt', 'OR', 'pubmedID', 'date')
      )               
ref = ref.drop_duplicates(['trait', 'rsID', 'alt'])
(ref
 .write
 .mode('overwrite')
 .parquet('s3a://zhong-active/PGS/shared-dna/notebooks/gwas_2022-11-04_ORs.pq')
)

## 5. Convert gwas to sparse format

In [None]:
code = spark.read.parquet('s3a://zhong-active/PGS/1kgenome_dbSNP328_code_mapping.pq')
ref = spark.read.parquet('s3a://zhong-active/PGS/shared-dna/notebooks/gwas_2022-11-04_ORs.pq')
# mapping ref codes to 1kgenome codes
ref1k = ref.join(code.select(F.col('ID').alias('rsID'), 'code', 'code1k'), on=['rsID', 'code'], how='left').fillna(0)

# create gwas ref sparse matrix
# rows are genotype(rsID:gt)
# columns are OR

ref1k = (ref1k
 .withColumn('trait', F.concat_ws('|', *['trait', 'pubmedID']))
 .withColumn('index', F.concat_ws(':', *['rsID', 'code1k']))
 .select('index', 'trait', 'OR')
) 


In [None]:
# create indexer for traits
from pyspark.ml.feature import StringIndexer
indexer = StringIndexer(inputCol='trait', outputCol='j').fit(ref1k)
ref1k = indexer.transform(ref1k).select('index', 'j', 'OR', 'trait')

In [None]:
# match genotypes in reference gwas using the indexed 1kgenome
# variants not in the 1kgenome are ignored

ref1k = (
    ref1k
    .join(gt.select('index', 'i'), on='index', how='left')
    .fillna({'i': -1})
    .where(F.col('i')>=0)
    .select('i', 'j', 'OR', 'trait')
)


In [None]:
# save results
ref1k.write.mode('overwrite').parquet('gwas_2022-11-04_ORs_1kgenome_indexed_coo.pq')

## 6. Sparse score calculation

In [None]:
from pyspark.mllib.linalg.distributed import CoordinateMatrix

gt = CoordinateMatrix(gt.rdd.map(lambda r: (r[0], r[1], r[2])))
gwas = CoordinateMatrix(ref1k.rdd.map(lambda r: (r[0], r[1], r[2])))

In [None]:
results = gt.toRowMatrix().multiply(gwas.toRowMatrix())

In [None]:
# save the score matrix
(spark
 .createDataFrame(results.entries)
 .toDF('sample', 'trait', 'score')
 .write
 .mode('overwrite')
 .parquet('s3a://zhong-active/PGS/1kgenome_gwas_scores_coo.pq')
)