# Gene2Vec: Pytorch


References:

https://pytorch.org/tutorials/beginner/nlp/word_embeddings_tutorial.html

https://discuss.pytorch.org/t/what-kind-of-loss-is-better-to-use-in-multilabel-classification/32203


In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
# Imports
import numpy as np
import pandas as pd

from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from tcga.msigdb import *
from tcga.util import *
from torch.utils.data import Dataset, DataLoader

# Dataset

In [318]:
def names2num(names):
    """ Create a mapping from names to numbers """
    names = list(set(names))  # Make sure names are unique
    return {n: i for i,n in enumerate(sorted(names))}
    
class DatasetMsigDb(Dataset):
    """ 
    Custom dataset: We have to override methods __len__ and __getitem__
    In our network, the inputs are genes and the outputs are gene-sets.
    We convert genes and gene sets to numbers, then store the forward and
    reverse mappings.
    
    Genes: One hot encoding
    
    Gene sets: We encode gene-> gene_sets as a dictionary indexed by
    gene, with tensors having 1 on the GeneSets the gene belongs to.
    
    The method __getitem__ returns a tuple with the gene (one-hot
    encoded) and the gene-set tensor (having ones on all gene-sets
    the gene belongs to)
    """
    def __init__(self, path):
        self.msigdb = read_msigdb_all(path)
        # Gene <-> number: forward and reverse mapping
        self.gene2num = names2num(msigdb2genes(self.msigdb))
        self.num2gene = {n: g for g, n in self.gene2num.items()}
        # GeneSet <-> number: forward and reverse mapping
        self.geneset2num = names2num(msigdb2gene_sets(self.msigdb))
        self.num2geneset = {n: g for g, n in self.gene2num.items()}
        # Gene -> GeneSets mapping (use gene_set numbers, in a tensor)
        self.init_genes()
        self.init_genesets()

    def genesets2num(self, genesets):
        " Convert to a list of numerically encoded gene-sets "
        return [self.geneset2num[gs] for gs in genesets]

    def gene2tensor(self, gene):
        " Convert to a one-hot encoding "
        gene_tensor = torch.zeros(len(self.gene2num))
        gene_tensor[self.gene2num[gene]] = 1
        return gene_tensor
        
    def genesets2tensor(self, genesets):
        " Convert to a vector having 1 in each geneset position "
        geneset_idxs = [self.geneset2num[gs] for gs in genesets]
        geneset_tensor = torch.zeros(len(self.msigdb))
        geneset_tensor[geneset_idxs] = 1
        return geneset_tensor
        
    def init_genes(self):
        " Create a one-hot encoding for a gene "
        self.gene_tensors = dict()
        for gene in self.gene2num.keys():
            self.gene_tensors[gene] = self.gene2tensor(gene)
        
    def init_genesets(self):
        " Map Gene to GeneSets. GeneSets are hot-encoded "
        self.gene_genesets = dict()
        self.gene_genesets_num = dict()
        self.gene_genesets_tensors = dict()
        num_genesets = len(self.geneset2num)
        for gene, genesets in gene_genesets(self.msigdb).items():
            self.gene_genesets[gene] = genesets
            self.gene_genesets_num[gene] = self.genesets2num(genesets)
            self.gene_genesets_tensors[gene] = self.genesets2tensor(genesets)
    
    def __len__(self):
        " Len: Count number of genes "
        return len(self.gene2num)

    def gene_sets_size(self):
        " Count number of gene sets "
        return len(self.msigdb)
    
    def __getitem__(self, idx):
        " Get item 'idx': A tuple of gene number 'idx' and gene set tensor for that gene "
        gene = self.num2gene[idx]
        return (self.gene_tensors[gene], self.gene_genesets_tensors[gene])
    
    def __repr__(self):
        return self.__str__()
    
    def __str__(self):
        " Show (a few) mappings gene -> gene_set tensor "
        out = f"Genes: {len(self.gene2num)}, Gene Sets: {len(self.geneset2num)}\n"
        for i in range(10):  #range(len(self)):
            gene = self.num2gene[i]
            gene_tensor, geneset_tensor = self[i]
            out += f"\tGene: {gene}, {i}, {gene_tensor}\n\tGeneSet: {self.gene_genesets[gene]}, {self.gene_genesets_num[gene]}, {geneset_tensor}\n\n"
        return out + "..."


In [319]:
path = Path('data/msigdb/small')
dataset = DatasetMsigDb(path)
dataset

File: data/msigdb/small/h.all.v7.0.symbols.gmt


Genes: 4384, Gene Sets: 50
	Gene: A2M, 0, tensor([1., 0., 0.,  ..., 0., 0., 0.])
	GeneSet: {'HALLMARK_IL6_JAK_STAT3_SIGNALING', 'HALLMARK_COAGULATION'}, [23, 9], tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

	Gene: AAAS, 1, tensor([0., 1., 0.,  ..., 0., 0., 0.])
	GeneSet: {'HALLMARK_DNA_REPAIR'}, [11], tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

	Gene: AADAT, 2, tensor([0., 0., 1.,  ..., 0., 0., 0.])
	GeneSet: {'HALLMARK_FATTY_ACID_METABOLISM'}, [16], tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0.

In [320]:
class Gene2GeneSetModule(nn.Module):
    def __init__(self, dataset_msigdb, embedding_dim, layer_size=10):
        super(Gene2GeneSetModule, self).__init__()
        genes_vocab_size = len(dataset_msigdb)
        genesets_vocab_size = dataset_msigdb.gene_sets_size()
        # self.embeddings = nn.Embedding(genes_vocab_size, embedding_dim)
        self.embeddings = nn.Linear(genes_vocab_size, embedding_dim)
        self.linear1 = nn.Linear(embedding_dim, layer_size)
        self.linear2 = nn.Linear(layer_size, genesets_vocab_size)

    def forward(self, inputs):
        x = F.relu(self.embeddings(inputs))
        x = F.relu(self.linear1(x))
        probs = torch.sigmoid(self.linear2(x))
        return probs

In [321]:
def train(model, dataloader, epochs, lr, momentum):
    device = torch.device("cpu")
    model.train()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    for n_epoch in range(epochs):
        for n_batch, batch in enumerate(dataloader):
            x, y = batch
            optimizer.zero_grad()
            output = model(x)
            loss = F.binary_cross_entropy(output.squeeze(), y)
            loss.backward()
            optimizer.step()
            if n_epoch % 1 == 0 and n_batch == 0:
                print(f"Train Epoch: {n_epoch} / {epochs}\tn_batch: {n_batch}\tLoss: {loss.item():.6f}\tx.shape: {x.shape}\ty.shape: {y.shape}")

In [322]:
dataloader = DataLoader(dataset, batch_size=1000)
model = Gene2GeneSetModule(dataset, 20, 10)
train(model, dataloader, epochs=100, lr=0.01, momentum=0.8)

Train Epoch: 0 / 100	n_batch: 0	Loss: 0.696408	x.shape: torch.Size([1000, 4384])	y.shape: torch.Size([1000, 50])
Train Epoch: 1 / 100	n_batch: 0	Loss: 0.660038	x.shape: torch.Size([1000, 4384])	y.shape: torch.Size([1000, 50])
Train Epoch: 2 / 100	n_batch: 0	Loss: 0.620970	x.shape: torch.Size([1000, 4384])	y.shape: torch.Size([1000, 50])
Train Epoch: 3 / 100	n_batch: 0	Loss: 0.563703	x.shape: torch.Size([1000, 4384])	y.shape: torch.Size([1000, 50])
Train Epoch: 4 / 100	n_batch: 0	Loss: 0.477061	x.shape: torch.Size([1000, 4384])	y.shape: torch.Size([1000, 50])
Train Epoch: 5 / 100	n_batch: 0	Loss: 0.360194	x.shape: torch.Size([1000, 4384])	y.shape: torch.Size([1000, 50])
Train Epoch: 6 / 100	n_batch: 0	Loss: 0.236580	x.shape: torch.Size([1000, 4384])	y.shape: torch.Size([1000, 50])
Train Epoch: 7 / 100	n_batch: 0	Loss: 0.162840	x.shape: torch.Size([1000, 4384])	y.shape: torch.Size([1000, 50])
Train Epoch: 8 / 100	n_batch: 0	Loss: 0.154340	x.shape: torch.Size([1000, 4384])	y.shape: torch.

Train Epoch: 72 / 100	n_batch: 0	Loss: 0.050298	x.shape: torch.Size([1000, 4384])	y.shape: torch.Size([1000, 50])
Train Epoch: 73 / 100	n_batch: 0	Loss: 0.048843	x.shape: torch.Size([1000, 4384])	y.shape: torch.Size([1000, 50])
Train Epoch: 74 / 100	n_batch: 0	Loss: 0.047670	x.shape: torch.Size([1000, 4384])	y.shape: torch.Size([1000, 50])
Train Epoch: 75 / 100	n_batch: 0	Loss: 0.046468	x.shape: torch.Size([1000, 4384])	y.shape: torch.Size([1000, 50])
Train Epoch: 76 / 100	n_batch: 0	Loss: 0.045415	x.shape: torch.Size([1000, 4384])	y.shape: torch.Size([1000, 50])
Train Epoch: 77 / 100	n_batch: 0	Loss: 0.044514	x.shape: torch.Size([1000, 4384])	y.shape: torch.Size([1000, 50])
Train Epoch: 78 / 100	n_batch: 0	Loss: 0.043807	x.shape: torch.Size([1000, 4384])	y.shape: torch.Size([1000, 50])
Train Epoch: 79 / 100	n_batch: 0	Loss: 0.043033	x.shape: torch.Size([1000, 4384])	y.shape: torch.Size([1000, 50])
Train Epoch: 80 / 100	n_batch: 0	Loss: 0.042129	x.shape: torch.Size([1000, 4384])	y.shap

In [323]:
dataloader = DataLoader(dataset, batch_size=1)

In [324]:
batch = next(iter(dataloader))

In [325]:
x, y = batch
x.shape, y.shape

(torch.Size([1, 4384]), torch.Size([1, 50]))

In [326]:
x

tensor([[1., 0., 0.,  ..., 0., 0., 0.]])

In [334]:
y, y.sum()

(tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]),
 tensor(2.))

In [328]:
output = model(x)
output

tensor([[4.9600e-02, 2.5676e-05, 1.5609e-11, 1.2241e-05, 1.0953e-09, 9.6507e-05,
         1.4133e-05, 3.4642e-04, 1.6628e-06, 8.2783e-01, 7.5926e-02, 2.1992e-06,
         6.1009e-12, 4.6155e-06, 1.5533e-04, 3.2543e-02, 4.0315e-04, 1.6284e-13,
         1.3446e-08, 1.4023e-08, 9.9435e-14, 8.8443e-07, 4.1290e-02, 7.2486e-01,
         5.9077e-05, 3.2605e-02, 7.3728e-03, 7.3745e-06, 3.7296e-02, 1.8355e-11,
         8.1447e-10, 3.9899e-13, 6.5733e-14, 4.6673e-07, 9.2048e-08, 3.3211e-04,
         6.2012e-05, 3.1749e-06, 1.4756e-02, 1.1596e-08, 1.3772e-08, 5.0771e-07,
         9.8594e-09, 3.7925e-09, 8.1214e-04, 1.0957e-08, 5.6726e-13, 5.8267e-06,
         1.0721e-08, 6.6117e-02]], grad_fn=<SigmoidBackward>)

In [337]:
(output > 0.1).int(), y, output * y

(tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0]], dtype=torch.int32),
 tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.8278, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.7249, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000]], grad_fn=<MulBackward0>))

In [301]:
(output - y)

tensor([[ 0.0057,  0.0155,  0.0148,  0.0012,  0.0084,  0.0125,  0.0310,  0.0058,
          0.0028, -0.8903,  0.0162,  0.0179,  0.0104,  0.0032,  0.0119,  0.0161,
          0.0051,  0.0193,  0.0099,  0.0033,  0.0083,  0.0027,  0.0038, -0.8947,
          0.0139,  0.0144,  0.0254,  0.0361,  0.0181,  0.0065,  0.0090,  0.0129,
          0.0037,  0.0046,  0.0317,  0.0083,  0.0154,  0.0067,  0.0058,  0.0038,
          0.0071,  0.0324,  0.0064,  0.0208,  0.0104,  0.0245,  0.0063,  0.0212,
          0.0092,  0.0064]], grad_fn=<SubBackward0>)

In [302]:
he = model.embeddings(x)
he

tensor([[ 0.6591, -0.1252,  0.6525,  0.6117,  0.1725, -0.1092,  0.5916,  0.6021,
          0.3134,  0.4587, -0.0986,  0.1905,  0.6319,  0.6196, -0.1314,  0.3249,
         -0.1442, -0.0804,  0.6251,  0.6490]], grad_fn=<AddmmBackward>)

In [303]:
xe = F.relu(he)
xe

tensor([[0.6591, 0.0000, 0.6525, 0.6117, 0.1725, 0.0000, 0.5916, 0.6021, 0.3134,
         0.4587, 0.0000, 0.1905, 0.6319, 0.6196, 0.0000, 0.3249, 0.0000, 0.0000,
         0.6251, 0.6490]], grad_fn=<ReluBackward0>)

In [304]:
hl1 = model.linear1(xe)
hl1

tensor([[-0.2316, -0.4473, -1.4305, -0.3557,  2.9244,  3.2764, -0.7536, -0.3363,
          3.9937, -0.1228]], grad_fn=<AddmmBackward>)

In [305]:
xl1 = F.relu(hl1)
xl1

tensor([[0.0000, 0.0000, 0.0000, 0.0000, 2.9244, 3.2764, 0.0000, 0.0000, 3.9937,
         0.0000]], grad_fn=<ReluBackward0>)

In [306]:
hl2 = model.linear2(xl1)
hl2

tensor([[-5.1570, -4.1510, -4.1982, -6.6995, -4.7714, -4.3663, -3.4408, -5.1393,
         -5.8730, -4.6303, -4.1082, -4.0032, -4.5511, -5.7505, -4.4198, -4.1140,
         -5.2817, -3.9262, -4.6030, -5.6990, -4.7828, -5.9047, -5.5611, -5.2333,
         -4.2636, -4.2290, -3.6491, -3.2849, -3.9926, -5.0236, -4.6992, -4.3374,
         -5.5866, -5.3676, -3.4187, -4.7862, -4.1592, -5.0032, -5.1436, -5.5789,
         -4.9371, -3.3968, -5.0464, -3.8539, -4.5558, -3.6838, -5.0565, -3.8329,
         -4.6815, -5.0523]], grad_fn=<AddmmBackward>)

In [307]:
torch.sigmoid(hl2)

tensor([[0.0057, 0.0155, 0.0148, 0.0012, 0.0084, 0.0125, 0.0310, 0.0058, 0.0028,
         0.0097, 0.0162, 0.0179, 0.0104, 0.0032, 0.0119, 0.0161, 0.0051, 0.0193,
         0.0099, 0.0033, 0.0083, 0.0027, 0.0038, 0.0053, 0.0139, 0.0144, 0.0254,
         0.0361, 0.0181, 0.0065, 0.0090, 0.0129, 0.0037, 0.0046, 0.0317, 0.0083,
         0.0154, 0.0067, 0.0058, 0.0038, 0.0071, 0.0324, 0.0064, 0.0208, 0.0104,
         0.0245, 0.0063, 0.0212, 0.0092, 0.0064]], grad_fn=<SigmoidBackward>)