# Gene2Vec

In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
# Imports
import numpy as np
import pandas as pd

from pathlib import Path
from util import *
from tcga_dna import *

# Pre-processing MsigDb

In [3]:
def read_msigdb(path):
    """ Read an MsigDb (single) file and return a dictionary (by gene set) of dictionaries (gene name: 1) """
    msigdb = dict()
    for line in path.read_text().split('\n'):
        fields = line.split('\t')
        geneset_name = fields[0]
        if geneset_name:
            msigdb[geneset_name] = {f: 1 for f in fields[2:]}
    return msigdb


def read_msigdb_all(path, regex="*.gmt"):
    """ Read all MsigDb files, return a dictionary of lists of gene names """
    msigdb = dict()
    for p in path.find_files(regex):
        print(f"File: {p}")
        msigdb.update(read_msigdb(p))
    return msigdb


def msigdb2genes(msigdb):
    """ Get a (sorted) list of all genes in MsigDb """
    genes = set([g for by_gene in msigdb.values() for g in by_gene.keys()])
    genes = list(genes)
    genes.sort()
    return genes


def msigdb2gene_sets(msigdb):
    """ Get a (sorted) list of all Gene-Sets in MsigDb """
    gs = list(msigdb.keys())
    gs.sort()
    return gs


def msigdb2df(path):
    """ Read all MsigDb in the path and create a dataframe """
    msigdb = read_msigdb_all(path)
    df = pd.DataFrame(msigdb, dtype='int8', index=msigdb2genes(msigdb), columns=msigdb2gene_sets(msigdb))
    df.fillna(0, inplace=True)
    return df.transpose()


def geneset_gene_pairs(msigdb):
    """ Iterate over all (geneset, gene) pairs from MsigDb dictionary """
    for gs, genes in msigdb.items():
        for gene in genes.keys():
            yield gs, gene

def save_pairs(msigdb, path_save):
    pairs_str = '\n'.join([f"{gs},{g}" for gs,g in geneset_gene_pairs(msigdb)])
    path_save.write_text(pairs_str)

In [7]:
path = Path('data/msigdb')
msigdb = read_msigdb_all(path)
# save_pairs(msigdb, path/'msigdb_pairs.csv')

File: data/msigdb/c6.all.v7.0.symbols.gmt
File: data/msigdb/c2.all.v7.0.symbols.gmt
File: data/msigdb/c7.all.v7.0.symbols.gmt
File: data/msigdb/c5.all.v7.0.symbols.gmt
File: data/msigdb/h.all.v7.0.symbols.gmt


In [5]:
# %%time
# df = msigdb2df(path)
# df.shape, df.sum().sum()

File: data/msigdb/c6.all.v7.0.symbols.gmt
File: data/msigdb/c2.all.v7.0.symbols.gmt
File: data/msigdb/c7.all.v7.0.symbols.gmt
File: data/msigdb/c5.all.v7.0.symbols.gmt
File: data/msigdb/h.all.v7.0.symbols.gmt
CPU times: user 1min 13s, sys: 3.06 s, total: 1min 16s
Wall time: 1min 16s


((20608, 23112), 2320475.0)

In [6]:
# %%time
# from sklearn.decomposition import NMF

# for n in [10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 150, 200]:
#     nmf = NMF(n_components=n)
#     nmf_fit = nmf.fit(df)
#     print(f"n: {n}, reconstruction_err_: {nmf.reconstruction_err_}, n_iter_: {nmf.n_iter_}")


n: 10, reconstruction_err_: 1456.1366669984125, n_iter_: 94
n: 20, reconstruction_err_: 1433.0421291386274, n_iter_: 71
n: 30, reconstruction_err_: 1415.7370651881865, n_iter_: 199
n: 40, reconstruction_err_: 1401.893353272043, n_iter_: 182
n: 50, reconstruction_err_: 1391.152196044853, n_iter_: 103
n: 60, reconstruction_err_: 1381.664265597847, n_iter_: 140
n: 70, reconstruction_err_: 1373.7059099861337, n_iter_: 143
n: 80, reconstruction_err_: 1366.3325620239789, n_iter_: 177
n: 90, reconstruction_err_: 1359.61818469111, n_iter_: 199
n: 100, reconstruction_err_: 1353.0666435705025, n_iter_: 199
n: 150, reconstruction_err_: 1326.5354100212003, n_iter_: 106
n: 200, reconstruction_err_: 1305.9657480770732, n_iter_: 199
CPU times: user 6h 54min 14s, sys: 32min 54s, total: 7h 27min 9s
Wall time: 2h 9min 7s


In [None]:
# n: 10, reconstruction_err_: 1456.1366669984125, n_iter_: 94
# n: 20, reconstruction_err_: 1433.0421291386274, n_iter_: 71
# n: 30, reconstruction_err_: 1415.7370651881865, n_iter_: 199
# n: 40, reconstruction_err_: 1401.893353272043, n_iter_: 182
# n: 50, reconstruction_err_: 1391.152196044853, n_iter_: 103
# n: 60, reconstruction_err_: 1381.664265597847, n_iter_: 140
# n: 70, reconstruction_err_: 1373.7059099861337, n_iter_: 143
# n: 80, reconstruction_err_: 1366.3325620239789, n_iter_: 177
# n: 90, reconstruction_err_: 1359.61818469111, n_iter_: 199
# n: 100, reconstruction_err_: 1353.0666435705025, n_iter_: 199
# n: 150, reconstruction_err_: 1326.5354100212003, n_iter_: 106
# n: 200, reconstruction_err_: 1305.9657480770732, n_iter_: 199
# CPU times: user 6h 54min 14s, sys: 32min 54s, total: 7h 27min 9s
# Wall time: 2h 9min 7s
    

In [19]:
def save_by_gene(msigdb, path, sep=','):
    """ Save to file as 'gene\tgenesets' """
    genesets_by_gene = dict()
    # Reverse mapping (Gene -> GeneSet)
    for gs in sorted(list(msigdb.keys())):
        for g in msigdb[gs].keys():
            if g not in genesets_by_gene:
                genesets_by_gene[g] = gs
            else:
                genesets_by_gene[g] += ' ' + gs
    # Create string to save
    strout = f"gene{sep}genesets\n"
    strout += '\n'.join([f"{g}{sep}{genesets_by_gene[g]}" for g in sorted(list(genesets_by_gene.keys()))])
    path.write_text(strout)

save_by_gene(msigdb, path/'msigdb.by_gene.csv')