In [13]:
import numpy as np
import argparse
import os
import imp
import re
import pickle
import datetime
import random
import math
import copy


import torch
from torch import nn
import torch.nn.utils.rnn as rnn_utils
from torch.utils import data
from torch.autograd import Variable
import torch.nn.functional as F


from utils import utils
from utils.readers import InHospitalMortalityReader
from utils.preprocessing import Discretizer, Normalizer
from utils import metrics
from utils import common_utils

### Prepare

In [14]:
data_path = './data/'
file_name = './model/concare0'
small_part = False
arg_timestep = 1.0
batch_size = 256
epochs = 100

In [15]:
# Build readers, discretizers, normalizers
train_reader = InHospitalMortalityReader(dataset_dir=os.path.join(data_path, 'train'),
                                         listfile=os.path.join(data_path, 'train_listfile.csv'),
                                         period_length=48.0)

val_reader = InHospitalMortalityReader(dataset_dir=os.path.join(data_path, 'train'),
                                       listfile=os.path.join(data_path, 'val_listfile.csv'),
                                       period_length=48.0)

discretizer = Discretizer(timestep=arg_timestep,
                          store_masks=True,
                          impute_strategy='previous',
                          start_time='zero')

In [16]:
discretizer_header = discretizer.transform(train_reader.read_example(0)["X"])[1].split(',')
cont_channels = [i for (i, x) in enumerate(discretizer_header) if x.find("->") == -1]

normalizer = Normalizer(fields=cont_channels)  # choose here which columns to standardize
normalizer_state = 'ihm_normalizer'
normalizer_state = os.path.join(os.path.dirname(data_path), normalizer_state)
normalizer.load_params(normalizer_state)

In [17]:
n_trained_chunks = 0
train_raw = utils.load_data(train_reader, discretizer, normalizer, small_part, return_names=True)
val_raw = utils.load_data(val_reader, discretizer, normalizer, small_part, return_names=True)

In [18]:
demographic_data = []
diagnosis_data = []
idx_list = []

demo_path = data_path + 'demographic/'
for cur_name in os.listdir(demo_path):
    cur_id, cur_episode = cur_name.split('_', 1)
    cur_episode = cur_episode[:-4]
    cur_file = demo_path + cur_name

    with open(cur_file, "r") as tsfile:
        header = tsfile.readline().strip().split(',')
        if header[0] != "Icustay":
            continue
        cur_data = tsfile.readline().strip().split(',')
        
    if len(cur_data) == 1:
        cur_demo = np.zeros(12)
        cur_diag = np.zeros(128)
    else:
        if cur_data[3] == '':
            cur_data[3] = 60.0
        if cur_data[4] == '':
            cur_data[4] = 160
        if cur_data[5] == '':
            cur_data[5] = 60

        cur_demo = np.zeros(12)
        cur_demo[int(cur_data[1])] = 1
        cur_demo[5 + int(cur_data[2])] = 1
        cur_demo[9:] = cur_data[3:6]
        cur_diag = np.array(cur_data[8:], dtype=np.int)

    demographic_data.append(cur_demo)
    diagnosis_data.append(cur_diag)
    idx_list.append(cur_id+'_'+cur_episode)

for each_idx in range(9,12):
    cur_val = []
    for i in range(len(demographic_data)):
        cur_val.append(demographic_data[i][each_idx])
    cur_val = np.array(cur_val)
    _mean = np.mean(cur_val)
    _std = np.std(cur_val)
    _std = _std if _std > 1e-7 else 1e-7
    for i in range(len(demographic_data)):
        demographic_data[i][each_idx] = (demographic_data[i][each_idx] - _mean) / _std

In [19]:
device = torch.device("cuda:0" if torch.cuda.is_available() == True else 'cpu')
#device = torch.device('cpu')
print("available device: {}".format(device))

available device: cuda:0


### model

In [20]:
class SingleAttention(nn.Module):
    def __init__(self, attention_input_dim, attention_hidden_dim, attention_type='add', demographic_dim=12, time_aware=False, use_demographic=False):
        super(SingleAttention, self).__init__()
        
        self.attention_type = attention_type
        self.attention_hidden_dim = attention_hidden_dim
        self.attention_input_dim = attention_input_dim
        self.use_demographic = use_demographic
        self.demographic_dim = demographic_dim
        self.time_aware = time_aware

        # batch_time = torch.arange(0, batch_mask.size()[1], dtype=torch.float32).reshape(1, batch_mask.size()[1], 1)
        # batch_time = batch_time.repeat(batch_mask.size()[0], 1, 1)
        
        if attention_type == 'add':
            if self.time_aware == True:
                # self.Wx = nn.Parameter(torch.randn(attention_input_dim+1, attention_hidden_dim))
                self.Wx = nn.Parameter(torch.randn(attention_input_dim, attention_hidden_dim))
                self.Wtime_aware = nn.Parameter(torch.randn(1, attention_hidden_dim))
                nn.init.kaiming_uniform_(self.Wtime_aware, a=math.sqrt(5))
            else:
                self.Wx = nn.Parameter(torch.randn(attention_input_dim, attention_hidden_dim))
            self.Wt = nn.Parameter(torch.randn(attention_input_dim, attention_hidden_dim))
            self.Wd = nn.Parameter(torch.randn(demographic_dim, attention_hidden_dim))
            self.bh = nn.Parameter(torch.zeros(attention_hidden_dim,))
            self.Wa = nn.Parameter(torch.randn(attention_hidden_dim, 1))
            self.ba = nn.Parameter(torch.zeros(1,))
            
            nn.init.kaiming_uniform_(self.Wd, a=math.sqrt(5))
            nn.init.kaiming_uniform_(self.Wx, a=math.sqrt(5))
            nn.init.kaiming_uniform_(self.Wt, a=math.sqrt(5))
            nn.init.kaiming_uniform_(self.Wa, a=math.sqrt(5))
        elif attention_type == 'mul':
            self.Wa = nn.Parameter(torch.randn(attention_input_dim, attention_input_dim))
            self.ba = nn.Parameter(torch.zeros(1,))
            
            nn.init.kaiming_uniform_(self.Wa, a=math.sqrt(5))
        elif attention_type == 'concat':
            if self.time_aware == True:
                self.Wh = nn.Parameter(torch.randn(2*attention_input_dim+1, attention_hidden_dim))
            else:
                self.Wh = nn.Parameter(torch.randn(2*attention_input_dim, attention_hidden_dim))

            self.Wa = nn.Parameter(torch.randn(attention_hidden_dim, 1))
            self.ba = nn.Parameter(torch.zeros(1,))
            
            nn.init.kaiming_uniform_(self.Wh, a=math.sqrt(5))
            nn.init.kaiming_uniform_(self.Wa, a=math.sqrt(5))
            
        elif attention_type == 'new':
            self.Wt = nn.Parameter(torch.randn(attention_input_dim, attention_hidden_dim))
            self.Wx = nn.Parameter(torch.randn(attention_input_dim, attention_hidden_dim))

            self.rate = nn.Parameter(torch.zeros(1)+0.8)
            nn.init.kaiming_uniform_(self.Wx, a=math.sqrt(5))
            nn.init.kaiming_uniform_(self.Wt, a=math.sqrt(5))
            
        else:
            raise RuntimeError('Wrong attention type.')
        
        self.tanh = nn.Tanh()
        self.softmax = nn.Softmax()
        self.sigmoid = nn.Sigmoid()
        self.relu = nn.ReLU()
    
    def forward(self, input, demo=None):
 
        batch_size, time_step, input_dim = input.size() # batch_size * time_step * hidden_dim(i)
        #assert(input_dim == self.input_dim)

        # time_decays = torch.zeros((time_step,time_step)).to(device)# t*t
        # for this_time in range(time_step):
        #     for pre_time in range(time_step):
        #         if pre_time > this_time:
        #             break
        #         time_decays[this_time][pre_time] = torch.tensor(this_time - pre_time, dtype=torch.float32).to(device)
        # b_time_decays = tile(time_decays, 0, batch_size).view(batch_size,time_step,time_step).unsqueeze(-1).to(device)# b t t 1

        time_decays = torch.tensor(range(47,-1,-1), dtype=torch.float32).unsqueeze(-1).unsqueeze(0).to(device)# 1*t*1
        b_time_decays = time_decays.repeat(batch_size,1,1)+1# b t 1
        
        if self.attention_type == 'add': #B*T*I  @ H*I
            q = torch.matmul(input[:,-1,:], self.Wt)# b h
            q = torch.reshape(q, (batch_size, 1, self.attention_hidden_dim)) #B*1*H
            if self.time_aware == True:
                # k_input = torch.cat((input, time), dim=-1)
                k = torch.matmul(input, self.Wx)#b t h
                # k = torch.reshape(k, (batch_size, 1, time_step, self.attention_hidden_dim)) #B*1*T*H
                time_hidden = torch.matmul(b_time_decays, self.Wtime_aware)#  b t h
            else:
                k = torch.matmul(input, self.Wx)# b t h
                # k = torch.reshape(k, (batch_size, 1, time_step, self.attention_hidden_dim)) #B*1*T*H
            if self.use_demographic == True:
                d = torch.matmul(demo, self.Wd) #B*H
                d = torch.reshape(d, (batch_size, 1, self.attention_hidden_dim)) # b 1 h
            h = q + k + self.bh # b t h
            if self.time_aware == True:
                h += time_hidden
            h = self.tanh(h) #B*T*H
            e = torch.matmul(h, self.Wa) + self.ba #B*T*1
            e = torch.reshape(e, (batch_size, time_step))# b t
        elif self.attention_type == 'mul':
            e = torch.matmul(input[:,-1,:], self.Wa)#b i
            e = torch.matmul(e.unsqueeze(1), input.permute(0,2,1)).squeeze() + self.ba #b t
        elif self.attention_type == 'concat':
            q = input[:,-1,:].unsqueeze(1).repeat(1,time_step,1)# b t i
            k = input
            c = torch.cat((q, k), dim=-1) #B*T*2I
            if self.time_aware == True:
                c = torch.cat((c, b_time_decays), dim=-1) #B*T*2I+1
            h = torch.matmul(c, self.Wh)
            h = self.tanh(h)
            e = torch.matmul(h, self.Wa) + self.ba #B*T*1
            e = torch.reshape(e, (batch_size, time_step)) # b t 
            
        elif self.attention_type == 'new':
            
            q = torch.matmul(input[:,-1,:], self.Wt)# b h
            q = torch.reshape(q, (batch_size, 1, self.attention_hidden_dim)) #B*1*H
            k = torch.matmul(input, self.Wx)#b t h
            dot_product = torch.matmul(q, k.transpose(1, 2)).squeeze() # b t
            denominator =  self.sigmoid(self.rate) * (torch.log(2.72 +  (1-self.sigmoid(dot_product)))* (b_time_decays.squeeze()))
            e = self.relu(self.sigmoid(dot_product)/(denominator)) # b * t
#          * (b_time_decays.squeeze())
        # e = torch.exp(e - torch.max(e, dim=-1, keepdim=True).values)
        
        # if self.attention_width is not None:
        #     if self.history_only:
        #         lower = torch.arange(0, time_step).to(device) - (self.attention_width - 1)
        #     else:
        #         lower = torch.arange(0, time_step).to(device) - self.attention_width // 2
        #     lower = lower.unsqueeze(-1)
        #     upper = lower + self.attention_width
        #     indices = torch.arange(0, time_step).unsqueeze(0).to(device)
        #     e = e * (lower <= indices).float() * (indices < upper).float()
        
        # s = torch.sum(e, dim=-1, keepdim=True)
        # mask = subsequent_mask(time_step).to(device) # 1 t t 下三角
        # scores = e.masked_fill(mask == 0, -1e9)# b t t 下三角
        a = self.softmax(e) #B*T
        v = torch.matmul(a.unsqueeze(1), input).squeeze() #B*I

        return v, a

class FinalAttentionQKV(nn.Module):
    def __init__(self, attention_input_dim, attention_hidden_dim, attention_type='add', dropout=None):
        super(FinalAttentionQKV, self).__init__()
        
        self.attention_type = attention_type
        self.attention_hidden_dim = attention_hidden_dim
        self.attention_input_dim = attention_input_dim


        self.W_q = nn.Linear(attention_input_dim, attention_hidden_dim)
        self.W_k = nn.Linear(attention_input_dim, attention_hidden_dim)
        self.W_v = nn.Linear(attention_input_dim, attention_hidden_dim)

        self.W_out = nn.Linear(attention_hidden_dim, 1)

        self.b_in = nn.Parameter(torch.zeros(1,))
        self.b_out = nn.Parameter(torch.zeros(1,))

        nn.init.kaiming_uniform_(self.W_q.weight, a=math.sqrt(5))
        nn.init.kaiming_uniform_(self.W_k.weight, a=math.sqrt(5))
        nn.init.kaiming_uniform_(self.W_v.weight, a=math.sqrt(5))
        nn.init.kaiming_uniform_(self.W_out.weight, a=math.sqrt(5))

        self.Wh = nn.Parameter(torch.randn(2*attention_input_dim, attention_hidden_dim))
        self.Wa = nn.Parameter(torch.randn(attention_hidden_dim, 1))
        self.ba = nn.Parameter(torch.zeros(1,))
        
        nn.init.kaiming_uniform_(self.Wh, a=math.sqrt(5))
        nn.init.kaiming_uniform_(self.Wa, a=math.sqrt(5))
        
        self.dropout = nn.Dropout(p=dropout)
        self.tanh = nn.Tanh()
        self.softmax = nn.Softmax()
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, input):
 
        batch_size, time_step, input_dim = input.size() # batch_size * input_dim + 1 * hidden_dim(i)
        input_q = self.W_q(input[:, -1, :]) # b h
        input_k = self.W_k(input)# b t h
        input_v = self.W_v(input)# b t h

        if self.attention_type == 'add': #B*T*I  @ H*I

            q = torch.reshape(input_q, (batch_size, 1, self.attention_hidden_dim)) #B*1*H
            h = q + input_k + self.b_in # b t h
            h = self.tanh(h) #B*T*H
            e = self.W_out(h) # b t 1
            e = torch.reshape(e, (batch_size, time_step))# b t

        elif self.attention_type == 'mul':
            q = torch.reshape(input_q, (batch_size, self.attention_hidden_dim, 1)) #B*h 1
            e = torch.matmul(input_k, q).squeeze()#b t
            
        elif self.attention_type == 'concat':
            q = input_q.unsqueeze(1).repeat(1,time_step,1)# b t h
            k = input_k
            c = torch.cat((q, k), dim=-1) #B*T*2I
            h = torch.matmul(c, self.Wh)
            h = self.tanh(h)
            e = torch.matmul(h, self.Wa) + self.ba #B*T*1
            e = torch.reshape(e, (batch_size, time_step)) # b t 
        
        a = self.softmax(e) #B*T
        if self.dropout is not None:
            a = self.dropout(a)
        v = torch.matmul(a.unsqueeze(1), input_v).squeeze() #B*I

        return v, a

def clones(module, N):
    "Produce N identical layers."
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])

def tile(a, dim, n_tile):
    init_dim = a.size(dim)
    repeat_idx = [1] * a.dim()
    repeat_idx[dim] = n_tile
    a = a.repeat(*(repeat_idx))
    order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])).to(device)
    return torch.index_select(a, dim, order_index).to(device)

class PositionwiseFeedForward(nn.Module): # new added
    "Implements FFN equation."
    def __init__(self, d_model, d_ff, dropout=0.1):
        super(PositionwiseFeedForward, self).__init__()
        self.w_1 = nn.Linear(d_model, d_ff)
        self.w_2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.w_2(self.dropout(F.relu(self.w_1(x)))), None

# class PositionwiseFeedForwardConv(nn.Module):

#     def __init__(self, model_dim=512, ffn_dim=2048, dropout=0.0):
#         super(PositionalWiseFeedForward, self).__init__()
#         self.w1 = nn.Conv1d(model_dim, ffn_dim, 1)
#         self.w2 = nn.Conv1d(model_dim, ffn_dim, 1)
#         self.dropout = nn.Dropout(dropout)
#         self.layer_norm = nn.LayerNorm(model_dim)

#     def forward(self, x):
#         output = x.transpose(1, 2)
#         output = self.w2(F.relu(self.w1(output)))
#         output = self.dropout(output.transpose(1, 2))

#         # add residual and norm layer
#         output = self.layer_norm(x + output)
#         return output

class PositionalEncoding(nn.Module): # new added / not use anymore
    "Implement the PE function."
    def __init__(self, d_model, dropout, max_len=400):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        
        # Compute the positional encodings once in log space.
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0., max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0., d_model, 2) * -(math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        x = x + Variable(self.pe[:, :x.size(1)], 
                         requires_grad=False)
        return self.dropout(x)

def subsequent_mask(size):
    "Mask out subsequent positions."
    attn_shape = (1, size, size)
    subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
    return torch.from_numpy(subsequent_mask) == 0 # 下三角矩阵

def attention(query, key, value, mask=None, dropout=None):
    "Compute 'Scaled Dot Product Attention'"
    d_k = query.size(-1)# b h t d_k
    scores = torch.matmul(query, key.transpose(-2, -1)) \
             / math.sqrt(d_k) # b h t t
    if mask is not None:# 1 1 t t
        scores = scores.masked_fill(mask == 0, -1e9)# b h t t 下三角
    p_attn = F.softmax(scores, dim = -1)# b h t t
    if dropout is not None:
        p_attn = dropout(p_attn)
    return torch.matmul(p_attn, value), p_attn # b h t v (d_k) 
    
class MultiHeadedAttention(nn.Module):
    def __init__(self, h, d_model, dropout=0):
        "Take in model size and number of heads."
        super(MultiHeadedAttention, self).__init__()
        assert d_model % h == 0
        # We assume d_v always equals d_k
        self.d_k = d_model // h
        self.h = h
        self.linears = clones(nn.Linear(d_model, self.d_k * self.h), 3)
        self.final_linear = nn.Linear(d_model, d_model)
        self.attn = None
        self.dropout = nn.Dropout(p=dropout)
        
    def forward(self, query, key, value, mask=None):
        if mask is not None:
            # Same mask applied to all h heads.
            mask = mask.unsqueeze(1) # 1 1 t t

        nbatches = query.size(0)# b
        input_dim = query.size(1)# i+1
        feature_dim = query.size(-1)# i+1

        #input size -> # batch_size * d_input * hidden_dim
        
        # d_model => h * d_k 
        query, key, value = \
            [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
             for l, x in zip(self.linears, (query, key, value))] # b num_head d_input d_k
        
       
        x, self.attn = attention(query, key, value, mask=mask, 
                                 dropout=self.dropout)# b num_head d_input d_v (d_k) 

      
        x = x.transpose(1, 2).contiguous() \
             .view(nbatches, -1, self.h * self.d_k)# batch_size * d_input * hidden_dim

        #DeCov 
        DeCov_contexts = x.transpose(0, 1).transpose(1, 2) # I+1 H B
        Covs = cov(DeCov_contexts[0,:,:])
        DeCov_loss = 0.5 * (torch.norm(Covs, p = 'fro')**2 - torch.norm(torch.diag(Covs))**2 ) 
        for i in range(feature_dim -1 + 1):
            Covs = cov(DeCov_contexts[i+1,:,:])
            DeCov_loss += 0.5 * (torch.norm(Covs, p = 'fro')**2 - torch.norm(torch.diag(Covs))**2 ) 


        return self.final_linear(x), DeCov_loss

class LayerNorm(nn.Module):
    def __init__(self, features, eps=1e-7):
        super(LayerNorm, self).__init__()
        self.a_2 = nn.Parameter(torch.ones(features))
        self.b_2 = nn.Parameter(torch.zeros(features))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return self.a_2 * (x - mean) / (std + self.eps) + self.b_2

def cov(m, y=None):
    if y is not None:
        m = torch.cat((m, y), dim=0)
    m_exp = torch.mean(m, dim=1)
    x = m - m_exp[:, None]
    cov = 1 / (x.size(1) - 1) * x.mm(x.t())
    return cov

class SublayerConnection(nn.Module):
    """
    A residual connection followed by a layer norm.
    Note for code simplicity the norm is first as opposed to last.
    """
    def __init__(self, size, dropout):
        super(SublayerConnection, self).__init__()
        self.norm = LayerNorm(size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, sublayer):
        "Apply residual connection to any sublayer with the same size."
        returned_value = sublayer(self.norm(x))
        return x + self.dropout(returned_value[0]) , returned_value[1]

class ConCare(nn.Module):
    def __init__(self, input_dim, hidden_dim, d_model,  MHD_num_head, d_ff, output_dim, keep_prob=0.5):
        super(ConCare, self).__init__()

        # hyperparameters
        self.input_dim = input_dim  
        self.hidden_dim = hidden_dim  # d_model
        self.d_model = d_model
        self.MHD_num_head = MHD_num_head
        self.d_ff = d_ff
        self.output_dim = output_dim
        self.keep_prob = keep_prob

        # layers
        self.PositionalEncoding = PositionalEncoding(self.d_model, dropout = 0, max_len = 400)

        self.GRUs = clones(nn.GRU(1, self.hidden_dim, batch_first = True), self.input_dim)
        self.LastStepAttentions = clones(SingleAttention(self.hidden_dim, 8, attention_type='new', demographic_dim=12, time_aware=True, use_demographic=False),self.input_dim)
        
        self.FinalAttentionQKV = FinalAttentionQKV(self.hidden_dim, self.hidden_dim, attention_type='mul',dropout = 1 - self.keep_prob)

        self.MultiHeadedAttention = MultiHeadedAttention(self.MHD_num_head, self.d_model,dropout = 1 - self.keep_prob)
        self.SublayerConnection = SublayerConnection(self.d_model, dropout = 1 - self.keep_prob)

        self.PositionwiseFeedForward = PositionwiseFeedForward(self.d_model, self.d_ff, dropout=0.1)

        self.demo_proj_main = nn.Linear(12, self.hidden_dim)
        self.demo_proj = nn.Linear(12, self.hidden_dim)
        self.output0 = nn.Linear(self.hidden_dim, self.hidden_dim)
        self.output1 = nn.Linear(self.hidden_dim, self.output_dim)

        self.dropout = nn.Dropout(p = 1 - self.keep_prob)
        self.tanh=nn.Tanh()
        self.softmax = nn.Softmax()
        self.sigmoid = nn.Sigmoid()
        self.relu=nn.ReLU()

    def forward(self, input, demo_input):
        # input shape [batch_size, timestep, feature_dim]
        demo_main = self.tanh(self.demo_proj_main(demo_input)).unsqueeze(1)# b hidden_dim
        
        batch_size = input.size(0)
        time_step = input.size(1)
        feature_dim = input.size(2)
        assert(feature_dim == self.input_dim)# input Tensor : 256 * 48 * 76
        assert(self.d_model % self.MHD_num_head == 0)

        # Initialization
        #cur_hs = Variable(torch.zeros(batch_size, self.hidden_dim).unsqueeze(0))

        # forward
        GRU_embeded_input = self.GRUs[0](input[:,:,0].unsqueeze(-1), Variable(torch.zeros(batch_size, self.hidden_dim).unsqueeze(0)).to(device))[0] # b t h
        Attention_embeded_input = self.LastStepAttentions[0](GRU_embeded_input)[0].unsqueeze(1)# b 1 h
        for i in range(feature_dim-1):
            embeded_input = self.GRUs[i+1](input[:,:,i+1].unsqueeze(-1), Variable(torch.zeros(batch_size, self.hidden_dim).unsqueeze(0)).to(device))[0] # b 1 h
            embeded_input = self.LastStepAttentions[i+1](embeded_input)[0].unsqueeze(1)# b 1 h
            Attention_embeded_input = torch.cat((Attention_embeded_input, embeded_input), 1)# b i h

        Attention_embeded_input = torch.cat((Attention_embeded_input, demo_main), 1)# b i+1 h
        posi_input = self.dropout(Attention_embeded_input) # batch_size * d_input+1 * hidden_dim

#         GRU_embeded_input = self.GRUs[0](input[:,:,0].unsqueeze(-1), Variable(torch.zeros(batch_size, self.hidden_dim).unsqueeze(0)).to(device))[0][:,-1,:].unsqueeze(1) # b 1 h
#         for i in range(feature_dim-1):
#             embeded_input = self.GRUs[i+1](input[:,:,i+1].unsqueeze(-1), Variable(torch.zeros(batch_size, self.hidden_dim).unsqueeze(0)).to(device))[0][:,-1,:].unsqueeze(1) # b 1 h
#             GRU_embeded_input = torch.cat((GRU_embeded_input, embeded_input), 1)

#         GRU_embeded_input = torch.cat((GRU_embeded_input, demo_main), 1)# b i+1 h
#         posi_input = self.dropout(GRU_embeded_input) # batch_size * d_input * hidden_dim


        #mask = subsequent_mask(time_step).to(device) # 1 t t 下三角 N to 1任务不用mask
        contexts = self.SublayerConnection(posi_input, lambda x: self.MultiHeadedAttention(posi_input, posi_input, posi_input, None))# # batch_size * d_input * hidden_dim
    
        DeCov_loss = contexts[1]
        contexts = contexts[0]

        contexts = self.SublayerConnection(contexts, lambda x: self.PositionwiseFeedForward(contexts))[0]# # batch_size * d_input * hidden_dim
        #contexts = contexts.view(batch_size, feature_dim * self.hidden_dim)#
        # contexts = torch.matmul(self.Wproj, contexts) + self.bproj
        # contexts = contexts.squeeze()
        # demo_key = self.demo_proj(demo_input)# b hidden_dim
        # demo_key = self.relu(demo_key)
        # input_dim_scores = torch.matmul(contexts, demo_key.unsqueeze(-1)).squeeze() # b i
        # input_dim_scores = self.dropout(self.sigmoid(input_dim_scores)).unsqueeze(1)# b i
        
        # weighted_contexts = torch.matmul(input_dim_scores, contexts).squeeze()

        weighted_contexts = self.FinalAttentionQKV(contexts)[0]
        output = self.output1(self.relu(self.output0(weighted_contexts)))# b 1
        output = self.sigmoid(output)
          
        return output, DeCov_loss
    #, self.MultiHeadedAttention.attn



In [21]:
def get_loss(y_pred, y_true):
    loss = torch.nn.BCELoss()
    return loss(y_pred, y_true)

In [22]:
class Dataset(data.Dataset):
    def __init__(self, x, y, name):
        self.x = x
        self.y = y
        self.name = name

    def __getitem__(self, index):#返回的是tensor
        return self.x[index], self.y[index], self.name[index]

    def __len__(self):
        return len(self.x)

In [23]:
train_dataset = Dataset(train_raw['data'][0], train_raw['data'][1], train_raw['names'])
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
valid_dataset = Dataset(val_raw['data'][0], val_raw['data'][1], val_raw['names'])
valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)

### Run

In [24]:
RANDOM_SEED = 42
np.random.seed(RANDOM_SEED) #numpy
random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED) # cpu
torch.cuda.manual_seed(RANDOM_SEED) #gpu
torch.backends.cudnn.deterministic=True # cudnn

model = ConCare(input_dim = 76, hidden_dim = 64, d_model = 64,  MHD_num_head = 4 , d_ff = 256, output_dim = 1).to(device)
# input_dim, d_model, d_k, d_v, MHD_num_head, d_ff, output_dim
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

max_roc = 0
max_prc = 0
train_loss = []
train_model_loss = []
train_decov_loss = []
valid_loss = []
valid_model_loss = []
valid_decov_loss = []
history = []
np.set_printoptions(threshold=np.inf)
np.set_printoptions(precision=2)
np.set_printoptions(suppress=True)

for each_epoch in range(epochs):
    batch_loss = []
    model_batch_loss = []
    decov_batch_loss = []

    model.train()
 
    for step, (batch_x, batch_y, batch_name) in enumerate(train_loader):   
        optimizer.zero_grad()
        batch_x = batch_x.float().to(device)
        batch_y = batch_y.float().to(device)

        batch_demo = []
        print('len(batch_name) : ', len(batch_name))
        print('idx_list : ',len(idx_list))
        for i in range(len(batch_name)):
            cur_id, cur_ep, _ = batch_name[i].split('_', 2)
            cur_idx = cur_id + '_' + cur_ep
            if cur_idx in idx_list:
                cur_demo = torch.tensor(demographic_data[idx_list.index(cur_idx)], dtype=torch.float32)
                batch_demo.append(cur_demo)
        
        
        batch_demo = torch.stack(batch_demo).to(device)
        print('batch_demo : ',batch_demo.shape)
        print('batch_x : ',batch_x.shape)
        output, decov_loss = model(batch_x, batch_demo)
        
        
        model_loss = get_loss(output, batch_y.unsqueeze(-1))
        loss = model_loss + 800* decov_loss
        
        batch_loss.append(loss.cpu().detach().numpy())
        model_batch_loss.append(model_loss.cpu().detach().numpy())
        decov_batch_loss.append(decov_loss.cpu().detach().numpy())
        loss.backward()
        optimizer.step()
        
        if step % 30 == 0:
            print('Epoch %d Batch %d: Train Loss = %.4f'%(each_epoch, step, np.mean(np.array(batch_loss))))
            print('Model Loss = %.4f, Decov Loss = %.4f'%(np.mean(np.array(model_batch_loss)), np.mean(np.array(decov_batch_loss))))
    train_loss.append(np.mean(np.array(batch_loss)))
    train_model_loss.append(np.mean(np.array(model_batch_loss)))
    train_decov_loss.append(np.mean(np.array(decov_batch_loss)))
    
    batch_loss = []
    model_batch_loss = []
    decov_batch_loss = []
    
    y_true = []
    y_pred = []
    with torch.no_grad():
        model.eval()
        for step, (batch_x, batch_y, batch_name) in enumerate(valid_loader):
            batch_x = batch_x.float().to(device)
            batch_y = batch_y.float().to(device)
            batch_demo = []
            print('len(batch_name) : ', len(batch_name))
            print('idx_list : ',len(idx_list))
            for i in range(len(batch_name)):
                cur_id, cur_ep, _ = batch_name[i].split('_', 2)
                cur_idx = cur_id + '_' + cur_ep
                if cur_idx in idx_list:
                    cur_demo = torch.tensor(demographic_data[idx_list.index(cur_idx)], dtype=torch.float32)
                    batch_demo.append(cur_demo)
            
            
            batch_demo = torch.stack(batch_demo).to(device)
            print('batch_demo 1 : ',batch_demo.shape)
            print('batch_x : ',batch_x.shape)
            output,decov_loss = model(batch_x, batch_demo)
            
            model_loss = get_loss(output, batch_y.unsqueeze(-1))

            loss = model_loss + 10* decov_loss
            batch_loss.append(loss.cpu().detach().numpy())
            model_batch_loss.append(model_loss.cpu().detach().numpy())
            decov_batch_loss.append(decov_loss.cpu().detach().numpy())
            y_pred += list(output.cpu().detach().numpy().flatten())
            y_true += list(batch_y.cpu().numpy().flatten())
            
    valid_loss.append(np.mean(np.array(batch_loss)))
    valid_model_loss.append(np.mean(np.array(model_batch_loss)))
    valid_decov_loss.append(np.mean(np.array(decov_batch_loss)))
    
    print("\n==>Predicting on validation")
    print('Valid Loss = %.4f'%(valid_loss[-1]))
    print('valid_model Loss = %.4f'%(valid_model_loss[-1]))
    print('valid_decov Loss = %.4f'%(valid_decov_loss[-1]))
    y_pred = np.array(y_pred)
    y_pred = np.stack([1 - y_pred, y_pred], axis=1)
    ret = metrics.print_metrics_binary(y_true, y_pred)
    history.append(ret)
    

    cur_auroc = ret['auroc']
    
    if cur_auroc > max_roc:
        max_roc = cur_auroc
        state = {
            'net': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'epoch': each_epoch
        }
        torch.save(state, file_name)
        print('\n------------ Save best model ------------\n')

len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])


  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 0 Batch 0: Train Loss = 0.9476
Model Loss = 0.7103, Decov Loss = 0.0003
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_

  prec1 = cf[1][1] / (cf[1][1] + cf[0][1])



------------ Save best model ------------

len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])


  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 1 Batch 0: Train Loss = 0.4248
Model Loss = 0.4202, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_

  prec1 = cf[1][1] / (cf[1][1] + cf[0][1])


batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])


  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 2 Batch 0: Train Loss = 0.4028
Model Loss = 0.4001, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_


------------ Save best model ------------

len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])


  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 3 Batch 0: Train Loss = 0.3474
Model Loss = 0.3450, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_


------------ Save best model ------------

len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])


  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 4 Batch 0: Train Loss = 0.3112
Model Loss = 0.3094, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_


------------ Save best model ------------

len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])


  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 5 Batch 0: Train Loss = 0.3476
Model Loss = 0.3459, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_


------------ Save best model ------------

len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])


  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 6 Batch 0: Train Loss = 0.3374
Model Loss = 0.3356, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_


------------ Save best model ------------

len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])


  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 7 Batch 0: Train Loss = 0.3063
Model Loss = 0.3046, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_


------------ Save best model ------------

len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])


  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 8 Batch 0: Train Loss = 0.3684
Model Loss = 0.3664, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_


------------ Save best model ------------

len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])


  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 9 Batch 0: Train Loss = 0.3312
Model Loss = 0.3293, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_


------------ Save best model ------------

len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])


  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 10 Batch 0: Train Loss = 0.3090
Model Loss = 0.3073, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch


------------ Save best model ------------

len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])


  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 11 Batch 0: Train Loss = 0.3131
Model Loss = 0.3110, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch


------------ Save best model ------------

len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])


  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 12 Batch 0: Train Loss = 0.2859
Model Loss = 0.2845, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch


------------ Save best model ------------

len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])


  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 13 Batch 0: Train Loss = 0.2754
Model Loss = 0.2730, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch

batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])


  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 14 Batch 0: Train Loss = 0.3632
Model Loss = 0.3607, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch


------------ Save best model ------------

len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])


  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 15 Batch 0: Train Loss = 0.2919
Model Loss = 0.2898, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch


------------ Save best model ------------

len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])


  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 16 Batch 0: Train Loss = 0.3356
Model Loss = 0.3327, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch

  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 17 Batch 0: Train Loss = 0.2856
Model Loss = 0.2833, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch

  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 18 Batch 0: Train Loss = 0.2635
Model Loss = 0.2606, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch


------------ Save best model ------------

len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])


  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 19 Batch 0: Train Loss = 0.2926
Model Loss = 0.2895, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch


------------ Save best model ------------

len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])


  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 20 Batch 0: Train Loss = 0.3238
Model Loss = 0.3195, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch


------------ Save best model ------------

len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])


  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 21 Batch 0: Train Loss = 0.3216
Model Loss = 0.3202, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch


------------ Save best model ------------

len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])


  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 22 Batch 0: Train Loss = 0.3417
Model Loss = 0.3397, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch

batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])


  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 23 Batch 0: Train Loss = 0.2955
Model Loss = 0.2933, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch

  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 24 Batch 0: Train Loss = 0.3280
Model Loss = 0.3265, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch

  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 25 Batch 0: Train Loss = 0.3024
Model Loss = 0.3009, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch


------------ Save best model ------------

len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])


  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 26 Batch 0: Train Loss = 0.2645
Model Loss = 0.2626, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch

  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 27 Batch 0: Train Loss = 0.3657
Model Loss = 0.3637, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch


------------ Save best model ------------

len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])


  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 28 Batch 0: Train Loss = 0.2748
Model Loss = 0.2730, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch


------------ Save best model ------------

len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])


  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 29 Batch 0: Train Loss = 0.2191
Model Loss = 0.2173, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch

  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 30 Batch 0: Train Loss = 0.3202
Model Loss = 0.3187, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch


------------ Save best model ------------

len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])


  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 31 Batch 0: Train Loss = 0.3618
Model Loss = 0.3596, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch

  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 32 Batch 0: Train Loss = 0.3103
Model Loss = 0.3081, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch


------------ Save best model ------------

len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])


  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 33 Batch 0: Train Loss = 0.2888
Model Loss = 0.2873, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch


------------ Save best model ------------

len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])


  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 34 Batch 0: Train Loss = 0.2477
Model Loss = 0.2453, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch


------------ Save best model ------------

len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])


  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 35 Batch 0: Train Loss = 0.2241
Model Loss = 0.2230, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch


------------ Save best model ------------

len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])


  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 36 Batch 0: Train Loss = 0.2573
Model Loss = 0.2562, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch

  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 37 Batch 0: Train Loss = 0.2669
Model Loss = 0.2659, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch


------------ Save best model ------------

len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])


  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 38 Batch 0: Train Loss = 0.3069
Model Loss = 0.3060, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch


------------ Save best model ------------

len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])


  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 39 Batch 0: Train Loss = 0.2580
Model Loss = 0.2572, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch

  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 40 Batch 0: Train Loss = 0.3738
Model Loss = 0.3727, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch

  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 41 Batch 0: Train Loss = 0.2720
Model Loss = 0.2712, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch

  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 42 Batch 0: Train Loss = 0.2905
Model Loss = 0.2898, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch

  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 43 Batch 0: Train Loss = 0.2028
Model Loss = 0.2020, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch

  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 44 Batch 0: Train Loss = 0.3246
Model Loss = 0.3237, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch

  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 45 Batch 0: Train Loss = 0.2359
Model Loss = 0.2351, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch

  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 46 Batch 0: Train Loss = 0.2608
Model Loss = 0.2595, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch


------------ Save best model ------------

len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])


  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 47 Batch 0: Train Loss = 0.3035
Model Loss = 0.3026, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch

  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 48 Batch 0: Train Loss = 0.3296
Model Loss = 0.3290, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch

  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 49 Batch 0: Train Loss = 0.2246
Model Loss = 0.2234, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch

  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 50 Batch 0: Train Loss = 0.2753
Model Loss = 0.2743, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch

  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 51 Batch 0: Train Loss = 0.2478
Model Loss = 0.2470, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch

  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 52 Batch 0: Train Loss = 0.3579
Model Loss = 0.3569, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch

  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 53 Batch 0: Train Loss = 0.2314
Model Loss = 0.2306, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch

  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 54 Batch 0: Train Loss = 0.2393
Model Loss = 0.2387, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch

  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 55 Batch 0: Train Loss = 0.2930
Model Loss = 0.2924, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch

  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 56 Batch 0: Train Loss = 0.2979
Model Loss = 0.2973, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch

  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 57 Batch 0: Train Loss = 0.2419
Model Loss = 0.2413, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch

  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 58 Batch 0: Train Loss = 0.2189
Model Loss = 0.2182, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch

  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 59 Batch 0: Train Loss = 0.2712
Model Loss = 0.2705, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch

  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 60 Batch 0: Train Loss = 0.2966
Model Loss = 0.2959, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch


len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])


  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 61 Batch 0: Train Loss = 0.2496
Model Loss = 0.2487, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch

  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 62 Batch 0: Train Loss = 0.2522
Model Loss = 0.2517, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch

  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 63 Batch 0: Train Loss = 0.2380
Model Loss = 0.2371, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch

  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 64 Batch 0: Train Loss = 0.2369
Model Loss = 0.2364, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch

  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 65 Batch 0: Train Loss = 0.2256
Model Loss = 0.2242, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch

  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 66 Batch 0: Train Loss = 0.2487
Model Loss = 0.2481, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch

  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 67 Batch 0: Train Loss = 0.3036
Model Loss = 0.3029, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch

  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 68 Batch 0: Train Loss = 0.2296
Model Loss = 0.2290, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch

  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 69 Batch 0: Train Loss = 0.3375
Model Loss = 0.3367, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch

  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 70 Batch 0: Train Loss = 0.3060
Model Loss = 0.3055, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch

  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 71 Batch 0: Train Loss = 0.2374
Model Loss = 0.2366, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch

  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 72 Batch 0: Train Loss = 0.2408
Model Loss = 0.2400, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch

batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])


  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 73 Batch 0: Train Loss = 0.3008
Model Loss = 0.3001, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch

  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 74 Batch 0: Train Loss = 0.2511
Model Loss = 0.2505, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch

  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 75 Batch 0: Train Loss = 0.2419
Model Loss = 0.2411, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch

  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 76 Batch 0: Train Loss = 0.2625
Model Loss = 0.2620, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch

  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 77 Batch 0: Train Loss = 0.2247
Model Loss = 0.2241, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch

  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 78 Batch 0: Train Loss = 0.2327
Model Loss = 0.2318, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch

  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 79 Batch 0: Train Loss = 0.2486
Model Loss = 0.2479, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch

  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 80 Batch 0: Train Loss = 0.2738
Model Loss = 0.2732, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch

  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 81 Batch 0: Train Loss = 0.2797
Model Loss = 0.2789, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch

  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 82 Batch 0: Train Loss = 0.2864
Model Loss = 0.2859, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch

  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 83 Batch 0: Train Loss = 0.2666
Model Loss = 0.2660, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch

  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 84 Batch 0: Train Loss = 0.2032
Model Loss = 0.2027, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch

  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 85 Batch 0: Train Loss = 0.2430
Model Loss = 0.2424, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch

  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 86 Batch 0: Train Loss = 0.2477
Model Loss = 0.2470, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch

batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])


  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 87 Batch 0: Train Loss = 0.2415
Model Loss = 0.2410, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch

  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 88 Batch 0: Train Loss = 0.2570
Model Loss = 0.2564, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch

  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 89 Batch 0: Train Loss = 0.2467
Model Loss = 0.2463, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch

  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 90 Batch 0: Train Loss = 0.2278
Model Loss = 0.2274, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch

batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])


  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 91 Batch 0: Train Loss = 0.2185
Model Loss = 0.2180, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch

  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 92 Batch 0: Train Loss = 0.2738
Model Loss = 0.2733, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch

  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 93 Batch 0: Train Loss = 0.2658
Model Loss = 0.2651, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch

  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 94 Batch 0: Train Loss = 0.2776
Model Loss = 0.2771, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch

  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 95 Batch 0: Train Loss = 0.3084
Model Loss = 0.3078, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch

  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 96 Batch 0: Train Loss = 0.2511
Model Loss = 0.2506, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch

  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 97 Batch 0: Train Loss = 0.2694
Model Loss = 0.2689, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch

  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 98 Batch 0: Train Loss = 0.2909
Model Loss = 0.2899, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch

  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T


Epoch 99 Batch 0: Train Loss = 0.2340
Model Loss = 0.2333, Decov Loss = 0.0000
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch_name) :  256
idx_list :  42025
batch_demo :  torch.Size([256, 12])
batch_x :  torch.Size([256, 48, 76])
len(batch

### Run for test

In [25]:
checkpoint = torch.load(file_name)
save_epoch = checkpoint['epoch']
model.load_state_dict(checkpoint['net'])
optimizer.load_state_dict(checkpoint['optimizer'])
model.eval()

test_reader = InHospitalMortalityReader(dataset_dir=os.path.join(data_path, 'test'),
                                            listfile=os.path.join(data_path, 'test_listfile.csv'),
                                            period_length=48.0)
test_raw = utils.load_data(test_reader, discretizer, normalizer, small_part, return_names=True)
test_dataset = Dataset(test_raw['data'][0], test_raw['data'][1], test_raw['names'])
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [26]:
batch_loss = []
y_true = []
y_pred = []
with torch.no_grad():
    model.eval()
    for step, (batch_x, batch_y, batch_name) in enumerate(test_loader):
        batch_x = batch_x.float().to(device)
        batch_y = batch_y.float().to(device)
        batch_demo = []
        for i in range(len(batch_name)):
            cur_id, cur_ep, _ = batch_name[i].split('_', 2)
            cur_idx = cur_id + '_' + cur_ep
            if cur_idx in idx_list:
                cur_demo = torch.tensor(demographic_data[idx_list.index(cur_idx)], dtype=torch.float32)
                batch_demo.append(cur_demo)

        batch_demo = torch.stack(batch_demo).to(device)
        output = model(batch_x, batch_demo)[0]

        loss = get_loss(output, batch_y.unsqueeze(-1))
        batch_loss.append(loss.cpu().detach().numpy())
        y_pred += list(output.cpu().detach().numpy().flatten())
        y_true += list(batch_y.cpu().numpy().flatten())

print("\n==>Predicting on test")
print('Test Loss = %.4f'%(np.mean(np.array(batch_loss))))
y_pred = np.array(y_pred)
y_pred = np.stack([1 - y_pred, y_pred], axis=1)
test_res = metrics.print_metrics_binary(y_true, y_pred)

  a = self.softmax(e) #B*T
  a = self.softmax(e) #B*T



==>Predicting on test
Test Loss = 0.2523
confusion matrix:
[[2803   59]
 [ 256  118]]
accuracy = 0.9026576280593872
precision class 0 = 0.9163125157356262
precision class 1 = 0.6666666865348816
recall class 0 = 0.9793850183486938
recall class 1 = 0.31550800800323486
AUC of ROC = 0.8689031454014805
AUC of PRC = 0.520446972384111
min(+P, Se) = 0.5106951871657754
f1_score = 0.42831215147454527


In [27]:
# Bootstrap
N = len(y_true)
N_idx = np.arange(N)
K = 1000

auroc = []
auprc = []
minpse = []
for i in range(K):
    boot_idx = np.random.choice(N_idx, N, replace=True)
    boot_true = np.array(y_true)[boot_idx]
    boot_pred = y_pred[boot_idx, :]
    test_ret = metrics.print_metrics_binary(boot_true, boot_pred, verbose=0)
    auroc.append(test_ret['auroc'])
    auprc.append(test_ret['auprc'])
    minpse.append(test_ret['minpse'])
    print('%d/%d'%(i+1,K))
    
print('auroc %.4f(%.4f)'%(np.mean(auroc), np.std(auroc)))
print('auprc %.4f(%.4f)'%(np.mean(auprc), np.std(auprc)))
print('minpse %.4f(%.4f)'%(np.mean(minpse), np.std(minpse)))

1/1000
2/1000
3/1000
4/1000
5/1000
6/1000
7/1000
8/1000
9/1000
10/1000
11/1000
12/1000
13/1000
14/1000
15/1000
16/1000
17/1000
18/1000
19/1000
20/1000
21/1000
22/1000
23/1000
24/1000
25/1000
26/1000
27/1000
28/1000
29/1000
30/1000
31/1000
32/1000
33/1000
34/1000
35/1000
36/1000
37/1000
38/1000
39/1000
40/1000
41/1000
42/1000
43/1000
44/1000
45/1000
46/1000
47/1000
48/1000
49/1000
50/1000
51/1000
52/1000
53/1000
54/1000
55/1000
56/1000
57/1000
58/1000
59/1000
60/1000
61/1000
62/1000
63/1000
64/1000
65/1000
66/1000
67/1000
68/1000
69/1000
70/1000
71/1000
72/1000
73/1000
74/1000
75/1000
76/1000
77/1000
78/1000
79/1000
80/1000
81/1000
82/1000
83/1000
84/1000
85/1000
86/1000
87/1000
88/1000
89/1000
90/1000
91/1000
92/1000
93/1000
94/1000
95/1000
96/1000
97/1000
98/1000
99/1000
100/1000
101/1000
102/1000
103/1000
104/1000
105/1000
106/1000
107/1000
108/1000
109/1000
110/1000
111/1000
112/1000
113/1000
114/1000
115/1000
116/1000
117/1000
118/1000
119/1000
120/1000
121/1000
122/1000
123/1000
1

940/1000
941/1000
942/1000
943/1000
944/1000
945/1000
946/1000
947/1000
948/1000
949/1000
950/1000
951/1000
952/1000
953/1000
954/1000
955/1000
956/1000
957/1000
958/1000
959/1000
960/1000
961/1000
962/1000
963/1000
964/1000
965/1000
966/1000
967/1000
968/1000
969/1000
970/1000
971/1000
972/1000
973/1000
974/1000
975/1000
976/1000
977/1000
978/1000
979/1000
980/1000
981/1000
982/1000
983/1000
984/1000
985/1000
986/1000
987/1000
988/1000
989/1000
990/1000
991/1000
992/1000
993/1000
994/1000
995/1000
996/1000
997/1000
998/1000
999/1000
1000/1000
auroc 0.8694(0.0088)
auprc 0.5226(0.0271)
minpse 0.5124(0.0221)
