## Setting up environment

In [2]:
import os

os.chdir('/home/yz979/code/kaggle-perturbation/')
os.environ['CUDA_VISIBLE_DEVICES'] = '3'

In [3]:
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Union

import anndata
import networkx as nx
import numpy as np
import pandas as pd
import pickle
import requests
import scanpy as sc
import torch
import torch.nn as nn
from tqdm import tqdm

## Load Gene Names
Get all gene names from ``de_train.h5ad``.

In [15]:
data_path = 'data/'
de_train = sc.read_h5ad(os.path.join(data_path, 'de_train.h5ad'))

gene_list = de_train.var_names.tolist()
node_map = {gene: i for i, gene in enumerate(gene_list)}

## Download Gene Ontology (GO) from Dataverse
This function downloads the GO from Dataverse. The GO is used to annotate the genes in the network. The GO is downloaded from the Dataverse repository and saved in the `data` folder. The following function is used to download the GO from Dataverse.

- ``url``: The URL of the GO file in Dataverse.
- ``save_path``: The path to save the GO file. If already have the GO file, then the function will not download the file again.

In [4]:
def dataverse_download(url, save_path):
    """
    Dataverse download helper with progress bar

    Args:
        url (str): the url of the dataset
        save_path (str): the path to save the dataset 
    """
    
    if os.path.exists(save_path):
        print('Found local copy...')
    else:
        print("Downloading...")
        response = requests.get(url, stream=True)
        total_size_in_bytes= int(response.headers.get('content-length', 0))
        block_size = 1024
        progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
        with open(save_path, 'wb') as file:
            for data in response.iter_content(block_size):
                progress_bar.update(len(data))
                file.write(data)
        progress_bar.close()

In [5]:
data_path = 'data/grn/'
if not os.path.exists(os.path.join(data_path, 'gene2go.pkl')):
    # download gene2go.pkl
    server_path = 'https://dataverse.harvard.edu/api/access/datafile/6153417' 
    dataverse_download(server_path, os.path.join(data_path, 'gene2go.pkl'))
    
with open(os.path.join(data_path, 'gene2go.pkl'), 'rb') as f:
    gene2go = pickle.load(f)

## Gene Ontology (GO) Graph
We follow the instructions from ``GEARS[Roohani et al., 2023]`` and construct a GO graph. The following function is used to construct the GO graph.

- ``gene_list``: The list of genes to be annotated.
- ``gene2go``: The GO annotations of all genes.
- ``save_path``: The path to save the GO graph. If already have the GO graph, then the function will not construct the graph again.
- ``threshold``: The threshold to filter the GO graph. The default value is 0.1.

In [8]:
def go_auto(gene_list, gene2go, save_path, threshold=0.1):
    """
    Generate gene ontology data

    Args:
        gene_list (list): list of genes
        gene2go (dict): gene2go mapping
        save_path (str): the path to save the dataset
        threshold (float): threshold for filtering edges, defaults to 0.1.

    Returns:
        pd.DataFrame: gene ontology data
    """

    
    if os.path.exists(save_path):
        return pd.read_csv(save_path)

    # filter gene2go mapping to current genes
    gene2go = {i: list(gene2go[i]) for i in gene_list if i in gene2go}
    edge_list = []
    for g1 in tqdm(gene2go.keys()):
        for g2 in gene2go.keys():
            edge_list.append((
                g1, g2, 
                len(np.intersect1d(gene2go[g1], gene2go[g2])) / len(np.union1d(gene2go[g1], gene2go[g2]))
            ))

    # filter edges
    edge_list = [i for i in edge_list if i[2] > threshold]
    edge_df = pd.DataFrame(edge_list).rename(columns={0: 'gene1', 
                                                      1: 'gene2',
                                                      2: 'score'})

    edge_df = edge_df.rename(columns={'gene1': 'source',
                                      'gene2': 'target',
                                      'score': 'importance'})
    edge_df.to_csv(go_path, index=False)
    return edge_df

In [10]:
data_path = 'data/grn/'
go_graph = go_auto(gene_list, gene2go, os.path.join(data_path, 'go.csv'))

## Gene Co-expression Network
We follow the instructions from ``GEARS[Roohani et al., 2023]`` and construct a gene co-expression network. The following function is used to construct the gene co-expression network.

- ``adata``: The AnnData object of the dataset.
- ``save_path``: The path to save the gene co-expression network. If already have the gene co-expression network, then the function will not construct the network again.
- ``threshold``: The threshold to construct the gene co-expression network. The default value is 0.1.

In [11]:
def coexpress_auto(adata, save_path, threshold=0.1, method='pearson'):
    """
    Generate coexpression data
    
    Args:
        adata (anndata.AnnData): anndata object
        save_path (str): the path to save the dataset
        threshold (float): threshold for filtering edges, defaults to 0.1.
        method (str): method for calculating correlation, defaults to 'pearson'.

    Returns:
        pd.DataFrame: coexpression data
    """

    if os.path.exists(save_path):
        return pd.read_csv(save_path)
    
    df = adata.to_df()
    gene_names = df.columns
    
    # calculate correlation matrix
    cor_matrix = df.corr(method=method)

    # filter edges
    edges = []
    for i in range(len(cor_matrix)):
        for j in range(i+1, len(cor_matrix)):
            if abs(cor_matrix.iloc[i, j]) > threshold:
                edges.append((gene_names[i], gene_names[j], cor_matrix.iloc[i, j]))
    
    edge_df = pd.DataFrame(edges, columns=['source', 'target', 'importance'])
    edge_df.to_csv(gc_path, index=False)
    return edge_df

In [12]:
data_path = 'data/grn/'
gc_graph = coexpress_auto(de_train, os.path.join(data_path, 'gc.csv'))

## Gene Regulatory Network (GRN)
We construct a gene regulatory network from either the gene co-expression network or the GO graph. The following function is used to construct the gene regulatory network.

- ``network``: The gene co-expression network or the GO graph.
- ``gene_list``: The list of genes to be annotated.
- ``node_map``: The mapping from gene names to node indices.

In [23]:
def gene_similarity_network(network, gene_list, node_map, save_path=None):
    """
    Generate gene similarity network

    Args:
        network (pd.DataFrame): gene similarity data
        gene_list (list): list of genes
        node_map (dict): mapping from gene to node
        save_path (str): the path to save the dataset, defaults to None.

    Returns:
        edge_index (np.array): edge index
        edge_weight (np.array): edge weight
    """

    G = nx.from_pandas_edgelist(network, source='source',
                                target='target', edge_attr=['importance'],
                                create_using=nx.Graph())
    G.add_nodes_from(gene_list)
    G.add_weighted_edges_from(network.values)
    G.remove_nodes_from([n for n in G.nodes if n not in node_map])

    edge_index = np.array([(node_map[e[0]], node_map[e[1]]) for e in G.edges]).T
    edge_attr = nx.get_edge_attributes(G, 'importance') 
    edge_weight = np.array([edge_attr[e] for e in G.edges])

    if save_path:
        np.save(os.path.join(save_path, 'edge_index.npy'), edge_index)
        np.save(os.path.join(save_path, 'edge_weight.npy'), edge_weight)
    return edge_index, edge_weight
    

In [24]:
gene_similarity_network(go_graph, gene_list, node_map)

(array([[    0,     0,     0, ..., 16791, 16967, 17637],
        [    0,     2,    71, ..., 16791, 16967, 17637]]),
 array([1.        , 0.18181818, 0.125     , ..., 1.        , 1.        ,
        1.        ]))