In [None]:
import torch
import transformers


tokenizer = transformers.AutoTokenizer.from_pretrained(
    "zhihan1996/DNABERT-2-117M", trust_remote_code=True
)
model = transformers.AutoModel.from_pretrained(
    "zhihan1996/DNABERT-2-117M", trust_remote_code=True
)

In [None]:
# input
dna_list = [                        # [B]
    "ACGTAGCATCGGATCTATCTATCGACACTTGGTTATCGATCTACGAGCATCTCGTTAGC", 
    "CGTACGTACGTACGTACGTACGTACGTACGTACGTACGTACGTACGTACGTACGTACGTA",
    "TACGTACGTACGTACGTACGTACGTACGTACGTACGTACGTACGTACGTACGTACGTACG",
]
inputs = tokenizer(                 # [B, N]
    dna_list, return_tensors = 'pt', padding=True
)["input_ids"]
# output
hidden_states   = model(inputs)     # ( [B, N, 768], [B, 768] )
sequence_output = hidden_states[0]  # [B, N, 768]
pooled_output   = hidden_states[1]  # [B, 768]

# embedding with mean pooling
embedding_mean = torch.mean(sequence_output, dim=1)     # [B, 768]
# embedding with max  pooling
embedding_max = torch.max(sequence_output, dim=1)[0]    # [B, 768]

In [16]:
import pysam

samfile = pysam.AlignmentFile("data/bam/SRR8924580.bam", "r")
# ['1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13', '14', '15', '16', '17', '18', '19', '20', '21', '22', 'X', 'Y', 'M']
contig_list = list(samfile.references)[:25]

for contig in contig_list:
    for read in samfile.fetch(contig):
        print(read)
        break

HISEQ1011:883:H5K7HBCX2:1:1111:14583:9891	99	#0	10001	0	110M15S	#0	10015	139	TAACCCTAACCCTAACCCTAACCCTAACCCTAACCCTAACCCTAACCCTAACCCTAACCCTAACCCTAACCCTAACCCTAACCCTAACCCAAACCCTAACCCTACCCCAAACCCAAACCCAACCC	array('B', [38, 38, 38, 38, 38, 40, 40, 40, 38, 38, 40, 40, 40, 38, 40, 40, 40, 40, 38, 38, 40, 40, 38, 40, 38, 38, 38, 38, 40, 40, 38, 40, 40, 38, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 38, 13, 38, 38, 32, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 32, 38, 38, 38, 40, 38, 38, 38, 32, 27, 38, 38, 13, 32, 13, 38, 38, 38, 13, 38, 32, 38, 38, 38, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])	[('NM', 2), ('MD', '90T13A5'), ('MC', '125M'), ('AS', 100), ('XS', 100)]
HISEQ1011:883:H5K7HBCX2:1:1101:10177:74924	2225	#1	10340	0	5H22M7I36M55H	#2	198173993	0	ACCCTGACCCTGACCCTGACCGTAACCCTTAACCCTAACCCTAACCCGAACCCGAACCCGAACCC	array('B', [38, 40, 40, 38, 38, 38, 38, 40, 38, 38, 32, 38, 38, 40, 40, 38, 38, 38, 38, 38, 

In [17]:
import pysam


samfile = pysam.AlignmentFile("data/bam/SRR8924581.bam", "r")
# ['1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13', '14', '15', '16', '17', '18', '19', '20', '21', '22', 'X', 'Y', 'M']
contig_list = list(samfile.references)[:25]

total_num = 0
for contig in contig_list[:25]:
    num = 0
    length = 0
    for read in samfile.fetch(contig):
        num += 1
        length += len(read.query_sequence)
    total_num += num
    print("Chromosome {:<4} #reads: {:<10} length/read: {:.2f}".format(
        contig, num, length/num
    ))
print(total_num)
samfile.close()

Chromosome 1    #reads: 14365821   length/read: 124.91
Chromosome 2    #reads: 11719391   length/read: 124.80
Chromosome 3    #reads: 8972469    length/read: 124.90
Chromosome 4    #reads: 6865470    length/read: 124.88
Chromosome 5    #reads: 7387724    length/read: 124.89
Chromosome 6    #reads: 8102749    length/read: 124.90
Chromosome 7    #reads: 8284689    length/read: 124.89
Chromosome 8    #reads: 5706703    length/read: 124.88
Chromosome 9    #reads: 5387576    length/read: 124.90
Chromosome 10   #reads: 6425285    length/read: 124.88
Chromosome 11   #reads: 8160671    length/read: 124.92
Chromosome 12   #reads: 7868704    length/read: 124.91
Chromosome 13   #reads: 3275567    length/read: 124.87
Chromosome 14   #reads: 5022530    length/read: 124.90
Chromosome 15   #reads: 5277768    length/read: 124.91
Chromosome 16   #reads: 6004828    length/read: 124.90
Chromosome 17   #reads: 6911851    length/read: 124.90
Chromosome 18   #reads: 2847626    length/read: 124.89
Chromosome

In [21]:
import pysam


samfile = pysam.AlignmentFile("data/bam/SRR8924580.bam", "r")
# ['1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13', '14', '15', '16', '17', '18', '19', '20', '21', '22', 'X', 'Y', 'M']
contig_list = list(samfile.references)[:25]

total_num = 0
for contig in contig_list[:25]:
    num = 0
    length = 0
    for read in samfile.fetch(contig):
        num += 1
        length += len(read.query_sequence)
    total_num += num
    print("Chromosome {:<4} #reads: {:<10} length/read: {:.2f}".format(
        contig, num, length/num
    ))
print(total_num)
samfile.close()

Chromosome 1    #reads: 14153155   length/read: 124.91
Chromosome 2    #reads: 11614344   length/read: 124.81
Chromosome 3    #reads: 8924525    length/read: 124.90
Chromosome 4    #reads: 6866045    length/read: 124.88
Chromosome 5    #reads: 7368754    length/read: 124.89
Chromosome 6    #reads: 8027427    length/read: 124.91
Chromosome 7    #reads: 8076612    length/read: 124.89
Chromosome 8    #reads: 5691124    length/read: 124.88
Chromosome 9    #reads: 5651471    length/read: 124.90
Chromosome 10   #reads: 6387098    length/read: 124.88
Chromosome 11   #reads: 8049973    length/read: 124.92
Chromosome 12   #reads: 7769051    length/read: 124.91
Chromosome 13   #reads: 3291614    length/read: 124.88
Chromosome 14   #reads: 4955527    length/read: 124.90
Chromosome 15   #reads: 5218518    length/read: 124.91
Chromosome 16   #reads: 5950267    length/read: 124.90
Chromosome 17   #reads: 7060133    length/read: 124.90
Chromosome 18   #reads: 2848967    length/read: 124.90
Chromosome