In [37]:
%%writefile dataframes.py
import pandas as pd
from Bio import SeqIO

class CoronaDataframe:
    '''Class that handles manipulation and generation of dataframes
    needed for plotting'''
    
    def __init__(self, query: str, ref: str, mutations: str) -> None:
        
        # Only include region between 256:29674 (same as pangolin)
        self.query = SeqIO.read(query, 'fasta').seq[256:29674]
        self.ref = SeqIO.read(ref, 'fasta').seq[256:29674]
        self.mutations = mutations
    
    
    def make_alignment_df(self) -> pd.DataFrame:
        '''Turns the sequence of the aligned fasta file into a df used to plot 
        the missing (N), deletions and mutations as a genome map'''
    
        reference = [x for x in self.ref]
        query = [x for x in self.query]
        converted_values = self._convert_alignment_value()

        alignment_df = pd.DataFrame({'nt': query, 'id': 'query', 
                                 'pos': range(256, len(query) + 256), 
                                 'value': converted_values, 
                                 'ref': reference})
        return alignment_df
    
    
    def make_mutations_df(self) -> pd.DataFrame:
        '''Transforms the df of mutations to a suitable format'''
        
        mutations_df = pd.read_csv(self.mutations)
        mutations_df['unique'] = [1 if x == 1 else 0 for x in mutations_df.number_mutations]
        mutations_df['kind'] = [2 if '-' in str(x) else y for x,y in zip(mutations_df.mutation, 
                                                                         mutations_df.unique)]
        mutations_df['position'] = mutations_df['mutation'].str.extract('(\d+)').astype(int)

        return mutations_df
    
    
    def make_query_df(self) -> pd.DataFrame:
        '''Turns a list of mutations for the query into a df in the right format'''
        
        mutations_list = self._extract_snp()
        query_df = pd.DataFrame({'pango': 'QUERY', 'mutation': mutations_list})
        query_df['position'] = query_df['mutation'].str.extract('(\d+)').astype(int)
    
        return query_df
    
    
    def make_mutations_inspection_df(self) -> pd.DataFrame:
        '''Returns a dataframe with all mutations and lineages 
        for every given position in the query mutation set'''
        mutations = self.make_mutations_df()
        query_df = self.make_query_df()
        mutations = mutations[mutations['position'].isin(query_df.position)]
        query_df = query_df[query_df['position'].isin(mutations.position)]


        mutations_list = mutations.groupby('position')[['mutation', 'pango']].aggregate(list).reset_index()
        mutations_list['mutation_list'] = [dict(zip(x, y)) for x,y in zip(mutations_list.pango, mutations_list.mutation)]
        mutations_list.drop(columns=['pango', 'mutation'], inplace=True)

        merged = query_df.merge(mutations_list)
        merged['pango_with_mutation'] = merged['mutation_list'].apply(len)
        merged['color'] = [1 if x.endswith('N') else 2 if x.endswith('-') else 3 for x in merged.mutation]

        return merged
    
    
    def _convert_alignment_value(self) -> list:
        '''Converts alignments to numbers depending on it is N, 
        deletion, substitution or the same'''
        values = []
        for q,r in zip(self.query, self.ref):
            if q != r:
                if q == 'N':
                    values.append(-1)
                elif q == '-':
                    values.append(-2)
                else:
                    values.append(1)
            else:
                values.append(0)  
        return values
    
    #### MÅSTE ADDERA +1 till index senare!!!!!!!!!!! #######
    def _extract_snp(self) -> list:
        '''Helper function to extract information about mutation'''

        return [f'{a}{index}{b}' for index,(a,b) in 
                enumerate(zip(self.ref, self.query), 256) if a != b]



Overwriting dataframes.py


In [28]:
%%writefile plotting.py
import altair as alt
import pandas as pd
alt.data_transformers.disable_max_rows()

from .dataframes import CoronaDataframe

class CoronaPlot:
    '''Class for plotting corona data. Uses Altair.'''
    
    def __init__(self, query: str, ref: str, mutations: str) -> None:
        self.dataframes = CoronaDataframe(query, ref, mutations)
        
    def plot_alignment_map(self):
        ''' Returns chart and view which plots the alignment map of the query sequence'''
        ####### Chache with json file????? ######## 
        df = self.dataframes.make_alignment_df()
        interval = alt.selection_interval(encodings=['x'])

        base = alt.Chart(df, title="Coverage and mutations").mark_rect().encode(
            x=alt.X('pos:N', axis=alt.Axis(labels=False, ticks=False)),
            y='id:N',
            color=alt.Color('max(value):N', legend=None, scale=alt.Scale(domain=[0,-1, -2, 1], range=['white', 'red', 'black', 'green'])),
            tooltip=[alt.Tooltip('nt:N', title='nt'), alt.Tooltip('pos:N', title='Position'), alt.Tooltip('ref:N', title='Reference')])

        chart = base.encode(x=alt.X('pos:N', scale=alt.Scale(domain=interval.ref()), axis=alt.Axis(labels=False, ticks=False))
            ).properties(width=1200, height=300
            )

        view = base.add_selection(interval
        ).properties(width=1200, height=100)


        return chart, view
        
        
    def plot_mutation_heatmap(self):
        '''Plots a heatmap over all mutations in the mutation csv file'''
        mutations = self.dataframes.make_mutations_df()

        unique_mut = list(mutations.mutation.unique())
        order_level = sorted(unique_mut, key=lambda x: int("".join([i for i in x if i.isdigit()])))

        fig = alt.Chart(mutations, title='Mutations. Dark blue == deletions. Light blue == unique mutations for that lineage').mark_rect(stroke='black').encode(
            x=alt.X('mutation', axis=alt.Axis(labels=False, ticks=False), sort=order_level),
            y='pango',
            color=alt.Color('kind', legend=None),
            tooltip=[
                alt.Tooltip('mutation', title='Mutation'),
                alt.Tooltip('pango', title='Pango')]).properties(width=1200)

        return fig
    
    
    
    def plot_compare(self, compare_to: str):
        '''Compares query mutatations against mutations for a given corona lineage.'''
        mutations_db = self.dataframes.make_mutations_df()
        mutations_db = mutations_db[['pango', 'mutation', 'kind', 'position']]
        lineage_df = mutations_db[mutations_db['pango'] == compare_to]

        query_df = self.dataframes.make_query_df()
        # Vad är bäst, jämföra mot bara mutationer i den man vill jämföra med eller ta alla mutationer som finns i query???
        # kommentera bort raderna nedan för att testa
        #query_df = query_df[query_df['position'].isin(mutations_db.position.to_list())]
        query_df = query_df[query_df['position'].isin(lineage_df.position.to_list())]

        query_df['kind'] = [3 if x in lineage_df.mutation.to_list() else 4 for x in query_df.mutation]

        concated = pd.concat([lineage_df, query_df])
        position_level = sorted(concated.position.unique())

        fig = alt.Chart(concated, title={'text': [f'Compared against {compare_to}'], 
                                         'subtitle': ['Green == same, pink == unique, blue == deletion, red == differs from comparison'], 
                                         'color': 'green'}).mark_rect(stroke='black').encode(
        x=alt.X('position:N', axis=alt.Axis(), sort=position_level),
        y='pango',
        color=alt.Color('kind', legend=None, scale=alt.Scale(domain=[0, 1, 2, 3, 4], range=['#e7fc98 ', '#f64fe7', '#98fcf3', '#e7fc98', '#fc9898'])),
        tooltip=[
            alt.Tooltip('mutation', title='Mutation'),
            alt.Tooltip('pango', title='Pango')]).properties(width=1000, height=50)

        return fig


    def plot_mutations_inspection(self):
        '''Plots how common each mutation is for every mutation in the query sequence'''
        df = self.dataframes.make_mutations_inspection_df()

        fig = alt.Chart(df, title={'text': ['Mutations in'], 
                                         'subtitle': ['Green == regular, blue == deletion, red == N', 
                                                      'Size shows how common mutation is'], 
                                         'color': 'green'}).mark_point(stroke='black').encode(
            x=alt.X('position:N', axis=alt.Axis()),
            y='pango',
            fill=alt.Fill('color:N', legend=None, scale=alt.Scale(domain=[1, 2, 3], range=['red', 'blue', 'green'])),
            size=alt.Size('pango_with_mutation', legend=None, scale=alt.Scale(range=[100, 1500])),
            tooltip=[
                alt.Tooltip('mutation', title='Mutation'),
                alt.Tooltip('mutation_list', title='Pangos with mutation')]).properties(width=1200, height=150)

        return fig
        
        


Overwriting plotting.py


In [2]:
%%writefile nextclade.py
import subprocess
import pandas as pd

class Nextclade:
    '''Runs nextclade and stores the information of the run'''
    
    def __init__(self) -> None:
        self.OUTDIR = 'nextclade_runs/'
        self.REF_DIR = 'reference/'
        self.INPUT = 'input.fasta'
        self.META_TSV = self.OUTDIR + 'meta.tsv'
        self.QUERY = self.OUTDIR + 'input.aligned.fasta'
        self.REF = self.REF_DIR + 'reference.fasta'
        
    def run_nextclade(self) -> None: 
        command = ['nextclade', '--in-order', '--input-fasta', self.INPUT, '--input-dataset', self.REF_DIR,
                   '--output-dir', self.OUTDIR, '--output-tsv', self.META_TSV]

        subprocess.call(command)
        
    def make_meta_df(self) -> None:
        return pd.read_csv(self.META_TSV, sep='\t')
        
        
        

Overwriting nextclade.py
