# Gene2Vec in Pytorch: Part 1



We want to create embeddings using "Gene Sets" from MsiDb.
The idea is to create an encoder: gene -> geneSets, then use the middle layer as en embedding to represent genes.

### Network

- Input (genes): One-hot encoded. There are 23,112 genes in MsigDb sets we are interested in, so the inputs are vectors of dimention 23,112 where all inputs are zero, except one (one-hot)
- Hidden layer 1: Converts the input from 23,112 to lower dimension vector (e.g. 100), this is the "Embedding" we'll use
- Hidden layer 2: Linear + Relu
- Output (gene sets): Many-hot encoded, because one gene can belong to multiple gene sets (i.e. this is a Multi-label classification problem). There are 20,608 gene sets, so the output is a vector of dimension 20,608 and most of the outputs should be zero.
- Output non-linearity: We use sigmoid function. Why not something like 'sofmax'? Because this is a multi-label classification problem, the output one neuron should not be influenced by the output of the other neurons (i.e. the sum of the outputs doesn't have to be 1.0) 
- Loss function: Obviously if we use MSE this won't work well because an easy way to lower the MSE is to make all the outputs zero (most of the outputs are 0.0, expcet for a few 1.0). In multi-label classification an typical loss function is binary-cross-entropy (the intuition is that it treats each output neuron as a binary classification problem).


In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
# Imports
import numpy as np
import pandas as pd

from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from tcga.msigdb import *
from tcga.util import *
from torch.utils.data import Dataset, DataLoader

# Dataset

We read the MsigDb and produce a 'dataset' with inputs and outputs.
- Inputs: Genes as one-hot vectors, e.g. `[0, 0, ..., 0, 1, 0, ..., 0]`
- Outputs: GeneSets as many-hot vectors, e.g. `[0, 1, ..., 0, 1, 1, ..., 0, 1, 0]`


In [3]:
def names2num(names):
    """ Create a mapping from names to numbers """
    names = list(set(names))  # Make sure names are unique
    return {n: i for i,n in enumerate(sorted(names))}
    
class DatasetMsigDb(Dataset):
    """ 
    Custom dataset: We have to override methods __len__ and __getitem__
    In our network, the inputs are genes and the outputs are gene-sets.
    We convert genes and gene sets to numbers, then store the forward and
    reverse mappings.
    
    Genes: One hot encoding
    
    Gene sets: We encode gene-> gene_sets as a dictionary indexed by
    gene, with tensors having 1 on the GeneSets the gene belongs to.
    
    The method __getitem__ returns a tuple with the gene (one-hot
    encoded) and the gene-set tensor (having ones on all gene-sets
    the gene belongs to)
    """
    def __init__(self, path):
        self.msigdb = read_msigdb_all(path)
        # Gene <-> number: forward and reverse mapping
        self.gene2num = names2num(msigdb2genes(self.msigdb))
        self.num2gene = {n: g for g, n in self.gene2num.items()}
        # GeneSet <-> number: forward and reverse mapping
        self.geneset2num = names2num(msigdb2gene_sets(self.msigdb))
        self.num2geneset = {n: g for g, n in self.gene2num.items()}
        # Gene -> GeneSets mapping (use gene_set numbers, in a tensor)
        self.init_genes()
        self.init_genesets()

    def genesets2num(self, genesets):
        " Convert to a list of numerically encoded gene-sets "
        return [self.geneset2num[gs] for gs in genesets]

    def gene2tensor(self, gene):
        " Convert to a one-hot encoding "
        gene_tensor = torch.zeros(len(self.gene2num))
        gene_tensor[self.gene2num[gene]] = 1
        return gene_tensor
        
    def genesets2tensor(self, genesets):
        " Convert to a vector having 1 in each geneset position "
        geneset_idxs = [self.geneset2num[gs] for gs in genesets]
        geneset_tensor = torch.zeros(len(self.msigdb))
        geneset_tensor[geneset_idxs] = 1
        return geneset_tensor
        
    def init_genes(self):
        " Create a one-hot encoding for a gene "
        self.gene_tensors = dict()
        for gene in self.gene2num.keys():
            self.gene_tensors[gene] = self.gene2tensor(gene)
        
    def init_genesets(self):
        " Map Gene to GeneSets. GeneSets are hot-encoded "
        self.gene_genesets = dict()
        self.gene_genesets_num = dict()
        self.gene_genesets_tensors = dict()
        num_genesets = len(self.geneset2num)
        for gene, genesets in gene_genesets(self.msigdb).items():
            self.gene_genesets[gene] = genesets
            self.gene_genesets_num[gene] = self.genesets2num(genesets)
            self.gene_genesets_tensors[gene] = self.genesets2tensor(genesets)
    
    def __len__(self):
        " Len: Count number of genes "
        return len(self.gene2num)

    def gene_sets_size(self):
        " Count number of gene sets "
        return len(self.msigdb)
    
    def __getitem__(self, idx):
        " Get item 'idx': A tuple of gene number 'idx' and gene set tensor for that gene "
        gene = self.num2gene[idx]
        return (self.gene_tensors[gene], self.gene_genesets_tensors[gene])
    
    def __repr__(self):
        return self.__str__()
    
    def __str__(self):
        " Show (a few) mappings gene -> gene_set tensor "
        out = f"Genes: {len(self.gene2num)}, Gene Sets: {len(self.geneset2num)}\n"
        for i in range(10):  #range(len(self)):
            gene = self.num2gene[i]
            gene_tensor, geneset_tensor = self[i]
            out += f"\tGene: {gene}, {i}, {gene_tensor}\n\tGeneSet: {self.gene_genesets[gene]}, {self.gene_genesets_num[gene]}, {geneset_tensor}\n\n"
        return out + "..."


In [4]:
path = Path('data/msigdb/small')
dataset = DatasetMsigDb(path)
dataset

File: data/msigdb/small/h.all.v7.0.symbols.gmt


Genes: 4384, Gene Sets: 50
	Gene: A2M, 0, tensor([1., 0., 0.,  ..., 0., 0., 0.])
	GeneSet: {'HALLMARK_IL6_JAK_STAT3_SIGNALING', 'HALLMARK_COAGULATION'}, [23, 9], tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

	Gene: AAAS, 1, tensor([0., 1., 0.,  ..., 0., 0., 0.])
	GeneSet: {'HALLMARK_DNA_REPAIR'}, [11], tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

	Gene: AADAT, 2, tensor([0., 0., 1.,  ..., 0., 0., 0.])
	GeneSet: {'HALLMARK_FATTY_ACID_METABOLISM'}, [16], tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0.

In [5]:
class Gene2GeneSetModule(nn.Module):
    def __init__(self, dataset_msigdb, embedding_dim, layer_size=10):
        super(Gene2GeneSetModule, self).__init__()
        genes_vocab_size = len(dataset_msigdb)
        genesets_vocab_size = dataset_msigdb.gene_sets_size()
        self.embeddings = nn.Linear(genes_vocab_size, embedding_dim)
        self.linear1 = nn.Linear(embedding_dim, layer_size)
        self.linear2 = nn.Linear(layer_size, genesets_vocab_size)

    def forward(self, inputs):
        x = F.relu(self.embeddings(inputs))
        x = F.relu(self.linear1(x))
        probs = torch.sigmoid(self.linear2(x))
        return probs

In [6]:
def train(model, dataloader, epochs, lr, momentum):
    device = torch.device("cpu")
    model.train()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    for n_epoch in range(epochs):
        for n_batch, batch in enumerate(dataloader):
            x, y = batch
            optimizer.zero_grad()
            output = model(x)
            loss = F.binary_cross_entropy(output.squeeze(), y)
            loss.backward()
            optimizer.step()
            if n_epoch % 1 == 0 and n_batch == 0:
                print(f"Train Epoch: {n_epoch} / {epochs}\tn_batch: {n_batch}\tLoss: {loss.item():.6f}\tx.shape: {x.shape}\ty.shape: {y.shape}")

In [7]:
dataloader = DataLoader(dataset, batch_size=1000)
model = Gene2GeneSetModule(dataset, 20, 10)
train(model, dataloader, epochs=100, lr=0.01, momentum=0.8)

Train Epoch: 0 / 100	n_batch: 0	Loss: 0.707272	x.shape: torch.Size([1000, 4384])	y.shape: torch.Size([1000, 50])
Train Epoch: 1 / 100	n_batch: 0	Loss: 0.671602	x.shape: torch.Size([1000, 4384])	y.shape: torch.Size([1000, 50])
Train Epoch: 2 / 100	n_batch: 0	Loss: 0.627928	x.shape: torch.Size([1000, 4384])	y.shape: torch.Size([1000, 50])
Train Epoch: 3 / 100	n_batch: 0	Loss: 0.562565	x.shape: torch.Size([1000, 4384])	y.shape: torch.Size([1000, 50])
Train Epoch: 4 / 100	n_batch: 0	Loss: 0.459402	x.shape: torch.Size([1000, 4384])	y.shape: torch.Size([1000, 50])
Train Epoch: 5 / 100	n_batch: 0	Loss: 0.321267	x.shape: torch.Size([1000, 4384])	y.shape: torch.Size([1000, 50])
Train Epoch: 6 / 100	n_batch: 0	Loss: 0.197739	x.shape: torch.Size([1000, 4384])	y.shape: torch.Size([1000, 50])
Train Epoch: 7 / 100	n_batch: 0	Loss: 0.152121	x.shape: torch.Size([1000, 4384])	y.shape: torch.Size([1000, 50])
Train Epoch: 8 / 100	n_batch: 0	Loss: 0.159890	x.shape: torch.Size([1000, 4384])	y.shape: torch.

Train Epoch: 73 / 100	n_batch: 0	Loss: 0.066506	x.shape: torch.Size([1000, 4384])	y.shape: torch.Size([1000, 50])
Train Epoch: 74 / 100	n_batch: 0	Loss: 0.065086	x.shape: torch.Size([1000, 4384])	y.shape: torch.Size([1000, 50])
Train Epoch: 75 / 100	n_batch: 0	Loss: 0.063780	x.shape: torch.Size([1000, 4384])	y.shape: torch.Size([1000, 50])
Train Epoch: 76 / 100	n_batch: 0	Loss: 0.062576	x.shape: torch.Size([1000, 4384])	y.shape: torch.Size([1000, 50])
Train Epoch: 77 / 100	n_batch: 0	Loss: 0.061500	x.shape: torch.Size([1000, 4384])	y.shape: torch.Size([1000, 50])
Train Epoch: 78 / 100	n_batch: 0	Loss: 0.060442	x.shape: torch.Size([1000, 4384])	y.shape: torch.Size([1000, 50])
Train Epoch: 79 / 100	n_batch: 0	Loss: 0.059402	x.shape: torch.Size([1000, 4384])	y.shape: torch.Size([1000, 50])
Train Epoch: 80 / 100	n_batch: 0	Loss: 0.058556	x.shape: torch.Size([1000, 4384])	y.shape: torch.Size([1000, 50])
Train Epoch: 81 / 100	n_batch: 0	Loss: 0.057738	x.shape: torch.Size([1000, 4384])	y.shap