## Pip installs

In [None]:
!pip install torch

In [None]:
!pip install transformers

In [None]:
!pip install datasets

In [None]:
!pip install pysam

In [None]:
!pip install HTSeq

In [5]:
!pip install enformer-pytorch>=0.5

In [3]:
!pip install polars

Collecting polars
  Downloading polars-0.16.18-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (16.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m16.4/16.4 MB[0m [31m30.8 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
Installing collected packages: polars
Successfully installed polars-0.16.18


In [87]:
!pip install scikit-learn



## Module imports and variable definitions

In [1]:
import torch
import polars as pl
from enformer_pytorch import Enformer, GenomeIntervalDataset
from datasets import concatenate_datasets, load_dataset
import os
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
### define variables
hg38_bed_path = "/scratch/users/sschulz/pta_on_normal/Homo_sapiens_assembly38_n25chr.bed"
bam_dir= "/scratch/users/sschulz/pta_on_normal"
fasta_dir = bam_dir + '/fastas'
dataset_name='test_dataset.hf'
preprocess_bam=True
vcf_file=bam_dir +'/CARTPt04_Scan2_svc_merged_extract_snp.vcf.gz'
ref_fasta=bam_dir + '/Homo_sapiens_assembly38.fasta'
results_dir=bam_dir
output_fasta_fn='chr_test_fasta.fasta'
sbatch=0

## Function definitions

In [3]:
def preProcessVCF(vcf_path, ref_fasta, results_dir, output_fasta_fn, sbatch):
    '''
        take directory of a combined vcf file with somatic mutations called for normal cells
        with pta run on them and make a fasta file 
        
        TO-DO: make something that generates the vcf for you if given normal bam files
    '''

    !sh vcf2fasta.sh --vcf_path $vcf_path --ref_fasta $ref_fasta --results_dir $results_dir --output_name $output_fasta_fn --sbatch $sbatch

In [4]:
def makeEnformerDataset(fasta_path, bed_path):
    '''
        makes enformer dataset for genome intervals in bed file at bed_path, returns the dataset
    '''
    #filter_train = lambda df: df.filter(pl.col('column_4') == 'train')
    ds = GenomeIntervalDataset(
        bed_file = bed_path, ##<- this can just be the whole hg38.bed for all chromosomes
        fasta_file = fasta_path,  ## path to fasta file
        #filter_df_fn = filter_train,                        # filter dataframe function
        return_seq_indices = False,                          # return nucleotide indices (ACGTN) or one hot encodings
        shift_augs = (-2, 2),                               # random shift augmentations from -2 to +2 basepairs
        rc_aug = True,                                      # use reverse complement augmentation with 50% probability
        context_length = 196_608,
        return_augs = True                                  # return the augmentation meta data
        )
    return ds

In [5]:
def trainEnformerPTA(dataset):
    '''
        should train the enformer model
    '''
    
    #filter_train = lambda df: df.filter(pl.col('column_4') == 'train')

    model = HeadAdapterWrapper(
                enformer = dataset,
                num_tracks = 128,
                post_transformer_embed = False   # by default, embeddings are taken from after the final pointwise block w/ conv -> gelu - but if you'd like the embeddings right after the transformer block with a learned layernorm, set this to True
            ).cuda()

    return model

In [7]:
def makeOneHot(fastaFilePath, num_lines):
    '''
       Give path to a fasta file and the number of lines you want to turn into one hot encodings, get the one_hot encoded
       tensor
    '''
    with open(fastaFilePath, 'r') as input_file:
        try:
            head = [next(input_file).strip('\n') for _ in range(num_lines)]
        except StopIteration as e:
            print("Exceeded length of input file, stopping.")
    seq = str_to_one_hot(head)
    input_file.close()
    return(seq)
        

In [8]:
## this was all taken from lucidrains/enformer-pytorch, i have no clue if i could just source this somehow and i don't care

def torch_fromstring(seq_strs):
    def cast_list(t):
        return t if isinstance(t, list) else [t]
    batched = not isinstance(seq_strs, str)
    seq_strs = cast_list(seq_strs)
    np_seq_chrs = list(map(lambda t: np.fromstring(t, dtype = np.uint8), seq_strs))
    seq_chrs = list(map(torch.from_numpy, np_seq_chrs))
    return torch.stack(seq_chrs) if batched else seq_chrs[0]

def str_to_one_hot(seq_strs):
    seq_chrs = torch_fromstring(seq_strs)
    return one_hot_embed[seq_chrs.long()]

one_hot_embed = torch.zeros(256, 4)
one_hot_embed[ord('a')] = torch.Tensor([1., 0., 0., 0.])
one_hot_embed[ord('c')] = torch.Tensor([0., 1., 0., 0.])
one_hot_embed[ord('g')] = torch.Tensor([0., 0., 1., 0.])
one_hot_embed[ord('t')] = torch.Tensor([0., 0., 0., 1.])
one_hot_embed[ord('n')] = torch.Tensor([0., 0., 0., 0.])
one_hot_embed[ord('A')] = torch.Tensor([1., 0., 0., 0.])
one_hot_embed[ord('C')] = torch.Tensor([0., 1., 0., 0.])
one_hot_embed[ord('G')] = torch.Tensor([0., 0., 1., 0.])
one_hot_embed[ord('T')] = torch.Tensor([0., 0., 0., 1.])
one_hot_embed[ord('N')] = torch.Tensor([0., 0., 0., 0.])
one_hot_embed[ord('.')] = torch.Tensor([0.25, 0.25, 0.25, 0.25])

## Preprocessing

In [33]:
## preprocessing the target fasta
vcf_file=bam_dir +'/CARTPt04_Scan2_svc_merged_extract_snp.vcf.gz'
ref_fasta=bam_dir + '/Homo_sapiens_assembly38.fasta'
results_dir=bam_dir
output_fasta_fn='final_fasta.fasta'
if preprocess_bam:
    preProcessVCF(vcf_file, ref_fasta, results_dir, output_fasta_fn, sbatch)

Note: the --sample option not given, applying all records regardless of the genotype
The site chr1:159782470 overlaps with another variant, skipping...
The site chr1:206939504 overlaps with another variant, skipping...
The site chr2:32811257 overlaps with another variant, skipping...
The site chr5:141375502 overlaps with another variant, skipping...
The site chr6:349203 overlaps with another variant, skipping...
The site chr6:42236345 overlaps with another variant, skipping...
The site chr7:142760369 overlaps with another variant, skipping...
The site chr7:142772094 overlaps with another variant, skipping...
The site chr9:33796768 overlaps with another variant, skipping...
The site chr9:33796801 overlaps with another variant, skipping...
The site chr9:106926826 overlaps with another variant, skipping...
The site chr11:1017451 overlaps with another variant, skipping...
The site chr11:1017461 overlaps with another variant, skipping...
The site chr11:1017466 overlaps with another variant,

In [None]:
## preprocessing the input fasta
vcf_file=bam_dir +'/CARTPt04_Scan2_svc_merged_extract_snp.vcf.gz'
ref_fasta=bam_dir + '/Homo_sapiens_assembly38.fasta'
results_dir=bam_dir
output_fasta_fn='test_fasta.fasta'
if preprocess_bam:
    preProcessVCF(vcf_file, ref_fasta, results_dir, output_fasta_fn, sbatch)

## get one-hot encodings with torch

In [11]:
lines_of_fasta_file=3000
targetFastaFilePath = bam_dir + '/final_fasta.fasta'

target = makeOneHot(targetFastaFilePath, lines_of_fasta_file)

  np_seq_chrs = list(map(lambda t: np.fromstring(t, dtype = np.uint8), seq_strs))


In [None]:
lines_of_fasta_file=3000
inputFastaFilePath = bam_dir + '/test_fasta.fasta'

input_seq = makeOneHot(targetFastaFilePath, lines_of_fasta_file)

In [12]:
input_seq = input_seq.cuda()

3000

In [None]:
target = target.cuda()

## Run the model on new DNA tracks with somatic mutations detected in normal cells

In [None]:
import torch
from enformer_pytorch import Enformer, seq_indices_to_one_hot

# will need to figure out correct way to configure hyperparams
model = Enformer.from_hparams(
    dim = 1536,
    depth = 11,
    heads = 8,
    output_heads = dict(human = 5313), #<- tbh idk what the correct number of output heads should be
    target_length = 3000,
)


output, embeddings = model(seq, return_embeddings = True)

'''
    target should EITHER be the one hot encodings of an input fasta with many artifacts, or the weights from 
    just running the model on the one hot encodings without a target, not exactly sure which but try the 
    one hot encodigns first
    
    input_seq should be the one-hot embeddings of any fasta
    
    corr_coef should be higher if many artifacts, lower if not many artifacts. check against a bulk sequencing sample,
    another fasta that should have normal artifacts in it, and a single cell DNA sequencing sample
'''

loss, embeddings = model(
    input_seq,
    head = 'human',
    target = target,
    return_embeddings = True
)

loss.backward()

# after much training

corr_coef = model(
    seq,
    head = 'human',
    target = target,
    return_corr_coef = True
)

corr_coef # pearson R, used as a metric in the paper. For our first test, this should return something close to 1, since 
          # we are running it on the exact same fasta file as a sanity check


## Run the model with finetuning

In [26]:
import torch
from enformer_pytorch import Enformer
from enformer_pytorch.finetune import HeadAdapterWrapper

enformer = Enformer.from_pretrained('EleutherAI/enformer-official-rough')

model = HeadAdapterWrapper(
    enformer = enformer,
    num_tracks = 128,
    post_transformer_embed = False   # by default, embeddings are taken from after the final pointwise block w/ conv -> gelu - but if you'd like the embeddings right after the transformer block with a learned layernorm, set this to True
).cuda()

target = torch.randn(1, 200, 128).cuda()  # 128 tracks

loss = model(seq, target = target)
loss.backward()

Downloading (…)lve/main/config.json: 100%|████████████████████████████████████████████████████████████████████████████████████| 439/439 [00:00<00:00, 273kB/s]
Downloading pytorch_model.bin: 100%|█████████████████████████████████████████████████████████████████████████████████████| 1.01G/1.01G [00:11<00:00, 87.2MB/s]


OSError: [Errno 28] No space left on device

In [20]:
pred = model(seq, head = 'human') # (3000, 3000)

In [21]:
pred

tensor([[[0.7276, 0.7898, 0.7651,  ..., 0.8316, 0.6573, 0.5948]],

        [[0.7207, 0.7824, 0.8117,  ..., 0.7090, 0.5889, 0.5795]],

        [[0.5686, 0.8286, 0.8646,  ..., 0.7037, 0.6471, 0.6555]],

        ...,

        [[0.7175, 0.7805, 0.7743,  ..., 0.7038, 0.8123, 0.5317]],

        [[0.6933, 0.6910, 0.7489,  ..., 0.7125, 0.6352, 0.6702]],

        [[0.6963, 0.7716, 0.8335,  ..., 0.6973, 0.6662, 0.6498]]],
       grad_fn=<SoftplusBackward0>)

## Old stuff

In [None]:
dataset_list = []
fai_dir=fasta_dir + '/fais'
!mkdir -p {fai_dir}
for fasta_name in os.listdir(fasta_dir):
    !mv {fasta_dir}/*.fai {fai_dir}
    fasta_path=fasta_dir + '/' + fasta_name
    if not os.path.splitext(fasta_path)[1] == 'fai':
        print('Now collecting dataset for: ' + fasta_path)
        dataset_list.append(makeEnformerDataset(fasta_path, hg38_bed_path))

mv: cannot stat ‘/scratch/users/sschulz/pta_on_normal/fastas/*.fai’: No such file or directory
Now collecting dataset for: /scratch/users/sschulz/pta_on_normal/fastas/CART-MRD-BALL-PTA-NEXTERA-WGS-CCT5007Pt04-A9_S57.realigned_deduped_sorted.fasta
Now collecting dataset for: /scratch/users/sschulz/pta_on_normal/fastas/CART-MRD-BALL-PTA-NEXTERA-WGS-CCT5007Pt04-B10_S65.realigned_deduped_sorted.fasta


In [8]:
def combineDatasets(dataset_list):
    '''
        uses hf datasets to concatenate a list of datasets
    '''
    return concatenate_datasets(dataset_list)

In [None]:
combined_dataset = combineDatasets(dataset_list)

In [None]:
combined_dataset.save_to_disk(bam_dir + '/' + dataset_name)

In [None]:
model = trainEnformerPTA(combined_datset)

In [None]:
seq = ### enter a sequence to test
pred = model(seq, head = 'human') # (896, 5313) ###<- what does human mean here?