In [None]:
import os
import gzip

import numpy

from sparsedat import Sparse_Data_Table as SDT

from capblood_seq import config

In [None]:
SAMPLES = config.SAMPLE_NAMES
NUM_DROPLETS = None

In [None]:
snp_id_stats = {}
snp_stats_id_map = {}
sample_snp_id_map = {}
next_snp_id = 0

for sample_index, sample in enumerate(SAMPLES):
    
    sample_dir = os.path.join("demux_data", sample)
    
    # To get the SNP ids, we only need to look at the first pileup file
    var_file_name = "pileup-000.var.gz"
    var_file_path = os.path.join(sample_dir, var_file_name)

    with gzip.open(var_file_path, mode="rt", encoding="utf-8") as var_file:

        var_file.readline()

        for line in var_file.readlines():
            line_elements = [x.strip() for x in line.split("\t")]
            sample_snp_id = int(line_elements[0])
            snp_stats = (line_elements[1], int(line_elements[2]), line_elements[3], line_elements[4], line_elements[5])

            if snp_stats in snp_stats_id_map:
                sample_snp_id_map[(sample_index, sample_snp_id)] = snp_stats_id_map[snp_stats]
            else:
                snp_stats_id_map[snp_stats] = next_snp_id
                snp_id_stats[next_snp_id] = snp_stats
                sample_snp_id_map[(sample_index, sample_snp_id)] = next_snp_id
                next_snp_id += 1

var_file_name = "pileup.var.gz"
var_file_path = os.path.join("demux_data", "merged", var_file_name)

with gzip.open(var_file_path, mode="wt", encoding="utf-8") as var_file:

    var_file.write("%s\n" % "\t".join(["#SNP_ID", "CHROM", "POS", "REF", "ALT", "AF"]))
    
    for snp_id, stats in sorted(snp_id_stats.items(), key=lambda x: x[0]):
        var_file.write("%i\t%s\n" % (snp_id, "\t".join([stats[0], str(stats[1]), stats[2], stats[3], stats[4]])))

In [None]:
del snp_stats_id_map
del snp_id_stats

In [None]:
cel_file_name = "pileup.cel.gz"
cel_file_path = os.path.join("demux_data", "merged", cel_file_name)
out_cel_file = gzip.open(cel_file_path, mode="wt", encoding="utf-8")
out_cel_file.write("%s\n" % "\t".join(["#DROPLET_ID", "BARCODE", "NUM.READ", "NUM.UMI", "NUM.UMIwSNP", "NUM.SNP"]))
        
plp_file_name = "pileup.plp.gz"
plp_file_path = os.path.join("demux_data", "merged", plp_file_name)
out_plp_file = gzip.open(plp_file_path, mode="wt", encoding="utf-8")
out_plp_file.write("%s\n" % "\t".join(["#DROPLET_ID", "SNP_ID", "ALLELES", "BASEQS"]))

umi_file_name = "pileup.umi.gz"
umi_file_path = os.path.join("demux_data", "merged", umi_file_name)
out_umi_file = gzip.open(umi_file_path, mode="wt", encoding="utf-8")

next_droplet_id = 0
sample_pileup_droplet_id_map = {}

for sample_index, sample in enumerate(SAMPLES):
    
    print(sample)
    
    sample_dir = os.path.join("demux_data", sample)
    
    barcode_transcript_counts = SDT(
        os.path.join("..", "capblood-seq", "examples", "data", sample, "barcode_transcript_counts.sdt")
    )
    
    if NUM_DROPLETS:
        barcode_total_transcript_counts = barcode_transcript_counts.sum(axis=1)
        barcode_size_ranks = numpy.argsort(barcode_total_transcript_counts)[-NUM_DROPLETS:]
        barcodes_to_include = barcode_transcript_counts[sorted(barcode_size_ranks)].row_names
        barcodes_to_include = set([barcode.split("-")[0] for barcode in barcodes_to_include])
    
    pileup_prefixes = set()
    
    for file in os.listdir(sample_dir):
        if "pileup" in file:
            pileup_prefix = file.split(".")[0]
            pileup_prefixes.add(pileup_prefix)
    
    for pileup_index, pileup_prefix in enumerate(sorted(pileup_prefixes)):
        
        print(pileup_prefix)
        
        droplet_id_cell_barcode_map = {}
        droplet_id_cel_stats = {}
        
        droplet_id_snp_id_plp_stats = {}
        droplet_id_umi_stats = {}
        
        cel_file_name = "%s.cel.gz" % pileup_prefix
        cel_file_path = os.path.join(sample_dir, cel_file_name)
        
        with gzip.open(cel_file_path, mode="rt", encoding="utf-8") as cel_file:
            
            cel_file.readline()
            
            for line in cel_file.readlines():
                line_elements = [x.strip() for x in line.split("\t")]
                cell_barcode = line_elements[1].split("-")[0]
                
                if NUM_DROPLETS and cell_barcode not in barcodes_to_include:
                    continue
                
                sample_droplet_id = int(line_elements[0])
                sample_pileup_droplet_id_map[(sample_index, pileup_index, sample_droplet_id)] = next_droplet_id
                cell_barcode = "%s-%i" % (cell_barcode, sample_index)
                droplet_id_cell_barcode_map[next_droplet_id] = cell_barcode
                droplet_id_cel_stats[next_droplet_id] = tuple([int(x) for x in line_elements[2:]])
                next_droplet_id += 1
                
        for droplet_id, stats in sorted(droplet_id_cel_stats.items(), key=lambda x: x[0]):
            cell_barcode = droplet_id_cell_barcode_map[droplet_id]
            out_cel_file.write("%i\t%s\t%s\n" % (droplet_id, cell_barcode, "\t".join([str(x) for x in stats])))

        plp_file_name = "%s.plp.gz" % pileup_prefix
        plp_file_path = os.path.join(sample_dir, plp_file_name)

        with gzip.open(plp_file_path, mode="rt", encoding="utf-8") as plp_file:
            
            plp_file.readline()

            for line in plp_file.readlines():
                line_elements = [x.strip() for x in line.split("\t")]
                sample_droplet_id = int(line_elements[0])
                
                if (sample_index, pileup_index, sample_droplet_id) not in sample_pileup_droplet_id_map:
                    continue
                
                droplet_id = sample_pileup_droplet_id_map[(sample_index, pileup_index, sample_droplet_id)]
                
                sample_snp_id = int(line_elements[1])
                snp_id = sample_snp_id_map[(sample_index, sample_snp_id)]
                
                droplet_id_snp_id_plp_stats[(droplet_id, snp_id)] = tuple(line_elements[2:])
                
        for droplet_id_snp_id, stats in sorted(droplet_id_snp_id_plp_stats.items(), key=lambda x: (x[0][1], x[0][0])):
            out_plp_file.write("%i\t%i\t%s\n" % (droplet_id_snp_id[0], droplet_id_snp_id[1], "\t".join(stats)))

        umi_file_name = "%s.umi.gz" % pileup_prefix
        umi_file_path = os.path.join(sample_dir, umi_file_name)

        with gzip.open(umi_file_path, mode="rt", encoding="utf-8") as umi_file:

            for line in umi_file.readlines():
                line_elements = [x.strip() for x in line.split("\t")]
                sample_droplet_id = int(line_elements[0])
                
                if (sample_index, pileup_index, sample_droplet_id) not in sample_pileup_droplet_id_map:
                    continue
                
                UMI = line_elements[1]
                count_1 = int(line_elements[2])
                count_2 = int(line_elements[3])
                positions = line_elements[4:]
                
                droplet_id = sample_pileup_droplet_id_map[(sample_index, pileup_index, sample_droplet_id)]
                
                droplet_id_umi_stats[(droplet_id, UMI)] = tuple([count_1, count_2] + positions)
                
        for droplet_id_barcode, stats in sorted(droplet_id_umi_stats.items(), key=lambda x: (x[0][0], x[0][1])):
            out_umi_file.write("%i\t%s\t%i\t%i\t%s\n" % (droplet_id_barcode[0], droplet_id_barcode[1], stats[0], stats[1], "\t".join(stats[2:])))

out_cel_file.close()
out_plp_file.close()
out_umi_file.close()