## Pip installs

In [None]:
!pip install torch

In [None]:
!pip install transformers

In [None]:
!pip install datasets

In [None]:
!pip install pysam

In [None]:
!pip install HTSeq

In [None]:
!pip install enformer-pytorch>=0.5

In [None]:
!pip install polars

## Load modules and variable declarations

In [1]:
from dnaDataSet import dnaDataSet

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import torch
import polars as pl
from enformer_pytorch import Enformer, GenomeIntervalDataset
from datasets import concatenate_datasets, load_dataset
import os
import numpy as np
import pandas as pd
import pysam
import json 
import pickle

In [None]:
from llama_cpp import Llama

In [None]:
#can change the model path to any quantized model binary
#TODO - make a script version of this with option for model path, relative context length, tsv_path, bam_path
modelPath="/home/shawn/Programming/ai_stuff/llama.cpp/models/30B/ggml-model-q4_0.bin"
memoryDir="/home/shawn/datasets/llm_memory"

In [None]:
llm = Llama(model_path=model_path)

In [None]:
tsv_dir= "/home/shawn/datasets/enformer_data"
bed_path = "/scratch/users/sschulz/pta_on_normal/chr10.bed"


In [None]:
#good testing but use a gvcf instead containing all known mutations first
tsv_path = tsv_dir + "/CARTPt04_Scan2_svc_merged_extract_snp.hg38_multianno.tsv"

In [None]:
tsv = pd.read_table(tsv_path, sep='\t')

In [None]:
tsv['CHROM'][0]

In [None]:
bam_path=tsv_dir + '/CART-MRD-BALL-PTA-NEXTERA-WGS-CCT5007Pt04-D4_S26.realigned_deduped_sorted.bam'

## get the gvcf files

In [None]:
sav_dir = os.getcwd()

In [None]:
gvcf_dir = '/scratch/users/sschulz/pta_on_normal/gvcf'
download_log_dir='scratch/users/sschulz/pta_on_normal/gvcf/logs/'

In [None]:
os.chdir(gvcf_dir)

In [None]:

new_list = [chrom for sublist in [('MT', 'X', 'Y'), list(range(1,23))] for chrom in sublist]
new_list

In [None]:
def downloadEnsembleGVCFs(output_dir, download_log_dir):
    sav_dir = os.getcwd()
    os.chdir(output_dir)
    for i in [chrom for sublist in [('MT', 'X', 'Y'), list(range(1,23))] for chrom in sublist]:
        command_ending = str(chrom) + '.gvf.gz'
        print(f"sbatch -c 2 --mem=32G -p cgawad --out={download_log_dir} --wrap='wget https://ftp.ensembl.org/pub/release-109/variation/gvf/homo_sapiens/homo_sapiens_incl_consequences-chr{i}.gvf.gz'")
        !sbatch -c 2 --mem=32G -p cgawad --out=$download_log_dir --wrap=f"wget https://ftp.ensembl.org/pub/release-109/variation/gvf/homo_sapiens/homo_sapiens_incl_consequences-chr{i}.gvf.gz"
    os.chdir(sav_dir)

In [None]:
downloadEnsembleGVCFs(gvcf_dir, download_log_dir)

In [None]:
for i in ('MT', 'X', 'Y'):
    print(f"sbatch -c 2 --mem=32G -p cgawad --out='/scratch/users/sschulz/pta_on_normal/gvcf/logs/' --wrap='wget https://ftp.ensembl.org/pub/release-109/variation/gvf/homo_sapiens/homo_sapiens_incl_consequences-chr{i}.gvf.gz'")
     

## Function and class definitions

In [None]:
# this is all in the dnaDataSet python file but I've kept it here just in case lol
 class dnaDataSet:
    def save(self, fp):
        '''
            save dnaDataSet as pickle somewhere
        '''
        file_name = fp
        with open(file_name, 'wb') as file:
            pickle.dump(self, file)
            print(f'dnaDataSet object successfully saved to "{file_name}"')
    
    def __init__(self, modelPath=False, memoryDir=os.getcwd()):
        self.mutationDictionary={}
        self.bamsDictionary={}
        self.tsv=pd.DataFrame()
        self.relativeContextLength=4
        self.memoryDir=memoryDir
        #promptOutput is a string formatted as a hf dataset
        self.promptOutput=''
        self.modelPath=modelPath
    
    def __getitem__(self, position):
        '''
        '''
        items=[
            self.mutationDictionary,
            self.bamsDictionary,
            self.tsv,
            self.relativeContextLength,
            self.memoryDir,
            #promptOutput is a string formatted as a hf dataset
            self.promptOutput
        ]
        return items[position]
    
    def __repr__(self):
        '''
            
        '''
        return f'dnaDataSet object\n current model being used: {str(self.modelPath)},\n mutationDictionary: {str(self.mutationDictionary)},\n bamsDictionary: {str(self.bamsDictionary)},\n tsv: {str(self.tsv)},\n relativeContextLength: {str(self.relativeContextLength)},\n memoryDir: {str(self.memoryDir)},\n promptOutput: {str(self.promptOutput)}'
    
    def __add__(self, other):
        '''
            returns a dnaDataset with consolidated mutationDictionary and bamsDictionary, however other info is kept from the first dictionary
        '''
        selfCopy = self
        otherCopy = other
        selfCopy.mutationDictionary.update(otherCopy.mutationDictionary)
        selfCopy.bamDictionary.update(otherCopy.mutationDictionary)
        return selfCopy
    
    def __len__(self):
        '''
            prints the length of the mutationDictioanry keys
        '''
        print("The length of the mutationDictionary keys is :")
        return(len(self.mutationDictionary))
    
    def setRelativeContextLength(self, contextLength):
        '''
            takes int contextLength and sets it in the dataset
        '''
        self.relativeContextLength=contextLength
    
    def saveOutput(self, fp, memoryDir=False):
        '''
            saves the output of a prompting to memoryDir by default (so it can be used automtically when calling prompting), but can also be called
            to save where user specifies filepath
        '''
        if not memoryDir:
            self.promptOutput.save_to_disk(fp)
        else:
            self.promptOutput.save_to_disk(self.memoryDir + '/' + fp)
    
    def saveMutationDictionary(self, fp, memoryDir=False):
        '''
            saves mutationDictionary produced from tsv file and bam files to memoryDir by default as json file, but can also be called to save where user
            specifies filepath
        '''
        if not memoryDir:
            with open(fp, "w") as outfile:
                json.dump(self.mutationDictionary, outfile)
        else:
            with open(self.memoryDir + '/' + fp, "w") as outfile:
                json.dump(self.mutationDictionary, outfile)
    
    def makeLlamaDataset(self, tsv_dir, bam_path, bed_path):
        '''
            from a directory containing an annotated tsv file, many bam files and a bed path, create a huggingface dataset for use in llama

            start by just passing lines from vcf to llama for fine tuning, along with a line that says 
            "The read/basepairs/sequence at this position is:
            The read information from reference is:"

            This is a pretty brute force way to do it but maybe it'll create something coherent from llama.


            Getting correct sequence instruction: 
            "instruction": f"The gene {gene} is mutated at the {start_pos} basepair. What is the sequence? What is the mutation?",
            "input": f"{read_seq}",
            "output": "5"

            Getting whether exonic or not/amino acid change:


            [WIP] Instrucitons incorporating answers from databases:

            Clinvar:

            NCBI:

            Genecards: 

        '''
        for filename in os.listdir(tsv_dir):
            if filename.endswith('tsv'):
                tsv_file = os.path.join(tsv_dir, filename)
                tsv_length=len(tsv_file)
                counter = 0
                print("the tsv file is: ")
                print(tsv_file)
                for i in range(tsv_length):
                    chrom = tsv['CHROM'][i]
                    start_pos = tsv['POS'][i]
                    sample = tsv['SAMPLE'][i]
                    gene = tsv['Gene.refGene'][i]
                    gt = tsv['GT'][i]
                    alt = tsv['ALT'][i]
         #           print("tsv from the tsv file is: ")
          #          print(' '.join(tsv.columns))
                    if gt == '0/1' or gt == '1/1':
                        print(start_pos)
                        print(sample, gt)
                        print(alt)
                        print(gene)

                        ### position of mutation is the position is says on the pileup - start position (0 indexed)
                        ## start position can be greater than or less than position of read start, but luckily
                        ## should be able to index the base that's changed either way 

                        #

                        samfile = pysam.AlignmentFile(bam_path, "rb" )
                        self.bamsDictionary[bam_path] = samfile
                        pileup = samfile.pileup(chrom, start_pos, start_pos+1, min_mapping_quality=58)
                        for read in pileup:
                            read_list = str(read).split('\t')
                            read_start = read_list[5]
                            read_seq = read_list[11]

                            mutated_base= read_seq[int(read_start) - start_pos] 


                            print(f"the start pos from tsv is {start_pos} the start pos from pileup is {read_start} the the gene is: "+ gene +  ' the read is: ' + str(read_list) + ' and the mutated base is: ' + mutated_base)
                            print('for sanity, the mutated allele was: ' + alt)
                            self.mutationDictionary["Reference Genome: hg38, Read: " + read_seq] =  f"the start pos from tsv is {start_pos} the start pos from pileup is {read_start} the the gene is: "+ gene + ' and the mutated base is: ' + mutated_base
                    # for x in pileup:
                    #     if counter == 0:
                    #         print(str(x))

        return(mutation_dictionary)
    def fewShotLearning(self, read):
        '''
            takes a read as a prompt
        '''
        counter = 0
        prompt_string = ''
        for key in self.mutationDictionary.keys():
            counter += 1
            if counter < self.relativeContextLength:
                prompt_string += "Input: " + key + "\n" + " Output: " + self.mutationDictionary[key] + "\n"
        prompt = 'Reference Genome: hg38, Read: ' + read
        output = llm(prompt_string + "\n" + "Input: " + prompt + "\n" + "Output: ", max_tokens=32, stop=["Input:"], echo=True)
        print(output)

In [None]:
def makeLlamaDataset(tsv_dir, bam_path, bed_path):
    '''
        from a directory containing many annotated tsv files and a bed path, create a huggingface dataset for use in llama
        
        start by just passing lines from vcf to llama for fine tuning, along with a line that says 
        "The read/basepairs/sequence at this position is:
        The read information from reference is:"
        
        This is a pretty brute force way to do it but maybe it'll create something coherent from llama.
        
        
        Getting correct sequence instruction: 
        "instruction": f"The gene {gene} is mutated at the {start_pos} basepair. What is the sequence? What is the mutation?",
        "input": f"{read_seq}",
        "output": "5"
        
        Getting whether exonic or not/amino acid change:
        
        
        [WIP] Instrucitons incorporating answers from databases:
        
        Clinvar:
        
        NCBI:
        
        Genecards: 
        
    '''
    mutation_dictionary = {}
    for filename in os.listdir(tsv_dir):
        if filename.endswith('tsv'):
            tsv_file = os.path.join(tsv_dir, filename)
            tsv_length=len(tsv_file)
            counter = 0
            print("the tsv file is: ")
            print(tsv_file)
            for i in range(tsv_length):
                chrom = tsv['CHROM'][i]
                start_pos = tsv['POS'][i]
                sample = tsv['SAMPLE'][i]
                gene = tsv['Gene.refGene'][i]
                gt = tsv['GT'][i]
                alt = tsv['ALT'][i]
     #           print("tsv from the tsv file is: ")
      #          print(' '.join(tsv.columns))
                if gt == '0/1' or gt == '1/1':
                    print(start_pos)
                    print(sample, gt)
                    print(alt)
                    print(gene)
                    
                    ### position of mutation is the position is says on the pileup - start position (0 indexed)
                    ## start position can be greater than or less than position of read start, but luckily
                    ## should be able to index the base that's changed either way 
                    
                    #
                    
                    samfile = pysam.AlignmentFile(bam_path, "rb" )
                    pileup = samfile.pileup(chrom, start_pos, start_pos+1, min_mapping_quality=58)
                    for read in pileup:
                        read_list = str(read).split('\t')
                        read_start = read_list[5]
                        read_seq = read_list[11]
                        
                        mutated_base= read_seq[int(read_start) - start_pos] 
                        
                
                        print(f"the start pos from tsv is {start_pos} the start pos from pileup is {read_start} the the gene is: "+ gene +  ' the read is: ' + str(read_list) + ' and the mutated base is: ' + mutated_base)
                        print('for sanity, the mutated allele was: ' + alt)
                        mutation_dictionary["Reference Genome: hg38, Read: " + read_seq] =  f"the start pos from tsv is {start_pos} the start pos from pileup is {read_start} the the gene is: "+ gene + ' and the mutated base is: ' + mutated_base
                # for x in pileup:
                #     if counter == 0:
                #         print(str(x))

    return(mutation_dictionary)

## test stuff with the class

In [5]:
dnaset = dnaDataSet()

In [6]:
dnaset

dnaDataSet object
 current model being used: False,
 mutationDictionary: {},
 bamsDictionary: {},
 tsv: Empty DataFrame
Columns: []
Index: [],
 relativeContextLength: 4,
 memoryDir: /oak/stanford/groups/cgawad/Scripts/dna-llama,
 promptOutput: 

In [7]:
dnaset2=dnaDataSet(modelPath=modelPath, memoryDir=memoryDir)

NameError: name 'modelPath' is not defined

In [None]:
dnaset2.save("/scratch/users/sschulz/dnaset.pickle")

## Few shot learning

In [None]:
'''
    Idea was to finetune with the dataset form makeLlamaDataset, but for now we are just trying to use it to do few shot learning by taking some examples
    from it and using it to get it to tell you the mutated base in a read you give it
'''

In [None]:
dnaset2.makeLlamaDataset(tsv_dir, bam_path, bed_path)

In [None]:
dnaset2.fewShotLearning('TAGAAAATGTGGATGAGACCTTCTGCATAGATAACGAAGCGCTATATGACATATGTTCCAGGACCCTAAAACTGCCCACACCCACCTATGGTGACCTGAA')

# Fine tuning

In [3]:
modelPath='/home/shawn/datasets/LLMs/llama_7b/config.json'
memoryDir='/home/shawn/datasets'
trainingDataset='pollner/dna_dataset' 
fineTuningDataset=dnaDataSet(modelPath=modelPath, memoryDir=memoryDir)
fineTuningDataset.mutationDictionary = load_dataset(trainingDataset)

Found cached dataset parquet (/home/shawn/.cache/huggingface/datasets/pollner___parquet/pollner--dna_dataset-df9452f1a4694811/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)
100%|██████████| 3/3 [00:00<00:00, 1115.70it/s]


In [4]:
fineTuningDataset

dnaDataSet object
 current model being used: /home/shawn/datasets/LLMs/llama_7b/config.json,
 mutationDictionary: DatasetDict({
    train: Dataset({
        features: ['target', 'context'],
        num_rows: 12844
    })
    test: Dataset({
        features: ['target', 'context'],
        num_rows: 1606
    })
    validation: Dataset({
        features: ['target', 'context'],
        num_rows: 1605
    })
}),
 bamsDictionary: {},
 tsv: Empty DataFrame
Columns: []
Index: [],
 relativeContextLength: 4,
 memoryDir: /home/shawn/datasets,
 promptOutput: 

In [5]:
model = fineTuningDataset.finetune()


Welcome to bitsandbytes. For bug reports, please run

python -m bitsandbytes

 and submit this information together with your error trace to: https://github.com/TimDettmers/bitsandbytes/issues
bin /home/shawn/.local/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cpu.so
CUDA SETUP: Highest compute capability among GPUs detected: 8.6
CUDA SETUP: Detected CUDA version 117
CUDA SETUP: Loading binary /home/shawn/.local/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cpu.so...


  warn(msg)
  warn(msg)
  warn(msg)
  warn(msg)
  warn(msg)
  warn(msg)
  warn(msg)
Loading checkpoint shards: 100%|██████████| 33/33 [00:07<00:00,  4.39it/s]
  0%|          | 0/9633 [00:00<?, ?it/s]

IndexError: Invalid key: 12684 is out of bounds for size 0