In [8]:
cfg = {
    'dev_train_len': 5*10**3,
    'dev_validation_len': 1*10**3,
    'learning_rate': 0.001,
    'epochs': 100,
    'embedding_dim': 50,
    'batch_size': 32,
    'dropout': 0.1,
    'optimizer': 'Adam',
    'num_layers': 2
}

cfg['hidden_dim'] = cfg['embedding_dim']

In [9]:
DEV_TRAIN_LEN = cfg['dev_train_len']
DEV_VALIDATION_LEN = cfg['dev_validation_len']
LEARNING_RATE = cfg['learning_rate']
EPOCHS = cfg['epochs']
BATCH_SIZE = cfg['batch_size']
DROPOUT = cfg['dropout']
OPTIMIZER = cfg['optimizer']
NUM_LAYERS = cfg['num_layers']
HIDDEN_DIM = cfg['hidden_dim']
EMBEDDING_DIM = cfg['embedding_dim']

DIR = '/scratch/shu7bh/RES/Word/1'

Create Dir

In [10]:
import os
if not os.path.exists(DIR):
    os.makedirs(DIR)

Set Device

In [43]:
import torch

if torch.cuda.is_available():
    DEVICE = torch.device('cuda')
else:
    DEVICE = torch.device('cpu')
print(DEVICE)

cuda


Prepare Data

In [112]:
from gensim.downloader import load

glove_dict = {
    '50': 'glove-wiki-gigaword-50',
    '100': 'glove-wiki-gigaword-100',
    '200': 'glove-wiki-gigaword-200'
}

glove_dict[str(EMBEDDING_DIM)] = load(glove_dict[str(EMBEDDING_DIM)])
# glove_dict['100'] = api.load(glove_dict['100'])
# glove_dict['200'] = api.load(glove_dict['200'])

glove = glove_dict[str(EMBEDDING_DIM)]

In [49]:
from nltk.tokenize import word_tokenize
from nltk.stem import WordNetLemmatizer
import pandas as pd
import unicodedata
import re

def normalize_unicode(text: str) -> str:
    return unicodedata.normalize('NFD', text)

def clean_data(text: str) -> str:
    text = normalize_unicode(text.lower().strip())
    text = re.sub(r"([.!?])", r" \1", text)
    text = re.sub(r"[^a-zA-Z.!?]+", r" ", text)
    return text

lemmatizer = WordNetLemmatizer()
freq_words = dict()

def tokenize_data(text: str, create_unique_words: bool) -> list:
    tokens = [lemmatizer.lemmatize(token) for token in word_tokenize(text)]
    tokens = [token if token in glove else '<unk>' for token in tokens]

    if '<unk>' in tokens:
        return tokens

    if create_unique_words:
        for token in tokens:
            if token not in freq_words:
                freq_words[token] = 1
            else:
                freq_words[token] += 1
    return tokens

def replace_words(tokens: list, filter_rare_words: bool) -> list:
    tokens = [token if token in freq_words else '<unk>' for token in tokens]
    if filter_rare_words:
        tokens = [token if freq_words[token] >= 4 else '<unk>' for token in tokens]
    return tokens

def read_data(path: str, create_unique_words, filter_rare_words) -> pd.DataFrame:
    df = pd.read_csv(path)
    df = df.sample(frac=1, random_state=0).reset_index(drop=True)
    df['Description'] = df['Description'].apply(clean_data)
    df['Description'] = df['Description'].apply(tokenize_data, create_unique_words=create_unique_words)

    df = df[df['Description'].apply(lambda x: '<unk>' not in x)]
    df = df.reset_index(drop=True)

    df['Class Index'] = df['Class Index'].apply(lambda x: x-1)
    ydf = df.copy(deep=True)
    ydf['Description'] = ydf['Description'].apply(replace_words, filter_rare_words=filter_rare_words)
    return df, ydf

In [50]:
freq_words = dict()
xdf, ydf = read_data(
    'data/train.csv', 
    create_unique_words=True, 
    filter_rare_words=True
)

unique_words = set()
for tokens in ydf['Description']:
    unique_words.update(tokens)
print(len(unique_words))

21489


In [51]:
xdf

Unnamed: 0,Class Index,Description
0,3,"[london, british, airline, magnate, richard, b..."
1,3,"[regardless, space, competition, are, poised, ..."
2,0,"[cbs, million, of, folded, paper, crane, flutt..."
3,3,"[in, the, time, it, take, you, to, read, this,..."
4,0,"[washington, a, highly, classified, u, intelli..."
...,...,...
101981,1,"[toronto, reuters, national, hockey, league, t..."
101982,3,"[com, september, am, pt, ., there, s, no, doub..."
101983,0,"[pakistani, security, force, have, arrested, m..."
101984,3,"[palmsource, finally, unveiled, it, new, o, ve..."


In [52]:
NUM_CLASSES = len(set(xdf['Class Index']))
NUM_CLASSES

4

In [53]:
# Create a dictionary of all words
word_to_idx = {word: idx + 1 for idx, word in enumerate(unique_words)}

# Add special tokens
word_to_idx['<pad>'] = 0
word_to_idx['<sos>'] = len(word_to_idx)
word_to_idx['<eos>'] = len(word_to_idx)

# Create a dictionary of all words
idx_to_word = {idx: word for word, idx in word_to_idx.items()}

# print the length of the word to index mapping
print(len(word_to_idx))

21492


In [64]:
idx_to_vec = glove.vectors

In [114]:
glove_key_to_idx = glove.key_to_index

In [116]:
glove_key_to_idx['<pad>'] = len(glove_key_to_idx)
glove_key_to_idx['<sos>'] = len(glove_key_to_idx)
glove_key_to_idx['<eos>'] = len(glove_key_to_idx)
glove_key_to_idx['<unk>'] = len(glove_key_to_idx)

In [130]:
glove_idx_to_key = {idx: key for key, idx in glove_key_to_idx.items()}

In [121]:
import numpy as np

glove_idx_to_vec = glove.vectors

pad_vec = np.zeros((1, EMBEDDING_DIM))
sos_vec = np.random.rand(1, EMBEDDING_DIM)
eos_vec = np.random.rand(1, EMBEDDING_DIM)
unk_vec = np.mean(glove_idx_to_vec, axis=0, keepdims=True)

glove_idx_to_vec = np.concatenate((glove_idx_to_vec, pad_vec, sos_vec, eos_vec, unk_vec), axis=0)

In [125]:
dev_train_raw_x = xdf[:DEV_TRAIN_LEN]
dev_train_raw_y = ydf[:DEV_TRAIN_LEN]

dev_validation_raw_x = xdf[DEV_TRAIN_LEN:DEV_TRAIN_LEN+DEV_VALIDATION_LEN]
dev_validation_raw_y = ydf[DEV_TRAIN_LEN:DEV_TRAIN_LEN+DEV_VALIDATION_LEN]

Dataset

In [127]:
from torch.utils.data import Dataset

class Sentences(Dataset):
    def __init__(
            self, 
            adf: pd.DataFrame, 
            pdf: pd.DataFrame, 
            word_to_idx: dict,
            glove_key_to_idx: dict
        ) -> None:

        self.X = []
        self.Y_ = []

        for sentence in adf['Description']:
            self.X += [torch.tensor(
                [glove_key_to_idx[w] for w in sentence] + 
                [glove_key_to_idx['<eos>']]
            )]

        for sentence in pdf['Description']:
            self.Y_ += [torch.tensor(
                [word_to_idx['<sos>']] + 
                [word_to_idx[w] for w in sentence] + 
                [word_to_idx['<eos>']] + 
                [word_to_idx['<pad>']]
            )]

        self.Y = torch.tensor(adf['Class Index'].tolist())

    def __len__(self) -> int:
        return len(self.X)

    def __getitem__(self, idx: int) -> tuple:
        return self.X[idx], self.Y[idx], torch.tensor(len(self.X[idx])), self.Y_[idx]

Create Dataset

In [128]:
dev_train_dataset = Sentences(dev_train_raw_x, dev_train_raw_y, word_to_idx, glove_key_to_idx)
dev_validation_dataset = Sentences(dev_validation_raw_x, dev_validation_raw_y, word_to_idx, glove_key_to_idx)

In [143]:
def collate_fn(batch: list) -> tuple:
    x, y, l, y_ = zip(*batch)

    x = torch.nn.utils.rnn.pad_sequence(x, padding_value=glove_key_to_idx['<pad>'], batch_first=True)
    y_ = torch.nn.utils.rnn.pad_sequence(y_, padding_value=word_to_idx['<pad>'], batch_first=True)
    return x, torch.stack(y), torch.stack(l), y_[..., 2:], y_[..., :-2]

In [144]:
from torch.utils.data import DataLoader

dev_train_loader = DataLoader(dev_train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
dev_validation_loader = DataLoader(dev_validation_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)

In [155]:
type(idx_to_word)

dict

ELMo

In [251]:
from torch import nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

class ELMo(nn.Module):
    def __init__(
            self, 
            glove_idx_to_vec: np.ndarray,
            idx_to_word: dict,
            dropout: float, 
            num_layers: int, 
            hidden_dim: int, 
            word_embed_dim: int,
            filename: str = None
        ) -> None:

        super(ELMo, self).__init__()

        self.word_embed = nn.Embedding.from_pretrained(torch.from_numpy(glove_idx_to_vec).float(), padding_idx=glove_key_to_idx['<pad>'])

        self.lstmf = nn.LSTM(
            input_size=word_embed_dim, 
            hidden_size=hidden_dim, 
            num_layers=num_layers, 
            dropout=dropout,
            batch_first=True
        )

        self.lstmb = nn.LSTM(
            input_size=word_embed_dim, 
            hidden_size=hidden_dim, 
            num_layers=num_layers, 
            dropout=dropout,
            batch_first=True
        )

        self.dropout = nn.Dropout(dropout)
        self.num_layers = num_layers
        self.hidden_dim = hidden_dim

        if filename:
            self.load_state_dict(torch.load(filename))

    def forward(self, x: torch.Tensor, l: torch.Tensor) -> torch.Tensor:
        x = self.word_embed(x)
        input = x.detach().clone()

        xf = pack_padded_sequence(x, lengths=l, batch_first=True, enforce_sorted=False)
        xb = pack_padded_sequence(x.flip([1]), lengths=l, batch_first=True, enforce_sorted=False)

        xf, (hsf, csf) = self.lstmf(xf)
        xb, (hsb, csb) = self.lstmb(xb)

        xf, _ = pad_packed_sequence(xf, batch_first=True)
        xb, _ = pad_packed_sequence(xb, batch_first=True)

        xb = xb.flip([1])

        xf = self.dropout(xf)
        xb = self.dropout(xb)
        return xf, xb, input, (hsf, csf), (hsb, csb)

    def state_dict_custom(self):
        state_dict = self.state_dict()
        del state_dict['word_embed.weight'] 
        return state_dict

    def save(self, filename: str) -> None:
        torch.save(self.state_dict_custom(), filename)

Early Stopping

In [252]:
import numpy as np

class EarlyStopping:
    def __init__(self, patience:int = 3, delta:float = 0.001):
        self.patience = patience
        self.counter = 0
        self.best_loss:float = np.inf
        self.best_model_pth = 0
        self.delta = delta

    def __call__(self, loss, epoch: int):
        should_stop = False

        if loss >= self.best_loss - self.delta:
            self.counter += 1
            if self.counter > self.patience:
                should_stop = True
        else:
            self.best_loss = loss
            self.counter = 0
            self.best_model_pth = epoch
        return should_stop

WANDB

In [253]:
# import wandb

# run = wandb.init(project='ELMo', entity='shu7bh', name='UpStream and DownStream')
# config = wandb.config

# config.dev_train_len = DEV_TRAIN_LEN
# config.dev_validation_len = DEV_VALIDATION_LEN
# config.learning_rate = LEARNING_RATE
# config.epochs = EPOCHS
# config.char_embedding_dim = CHAR_EMBEDDING_DIM
# config.batch_size = BATCH_SIZE
# config.dropout = DROPOUT
# config.optimizer = OPTIMIZER
# config.num_layers = NUM_LAYERS
# config.word_emb_dim = WORD_EMB_DIM
# config.max_word_len = MAX_WORD_LEN
# config.hidden_dim = HIDDEN_DIM
# config.char_out_channels = CHAR_OUT_CHANNELS

LM

In [254]:
from tqdm import tqdm

class LM(nn.Module):
    def __init__(self, 
            hidden_dim: int, 
            vocab_size: int, 
            filename: str = None
        ) -> None:

        super(LM, self).__init__()
        self.hidden_dim = hidden_dim
        self.vocab_size = vocab_size
        self.elmo = ELMo(
            glove_idx_to_vec=glove_idx_to_vec,
            idx_to_word=idx_to_word,
            dropout=DROPOUT, 
            num_layers=NUM_LAYERS, 
            hidden_dim=HIDDEN_DIM,
            word_embed_dim=EMBEDDING_DIM
        )
        self.linear_forward = nn.Linear(hidden_dim, vocab_size)
        self.linear_backward = nn.Linear(hidden_dim, vocab_size)

        if filename:
            self.load_state_dict(torch.load(filename))

    def forward(self, x: torch.Tensor, l: torch.Tensor) -> torch.Tensor:
        xf, xb, _, _, _ = self.elmo(x, l)
        yf = self.linear_forward(xf)
        yb = self.linear_backward(xb)
        return yf, yb

    def fit(self, train_loader: DataLoader, validation_loader: DataLoader, epochs: int, learning_rate: float, filename: str) -> None:
        self.es = EarlyStopping()
        self.optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate)
        self.criterion = nn.CrossEntropyLoss(ignore_index=word_to_idx['<pad>'])

        for epoch in range(epochs):
            print('----------------------------------------')
            self._train(train_loader)
            loss = self._evaluate(validation_loader)
            print(f'Epoch: {epoch + 1} | Loss: {loss:7.4f}')
            if self.es(loss, epoch):
                break
            if self.es.counter == 0:
                # torch.save(self.state_dict(), os.path.join(DIR, f'{filename}_lm.pth'))
                # torch.save(self.elmo.state_dict(), os.path.join(DIR, f'{filename}_elmo.pth'))
                self.elmo.save(os.path.join(DIR, f'{filename}_elmo.pth'))

    def _call(self, x: torch.Tensor, y: torch.Tensor, l: torch.Tensor, yf: torch.Tensor, yb: torch.Tensor) -> torch.Tensor:
        x, y, yf, yb = x.to(DEVICE), y.to(DEVICE), yf.to(DEVICE), yb.to(DEVICE)
        yf_hat, yb_hat = self(x, l)

        yf_hat = yf_hat.view(-1, self.vocab_size)
        yb_hat = yb_hat.view(-1, self.vocab_size)

        yf = yf.view(-1)
        yb = yb.view(-1)

        loss1 = self.criterion(yf_hat, yf)
        loss2 = self.criterion(yb_hat, yb)

        loss = (loss1 + loss2) / 2

        return loss, loss1, loss2

    def _train(self, train_loader: DataLoader) -> None:
        self.train()
        epoch_loss = []
        epoch_loss1 = []
        epoch_loss2 = []

        pbar = tqdm(train_loader)
        for x, y, l, yf, yb in pbar:

            loss, loss1, loss2 = self._call(x, y, l, yf, yb)
            epoch_loss.append(loss.item())
            epoch_loss1.append(loss1.item())
            epoch_loss2.append(loss2.item())
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            pbar.set_description(f'T Loss: {loss.item():7.4f}, Avg Loss: {np.mean(epoch_loss):7.4f}, Avg Loss1: {np.mean(epoch_loss1):7.4f}, Avg Loss2: {np.mean(epoch_loss2):7.4f}')

        # run.log({'upstream_train_loss': np.mean(epoch_loss)})

    def _evaluate(self, validation_loader: DataLoader) -> float:
        self.eval()
        epoch_loss = []
        epoch_loss1 = []
        epoch_loss2 = []
        pbar = tqdm(validation_loader)
        with torch.no_grad():
            for x, y, l, yf, yb in pbar:
                loss, loss1, loss2 = self._call(x, y, l, yf, yb)
                epoch_loss.append(loss.item())
                epoch_loss1.append(loss1.item())
                epoch_loss2.append(loss2.item())
                pbar.set_description(f'V Loss: {epoch_loss[-1]:7.4f}, Avg Loss: {np.mean(epoch_loss):7.4f}, Avg Loss1: {np.mean(epoch_loss1):7.4f}, Avg Loss2: {np.mean(epoch_loss2):7.4f}, Counter: {self.es.counter}, Best Loss: {self.es.best_loss:7.4f}')

        # run.log({'upstream_validation_loss': np.mean(epoch_loss)})
        return np.mean(epoch_loss)

Initialize Model

In [255]:
lm = LM(
    hidden_dim=HIDDEN_DIM, 
    vocab_size=len(word_to_idx), 
    filename=None
).to(DEVICE)
print(lm)

LM(
  (elmo): ELMo(
    (word_embed): Embedding(400004, 50, padding_idx=400000)
    (lstmf): LSTM(50, 50, num_layers=2, batch_first=True, dropout=0.1)
    (lstmb): LSTM(50, 50, num_layers=2, batch_first=True, dropout=0.1)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (linear_forward): Linear(in_features=50, out_features=21492, bias=True)
  (linear_backward): Linear(in_features=50, out_features=21492, bias=True)
)


In [256]:
from torchinfo import summary

summary(lm, device=DEVICE)

Layer (type:depth-idx)                   Param #
LM                                       --
├─ELMo: 1-1                              --
│    └─Embedding: 2-1                    (20,000,200)
│    └─LSTM: 2-2                         40,800
│    └─LSTM: 2-3                         40,800
│    └─Dropout: 2-4                      --
├─Linear: 1-2                            1,096,092
├─Linear: 1-3                            1,096,092
Total params: 22,273,984
Trainable params: 2,273,784
Non-trainable params: 20,000,200

In [257]:
lm.fit(dev_train_loader, dev_validation_loader, epochs=EPOCHS, learning_rate=LEARNING_RATE, filename='best')

----------------------------------------


T Loss:  7.7357, Avg Loss:  8.3316, Avg Loss1:  7.7042, Avg Loss2:  8.9589: 100%|██████████| 157/157 [00:04<00:00, 33.61it/s]
V Loss:  8.0127, Avg Loss:  7.9358, Avg Loss1:  7.0450, Avg Loss2:  8.8267, Counter: 0, Best Loss:     inf: 100%|██████████| 32/32 [00:00<00:00, 85.14it/s]


Epoch: 1 | Loss:  7.9358
----------------------------------------


T Loss:  7.6471, Avg Loss:  7.7860, Avg Loss1:  6.9782, Avg Loss2:  8.5938: 100%|██████████| 157/157 [00:04<00:00, 33.73it/s]
V Loss:  8.0163, Avg Loss:  7.8747, Avg Loss1:  7.0445, Avg Loss2:  8.7049, Counter: 0, Best Loss:  7.9358: 100%|██████████| 32/32 [00:00<00:00, 86.34it/s]


Epoch: 2 | Loss:  7.8747
----------------------------------------


T Loss:  7.1592, Avg Loss:  7.7499, Avg Loss1:  6.9560, Avg Loss2:  8.5438: 100%|██████████| 157/157 [00:04<00:00, 33.44it/s]
V Loss:  7.5219, Avg Loss:  7.8069, Avg Loss1:  7.0528, Avg Loss2:  8.5610, Counter: 0, Best Loss:  7.8747: 100%|██████████| 32/32 [00:00<00:00, 86.80it/s]


Epoch: 3 | Loss:  7.8069
----------------------------------------


T Loss:  7.1676, Avg Loss:  7.6396, Avg Loss1:  6.9479, Avg Loss2:  8.3312: 100%|██████████| 157/157 [00:04<00:00, 33.68it/s]
V Loss:  7.2229, Avg Loss:  7.7423, Avg Loss1:  7.0605, Avg Loss2:  8.4241, Counter: 0, Best Loss:  7.8069: 100%|██████████| 32/32 [00:00<00:00, 86.42it/s]


Epoch: 4 | Loss:  7.7423
----------------------------------------


T Loss:  7.6188, Avg Loss:  7.5761, Avg Loss1:  6.9415, Avg Loss2:  8.2106: 100%|██████████| 157/157 [00:04<00:00, 33.54it/s]
V Loss:  7.7310, Avg Loss:  7.7299, Avg Loss1:  7.0616, Avg Loss2:  8.3982, Counter: 0, Best Loss:  7.7423: 100%|██████████| 32/32 [00:00<00:00, 84.27it/s]


Epoch: 5 | Loss:  7.7299
----------------------------------------


T Loss:  7.2065, Avg Loss:  7.5033, Avg Loss1:  6.9358, Avg Loss2:  8.0708: 100%|██████████| 157/157 [00:04<00:00, 34.36it/s]
V Loss:  7.7913, Avg Loss:  7.7049, Avg Loss1:  7.0665, Avg Loss2:  8.3433, Counter: 0, Best Loss:  7.7299: 100%|██████████| 32/32 [00:00<00:00, 82.87it/s]


Epoch: 6 | Loss:  7.7049
----------------------------------------


T Loss:  7.2917, Avg Loss:  7.4862, Avg Loss1:  6.9304, Avg Loss2:  8.0420: 100%|██████████| 157/157 [00:04<00:00, 33.74it/s]
V Loss:  7.3596, Avg Loss:  7.6711, Avg Loss1:  7.0677, Avg Loss2:  8.2745, Counter: 0, Best Loss:  7.7049: 100%|██████████| 32/32 [00:00<00:00, 83.23it/s]


Epoch: 7 | Loss:  7.6711
----------------------------------------


T Loss:  7.0419, Avg Loss:  7.4251, Avg Loss1:  6.9177, Avg Loss2:  7.9325: 100%|██████████| 157/157 [00:04<00:00, 33.78it/s]
V Loss:  7.5515, Avg Loss:  7.5772, Avg Loss1:  7.0517, Avg Loss2:  8.1028, Counter: 0, Best Loss:  7.6711: 100%|██████████| 32/32 [00:00<00:00, 84.20it/s]


Epoch: 8 | Loss:  7.5772
----------------------------------------


T Loss:  6.9948, Avg Loss:  7.3812, Avg Loss1:  6.8987, Avg Loss2:  7.8637: 100%|██████████| 157/157 [00:04<00:00, 33.76it/s]
V Loss:  7.2113, Avg Loss:  7.5489, Avg Loss1:  7.0312, Avg Loss2:  8.0666, Counter: 0, Best Loss:  7.5772: 100%|██████████| 32/32 [00:00<00:00, 85.16it/s]


Epoch: 9 | Loss:  7.5489
----------------------------------------


T Loss:  7.1530, Avg Loss:  7.3254, Avg Loss1:  6.8762, Avg Loss2:  7.7745: 100%|██████████| 157/157 [00:04<00:00, 33.89it/s]
V Loss:  7.1497, Avg Loss:  7.4800, Avg Loss1:  7.0094, Avg Loss2:  7.9505, Counter: 0, Best Loss:  7.5489: 100%|██████████| 32/32 [00:00<00:00, 85.52it/s]


Epoch: 10 | Loss:  7.4800
----------------------------------------


T Loss:  7.0758, Avg Loss:  7.2820, Avg Loss1:  6.8506, Avg Loss2:  7.7135: 100%|██████████| 157/157 [00:04<00:00, 35.65it/s]
V Loss:  6.6700, Avg Loss:  7.4620, Avg Loss1:  6.9768, Avg Loss2:  7.9472, Counter: 0, Best Loss:  7.4800: 100%|██████████| 32/32 [00:00<00:00, 83.65it/s]


Epoch: 11 | Loss:  7.4620
----------------------------------------


T Loss:  7.3008, Avg Loss:  7.2232, Avg Loss1:  6.8188, Avg Loss2:  7.6277: 100%|██████████| 157/157 [00:04<00:00, 37.96it/s]
V Loss:  7.3456, Avg Loss:  7.4064, Avg Loss1:  6.9598, Avg Loss2:  7.8530, Counter: 0, Best Loss:  7.4620: 100%|██████████| 32/32 [00:00<00:00, 83.07it/s]


Epoch: 12 | Loss:  7.4064
----------------------------------------


T Loss:  7.0713, Avg Loss:  7.1581, Avg Loss1:  6.7766, Avg Loss2:  7.5396: 100%|██████████| 157/157 [00:04<00:00, 38.14it/s]
V Loss:  7.2525, Avg Loss:  7.3298, Avg Loss1:  6.9087, Avg Loss2:  7.7510, Counter: 0, Best Loss:  7.4064: 100%|██████████| 32/32 [00:00<00:00, 85.83it/s]


Epoch: 13 | Loss:  7.3298
----------------------------------------


T Loss:  7.0402, Avg Loss:  7.1091, Avg Loss1:  6.7283, Avg Loss2:  7.4898: 100%|██████████| 157/157 [00:04<00:00, 33.92it/s]
V Loss:  7.3682, Avg Loss:  7.2713, Avg Loss1:  6.8725, Avg Loss2:  7.6701, Counter: 0, Best Loss:  7.3298: 100%|██████████| 32/32 [00:00<00:00, 86.18it/s]


Epoch: 14 | Loss:  7.2713
----------------------------------------


T Loss:  6.8825, Avg Loss:  7.0604, Avg Loss1:  6.6878, Avg Loss2:  7.4330: 100%|██████████| 157/157 [00:04<00:00, 33.74it/s]
V Loss:  7.1336, Avg Loss:  7.2402, Avg Loss1:  6.8471, Avg Loss2:  7.6334, Counter: 0, Best Loss:  7.2713: 100%|██████████| 32/32 [00:00<00:00, 85.01it/s]


Epoch: 15 | Loss:  7.2402
----------------------------------------


T Loss:  6.8511, Avg Loss:  6.9993, Avg Loss1:  6.6517, Avg Loss2:  7.3469: 100%|██████████| 157/157 [00:04<00:00, 33.57it/s]
V Loss:  6.7793, Avg Loss:  7.1651, Avg Loss1:  6.8060, Avg Loss2:  7.5241, Counter: 0, Best Loss:  7.2402: 100%|██████████| 32/32 [00:00<00:00, 85.90it/s]


Epoch: 16 | Loss:  7.1651
----------------------------------------


T Loss:  6.7121, Avg Loss:  6.9542, Avg Loss1:  6.6115, Avg Loss2:  7.2968: 100%|██████████| 157/157 [00:04<00:00, 33.54it/s]
V Loss:  6.8602, Avg Loss:  7.1449, Avg Loss1:  6.7776, Avg Loss2:  7.5122, Counter: 0, Best Loss:  7.1651: 100%|██████████| 32/32 [00:00<00:00, 82.48it/s]


Epoch: 17 | Loss:  7.1449
----------------------------------------


T Loss:  6.8247, Avg Loss:  6.9043, Avg Loss1:  6.5701, Avg Loss2:  7.2385: 100%|██████████| 157/157 [00:04<00:00, 34.27it/s]
V Loss:  6.7695, Avg Loss:  7.0722, Avg Loss1:  6.7390, Avg Loss2:  7.4053, Counter: 0, Best Loss:  7.1449: 100%|██████████| 32/32 [00:00<00:00, 87.46it/s]


Epoch: 18 | Loss:  7.0722
----------------------------------------


T Loss:  6.6003, Avg Loss:  6.8496, Avg Loss1:  6.5249, Avg Loss2:  7.1742: 100%|██████████| 157/157 [00:04<00:00, 33.79it/s]
V Loss:  6.6418, Avg Loss:  7.0376, Avg Loss1:  6.7077, Avg Loss2:  7.3675, Counter: 0, Best Loss:  7.0722: 100%|██████████| 32/32 [00:00<00:00, 84.84it/s]


Epoch: 19 | Loss:  7.0376
----------------------------------------


T Loss:  6.4215, Avg Loss:  6.8175, Avg Loss1:  6.4813, Avg Loss2:  7.1538: 100%|██████████| 157/157 [00:04<00:00, 33.66it/s]
V Loss:  6.7888, Avg Loss:  7.0102, Avg Loss1:  6.6750, Avg Loss2:  7.3455, Counter: 0, Best Loss:  7.0376: 100%|██████████| 32/32 [00:00<00:00, 84.69it/s]


Epoch: 20 | Loss:  7.0102
----------------------------------------


T Loss:  7.1230, Avg Loss:  6.7498, Avg Loss1:  6.4400, Avg Loss2:  7.0597: 100%|██████████| 157/157 [00:04<00:00, 33.85it/s]
V Loss:  6.6513, Avg Loss:  6.9461, Avg Loss1:  6.6469, Avg Loss2:  7.2453, Counter: 0, Best Loss:  7.0102: 100%|██████████| 32/32 [00:00<00:00, 88.06it/s]


Epoch: 21 | Loss:  6.9461
----------------------------------------


T Loss:  6.3518, Avg Loss:  6.6994, Avg Loss1:  6.3956, Avg Loss2:  7.0031: 100%|██████████| 157/157 [00:04<00:00, 33.98it/s]
V Loss:  6.7949, Avg Loss:  6.9255, Avg Loss1:  6.6098, Avg Loss2:  7.2412, Counter: 0, Best Loss:  6.9461: 100%|██████████| 32/32 [00:00<00:00, 84.27it/s]


Epoch: 22 | Loss:  6.9255
----------------------------------------


T Loss:  6.4426, Avg Loss:  6.6598, Avg Loss1:  6.3513, Avg Loss2:  6.9682: 100%|██████████| 157/157 [00:04<00:00, 33.24it/s]
V Loss:  6.6499, Avg Loss:  6.8971, Avg Loss1:  6.5788, Avg Loss2:  7.2154, Counter: 0, Best Loss:  6.9255: 100%|██████████| 32/32 [00:00<00:00, 84.59it/s]


Epoch: 23 | Loss:  6.8971
----------------------------------------


T Loss:  6.3408, Avg Loss:  6.6031, Avg Loss1:  6.3077, Avg Loss2:  6.8984: 100%|██████████| 157/157 [00:04<00:00, 33.76it/s]
V Loss:  6.5192, Avg Loss:  6.8432, Avg Loss1:  6.5464, Avg Loss2:  7.1400, Counter: 0, Best Loss:  6.8971: 100%|██████████| 32/32 [00:00<00:00, 84.40it/s]


Epoch: 24 | Loss:  6.8432
----------------------------------------


T Loss:  6.2900, Avg Loss:  6.5672, Avg Loss1:  6.2677, Avg Loss2:  6.8668: 100%|██████████| 157/157 [00:04<00:00, 33.61it/s]
V Loss:  6.7848, Avg Loss:  6.8152, Avg Loss1:  6.5234, Avg Loss2:  7.1070, Counter: 0, Best Loss:  6.8432: 100%|██████████| 32/32 [00:00<00:00, 86.67it/s]


Epoch: 25 | Loss:  6.8152
----------------------------------------


T Loss:  6.2401, Avg Loss:  6.5175, Avg Loss1:  6.2255, Avg Loss2:  6.8095: 100%|██████████| 157/157 [00:04<00:00, 33.94it/s]
V Loss:  6.5660, Avg Loss:  6.7851, Avg Loss1:  6.4867, Avg Loss2:  7.0835, Counter: 0, Best Loss:  6.8152: 100%|██████████| 32/32 [00:00<00:00, 84.44it/s]


Epoch: 26 | Loss:  6.7851
----------------------------------------


T Loss:  6.0176, Avg Loss:  6.4791, Avg Loss1:  6.1829, Avg Loss2:  6.7752: 100%|██████████| 157/157 [00:04<00:00, 33.72it/s]
V Loss:  6.4070, Avg Loss:  6.7432, Avg Loss1:  6.4571, Avg Loss2:  7.0293, Counter: 0, Best Loss:  6.7851: 100%|██████████| 32/32 [00:00<00:00, 81.92it/s]


Epoch: 27 | Loss:  6.7432
----------------------------------------


T Loss:  6.4902, Avg Loss:  6.4448, Avg Loss1:  6.1468, Avg Loss2:  6.7429: 100%|██████████| 157/157 [00:04<00:00, 32.98it/s]
V Loss:  6.2481, Avg Loss:  6.7358, Avg Loss1:  6.4274, Avg Loss2:  7.0442, Counter: 0, Best Loss:  6.7432: 100%|██████████| 32/32 [00:00<00:00, 84.68it/s]


Epoch: 28 | Loss:  6.7358
----------------------------------------


T Loss:  6.2315, Avg Loss:  6.3955, Avg Loss1:  6.1072, Avg Loss2:  6.6838: 100%|██████████| 157/157 [00:04<00:00, 33.86it/s]
V Loss:  5.8876, Avg Loss:  6.6910, Avg Loss1:  6.4023, Avg Loss2:  6.9798, Counter: 0, Best Loss:  6.7358: 100%|██████████| 32/32 [00:00<00:00, 85.20it/s]


Epoch: 29 | Loss:  6.6910
----------------------------------------


T Loss:  6.4070, Avg Loss:  6.3750, Avg Loss1:  6.0689, Avg Loss2:  6.6810: 100%|██████████| 157/157 [00:04<00:00, 33.56it/s]
V Loss:  6.0192, Avg Loss:  6.6481, Avg Loss1:  6.3744, Avg Loss2:  6.9218, Counter: 0, Best Loss:  6.6910: 100%|██████████| 32/32 [00:00<00:00, 85.80it/s]


Epoch: 30 | Loss:  6.6481
----------------------------------------


T Loss:  6.1001, Avg Loss:  6.3221, Avg Loss1:  6.0287, Avg Loss2:  6.6154: 100%|██████████| 157/157 [00:04<00:00, 34.07it/s]
V Loss:  6.8969, Avg Loss:  6.6418, Avg Loss1:  6.3747, Avg Loss2:  6.9090, Counter: 0, Best Loss:  6.6481: 100%|██████████| 32/32 [00:00<00:00, 86.36it/s]


Epoch: 31 | Loss:  6.6418
----------------------------------------


T Loss:  6.3737, Avg Loss:  6.2908, Avg Loss1:  5.9956, Avg Loss2:  6.5861: 100%|██████████| 157/157 [00:04<00:00, 34.22it/s]
V Loss:  6.3028, Avg Loss:  6.6345, Avg Loss1:  6.3425, Avg Loss2:  6.9265, Counter: 0, Best Loss:  6.6418: 100%|██████████| 32/32 [00:00<00:00, 84.57it/s]


Epoch: 32 | Loss:  6.6345
----------------------------------------


T Loss:  5.9302, Avg Loss:  6.2493, Avg Loss1:  5.9581, Avg Loss2:  6.5405: 100%|██████████| 157/157 [00:04<00:00, 34.80it/s]
V Loss:  6.5327, Avg Loss:  6.6081, Avg Loss1:  6.3302, Avg Loss2:  6.8860, Counter: 0, Best Loss:  6.6345: 100%|██████████| 32/32 [00:00<00:00, 87.00it/s]


Epoch: 33 | Loss:  6.6081
----------------------------------------


T Loss:  6.2055, Avg Loss:  6.2249, Avg Loss1:  5.9278, Avg Loss2:  6.5219: 100%|██████████| 157/157 [00:04<00:00, 35.02it/s]
V Loss:  7.0226, Avg Loss:  6.6069, Avg Loss1:  6.3299, Avg Loss2:  6.8840, Counter: 0, Best Loss:  6.6081: 100%|██████████| 32/32 [00:00<00:00, 87.78it/s]


Epoch: 34 | Loss:  6.6069
----------------------------------------


T Loss:  6.0545, Avg Loss:  6.1953, Avg Loss1:  5.8958, Avg Loss2:  6.4949: 100%|██████████| 157/157 [00:04<00:00, 34.37it/s]
V Loss:  6.3802, Avg Loss:  6.5630, Avg Loss1:  6.2930, Avg Loss2:  6.8331, Counter: 0, Best Loss:  6.6069: 100%|██████████| 32/32 [00:00<00:00, 86.40it/s]


Epoch: 35 | Loss:  6.5630
----------------------------------------


T Loss:  5.9455, Avg Loss:  6.1658, Avg Loss1:  5.8642, Avg Loss2:  6.4674: 100%|██████████| 157/157 [00:04<00:00, 34.59it/s]
V Loss:  6.2454, Avg Loss:  6.5595, Avg Loss1:  6.2770, Avg Loss2:  6.8421, Counter: 0, Best Loss:  6.5630: 100%|██████████| 32/32 [00:00<00:00, 86.46it/s]


Epoch: 36 | Loss:  6.5595
----------------------------------------


T Loss:  5.8889, Avg Loss:  6.1511, Avg Loss1:  5.8357, Avg Loss2:  6.4665: 100%|██████████| 157/157 [00:04<00:00, 34.34it/s]
V Loss:  6.1214, Avg Loss:  6.5333, Avg Loss1:  6.2563, Avg Loss2:  6.8103, Counter: 0, Best Loss:  6.5595: 100%|██████████| 32/32 [00:00<00:00, 84.20it/s]


Epoch: 37 | Loss:  6.5333
----------------------------------------


T Loss:  6.0339, Avg Loss:  6.1191, Avg Loss1:  5.8099, Avg Loss2:  6.4283: 100%|██████████| 157/157 [00:04<00:00, 34.00it/s]
V Loss:  6.3917, Avg Loss:  6.5194, Avg Loss1:  6.2514, Avg Loss2:  6.7873, Counter: 0, Best Loss:  6.5333: 100%|██████████| 32/32 [00:00<00:00, 89.77it/s]


Epoch: 38 | Loss:  6.5194
----------------------------------------


T Loss:  5.6737, Avg Loss:  6.0928, Avg Loss1:  5.7804, Avg Loss2:  6.4052: 100%|██████████| 157/157 [00:04<00:00, 34.36it/s]
V Loss:  6.4057, Avg Loss:  6.5045, Avg Loss1:  6.2454, Avg Loss2:  6.7637, Counter: 0, Best Loss:  6.5194: 100%|██████████| 32/32 [00:00<00:00, 86.40it/s]


Epoch: 39 | Loss:  6.5045
----------------------------------------


T Loss:  5.7361, Avg Loss:  6.0729, Avg Loss1:  5.7557, Avg Loss2:  6.3900: 100%|██████████| 157/157 [00:04<00:00, 33.78it/s]
V Loss:  6.5385, Avg Loss:  6.5177, Avg Loss1:  6.2350, Avg Loss2:  6.8003, Counter: 0, Best Loss:  6.5045: 100%|██████████| 32/32 [00:00<00:00, 82.54it/s]


Epoch: 40 | Loss:  6.5177
----------------------------------------


T Loss:  6.3386, Avg Loss:  6.0493, Avg Loss1:  5.7343, Avg Loss2:  6.3644: 100%|██████████| 157/157 [00:04<00:00, 34.04it/s]
V Loss:  6.6042, Avg Loss:  6.4758, Avg Loss1:  6.2236, Avg Loss2:  6.7279, Counter: 1, Best Loss:  6.5045: 100%|██████████| 32/32 [00:00<00:00, 85.89it/s]


Epoch: 41 | Loss:  6.4758
----------------------------------------


T Loss:  5.8612, Avg Loss:  6.0389, Avg Loss1:  5.7057, Avg Loss2:  6.3720: 100%|██████████| 157/157 [00:04<00:00, 34.02it/s]
V Loss:  6.0829, Avg Loss:  6.4909, Avg Loss1:  6.2078, Avg Loss2:  6.7741, Counter: 0, Best Loss:  6.4758: 100%|██████████| 32/32 [00:00<00:00, 84.49it/s]


Epoch: 42 | Loss:  6.4909
----------------------------------------


T Loss:  5.8588, Avg Loss:  6.0145, Avg Loss1:  5.6845, Avg Loss2:  6.3445: 100%|██████████| 157/157 [00:04<00:00, 34.13it/s]
V Loss:  6.7094, Avg Loss:  6.4848, Avg Loss1:  6.2120, Avg Loss2:  6.7576, Counter: 1, Best Loss:  6.4758: 100%|██████████| 32/32 [00:00<00:00, 83.94it/s]


Epoch: 43 | Loss:  6.4848
----------------------------------------


T Loss:  5.8558, Avg Loss:  5.9931, Avg Loss1:  5.6626, Avg Loss2:  6.3236: 100%|██████████| 157/157 [00:04<00:00, 34.30it/s]
V Loss:  6.3451, Avg Loss:  6.4780, Avg Loss1:  6.1979, Avg Loss2:  6.7582, Counter: 2, Best Loss:  6.4758: 100%|██████████| 32/32 [00:00<00:00, 84.93it/s]


Epoch: 44 | Loss:  6.4780
----------------------------------------


T Loss:  5.7883, Avg Loss:  5.9874, Avg Loss1:  5.6383, Avg Loss2:  6.3364: 100%|██████████| 157/157 [00:04<00:00, 33.82it/s]
V Loss:  6.2860, Avg Loss:  6.4569, Avg Loss1:  6.1870, Avg Loss2:  6.7267, Counter: 3, Best Loss:  6.4758: 100%|██████████| 32/32 [00:00<00:00, 88.23it/s]


Epoch: 45 | Loss:  6.4569
----------------------------------------


T Loss:  5.9093, Avg Loss:  5.9571, Avg Loss1:  5.6176, Avg Loss2:  6.2965: 100%|██████████| 157/157 [00:04<00:00, 34.22it/s]
V Loss:  6.3533, Avg Loss:  6.4758, Avg Loss1:  6.1849, Avg Loss2:  6.7666, Counter: 0, Best Loss:  6.4569: 100%|██████████| 32/32 [00:00<00:00, 82.70it/s]


Epoch: 46 | Loss:  6.4758
----------------------------------------


T Loss:  5.7294, Avg Loss:  5.9361, Avg Loss1:  5.5939, Avg Loss2:  6.2782: 100%|██████████| 157/157 [00:04<00:00, 34.59it/s]
V Loss:  6.1091, Avg Loss:  6.4490, Avg Loss1:  6.1777, Avg Loss2:  6.7203, Counter: 1, Best Loss:  6.4569: 100%|██████████| 32/32 [00:00<00:00, 87.22it/s]


Epoch: 47 | Loss:  6.4490
----------------------------------------


T Loss:  5.8467, Avg Loss:  5.9220, Avg Loss1:  5.5772, Avg Loss2:  6.2668: 100%|██████████| 157/157 [00:04<00:00, 34.75it/s]
V Loss:  6.7187, Avg Loss:  6.4557, Avg Loss1:  6.1823, Avg Loss2:  6.7291, Counter: 0, Best Loss:  6.4490: 100%|██████████| 32/32 [00:00<00:00, 85.86it/s]


Epoch: 48 | Loss:  6.4557
----------------------------------------


T Loss:  5.8715, Avg Loss:  5.9165, Avg Loss1:  5.5558, Avg Loss2:  6.2772: 100%|██████████| 157/157 [00:04<00:00, 34.31it/s]
V Loss:  6.0149, Avg Loss:  6.4443, Avg Loss1:  6.1675, Avg Loss2:  6.7211, Counter: 1, Best Loss:  6.4490: 100%|██████████| 32/32 [00:00<00:00, 85.04it/s]


Epoch: 49 | Loss:  6.4443
----------------------------------------


T Loss:  6.0206, Avg Loss:  5.9073, Avg Loss1:  5.5391, Avg Loss2:  6.2755: 100%|██████████| 157/157 [00:04<00:00, 34.27it/s]
V Loss:  6.5106, Avg Loss:  6.4494, Avg Loss1:  6.1747, Avg Loss2:  6.7241, Counter: 0, Best Loss:  6.4443: 100%|██████████| 32/32 [00:00<00:00, 85.81it/s]


Epoch: 50 | Loss:  6.4494
----------------------------------------


T Loss:  5.7627, Avg Loss:  5.8801, Avg Loss1:  5.5153, Avg Loss2:  6.2449: 100%|██████████| 157/157 [00:04<00:00, 34.19it/s]
V Loss:  6.3778, Avg Loss:  6.4526, Avg Loss1:  6.1657, Avg Loss2:  6.7395, Counter: 1, Best Loss:  6.4443: 100%|██████████| 32/32 [00:00<00:00, 84.38it/s]


Epoch: 51 | Loss:  6.4526
----------------------------------------


T Loss:  5.8192, Avg Loss:  5.8639, Avg Loss1:  5.4981, Avg Loss2:  6.2297: 100%|██████████| 157/157 [00:04<00:00, 34.01it/s]
V Loss:  6.5331, Avg Loss:  6.4436, Avg Loss1:  6.1627, Avg Loss2:  6.7245, Counter: 2, Best Loss:  6.4443: 100%|██████████| 32/32 [00:00<00:00, 86.58it/s]


Epoch: 52 | Loss:  6.4436
----------------------------------------


T Loss:  5.8883, Avg Loss:  5.8579, Avg Loss1:  5.4819, Avg Loss2:  6.2339: 100%|██████████| 157/157 [00:04<00:00, 34.16it/s]
V Loss:  6.4844, Avg Loss:  6.4513, Avg Loss1:  6.1642, Avg Loss2:  6.7384, Counter: 3, Best Loss:  6.4443: 100%|██████████| 32/32 [00:00<00:00, 82.98it/s]

Epoch: 53 | Loss:  6.4513



