In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [3]:
import numpy as np
import matplotlib.pyplot as plt
from scipy import stats
import pandas as pd

plt.style.use('seaborn-notebook')

In [24]:
df.head()

Unnamed: 0,rid,CPC005_A375_6H:BRD-A85280935-003-01-7:10,CPC005_A375_6H:BRD-A07824748-001-02-6:10,CPC004_A375_6H:BRD-K20482099-001-01-1:10,CPC005_A375_6H:BRD-K62929068-001-03-3:10,CPC005_A375_6H:BRD-K43405658-001-01-8:10,CPC004_A375_6H:BRD-K03670461-001-02-0:10,CPC004_A375_6H:BRD-K36737713-001-01-6:10,CPC005_A375_6H:BRD-K51223576-001-01-3:10,CPC004_A375_6H:BRD-A14966924-001-03-0:10,...,CPC005_A375_24H:BRD-A59303141-001-03-9:10,CPC005_A375_24H:BRD-K54665485-001-04-6:10,CPC005_A375_24H:BRD-A54236247-003-03-5:10,CPC005_A375_24H:BRD-K10098805-001-02-0:10,CPC005_A375_24H:BRD-K13725475-001-02-4:10,CPC005_A375_24H:BRD-A08003242-001-02-7:10,CPC005_A375_24H:BRD-K43796186-001-01-1:10,CPC005_A375_24H:BRD-K83063356-003-01-7:10,CPC005_A375_24H:BRD-K86600316-003-01-2:10,CPC005_A375_24H:BRD-A92585442-237-01-0:10
0,5720,0.773769,-0.645586,-5.449666,0.193408,1.006298,-5.388713,-1.00024,0.49011,0.063297,...,0.160586,-0.193009,0.247968,0.384757,0.352685,-0.23349,0.281433,1.141963,-0.302364,1.02505
1,466,-0.818468,-0.810749,2.393775,-0.582243,0.455536,1.867731,-1.106092,0.595174,-0.962553,...,0.161364,-0.244689,0.559568,0.592947,-1.140376,-2.4135,-1.134386,0.623217,-0.170404,0.265432
2,6009,0.189572,0.45906,1.27979,-0.178977,0.631738,0.281383,-0.422545,-0.224163,0.521552,...,-0.663482,-0.235831,0.684576,1.720635,-0.25445,-0.414349,-0.796767,0.418341,0.870858,-0.539486
3,2309,-0.146031,-0.224676,2.167868,-1.182025,-0.936414,1.378175,0.406279,-0.244783,0.182361,...,0.552385,0.21892,-0.601392,-0.404516,0.662811,1.789149,0.664427,1.452139,-0.774794,-0.752421
4,387,-0.654002,-0.335681,2.333199,-1.012651,-1.213203,1.290522,-0.218671,-0.124029,0.572183,...,-0.37329,1.054628,0.458266,0.077265,0.079647,-1.665797,-0.32378,0.285577,0.879944,0.504271


In [25]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

In [38]:
class Lincs(Dataset):

    def __init__(self):
        super().__init__()
        self.df = pd.read_csv('level5_1000.csv')
    
    def shape(self):
        return self.df.shape
    
    def __len__(self):
        return self.df.shape[1]-1
    
    def __getitem__(self, idx):
        return torch.as_tensor(self.df.iloc[:,idx+1].values, dtype=torch.float32)

lincs = Lincs()
dloader = DataLoader(lincs, 32)

In [39]:
class VAE(nn.Module):
    def __init__(self, input_shape):
        super().__init__()
        self.fc1 = nn.Linear(input_shape,64)
        self.fc2 = nn.Linear(64,64)
        self.enc_mu = nn.Linear(64, 7)
        self.enc_logvar = nn.Linear(64, 7)
        
        self.fc_out1 = nn.Linear(7, 64)
        self.fc_out2 = nn.Linear(64,64)
        self.out = nn.Linear(64, input_shape)
    
    def encode(self, x):
        hid = F.relu(self.fc1(x))
        hid = F.relu(self.fc2(hid))
        return self.enc_mu(hid), self.enc_logvar(hid)
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std
    
    def decode(self, z):
        hid = F.relu(self.fc_out1(z))
        hid = F.relu(self.fc_out2(hid))
        return self.out(hid)
    
    def forward(self, t):
        mu, logvar = self.encode(t)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar        

In [40]:
model = VAE(lincs.shape()[0])
optimizer = optim.Adam(model.parameters(), lr=1e-3)
def loss_function(recon_x, x, mu, logvar):
    mse = F.mse_loss(x, recon_x, reduction='sum')
    kld = 0.5*(mu.pow(2).sum(dim=-1) + torch.exp(logvar).sum(dim=-1) - (logvar+1).sum(dim=-1))
    
    return (mse + kld).sum(dim=-1)

In [41]:
def train(epoch):
    model.train()
    train_loss = 0
    for batch_idx, x in enumerate(dloader):
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(x)
        loss = loss_function(recon_batch, x, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
    return train_loss / len(dloader)

In [43]:
train_losses = []
for epoch in range(1000):
    train_losses.append(train(epoch))
    if epoch % 100 == 0:
        print(f'=======> Epoch: {epoch} Average loss: {train_losses[-1]}')
plt.plot(np.arange(len(train_losses)), train_losses)



KeyboardInterrupt: 