## Pip installs

In [None]:
!pip install torch

In [None]:
!pip install transformers

In [None]:
!pip install datasets

In [None]:
!pip install pysam

In [None]:
!pip install HTSeq

In [None]:
!pip install enformer-pytorch>=0.5

In [None]:
!pip install polars

## Load modules and variable declarations

In [5]:
import torch
import polars as pl
from enformer_pytorch import Enformer, GenomeIntervalDataset
from datasets import concatenate_datasets, load_dataset
import os
import numpy as np
import pandas as pd
import pysam

  from .autonotebook import tqdm as notebook_tqdm


In [6]:
from llama_cpp import Llama

In [7]:
#can change the model path to any quantized model binary
#TODO - make a script version of this with option for model path, relative context length, tsv_path, bam_path
model_path="/scratch/users/sschulz/ggml-model-q4_1.bin"

In [8]:
llm = Llama(model_path=model_path)

llama.cpp: loading model from /scratch/users/sschulz/ggml-model-q4_1.bin
llama_model_load_internal: format     = ggjt v1 (latest)
llama_model_load_internal: n_vocab    = 32000
llama_model_load_internal: n_ctx      = 512
llama_model_load_internal: n_embd     = 6656
llama_model_load_internal: n_mult     = 256
llama_model_load_internal: n_head     = 52
llama_model_load_internal: n_layer    = 60
llama_model_load_internal: n_rot      = 128
llama_model_load_internal: ftype      = 3 (mostly Q4_1)
llama_model_load_internal: n_ff       = 17920
llama_model_load_internal: n_parts    = 1
llama_model_load_internal: model size = 30B
llama_model_load_internal: ggml ctx size = 110.30 KB
llama_model_load_internal: mem required  = 25573.12 MB (+ 3124.00 MB per state)
llama_init_from_file: kv self size  =  780.00 MB
AVX = 1 | AVX2 = 1 | AVX512 = 1 | FMA = 1 | NEON = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 0 | SSE3 = 1 | VSX = 0 | 


In [None]:
tsv_dir= "/scratch/users/sschulz/pta_on_normal"
bed_path = "/scratch/users/sschulz/pta_on_normal/chr10.bed"


In [None]:
#good testing but use a gvcf instead containing all known mutations first
tsv_path = tsv_dir + "/CARTPt04_Scan2_svc_merged_extract_snp.hg38_multianno.tsv"

In [None]:
tsv = pd.read_table(tsv_path, sep='\t')

In [None]:
tsv['CHROM'][0]

In [None]:
bam_path=tsv_dir + '/CART-MRD-BALL-PTA-NEXTERA-WGS-CCT5007Pt04-D4_S26.realigned_deduped_sorted.bam'

## get the gvcf files

In [None]:
sav_dir = os.getcwd()

In [None]:
gvcf_dir = '/scratch/users/sschulz/pta_on_normal/gvcf'
download_log_dir='scratch/users/sschulz/pta_on_normal/gvcf/logs/'

In [None]:
os.chdir(gvcf_dir)

In [None]:

new_list = [chrom for sublist in [('MT', 'X', 'Y'), list(range(1,23))] for chrom in sublist]
new_list

In [None]:
def downloadEnsembleGVCFs(output_dir, download_log_dir):
    sav_dir = os.getcwd()
    os.chdir(output_dir)
    for i in [chrom for sublist in [('MT', 'X', 'Y'), list(range(1,23))] for chrom in sublist]:
        command_ending = str(chrom) + '.gvf.gz'
        print(f"sbatch -c 2 --mem=32G -p cgawad --out={download_log_dir} --wrap='wget https://ftp.ensembl.org/pub/release-109/variation/gvf/homo_sapiens/homo_sapiens_incl_consequences-chr{i}.gvf.gz'")
        !sbatch -c 2 --mem=32G -p cgawad --out=$download_log_dir --wrap=f"wget https://ftp.ensembl.org/pub/release-109/variation/gvf/homo_sapiens/homo_sapiens_incl_consequences-chr{i}.gvf.gz"
    os.chdir(sav_dir)

In [None]:
downloadEnsembleGVCFs(gvcf_dir, download_log_dir)

In [None]:
for i in ('MT', 'X', 'Y'):
    print(f"sbatch -c 2 --mem=32G -p cgawad --out='/scratch/users/sschulz/pta_on_normal/gvcf/logs/' --wrap='wget https://ftp.ensembl.org/pub/release-109/variation/gvf/homo_sapiens/homo_sapiens_incl_consequences-chr{i}.gvf.gz'")
     

## Function and class definitions

In [None]:
def makeLlamaDataset(tsv_dir, bam_path, bed_path):
    '''
        from a directory containing many annotated tsv files and a bed path, create a huggingface dataset for use in llama
        
        start by just passing lines from vcf to llama for fine tuning, along with a line that says 
        "The read/basepairs/sequence at this position is:
        The read information from reference is:"
        
        This is a pretty brute force way to do it but maybe it'll create something coherent from llama.
        
        
        Getting correct sequence instruction: 
        "instruction": f"The gene {gene} is mutated at the {start_pos} basepair. What is the sequence? What is the mutation?",
        "input": f"{read_seq}",
        "output": "5"
        
        Getting whether exonic or not/amino acid change:
        
        
        [WIP] Instrucitons incorporating answers from databases:
        
        Clinvar:
        
        NCBI:
        
        Genecards: 
        
    '''
    mutation_dictionary = {}
    for filename in os.listdir(tsv_dir):
        if filename.endswith('tsv'):
            tsv_file = os.path.join(tsv_dir, filename)
            tsv_length=len(tsv_file)
            counter = 0
            print("the tsv file is: ")
            print(tsv_file)
            for i in range(tsv_length):
                chrom = tsv['CHROM'][i]
                start_pos = tsv['POS'][i]
                sample = tsv['SAMPLE'][i]
                gene = tsv['Gene.refGene'][i]
                gt = tsv['GT'][i]
                alt = tsv['ALT'][i]
     #           print("tsv from the tsv file is: ")
      #          print(' '.join(tsv.columns))
                if gt == '0/1' or gt == '1/1':
                    print(start_pos)
                    print(sample, gt)
                    print(alt)
                    print(gene)
                    
                    ### position of mutation is the position is says on the pileup - start position (0 indexed)
                    ## start position can be greater than or less than position of read start, but luckily
                    ## should be able to index the base that's changed either way 
                    
                    #
                    
                    samfile = pysam.AlignmentFile(bam_path, "rb" )
                    pileup = samfile.pileup(chrom, start_pos, start_pos+1, min_mapping_quality=58)
                    for read in pileup:
                        read_list = str(read).split('\t')
                        read_start = read_list[5]
                        read_seq = read_list[11]
                        
                        mutated_base= read_seq[int(read_start) - start_pos] 
                        
                
                        print(f"the start pos from tsv is {start_pos} the start pos from pileup is {read_start} the the gene is: "+ gene +  ' the read is: ' + str(read_list) + ' and the mutated base is: ' + mutated_base)
                        print('for sanity, the mutated allele was: ' + alt)
                        mutation_dictionary["Reference Genome: hg38, Read: " + read_seq] =  f"the start pos from tsv is {start_pos} the start pos from pileup is {read_start} the the gene is: "+ gene + ' and the mutated base is: ' + mutated_base
                # for x in pileup:
                #     if counter == 0:
                #         print(str(x))

    return(mutation_dictionary)

In [None]:
read = 'GTGTCAGACACTGTGGTGGAGCCCTACAACGCCACCCTCTCAGTCCACCAGCTCATAGAAAATGTGGATGAGACCTTCTGCATAGATAACGAAGCGCTAT'

## Few shot learning

In [None]:
'''
    Idea was to finetune with the dataset form makeLlamaDataset, but for now we are just trying to use it to do few shot learning by taking some examples
    from it and using it to get it to tell you the mutated base in a read you give it
'''

In [None]:
prompt_dictionary = makeLlamaDataset(tsv_dir, bam_path, bed_path)

In [None]:
prompt_dictionary.keys()

In [None]:
# can change the relative context length if you want to try and give the model more context, but 4 is already a lot and very slow
relative_context_length = 4

In [None]:
counter = 0
prompt_string = ''
for key in prompt_dictionary.keys():
    counter += 1
    if counter < relative_context_length:
        prompt_string += "Input: " + key + "\n" + " Output: " + prompt_dictionary[key] + "\n"

In [None]:
print(prompt_string)

In [None]:
#test to check if llama is working
prompt = "Why do giraffes have long necks?"
output = llm("\n" + "Input: " + prompt + "\n" + "Output: ", max_tokens=32, stop=["Input:"], echo=True)
print(output)

In [None]:
prompt = 'Reference Genome: hg38, Read: TAGAAAATGTGGATGAGACCTTCTGCATAGATAACGAAGCGCTATATGACATATGTTCCAGGACCCTAAAACTGCCCACACCCACCTATGGTGACCTGAA'
output = llm(prompt_string + "\n" + "Input: " + prompt + "\n" + "Output: ", max_tokens=32, stop=["Input:"], echo=True)
print(output)

In [None]:
print(prompt_dictionary[prompt])