In [1]:
'''
Host-virus coexpression plot wrapper
Takes as an input the homology search for the virus and the host and overlay the results on a UMAP projection of
single-cell gene expression.
Returns two scatter plots: 
1. UMAP projections of all annotated cells captured. 
2. A close up view on a region of high diversity in the UMAP projection, with infected cells highlighted.
Also looks at the subpopulation of infected Katablepharidaceae and return 
the barcodes of infected Katablepharidaceae cells from bag 4 day 20

'''

/Users/amirf/Dropbox (Weizmann Institute)/scripts/single_cell_metag/working_scripts


# Host-virus coexpression wrapper

In [2]:
import pandas as pd
import scprep
import sys
import matplotlib.pyplot as plt
from math import isnan
import nbimporter
%matplotlib qt

In [3]:
# import annotation functions from sankey plot wrapper
import sankey_wrapper

In [4]:
# import annotation functions from the sankey wrapper
lenset = sankey_wrapper.lenset
dedup_data = sankey_wrapper.dedup_data
dedup_cells = sankey_wrapper.dedup_cells
annotation_data = sankey_wrapper.annotation_data
replace = sankey_wrapper.replace
replacement_process = sankey_wrapper.replacement_process

In [5]:
# Annotate the cells based on 18s rRNA homology according to blast results.
def annotate(path,
             metapr2_file,
             pr2_file,
            replacement_file):

    names_df = pd.read_table(path+ replacement_file, sep = "\t")
    dict_colors = dict(zip(names_df["new_name"],names_df["color"]))
    dict_colors['Other eukaryotes'] = "#797979"
    metapr2_data = dedup_data(path = path,
        file = metapr2_file , data_type = "metapr2",replacement_df = names_df)

    pr2_data = dedup_data(path = path,
        file = pr2_file , data_type = "pr2",replacement_df = names_df)

    data_annotations = annotation_data(pr2_data,metapr2_data)
    data_annotations['cell_name'] = data_annotations['cell']
    
    return data_annotations,dict_colors

In [6]:
# The sample names as sppear in the UMI tables are different than 
# the standardized names of the appear in in the paper, this is a standardization function
def edit_annotations(data_annotations):
    
    dict_samples = {"Mes_1_S1": "B7T16",
    "Mes_2_S2": "B7T18",
    "Mes_3_S3": "B4T13",
    "Mes_4_S4": "B4T15",
    "Mes_5_S5": "B4T20",
    "Mes_6_S6": "B3T15",
    "Mes_7_S7": "B3T20",
    "Mes_8_S8": "B6T17",
    "B7_S3": "B7T17",
    "B4_S4": "B4T19"}

    data_annotations['barcode'] =  data_annotations['cell_name'].str.split(".", expand = True)[0]
    data_annotations['Sample'] =  data_annotations['cell_name'].str.split(".", expand = True)[1]
    data_annotations['Bag/Day'] = data_annotations['Sample'].map(dict_samples)
    data_annotations['cell_name'] = data_annotations['barcode']+ "-1." + data_annotations['Bag/Day']
    return data_annotations

### Find viral contigs and measure viral expression

In [7]:
def annotation_dictionary(annotations):
    dict_contigs_percell = dict(zip(annotations.cell_name,annotations["Annotation"]))

    clean_dict = {k: dict_contigs_percell[k] for k in dict_contigs_percell if not pd.isna(dict_contigs_percell[k])}
    clean_dict = {k: clean_dict[k] for k in clean_dict if not pd.isna(k)}
    return clean_dict

def get_viral_contigs(path_blastx):
    columns = ["qseqid","sseqid","pident","evalue","bitscore","Domain","Supergroup","Family","Genus","Species","cell_barcode"]
    blastx_results = pd.read_table(path_blastx, sep = "\t",index_col = None, header = None, names = columns )
    viral_contigs = list(blastx_results[blastx_results['Domain'] == 'Viruses']['qseqid'])
    return viral_contigs

def viral_expression(data,metadata,viral_contigs):
    series_df = data[[x for x in data.columns if x in viral_contigs]].sum(axis = 1) >= 10
    virus_cells = data[series_df].index
    return virus_cells

Load data

In [8]:
def load_data(path,file, df_type):
    if df_type == "data":
        df = pd.read_pickle(path+file)
    #data_magic = pd.read_pickle(path+"/data_magic_nd.gz")
    elif df_type == "metadata":
        df = pd.read_pickle(path+file)
    elif df_type == "blast":
        df = pd.read_table(path+file, sep = "\t")
    #wells_cells = pd.read_table(path+"/wells_cells.txt")
    else:
        sys.exit("No data type found")
    return df

In [9]:
def sort_legend(handles, labels, to_end):


    handles = list(pd.DataFrame(labels,handles).sort_values(by = 0,ascending = True).index)
    labels = list(pd.DataFrame(labels,handles).sort_values(by = 0,ascending = True).iloc[:,0])
    
    for i in to_end:
        others_idx = labels.index(i)
        labels_poped = labels.pop(others_idx)
        handles_poped = handles.pop(others_idx)
        labels.insert(len(labels),labels_poped)
        handles.insert(len(labels),handles_poped)
    
    
    return(handles, labels )

In [10]:
def virus_host_coexpression_cells_small(metadata_n,dim,virus_cells,dict_colors):
    fig, ax = plt.subplots(figsize=(6,6), dpi = 300)
    print("Removing unidentified cells...")
    metadata_n = metadata_n[metadata_n['host'] != 'Unknown']
    print("Subsetting the data...")
    metadata_n = metadata_n[(metadata_n['UMAP2'] < 0.8) & (metadata_n['UMAP1'] > -0.5)]
    print("n = ",metadata_n.shape[0])
    print("splitting data into highly and lowly infected cells...")
    
    metadata_n1, metadata_n2 = metadata_n[~metadata_n.index.isin(virus_cells)],metadata_n[metadata_n.index.isin(virus_cells)]
    print("n infected = ",metadata_n2.shape[0])
    
    print("n katablepharidacea = ",metadata_n[metadata_n['host'] == 'Katablepharidaceae'].shape[0])
    dim = dim

    dict_colors_n = {your_key: dict_colors[your_key] for your_key in set(metadata_n['host']) }
    scprep.plot.scatter2d(metadata_n1[['{}1'.format(dim),'{}2'.format(dim)]], 
                          c=metadata_n1['host'].map(dict_colors_n), 
                          s=10, 
                          ticks=False, 
                           legend_title='Predicted host',
                          legend = False,
                          ax=ax)
    
    scprep.plot.scatter2d(metadata_n2[['{}1'.format(dim),'{}2'.format(dim)]], 
                          c=metadata_n2['host'].map(dict_colors_n), 
                          s = 60,
                          vmin = 0,
                          vmax = 1,
                          ticks=False, 
                          xlabel=None, ylabel=None,
                          edgecolor = "black",
                          linewidths = 1,
                          legend=False, 
                          ax=ax)
    
    ax.invert_yaxis()
    ax.invert_xaxis()
    
    plt.tight_layout()
    plt.show()
    


In [11]:
# threshold for UMIs
def virus_host_coexpression_cells(metadata_n,dim,virus_cells,dict_colors):
    fig, ax = plt.subplots(figsize=(6,6), dpi = 300)
    print("Removing unidentified cells...")
    metadata_n = metadata_n[metadata_n['host'] != 'Unknown']
    print("n = ",metadata_n.shape[0])
   
    dict_colors_n = {your_key: dict_colors[your_key] for your_key in set(metadata_n['host']) }
    scprep.plot.scatter2d(metadata_n[['{}1'.format(dim),'{}2'.format(dim)]], 
                          c=metadata_n['host'].map(dict_colors_n), 
                          s=2, 
                          ticks=False, 
                         
                           legend_title='Predicted host',
                    
                          ax=ax)
    
   
    markers = [plt.Line2D([0,0],[0,0],color=color, marker='o', linestyle='') for color in dict_colors_n.values()]
    handles, labels = sort_legend(markers, dict_colors_n.keys(),['Other eukaryotes'])
    ax.legend(handles, labels, 
              numpoints=1,
              title = 'Predicted host',
              title_fontsize = 12,
            loc='upper left',
            bbox_to_anchor=(-0.02, 1),
              fontsize = 12,
               frameon=False,
              handletextpad = 0.3,
              labelspacing = 0.2,
             )
    ax.invert_yaxis()
    ax.invert_xaxis()
    
    plt.tight_layout()
    plt.show()
    


In [12]:
def find_kata_cells(path, metadata, name_file,virus_cells):
    kata_b4t20 = metadata.loc[[x for x in metadata.index if x in virus_cells]][(metadata['sample'] == 'B4T20') & (metadata['host'] == 'Katablepharidaceae')]
    kata_b4t20_cells = kata_b4t20.index.str.split("-").str[0]
    with open(path+name_file, 'w') as f:
        for line in kata_b4t20_cells:
            f.write("%s\n" % line)

In [13]:
def main(path,
         data_file,
         data_raw_file,
         metadata_file,
        blastx_file,
         metapr2_file,
         pr2_file,
        replacement_file,
        out_file):

    if 'data_raw' not in locals():
        data_raw = load_data(path ,data_raw_file, "data")
    if 'data' not in locals():
        data = load_data(path ,data_file, "data")

    metadata = load_data(path , metadata_file, "metadata")


    # Annotate each cell based on the result of the blast search
    data_annotations, dict_colors = annotate(path,
        metapr2_file,
        pr2_file,
        replacement_file)

    annotations = edit_annotations(data_annotations)

    clean_dict = annotation_dictionary(annotations)

    metadata['host'] = metadata.index.map(clean_dict).fillna('Unknown')

    # Find infected cells as those expressing more than 10 viral UMIs
    path_blastx = path + blastx_file

    viral_contigs = get_viral_contigs(path_blastx)
    virus_cells = viral_expression(data_raw,metadata,viral_contigs)

    # plot figure 3a
    virus_host_coexpression_cells(metadata,"UMAP",virus_cells,dict_colors)

    # plot figure 3b as a subset of figure 3a
    virus_host_coexpression_cells_small(metadata,"UMAP",virus_cells,dict_colors)
    
    find_kata_cells(path,metadata, out_file,virus_cells)

In [14]:
'''
Input:
# path = Path for all the files
# data_file = The combined processed UMI table
# data_raw_file = The combined raw UMI table
# metadata_file = Metadata containing the UMAP coordinated for all the cells
# blastx_file = The virus homology results for the assembled transcripts (blastx)
# metapr2_file = The host homology results for the assembled transcripts againt metapr2 database
# pr2_file = The host homology results for the assembled transcripts againt pr2 database
# replacement_file = A tab delimited file for replacing taxonomic names and assigning colors for each taxonomic group.
Must contain most of the main groups present in the analysis (an inspection of the pr2 and metapr2 results is required)
Contains the following columns:
    old_name: The taxonomic name to replace (For example: Bacillariophyta)
    level: The taxonomic level of the old name (For example: Class)
    new_name: The new name to use in the plot (For example: Diatoms)
    color: The color assigned to the taxonomic group in the Sankey plot (For example: #8C613C)

# out_file = The name of the output file for the infected Katablepharidacea cells

'''

if __name__ == "__main__":
    main(path = '/Users/amirf/Dropbox (Weizmann Institute)/scGVDB/pipeline_files/',
         data_file = "UMI_tables/data.pickle.gz",
         data_raw_file = "UMI_tables/data_raw.pickle.gz",
         metadata_file = "UMI_tables/metadata_dimentionality_reduction_clusters.pickle.gz",
        blastx_file = 'cells.filtered.blastx.tsv',
         metapr2_file = 'summary_metapr2_all.tsv',
         pr2_file = 'summary_pr2_all.tsv',
        replacement_file = 'replacement_taxa_allcells.txt',
        out_file = 'kata_cells.txt')
  

Subsetting metapr2 data
Subsetting pr2 data
Annotating cells
Create a dataset for host
n =  18793
Create a dataset for host
n =  612
splitting data into highly and lowly infected cells
n infected =  81
n katablepharidacea =  37


  kata_b4t20 = metadata.loc[[x for x in metadata.index if x in virus_cells]][(metadata['sample'] == 'B4T20') & (metadata['host'] == 'Katablepharidaceae')]
