## Pip installs

In [None]:
!pip install torch

In [None]:
!pip install transformers

In [None]:
!pip install datasets

In [None]:
!pip install pysam

In [None]:
!pip install HTSeq

In [5]:
!pip install enformer-pytorch>=0.5

In [3]:
!pip install polars

Collecting polars
  Downloading polars-0.16.18-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (16.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m16.4/16.4 MB[0m [31m30.8 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
Installing collected packages: polars
Successfully installed polars-0.16.18


## Loading DNA sequence data from BAMs to create dataset for fine tuning enformer

In [2]:
import torch
import polars as pl
from enformer_pytorch import Enformer, GenomeIntervalDataset
from datasets import concatenate_datasets, load_dataset
import os

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
### define variables
hg38_bed_path = "/scratch/users/sschulz/pta_on_normal/chr10.bed"
bam_dir= "/scratch/users/sschulz/pta_on_normal"
fasta_dir = bam_dir + '/fastas'
dataset_name='chr10_test_dataset.hf'
preprocess_bam=False

In [4]:
def preProcessBam(bam_dir, fasta_dir):
    '''
        take directory of bam files and make a subdirecotry fasta_dir 
        containing fasta files
    '''

    !sh bam2fasta.sh {bam_dir} {fasta_dir}

In [5]:
def makeEnformerDataset(fasta_path, bed_path):
    '''
        makes enformer dataset for genome intervals in bed file at bed_path, returns the dataset
    '''
    #filter_train = lambda df: df.filter(pl.col('column_4') == 'train')
    ds = GenomeIntervalDataset(
        bed_file = bed_path, ##<- this can just be the whole hg38.bed for all chromosomes
        fasta_file = fasta_path,  ## path to fasta file
        #filter_df_fn = filter_train,                        # filter dataframe function
        return_seq_indices = False,                          # return nucleotide indices (ACGTN) or one hot encodings
        shift_augs = (-2, 2),                               # random shift augmentations from -2 to +2 basepairs
        rc_aug = True,                                      # use reverse complement augmentation with 50% probability
        context_length = 196_608,
        return_augs = True                                  # return the augmentation meta data
        )
    return ds

In [6]:
def combineDatasets(dataset_list):
    '''
        uses hf datasets to concatenate a list of datasets
    '''
    return concatenate_datasets(dataset_list)

In [7]:
def trainEnformerPTA(dataset):
    '''
        should train the enformer model
    '''
    
    #filter_train = lambda df: df.filter(pl.col('column_4') == 'train')

    model = HeadAdapterWrapper(
                enformer = dataset,
                num_tracks = 128,
                post_transformer_embed = False   # by default, embeddings are taken from after the final pointwise block w/ conv -> gelu - but if you'd like the embeddings right after the transformer block with a learned layernorm, set this to True
            ).cuda()

    return model

In [19]:
if preprocess_bam:
    preProcessBam(bam_dir, fasta_dir)

-rw-r--r-- 1 sschulz cgawad 0 Apr  4 14:18 /scratch/users/sschulz/pta_on_normal/fastas/CART-MRD-BALL-PTA-NEXTERA-WGS-CCT5007Pt04-B8_S51.deduped_sorted.fasta
/scratch/users/sschulz/pta_on_normal/fastas/CART-MRD-BALL-PTA-NEXTERA-WGS-CCT5007Pt04-B8_S51.deduped_sorted.fasta
bam2fasta.sh: line 21: -s: command not found

Submitted batch job 14875317
-rw-r--r-- 1 sschulz cgawad 0 Apr  4 14:18 /scratch/users/sschulz/pta_on_normal/fastas/CART-MRD-BALL-PTA-NEXTERA-WGS-CCT5007Pt04-B8_S51.realigned_deduped_sorted.fasta
/scratch/users/sschulz/pta_on_normal/fastas/CART-MRD-BALL-PTA-NEXTERA-WGS-CCT5007Pt04-B8_S51.realigned_deduped_sorted.fasta
bam2fasta.sh: line 21: -s: command not found

Submitted batch job 14875318
-rw-r--r-- 1 sschulz cgawad 0 Apr  4 14:18 /scratch/users/sschulz/pta_on_normal/fastas/CART-MRD-BALL-PTA-NEXTERA-WGS-CCT5007Pt04-B8_S51.recalibrated_realigned_deduped_sorted.fasta
/scratch/users/sschulz/pta_on_normal/fastas/CART-MRD-BALL-PTA-NEXTERA-WGS-CCT5007Pt04-B8_S51.recalibrated_re

In [None]:
dataset_list = []
fai_dir=fasta_dir + '/fais'
!mkdir -p {fai_dir}
for fasta_name in os.listdir(fasta_dir):
    !mv {fasta_dir}/*.fai {fai_dir}
    fasta_path=fasta_dir + '/' + fasta_name
    if not os.path.splitext(fasta_path)[1] == 'fai':
        print('Now collecting dataset for: ' + fasta_path)
        dataset_list.append(makeEnformerDataset(fasta_path, hg38_bed_path))

mv: cannot stat ‘/scratch/users/sschulz/pta_on_normal/fastas/*.fai’: No such file or directory
Now collecting dataset for: /scratch/users/sschulz/pta_on_normal/fastas/CART-MRD-BALL-PTA-NEXTERA-WGS-CCT5007Pt04-A9_S57.realigned_deduped_sorted.fasta
Now collecting dataset for: /scratch/users/sschulz/pta_on_normal/fastas/CART-MRD-BALL-PTA-NEXTERA-WGS-CCT5007Pt04-B10_S65.realigned_deduped_sorted.fasta


In [None]:
combined_dataset = combineDatasets(dataset_list)

In [None]:
combined_dataset.save_to_disk(bam_dir + '/' + dataset_name)

In [None]:
model = trainEnformerPTA(combined_datset)

In [None]:
seq = ### enter a sequence to test
pred = model(seq, head = 'human') # (896, 5313) ###<- what does human mean here?