In [2]:
from fastai.basics import *
from fastai.text.all import *
from torch.utils.data import Dataset, TensorDataset
from Bio import SeqIO
from tqdm.notebook import tqdm

In [None]:
class PeakCountDataset(Dataset):
    def __init__(self, count_file, genome_file):
        genes = SeqIO.to_dict(SeqIO.parse(genome_file,'fasta'))
        npzf = np.load(count_file)
        peaks, counts = npzf['peaks'], npzf['values'].flatten()
        self.chars = ['A','T','C','G','N']
        self.int2char = dict(enumerate(self.chars))
        self.char2int = {ch: ii for ii, ch in self.int2char.items()}
        
        def pull_seq(peak):
            seq_idx = peak.decode()
            gene, start, end = seq_idx.split('-')
            start, end = int(start), int(end)
            seq = genes['chr'+gene][start:end]
            decoded = str(seq.seq).upper()
            encoded = np.array([self.char2int[ch] for ch in decoded])
            one_hot = torch.nn.functional.one_hot(torch.tensor(encoded), len(self.chars))
            return gene, seq, one_hot
                   
        self.chromo, self.seqs, self.one_hot = zip(*[pull_seq(x) for x in tqdm(peaks)])
        self.counts = counts
        self.genes = genes
        self.one_hot = torch.stack(self.one_hot)
        
    def __getitem__(self, index):
        x = self.one_hot[index]
        y = self.counts[index]
        
        return x, y
    
    def __len__(self):
        return len(self.counts)

In [None]:
ds = PeakCountDataset('data/new_pseudobulk/300bp_slop_counts/Astrocytes.npz', 'data/rn6.fa')

In [3]:
npzf = np.load('data/new_pseudobulk/300bp_slop_counts/Astrocytes.npz')
peaks, counts = npzf['peaks'], torch.tensor(npzf['values'].flatten())
genes = SeqIO.to_dict(SeqIO.parse('data/rn6.fa','fasta'))

In [20]:
nucleotides = ['A','T','C','G','N']
int2char = dict(enumerate(nucleotides))
char2int = {ch: ii for ii, ch in int2char.items()}
target_len = 601

encoded = []
keep_counts = []
for i, peak in tqdm(list(enumerate(peaks))):
    c, start, end = peak.decode().split('-')
    start, end = int(start), int(end)
    if (end-start) == target_len:
        seq = genes['chr'+c][start:end]
        encoded.append(np.array([char2int[ch] for ch in str(seq.seq).upper()]))
        keep_counts.append(counts[i])
peak_seqs = torch.nn.functional.one_hot(torch.tensor(np.stack(encoded, axis=0)), len(nucleotides))
keep_counts = torch.tensor(keep_counts)

  0%|          | 0/411969 [00:00<?, ?it/s]

In [24]:
peak_seqs.permute(0,2,1).shape

torch.Size([403624, 5, 601])

In [25]:
ds = TensorDataset(peak_seqs.permute(0,2,1).float(), keep_counts)

In [26]:
torch.save(ds, 'peaks.pt')

In [None]:
ds = torch.load('peaks.pt')

In [None]:
dl = DataLoader(ds, bs=16)

In [None]:
x, y = dl.one_batch()
x.shape, y.shape

In [None]:
str(seq.seq)

In [None]:
sp.encode(str(seq.seq), out_type='str')

In [None]:


def get_len(x):
    c, start, end = x.decode().split('-')
    start, end = int(start), int(end)
    if ((end-start) > 1000):
        seq = genes['chr'+c][start:end]
        print(seq)
    return end-start

s = pd.Series([get_len(x) for x in peaks])

In [None]:
s.hist()

In [None]:
import sentencepiece as spm

In [None]:
with open('rn6_out.txt','w') as rn6_file:
    for k in tqdm(ds.genes.keys()):
        rn6_file.write('%s\n' % str(ds.genes[k].seq).upper())

In [None]:
spm.SentencePieceTrainer.train(input='rn6_out.txt', model_prefix='m', vocab_size=1000 )

In [None]:
sp = spm.SentencePieceProcessor(model_file='m.model')

In [None]:
peaks[0]

In [None]:
encoded

In [None]:
sp.encode?

In [None]:
ds[0][0]