In [1]:
import time
import os


import scanpy
import numpy as np
import scipy.sparse as sp

import torch
from torch import optim

import gae.gae.optimizer as optimizer
import gae.gae.model
import gae.gae.preprocessing as preprocessing

import pickle
import matplotlib.pyplot as plt
import seaborn as sns
import umap
import pandas as pd
from sklearn.preprocessing import scale
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import pairwise_distances

In [2]:
# Settings
# os.environ["CUDA_VISIBLE_DEVICES"] = "0" 
use_cuda=True
fastmode=False #Validate during training pass
seed=6
epochs=20000
saveFreq=20
batchsize=2000
lr=0.0001 #initial learning rate
lr_adv=0.001
weight_decay=0 #Weight for L2 loss on embedding matrix.

hidden1=30000 #Number of units in hidden layer 1
hidden2=30000 #Number of units in hidden layer 2
# hidden3=2048
# hidden4=2048
# hidden5=128
fc_dim1=30000
# fc_dim2=128
# fc_dim3=128
# fc_dim4=128
# gcn_dim1=2600
# clf_hidden=256
adv_hidden=128

dropout=0.01
testNodes=0.1 #fraction of total nodes for testing
valNodes=0.05 #fraction of total nodes for validation
XreconWeight=20
# clfweight=20
advWeight=2
# randFeatureSubset=None
model_str='fc1_dca_sharded'
clf=None
adv=None  #'clf_fc1_eq'  #'clf_fc1_control_eq' #'clf_fc1_control'  #'clf_fc1'
protein=None #'nearest' #None #'scaled_binary'
adj_decodeName=None #gala or None
ridgeL=0.01
shareGenePi=True

targetBatch=None
training_sample_X='logminmax'
switchFreq=10
standardizeX=False
tissue='Bone_Marrow'
name='allk20XA_02_dca_over_FCXonly_scrnaseq_'+tissue 
logsavepath='/data/xinyi/log/train_gae_scrnaseq/'+name
modelsavepath='/data/xinyi/models/train_gae_scrnaseq/'+name
plotsavepath='/data/xinyi/plots/train_gae_scrnaseq/'+name
datadir='/data/xinyi/STACI/scrnaseq/'




In [3]:
data=scanpy.read_h5ad(os.path.join(datadir,'Immune_ALL_human.h5ad'))
data=data[:,np.array(np.sum(data.layers['counts'],axis=0)>3).flatten()]
if tissue!= 'all':
    data=data[data.obs['tissue']==tissue]

In [8]:
if training_sample_X=='logminmax':
    data_train=np.log2(data.layers['counts']+1/2)
    scaler = MinMaxScaler()
    data_train=np.transpose(scaler.fit_transform(np.transpose(data_train)))
    data_train=torch.tensor(data_train)
    
if 'dca' in model_str:
    data_raw=torch.tensor(data.layers['counts']+1/2)
    

# Set cuda and seed
np.random.seed(seed)
if use_cuda and (not torch.cuda.is_available()):
    print('cuda not available')
    use_cuda=False
torch.manual_seed(seed)
if use_cuda:
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.enabled = True



In [9]:
if not os.path.exists(logsavepath):
    os.mkdir(logsavepath)
if not os.path.exists(modelsavepath):
    os.mkdir(modelsavepath)
if not os.path.exists(plotsavepath):
    os.mkdir(plotsavepath)

In [10]:
# loop over all train/validation sets
np.random.seed(seed)
torch.manual_seed(seed)
if use_cuda:
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.enabled = True
    
num_features=data_train.shape[1]
mse=torch.nn.MSELoss()
# mse=torch.nn.MSELoss(reduction=None)
# Create model
if model_str=='gcn_vae_xa':
    model = gae.gae.model.GCNModelVAE_XA(num_features, hidden1, hidden2,fc_dim1,fc_dim2,fc_dim3,fc_dim4, dropout)
    loss_kl=optimizer.optimizer_kl
    loss_x=optimizer.optimizer_MSE
    loss_a=optimizer.optimizer_CE
elif model_str == 'gcn_vae_gcnX_inprA':
    model = gae.gae.model.GCNModelVAE_gcnX_inprA(num_features, hidden1, hidden2,gcn_dim1, dropout)
    loss_kl=optimizer.optimizer_kl
    loss_x=optimizer.optimizer_MSE
    loss_a=optimizer.optimizer_CE
    
elif model_str == 'gcn_vae_gcnX_inprA_w':
    model = gae.gae.model.GCNModelVAE_gcnX_inprA_w(num_features, hidden1, hidden2,gcn_dim1, dropout)
    loss_kl=optimizer.optimizer_kl
    loss_x=optimizer.optimizer_MSE
    loss_a=optimizer.optimizer_CE    
    
elif model_str=='gcn_vae_xa_e3':
    model = gae.gae.model.GCNModelVAE_XA_e3(num_features, hidden1, hidden2,hidden3,fc_dim1,fc_dim2,fc_dim3,fc_dim4, dropout)
    loss_kl=optimizer.optimizer_kl
    loss_x=optimizer.optimizer_MSE
    loss_a=optimizer.optimizer_CE
    
elif model_str == 'gcn_vae_xa_e1':
    model = gae.gae.model.GCNModelVAE_XA_e1(num_features, hidden1, dropout)
    loss_kl=optimizer.optimizer_kl
    loss_x=optimizer.optimizer_MSE
    loss_a=optimizer.optimizer_CE

elif model_str == 'gcn_vae_xa_e2_d1':
    model = gae.gae.model.GCNModelVAE_XA_e2_d1(num_features, hidden1,hidden2, dropout)
    loss_kl=optimizer.optimizer_kl
    loss_x=optimizer.optimizer_MSE
    loss_a=optimizer.optimizer_CE
    
elif model_str=='gcn_vae_xa_e2_d1_dca':
    model = gae.gae.model.GCNModelVAE_XA_e2_d1_DCA(num_features, hidden1,hidden2,fc_dim1, dropout)
    loss_kl=optimizer.optimizer_kl
    loss_x=optimizer.optimizer_zinb
    loss_a=optimizer.optimizer_CE

elif model_str=='gcn_vae_xa_e2_d1_dca_fca':
    model = gae.gae.model.GCNModelVAE_XA_e2_d1_DCA_fca(num_features, hidden1,hidden2,fc_dim1, dropout)
    loss_kl=optimizer.optimizer_kl
    loss_x=optimizer.optimizer_zinb
    loss_a=optimizer.optimizer_CE
    
elif model_str=='gcn_vae_xa_e2_d1_dcaFork':
    model = gae.gae.model.GCNModelVAE_XA_e2_d1_DCAfork(num_features, hidden1,hidden2,fc_dim1, dropout)
    loss_kl=optimizer.optimizer_kl
    loss_x=optimizer.optimizer_zinb
    loss_a=optimizer.optimizer_CE

elif model_str=='gcn_vae_xa_e2_d1_dcaElemPi':
    model = gae.gae.model.GCNModelVAE_XA_e2_d1_DCAelemPi(num_features, hidden1,hidden2,fc_dim1, dropout,shareGenePi)
    loss_kl=optimizer.optimizer_kl
    loss_x=optimizer.optimizer_zinb
    loss_a=optimizer.optimizer_CE

elif model_str=='gcn_vae_xa_e2_d1_dcaConstantDisp':
    model = gae.gae.model.GCNModelVAE_XA_e2_d1_DCA_constantDisp(num_features, hidden1,hidden2,fc_dim1, dropout,shareGenePi)
    loss_kl=optimizer.optimizer_kl
    loss_x=optimizer.optimizer_zinb
    loss_a=optimizer.optimizer_CE    
    
elif model_str == 'gcn_vae_xa_e4_d1':
    model = gae.gae.model.GCNModelVAE_XA_e4_d1(num_features, hidden1,hidden2,hidden3,hidden4, dropout)
    loss_kl=optimizer.optimizer_kl
    loss_x=optimizer.optimizer_MSE
    loss_a=optimizer.optimizer_CE    
    
elif model_str=='fc':
    model = gae.gae.model.FCVAE(num_features, hidden1, hidden2,hidden3,hidden4,hidden5,fc_dim1,fc_dim2,fc_dim3,fc_dim4, dropout)
    loss_kl=optimizer.optimizer_kl
    loss_x=optimizer.optimizer_MSE
#     loss_x=mse
    loss_a=optimizer.optimizer_CE
elif model_str=='fcae':
    model = gae.gae.model.FCAE(num_features, hidden1, hidden2,hidden3,hidden4,hidden5,fc_dim1,fc_dim2,fc_dim3,fc_dim4, dropout)
    loss_x=optimizer.optimizer_MSE
#     loss_x=mse

elif model_str=='fcae1':
    model = gae.gae.model.FCAE1(num_features, dropout,hidden1)
    loss_x=optimizer.optimizer_MSE
#     loss_x=mse
elif model_str=='fcae2':
    model = gae.gae.model.FCAE2(num_features, dropout,hidden1,hidden2)
    loss_x=optimizer.optimizer_MSE

elif model_str=='fc1':
    model = gae.gae.model.FCVAE1(num_features, hidden1,dropout)
    loss_x=optimizer.optimizer_MSE
    loss_kl=optimizer.optimizer_kl
    loss_a=optimizer.optimizer_CE

elif model_str=='fc1_fca':
    model = gae.gae.model.FCVAE1_fca(num_features, hidden1,dropout)
    loss_x=optimizer.optimizer_MSE
    loss_kl=optimizer.optimizer_kl
    loss_a=optimizer.optimizer_CE    
    
elif model_str=='fc1_dca':
    model = gae.gae.model.FCVAE1_DCA(num_features, hidden1,fc_dim1, dropout)
    loss_x=optimizer.optimizer_zinb
    loss_kl=optimizer.optimizer_kl
    loss_a=optimizer.optimizer_CE
    
elif model_str=='fc1_dca_sharded':
    model = gae.gae.model.FCVAE1_DCA_sharded(num_features, hidden1,fc_dim1, dropout)
    loss_x=optimizer.optimizer_zinb
    loss_kl=optimizer.optimizer_kl
    loss_a=optimizer.optimizer_CE
    
if clf=='clf_fc1':
    modelClf=gae.gae.model.Clf_fc1(hidden2, dropout,clf_hidden,ct_unique.size)
    loss_clf=torch.nn.CrossEntropyLoss()
    
if adv=='clf_fc1' or adv=='clf_fc1_eq' or adv=='clf_fc1_control' or adv=='clf_fc1_control_eq':
    modelAdv=gae.gae.model.Clf_fc1(hidden2, dropout,adv_hidden,sampleLabellist_ae['control13'].size()[1])
    loss_adv=optimizer.optimizer_CEclf
    
if adv=='clf_linear1' or adv=='clf_linear1_control':
    modelAdv=gae.gae.model.Clf_linear1(hidden2, dropout,sampleLabellist_ae['control13'].size()[1])
    loss_adv=optimizer.optimizer_CEclf
    
if 'NB' in name:
    print('using NB loss for X')
    loss_x=optimizer.optimizer_nb
    
# if use_cuda:
#     model.cuda()
    

optimizerVAEXA = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
if clf:
    optimizerClf=optim.Adam(modelClf.parameters(), lr=lr, weight_decay=weight_decay)
if adv:
    optimizerAdv=optim.Adam(modelAdv.parameters(), lr=lr_adv, weight_decay=weight_decay)    


In [11]:
pctVal=0.05
pctTest=0.1
np.random.seed(3)
allIdx=np.arange(data.shape[0])
np.random.shuffle(allIdx)
valIdx=allIdx[:int(pctVal*data.shape[0])]
testIdx=allIdx[int(pctVal*data.shape[0]):(int(pctVal*data.shape[0])+int(pctTest*data.shape[0]))]
trainIdx=allIdx[(int(pctVal*data.shape[0])+int(pctTest*data.shape[0])):]


In [12]:
def train(epoch):
    t = time.time()
    model.train()
    
    loss_all=0
    loss_x_all=0
    loss_kl_all=0
    for i in range(int(np.ceil(trainIdx.shape[0]/batchsize))):
        trainIdx_i=trainIdx[i*batchsize:min((i+1)*batchsize,trainIdx.shape[0])]
        
    
        optimizerVAEXA.zero_grad()

        features=data_train[trainIdx_i].cuda(0).float()
        if 'dca' in model_str:
            features_raw=data_raw[trainIdx_i].cuda(0)

        if use_cuda:
            train_nodes_idx=torch.tensor(np.repeat(True,trainIdx_i.size)).cuda(0)

        adj_recon,mu,logvar,z,features_recon = model(features, None)

        adj_recon=adj_recon.cuda(0)
        mu=mu.cuda(0)
        logvar=logvar.cuda(0)
        z=z.cuda(0)

        loss_kl_train=loss_kl(mu, logvar, train_nodes_idx)

        if 'dca' in model_str:
            if 'NB' in name:
                loss_x_train=loss_x(features_recon, features,train_nodes_idx,XreconWeight)
            else:
                loss_x_train=loss_x(features_recon, features,train_nodes_idx,XreconWeight,ridgeL,features_raw)
        else:
            loss_x_train=loss_x(features_recon, features,train_nodes_idx,XreconWeight,mse)
    
    
        loss= loss_kl_train*0.0001+loss_x_train
        loss.backward()
        optimizerVAEXA.step()
        loss_all+=loss.item()
        loss_x_all+=loss_x_train.item()
        loss_kl_all+=loss_kl_train.item()
        
    loss_all=loss_all/int(np.ceil(trainIdx.shape[0]/batchsize))
    loss_x_all=loss_x_all/int(np.ceil(trainIdx.shape[0]/batchsize))
    loss_kl_all=loss_kl_all/int(np.ceil(trainIdx.shape[0]/batchsize))

    with torch.no_grad():
        model.eval()
        loss_val_all=0
        nvalBatches=int(np.ceil(valIdx.shape[0]/batchsize))
        for i in range(nvalBatches):
            valIdx_i=valIdx[i*batchsize:min((i+1)*batchsize,valIdx.shape[0])]
            
            features=data_train[valIdx_i].cuda(0).float()
            if 'dca' in model_str:
                features_raw=data_raw[valIdx_i].cuda(0)
            adj_recon,mu,logvar,z, features_recon = model(features,None)
            adj_recon=adj_recon.cuda(0)
            mu=mu.cuda(0)
            logvar=logvar.cuda(0)
            z=z.cuda(0)
    
            if use_cuda:
                val_nodes_idx=torch.tensor(np.repeat(True,valIdx_i.size)).cuda(0)
            if 'dca' in model_str:
                if 'NB' in name:
                    loss_x_val=loss_x(features_recon, features,val_nodes_idx,XreconWeight)
                else:
                    loss_x_val=loss_x(features_recon, features,val_nodes_idx,XreconWeight,ridgeL,features_raw)
            else:
                loss_x_val=loss_x(features_recon, features,val_nodes_idx,XreconWeight,mse)


            loss_val_all += loss_x_val.item()
        loss_val_all=loss_val_all/nvalBatches
    print(' Epoch: {:04d}'.format(epoch),
          'loss_train: {:.4f}'.format(loss_all),
          'loss_kl_train: {:.4f}'.format(loss_kl_all),
          'loss_x_train: {:.4f}'.format(loss_x_all),
          'loss_val: {:.4f}'.format(loss_val_all),
          'time: {:.4f}s'.format(time.time() - t))
    return loss_all,loss_kl_all,loss_x_all,loss_val_all

    
# print('cross-validation ',seti)
train_loss_ep=[None]*epochs
train_loss_kl_ep=[None]*epochs
train_loss_x_ep=[None]*epochs
val_loss_ep=[None]*epochs
t_ep=time.time()

for ep in range(epochs):
# for ep in range(10000,20000):
    
    train_loss_ep[ep],train_loss_kl_ep[ep],train_loss_x_ep[ep],val_loss_ep[ep]=train(ep)

        
    if ep%saveFreq == (saveFreq-1):
        torch.save(model.state_dict(), os.path.join(modelsavepath,str(ep)+'.pt'))
    if use_cuda:
#         model.cuda()
        torch.cuda.empty_cache()
print(' total time: {:.4f}s'.format(time.time() - t_ep))


 Epoch: 0000 loss_train: 3.9689 loss_kl_train: 238.9152 loss_x_train: 3.9450 loss_val: 6.3519 time: 7.0443s
 Epoch: 0001 loss_train: 2.0433 loss_kl_train: 797.7281 loss_x_train: 1.9635 loss_val: 4.2195 time: 5.7016s
 Epoch: 0002 loss_train: 3.3342 loss_kl_train: 7616.9347 loss_x_train: 2.5725 loss_val: 4.3637 time: 5.7282s
 Epoch: 0003 loss_train: 2.8537 loss_kl_train: 7570.0473 loss_x_train: 2.0967 loss_val: 3.2791 time: 5.7369s
 Epoch: 0004 loss_train: 3.3372 loss_kl_train: 13723.3738 loss_x_train: 1.9649 loss_val: 2.7111 time: 5.7179s
 Epoch: 0005 loss_train: 6.4816 loss_kl_train: 46460.9623 loss_x_train: 1.8356 loss_val: 2.4949 time: 5.7989s
 Epoch: 0006 loss_train: 6.2741 loss_kl_train: 41419.7039 loss_x_train: 2.1321 loss_val: 2.7203 time: 5.7441s
 Epoch: 0007 loss_train: 5.5228 loss_kl_train: 35551.0493 loss_x_train: 1.9677 loss_val: 2.5361 time: 5.7528s
 Epoch: 0008 loss_train: 6.5626 loss_kl_train: 47818.9674 loss_x_train: 1.7807 loss_val: 2.3176 time: 5.6544s
 Epoch: 0009 los

 Epoch: 0076 loss_train: 1.3001 loss_kl_train: 86.5619 loss_x_train: 1.2914 loss_val: 1.3134 time: 5.8011s
 Epoch: 0077 loss_train: 1.2999 loss_kl_train: 86.0599 loss_x_train: 1.2913 loss_val: 1.3128 time: 5.8017s
 Epoch: 0078 loss_train: 1.2997 loss_kl_train: 85.5457 loss_x_train: 1.2911 loss_val: 1.3124 time: 5.7860s
 Epoch: 0079 loss_train: 1.2996 loss_kl_train: 85.1380 loss_x_train: 1.2910 loss_val: 1.3123 time: 5.7353s
 Epoch: 0080 loss_train: 1.2993 loss_kl_train: 84.4747 loss_x_train: 1.2909 loss_val: 1.3122 time: 5.7936s
 Epoch: 0081 loss_train: 1.2992 loss_kl_train: 83.9703 loss_x_train: 1.2908 loss_val: 1.3118 time: 5.7792s
 Epoch: 0082 loss_train: 1.2990 loss_kl_train: 83.4681 loss_x_train: 1.2906 loss_val: 1.3118 time: 5.6784s
 Epoch: 0083 loss_train: 1.2988 loss_kl_train: 83.0953 loss_x_train: 1.2905 loss_val: 1.3118 time: 5.6802s
 Epoch: 0084 loss_train: 1.2986 loss_kl_train: 82.5636 loss_x_train: 1.2904 loss_val: 1.3120 time: 5.7812s
 Epoch: 0085 loss_train: 1.2984 loss_

 Epoch: 0153 loss_train: 1.2919 loss_kl_train: 61.8252 loss_x_train: 1.2857 loss_val: 1.3077 time: 5.8833s
 Epoch: 0154 loss_train: 1.2920 loss_kl_train: 61.6415 loss_x_train: 1.2858 loss_val: 1.3070 time: 5.8709s
 Epoch: 0155 loss_train: 1.2918 loss_kl_train: 61.4580 loss_x_train: 1.2856 loss_val: 1.3068 time: 5.8465s
 Epoch: 0156 loss_train: 1.2915 loss_kl_train: 61.2386 loss_x_train: 1.2854 loss_val: 1.3078 time: 5.8892s
 Epoch: 0157 loss_train: 1.2913 loss_kl_train: 60.9787 loss_x_train: 1.2852 loss_val: 1.3090 time: 5.8245s
 Epoch: 0158 loss_train: 1.2912 loss_kl_train: 60.8216 loss_x_train: 1.2852 loss_val: 1.3092 time: 5.8018s
 Epoch: 0159 loss_train: 1.2914 loss_kl_train: 60.6049 loss_x_train: 1.2853 loss_val: 1.3091 time: 5.7530s
 Epoch: 0160 loss_train: 1.2915 loss_kl_train: 60.4296 loss_x_train: 1.2854 loss_val: 1.3080 time: 5.7873s
 Epoch: 0161 loss_train: 1.2913 loss_kl_train: 60.2614 loss_x_train: 1.2853 loss_val: 1.3065 time: 5.7683s
 Epoch: 0162 loss_train: 1.2911 loss_

 Epoch: 0230 loss_train: 1.2876 loss_kl_train: 50.1425 loss_x_train: 1.2826 loss_val: 1.3047 time: 5.8357s
 Epoch: 0231 loss_train: 1.2874 loss_kl_train: 49.9958 loss_x_train: 1.2824 loss_val: 1.3050 time: 5.8305s
 Epoch: 0232 loss_train: 1.2873 loss_kl_train: 49.9012 loss_x_train: 1.2823 loss_val: 1.3063 time: 5.8613s
 Epoch: 0233 loss_train: 1.2873 loss_kl_train: 49.7833 loss_x_train: 1.2823 loss_val: 1.3084 time: 5.7581s
 Epoch: 0234 loss_train: 1.2876 loss_kl_train: 49.6595 loss_x_train: 1.2826 loss_val: 1.3074 time: 5.7056s
 Epoch: 0235 loss_train: 1.2877 loss_kl_train: 49.6346 loss_x_train: 1.2828 loss_val: 1.3056 time: 5.8376s
 Epoch: 0236 loss_train: 1.2875 loss_kl_train: 49.5237 loss_x_train: 1.2826 loss_val: 1.3045 time: 5.8224s
 Epoch: 0237 loss_train: 1.2873 loss_kl_train: 49.4566 loss_x_train: 1.2823 loss_val: 1.3048 time: 5.8897s
 Epoch: 0238 loss_train: 1.2871 loss_kl_train: 49.2505 loss_x_train: 1.2822 loss_val: 1.3062 time: 5.8648s
 Epoch: 0239 loss_train: 1.2874 loss_

 Epoch: 0307 loss_train: 1.5102 loss_kl_train: 1974.0967 loss_x_train: 1.3128 loss_val: 1.3285 time: 5.8429s
 Epoch: 0308 loss_train: 2.3431 loss_kl_train: 9679.3044 loss_x_train: 1.3752 loss_val: 1.3616 time: 5.7701s
 Epoch: 0309 loss_train: 4.8837 loss_kl_train: 8846.0609 loss_x_train: 3.9991 loss_val: 1.4836 time: 5.7330s
 Epoch: 0310 loss_train: 13.1742 loss_kl_train: 76219.1052 loss_x_train: 5.5523 loss_val: 1.6024 time: 5.8124s
 Epoch: 0311 loss_train: 26.4357 loss_kl_train: 245405.2128 loss_x_train: 1.8952 loss_val: 1.8267 time: 5.8534s
 Epoch: 0312 loss_train: 3.4633 loss_kl_train: 16580.1079 loss_x_train: 1.8053 loss_val: 1.8121 time: 5.8299s
 Epoch: 0313 loss_train: 1.9184 loss_kl_train: 2141.6075 loss_x_train: 1.7043 loss_val: 1.7146 time: 5.8653s
 Epoch: 0314 loss_train: 1.6950 loss_kl_train: 928.6013 loss_x_train: 1.6021 loss_val: 1.6831 time: 5.7937s
 Epoch: 0315 loss_train: 1.6551 loss_kl_train: 1012.4858 loss_x_train: 1.5539 loss_val: 1.5480 time: 5.7689s
 Epoch: 0316 l

 Epoch: 0393 loss_train: nan loss_kl_train: nan loss_x_train: nan loss_val: nan time: 5.2414s
 Epoch: 0394 loss_train: nan loss_kl_train: nan loss_x_train: nan loss_val: nan time: 5.2098s
 Epoch: 0395 loss_train: nan loss_kl_train: nan loss_x_train: nan loss_val: nan time: 5.2142s
 Epoch: 0396 loss_train: nan loss_kl_train: nan loss_x_train: nan loss_val: nan time: 5.2641s
 Epoch: 0397 loss_train: nan loss_kl_train: nan loss_x_train: nan loss_val: nan time: 5.2642s
 Epoch: 0398 loss_train: nan loss_kl_train: nan loss_x_train: nan loss_val: nan time: 5.1533s
 Epoch: 0399 loss_train: nan loss_kl_train: nan loss_x_train: nan loss_val: nan time: 5.1668s
 Epoch: 0400 loss_train: nan loss_kl_train: nan loss_x_train: nan loss_val: nan time: 5.1698s
 Epoch: 0401 loss_train: nan loss_kl_train: nan loss_x_train: nan loss_val: nan time: 5.1616s
 Epoch: 0402 loss_train: nan loss_kl_train: nan loss_x_train: nan loss_val: nan time: 5.1361s
 Epoch: 0403 loss_train: nan loss_kl_train: nan loss_x_train

KeyboardInterrupt: 

In [12]:
torch.save(model.state_dict(), os.path.join(modelsavepath,str(ep)+'.pt'))
torch.cuda.empty_cache()

In [1]:
with open(os.path.join(logsavepath,'train_loss'), 'wb') as output:
    pickle.dump(train_loss_ep, output, pickle.HIGHEST_PROTOCOL)
with open(os.path.join(logsavepath,'train_loss_kl'), 'wb') as output:
    pickle.dump(train_loss_kl_ep, output, pickle.HIGHEST_PROTOCOL)
with open(os.path.join(logsavepath,'train_loss_x'), 'wb') as output:
    pickle.dump(train_loss_x_ep, output, pickle.HIGHEST_PROTOCOL)
with open(os.path.join(logsavepath,'val_loss'), 'wb') as output:
    pickle.dump(val_loss_ep, output, pickle.HIGHEST_PROTOCOL)


NameError: name 'os' is not defined

In [None]:
plt.plot(np.arange(epochs),train_loss_x_ep)
plt.plot(np.arange(epochs),val_loss_ep)
plt.plot(np.arange(epochs),train_loss_kl_ep)

plt.ylim((0,5))
# plt.xlim((0,3000))
plt.legend(['training x recon loss','validation x recon loss','training kl loss'],loc='upper right')
plt.savefig(os.path.join(plotsavepath,'loss_seed6_zoom.jpg'))
plt.show()
