# scbasset implementation

## outline
**1. write static method that:**
- downloads DNA seq
- add sequences from fasta to anndata
- adds regions x 1344 sparse matrix
    - adata.varm['dna'] and ints of each a,c,t,g (0,1,2,3)
- take transpose: bdata = adata.transpose()
- SCBASSET.add_seqs(adata, species:str)

**2. implement model in pytorch lightning**
- use anndataloader
    - x.varm['codes'] --> ADL --> for md in ADL: mb['codes'] (regions x 1334 x 4 matrix) and mb['x'] (loss)
    
## resources
- [scvi-intro-colab](https://colab.research.google.com/drive/1RSm8IU3NK-xGNiRTDcTnX8Hbcj0gS_1S)
- [scvi-data-loading](https://colab.research.google.com/drive/1iQOo2SoqNSC_uRPt9jiTA8xUaV4-909K#scrollTo=5-_VxGJ9xWaW)
- [scvi-peakvi](https://colab.research.google.com/github/scverse/scvi-tutorials/blob/0.17.4/PeakVI.ipynb)

In [2]:
import os
from pathlib import Path
import scvi
import anndata
from anndata import AnnData
import pandas as pd
import numpy as np
import scanpy as sc
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import pysam
import random
import scipy

Global seed set to 0
  new_rank_zero_deprecation(
  return new_rank_zero_deprecation(*args, **kwargs)


## preprocess
### download fasta file
- use gget
- from scbasset, make_bed_seqs_from_df
    - reimplement this function
    - pysam package
- scvi.data.read_10x_atac
- adata.var has the coordinates
    - index is regions

In [3]:
# download sample data
## 10x atac pbmc 5k nextgem data
!wget https://cf.10xgenomics.com/samples/cell-atac/1.2.0/atac_pbmc_5k_nextgem/atac_pbmc_5k_nextgem_filtered_peak_bc_matrix.tar.bgz
!tar -xvf atac_pbmc_5k_nextgem_filtered_peak_bc_matrix.tar.bgz

--2022-11-29 19:23:51--  https://cf.10xgenomics.com/samples/cell-atac/1.2.0/atac_pbmc_5k_nextgem/atac_pbmc_5k_nextgem_filtered_peak_bc_matrix.tar.bgz
Resolving cf.10xgenomics.com (cf.10xgenomics.com)... 2606:4700::6812:ad, 2606:4700::6812:1ad, 104.18.1.173, ...
Connecting to cf.10xgenomics.com (cf.10xgenomics.com)|2606:4700::6812:ad|:443... connected.
HTTP request sent, awaiting response... 403 Forbidden
2022-11-29 19:23:55 ERROR 403: Forbidden.

tar: Error opening archive: Failed to open 'atac_pbmc_5k_nextgem_filtered_peak_bc_matrix.tar.bgz'


In [4]:
# read data
adata = scvi.data.read_10x_atac("filtered_peak_bc_matrix")
adata

  return AnnData(data.tocsr(), var=coords, obs=cell_annot)


AnnData object with n_obs × n_vars = 4585 × 115554
    obs: 'batch_id'
    var: 'chr', 'start', 'end'

In [5]:
# download sample fasta file
## from scbasset tutorial, homo sapiens motif fasta file
!wget https://storage.googleapis.com/scbasset_tutorial_data/Homo_sapiens_motif_fasta.tar.bgz
!tar -xvf Homo_sapiens_motif_fasta.tar.bgz

--2022-11-29 19:25:09--  https://storage.googleapis.com/scbasset_tutorial_data/Homo_sapiens_motif_fasta.tar.bgz
Resolving storage.googleapis.com (storage.googleapis.com)... 2607:f8b0:4005:814::2010, 2607:f8b0:4005:80d::2010, 2607:f8b0:4005:811::2010, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|2607:f8b0:4005:814::2010|:443... connected.
HTTP request sent, awaiting response... 404 Not Found
2022-11-29 19:25:09 ERROR 404: Not Found.

tar: Error opening archive: Failed to open 'Homo_sapiens_motif_fasta.tar.bgz'


In [6]:
# import subprocess

# download_savepath = '../data/downloads'
# os.makedirs(download_savepath, exist_ok=True)

# if not os.path.exists('%s/Homo_sapiens_motif_fasta.tar.bgz'%download_savepath):
#     subprocess.run('wget -P %s https://storage.googleapis.com/scbasset_tutorial_data/Homo_sapiens_motif_fasta.tar.bgz'%download_savepath, shell=True)
# subprocess.run('tar -xzf %s/Homo_sapiens_motif_fasta.tar.bgz -C %s/'%(download_savepath, download_savepath), shell=True)

In [7]:
# fasta path for example peaks
fasta_path = "Homo_sapiens_motif_fasta/example_peaks.fasta"

In [8]:
!wget https://cf.10xgenomics.com/samples/cell-atac/1.2.0/atac_pbmc_5k_nextgem/atac_pbmc_5k_nextgem_filtered_peak_bc_matrix.tar.gz
!tar -xvf atac_pbmc_5k_nextgem_filtered_peak_bc_matrix.tar.gz

--2022-11-29 19:25:11--  https://cf.10xgenomics.com/samples/cell-atac/1.2.0/atac_pbmc_5k_nextgem/atac_pbmc_5k_nextgem_filtered_peak_bc_matrix.tar.gz
Resolving cf.10xgenomics.com (cf.10xgenomics.com)... 2606:4700::6812:1ad, 2606:4700::6812:ad, 104.18.1.173, ...
Connecting to cf.10xgenomics.com (cf.10xgenomics.com)|2606:4700::6812:1ad|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 114015463 (109M) [application/x-tar]
Saving to: ‘atac_pbmc_5k_nextgem_filtered_peak_bc_matrix.tar.gz.2’


2022-11-29 19:27:23 (847 KB/s) - ‘atac_pbmc_5k_nextgem_filtered_peak_bc_matrix.tar.gz.2’ saved [114015463/114015463]

x filtered_peak_bc_matrix/
x filtered_peak_bc_matrix/matrix.mtx
x filtered_peak_bc_matrix/peaks.bed
x filtered_peak_bc_matrix/barcodes.tsv


In [9]:
# read the count matrix into a sparse matrix, and the cell and region annotations as pandas DataFrames
counts = scipy.io.mmread("filtered_peak_bc_matrix/matrix.mtx").T
regions = pd.read_csv("filtered_peak_bc_matrix/peaks.bed", sep='\t', header=None, names=['chr','start','end'])
cells = pd.read_csv("filtered_peak_bc_matrix/barcodes.tsv", header=None, names=['barcodes'])

# then initialize a new AnnData object
adata = anndata.AnnData(X=counts, obs=cells, var=regions)
adata

  adata = anndata.AnnData(X=counts, obs=cells, var=regions)


AnnData object with n_obs × n_vars = 4585 × 115554
    obs: 'barcodes'
    var: 'chr', 'start', 'end'

In [10]:
adata

AnnData object with n_obs × n_vars = 4585 × 115554
    obs: 'barcodes'
    var: 'chr', 'start', 'end'

In [11]:
# do i need this?

def sequence_code(seq, seq_len):
    # get middle seq_len nucleotides
    if seq_len is None:
        seq_start = 0
        seq_len = len(seq)
    else:
        if seq_len <= len(seq):
            seq_start = 0
            seq_trim = (len(seq) - seq_len) // 2
            seq = seq[seq_trim:seq_trim+seq_len]
        else:
            seq_start = (seq_len - len(seq)) // 2
    seq = seq.upper()

    # initialize seq_code
    seq_code = np.zeros((seq_len, ), dtype="int8")

    # sequence code matrix: seq_len x 4
    for i in range(seq_len):
        if i >= seq_start and i - seq_start < len(seq):
            nt = seq[i - seq_start]
            if nt == "A":
                seq_code[i] = 0
            elif nt == "C":
                seq_code[i] = 1
            elif nt == "G":
                seq_code[i] = 2
            elif nt == "T":
                seq_code[i] = 3
            else:
                seq_code[i] = random.randint(0, 3)
    # 4 x seq_len
    return seq_code.T

In [12]:
sequence_code('ATCGTGCATTCGAT', 10)

array([1, 2, 3, 2, 1, 0, 3, 3, 1, 2], dtype=int8)

In [13]:
adata.var

Unnamed: 0,chr,start,end
0,chr1,10404,10411
1,chr1,237567,237947
2,chr1,565116,565538
3,chr1,569178,569639
4,chr1,713460,715296
...,...,...,...
115549,chrY,23602417,23602787
115550,chrY,23898794,23899450
115551,chrY,28816591,28817535
115552,chrY,58827188,58827516


In [16]:
adata.varm

AxisArrays with keys: 

In [14]:
# @static_method
def add_sequence_to_adata(adata: AnnData, 
                          fasta_file: Path, 
                          seq_len: int, 
                          chr_var_key: str='chr',
                          start_var_key: str='start',
                          end_var_key: str='end',
                          sequence_varm_key='sequence', 
                          code_varm_key='code'
                         ) -> None:
    # assume we start with the annData object
        # read_10_atac returns anndata object, where var=coordinates, obs=cell annotations

    # get fasta file
    fasta_open = pysam.Fastafile(fasta_file)
    
    seqs_dna = []
    seqs_coords = []

    for i in range(adata.shape[1]):
        chr = adata.var.loc[i, chr_var_key]
        start = int(adata.var.loc[i, start_var_key])
        end = int(adata.var.loc[i, end_var_key])

        # determine sequence limits
        mid = (start + end) // 2
        seq_start = mid - seq_len // 2
        seq_end = seq_start + seq_len

        seqs_coords.append((chr, seq_start, seq_end))

        # initialize sequence
        dna = ""

        # get dna
        dna += fasta_open.fetch(chr, seq_start, seq_end).upper()
        seqs_dna.append(dna)

    fasta_open.close()


    # function for sequence --> sequence codes
    # initialize seq_code
    seq_code = np.zeros((seq_len, ), dtype="int8")

    # sequence code matrix: seq_len x 4
    for i in range(seq_len):
        if i >= seq_start and i - seq_start < len(seqs_dna):
            nt = seqs_dna[i - seq_start]
            if nt == "A":
                seq_code[i] = 0
            elif nt == "C":
                seq_code[i] = 1
            elif nt == "G":
                seq_code[i] = 2
            elif nt == "T":
                seq_code[i] = 3
            else:
                seq_code[i] = random.randint(0, 3)
    # 4 x seq_len
    seqs_dna_code = seq_code # or seq_code.T?

    # add seqs and codes to adata.varm
    adata.varm[sequence_varm_key] = seqs_dna
    adata.varm[code_varm_key] = seqs_dna_code

In [15]:
add_sequence_to_adata(adata, fasta_file=fasta_path, seq_len=1344)

KeyError: 0

In [None]:
# def sparse(adata, h5_name, fasta_file, seq_len, batch_size):
#     ad = read_10x_atac(adata)
#     n_peaks = ad.shape[1]
#     bed_df = ad.var.loc[:, ['chr', 'start', 'end']] # bed file
#     bed_df.index = np.arange(bed_df.shape[0])
#     n_batch = int(np.floor(n_peaks / batch_size))
#     batches = np.array_split(np.arange(n_peaks), n_batch)

## model
- take transpose: bdata = adata.transpose()
- use anndataloader
    - x.varm['codes'] --> ADL --> for md in ADL: mb['codes'] (regions x 1334 x 4 matrix) and mb['x'] (loss)

In [17]:
# annDataLoader
from scvi.dataloaders import AnnDataLoader

In [None]:
class CNN(nn.Module):
    def cnn(n_cells, seq_len=1344, bottleneck_size):
        """cnn"""
        def conv_tower(inputs, filters_init, filters_end=None, filters_mult=None, divisible_by=1, repeat=1, **kwargs):
            """this function was taken from scbasset"""
            def round(x):
                return int(np.round(x / divisible_by) * divisible_by)
        
            current = inputs
            rep_filters = filters_init

            if filters_mult is None:
                filters_mult = np.exp(np.log(filters_end / filters_init) / (repeat - 1))

            for i in range(repeat):
                current = nn.Conv1d(in_channels=current, filters=round(rep_filters), **kwargs)
                rep_filters *= filters_mult

            return current


        def __init__(self):
            super(CNN, self).__init__()
            # input_shape=(seq_len=1344, 4) so input shape is 1344x4
            # stochasticshift(3)?

            # for in_channels, we want [batch_size, seq_length, features], in_channels = 3?
            self.conv1 = nn.Conv1d(in_channels=3, out_channels=288, kernel_size=17)
            # self.pool = nn.MaxPool1d(kernel_size=3)
            
            # conv tower (reducing conv blocks) ?
            self.conv2 = conv_tower(inputs=__, filters_init=288, filters_mult=1.122, repeat=6, kernel_size=5)
            # conv block
            self.conv3 = nn.Conv1d(in_channels=___, out_channels=256, kernel_size=1)

            # dense conv block (linear)
            # self.fc1 = nn.Linear(in_features=__, out_features=bottleneck_size)
            # self.fc2 = nn.Linear(in_features=bottleneck_size, out_features=n_cells)

            self.fc1 = scvi.nn.FCLayers(n_in=__, in_out=bottleneck_size, dropout_rate=0.2, activation_fn=F.gelu)
            self.fc2 = scvi.nn.FCLayers(n_in=bottleneck_size, in_out=n_cells, dropout_rate=0.2, activation_fn=F.gelu)

            # switchreverse?


        def forward(self, x):
            x = F.gelu(self.conv1(x))
            x = F.max_pool1d(x, 3)
            x = F.gelu(self.conv2(x))
            x = F.max_pool1d(x, 2)
            x = F.gelu(self.conv3(x))
            x = F.max_pool1d(x, 1)
            x = F.gelu(self.fc1(x))
            # x = F.dropout(x, p=0.2)
            x = F.sigmoid(self.fc2(x))

            return x
