In [40]:
import pandas as pd
import os
from Bio import SeqIO, AlignIO
import pysam
# from Bio.Align.Applications import PrankCommandline
# import subprocess
from collections import Counter
import statistics

In [41]:
def groupUMI(ucfile):
    uc = pd.read_csv(ucfile,delimiter="\t",header=None,usecols=[0,1,8])
    uc.columns=["a","cluster","name"]
    uc = uc.loc[uc["a"].isin(["C","H"])]
    # uc.index=uc["name"]
    uc = uc.loc[uc.duplicated(subset='cluster', keep=False)]
    return(uc)


def dictBam(bamfile, outDict):
    """
    i
    """
    with pysam.AlignmentFile(bamfile, "rb", check_sq=False) as infile:
        with open(outDict, "w") as outfile:
#         dic ={}
            for read in infile.fetch(until_eof=True):
                dic = {}
                if not read.has_tag("MD") or (query_seq := read.query_sequence) is None or read.is_secondary:
                    return None

                readDict = callMutation(read)
                dic[read.qname] = readDict
                outfile.write(f"{read.qname}\t{readDict[0]}\t{readDict[1]}\t{readDict[2]}\t{readDict[3]}\n")

def callMutation(read):
    """
    get mutations and indels from a read which has MD tag in sam/bam
    """

    mut_type = []
    mut_qual= []
    indel_type = []
    indel_qual =[]

    ref_index = read.reference_start-1
    query_index = -1
    I_index = [0]
    D_index = [0]

    for query_pos, ref_pos, ref_base in read.get_aligned_pairs(with_seq=True): #matches_only=True, # if True, no None on either side

        if query_pos is None: # Deletion
            ref_index += 1
            D_index.append(ref_index)
            if D_index[-1]-D_index[-2] >1:
                indel_type.append(str(ref_index)+"D")
                indel_qual.append(read.query_qualities[query_index])

        elif ref_pos is None: # Insertion
            query_index +=1
            I_index.append(ref_index)
            if I_index[-1]-I_index[-2] >1:
                indel_type.append(str(ref_index)+"I")
                indel_qual.append(read.query_qualities[query_pos])

        elif ref_base.islower():
            ref_index += 1
            query_index +=1
            mut_type.append(str(ref_pos)+ref_base.upper()+read.query_sequence[query_pos])
            mut_qual.append(read.query_qualities[query_pos])

        else:
            ref_index += 1
            query_index +=1

    return(mut_type, mut_qual, indel_type, indel_qual)


# def generateConsensus(dictbam, uc, my_cluster, frequency_threshold = 0.5, quality_threshold = 4):
#     uc = groupUMI(ucfile)
#     my_cluster = 2
#     clusters = ucsam.cluster.unique()
#     for my_cluster in clusters:
#         sub=uc.loc[uc["cluster"]==my_cluster]
#         if len(sub)>=2:
#             ssMuts = ccSeq(dictbam, sub, frequency_threshold, quality_threshold)
            
            

def ccSeq(sub, frequency_threshold, quality_threshold):
    """
    find cluster consensus mutations in each uc group
    """
    num = len(sub)
    ccMuts = []
    
    muts =sub["mut_type"].to_numpy()
    muts = removeBracket(muts, qual = False)
    mut_quals =sub["mut_qual"].to_numpy()
    mut_quals = removeBracket(mut_quals, qual = True)
    indels =sub["indel_type"].to_numpy()
    indels = removeBracket(indels, qual = False)
    indel_quals =sub["indel_qual"].to_numpy()
    indel_quals = removeBracket(indel_quals, qual = True)
      
    fM = filterMuts(muts, mut_quals, num, frequency_threshold, quality_threshold)
    fID = filterMuts(indels, indel_quals, num, frequency_threshold, quality_threshold)
    ccMuts.append(fM)
    ccMuts.append(fID)
    ccMuts = flatList(ccMuts)
    return(ccMuts)

def removeBracket(s,qual=True):
    """
    ori: ["['2048TC']" '[]' "['1283CT', '1529GC']"]
    transformed: ['2048TC','1283CT','1529GC']
    """
    su = []
    if s is not None:
        for s_sub in s:
            s_sub = s_sub.replace("[","")
            s_sub = s_sub.replace("]","")
            s_sub = s_sub.replace("'","")
            s_sub = s_sub.split(",")
            for s_sub_sub in s_sub:
                if s_sub_sub != '':
                    s_sub_sub = s_sub_sub.replace(" ","")
                    if qual:
                        su.append(int(s_sub_sub))
                    else:
                        su.append(str(s_sub_sub))
    return(su)
    
def filterMuts(muts, mut_quals, num, frequency_threshold = 0.5, quality_threshold = 4):
    """
    filter mutations/indels with frequency >= frequency_threshold
    """
    cs = [ele for ele, cnt in Counter(muts).items() if cnt >= num*frequency_threshold]
    csMuts = []
    if len(cs) >0:
        for mutType in cs:
            quals = []
            for b,q in zip(muts, mut_quals):
                if b == mutType:
                    quals.append(q)
            if statistics.mean(quals) >= quality_threshold:
                csMuts.append(mutType)
    return(csMuts)


def flatList(Alist):
    flat_list = []
    for sublist in Alist:
        for item in sublist:
            flat_list.append(item)        
    return(flat_list)




In [39]:
if __name__ == "__main__":
    """
    generate consensus
    """
        
    SAMPLE_NAME = "DSST"
    SAMPLE_BARCODES = "/data/zhaolian/LineageTracing/DSS/PacBio/sampleBarcodes_"+SAMPLE_NAME+".txt"
    BAMDIR="/data/zhaolian/LineageTracing/DSS/PacBio/5.bam_split/"
    UCDIR="/data/zhaolian/LineageTracing/DSS/PacBio/4.umi_clustering/"
    OUTDIR="/data/zhaolian/LineageTracing/DSS/PacBio/6.umiConsensus/"
    numUC=3 ## number of ccs reads in one consensus cluster
    mutFreq=0.8 ## export the mutation if mutation frequency >= mutFreq 

    if not os.path.exists(OUTDIR):
        os.mkdir(OUTDIR)

    with open(SAMPLE_BARCODES,"r") as f_barcode:
        for line in f_barcode.readlines():
            sample = line.replace("\n", "").split("\t")[0]
            print(sample)

            bamfile = BAMDIR+sample+".bam"
            outDict = OUTDIR+sample+"_mutation_per_read"+str(numUC)+"_"+str(mutFreq)+".txt" ## output file 1
            dictBam(bamfile, outDict)

            ucfile = UCDIR+sample+"_umi_UsearchClusters.uc"
            outfile = OUTDIR+sample+"_umiConsensus"+str(numUC)+"_"+str(mutFreq)+".tsv" ## output file 2

            uc = groupUMI(ucfile)
            colnames =["name","mut_type", "mut_qual", "indel_type", "indel_qual"]
            dictbam = pd.read_csv(outDict,delimiter="\t",header=None,names=colnames)
            ucbam=pd.merge(uc,dictbam, on="name",how="inner")


            clusters = uc.cluster.unique()
            with open(outfile,"w") as f:
                for(my_cluster) in clusters:
                    sub=ucbam.loc[ucbam["cluster"]==my_cluster]
                    if len(sub) >= 3:
                        ssMuts = ccSeq(sub, frequency_threshold = 0.8, quality_threshold = 4)
                        if len(ssMuts) >0:
                            s=len(ssMuts)
    #                         print(ssMuts)
                            f.write(f"{my_cluster}\t{s}\t{','.join([i for i in ssMuts])}\n")


4T
5T
16T
47_1T
47_4T
47_5T
47_6T
47_8T
50T
66_1T


In [None]:
# SAMPLE_NAME = "DSST"
# SAMPLE_BARCODES = "/data/zhaolian/LineageTracing/DSS/PacBio/sampleBarcodes_"+SAMPLE_NAME+".txt"
BAMDIR="/data/zhaolian/LineageTracing/DSS/PacBio/5.bam_split/"
UCDIR="/data/zhaolian/LineageTracing/DSS/PacBio/4.umi_clustering/"
OUTDIR="/data/zhaolian/LineageTracing/DSS/PacBio/6.umiConsensus/"
numUC=3 ## number of ccs reads in one consensus cluster
mutFreq=0.8 ## export the mutation if mutation frequency >= mutFreq 

print(sample)

bamfile = BAMDIR+sample+".bam"
outDict = OUTDIR+sample+"_mutation_per_read"+str(numUC)+"_"+str(mutFreq)+".txt" ## output file 1
dictBam(bamfile, outDict)

ucfile = UCDIR+sample+"_umi_UsearchClusters.uc"
outfile = OUTDIR+sample+"_umiConsensus"+str(numUC)+"_"+str(mutFreq)+".tsv" ## output file 2

uc = groupUMI(ucfile)
colnames =["name","mut_type", "mut_qual", "indel_type", "indel_qual"]
dictbam = pd.read_csv(outDict,delimiter="\t",header=None,names=colnames)
ucbam=pd.merge(uc,dictbam, on="name",how="inner")


clusters = uc.cluster.unique()
with open(outfile,"w") as f:
    for(my_cluster) in clusters:
        sub=ucbam.loc[ucbam["cluster"]==my_cluster]
        if len(sub) >= 3:
            ssMuts = ccSeq(sub, frequency_threshold = 0.8, quality_threshold = 4)
            if len(ssMuts) >0:
                s=len(ssMuts)
#                         print(ssMuts)
                f.write(f"{my_cluster}\t{s}\t{','.join([i for i in ssMuts])}\n")
