# Gene2Vec: Pytorch


References:

https://pytorch.org/tutorials/beginner/nlp/word_embeddings_tutorial.html

https://discuss.pytorch.org/t/what-kind-of-loss-is-better-to-use-in-multilabel-classification/32203


In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
# Imports
import numpy as np
import pandas as pd

from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from tcga.msigdb import *
from tcga.util import *
from torch.utils.data import Dataset, DataLoader

# Dataset

In [10]:
def names2num(names):
    """ Create a mapping from names to numbers """
    names = list(set(names))  # Make sure names are unique
    return {n: i for i,n in enumerate(sorted(names))}
    
class DatasetMsigDb(Dataset):
    """ 
    Custom dataset: We have to override methods __len__ and __getitem__
    In our network, the inputs are genes and the outputs are gene-sets.
    We convert genes and gene sets to numbers, then store the forward and
    reverse mappings.
    We encode gene-> gene_sets as a dictionary indexed by gene, with
    tensors having 1 on the GeneSets the gene belongs to.
    The method __getitem__ returns a tuple with the gene number and
    the GeneSet tensor
    """
    def __init__(self, path):
        self.msigdb = read_msigdb_all(path)
        # Gene <-> number: forward and reverse mapping
        self.gene2num = names2num(msigdb2genes(self.msigdb))
        self.num2gene = {n: g for g, n in self.gene2num.items()}
        # GeneSet <-> number: forward and reverse mapping
        self.geneset2num = names2num(msigdb2gene_sets(self.msigdb))
        self.num2geneset = {n: g for g, n in self.gene2num.items()}
        # Gene -> GeneSets mapping (use gene_set numbers, in a tensor)
        self.init_genesets()

    def genesets2num(self, genesets):
        " Convert to a list of numerically encoded gene-sets "
        return [self.geneset2num[gs] for gs in genesets]

    def genesets2tensor(self, genesets):
        " Convert to a vector having 1 in each geneset position "
        geneset_idxs = [self.geneset2num[gs] for gs in genesets]
        geneset_tensor = torch.zeros(len(self.msigdb))
        geneset_tensor[geneset_idxs] = 1
        return geneset_tensor
        
    def init_genesets(self):
        " Map Gene to GeneSets. GeneSets are hot-encoded "
        self.gene_genesets = dict()
        self.gene_genesets_num = dict()
        self.gene_genesets_tensors = dict()
        num_genesets = len(self.geneset2num)
        for gene, genesets in gene_genesets(self.msigdb).items():
            self.gene_genesets[gene] = genesets
            self.gene_genesets_num[gene] = self.genesets2num(genesets)
            self.gene_genesets_tensors[gene] = self.genesets2tensor(genesets)
    
    def __len__(self):
        " Len: Count number of genes "
        return len(self.gene2num)

    def gene_sets_count(self):
        " Count number of gene sets "
        return len(self.msigdb)
    
    def __getitem__(self, idx):
        " Get item 'idx': A tuple of gene number 'idx' and gene set tensor for that gene "
        gene = self.num2gene[idx]
        return (idx, self.gene_genesets_tensors[gene])
    
    def __repr__(self):
        return self.__str__()
    
    def __str__(self):
        " Show (a few) mappings gene -> gene_set tensor "
        out = f"Genes: {len(self.gene2num)}, Gene Sets: {len(self.geneset2num)}"
        for i in range(10):  #range(len(self)):
            idx, gs_tensor = self[i]
            gene = self.num2gene[idx]
            out += f"\t{gene}: {idx}\t{self.gene_genesets[gene]}\t{self.gene_genesets_num[gene]}\t{gs_tensor}\n"
        return out + "..."


In [11]:
path = Path('data/msigdb/small')
dataset = DatasetMsigDb(path)
dataset

File: data/msigdb/small/h.all.v7.0.symbols.gmt


IndexError: too many indices for tensor of dimension 1

In [73]:
acc_02 = partial(accuracy_thresh, thresh=0.2)
f_score = partial(fbeta, thresh=0.2)

learn = tabular_learner_zzz(data, layers=[10], emb_szs={'gene': 10}, metrics=[acc_02, f_score])

NameError: name 'partial' is not defined

In [221]:
class Gene2GeneSetModule(nn.Module):
    def __init__(self, dataset_msigdb, embedding_dim):
        super(Gene2GeneSetModule, self).__init__()
        genes_vocab_size = len(dataset_msigdb)
        genesets_vocab_size = dataset_msigdb.gene_sets_count()
        self.embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.linear1 = nn.Linear(embedding_dim, 10)
        self.linear2 = nn.Linear(10, genesets_vocab_size)

    def forward(self, inputs):
        embeds = self.embeddings(inputs).view((1, -1))
        out = F.relu(self.linear1(embeds))
        out = self.linear2(out)
        log_probs = F.log_softmax(out, dim=1)
        return log_probs

In [222]:
loss = nn.CrossEntropyLoss()
yhat = torch.randn(3, 5, requires_grad=True)
yhat

tensor([[-1.2933,  1.3426,  0.0185,  0.7019,  0.2557],
        [ 0.7450, -0.2813,  0.6328, -0.1570, -0.5152],
        [ 1.5880, -0.4097, -0.6683,  0.0106, -1.3654]], requires_grad=True)

In [223]:
y = torch.empty(3, dtype=torch.long).random_(5)
y

tensor([4, 2, 3])

In [224]:
output = loss(yhat, y)
output

tensor(1.6832, grad_fn=<NllLossBackward>)

# Cross-entropy loss

In [235]:
def xeloss(yhat, y):
    s = 0
    for i in range(len(yhat)):
        c = y[i]
        sl = yhat[i].exp().sum().log()
        s += - yhat[i, c] + sl
    return s / len(yhat)

def xeloss2(yhat, y):
    yhat_c = yhat.gather(1, idx).squeeze()
    exp_sum_log = yhat.exp().sum(axis=1).log() 
    return (-yhat_c + exp_sum_log).mean()


In [237]:
xeloss(yhat, y), xeloss2(yhat, y)

(tensor(1.6832, grad_fn=<DivBackward0>),
 tensor(1.6832, grad_fn=<MeanBackward0>))