In [1]:
import torch
from torch import optim
import torch.nn as nn
import numpy as np
from matplotlib import pyplot as plt
from tqdm import tqdm
import torchvision
import pandas as pd
import copy
import random
import pickle
import torch.nn.functional as F
from torch.nn import Linear

In [2]:
#load the processed split data
def load_df(name):
    return pd.read_pickle('data/'+name+'.pkl')

[x_train, y_train, x_test, y_test, x_val, y_val] = map(load_df, ['x_train','y_train','x_test','y_test', 'x_val','y_val']) 

In [3]:
def typing(df):
    sequences=df.astype(np.float32).to_numpy().tolist()
    dataset = torch.stack([torch.tensor(s).unsqueeze(1).float() for s in sequences])
    return dataset

x_train, y_train, x_test, y_test, x_val, y_val = map(typing, [x_train, y_train.drop('sample_collection_site',axis=1), x_test, y_test.drop('sample_collection_site',axis=1), x_val, y_val.drop('sample_collection_site',axis=1)])

In [12]:
x_train[:,[1,4]].shape

torch.Size([1052, 2, 1])

In [4]:
from torch.utils.data import DataLoader,TensorDataset
from torch import Tensor

#y_train_ohe = y_train.drop('sample_collection_site', axis=1)
dataset = TensorDataset(x_train, y_train)
train_loader = DataLoader(dataset, batch_size=8, shuffle=False, num_workers=2, pin_memory=True)

#y_val_ohe = y_val.drop('sample_collection_site', axis=1)
dataset_val = TensorDataset(x_val, y_val)
val_loader = DataLoader(dataset_val, batch_size=8, shuffle=False, num_workers=2, pin_memory=True)


In [21]:
from sklearn.preprocessing import StandardScaler
sc = StandardScaler()
x_train = sc.fit_transform(x_train).astype(np.float64)
dataset = TensorDataset(Tensor(x_train), Tensor(y_train_ohe.astype(np.float32).to_numpy()))
train_loader = DataLoader(dataset, batch_size=4, shuffle=False, num_workers=2, pin_memory=True)


In [17]:
next(iter(train_loader)).shape

torch.Size([4, 19221])

In [3]:
from torch.utils.data import DataLoader,TensorDataset
from torch import Tensor

y_train_ohe = y_train.drop('sample_collection_site', axis=1)
dataset = TensorDataset(Tensor(x_train.astype(np.float64).to_numpy()), Tensor(y_train_ohe.astype(np.float64).to_numpy()))
train_loader = DataLoader(dataset, batch_size=128, shuffle=False, num_workers=4, pin_memory=True)

y_val_ohe = y_val.drop('sample_collection_site', axis=1)
dataset_val = TensorDataset(Tensor(x_val.astype(np.float64).to_numpy()), Tensor(y_val_ohe.astype(np.float64).to_numpy()))
val_loader = DataLoader(dataset_val, batch_size=128, shuffle=False, num_workers=4, pin_memory=True)


In [4]:
class EnCl(nn.Module):
    def __init__(self, input_size=19221, drop_rate=0.2, num_class=39):
        super(EnCl, self).__init__()
        self.encode = nn.Sequential(
            nn.Linear(input_size, input_size//2),
            nn.ReLU(inplace=True),
            nn.Dropout(drop_rate),
            nn.Linear(input_size//2, input_size//2//4),
            nn.ReLU(inplace=True),
            nn.Dropout(drop_rate),
            nn.Linear(input_size//2//4, input_size//2//4//4),
            nn.ReLU(inplace=True),
            #nn.relu(drop_rate),
            nn.Linear(input_size//2//4//4, input_size//2//4//4//8),
            nn.ReLU(inplace=True),
            #nn.relu(drop_rate),
            nn.Linear(input_size//2//4//4//8, num_class)
        )
        
        self.classify = nn.Sequential(
            nn.BatchNorm1d(num_class),
            nn.ReLU(inplace=True),
            nn.Linear(num_class, num_class),
            nn.ReLU(inplace=True),
            nn.Linear(num_class, num_class),
            nn.ReLU(inplace=True),
            nn.Linear(num_class, num_class),
            #nn.ReLU(True),
            nn.Softmax(dim=1)
        )
        
    def forward(self, x):
        code = self.encode(x)
        x = self.classify(code)
        return x

In [4]:
class Encoder(nn.Module):
    def __init__(self, input_size=18252, num_class=39, drop_rate=0.2):
        super(Encoder, self).__init__()
        self.drop_rate = drop_rate
        self.enc1 = nn.Linear(input_size, input_size//2)
        self.enc2 = nn.Linear(input_size//2, input_size//2//4)
        self.enc3 = nn.Linear(input_size//2//4, input_size//2//4//4)
        self.enc4 = nn.Linear(input_size//2//4//4, input_size//2//4//4//8)
        self.enc5 = nn.Linear(input_size//2//4//4//8, num_class)
    def forward(self, features):
        x = torch.relu(torch.dropout(self.enc1(features), p=self.drop_rate, train=True))
        x = torch.relu(torch.dropout(self.enc2(x), p=self.drop_rate, train=True))
        x = torch.relu(torch.dropout(self.enc3(x), p=self.drop_rate, train=True))
        x = torch.relu(torch.dropout(self.enc4(x), p=self.drop_rate, train=True))
        #x = torch.dropout(torch.relu(self.enc5(x)), p=drop_rate)
        code = self.enc5(x)
        return code

class MC_LR(nn.Module):
    def __init__(self, input_size=18252, num_class=39, drop_rate=0.2):
        super(MC_LR, self).__init__()
        self.fc1 = nn.Linear(num_class, num_class)
        self.fc2 = nn.Linear(num_class, num_class)
        self.fc3 = nn.Linear(num_class, num_class)
    def forward(self, code):
        x = torch.relu(self.fc1(code))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        cl = torch.softmax(x,dim=1)
        return cl

class Enc_MC_LR(nn.Module):
    def __init__(self):
        super(Enc_MC_LR, self).__init__()
        self.encoder = Encoder()
        self.regressor = MC_LR()
    def forward(self, features):
        latent = self.encoder(features)
        return self.regressor(latent)

In [11]:
19221//16

1201

In [4]:
class EnClS(nn.Module):
    def __init__(self, input_size=18252, drop_rate=0, num_class=39):
        super(EnClS, self).__init__()
        self.encode = nn.Sequential(
            nn.Linear(input_size, input_size//2),
            nn.ReLU(inplace=True),
            nn.Dropout(drop_rate),
            nn.Linear(input_size//2, input_size//16),
            nn.ReLU(inplace=True),
            nn.Dropout(drop_rate),
            nn.Linear(input_size//16, num_class)
        )
        
        self.classify = nn.Sequential(
            #nn.BatchNorm1d(num_class),
            nn.ReLU(inplace=True),
            nn.Linear(num_class, num_class),
            #nn.ReLU(True),
            nn.Softmax(dim=1)
        )
        
    def forward(self, x):
        code = self.encode(x)
        x = self.classify(code)
        return x

In [18]:
from torchsummary import summary
print(summary(EnClS().to('cuda'),(1054,18252)))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1           [-1, 1054, 9126]     166,576,878
              ReLU-2           [-1, 1054, 9126]               0
           Dropout-3           [-1, 1054, 9126]               0
            Linear-4           [-1, 1054, 1140]      10,404,780
              ReLU-5           [-1, 1054, 1140]               0
           Dropout-6           [-1, 1054, 1140]               0
            Linear-7             [-1, 1054, 39]          44,499
              ReLU-8             [-1, 1054, 39]               0
            Linear-9             [-1, 1054, 39]           1,560
          Softmax-10             [-1, 1054, 39]               0
Total params: 177,027,717
Trainable params: 177,027,717
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 73.39
Forward/backward pass size (MB): 248.91
Params size (MB): 675.3

In [5]:
def train_valid(model, device, train_loader, valid_loader, l1_lambda=0.01, class_num=39, input_len=18252, n_epochs=100):
  #optimizer = optim.SGD(model.parameters(), lr=1e-3, momentum=0.05)
  optimizer = optim.Adam(model.parameters(), lr=1e-5)
  criterion = nn.CrossEntropyLoss()
  losses = []
  min_val_loss = np.inf
  for epoch in tqdm(range(n_epochs)):

      model.train()
      loss = 0
      this = 0
      for batched, label in train_loader:
          this += 1
          #reshape
          #batch_features = batch_features.view(-1, input_len).to(device)
          # reset the gradients back to zero
          # PyTorch accumulates gradients on subsequent backward passes
          optimizer.zero_grad()
          # compute reconstructions
          #codes = model.encode(batch_features)#[1]
          #outputs = model.classify(codes)
          #sm = nn.Softmax(dim=1)
          #outputs = sm(codes)
          outputs = model(batched.to(device))
          # print(batch_features)
          # print(batch_features.shape)
          # print(codes)
          # print(codes.shape)
          #print(outputs)
          # print(outputs.shape)
          # print(label)
          # print(label.shape)
          #print(s)
          #compute training reconstruction loss
          # print(this)
          # print(outputs)
          # print(outputs.shape)
          #label = label.view(-1, class_num).to(device)
          train_loss = criterion(outputs, label.to(device))
          # compute accumulated gradients
          train_loss.backward()
          # perform parameter update based on current gradients
          optimizer.step()
          # add the mini-batch training loss to epoch loss
          loss += (train_loss.item()) #+ l1_lambda*sum(p.abs().sum() for p in model.parameters()))

      model.eval()
      val_loss = 0
      #with torch.no_grad():
      for feats, lv in valid_loader:
          feats = feats.to(device)
          target = model(feats)
          l = criterion(target, lv.to(device))
          val_loss += (l.item())# + l1_lambda*sum(p.abs().sum() for p in model.parameters()))

      # compute the epoch training loss
      loss = loss / len(train_loader)
      val_loss = val_loss / len(val_loader)
      losses.append((loss, val_loss))
      # display the epoch training loss
      print("epoch : {}/{}, loss = {:.6f}, val_loss = {:.6f}".format(epoch + 1, n_epochs, loss, val_loss))
      
      if min_val_loss > val_loss:
        print('val loss decreased, saved model')
        min_val_loss = val_loss
        torch.save(model.state_dict(), 'results/dl_encmclr_dopout0p2_noznorm.pth' )

  with open('results/loss_dl_encmclr_dropout0p2__noznorm.pkl','wb') as f:
    pickle.dump(losses,f)

  return model, model.eval()

In [6]:
device = ('cuda' if torch.cuda.is_available() else 'cpu')
model = Enc_MC_LR().to(device) #double?

In [7]:
trained,ev = train_valid(model, device, train_loader, val_loader, class_num=39, input_len=18252, n_epochs=300)

  0%|          | 0/300 [00:00<?, ?it/s]

epoch : 1/300, loss = 3.663340, val_loss = 3.664159
val loss decreased, saved model


  0%|          | 1/300 [00:09<48:53,  9.81s/it]

epoch : 2/300, loss = 3.662934, val_loss = 3.664142
val loss decreased, saved model


  1%|          | 3/300 [00:29<45:32,  9.20s/it]  

epoch : 3/300, loss = 3.662266, val_loss = 3.664487
epoch : 4/300, loss = 3.660522, val_loss = 3.664075
val loss decreased, saved model


  2%|▏         | 5/300 [00:52<48:48,  9.93s/it]  

epoch : 5/300, loss = 3.657474, val_loss = 3.664756


  2%|▏         | 6/300 [00:56<39:12,  8.00s/it]

epoch : 6/300, loss = 3.648962, val_loss = 3.664782


  2%|▏         | 7/300 [01:04<38:30,  7.88s/it]

epoch : 7/300, loss = 3.632536, val_loss = 3.668773


  3%|▎         | 8/300 [01:09<34:01,  6.99s/it]

epoch : 8/300, loss = 3.617600, val_loss = 3.674835


  3%|▎         | 9/300 [01:14<31:04,  6.41s/it]

epoch : 9/300, loss = 3.613538, val_loss = 3.674349


  3%|▎         | 10/300 [01:20<29:54,  6.19s/it]

epoch : 10/300, loss = 3.613954, val_loss = 3.672470


  4%|▎         | 11/300 [01:30<35:41,  7.41s/it]

epoch : 11/300, loss = 3.613014, val_loss = 3.672718


  4%|▍         | 12/300 [01:38<35:33,  7.41s/it]

epoch : 12/300, loss = 3.613680, val_loss = 3.674172


  4%|▍         | 13/300 [01:42<31:50,  6.66s/it]

epoch : 13/300, loss = 3.613220, val_loss = 3.674604


  5%|▍         | 14/300 [01:48<30:00,  6.29s/it]

epoch : 14/300, loss = 3.612589, val_loss = 3.674942


  5%|▌         | 15/300 [01:56<32:21,  6.81s/it]

epoch : 15/300, loss = 3.612132, val_loss = 3.675138


  5%|▌         | 16/300 [02:01<29:04,  6.14s/it]

epoch : 16/300, loss = 3.614459, val_loss = 3.676552


  6%|▌         | 17/300 [02:05<26:02,  5.52s/it]

epoch : 17/300, loss = 3.612518, val_loss = 3.675257


  6%|▌         | 18/300 [02:09<23:48,  5.06s/it]

epoch : 18/300, loss = 3.613626, val_loss = 3.674441


  6%|▋         | 19/300 [02:13<22:28,  4.80s/it]

epoch : 19/300, loss = 3.611361, val_loss = 3.674230


  7%|▋         | 20/300 [02:20<25:35,  5.48s/it]

epoch : 20/300, loss = 3.611738, val_loss = 3.675070


  7%|▋         | 21/300 [02:24<23:16,  5.01s/it]

epoch : 21/300, loss = 3.613182, val_loss = 3.675248


  7%|▋         | 22/300 [02:34<30:33,  6.59s/it]

epoch : 22/300, loss = 3.612290, val_loss = 3.674554


  8%|▊         | 23/300 [02:40<29:59,  6.50s/it]

epoch : 23/300, loss = 3.612805, val_loss = 3.675274


  8%|▊         | 24/300 [02:46<29:12,  6.35s/it]

epoch : 24/300, loss = 3.613947, val_loss = 3.675585


  8%|▊         | 25/300 [02:51<26:42,  5.83s/it]

epoch : 25/300, loss = 3.612324, val_loss = 3.674893


  9%|▊         | 26/300 [02:58<28:49,  6.31s/it]

epoch : 26/300, loss = 3.612230, val_loss = 3.677862


  9%|▉         | 27/300 [03:04<27:06,  5.96s/it]

epoch : 27/300, loss = 3.613107, val_loss = 3.675594


  9%|▉         | 28/300 [03:09<27:01,  5.96s/it]

epoch : 28/300, loss = 3.612027, val_loss = 3.675001


 10%|▉         | 29/300 [03:18<30:08,  6.67s/it]

epoch : 29/300, loss = 3.612923, val_loss = 3.675020


 10%|█         | 30/300 [03:25<30:44,  6.83s/it]

epoch : 30/300, loss = 3.612986, val_loss = 3.675058


 10%|█         | 31/300 [03:31<29:47,  6.64s/it]

epoch : 31/300, loss = 3.612652, val_loss = 3.674650


 11%|█         | 32/300 [03:38<29:50,  6.68s/it]

epoch : 32/300, loss = 3.613021, val_loss = 3.675575


 11%|█         | 33/300 [03:42<26:47,  6.02s/it]

epoch : 33/300, loss = 3.612029, val_loss = 3.675281


 11%|█▏        | 34/300 [03:47<24:19,  5.49s/it]

epoch : 34/300, loss = 3.612456, val_loss = 3.674711


 12%|█▏        | 35/300 [03:52<24:21,  5.51s/it]

epoch : 35/300, loss = 3.612178, val_loss = 3.675016


 12%|█▏        | 36/300 [03:56<22:08,  5.03s/it]

epoch : 36/300, loss = 3.612495, val_loss = 3.675118


 12%|█▏        | 37/300 [04:02<23:28,  5.35s/it]

epoch : 37/300, loss = 3.612789, val_loss = 3.675438


 13%|█▎        | 38/300 [04:10<26:26,  6.06s/it]

epoch : 38/300, loss = 3.612701, val_loss = 3.674827


 13%|█▎        | 39/300 [04:14<24:11,  5.56s/it]

epoch : 39/300, loss = 3.612708, val_loss = 3.674430


 13%|█▎        | 40/300 [04:22<26:17,  6.07s/it]

epoch : 40/300, loss = 3.612579, val_loss = 3.675145


 14%|█▎        | 41/300 [04:31<30:13,  7.00s/it]

epoch : 41/300, loss = 3.612658, val_loss = 3.675233


 14%|█▍        | 42/300 [04:38<30:02,  6.99s/it]

epoch : 42/300, loss = 3.612512, val_loss = 3.675628


 14%|█▍        | 43/300 [04:44<29:18,  6.84s/it]

epoch : 43/300, loss = 3.612275, val_loss = 3.674922


 15%|█▍        | 44/300 [04:48<25:23,  5.95s/it]

epoch : 44/300, loss = 3.612388, val_loss = 3.675340


 15%|█▌        | 45/300 [04:54<24:32,  5.78s/it]

epoch : 45/300, loss = 3.612578, val_loss = 3.674972


 15%|█▌        | 46/300 [05:01<26:24,  6.24s/it]

epoch : 46/300, loss = 3.612804, val_loss = 3.675581


 16%|█▌        | 47/300 [05:07<26:39,  6.32s/it]

epoch : 47/300, loss = 3.613139, val_loss = 3.675332


 16%|█▌        | 48/300 [05:15<28:41,  6.83s/it]

epoch : 48/300, loss = 3.611903, val_loss = 3.675322


 16%|█▋        | 49/300 [05:20<25:44,  6.15s/it]

epoch : 49/300, loss = 3.612546, val_loss = 3.675351


 17%|█▋        | 50/300 [05:30<30:00,  7.20s/it]

epoch : 50/300, loss = 3.612319, val_loss = 3.675442


 17%|█▋        | 51/300 [05:35<27:37,  6.66s/it]

epoch : 51/300, loss = 3.612206, val_loss = 3.675277


 17%|█▋        | 52/300 [05:43<29:26,  7.12s/it]

epoch : 52/300, loss = 3.612623, val_loss = 3.675040


 18%|█▊        | 53/300 [05:49<27:55,  6.78s/it]

epoch : 53/300, loss = 3.612815, val_loss = 3.675306


 18%|█▊        | 54/300 [05:59<31:46,  7.75s/it]

epoch : 54/300, loss = 3.612293, val_loss = 3.675284


 18%|█▊        | 55/300 [06:04<28:30,  6.98s/it]

epoch : 55/300, loss = 3.613753, val_loss = 3.675339


 19%|█▊        | 56/300 [06:09<24:53,  6.12s/it]

epoch : 56/300, loss = 3.612649, val_loss = 3.675303


 19%|█▉        | 57/300 [06:16<26:43,  6.60s/it]

epoch : 57/300, loss = 3.612133, val_loss = 3.675187


 19%|█▉        | 58/300 [06:22<25:07,  6.23s/it]

epoch : 58/300, loss = 3.612269, val_loss = 3.675381


 20%|█▉        | 59/300 [06:26<22:25,  5.58s/it]

epoch : 59/300, loss = 3.612463, val_loss = 3.675866


 20%|██        | 60/300 [06:30<20:20,  5.08s/it]

epoch : 60/300, loss = 3.612746, val_loss = 3.675358


 20%|██        | 61/300 [06:35<20:16,  5.09s/it]

epoch : 61/300, loss = 3.612588, val_loss = 3.675288


 21%|██        | 62/300 [06:45<26:28,  6.68s/it]

epoch : 62/300, loss = 3.612026, val_loss = 3.675312


 21%|██        | 63/300 [06:54<28:53,  7.32s/it]

epoch : 63/300, loss = 3.612770, val_loss = 3.675188


 21%|██▏       | 64/300 [07:00<27:02,  6.88s/it]

epoch : 64/300, loss = 3.612772, val_loss = 3.675257


 22%|██▏       | 65/300 [07:07<26:55,  6.88s/it]

epoch : 65/300, loss = 3.612356, val_loss = 3.675446


 22%|██▏       | 66/300 [07:14<27:01,  6.93s/it]

epoch : 66/300, loss = 3.612402, val_loss = 3.675242


 22%|██▏       | 67/300 [07:18<24:22,  6.28s/it]

epoch : 67/300, loss = 3.612661, val_loss = 3.675430


 23%|██▎       | 68/300 [07:26<26:08,  6.76s/it]

epoch : 68/300, loss = 3.612469, val_loss = 3.675402


 23%|██▎       | 69/300 [07:31<23:58,  6.23s/it]

epoch : 69/300, loss = 3.612377, val_loss = 3.675369


 23%|██▎       | 70/300 [07:36<21:38,  5.64s/it]

epoch : 70/300, loss = 3.613453, val_loss = 3.675394


 24%|██▎       | 71/300 [07:46<26:47,  7.02s/it]

epoch : 71/300, loss = 3.612440, val_loss = 3.675352


 24%|██▍       | 72/300 [07:50<23:09,  6.09s/it]

epoch : 72/300, loss = 3.612397, val_loss = 3.675323


 24%|██▍       | 73/300 [07:55<21:50,  5.77s/it]

epoch : 73/300, loss = 3.612277, val_loss = 3.675415


 25%|██▍       | 74/300 [08:00<21:39,  5.75s/it]

epoch : 74/300, loss = 3.612481, val_loss = 3.675203


 25%|██▌       | 75/300 [08:06<21:14,  5.66s/it]

epoch : 75/300, loss = 3.612691, val_loss = 3.675328


 25%|██▌       | 76/300 [08:10<19:35,  5.25s/it]

epoch : 76/300, loss = 3.612542, val_loss = 3.675360


 26%|██▌       | 77/300 [08:14<18:06,  4.87s/it]

epoch : 77/300, loss = 3.612445, val_loss = 3.675435


 26%|██▌       | 78/300 [08:20<19:20,  5.23s/it]

epoch : 78/300, loss = 3.612464, val_loss = 3.675207


 26%|██▋       | 79/300 [08:27<20:53,  5.67s/it]

epoch : 79/300, loss = 3.612438, val_loss = 3.675178


 27%|██▋       | 80/300 [08:35<22:56,  6.26s/it]

epoch : 80/300, loss = 3.612474, val_loss = 3.675350


 27%|██▋       | 81/300 [08:41<22:51,  6.26s/it]

epoch : 81/300, loss = 3.612363, val_loss = 3.675403


 27%|██▋       | 82/300 [08:45<20:33,  5.66s/it]

epoch : 82/300, loss = 3.612482, val_loss = 3.675417


 28%|██▊       | 83/300 [08:49<18:42,  5.17s/it]

epoch : 83/300, loss = 3.612402, val_loss = 3.675390


 28%|██▊       | 84/300 [08:53<17:38,  4.90s/it]

epoch : 84/300, loss = 3.612555, val_loss = 3.675410


 28%|██▊       | 85/300 [09:03<23:03,  6.43s/it]

epoch : 85/300, loss = 3.612510, val_loss = 3.675408


 29%|██▊       | 86/300 [09:07<20:25,  5.73s/it]

epoch : 86/300, loss = 3.612435, val_loss = 3.675429


 29%|██▉       | 87/300 [09:14<20:39,  5.82s/it]

epoch : 87/300, loss = 3.612595, val_loss = 3.675391


 29%|██▉       | 88/300 [09:23<24:29,  6.93s/it]

epoch : 88/300, loss = 3.612426, val_loss = 3.675372


 30%|██▉       | 89/300 [09:28<22:08,  6.30s/it]

epoch : 89/300, loss = 3.612383, val_loss = 3.675364


 30%|███       | 90/300 [09:38<25:35,  7.31s/it]

epoch : 90/300, loss = 3.612400, val_loss = 3.675251


 30%|███       | 91/300 [09:41<21:48,  6.26s/it]

epoch : 91/300, loss = 3.612307, val_loss = 3.675297


 31%|███       | 92/300 [09:47<21:21,  6.16s/it]

epoch : 92/300, loss = 3.612478, val_loss = 3.675407


 31%|███       | 93/300 [09:55<22:33,  6.54s/it]

epoch : 93/300, loss = 3.612437, val_loss = 3.675396


 31%|███▏      | 94/300 [10:01<21:55,  6.39s/it]

epoch : 94/300, loss = 3.612507, val_loss = 3.675407


 32%|███▏      | 95/300 [10:07<21:22,  6.26s/it]

epoch : 95/300, loss = 3.612963, val_loss = 3.675384


 32%|███▏      | 96/300 [10:16<24:41,  7.26s/it]

epoch : 96/300, loss = 3.612415, val_loss = 3.675256


 32%|███▏      | 97/300 [10:22<22:42,  6.71s/it]

epoch : 97/300, loss = 3.612468, val_loss = 3.675213


 33%|███▎      | 98/300 [10:26<19:49,  5.89s/it]

epoch : 98/300, loss = 3.612372, val_loss = 3.675462


 33%|███▎      | 99/300 [10:37<24:53,  7.43s/it]

epoch : 99/300, loss = 3.612416, val_loss = 3.675324


 33%|███▎      | 100/300 [10:41<21:13,  6.37s/it]

epoch : 100/300, loss = 3.612319, val_loss = 3.674557


 34%|███▎      | 101/300 [10:47<20:42,  6.24s/it]

epoch : 101/300, loss = 3.612437, val_loss = 3.675377


 34%|███▍      | 102/300 [10:53<21:04,  6.39s/it]

epoch : 102/300, loss = 3.612580, val_loss = 3.675375


 34%|███▍      | 103/300 [11:02<23:33,  7.18s/it]

epoch : 103/300, loss = 3.612481, val_loss = 3.675336


 35%|███▍      | 104/300 [11:10<23:55,  7.32s/it]

epoch : 104/300, loss = 3.612447, val_loss = 3.675468


 35%|███▌      | 105/300 [11:14<21:01,  6.47s/it]

epoch : 105/300, loss = 3.612382, val_loss = 3.675412


 35%|███▌      | 106/300 [11:19<18:51,  5.83s/it]

epoch : 106/300, loss = 3.612480, val_loss = 3.675382


 36%|███▌      | 107/300 [11:27<20:45,  6.45s/it]

epoch : 107/300, loss = 3.612564, val_loss = 3.675416


 36%|███▌      | 108/300 [11:32<19:17,  6.03s/it]

epoch : 108/300, loss = 3.612381, val_loss = 3.675345


 36%|███▋      | 109/300 [11:38<19:04,  5.99s/it]

epoch : 109/300, loss = 3.612414, val_loss = 3.674931


 37%|███▋      | 110/300 [11:46<20:48,  6.57s/it]

epoch : 110/300, loss = 3.612406, val_loss = 3.675170


 37%|███▋      | 111/300 [11:50<18:38,  5.92s/it]

epoch : 111/300, loss = 3.612415, val_loss = 3.674251


 37%|███▋      | 112/300 [11:54<17:02,  5.44s/it]

epoch : 112/300, loss = 3.612420, val_loss = 3.675418


 38%|███▊      | 113/300 [12:06<23:17,  7.47s/it]

epoch : 113/300, loss = 3.612456, val_loss = 3.675435


 38%|███▊      | 114/300 [12:11<19:59,  6.45s/it]

epoch : 114/300, loss = 3.612561, val_loss = 3.675397


 38%|███▊      | 115/300 [12:20<23:06,  7.49s/it]

epoch : 115/300, loss = 3.612487, val_loss = 3.675032


 39%|███▊      | 116/300 [12:25<20:06,  6.56s/it]

epoch : 116/300, loss = 3.612430, val_loss = 3.675191


 39%|███▉      | 117/300 [12:34<22:21,  7.33s/it]

epoch : 117/300, loss = 3.612391, val_loss = 3.675407


 39%|███▉      | 118/300 [12:40<21:09,  6.98s/it]

epoch : 118/300, loss = 3.612410, val_loss = 3.675408


 40%|███▉      | 119/300 [12:47<20:49,  6.90s/it]

epoch : 119/300, loss = 3.612435, val_loss = 3.675317


 40%|████      | 120/300 [12:53<20:05,  6.69s/it]

epoch : 120/300, loss = 3.612382, val_loss = 3.675457


 40%|████      | 121/300 [12:58<18:45,  6.29s/it]

epoch : 121/300, loss = 3.612474, val_loss = 3.675374


 41%|████      | 122/300 [13:06<20:09,  6.80s/it]

epoch : 122/300, loss = 3.612401, val_loss = 3.675482


 41%|████      | 123/300 [13:11<18:06,  6.14s/it]

epoch : 123/300, loss = 3.612360, val_loss = 3.675415


 41%|████▏     | 124/300 [13:15<16:11,  5.52s/it]

epoch : 124/300, loss = 3.612444, val_loss = 3.675472


 42%|████▏     | 125/300 [13:21<16:49,  5.77s/it]

epoch : 125/300, loss = 3.612384, val_loss = 3.675220


 42%|████▏     | 126/300 [13:31<19:41,  6.79s/it]

epoch : 126/300, loss = 3.612419, val_loss = 3.675378


 42%|████▏     | 127/300 [13:36<18:42,  6.49s/it]

epoch : 127/300, loss = 3.612414, val_loss = 3.675359


 43%|████▎     | 128/300 [13:45<20:06,  7.02s/it]

epoch : 128/300, loss = 3.612426, val_loss = 3.675427


 43%|████▎     | 129/300 [13:51<19:48,  6.95s/it]

epoch : 129/300, loss = 3.612367, val_loss = 3.675271


 43%|████▎     | 130/300 [13:58<19:03,  6.72s/it]

epoch : 130/300, loss = 3.612357, val_loss = 3.675378


 44%|████▎     | 131/300 [14:07<21:17,  7.56s/it]

epoch : 131/300, loss = 3.612637, val_loss = 3.675363


 44%|████▍     | 132/300 [14:13<19:53,  7.11s/it]

epoch : 132/300, loss = 3.612461, val_loss = 3.675385


 44%|████▍     | 133/300 [14:21<20:04,  7.21s/it]

epoch : 133/300, loss = 3.612438, val_loss = 3.675448


 45%|████▍     | 134/300 [14:26<18:15,  6.60s/it]

epoch : 134/300, loss = 3.612440, val_loss = 3.675380


 45%|████▌     | 135/300 [14:36<21:27,  7.80s/it]

epoch : 135/300, loss = 3.612189, val_loss = 3.675325


 45%|████▌     | 136/300 [14:44<21:07,  7.73s/it]

epoch : 136/300, loss = 3.612380, val_loss = 3.675394


 46%|████▌     | 137/300 [14:49<19:00,  6.99s/it]

epoch : 137/300, loss = 3.612318, val_loss = 3.675407


 46%|████▌     | 138/300 [14:53<16:27,  6.09s/it]

epoch : 138/300, loss = 3.612316, val_loss = 3.675211


 46%|████▋     | 139/300 [15:00<17:12,  6.41s/it]

epoch : 139/300, loss = 3.612465, val_loss = 3.675333


 47%|████▋     | 140/300 [15:05<15:45,  5.91s/it]

epoch : 140/300, loss = 3.612427, val_loss = 3.675273


 47%|████▋     | 141/300 [15:11<15:25,  5.82s/it]

epoch : 141/300, loss = 3.612384, val_loss = 3.675389


 47%|████▋     | 142/300 [15:15<13:48,  5.24s/it]

epoch : 142/300, loss = 3.612819, val_loss = 3.675391


 48%|████▊     | 143/300 [15:19<12:44,  4.87s/it]

epoch : 143/300, loss = 3.612358, val_loss = 3.675244


 48%|████▊     | 144/300 [15:25<13:55,  5.36s/it]

epoch : 144/300, loss = 3.612476, val_loss = 3.675340


 48%|████▊     | 145/300 [15:30<13:47,  5.34s/it]

epoch : 145/300, loss = 3.612418, val_loss = 3.675372


 49%|████▊     | 146/300 [15:38<15:39,  6.10s/it]

epoch : 146/300, loss = 3.612393, val_loss = 3.675249


 49%|████▉     | 147/300 [15:44<15:03,  5.91s/it]

epoch : 147/300, loss = 3.612381, val_loss = 3.675408


 49%|████▉     | 148/300 [15:51<16:03,  6.34s/it]

epoch : 148/300, loss = 3.612489, val_loss = 3.674559


 50%|████▉     | 149/300 [15:56<15:13,  6.05s/it]

epoch : 149/300, loss = 3.612314, val_loss = 3.675556


 50%|█████     | 150/300 [16:05<16:58,  6.79s/it]

epoch : 150/300, loss = 3.613120, val_loss = 3.675213


 50%|█████     | 151/300 [16:11<16:26,  6.62s/it]

epoch : 151/300, loss = 3.612553, val_loss = 3.675285


 51%|█████     | 152/300 [16:17<15:30,  6.29s/it]

epoch : 152/300, loss = 3.612412, val_loss = 3.675086


 51%|█████     | 153/300 [16:21<13:51,  5.66s/it]

epoch : 153/300, loss = 3.612349, val_loss = 3.675334


 51%|█████▏    | 154/300 [16:31<16:58,  6.98s/it]

epoch : 154/300, loss = 3.612355, val_loss = 3.675360


 52%|█████▏    | 155/300 [16:36<15:19,  6.34s/it]

epoch : 155/300, loss = 3.612383, val_loss = 3.675392


 52%|█████▏    | 156/300 [16:47<18:23,  7.66s/it]

epoch : 156/300, loss = 3.612321, val_loss = 3.675384


 52%|█████▏    | 157/300 [16:56<19:13,  8.06s/it]

epoch : 157/300, loss = 3.612384, val_loss = 3.675288


 53%|█████▎    | 158/300 [17:01<17:26,  7.37s/it]

epoch : 158/300, loss = 3.612442, val_loss = 3.675377


 53%|█████▎    | 159/300 [17:12<19:28,  8.29s/it]

epoch : 159/300, loss = 3.612391, val_loss = 3.675362


 53%|█████▎    | 160/300 [17:16<16:28,  7.06s/it]

epoch : 160/300, loss = 3.612360, val_loss = 3.675371


 54%|█████▎    | 161/300 [17:25<17:38,  7.61s/it]

epoch : 161/300, loss = 3.612369, val_loss = 3.675279


 54%|█████▍    | 162/300 [17:29<15:25,  6.70s/it]

epoch : 162/300, loss = 3.612376, val_loss = 3.675318


 54%|█████▍    | 163/300 [17:38<16:23,  7.18s/it]

epoch : 163/300, loss = 3.612346, val_loss = 3.675314


 55%|█████▍    | 164/300 [17:42<14:35,  6.44s/it]

epoch : 164/300, loss = 3.612889, val_loss = 3.675344


 55%|█████▌    | 165/300 [17:48<14:09,  6.30s/it]

epoch : 165/300, loss = 3.612259, val_loss = 3.676210


 55%|█████▌    | 166/300 [17:56<14:49,  6.64s/it]

epoch : 166/300, loss = 3.612603, val_loss = 3.675327


 56%|█████▌    | 167/300 [18:00<12:51,  5.80s/it]

epoch : 167/300, loss = 3.612222, val_loss = 3.675321


 56%|█████▌    | 168/300 [18:08<14:18,  6.50s/it]

epoch : 168/300, loss = 3.612373, val_loss = 3.675229


 56%|█████▋    | 169/300 [18:12<12:49,  5.87s/it]

epoch : 169/300, loss = 3.612434, val_loss = 3.675172


 57%|█████▋    | 170/300 [18:17<11:49,  5.46s/it]

epoch : 170/300, loss = 3.612355, val_loss = 3.675219


 57%|█████▋    | 171/300 [18:25<13:35,  6.32s/it]

epoch : 171/300, loss = 3.612010, val_loss = 3.675098


 57%|█████▋    | 172/300 [18:30<12:17,  5.76s/it]

epoch : 172/300, loss = 3.612124, val_loss = 3.675198


 58%|█████▊    | 173/300 [18:34<11:22,  5.37s/it]

epoch : 173/300, loss = 3.612215, val_loss = 3.674355


 58%|█████▊    | 174/300 [18:44<14:14,  6.78s/it]

epoch : 174/300, loss = 3.612366, val_loss = 3.674892


 58%|█████▊    | 175/300 [18:52<14:32,  6.98s/it]

epoch : 175/300, loss = 3.611789, val_loss = 3.674034


 59%|█████▊    | 176/300 [19:00<15:05,  7.30s/it]

epoch : 176/300, loss = 3.611112, val_loss = 3.672352


 59%|█████▉    | 177/300 [19:05<13:44,  6.71s/it]

epoch : 177/300, loss = 3.610950, val_loss = 3.672621


 59%|█████▉    | 178/300 [19:09<12:15,  6.03s/it]

epoch : 178/300, loss = 3.613034, val_loss = 3.674698


 60%|█████▉    | 179/300 [19:17<13:18,  6.60s/it]

epoch : 179/300, loss = 3.613085, val_loss = 3.675080


 60%|██████    | 180/300 [19:22<11:56,  5.97s/it]

epoch : 180/300, loss = 3.611642, val_loss = 3.675024


 60%|██████    | 181/300 [19:26<10:52,  5.48s/it]

epoch : 181/300, loss = 3.611541, val_loss = 3.672823
epoch : 182/300, loss = 3.608388, val_loss = 3.664001
val loss decreased, saved model


 61%|██████    | 183/300 [19:55<17:52,  9.16s/it]

epoch : 183/300, loss = 3.609901, val_loss = 3.671644


 61%|██████▏   | 184/300 [20:00<15:06,  7.82s/it]

epoch : 184/300, loss = 3.607546, val_loss = 3.667311


 62%|██████▏   | 185/300 [20:09<16:01,  8.36s/it]

epoch : 185/300, loss = 3.603618, val_loss = 3.665842
epoch : 186/300, loss = 3.598814, val_loss = 3.662389
val loss decreased, saved model


 62%|██████▏   | 187/300 [20:21<12:54,  6.86s/it]

epoch : 187/300, loss = 3.596808, val_loss = 3.663068
epoch : 188/300, loss = 3.588650, val_loss = 3.656172
val loss decreased, saved model


 63%|██████▎   | 189/300 [20:51<18:35, 10.05s/it]

epoch : 189/300, loss = 3.590085, val_loss = 3.661750
epoch : 190/300, loss = 3.588026, val_loss = 3.650299
val loss decreased, saved model


 64%|██████▎   | 191/300 [21:13<18:16, 10.06s/it]

epoch : 191/300, loss = 3.588377, val_loss = 3.653251


 64%|██████▍   | 192/300 [21:19<15:32,  8.64s/it]

epoch : 192/300, loss = 3.588422, val_loss = 3.658530


 64%|██████▍   | 193/300 [21:25<14:03,  7.88s/it]

epoch : 193/300, loss = 3.590022, val_loss = 3.655520


 65%|██████▍   | 194/300 [21:32<13:26,  7.60s/it]

epoch : 194/300, loss = 3.585742, val_loss = 3.659124


 65%|██████▌   | 195/300 [21:36<11:40,  6.67s/it]

epoch : 195/300, loss = 3.588402, val_loss = 3.654302
epoch : 196/300, loss = 3.579162, val_loss = 3.638193
val loss decreased, saved model


 66%|██████▌   | 197/300 [21:49<10:52,  6.33s/it]

epoch : 197/300, loss = 3.582223, val_loss = 3.653120
epoch : 198/300, loss = 3.570727, val_loss = 3.570259
val loss decreased, saved model


 66%|██████▋   | 199/300 [22:08<12:28,  7.41s/it]

epoch : 199/300, loss = 3.581209, val_loss = 3.634018


 67%|██████▋   | 200/300 [22:12<10:34,  6.34s/it]

epoch : 200/300, loss = 3.574896, val_loss = 3.637397


 67%|██████▋   | 201/300 [22:16<09:22,  5.68s/it]

epoch : 201/300, loss = 3.564952, val_loss = 3.643009


 67%|██████▋   | 202/300 [22:22<09:46,  5.98s/it]

epoch : 202/300, loss = 3.563418, val_loss = 3.587984


 68%|██████▊   | 203/300 [22:32<11:32,  7.13s/it]

epoch : 203/300, loss = 3.553714, val_loss = 3.587391


 68%|██████▊   | 204/300 [22:39<11:28,  7.17s/it]

epoch : 204/300, loss = 3.544686, val_loss = 3.602902


 68%|██████▊   | 205/300 [22:46<10:57,  6.92s/it]

epoch : 205/300, loss = 3.543983, val_loss = 3.594223


 69%|██████▊   | 206/300 [22:53<10:58,  7.01s/it]

epoch : 206/300, loss = 3.543179, val_loss = 3.588466


 69%|██████▉   | 207/300 [23:00<11:03,  7.14s/it]

epoch : 207/300, loss = 3.541556, val_loss = 3.585746


 69%|██████▉   | 208/300 [23:05<09:49,  6.41s/it]

epoch : 208/300, loss = 3.538739, val_loss = 3.588208


 70%|██████▉   | 209/300 [23:18<12:26,  8.20s/it]

epoch : 209/300, loss = 3.535818, val_loss = 3.585272


 70%|███████   | 210/300 [23:22<10:33,  7.04s/it]

epoch : 210/300, loss = 3.536752, val_loss = 3.583851


 70%|███████   | 211/300 [23:30<10:47,  7.28s/it]

epoch : 211/300, loss = 3.531509, val_loss = 3.581836


 71%|███████   | 212/300 [23:35<09:44,  6.64s/it]

epoch : 212/300, loss = 3.532900, val_loss = 3.576962


 71%|███████   | 213/300 [23:39<08:30,  5.86s/it]

epoch : 213/300, loss = 3.530702, val_loss = 3.583205


 71%|███████▏  | 214/300 [23:45<08:31,  5.94s/it]

epoch : 214/300, loss = 3.528400, val_loss = 3.579530


 72%|███████▏  | 215/300 [23:57<11:04,  7.82s/it]

epoch : 215/300, loss = 3.524900, val_loss = 3.584484


 72%|███████▏  | 216/300 [24:02<09:50,  7.03s/it]

epoch : 216/300, loss = 3.528712, val_loss = 3.587047


 72%|███████▏  | 217/300 [24:07<08:46,  6.34s/it]

epoch : 217/300, loss = 3.527243, val_loss = 3.584941


 73%|███████▎  | 218/300 [24:11<07:41,  5.63s/it]

epoch : 218/300, loss = 3.526608, val_loss = 3.584366


 73%|███████▎  | 219/300 [24:16<07:16,  5.39s/it]

epoch : 219/300, loss = 3.526971, val_loss = 3.580376


 73%|███████▎  | 220/300 [24:24<08:13,  6.17s/it]

epoch : 220/300, loss = 3.527542, val_loss = 3.585968


 74%|███████▎  | 221/300 [24:29<07:35,  5.77s/it]

epoch : 221/300, loss = 3.527132, val_loss = 3.586934


 74%|███████▍  | 222/300 [24:38<09:00,  6.93s/it]

epoch : 222/300, loss = 3.525018, val_loss = 3.583298


 74%|███████▍  | 223/300 [24:43<08:00,  6.24s/it]

epoch : 223/300, loss = 3.525468, val_loss = 3.586751


 75%|███████▍  | 224/300 [24:50<08:18,  6.57s/it]

epoch : 224/300, loss = 3.524216, val_loss = 3.585532
epoch : 225/300, loss = 3.526534, val_loss = 3.551427
val loss decreased, saved model


 75%|███████▌  | 226/300 [25:18<11:31,  9.34s/it]

epoch : 226/300, loss = 3.526001, val_loss = 3.581815


 76%|███████▌  | 227/300 [25:24<10:23,  8.54s/it]

epoch : 227/300, loss = 3.522822, val_loss = 3.561330


 76%|███████▌  | 228/300 [25:34<10:41,  8.91s/it]

epoch : 228/300, loss = 3.524972, val_loss = 3.578599


 76%|███████▋  | 229/300 [25:38<08:48,  7.44s/it]

epoch : 229/300, loss = 3.526266, val_loss = 3.586059


 77%|███████▋  | 230/300 [25:44<08:18,  7.12s/it]

epoch : 230/300, loss = 3.524954, val_loss = 3.574127


 77%|███████▋  | 231/300 [25:53<08:41,  7.55s/it]

epoch : 231/300, loss = 3.522726, val_loss = 3.579090


 77%|███████▋  | 232/300 [26:01<08:36,  7.60s/it]

epoch : 232/300, loss = 3.521952, val_loss = 3.580426


 78%|███████▊  | 233/300 [26:07<08:02,  7.20s/it]

epoch : 233/300, loss = 3.522872, val_loss = 3.583968


 78%|███████▊  | 234/300 [26:15<08:19,  7.57s/it]

epoch : 234/300, loss = 3.522316, val_loss = 3.588661


 78%|███████▊  | 235/300 [26:24<08:22,  7.73s/it]

epoch : 235/300, loss = 3.522859, val_loss = 3.568732


 79%|███████▊  | 236/300 [26:34<09:03,  8.50s/it]

epoch : 236/300, loss = 3.521457, val_loss = 3.579089


 79%|███████▉  | 237/300 [26:40<08:04,  7.69s/it]

epoch : 237/300, loss = 3.521470, val_loss = 3.578547


 79%|███████▉  | 238/300 [26:48<08:04,  7.82s/it]

epoch : 238/300, loss = 3.520866, val_loss = 3.577839


 80%|███████▉  | 239/300 [26:53<07:14,  7.13s/it]

epoch : 239/300, loss = 3.521670, val_loss = 3.551920


 80%|████████  | 240/300 [27:04<08:07,  8.13s/it]

epoch : 240/300, loss = 3.521281, val_loss = 3.575213


 80%|████████  | 241/300 [27:09<07:02,  7.16s/it]

epoch : 241/300, loss = 3.521018, val_loss = 3.570215


 81%|████████  | 242/300 [27:13<06:02,  6.25s/it]

epoch : 242/300, loss = 3.520213, val_loss = 3.575741


 81%|████████  | 243/300 [27:19<05:55,  6.23s/it]

epoch : 243/300, loss = 3.519921, val_loss = 3.571206


 81%|████████▏ | 244/300 [27:27<06:21,  6.81s/it]

epoch : 244/300, loss = 3.519885, val_loss = 3.579385


 82%|████████▏ | 245/300 [27:35<06:30,  7.11s/it]

epoch : 245/300, loss = 3.520754, val_loss = 3.582092
epoch : 246/300, loss = 3.519454, val_loss = 3.541178
val loss decreased, saved model


 82%|████████▏ | 247/300 [27:59<07:57,  9.01s/it]

epoch : 247/300, loss = 3.521987, val_loss = 3.584784


 83%|████████▎ | 248/300 [28:06<07:14,  8.35s/it]

epoch : 248/300, loss = 3.522484, val_loss = 3.556334


 83%|████████▎ | 249/300 [28:13<06:44,  7.94s/it]

epoch : 249/300, loss = 3.523122, val_loss = 3.579995


 83%|████████▎ | 250/300 [28:18<05:54,  7.08s/it]

epoch : 250/300, loss = 3.518015, val_loss = 3.580323


 84%|████████▎ | 251/300 [28:32<07:19,  8.97s/it]

epoch : 251/300, loss = 3.518687, val_loss = 3.584321


 84%|████████▍ | 252/300 [28:36<06:11,  7.74s/it]

epoch : 252/300, loss = 3.518541, val_loss = 3.562242


 84%|████████▍ | 253/300 [28:48<06:56,  8.87s/it]

epoch : 253/300, loss = 3.519279, val_loss = 3.583975


 85%|████████▍ | 254/300 [28:52<05:38,  7.36s/it]

epoch : 254/300, loss = 3.518255, val_loss = 3.576805


 85%|████████▌ | 255/300 [28:56<04:54,  6.53s/it]

epoch : 255/300, loss = 3.517573, val_loss = 3.562420


 85%|████████▌ | 256/300 [29:01<04:16,  5.83s/it]

epoch : 256/300, loss = 3.516493, val_loss = 3.577949
epoch : 257/300, loss = 3.516770, val_loss = 3.536707
val loss decreased, saved model


 86%|████████▌ | 258/300 [29:24<05:43,  8.17s/it]

epoch : 258/300, loss = 3.518184, val_loss = 3.579004


 86%|████████▋ | 259/300 [29:37<06:28,  9.49s/it]

epoch : 259/300, loss = 3.517717, val_loss = 3.582613


 87%|████████▋ | 260/300 [29:42<05:24,  8.10s/it]

epoch : 260/300, loss = 3.518163, val_loss = 3.560668


 87%|████████▋ | 261/300 [29:46<04:27,  6.87s/it]

epoch : 261/300, loss = 3.516476, val_loss = 3.577608


 87%|████████▋ | 262/300 [29:54<04:40,  7.39s/it]

epoch : 262/300, loss = 3.515052, val_loss = 3.540607


 88%|████████▊ | 263/300 [30:00<04:15,  6.91s/it]

epoch : 263/300, loss = 3.514571, val_loss = 3.581324


 88%|████████▊ | 264/300 [30:06<03:57,  6.61s/it]

epoch : 264/300, loss = 3.514429, val_loss = 3.574195


 88%|████████▊ | 265/300 [30:13<03:54,  6.70s/it]

epoch : 265/300, loss = 3.513968, val_loss = 3.563431


 89%|████████▊ | 266/300 [30:19<03:47,  6.70s/it]

epoch : 266/300, loss = 3.514518, val_loss = 3.538750


 89%|████████▉ | 267/300 [30:25<03:29,  6.36s/it]

epoch : 267/300, loss = 3.515779, val_loss = 3.577000


 89%|████████▉ | 268/300 [30:36<04:06,  7.71s/it]

epoch : 268/300, loss = 3.513073, val_loss = 3.575818


 90%|████████▉ | 269/300 [30:41<03:31,  6.82s/it]

epoch : 269/300, loss = 3.513619, val_loss = 3.543741


 90%|█████████ | 270/300 [30:47<03:16,  6.56s/it]

epoch : 270/300, loss = 3.513238, val_loss = 3.538488


 90%|█████████ | 271/300 [30:56<03:35,  7.43s/it]

epoch : 271/300, loss = 3.512870, val_loss = 3.556352


 91%|█████████ | 272/300 [31:01<03:05,  6.62s/it]

epoch : 272/300, loss = 3.513084, val_loss = 3.579627


 91%|█████████ | 273/300 [31:13<03:46,  8.40s/it]

epoch : 273/300, loss = 3.512835, val_loss = 3.542550


 91%|█████████▏| 274/300 [31:18<03:12,  7.41s/it]

epoch : 274/300, loss = 3.512417, val_loss = 3.549431


 92%|█████████▏| 275/300 [31:25<03:00,  7.24s/it]

epoch : 275/300, loss = 3.514247, val_loss = 3.570941


 92%|█████████▏| 276/300 [31:35<03:13,  8.05s/it]

epoch : 276/300, loss = 3.512424, val_loss = 3.542037


 92%|█████████▏| 277/300 [31:43<03:04,  8.00s/it]

epoch : 277/300, loss = 3.513531, val_loss = 3.542369


 93%|█████████▎| 278/300 [31:51<02:58,  8.09s/it]

epoch : 278/300, loss = 3.511916, val_loss = 3.541590


 93%|█████████▎| 279/300 [31:57<02:35,  7.41s/it]

epoch : 279/300, loss = 3.512643, val_loss = 3.544121


 93%|█████████▎| 280/300 [32:08<02:49,  8.46s/it]

epoch : 280/300, loss = 3.511916, val_loss = 3.540215


 94%|█████████▎| 281/300 [32:12<02:16,  7.19s/it]

epoch : 281/300, loss = 3.512487, val_loss = 3.536793


 94%|█████████▍| 282/300 [32:17<01:55,  6.41s/it]

epoch : 282/300, loss = 3.512913, val_loss = 3.543126


 94%|█████████▍| 283/300 [32:25<01:55,  6.81s/it]

epoch : 283/300, loss = 3.513192, val_loss = 3.542921


 95%|█████████▍| 284/300 [32:31<01:47,  6.70s/it]

epoch : 284/300, loss = 3.512301, val_loss = 3.569187


 95%|█████████▌| 285/300 [32:36<01:30,  6.04s/it]

epoch : 285/300, loss = 3.511761, val_loss = 3.548139
epoch : 286/300, loss = 3.512308, val_loss = 3.536267
val loss decreased, saved model


 96%|█████████▌| 287/300 [32:54<01:33,  7.22s/it]

epoch : 287/300, loss = 3.512820, val_loss = 3.540750


 96%|█████████▌| 288/300 [33:01<01:27,  7.29s/it]

epoch : 288/300, loss = 3.512122, val_loss = 3.549146


 96%|█████████▋| 289/300 [33:07<01:15,  6.84s/it]

epoch : 289/300, loss = 3.512206, val_loss = 3.573745


 97%|█████████▋| 290/300 [33:12<01:02,  6.23s/it]

epoch : 290/300, loss = 3.512395, val_loss = 3.542763


 97%|█████████▋| 291/300 [33:16<00:51,  5.68s/it]

epoch : 291/300, loss = 3.511049, val_loss = 3.548139


 97%|█████████▋| 292/300 [33:23<00:46,  5.84s/it]

epoch : 292/300, loss = 3.512640, val_loss = 3.538276


 98%|█████████▊| 293/300 [33:27<00:36,  5.28s/it]

epoch : 293/300, loss = 3.511480, val_loss = 3.578825


 98%|█████████▊| 294/300 [33:34<00:35,  5.99s/it]

epoch : 294/300, loss = 3.511057, val_loss = 3.545604


 98%|█████████▊| 295/300 [33:42<00:32,  6.45s/it]

epoch : 295/300, loss = 3.511876, val_loss = 3.542101


 99%|█████████▊| 296/300 [33:50<00:28,  7.11s/it]

epoch : 296/300, loss = 3.511383, val_loss = 3.552651


 99%|█████████▉| 297/300 [33:55<00:18,  6.30s/it]

epoch : 297/300, loss = 3.511335, val_loss = 3.567878


 99%|█████████▉| 298/300 [34:01<00:12,  6.20s/it]

epoch : 298/300, loss = 3.511082, val_loss = 3.547419


100%|█████████▉| 299/300 [34:07<00:06,  6.30s/it]

epoch : 299/300, loss = 3.510762, val_loss = 3.571260


100%|██████████| 300/300 [34:12<00:00,  6.84s/it]

epoch : 300/300, loss = 3.510644, val_loss = 3.546981





In [68]:
criterion = nn.CrossEntropyLoss()

In [40]:
# Example of target with class indices
loss = nn.CrossEntropyLoss()
input = torch.randn(3, 5, requires_grad=True)
target = torch.empty(3, dtype=torch.long).random_(5)
output = loss(input, target)
output.backward()
# Example of target with class probabilities
input = torch.randn(3, 5, requires_grad=True)
target = torch.randn(3, 5).softmax(dim=1)
output = loss(input, target)
output.backward()

In [44]:
output

tensor(2.1669, grad_fn=<DivBackward1>)

In [43]:
torch.randn(3, 5, requires_grad=True)

tensor([[ 1.2228, -0.6118, -0.3990, -0.1868, -1.2650],
        [-1.0470, -0.2454, -0.5919,  1.3557,  0.2780],
        [-1.9957, -2.3080,  1.0985,  2.4081, -0.2818]], requires_grad=True)

In [41]:
torch.empty(3, dtype=torch.long).random_(5)

tensor([0, 4, 0])

In [42]:
torch.randn(3, 5).softmax(dim=1)

tensor([[0.3029, 0.0985, 0.3255, 0.1294, 0.1438],
        [0.3081, 0.2550, 0.1167, 0.1676, 0.1527],
        [0.2497, 0.0885, 0.3613, 0.1574, 0.1432]])

In [2]:
torch.cuda.empty_cache()
torch.cuda.memory_summary(device=None)

