In [1]:
from tqdm import tqdm
from scipy import io
from scipy import sparse
import scipy
import gzip
import scanpy

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torch.utils.data import DataLoader, Dataset
from torch.distributions.beta import Beta

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
import os
import pickle




class CustomDataset(Dataset):
    def __init__(self, mtx):
        super(CustomDataset, self).__init__()
        self.mtx = mtx.tocsc()
        self.number_of_genes, self.number_of_cells = mtx.shape
    def __len__(self):
        return self.number_of_cells

    def __getitem__(self, idx):
        X = torch.FloatTensor(np.asarray(self.mtx[:, idx].todense()).squeeze())        
        return X

class BetaVAE(nn.Module):
    def __init__(self, number_of_genes):
        super(BetaVAE, self).__init__()

        self.number_of_genes = number_of_genes

        self.encode = nn.Sequential(
            nn.Linear(self.number_of_genes, 1000),
            nn.ReLU(),
            nn.BatchNorm1d(1000),
            nn.Dropout(),
            
            nn.Linear(1000, 100),
            nn.ReLU(),
            nn.BatchNorm1d(100),
            nn.Dropout(),
        )
        self.linear_a = nn.Linear(100, 10)
        self.linear_b = nn.Linear(100, 10)


        self.decode = nn.Linear(10, self.number_of_genes)
        self.decode.weight.data.fill_(0.5)

      

    def forward(self, x):
        # Encoding
        x = self.encode(x)
        a = torch.exp(self.linear_a(x))
        b = torch.exp(self.linear_b(x))

        # Random sampling (reparametrization trick)
        z = Beta(a, b).sample()

        # Decoding
        x_decoded = self.decode(z)

        
        return x_decoded

class Discriminator(nn.Module):
    def __init__(self, number_of_genes):
        super(Discriminator, self).__init__()
        self.number_of_genes = number_of_genes
        self.model = nn.Sequential(
            nn.Linear(self.number_of_genes, 512),
            nn.ReLU(),
            nn.BatchNorm1d(512),
            nn.Dropout(),

            nn.Linear(512, 256),
            nn.ReLU(),
            nn.BatchNorm1d(256),
            nn.Dropout(),

            nn.Linear(256, 128),
            nn.ReLU(),
            nn.BatchNorm1d(128),
            nn.Dropout(),

            nn.Linear(128, 64),
            nn.ReLU(),
            nn.BatchNorm1d(64),
            nn.Dropout(),

            nn.Linear(64, 32),
            nn.ReLU(),
            nn.BatchNorm1d(32),
            nn.Dropout(),

            nn.Linear(32, 1)
        )
    
    def forward(self, x):
        x = self.model(x)
        z = torch.sigmoid(x)
        return z
        

In [None]:
# Temp Data
heart = normalize(heart_raw)
number_of_genes = heart.shape[0]

In [None]:
def train(model, device, train_dataloader, optim, epoch):
    mse = nn.MSELoss()
    softmax = nn.Softmax(dim=1)
    model.train()
    total_loss = 0
    for i, X in enumerate(train_dataloader):
        X = X.to(device)
        optim.zero_grad()
        pred = model.forward(X)
        loss = mse(pred, X)
        total_loss += loss.item()
        loss.backward()
        optim.step()
        
        model.decode.weight.data.copy_(torch.sigmoid(model.decode.weight.data))
    average_loss = total_loss / len(train_dataloader)
    if epoch % 1 == 0:
        print('epoch', epoch)
        print('TRAIN loss =', average_loss*1000)

In [None]:
dataset = CustomDataset(heart)

train_loader = DataLoader(dataset, batch_size=32, shuffle=True)
device = 'cuda:0'
vae = BetaVAE(number_of_genes).to(device)
learning_rate = 0.000002
optimizer = optim.Adam(vae.parameters(), lr=learning_rate)
epochs = 10
for epoch in range(epochs):
    train(vae, device, train_loader, optimizer, epoch)
torch.cuda.empty_cache()

In [None]:
def train(epoch, device, train_loader, vae, discriminator, gen_opt, disc_opt):
    vae.train()
    discriminator.train()
    total_g = 0
    total_d = 0

    bce = nn.BCELoss()

    for X in train_loader:
        X = X.to(device)

        size = discriminator(X).size()
        ones = torch.ones(size, requires_grad=False).to(device)
        zeros = torch.zeros(size, requires_grad=False).to(device)

        gen_opt.zero_grad()
        X_ = vae(X)
        gen_loss = bce(discriminator(X_), ones)
        gen_loss.backward()
        total_g += gen_loss.item()
        gen_opt.step()

        disc_opt.zero_grad()        
        actual_loss = bce(discriminator(X), ones)
        fake_loss = bce(discriminator(X_.detach()), zeros)
        disc_loss = (actual_loss + fake_loss) / 2
        disc_loss.backward()
        total_d += disc_loss.item()
        disc_opt.step()
    if epoch % 1 == 0:
        print('epoch', epoch, 'G', total_g/len(train_loader), 'D', total_d/len(train_loader))


In [None]:
immune_dataset = CustomDataset(immune)
dataset = immune_dataset
train_loader = DataLoader(dataset, batch_size=32, shuffle=True)

device = 'cuda:4'
learning_rate = 0.01
epochs = 100

vae = BetaVAE(number_of_genes)
vae.load_state_dict(torch.load('/data/home/kimds/GAN/betavae.pt'))
discriminator = Discriminator(number_of_genes)
gen_opt = optim.Adam(vae.parameters(), lr=learning_rate)
disc_opt = optim.Adam(discriminator.parameters(), lr=learning_rate)

for epoch in range(epochs):
    train(epoch, device, train_loader, vae, discriminator, gen_opt, disc_opt)
torch.cuda.empty_cache()