# Cleaned notebook of functions and classes used in the corona project

## Webscraping

### GISAID data

In [None]:
from selenium.webdriver.common.by import By
from selenium.webdriver.common.keys import Keys
from selenium.webdriver.common.action_chains import ActionChains
from selenium import webdriver
import time
import numpy as np
import pandas as pd
import clipboard

def download_gisaid_files(files: list[str], user: str, password: str) -> None:
    '''
    Downloads all the sequences in the list given to the function. 
    '''
    
    clipped = ' '.join(files)
    clipboard.copy(clipped)
    
    driver = webdriver.Chrome()
    driver.get('https://www.epicov.org/epi3/frontend#df91e')
    time.sleep(4)
    
    username = driver.find_element(By.NAME, 'login')
    username.send_keys(user)
    passw = driver.find_element(By.NAME, 'password')
    passw.send_keys(password)
    login_button = driver.find_element(By.CLASS_NAME, 'form_button_submit')
    login_button.click()
    time.sleep(4)

    search = driver.find_element(By.XPATH, '/html/body/form/div[5]/div/div[2]/div/div[1]/div/div/div[3]')
    search.click()
    time.sleep(4)

    select = driver.find_element(By.XPATH, '/html/body/form/div[5]/div/div[2]/div/div[2]/div[2]/div[2]/div[3]/button[2]')
    select.click()
    time.sleep(4)

    driver.switch_to.frame(driver.find_element(By.TAG_NAME,  'iframe'))
    time.sleep(4)

    text_area = driver.find_element(By.XPATH, '/html/body/form/div[5]/div/div[1]/div/div[1]/table[2]/tbody/tr/td/div/div[1]/textarea')
    text_area.send_keys(Keys.COMMAND, 'v')
    time.sleep(3)

    ok = driver.find_element(By.XPATH, '/html/body/form/div[5]/div/div[2]/div/div/div[2]/div/button')
    ok.click()
    time.sleep(4)

    actions = ActionChains(driver)
    actions.send_keys(Keys.RETURN)
    actions.perform()  
    time.sleep(4)

    driver.switch_to.parent_frame()
    time.sleep(4)

    download = driver.find_element(By.XPATH, '/html/body/form/div[5]/div/div[2]/div/div[2]/div[2]/div[2]/div[3]/button[4]')
    download.click()
    time.sleep(4)

    driver.switch_to.frame(driver.find_element(By.TAG_NAME,  'iframe'))
    time.sleep(4)

    input_auger = driver.find_element(By.XPATH, '/html/body/form/div[5]/div/div[1]/div/div/table[1]/tbody/tr/td[2]/div/div[1]/div[2]/div[1]/input')
    input_auger.click()
    time.sleep(4)

    get_files = driver.find_element(By.XPATH, '/html/body/form/div[5]/div/div[2]/div/div/div[2]/div/button')
    get_files.click()
    time.sleep(230)
    driver.quit()
    
def split_dataframe_chunk(path_to_csv: str, chunk_size: 5000) -> list[pd.DataFrame]:
    '''
    Splits the csv file into chunks of < the chunk_size.
    '''
    df = pd.read_csv(path_to_csv, header=None, names=['id'])
    chunks = np.ceil(len(df) / chunk_size)
    return np.array_split(df, chunks)

def save_gisaid_csv(path_to_csv: str, output_folder: str, chunk_size=5000) -> None:
    '''
    Saves chunks of the input csv file into chunks of desired sizes. 
    '''
    
    chunks = split_dataframe_chunk(path_to_csv, chunk_size)
    
    path = pathlib.Path(output_folder)
    path.mkdir(exist_ok=True, parents=True)
    
    for num, file in enumerate(chunks, 1):
        name = path / f'{num}.csv'
        file.to_csv(name, index=False)
        

def download_gisaid_files_from_csv(path_csv: str, username: str, password: str) -> list:
    '''
    Downloads data from GISAID from the csv files given as argument to the function. 
    '''
    path_csv = pathlib.Path(path_csv)
    csv_files = [file for file in path_csv.iterdir() if file.suffix == '.csv']
    
    for file in csv_files:
        sequences = pd.read_csv(file).id.to_list()
        not_downloaded = []
        try:
            download_gisaid_files(sequences, username, password)
            print(f'{file.name}')
        except Exception as e:
            print(f'{file.name} was not downloaded to to exception {e}')
            not_downloaded.append(file.name)
            
    return not_downloaded

### Outbreak.info data

In [None]:
from selenium import webdriver
from selenium.webdriver.common.by import By
from selenium.webdriver.common.keys import Keys
from selenium.webdriver.common.action_chains import ActionChains
import time
import re
import pathlib



def download_mutations_table(pango: str) -> pd.DataFrame:
    '''
    Downloads the table of characteristic mutations of the given pango lineage. 
    '''
    options = webdriver.ChromeOptions()
    options.headless = True
    driver = webdriver.Chrome(options=options)
    

    driver.get(f'https://outbreak.info/situation-reports?pango={pango}')
    time.sleep(60)
    number = driver.find_element(By.XPATH, '/html/body/div/div/div/div/div[5]/div/div[1]/div[3]/div/span')
    pattern = re.compile(r'(\d*,*\d+)')
    number_sequences = int(re.findall(pattern, number.text)[0].replace(',', ''))

    df = pd.read_html(driver.page_source)[0]
    df['pango'] = pango
    df['number_sequences'] = number_sequences
    df.rename(columns={'amino acid': 'amino_acid'}, inplace=True)
    driver.quit()
    
    return df


def save_mutation_table(pango: str, folder: str) -> None:
    '''
    Saves the mutation table from outbreak.info in a csv file.
    '''
    try:
        df = download_mutations_table(pango)
        path = pathlib.Path(folder, f'{pango}.csv')
        df.to_csv(path, index=False)
        print(f'{pango} saved')
    except Exception as e:
        print(f'{pango} failed due to {e}')

## Data cleaning

In [None]:
def nextclade_mutations(csv: str) -> pd.DataFrame:
    '''
    Reads the csv file from nextclade and returns a df of mutations.
    '''
    df = pd.read_csv(csv, sep='\t')
    df['aaSubstitutions'] = df['aaSubstitutions'].str.split(',')
    df = df['aaSubstitutions'].explode(ignore_index=True).to_frame()
    df.rename(columns={'aaSubstitutions': 'aa'}, inplace=True)
    df[['gene', 'amino_acid']] = df['aa'].str.split(':', expand=True)
    df.drop(columns=['aa'], inplace=True)
    return df


## Machine learning classes 

In [None]:
from Bio import SeqIO
import re
import pathlib
from sklearn.model_selection import train_test_split
from sklearn.naive_bayes import MultinomialNB
from sklearn.feature_extraction.text import CountVectorizer
import pandas as pd



class CoronaData:
    '''
    Keeps the information and the sequence of the corona lineages in a pandas dataframe. 
    Methods for manipulation the sequences for creating k-mers and similar.
    '''

    def __init__(self, path_meta: str, path_fasta: str, multiple: bool = False):
        if multiple:
            self.df = self.__multiple_corona_dataframe(path_meta, path_fasta)
        else:
            self.df = self.__corona_dataframe(path_meta, path_fasta)
    
    def __corona_dataframe(self, path_meta: str, path_fasta: str) -> pd.DataFrame:
        '''
        Returns a dataframe with the id, pangolin, gisaid clade, author, 
        length and sequence of all corona files.
        '''
        df = pd.read_csv(path_meta, sep='\t', parse_dates=['date'])
        df = df[['strain', 'pangolin_lineage', 'GISAID_clade', 'originating_lab', 'date']]
        df['strain'] = [x.split('/')[2] for x in df['strain']]
        
        sequences = list(SeqIO.parse(path_fasta, 'fasta'))
        df['sequence'] = [str(x.seq) for x in sequences]
        
        new_names = ['id', 'pangolin', 'clade', 'originating_lab', 'date', 'sequence']
        df.columns = new_names

        df['raw_length'] = [len(x) for x in df['sequence']]

        return df
    
    def __multiple_corona_dataframe(self, path_meta: str, path_fasta: str) -> pd.DataFrame:
        '''
        If folder contains multiple fasta and tsv files, everything is concatenated into 
        one big dataframe.
        '''
        path_meta = pathlib.Path(path_meta)
        path_fasta = pathlib.Path(path_fasta)
        
        metas = sorted([meta for meta in path_meta.iterdir() if meta.suffix == '.tsv'])
        fastas = sorted([fasta for fasta in path_fasta.iterdir() if fasta.suffix == '.fasta'])
        
        dataframes = [self.__corona_dataframe(meta, fasta) for meta, fasta in zip(metas, fastas)]
        df = pd.concat(dataframes) 
        
        return df 
    
    def trim_sequence(self) -> None:
        '''
        Operates on the self.df object.
        Removes everything from the sequence which is not ATCG.
        Creates a new column with the length of the trimmed sequence.
        '''
        self.df['sequence'] = [re.sub(r'[^ATCG]', '', x) for x in self.df['sequence']]
        self.df['trimmed_length'] = [len(x) for x in self.df['sequence']]
        
    def kmerize(self, k: int = 6) -> None:
        '''
        Operates on the self.df object.
        Kmerizes the sequences into length of k and return a string of the kmers. 
        Contains the inner function 'get_kmer'.
        '''
        def get_kmer(sequence, k):
            return [sequence[x: x + k] for x in range(len(sequence) - k + 1)]
                
        self.df['sequence'] = [' '.join(get_kmer(x, k)) for x in self.df['sequence']]

    def train_test_split(self, train_size: float = 0.8) -> None:
        '''
        Splits the self.df object into a train and test df based on the split value.
        '''
        self.train, self.test = train_test_split(self.df, random_state=42, stratify=self.df.pangolin, train_size=train_size)
        
    def vectorize(self, vectorizer: CountVectorizer) -> None:
        '''
        Vectorizes the sequence of the dataframe and returns a transformed scarce matrix.
        '''        
        self.transformed = vectorizer.transform(self.df.sequence)
        
    def classify(self, classifier: MultinomialNB) -> None:
        '''
        Classifies the transformed data and returns predictions about
        the sequences. 
        '''
        self.predicted = classifier.predict(self.transformed)
        
    def split_sequence(self, k: int = 6) -> None:
        '''
        Splits sequence into chunks with len. E.g 'abcde' -> 'ab cd e'.
        '''
        def make_chunks(sequence, k):
            chunks = []
            for i in range(0, len(sequence), 3):
                sek
                new_list.append(sekvens)
            return chunks
        
        self.df['sequence'] = [' '.join(make_chunks(x, k)) for x in self.df['sequence']]  

# Tankar och ideer

### How to search mutations with query on outbreak.info

&muts={gene}%3{substitution}

### Also possible to filter the mutation df with the mutations of choice

Everything now is just with substitutions and without deletion. Maybe add deletion later?

# Idea

To test how many similar substitutions the query has to all the designated lineages in the database, use the set union of the query and the mutations characteristic to find the ratio. 

Also possible to compare the query to all the mutations of all sequences in sweden to find the sequence that is most similar! 

# NLP on the mutation table of all sequence from Sweden!
Do TF-IDF on the mutatitons e.g. S:K128L and calculate the tf-idf for all the lineages! This will probably result in just a few common mutations which can be used for identifying the real characteriztics!!

Se om en mutation är ny för en sekvens genom att jämföra alla mutationer i en sekvens med ett union set av alla mutationer tillgängliga i databasen. Om en mutation inte finns i den nya sekvensen kan detta kanske vara en ny variant eller nåt att hålla ögonen på? 

## nextclade command
nextclade \
   --in-order \
   --input-fasta 1648629336772.sequences.fasta \
   --input-dataset data/sars-cov-2 \
   --output-tsv output/nextclade.tsv \
   --output-tree output/nextclade.auspice.json \
   --output-dir output/ \
   --output-basename nextclade