<a href="https://colab.research.google.com/github/thxsxth/RLMimic/blob/master/Model/Pyro_Model_building.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## This Notebook implements A Deep Latent Inference Network for Sepsis Patients.
Structured Inference Networks as described by this [paper](https://arxiv.org/abs/1609.09869), is the basis for the implementation, modified to respect causailty and continous observations.

Pyro probabilistic language is used for training via Stochastic (Amotized) Variational Inference.

In [1]:
cd 'drive/My Drive/sepsis3-cohort'

/content/drive/My Drive/sepsis3-cohort


Install dependencies

In [3]:
!pip install pyro-ppl

Collecting pyro-ppl
[?25l  Downloading https://files.pythonhosted.org/packages/c0/77/4db4946f6b5bf0601869c7b7594def42a7197729167484e1779fff5ca0d6/pyro_ppl-1.3.1-py3-none-any.whl (520kB)
[K     |████████████████████████████████| 522kB 2.9MB/s 
Collecting pyro-api>=0.1.1
  Downloading https://files.pythonhosted.org/packages/fc/81/957ae78e6398460a7230b0eb9b8f1cb954c5e913e868e48d89324c68cec7/pyro_api-0.1.2-py3-none-any.whl
Installing collected packages: pyro-api, pyro-ppl
Successfully installed pyro-api-0.1.2 pyro-ppl-1.3.1


In [5]:
!pip install tensorboardX

Collecting tensorboardX
[?25l  Downloading https://files.pythonhosted.org/packages/35/f1/5843425495765c8c2dd0784a851a93ef204d314fc87bcc2bbb9f662a3ad1/tensorboardX-2.0-py2.py3-none-any.whl (195kB)
[K     |█▊                              | 10kB 16.7MB/s eta 0:00:01[K     |███▍                            | 20kB 6.0MB/s eta 0:00:01[K     |█████                           | 30kB 7.1MB/s eta 0:00:01[K     |██████▊                         | 40kB 7.2MB/s eta 0:00:01[K     |████████▍                       | 51kB 7.0MB/s eta 0:00:01[K     |██████████                      | 61kB 7.8MB/s eta 0:00:01[K     |███████████▊                    | 71kB 7.7MB/s eta 0:00:01[K     |█████████████▍                  | 81kB 8.2MB/s eta 0:00:01[K     |███████████████                 | 92kB 7.7MB/s eta 0:00:01[K     |████████████████▊               | 102kB 8.0MB/s eta 0:00:01[K     |██████████████████▍             | 112kB 8.0MB/s eta 0:00:01[K     |████████████████████            | 122kB 8.

In [4]:
import torch
import numpy as np
import pandas as pd
import datetime as dt
import random
import time
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset,DataLoader
import torch.nn.functional as F
import os
import glob
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence,pad_packed_sequence
# from tensorboardX import SummaryWriter



In [5]:
import pyro
import pyro.distributions as dist
import pyro.poutine as poutine
from pyro.distributions import TransformedDistribution
from pyro.distributions.transforms import affine_autoregressive
from pyro.infer import SVI, JitTrace_ELBO, Trace_ELBO, TraceEnum_ELBO, TraceTMC_ELBO, config_enumerate
from pyro.optim import ClippedAdam
from modules import Emitter,Gated_Transition,Combiner,Encoder
device='cuda' if torch.cuda.is_available() else 'cpu'

In [9]:
with open('patientIDs_MIMIC3.csv') as file:
   icustays=file.readlines()

icustays=[200000+int(x[:-1]) for x in icustays]

In [None]:
icustays[:3]

[200003, 200014, 200030]

In [6]:
vitals=pd.read_csv('../Vitals/Vitals.csv',parse_dates=['charttime']) #pivoted vitals
sofa=pd.read_csv('../pivoted_sofa/pivoted_sofa.csv',parse_dates=['endtime','starttime']) #pivoted sofa


In [8]:
co=pd.read_csv('sepsis3_adults.csv',parse_dates=['intime','outtime','suspected_infection_time_poe']) #cohort + demographics

In [9]:
input_cv=pd.read_csv('../Fluids/cleaned_input_cv.csv',parse_dates=['charttime']) 
input_mv=pd.read_csv('../Fluids/input_eventsMV.csv',parse_dates=['starttime','endtime'])

In [10]:
## Consider only the cohort
vitals=vitals[vitals.icustay_id.isin(set(co.icustay_id))]
sofa=sofa[sofa.icustay_id.isin(set(co.icustay_id))]

Here we are only considering the cohort defined in https://gitlab.doc.ic.ac.uk/AIClinician/AIClinician

In [11]:
vitals=vitals[vitals.icustay_id.isin(icustays)]
sofa=sofa[sofa.icustay_id.isin(icustays)]

In [14]:
len(vitals.icustay_id.unique())

14286

In [None]:
# vitals.to_csv('vitals_demo.csv')

### Cleaning Dataframes and Creating the Treatment Columns

In [11]:
# 
sofa[['rate_epinephrine','rate_norepinephrine','rate_dopamine',	'rate_dobutamine']]=sofa[['rate_epinephrine','rate_norepinephrine','rate_dopamine',	'rate_dobutamine']].fillna(0)

In [12]:
sofa['vaso_rate']=sofa['rate_epinephrine']+sofa['rate_norepinephrine']+sofa['rate_dobutamine']+sofa['rate_dopamine']
sofa['vaso_rate'].describe()

count    3.821099e+06
mean     1.534592e-01
std      1.389909e+00
min      0.000000e+00
25%      0.000000e+00
50%      0.000000e+00
75%      0.000000e+00
max      5.247707e+02
Name: vaso_rate, dtype: float64

In [13]:
sofa=sofa[['icustay_id','endtime','vaso_rate','rate_norepinephrine','rate_dopamine',	'rate_dobutamine','urineoutput','cardiovascular_24hours',	'liver_24hours','cns_24hours',	'renal_24hours',	'SOFA_24hours']]

In [14]:
## Cleaning and concatenating Fluid Inputs
input_cv=input_cv[['icustay_id','charttime','tev']]
input_mv=input_mv[['icustay_id','endtime','tev']]
input_mv['tev_mv']=input_mv['tev']
input_mv['charttime']=input_mv['endtime']
input_mv=input_mv.drop('tev',axis=1)
input_fluids=input_mv.merge(input_cv,on=['icustay_id','charttime'],how='outer')[['icustay_id','charttime','tev','tev_mv']]
input_fluids['tev'],input_fluids['tev_mv']=input_fluids['tev'].fillna(0),input_fluids['tev_mv'].fillna(0)
input_fluids['volume']=input_fluids['tev']+input_fluids['tev_mv']
input_fluids=input_fluids[input_fluids.icustay_id.isin(set(co.icustay_id))]

Include age,gender BMI for Vitals df

In [15]:
## Re Index so it's easier to find
co=co.set_index('icustay_id')
vitals['age']=co.loc[vitals['icustay_id']]['age'].values
vitals['gender']=co.loc[vitals['icustay_id']]['is_male'].values
vitals['bmi']=co.loc[vitals['icustay_id']]['bmi'].values
vitals['sus_time']=co.loc[vitals['icustay_id']]['suspected_infection_time_poe'].values


In [None]:
len(vitals[vitals.bmi.notna()].icustay_id.unique()),len(vitals.icustay_id.unique())

(21398, 31211)

In [None]:
sofa.head(),vitals.head(),co.head()

(   icustay_id             endtime  ...  renal_24hours  SOFA_24hours
 0      200001 2181-11-25 19:00:00  ...              2             3
 1      200001 2181-11-25 20:00:00  ...              2             3
 2      200001 2181-11-25 21:00:00  ...              2             3
 3      200001 2181-11-25 22:00:00  ...              2             3
 4      200001 2181-11-25 23:00:00  ...              3             4
 
 [5 rows x 11 columns],
    subject_id  icustay_id           charttime  ...   age  gender       bmi
 0       55973      200001 2181-11-25 19:06:00  ...  61.0       0  21.06264
 1       55973      200001 2181-11-25 19:07:00  ...  61.0       0  21.06264
 2       55973      200001 2181-11-25 19:08:00  ...  61.0       0  21.06264
 3       55973      200001 2181-11-25 19:14:00  ...  61.0       0  21.06264
 4       55973      200001 2181-11-25 19:16:00  ...  61.0       0  21.06264
 
 [5 rows x 14 columns],
             Unnamed: 0  hadm_id  excluded  ... abx_poe sepsis-3 sofa>=2
 icus

### Necessary Imports

In [16]:
device='cuda' if torch.cuda.is_available() else 'cpu'

#### Let's define training and validation cohorts

In [18]:
# training_cohort=np.random.choice(list(co.index),int(0.8*len(list(co.index))),replace=False)
# valid_cohort=np.array(list(set(co.index)-set(training_cohort)))

In [132]:
# training_cohort=np.random.choice(list(vitals.icustay_id.unique()),int(0.8*len(vitals.icustay_id.unique())),replace=False)
# valid_cohort=np.array(list((set(vitals.icustay_id.unique())-set(training_cohort))))

In [135]:
# np.save('train_cohort',training_cohort)
# np.save('valid_cohort',valid_cohort)

In [30]:
valid_cohort=np.load('valid_cohort.npy')
training_cohort=np.load('train_cohort.npy')

In [31]:
test_cohort=np.array(list(set(co.index)-set(training_cohort)))

In [None]:
# set(valid_cohort).intersection(set(training_cohort)),len(training_cohort)+len(valid_cohort)==len(list(co.index))

(set(), True)

In [50]:
valid_cohort=list(set(vitals.icustay_id.unique())-set(training_cohort))
len(valid_cohort)

2858

So as expected we don't have any common elements, and everything is accounted for

In [19]:
def get_mini_batch_mask(mini_batch, seq_lengths):
    mask = torch.zeros(mini_batch.shape[0:2])
    for b in range(mini_batch.shape[0]):
        mask[b, 0:seq_lengths[b]] = torch.ones(seq_lengths[b])
    return mask.to(device)

In [20]:
class MyDataLoader():
  """
  Instance of MyDataLoader class yeilds batches of trajectories , treatments

  """
  def __init__(self,sofa_df=sofa,vitals_df=vitals,input_df=input_fluids,cohort=co,batch_size=16,icustay_list=training_cohort,train=True):

    """
    sofa_df (pd.Dataframe): Pivoted Sofa Dataframe (Also includes Vasopressors)
    vitals (pd.Dataframe): Pivoted vitals
    input_df (pd.Dataframe):Input fluids (CV and MV concatanated)
    cohort(pd.Dataframe): Cohort Dataframe (contains some demographics)
    batch_size (int):batch size
    icu_list (iterable): List of patient Ids

    """
   
    self.sofa=sofa_df
    self.vitals=vitals_df
    self.batch_size=batch_size
    self.icustays=icustay_list
    self.input_fluids=input_df
    self.cohort=cohort
    self.train=train
     
    
  def __iter__(self):
     if self.train:
        np.random.shuffle(self.icustays)
     patients=self.icustays
     for k in range(0,len(patients)-self.batch_size,self.batch_size):
          batch_patients=patients[k:k+self.batch_size]   # Iterable containing Batch_size IDS          
          treatments=[]
          trajectories=[]
          seq_lens=[]
          
          for pat in batch_patients:
              temp_v=self.vitals[self.vitals['icustay_id']==pat].set_index('charttime')
              temp_sofa=self.sofa[self.sofa['icustay_id']==pat].set_index('endtime')
          
              # sus_time=self.sus_dict[pat]
              ## Get the data points after suspection of infection
             
              
              ## Also need to consider the suspected infection
              df=pd.concat([self.vitals[self.vitals.icustay_id==pat].set_index('charttime'),
                              self.input_fluids[self.input_fluids.icustay_id==pat].set_index('charttime'),
                              self.sofa[self.sofa.icustay_id==pat].set_index('endtime')]).resample('H').last()
              # print(df.shape)
              df=df.truncate(before=df['sus_time'].values[0])
              df=df[['volume','vaso_rate','age','HeartRate','SysBP','DiasBP',	'MeanBP','RespRate','SpO2',
                                      'liver_24hours','cardiovascular_24hours','cns_24hours','renal_24hours','SOFA_24hours']]
                        
              ## Drop null values (we sill have to hourly sequential structure)
              df=df.ffill().dropna()
        
              if not self.train and  df.shape[0]<1:
                   continue
              if self.train:
                if df.shape[0]>50:
                  k=np.random.choice(np.arange(df.shape[0]-50))
                  df=df.iloc[k:k+50,]

              
              trajectories.append(torch.FloatTensor(df[['age','HeartRate','SysBP','DiasBP',	'MeanBP','RespRate','SpO2',
                                      'liver_24hours','cardiovascular_24hours',
                                      'cns_24hours','renal_24hours','SOFA_24hours']].values).to(device))           
              
              actions=df[['vaso_rate','volume']]
              treatments.append(torch.FloatTensor(actions.values).to(device))
              seq_lens.append(df.shape[0])

          padded_trajectories=pad_sequence(trajectories,batch_first=True)
          padded_treatments=pad_sequence(treatments,batch_first=True)
          mask=get_mini_batch_mask(padded_trajectories,seq_lens)
          
          yield padded_trajectories,mask, padded_treatments,seq_lens
         

              

         
     


#### Testing the data loader

Works when returning lists of trajectories list has length L, and trajectory[i].shape : T*D (D is the Dimension)

In [32]:
train_loader= MyDataLoader(batch_size=32)
validation_loader=MyDataLoader(batch_size=32,icustay_list=valid_cohort,train=False)
test_loader=MyDataLoader(batch_size=32,icustay_list=test_cohort,train=False)

In [22]:
for i, (trajectory,mask,treatment,lens) in enumerate(validation_loader):
  print('Batch number {}'.format(i))
  print(trajectory.shape)
  print(treatment.shape)
  print(mask.shape)
  print(min(lens))
  if i==4:
    break

Batch number 0
torch.Size([32, 606, 12])
torch.Size([32, 606, 2])
torch.Size([32, 606])
13
Batch number 1
torch.Size([29, 545, 12])
torch.Size([29, 545, 2])
torch.Size([29, 545])
11
Batch number 2
torch.Size([30, 1366, 12])
torch.Size([30, 1366, 2])
torch.Size([30, 1366])
13
Batch number 3
torch.Size([30, 798, 12])
torch.Size([30, 798, 2])
torch.Size([30, 798])
4
Batch number 4
torch.Size([30, 1368, 12])
torch.Size([30, 1368, 2])
torch.Size([30, 1368])
15


In [None]:
sum(mask[12]),lens[12]

(tensor(6., device='cuda:0'), 6)

### Defining Model and hyperparamters

In [23]:
class DMM(nn.Module):
  
  """
    This PyTorch Module encapsulates the model as well as the
    variational distribution (the guide) for the Deep Markov Model

    Modified from https://github.com/pyro-ppl/pyro/blob/dev/examples/dmm/dmm.py
  """
  def __init__(self,z_dim,u_dim,x_dim,binary_dim=None,rnn_dim=1024,
               hidden_emitter_dim=512,hidden_gated_dim=512,hidden_layers=[16,8,4]
               ,num_iafs=0, iaf_dim=50):
    
    """
    z_dim (int) : Dimension of the Latent Space
    u_dim (int) : Dimension of Action space
    x_dim (int) : Dimension of (conitnous) observations
    binary_dim (int) : Dimension of binary observations
    hidden_layers (iterable): Number of hidden layers for Emitter, transm Encoder respectively
    Others should be self explanatory
    """
    
    super(DMM,self).__init__()
    self.emitter=Emitter(z_dim,x_dim,binary_dim,hidden_emitter_dim,hidden_layers[0])
    self.trans=Gated_Transition(z_dim,u_dim, hidden_gated_dim,hidden_layers[1])
    self.combiner=Combiner(z_dim,rnn_dim)
    self.binary=binary_dim

    input_dim=z_dim+u_dim 
    if binary_dim:
      self.rnn=Encoder(x_dim+binary_dim,hidden_dim=rnn_dim,n_layers=hidden_layers[2])
    else:
      self.rnn=Encoder(x_dim,hidden_dim=rnn_dim,n_layers=hidden_layers[2])
    
    # if we're using normalizing flows, instantiate those 
    self.iafs = [affine_autoregressive(z_dim, hidden_dims=[iaf_dim]) for _ in range(num_iafs)]
    self.iafs_modules = nn.ModuleList(self.iafs)

    # define a (trainable) parameters z_0 and z_q_0 that help define the probability
    # distributions p(z_1) and q(z_1)
    # (since for t = 1 there are no previous latents to condition on)
    
    self.z_0 = nn.Parameter(torch.zeros(z_dim))
    self.z_q_0 = nn.Parameter(torch.zeros(z_dim))
    
    # define a (trainable) parameter for the initial hidden state of the rnn
    self.h_0 = nn.Parameter(torch.zeros(self.rnn.n_layers, 1, rnn_dim))

    if device=='cuda':
      self.cuda()



    def model(self,batch,batch_lens,actions,mask,binary=None,
               annealing_factor=1.0):
      """
      batch : Batch of continous observables: B*T*
      binary: Batch of binary observables:
      batch_lengths :list
      actions : B*T*|A|

      """
      
      T_max = batch.size(1)

      # set z_prev = z_0 to setup the recursive conditioning in p(z_t | z_{t-1,u_{t-1}})
      # and set initial treatment to zero

      z_prev = self.z_0.expand(batch.size(0), self.z_0.size(0)) #B*Z_dim
      u_prev=torch.zeros(actions.shape[0],actions.shape[2]).to(device)     #B*A
      
      # we enclose all the sample statements in the model in a plate.
      # this marks that each datapoint is conditionally independent of the others
      
      with pyro.plate("z_minibatch", len(batch)):

        for t in pyro.markov(range(1, T_max + 1)):

          # the next chunk of code samples z_t ~ p(z_t | z_{t-1},u_{t-1})
          # note that (both here and elsewhere) we use poutine.scale to take care of KL annealing

          z_mean,z_scale=self.trans(z_prev,u_prev)

          with poutine.scale(scale=annealing_factor):              
                    z_t = pyro.sample("z_%d" % t,
                                      dist.Normal(z_mean, z_scale)
                                          .mask(mask[:, t - 1:t])
                                          .to_event(1))
          
          if self.binary:
            mu,sigma,binary=self.emitter(z_t)
            # change sigma and see if it helps
            # sigma=torch.ones_like(mu).to(device)*0.0001
          else:
            
            mu, sigma=self.emitter(z_t)
            # sigma=torch.ones_like(mu).to(device)*0.001

          pyro.sample("cts_x_%d" % t,
                            dist.Normal(mu,sigma)
                                .mask(mask[:, t - 1:t])
                                .to_event(1),
                            obs=batch[:, t - 1, :])
          
          if self.binary:        
               pyro.sample("binary_x_%d" % t,
                            dist.Bernoulli(binary)
                                .mask(mini_batch_mask[:, t - 1:t])
                                .to_event(1),
                            obs=binary[:, t - 1, :])
          z_prev = z_t
          u_prev=actions[:,t-1,:]
    
    def guide(self,batch,batch_lens,actions,mask,
               binary=None,annealing_factor=1.0):
      """
      Need to adjust when using binary observations

      """
      # this is the number of time steps we need to process in the mini-batch
      T_max = batch.size(1)
      # register all PyTorch (sub)modules with pyro
      pyro.module("dmm", self)

       # if on gpu we need the fully broadcast view of the rnn initial state
       # to be in contiguous gpu memory
       
      h_0_contig = self.h_0.expand(self.rnn.n_layers, batch.size(0), self.rnn.hidden_dim).contiguous()  #n_layers*B*rnn_hidden
      _,rnn_output=self.rnn(batch,batch_lens,h_0_contig)  #rnn_ouput has shape B*T*H
       
      # set z_prev = z_q_0 to setup the recursive conditioning in q(z_t |...)     
      z_prev = self.z_q_0.expand(batch.size(0), self.z_q_0.size(0))
      u_prev=torch.zeros(actions.shape[0],actions.shape[2]).to(device)

      with pyro.plate("z_minibatch", len(batch)):
            
            # sample the latents z one time step at a time
            # we wrap this loop in pyro.markov so that TraceEnum_ELBO can use multiple samples from the guide at each z
            
            for t in pyro.markov(range(1, T_max + 1)):
                # the next two lines assemble the distribution q(z_t | z_{t-1}, x_{t:T})
                
                z_loc, z_scale = self.combiner(z_prev, rnn_output[:, t - 1, :])

                # if we are using normalizing flows, we apply the sequence of transformations
                # parameterized by self.iafs to the base distribution defined in the previous line
                # to yield a transformed distribution that we use for q(z_t|...)
                
                if len(self.iafs) > 0:
                    z_dist = TransformedDistribution(dist.Normal(z_loc, z_scale), self.iafs)
                    assert z_dist.event_shape == (self.z_q_0.size(0),)
                    assert z_dist.batch_shape[-1:] == (len(batch),)
                
                else:
                    z_dist = dist.Normal(z_loc, z_scale)
                    assert z_dist.event_shape == ()
                    assert z_dist.batch_shape[-2:] == (len(batch), self.z_q_0.size(0))

                # sample z_t from the distribution z_dist
               
                with pyro.poutine.scale(scale=annealing_factor):
                    if len(self.iafs) > 0:
                        # in output of normalizing flow, all dimensions are correlated (event shape is not empty)
                        z_t = pyro.sample("z_%d" % t,
                                          z_dist.mask(mask[:, t - 1]))
                    
                    else:
                        # when no normalizing flow used, ".to_event(1)" indicates latent dimensions are independent
                        z_t = pyro.sample("z_%d" % t,
                                          z_dist.mask(mask[:, t - 1:t])
                                          .to_event(1))
                
                # the latent sampled at this time step will be conditioned upon in the next time step
                # so keep track of it         
                z_prev = z_t
                u_prev=actions[:,t-1,:]

    self.model=model
    self.guide=guide




def get_annhealing_factor(epoch):
  if epoch<10:
    return 0.5
  else:
     return min(1.0,0.25+0.005*epoch)



In [24]:

def validate(batch,batch_lens,actions,masks):
    # put the  into evaluation mode (i.e. turn off drop-out if applicable)
    dmm.rnn.eval()
    dmm.emitter.eval()
    dmm.trans.eval()

    val_nll = svi.evaluate_loss(dmm,batch,batch_lens,actions,masks
                               ) / np.sum(batch_lens)

    dmm.rnn.train()
    dmm.emitter.train()
    dmm.trans.train()
    return val_nll

In [25]:

"""### Defining Model and hyperparamters"""

learning_rate=0.00001
# learning_rate=25
beta1=0.96
beta2=0.999
clip_norm=20
lr_decay=0.99996
weight_decay=0.0

dmm=DMM(64,2,12)
# dmm=DMM(256,2,12,binary_dim=None,rnn_dim=512,
#                hidden_emitter_dim=1024,hidden_gated_dim=512,hidden_layers=[16,16,5]
#                ,num_iafs=0, iaf_dim=50)
N_epochs=5000
annhealing_factor=0.5
# writer=SummaryWriter(logdir='logs/exp1') #change this as needed
# setup optimizer
adam_params = {"lr": learning_rate, "betas": (beta1, beta2),
                   "clip_norm": clip_norm, "lrd": lr_decay,
                   "weight_decay": weight_decay}
optimizer = ClippedAdam(adam_params)
# setup inference algorithm
svi = SVI(dmm.model, dmm.guide, optimizer, Trace_ELBO())


### Training

In [None]:
times = [time.time()]
for epoch in range(1,N_epochs):
  val_nll=0
  train_nll=0
  val_steps=0
  train_steps=0
  for i,(batch,masks,actions,batch_lens) in enumerate(train_loader):

    if min(batch_lens)==0:
      continue


    loss = svi.step(dmm,batch=batch,batch_lens=batch_lens,actions=actions,mask=masks,
               binary=None,annealing_factor=annhealing_factor)
    
    # print(loss)
    
    batch_nll=svi.evaluate_loss(dmm,batch,batch_lens,actions,masks
                                 ) / np.sum(batch_lens)
    
    
    train_nll+=batch_nll
    train_steps+=1
    print(batch_nll)
    print('Batch : ', train_steps, ' Training Loss :',train_nll/train_steps,end='')
    
    
     
  val_nll=0
  val_steps=0
  for batch,masks,actions,batch_lens in validation_loader:
      if min(batch_lens)==0:
          continue
      
      val_nll+=validate(batch,batch_lens,actions,masks)
      print('Validating : ',validate(batch,batch_lens,actions,masks))
      val_steps+=1 
       
      print(val_nll/val_steps) 
     


  print('-'*125)
  if val_nll/val_steps<60:
    torch.save(dmm.state_dict(),'state_dict_{}.pt'.format(val_nll))
  
  # writer.add_scalar('train_nll', train_nll/train_steps, epoch)
  print('Train nll {}'.format(train_nll/train_steps))
  # writer.add_scalar('vall_nll',val_nll/val_steps, epoch)
  print('Validation nll {}'.format(val_nll/val_steps))  
  annhealing_factor=get_annhealing_factor(epoch)

In [40]:
#  torch.save(dmm.state_dict(),'state_dict_crazy.pt')

In [28]:
# dmm.load_state_dict(torch.load('state_dict_new.pt',map_location=torch.device('cpu')))

<All keys matched successfully>