## Pip installs

In [None]:
!pip install torch

In [None]:
!pip install transformers

In [None]:
!pip install datasets

In [None]:
!pip install pysam

In [None]:
!pip install HTSeq

## Loading DNA sequence data from BAMs to create dataset for fine tuning DNABERT

### Import stuff

In [None]:
import HTSeq
import pysam
from collections import defaultdict

### Function definitions

In [None]:
def seq2kmer_label(seq, k):
    """
    Convert original sequence to kmers, also labels data by the position of the mutation in the kmer 
    
    Arguments:
    seq -- str, original sequence.
    k -- int, kmer of length k specified.
    
    Returns:
    kmers -- str, kmers separated by space
    """
    kmer = [seq[x:x+k] for x in range(len(seq)+1-k)]
    kmers = " ".join(kmer)
    return kmers


In [None]:
def addBAM(bam_path, vcf_path):
    '''
        Given a path pointing ot a bam file, return a json or hf dataset object with following info:
            1. iterate through the bam file and get sequence info on reads containing mutations appearing in vcf file
            2. for each read want to have the chr #, description, and whther it is in exonic portion of a gene
                -if not exonic, ideally want what type non-coding element the read is a part of
            3. If a mutation, want info for Mutation from Clinvar, Cosmic, NCBI
            4. desc of gene from NCBI 
            
            desc of sam_alignment from htseq:
                >>> aln.iv
                <GenomicInterval object 'IV', [246048,246084), strand '+'>
                >>> aln.iv.chrom
                'IV'
                >>> aln.iv.start
                246048
                >>> aln.iv.end
                246084
                >>> aln.iv.strand
                '+'
    '''
    data_dict = {}
    with HTSeq.BAM_Reader(bam_path) as f:
        for i, sam_alignment in enumerate(f):
            ### did this a alittle backwards, should iterate through gtf file to get gene name and interval, then (hopefully) use that
            ### to index into the bam file to get the sequence rather than iterating through all parts of the bam file
            if sam_alignment.aligned == True:
                ### i'm not sure what's gonna be the most useful thing for llama to map all the info to, starting with a string with
                ### chrom and pos and what file it's from
                chrom_pos_identifier = sam_alignment.iv.chrom + ' START: ' + str(sam_alignment.iv.start) + ' END: ' + str(sam_alignment.iv.end) 
                data_dict[chrom_pos_identifier] = [
                {
                    'read_name':sam_alignment.read.name,
                    'seq':sam_alignment.read
                }
                ]
            print(data_dict)
            ## for testing don't do the whole thing
            if i == 5:
                break

In [None]:
def prepareBertDataset(bam_file, vcf_file):
    '''
        get mutations from vcf file, then get the read containing that mutation from the bam file, then mask that mutation and format
        for BERT (i.e. token with the mask at the mutation position so it can learn with PTA artifacts look like
        
        dataset should look like this:
        DatasetDict({
                train: Dataset({
                    features: ['squence', 'label'],
                    num_rows: 25000
                })
                test: Dataset({
                    features: ['sequence', 'label'],
                    num_rows: 25000
                })
                unsupervised: Dataset({
                    features: ['sequence', 'label'],
                    num_rows: 50000
                })
            })
        
        for row in sample:
            print(f"\n'>>> Sequence: {row['sequence']}'")
            print(f"'>>> Label: {row['label']}'")

        '>>> Sequence: ACTAGATAGATA'
        '>>> Label: 0'
                            
        '>>> Review: ACATAGATATATA'
                                ^
                            mutated base
        '>>> Label: 1'
        
        ### If we do this the model will only learn whether a mutation is in the sequence, can we label it so the number corresponds to
        ### what base is mutated, i.e. 0 is no mutation?
        
        '>>> Sequence: ACTAGATAGATA'
        '>>> Label: 0'
                            
        '>>> Review: ACATAGATATATA'
                                ^
                            mutated base
        '>>> Label: 12'
        
        ### this also seems like it won't leverage the power of the whole dna sequence
        
    '''
    
    import pysam

    bamfile = pysam.AlignmentFile("your_bam_file.bam", "rb")
    reads = bamfile.fetch("chromosome_name", start_position, end_position
    

In [None]:
## variable definitions
bam_dir='/scratch/users/sschulz/pta_on_normal'
bam_file='CART-MRD-BALL-PTA-NEXTERA-WGS-CCT5007Pt04-B8_S51.realigned_deduped_sorted.bam'
bam_path = bam_dir + '/' + bam_file
vcf_name='CARTPt04_Scan2_svc_merged_extract_snp.vcf'
vcf_path = bam_dir + '/' + vcf_name

In [None]:
bamfile = pysam.AlignmentFile(bam_path, "rb")
reads = bamfile.fetch("chr1", 10, 100000)

In [None]:
bamfile

In [None]:
samfile=bamfile
pileup = samfile.pileup('chr1', 1000, 20000)
for x in pileup:
    print(str(x))

In [None]:
pileup

## DNA tokenizer Test


In [None]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("AIRI-Institute/gena-lm-bert-base")


In [None]:
encoded_input = tokenizer("ACGTGGTATGATGATAGATGATGA")


In [None]:
print(encoded_input)

In [None]:
tokenizer.decode(encoded_input["input_ids"])

In [None]:
batch_sequences = [
    "ACGTAGCTGACTGACTTAGTGA",
    "ACTAGCATGCATCGTAGCTAGCTAGACTGA",
    "ATATATATTACACACACGAGACTAGCTT",
]

In [None]:
encoded_input=tokenizer(batch_sequences, padding=True, truncation=True)

In [None]:
print(encoded_input)

In [None]:
for i in encoded_input['input_ids']:
    print(tokenizer.decode(i))

In [None]:
## p much above but we've padded, truncated (no maximum length provided tho) and returned tensors
encoded_input = tokenizer(batch_sequences, padding=True, truncation=True, return_tensors="pt")

In [None]:
## so we can tokenize DNA sequences, but how do we 