# Gene2Vec in Pytorch: Part 2



This time we solve the exact same problem, but with a few improvements.

In the previous version, we created one-hot input vectors.

The input is then applied to the first linear layer in the network, since all but one of the inputs are 0 (one-hot encoded), only one of the rows of the layer will be used (i.e. the one corresponding to the '1' in the input). So this is wastefull because we are spending a lot of time multiplying by zeros. Pytorch provides an `Embedding` layer that speeds this up by performing a lookup (much faster). Also we use less memory since we now only have to provide an input index, instead of a one-hot encoded vertor of dimention 23,112.

Changes in this version
- Network: Use Pytorch Embedding, instead of linear layer
- Dataset inputs: Use single number (gene index), instead of one-hot encoded vector


In [11]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [12]:
# Imports
import numpy as np
import pandas as pd

from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from tcga.msigdb import *
from tcga.util import *
from torch.utils.data import Dataset, DataLoader

# Dataset

In [13]:
def names2num(names):
    """ Create a mapping from names to numbers """
    names = list(set(names))  # Make sure names are unique
    return {n: i for i,n in enumerate(sorted(names))}
    
class DatasetMsigDb(Dataset):
    """ 
    Custom dataset: We have to override methods __len__ and __getitem__
    In our network, the inputs are genes and the outputs are gene-sets.
    We convert genes and gene sets to numbers, then store the forward and
    reverse mappings.
    
    Genes: One hot encoding
    
    Gene sets: We encode gene-> gene_sets as a dictionary indexed by
    gene, with tensors having 1 on the GeneSets the gene belongs to.
    
    The method __getitem__ returns a tuple with the gene (one-hot
    encoded) and the gene-set tensor (having ones on all gene-sets
    the gene belongs to)
    """
    def __init__(self, path):
        self.msigdb = read_msigdb_all(path)
        # Gene <-> number: forward and reverse mapping
        self.gene2num = names2num(msigdb2genes(self.msigdb))
        self.num2gene = {n: g for g, n in self.gene2num.items()}
        # GeneSet <-> number: forward and reverse mapping
        self.geneset2num = names2num(msigdb2gene_sets(self.msigdb))
        self.num2geneset = {n: g for g, n in self.gene2num.items()}
        # Gene -> GeneSets mapping (use gene_set numbers, in a tensor)
        self.init_genes()
        self.init_genesets()

    def genesets2num(self, genesets):
        " Convert to a list of numerically encoded gene-sets "
        return [self.geneset2num[gs] for gs in genesets]

    def gene2tensor(self, gene):
        " Convert to an index tensor (yes, it's just a number) "
        gene_tensor = torch.LongTensor([self.gene2num[gene]])
        return gene_tensor
        
    def genesets2tensor(self, genesets):
        " Convert to a vector having 1 in each geneset position "
        geneset_idxs = [self.geneset2num[gs] for gs in genesets]
        geneset_tensor = torch.zeros(len(self.msigdb))
        geneset_tensor[geneset_idxs] = 1
        return geneset_tensor
        
    def init_genes(self):
        " Create a one-hot encoding for a gene "
        self.gene_tensors = dict()
        for gene in self.gene2num.keys():
            self.gene_tensors[gene] = self.gene2tensor(gene)
        
    def init_genesets(self):
        " Map Gene to GeneSets. GeneSets are hot-encoded "
        self.gene_genesets = dict()
        self.gene_genesets_num = dict()
        self.gene_genesets_tensors = dict()
        num_genesets = len(self.geneset2num)
        for gene, genesets in gene_genesets(self.msigdb).items():
            self.gene_genesets[gene] = genesets
            self.gene_genesets_num[gene] = self.genesets2num(genesets)
            self.gene_genesets_tensors[gene] = self.genesets2tensor(genesets)
    
    def __len__(self):
        " Len: Count number of genes "
        return len(self.gene2num)

    def gene_sets_size(self):
        " Count number of gene sets "
        return len(self.msigdb)
    
    def __getitem__(self, idx):
        " Get item 'idx': A tuple of gene number 'idx' and gene set tensor for that gene "
        gene = self.num2gene[idx]
        return (self.gene_tensors[gene], self.gene_genesets_tensors[gene])
    
    def __repr__(self):
        return self.__str__()
    
    def __str__(self):
        " Show (a few) mappings gene -> gene_set tensor "
        out = f"Genes: {len(self.gene2num)}, Gene Sets: {len(self.geneset2num)}\n"
        for i in range(10):  #range(len(self)):
            gene = self.num2gene[i]
            gene_tensor, geneset_tensor = self[i]
            out += f"\tGene: {gene}, {i}, {gene_tensor}\n\tGeneSet: {self.gene_genesets[gene]}, {self.gene_genesets_num[gene]}, {geneset_tensor}\n\n"
        return out + "..."


In [14]:
path = Path('data/msigdb/small')
dataset = DatasetMsigDb(path)
dataset

File: data/msigdb/small/h.all.v7.0.symbols.gmt


Genes: 4384, Gene Sets: 50
	Gene: A2M, 0, tensor([0])
	GeneSet: {'HALLMARK_COAGULATION', 'HALLMARK_IL6_JAK_STAT3_SIGNALING'}, [9, 23], tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

	Gene: AAAS, 1, tensor([1])
	GeneSet: {'HALLMARK_DNA_REPAIR'}, [11], tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

	Gene: AADAT, 2, tensor([2])
	GeneSet: {'HALLMARK_FATTY_ACID_METABOLISM'}, [16], tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

	Gene: AARS, 3, tensor([3])
	GeneSet:

In [15]:
class Gene2GeneSetModule(nn.Module):
    def __init__(self, dataset_msigdb, embedding_dim, layer_size=10):
        super(Gene2GeneSetModule, self).__init__()
        genes_vocab_size = len(dataset_msigdb)
        genesets_vocab_size = dataset_msigdb.gene_sets_size()
        self.embeddings = nn.Embedding(genes_vocab_size, embedding_dim)
        self.linear1 = nn.Linear(embedding_dim, layer_size)
        self.linear2 = nn.Linear(layer_size, genesets_vocab_size)

    def forward(self, inputs):
        x = F.relu(self.embeddings(inputs))
        x = F.relu(self.linear1(x))
        probs = torch.sigmoid(self.linear2(x))
        return probs

In [16]:
def train(model, dataloader, epochs, lr, momentum):
    device = torch.device("cpu")
    model.train()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    for n_epoch in range(epochs):
        for n_batch, batch in enumerate(dataloader):
            x, y = batch
            optimizer.zero_grad()
            output = model(x)
            loss = F.binary_cross_entropy(output.squeeze(), y)
            loss.backward()
            optimizer.step()
            if n_epoch % 10 == 0 and n_batch == 0:
                print(f"Train Epoch: {n_epoch} / {epochs}\tn_batch: {n_batch}\tLoss: {loss.item():.6f}\tx.shape: {x.shape}\ty.shape: {y.shape}")

In [17]:
dataloader = DataLoader(dataset, batch_size=1000)
model = Gene2GeneSetModule(dataset, 20, 10)
train(model, dataloader, epochs=150, lr=0.01, momentum=0.8)

Train Epoch: 0 / 150	n_batch: 0	Loss: 0.690097	x.shape: torch.Size([1000, 1])	y.shape: torch.Size([1000, 50])
Train Epoch: 10 / 150	n_batch: 0	Loss: 0.149002	x.shape: torch.Size([1000, 1])	y.shape: torch.Size([1000, 50])
Train Epoch: 20 / 150	n_batch: 0	Loss: 0.140674	x.shape: torch.Size([1000, 1])	y.shape: torch.Size([1000, 50])
Train Epoch: 30 / 150	n_batch: 0	Loss: 0.137686	x.shape: torch.Size([1000, 1])	y.shape: torch.Size([1000, 50])
Train Epoch: 40 / 150	n_batch: 0	Loss: 0.135548	x.shape: torch.Size([1000, 1])	y.shape: torch.Size([1000, 50])
Train Epoch: 50 / 150	n_batch: 0	Loss: 0.132318	x.shape: torch.Size([1000, 1])	y.shape: torch.Size([1000, 50])
Train Epoch: 60 / 150	n_batch: 0	Loss: 0.125862	x.shape: torch.Size([1000, 1])	y.shape: torch.Size([1000, 50])
Train Epoch: 70 / 150	n_batch: 0	Loss: 0.115653	x.shape: torch.Size([1000, 1])	y.shape: torch.Size([1000, 50])
Train Epoch: 80 / 150	n_batch: 0	Loss: 0.103543	x.shape: torch.Size([1000, 1])	y.shape: torch.Size([1000, 50])
Tr