### Data

In [None]:
from data_process_and_get_mask import create_mask
from get_data import pbmc_train, pbmc_val

tf_list = [
    'IRF1', 'IRF2', 'IRF2BPL', 'IRF3', 'IRF4', 'IRF5', 'IRF6', 'IRF7', 'IRF8', 'IRF9', 'STAT1', 'STAT2', 'STAT3', 'STAT4', 'STAT5A','STAT5B', 'STAT6'
]
#["GTF2I","GTF3A","NRF1","ELF1","STAT1","STAT2","IRF9","STAT3","STAT4","STAT5A","STAT5B","IRF3","IRF7","IRF1","IRF5","IRF8"]

mask, train, valid = create_mask(pbmc_train, pbmc_val, tf_list)
print("Printing mask:")
print(mask)
print("\nPrinting train:")
print(train)
print("\nPrinting valid:")
print(valid)

Printing mask:
shape: (357, 16)
┌─────────┬────────┬────────┬───────┬───┬───────┬──────┬──────┬───────────────┐
│ target  ┆ STAT5A ┆ STAT5B ┆ STAT4 ┆ … ┆ STAT2 ┆ IRF7 ┆ IRF9 ┆ unannotated_1 │
│ ---     ┆ ---    ┆ ---    ┆ ---   ┆   ┆ ---   ┆ ---  ┆ ---  ┆ ---           │
│ str     ┆ f64    ┆ f64    ┆ f64   ┆   ┆ f64   ┆ f64  ┆ f64  ┆ i32           │
╞═════════╪════════╪════════╪═══════╪═══╪═══════╪══════╪══════╪═══════════════╡
│ IL2     ┆ 1.0    ┆ 1.0    ┆ 0.0   ┆ … ┆ 0.0   ┆ 0.0  ┆ 0.0  ┆ 1             │
│ IRF1    ┆ 1.0    ┆ 1.0    ┆ 1.0   ┆ … ┆ 1.0   ┆ 0.0  ┆ 0.0  ┆ 1             │
│ CCND1   ┆ 1.0    ┆ 1.0    ┆ 0.0   ┆ … ┆ 0.0   ┆ 0.0  ┆ 0.0  ┆ 1             │
│ PRF1    ┆ 1.0    ┆ 1.0    ┆ 1.0   ┆ … ┆ 0.0   ┆ 0.0  ┆ 0.0  ┆ 1             │
│ IFNG    ┆ 1.0    ┆ 1.0    ┆ 1.0   ┆ … ┆ 1.0   ┆ 0.0  ┆ 1.0  ┆ 1             │
│ …       ┆ …      ┆ …      ┆ …     ┆ … ┆ …     ┆ …    ┆ …    ┆ …             │
│ ZC3HAV1 ┆ 0.0    ┆ 0.0    ┆ 0.0   ┆ … ┆ 0.0   ┆ 0.0  ┆ 0.0  ┆ 1             │
│ ZNF181

In [12]:
import torch
from device import device
#training loop
def trainVEGA(vae, data, epochs=60, beta = 0.0001, learning_rate = 0.001):
  #epochs 60 looked enough when we tried 100
  data = torch.utils.data.DataLoader(data.X.toarray(), batch_size=128)
  opt = torch.optim.Adam(vae.parameters(), lr = learning_rate, weight_decay = 5e-4)
  vae.train() #train mode
  losses = []
  klds = []
  mses = []

  for epoch in range(epochs):
      loss_e = 0
      kld_e = 0
      mse_e = 0

      for x in data:
          x = x.to(device)
          opt.zero_grad()
          x_hat = vae(x)
          mse = ((x - x_hat)**2).sum()
          kld = beta* vae.encoder.kl
          loss = mse +  kld # loss calculation
          loss.backward()
          opt.step()
          loss_e += loss.to('cpu').detach().numpy()
          kld_e += kld.to('cpu').detach().numpy()
          mse_e += mse.to('cpu').detach().numpy()
          vae.decoder.weights() # this used to be positive_weights
          vae.encoder.clamp_mu() # this used to be positive_weights

      losses.append(loss_e/(len(data)*128))
      klds.append(kld_e/(len(data)*128))
      mses.append(mse_e/(len(data)*128))

      print("epoch: ", epoch, " loss: ", loss_e/(len(data)*128))

  return vae, losses, klds, mses

In [21]:
from model.vega import VEGA
from model.decoder_default import DecoderVEGA
from model.encoder_with_mu_clamp import Encoder

mask_np = mask.drop("target").to_numpy()

latent_dims= mask_np.shape[1]
input_dims = mask_np.shape[0]
dropout = 0.3
z_dropout = 0.5
vega = VEGA(
    encoder=Encoder(latent_dims, input_dims, dropout, z_dropout),
    decoder=DecoderVEGA(mask_np.T)
    ).to(device)

In [22]:
# model training
vega, vega_losses, vega_klds, vega_mses = trainVEGA(vega, train, epochs = 20, beta = 0.0001) #takes about 2 mins on GPU # epoch 100

epoch:  0  loss:  83.86034
epoch:  1  loss:  62.688946
epoch:  2  loss:  50.798916
epoch:  3  loss:  47.324947
epoch:  4  loss:  44.367134
epoch:  5  loss:  42.468143
epoch:  6  loss:  41.829674
epoch:  7  loss:  41.599037
epoch:  8  loss:  40.101185
epoch:  9  loss:  39.86486
epoch:  10  loss:  40.877087
epoch:  11  loss:  40.34149
epoch:  12  loss:  40.60262
epoch:  13  loss:  39.786972
epoch:  14  loss:  37.985283
epoch:  15  loss:  37.17095
epoch:  16  loss:  37.727222
epoch:  17  loss:  36.522495
epoch:  18  loss:  36.338856
epoch:  19  loss:  37.105186
