# Gene2Vec

In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
# Imports
import numpy as np
import pandas as pd

from pathlib import Path

from fastai.collab import *
from fastai.tabular import *
from tcga.util import *
from tcga.tcga_dna import *

# Pre-processing MsigDb

In [3]:
def read_msigdb(path):
    """ Read an MsigDb (single) file and return a dictionary (by gene set) of dictionaries (gene name: 1) """
    msigdb = dict()
    for line in path.read_text().split('\n'):
        fields = line.split('\t')
        geneset_name = fields[0]
        if geneset_name:
            msigdb[geneset_name] = {f: 1 for f in fields[2:]}
    return msigdb


def read_msigdb_all(path, regex="*.gmt"):
    """ Read all MsigDb files, return a dictionary of lists of gene names """
    msigdb = dict()
    for p in path.find_files(regex):
        print(f"File: {p}")
        msigdb.update(read_msigdb(p))
    return msigdb


def msigdb2genes(msigdb):
    """ Get a (sorted) list of all genes in MsigDb """
    genes = set([g for by_gene in msigdb.values() for g in by_gene.keys()])
    genes = list(genes)
    genes.sort()
    return genes


def msigdb2gene_sets(msigdb):
    """ Get a (sorted) list of all Gene-Sets in MsigDb """
    gs = list(msigdb.keys())
    gs.sort()
    return gs


def msigdb2df(path):
    """ Read all MsigDb in the path and create a dataframe """
    msigdb = read_msigdb_all(path)
    df = pd.DataFrame(msigdb, dtype='int8', index=msigdb2genes(msigdb), columns=msigdb2gene_sets(msigdb))
    df.fillna(0, inplace=True)
    return df.transpose()


def gene_genesets(msigdb):
    """ Iterate over all (geneset, gene) pairs from MsigDb dictionary """
    ggs = dict()
    for gs, genes in msigdb.items():
        for gene in genes.keys():
            if gene not in ggs:
                ggs[gene] = gs
            else:
                ggs[gene] += f" {gs}"
    return ggs


def geneset_gene_pairs(msigdb):
    """ Iterate over all (geneset, gene) pairs from MsigDb dictionary """
    for gs, genes in msigdb.items():
        for gene in genes.keys():
            yield gs, gene


def save_gene_geneset(msigdb, path_save):
    pairs_str = '\n'.join([f"{g},{gss}" for g,gss in gene_genesets(msigdb).items()])
    path_save.write_text(pairs_str)

In [4]:
path = Path('data/msigdb')
msigdb = read_msigdb_all(path)

File: data/msigdb/c5.all.v7.0.symbols.gmt
File: data/msigdb/h.all.v7.0.symbols.gmt
File: data/msigdb/c7.all.v7.0.symbols.gmt
File: data/msigdb/c2.all.v7.0.symbols.gmt
File: data/msigdb/c6.all.v7.0.symbols.gmt


In [5]:
save_gene_geneset(msigdb, path/'gene2genesets.csv')

In [6]:
! head -n 100 {path/'gene2genesets.csv'} > {path/'gene2genesets.head.csv'}

# Tabular Learner

In [7]:
msigdb_df = pd.read_csv(path/'gene2genesets.head.csv', names=['gene', 'genesets'])
print(f"Shape: {msigdb_df.shape}")
msigdb_df.head()

Shape: (100, 2)


Unnamed: 0,gene,genesets
0,CDK9,GO_POSITIVE_REGULATION_OF_VIRAL_TRANSCRIPTION ...
1,CHD1,GO_POSITIVE_REGULATION_OF_VIRAL_TRANSCRIPTION ...
2,DHX9,GO_POSITIVE_REGULATION_OF_VIRAL_TRANSCRIPTION ...
3,EP300,GO_POSITIVE_REGULATION_OF_VIRAL_TRANSCRIPTION ...
4,SNW1,GO_POSITIVE_REGULATION_OF_VIRAL_TRANSCRIPTION ...


In [8]:
procs = [Categorify, Normalize]
dep_var = 'genesets'
cat_names = ['gene']

msigdb_df2 = pd.concat([msigdb_df, msigdb_df])
valid_idx = range(len(msigdb_df), len(msigdb_df2))

In [9]:
# data = TabularDataBunch.from_df(path=path, df=msigdb_df2, dep_var=dep_var, valid_idx=valid_idx, procs=procs, cat_names=cat_names)
# data

data = (TabularList.from_df(msigdb_df2, cat_names=cat_names, procs=procs)
        .split_by_rand_pct()
        .label_from_df(label_delim=' ')
        .databunch()
        )
data

TabularDataBunch;

Train: LabelList (160 items)
x: TabularList
gene CDK9; ,gene DHX9; ,gene EP300; ,gene SNW1; ,gene RRP1B; 
y: MultiCategoryList
GO_POSITIVE_REGULATION_OF_VIRAL_TRANSCRIPTION;GO_DNA_DEPENDENT_DNA_REPLICATION_MAINTENANCE_OF_FIDELITY;GO_CHROMOSOME_ORGANIZATION;GO_DNA_TEMPLATED_TRANSCRIPTION_ELONGATION;GO_DNA_DEPENDENT_DNA_REPLICATION;GO_POSITIVE_REGULATION_OF_BINDING;GO_PROTEIN_MODIFICATION_BY_SMALL_PROTEIN_CONJUGATION_OR_REMOVAL;GO_CELL_CYCLE_ARREST;GO_POSITIVE_REGULATION_OF_BIOSYNTHETIC_PROCESS;GO_POSITIVE_REGULATION_OF_ORGANELLE_ORGANIZATION;GO_REGULATION_OF_DNA_REPAIR;GO_POSITIVE_REGULATION_OF_CELLULAR_COMPONENT_ORGANIZATION;GO_POSITIVE_REGULATION_OF_GENE_EXPRESSION;GO_REGULATION_OF_TRANSCRIPTION_ELONGATION_FROM_RNA_POLYMERASE_II_PROMOTER;GO_NCRNA_TRANSCRIPTION;GO_POSITIVE_REGULATION_OF_MOLECULAR_FUNCTION;GO_MUSCLE_STRUCTURE_DEVELOPMENT;GO_NEGATIVE_REGULATION_OF_CELL_CYCLE_PROCESS;GO_DNA_REPAIR;GO_POSITIVE_REGULATION_OF_TRANSCRIPTION_ELONGATION_FROM_RNA_POLYMERASE_II

In [10]:
class TabularModelZzz(Module):
    "Basic model for tabular data."
    def __init__(self, emb_szs:ListSizes, n_cont:int, out_sz:int, layers:Collection[int], ps:Collection[float]=None,
                 emb_drop:float=0., y_range:OptRange=None, use_bn:bool=True, bn_final:bool=False):
        super().__init__()
        ps = ifnone(ps, [0]*len(layers))
        ps = listify(ps, layers)
        self.embeds = nn.ModuleList([embedding(ni, nf) for ni,nf in emb_szs])
        self.emb_drop = nn.Dropout(emb_drop)
        self.bn_cont = nn.BatchNorm1d(n_cont)
        n_emb = sum(e.embedding_dim for e in self.embeds)
        self.n_emb,self.n_cont,self.y_range = n_emb,n_cont,y_range
        sizes = self.get_sizes(layers, out_sz)
        actns = [nn.ReLU(inplace=True) for _ in range(len(sizes)-2)] + [None]
        layers = []
        for i,(n_in,n_out,dp,act) in enumerate(zip(sizes[:-1],sizes[1:],[0.]+ps,actns)):
            layers += bn_drop_lin(n_in, n_out, bn=use_bn and i!=0, p=dp, actn=act)
        if bn_final: layers.append(nn.BatchNorm1d(sizes[-1]))
        self.layers = nn.Sequential(*layers)

    def get_sizes(self, layers, out_sz):
        return [self.n_emb + self.n_cont] + layers + [out_sz]

    def forward(self, x_cat:Tensor, x_cont:Tensor) -> Tensor:
        print(x_cat)
        if self.n_emb != 0:
            x = [e(x_cat[:,i]) for i,e in enumerate(self.embeds)]
            x = torch.cat(x, 1)
            x = self.emb_drop(x)
        if self.n_cont != 0:
            x_cont = self.bn_cont(x_cont)
            x = torch.cat([x, x_cont], 1) if self.n_emb != 0 else x_cont
        x = self.layers(x)
        if self.y_range is not None:
            x = (self.y_range[1]-self.y_range[0]) * torch.sigmoid(x) + self.y_range[0]
        return x
    
def tabular_learner_zzz(data:DataBunch, layers:Collection[int], emb_szs:Dict[str,int]=None, metrics=None,
        ps:Collection[float]=None, emb_drop:float=0., y_range:OptRange=None, use_bn:bool=True, **learn_kwargs):
    "Get a `Learner` using `data`, with `metrics`, including a `TabularModel` created using the remaining params."
    emb_szs = data.get_emb_szs(ifnone(emb_szs, {}))
    model = TabularModelZzz(emb_szs, len(data.cont_names), out_sz=data.c, layers=layers, ps=ps, emb_drop=emb_drop,
                         y_range=y_range, use_bn=use_bn)
    return Learner(data, model, metrics=metrics, **learn_kwargs)

In [11]:
acc_02 = partial(accuracy_thresh, thresh=0.2)
f_score = partial(fbeta, thresh=0.2)

learn = tabular_learner_zzz(data, layers=[10], emb_szs={'gene': 10}, metrics=[acc_02, f_score])

In [12]:
learn.model

TabularModelZzz(
  (embeds): ModuleList(
    (0): Embedding(98, 10)
  )
  (emb_drop): Dropout(p=0.0, inplace=False)
  (bn_cont): BatchNorm1d(0, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layers): Sequential(
    (0): Linear(in_features=10, out_features=10, bias=True)
    (1): ReLU(inplace=True)
    (2): BatchNorm1d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): Linear(in_features=10, out_features=9130, bias=True)
  )
)

In [13]:
learn.summary()

tensor([[83]])


TabularModelZzz
Layer (type)         Output Shape         Param #    Trainable 
Embedding            [10]                 980        True      
______________________________________________________________________
Dropout              [10]                 0          False     
______________________________________________________________________
Linear               [10]                 110        True      
______________________________________________________________________
ReLU                 [10]                 0          False     
______________________________________________________________________
BatchNorm1d          [10]                 20         True      
______________________________________________________________________
Linear               [9130]               100,430    True      
______________________________________________________________________

Total params: 101,540
Total trainable params: 101,540
Total non-trainable params: 0
Optimized with 'torch.opt

### Multilabel classification in fastai

Reference: https://forums.fast.ai/t/multi-label-classification-how-does-it-work/42154/5

If you take a look at the data_block.py file in the fastai library, you’ll see that when you create your data, it either creates a CategoryList (mono-label classification) or a MultiCategoryList(multi-label classification) depending on your case. It then affects the adequate loss function (Categorical Cross Entropy or Binary Cross Entropy in our case)

Then, in the basic_train.py file, you can see that each loss function is linked to a particular final activation function (CE with softmax and BCE with sigmoid), that will be appended at the end of your model when you do the predictions.

In [None]:
# learn.lr_find()

In [None]:
learn.fit_one_cycle(1, 1e-2)

In [None]:
??learn.predict

In [None]:
x = data[0]
x

In [None]:
learn.predict(x[0])