## Importing all libraries


In [1]:
import os
import pathlib
import numpy as np
import string
import glob
import re
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from collections import Counter
from torch.utils.data import DataLoader, Dataset, Subset

plt.style.use("ggplot")

  from .autonotebook import tqdm as notebook_tqdm


### Setting seeds for reproducibility


In [2]:
def set_seed(seed=1234):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # for multiple gpus
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True


set_seed()

In [3]:
OUTPUT_DIR = os.path.join("outputs", "imdb_movie_review_embedding")
os.makedirs(OUTPUT_DIR, exist_ok=True)

In [4]:
data_dir = os.path.join("input", "aclImdb")
dataset_dir = os.path.join(data_dir, "aclImdb")
train_dir = os.path.join(dataset_dir, "train")
print(f"The dataset directory is {dataset_dir}")
print(f"The train directory is {train_dir}")

The dataset directory is input/aclImdb/aclImdb
The train directory is input/aclImdb/aclImdb/train


### Preprocessing


In [17]:
def find_longest_length(text_file_paths):
    """
    Find the longest review length in the entire training set.
    """
    longest = 0
    for path in text_file_paths:
        with open(path, "r") as f:
            text = f.read()
            # Remove <br> tags.
            text = re.sub("<[^>]+>+", "", text)
            corpus = [word for word in text.split()]
        if len(corpus) > longest:
            longest = len(corpus)
    return longest

In [18]:
file_paths = []
file_paths.extend(glob.glob(os.path.join(dataset_dir, "train", "pos", "*.txt")))
file_paths.extend(glob.glob(os.path.join(dataset_dir, "train", "neg", "*.txt")))
longest_sentence_length = find_longest_length(file_paths)
print(f"Longest review length: {longest_sentence_length} words")

Longest review length: 2450 words


In [42]:
def find_avg_length(text_file_paths):
    """
    Find the average review size in the training
    """
    all_lens = []
    for path in text_file_paths:
        with open(path, "r") as f:
            text = f.read()
            # Remove <br> tag
            text = re.sub("<[^>]+>+", "", text)
            corpus = [word for word in text.split()]
        all_lens.append(len(corpus))
    np_array = np.array(all_lens)
    return np.mean(np_array)

In [44]:
file_paths = []
file_paths.extend(glob.glob(os.path.join(dataset_dir, "train", "pos", "*.txt")))
file_paths.extend(glob.glob(os.path.join(data_dir, "train", "neg", "*.txt")))
average_length = find_avg_length(file_paths)
print(f"Average review length: {average_length} words")

Average review length: 232.76296 words


### Defining some general hyper parameters


In [72]:
MAX_LEN = int(longest_sentence_length)
NUM_WORDS = -1
BATCH_SIZE = 512
VALID_SPLIT = 0.25

In [73]:
def find_word_frequency(text_file_paths, most_common=None):
    """
    Returns a list of tuples, (<word>,<frequency>)
    """
    corpus = []
    for path in text_file_paths:
        with open(path, "r") as f:
            text = f.read()
            text = re.sub("<^>+>+", "", text)
            corpus.extend([word for word in text.split()])
    count_words = Counter(corpus)
    word_frequency = count_words.most_common(n=most_common)
    return word_frequency

In [74]:
 def word2int(input_words, num_words):
    """
        int_mapping: A dictionary of word and a integer mapping as 
            key-value pair. Example, {'Hello,': 1, 'the': 2, 'let': 3}
    """
    if num_words > -1:
        int_mapping = {
            w:i+1 for i, (w, c) in enumerate(input_words) \
                if i <= num_words - 1 # -1 to avoid getting (num_words + 1) integer mapping.
        }
    else:
        int_mapping = {w:i+1 for i, (w, c) in enumerate(input_words)}
    return int_mapping

### Creating a pytorch dataset


In [83]:
class ClassificationDataset(Dataset):
    def __init__(self, file_paths, word_frequency, int_mapping, max_len) -> None:
        self.word_frequency = word_frequency
        self.int_mapping = int_mapping
        self.max_len = max_len
        self.file_paths = file_paths

    def standardize_text(self, input_text: str) -> str:
        text = input_text.lower()
        text = re.sub(r"<^>+>+", "", text)
        text = "".join(
            [character for character in text if character not in string.punctuation]
        )
        return text

    def return_int_vector(self, int_mapping, text_file_path):
        """
        Map word to int in list
        """
        with open(text_file_path, "r") as f:
            text = f.read()
            text = self.standardize_text(text)
            corpus = [word for word in text.split()]
        int_vector = [int_mapping[word] for word in text.split() if word in int_mapping]
        return int_vector

    def pad_features(self, int_vector, max_len):
        """
        Return features of `int_vector`, where each vector is padded
        with 0's or truncated to the input seq_length. Return as Numpy
        array.
        """
        features = np.zeros((1, max_len), dtype=int)
        if len(int_vector) <= max_len:
            zeros = list(np.zeros(max_len - len(int_vector)))
            new = zeros + int_vector
        else:
            new = int_vector[:max_len]
        features = np.array(new)
        return features

    def encode_labels(self, text_file_path):
        file_path = pathlib.Path(text_file_path)
        class_label = str(file_path).split(os.path.sep)[-2]
        if class_label == "pos":
            int_label = 1
        else:
            int_label = 0
        return int_label

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        file_path = self.file_paths[idx]
        int_vector = self.return_int_vector(self.int_mapping, file_path)
        padded_features = self.pad_features(int_vector, self.max_len)
        label = self.encode_labels(file_path)
        return {
            "text": torch.tensor(padded_features, dtype=torch.int32),
            "label": torch.tensor(label, dtype=torch.long),
        }

In [84]:
file_paths = []
file_paths.extend(glob.glob(os.path.join(
    dataset_dir, 'train', 'pos', '*.txt'
)))
file_paths.extend(glob.glob(os.path.join(
    dataset_dir, 'train', 'neg', '*.txt'
)))
test_file_paths = []
test_file_paths.extend(glob.glob(os.path.join(
    dataset_dir, 'test', 'pos', '*.txt'
)))
test_file_paths.extend(glob.glob(os.path.join(
    dataset_dir, 'test', 'neg', '*.txt'
)))

In [85]:
word_frequency = find_word_frequency(file_paths)
int_mapping = word2int(word_frequency, num_words=NUM_WORDS)
print(f'The word freq is {word_frequency}')

IOPub data rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_data_rate_limit`.

Current values:
ServerApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
ServerApp.rate_limit_window=3.0 (secs)



In [86]:
dataset = ClassificationDataset(
    file_paths, word_frequency, int_mapping, MAX_LEN
)
dataset_size = len(dataset)
# Calculate the validation dataset size.
valid_size = int(VALID_SPLIT*dataset_size)
# Radomize the data indices.
indices = torch.randperm(len(dataset)).tolist()
# Training and validation sets.
dataset_train = Subset(dataset, indices[:-valid_size])
dataset_valid = Subset(dataset, indices[-valid_size:])
dataset_test = ClassificationDataset(
    test_file_paths, word_frequency, int_mapping, MAX_LEN
)
# dataset_valid = NLPClassificationDataset()
print(f"Number of training samples: {len(dataset_train)}")
print(f"Number of validation samples: {len(dataset_valid)}")
print(f"Number of test samples: {len(dataset_test)}")

Number of training samples: 18750
Number of validation samples: 6250
Number of test samples: 25000


In [87]:
int_to_word_train = {value: key for key, value in int_mapping.items()}
inputs = ''
for x in dataset_train[0]['text']:
    if x != 0:
        inputs += ' ' + int_to_word_train[int(x)]
print(inputs)
print('#'*25)
if int(dataset_train[0]['label']) == 1:
    label = 'Positive'
else:
    label = 'Negative'
print('Label:', label)

 what a stunning episode for this fine series this is television excellence at its best the story takes place in 1968 and its beautifully filmed in black white almost a film noir style with its deep shadows and stark images this is a story about two men who fall in love but i dont want to spoil this it is a rare presentation of what homosexuals faced in the 1960s in america written by the superb tom and directed by the great we move through their lives their love for each other and their tragedy taking on such a sensitive issue makes this episode all the more stunning our emotions are as torn and on edge as the characters chills ran up my spine at the end when they played bob gorgeous ah but i was so much older then im younger than that now as sung by the this one goes far past a 10 and all the way to the stars beautiful
#########################
Label: Positive


### Preparing the train and test data loader

In [88]:
train_loader = DataLoader(
    dataset_train, 
    batch_size=BATCH_SIZE,
    shuffle=True, 
    num_workers=4
)
valid_loader = DataLoader(
    dataset_valid, 
    batch_size=BATCH_SIZE,
    shuffle=False, 
    num_workers=4
)
test_loader = DataLoader(
    dataset_test, 
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=4
)

In [89]:
def binary_accuracy(labels, outputs, train_running_correct):
    outputs = torch.sigmoid(outputs)
    running_correct = 0
    for i, label in enumerate(labels):
        if label < 0.5 and outputs[i] < 0.5:
            running_correct += 1
        elif label >= 0.5 and outputs[i] >= 0.5:
            running_correct += 1
    return running_correct

In [90]:
def train(model, trainloader, optimizer, criterion, device):
    model.train()
    print('Training........')
    train_running_loss = 0.0
    train_running_correct = 0
    counter = 0
    for i, data in tqdm(enumerate(trainloader), total=len(trainloader)):
        counter += 1
        inputs, labels = data['text'], data['label']
        inputs = inputs.to(device)
        labels = torch.tensor(labels, dtype=torch.float32).to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        outputs = torch.squeeze(outputs, -1)
        loss = criterion(outputs, labels)
        train_running_loss += loss.item()
        running_correct = binary_accuracy(
            labels, outputs, train_running_correct
        )
        train_running_correct += running_correct
        loss.backward()
        optimizer.step()
    
    epoch_loss = train_running_loss / counter
    epoch_acc = 100. * (train_running_correct / len(trainloader.dataset))
    return epoch_loss, epoch_acc
# Validation function.
def validate(model, testloader, criterion, device):
    model.eval()
    print('Validation')
    valid_running_loss = 0.0
    valid_running_correct = 0
    counter = 0
    
    with torch.no_grad():
        for i, data in tqdm(enumerate(testloader), total=len(testloader)):
            counter += 1
            inputs, labels = data['text'], data['label']
            inputs = inputs.to(device)
            labels = torch.tensor(labels, dtype=torch.float32).to(device)
            outputs = model(inputs)
            outputs = torch.squeeze(outputs, -1)
            loss = criterion(outputs, labels)
            valid_running_loss += loss.item()
            running_correct = binary_accuracy(
                labels, outputs, valid_running_correct
            )
            valid_running_correct += running_correct
        
    epoch_loss = valid_running_loss / counter
    epoch_acc = 100. * (valid_running_correct / len(testloader.dataset))
    return epoch_loss, epoch_acc

In [95]:
 class TextEmbedding(nn.Module):
    def __init__(self, vocab_size, max_len, embed_dim):
        super(TextEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim=embed_dim)
        self.linear1 = nn.Linear(max_len, 1)
        
        self.dropout = nn.Dropout(0.5)
    def forward(self, x):
        x = self.embedding(x)
        x = self.dropout(x)
        bs, _, _ = x.shape
        x = F.adaptive_avg_pool1d(x, 1).reshape(bs, -1)
        out = self.linear1(x)
        return out   

In [96]:
EMBED_DIM = 50
model = TextEmbedding(
    len(int_mapping)+1, 
    MAX_LEN,
    EMBED_DIM
).to(device)

### Actually Training the model

In [97]:
print(model)
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(
    model.parameters(), 
    lr=0.001,
)
total_params = sum(p.numel() for p in model.parameters())
print(f"{total_params:,} total parameters.")
total_trainable_params = sum(
    p.numel() for p in model.parameters() if p.requires_grad)
print(f"{total_trainable_params:,} training parameters.\n")

TextEmbedding(
  (embedding): Embedding(280618, 50)
  (linear1): Linear(in_features=2450, out_features=1, bias=True)
  (dropout): Dropout(p=0.5, inplace=False)
)
14,033,351 total parameters.
14,033,351 training parameters.



In [98]:
epochs = 50
train_loss, valid_loss = [], []
train_acc, valid_acc = [], []
for epoch in range(epochs):
    print(f"Epoch {epoch+1} of {epochs}")
    train_epoch_loss, train_epoch_acc = train(model, train_loader, 
                                            optimizer, criterion, device)
    valid_epoch_loss, valid_epoch_acc = validate(model, valid_loader,  
                                                criterion, device)
    train_loss.append(train_epoch_loss)
    valid_loss.append(valid_epoch_loss)
    train_acc.append(train_epoch_acc)
    valid_acc.append(valid_epoch_acc)
    print(f"Training loss: {train_epoch_loss}, training acc: {train_epoch_acc}")
    print(f"Validation loss: {valid_epoch_loss}, validation acc: {valid_epoch_acc}")
    # Save model.
    torch.save(
        model, os.path.join(OUTPUTS_DIR, 'model.pth')
    )
    print('-'*50)

Epoch 1 of 50
Training........


  labels = torch.tensor(labels, dtype=torch.float32).to(device)
100%|█████████████████████████████████████████████████████████████| 37/37 [00:04<00:00,  8.13it/s]

Validation



  labels = torch.tensor(labels, dtype=torch.float32).to(device)
100%|█████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 13.52it/s]


Training loss: 0.6971754512271365, training acc: 50.34133333333334
Validation loss: 0.6925831849758441, validation acc: 51.632
--------------------------------------------------
Epoch 2 of 50
Training........


100%|█████████████████████████████████████████████████████████████| 37/37 [00:04<00:00,  8.98it/s]

Validation



100%|█████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 13.54it/s]


Training loss: 0.6946328491777987, training acc: 51.136
Validation loss: 0.692703888966487, validation acc: 51.632
--------------------------------------------------
Epoch 3 of 50
Training........


100%|█████████████████████████████████████████████████████████████| 37/37 [00:04<00:00,  8.76it/s]

Validation



100%|█████████████████████████████████████████████████████████████| 13/13 [00:01<00:00, 12.94it/s]


Training loss: 0.6920583634763151, training acc: 52.202666666666666
Validation loss: 0.6910931834807763, validation acc: 53.888000000000005
--------------------------------------------------
Epoch 4 of 50
Training........


100%|█████████████████████████████████████████████████████████████| 37/37 [00:04<00:00,  9.15it/s]

Validation



100%|█████████████████████████████████████████████████████████████| 13/13 [00:01<00:00, 12.39it/s]


Training loss: 0.6883447073601388, training acc: 53.99466666666667
Validation loss: 0.6876389659368075, validation acc: 52.528
--------------------------------------------------
Epoch 5 of 50
Training........


100%|█████████████████████████████████████████████████████████████| 37/37 [00:04<00:00,  9.16it/s]

Validation



100%|█████████████████████████████████████████████████████████████| 13/13 [00:01<00:00, 12.80it/s]


Training loss: 0.6823924828220058, training acc: 56.250666666666675
Validation loss: 0.6786368351716262, validation acc: 55.296
--------------------------------------------------
Epoch 6 of 50
Training........


100%|█████████████████████████████████████████████████████████████| 37/37 [00:04<00:00,  8.91it/s]

Validation



100%|█████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 13.73it/s]


Training loss: 0.6684939667985246, training acc: 60.848
Validation loss: 0.660184057859274, validation acc: 63.983999999999995
--------------------------------------------------
Epoch 7 of 50
Training........


100%|█████████████████████████████████████████████████████████████| 37/37 [00:04<00:00,  9.01it/s]

Validation



100%|█████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 13.26it/s]


Training loss: 0.641685290916546, training acc: 66.672
Validation loss: 0.6314420241575974, validation acc: 70.48
--------------------------------------------------
Epoch 8 of 50
Training........


100%|█████████████████████████████████████████████████████████████| 37/37 [00:03<00:00,  9.44it/s]

Validation



100%|█████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 13.37it/s]


Training loss: 0.6061335834296974, training acc: 70.98133333333332
Validation loss: 0.5920733075875503, validation acc: 73.776
--------------------------------------------------
Epoch 9 of 50
Training........


100%|█████████████████████████████████████████████████████████████| 37/37 [00:03<00:00,  9.38it/s]

Validation



100%|█████████████████████████████████████████████████████████████| 13/13 [00:01<00:00, 11.37it/s]


Training loss: 0.5633737412658898, training acc: 74.34666666666668
Validation loss: 0.5483204218057486, validation acc: 76.4
--------------------------------------------------
Epoch 10 of 50
Training........


100%|█████████████████████████████████████████████████████████████| 37/37 [00:04<00:00,  8.87it/s]

Validation



100%|█████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 13.64it/s]


Training loss: 0.5201743180687363, training acc: 77.312
Validation loss: 0.5100624286211454, validation acc: 78.256
--------------------------------------------------
Epoch 11 of 50
Training........


100%|█████████████████████████████████████████████████████████████| 37/37 [00:04<00:00,  9.08it/s]

Validation



100%|█████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 13.66it/s]


Training loss: 0.4822179022673014, training acc: 79.48266666666667
Validation loss: 0.478204417687196, validation acc: 80.11200000000001
--------------------------------------------------
Epoch 12 of 50
Training........


100%|█████████████████████████████████████████████████████████████| 37/37 [00:03<00:00,  9.44it/s]

Validation



100%|█████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 13.31it/s]


Training loss: 0.4503240657819284, training acc: 81.08266666666667
Validation loss: 0.4510009197088388, validation acc: 81.328
--------------------------------------------------
Epoch 13 of 50
Training........


100%|█████████████████████████████████████████████████████████████| 37/37 [00:03<00:00,  9.44it/s]

Validation



100%|█████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 13.46it/s]


Training loss: 0.42427539019971283, training acc: 82.37333333333333
Validation loss: 0.4284611138013693, validation acc: 82.368
--------------------------------------------------
Epoch 14 of 50
Training........


100%|█████████████████████████████████████████████████████████████| 37/37 [00:04<00:00,  9.09it/s]

Validation



100%|█████████████████████████████████████████████████████████████| 13/13 [00:01<00:00, 10.84it/s]


Training loss: 0.39865417093844024, training acc: 83.408
Validation loss: 0.4087434663222386, validation acc: 83.232
--------------------------------------------------
Epoch 15 of 50
Training........


100%|█████████████████████████████████████████████████████████████| 37/37 [00:03<00:00,  9.47it/s]

Validation



100%|█████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 13.93it/s]


Training loss: 0.381512384156923, training acc: 84.17066666666668
Validation loss: 0.392691616828625, validation acc: 83.824
--------------------------------------------------
Epoch 16 of 50
Training........


100%|█████████████████████████████████████████████████████████████| 37/37 [00:03<00:00,  9.44it/s]

Validation



100%|█████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 13.92it/s]


Training loss: 0.3609704488032573, training acc: 85.12533333333333
Validation loss: 0.37935539392324596, validation acc: 84.208
--------------------------------------------------
Epoch 17 of 50
Training........


100%|█████████████████████████████████████████████████████████████| 37/37 [00:03<00:00,  9.47it/s]

Validation



100%|█████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 13.83it/s]


Training loss: 0.3446130905602429, training acc: 85.88799999999999
Validation loss: 0.3671519183195554, validation acc: 84.816
--------------------------------------------------
Epoch 18 of 50
Training........


100%|█████████████████████████████████████████████████████████████| 37/37 [00:03<00:00,  9.47it/s]

Validation



100%|█████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 13.76it/s]


Training loss: 0.33018439927616633, training acc: 86.75733333333334
Validation loss: 0.35730383487848133, validation acc: 85.232
--------------------------------------------------
Epoch 19 of 50
Training........


100%|█████████████████████████████████████████████████████████████| 37/37 [00:04<00:00,  8.11it/s]

Validation



100%|█████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 13.84it/s]


Training loss: 0.3185387004066158, training acc: 87.36533333333333
Validation loss: 0.3494497835636139, validation acc: 85.37599999999999
--------------------------------------------------
Epoch 20 of 50
Training........


100%|█████████████████████████████████████████████████████████████| 37/37 [00:03<00:00,  9.54it/s]

Validation



100%|█████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 13.79it/s]


Training loss: 0.308582067489624, training acc: 87.62666666666667
Validation loss: 0.341267296901116, validation acc: 85.792
--------------------------------------------------
Epoch 21 of 50
Training........


100%|█████████████████████████████████████████████████████████████| 37/37 [00:03<00:00,  9.55it/s]

Validation



100%|█████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 13.12it/s]


Training loss: 0.2971411734013944, training acc: 88.29333333333334
Validation loss: 0.335135810650312, validation acc: 86.112
--------------------------------------------------
Epoch 22 of 50
Training........


100%|█████████████████████████████████████████████████████████████| 37/37 [00:04<00:00,  9.23it/s]

Validation



100%|█████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 13.95it/s]


Training loss: 0.28704881023716283, training acc: 88.848
Validation loss: 0.3291157942551833, validation acc: 86.44800000000001
--------------------------------------------------
Epoch 23 of 50
Training........


100%|█████████████████████████████████████████████████████████████| 37/37 [00:03<00:00,  9.41it/s]

Validation



100%|█████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 13.75it/s]


Training loss: 0.28036722621402227, training acc: 88.976
Validation loss: 0.32395079273443955, validation acc: 86.64
--------------------------------------------------
Epoch 24 of 50
Training........


100%|█████████████████████████████████████████████████████████████| 37/37 [00:03<00:00,  9.48it/s]

Validation



100%|█████████████████████████████████████████████████████████████| 13/13 [00:01<00:00, 12.71it/s]


Training loss: 0.27146198217933243, training acc: 89.14666666666666
Validation loss: 0.3192700674900642, validation acc: 86.688
--------------------------------------------------
Epoch 25 of 50
Training........


100%|█████████████████████████████████████████████████████████████| 37/37 [00:04<00:00,  8.77it/s]

Validation



100%|█████████████████████████████████████████████████████████████| 13/13 [00:01<00:00, 11.80it/s]


Training loss: 0.2630383211213189, training acc: 89.61066666666667
Validation loss: 0.31593991013673633, validation acc: 86.92800000000001
--------------------------------------------------
Epoch 26 of 50
Training........


100%|█████████████████████████████████████████████████████████████| 37/37 [00:04<00:00,  7.46it/s]

Validation



100%|█████████████████████████████████████████████████████████████| 13/13 [00:01<00:00,  9.97it/s]


Training loss: 0.2579832830139109, training acc: 89.80266666666667
Validation loss: 0.31201120752554673, validation acc: 87.072
--------------------------------------------------
Epoch 27 of 50
Training........


100%|█████████████████████████████████████████████████████████████| 37/37 [00:04<00:00,  7.70it/s]

Validation



100%|█████████████████████████████████████████████████████████████| 13/13 [00:01<00:00, 11.14it/s]


Training loss: 0.25213703553418854, training acc: 90.144
Validation loss: 0.30880354459469134, validation acc: 86.896
--------------------------------------------------
Epoch 28 of 50
Training........


100%|█████████████████████████████████████████████████████████████| 37/37 [00:04<00:00,  7.82it/s]

Validation



100%|█████████████████████████████████████████████████████████████| 13/13 [00:01<00:00, 10.44it/s]


Training loss: 0.2430432877830557, training acc: 90.49066666666667
Validation loss: 0.3056671481866103, validation acc: 87.248
--------------------------------------------------
Epoch 29 of 50
Training........


100%|█████████████████████████████████████████████████████████████| 37/37 [00:04<00:00,  7.75it/s]

Validation



100%|█████████████████████████████████████████████████████████████| 13/13 [00:01<00:00, 11.37it/s]


Training loss: 0.23808520750419512, training acc: 90.82666666666667
Validation loss: 0.3029696735051962, validation acc: 87.136
--------------------------------------------------
Epoch 30 of 50
Training........


100%|█████████████████████████████████████████████████████████████| 37/37 [00:04<00:00,  7.71it/s]

Validation



100%|█████████████████████████████████████████████████████████████| 13/13 [00:01<00:00, 11.37it/s]


Training loss: 0.2336964764305063, training acc: 91.01333333333334
Validation loss: 0.30079577977840716, validation acc: 87.248
--------------------------------------------------
Epoch 31 of 50
Training........


100%|█████████████████████████████████████████████████████████████| 37/37 [00:04<00:00,  7.69it/s]

Validation



100%|█████████████████████████████████████████████████████████████| 13/13 [00:01<00:00, 11.22it/s]


Training loss: 0.2289726790544149, training acc: 91.11466666666666
Validation loss: 0.298579073869265, validation acc: 87.408
--------------------------------------------------
Epoch 32 of 50
Training........


100%|█████████████████████████████████████████████████████████████| 37/37 [00:04<00:00,  7.76it/s]

Validation



100%|█████████████████████████████████████████████████████████████| 13/13 [00:01<00:00, 10.90it/s]


Training loss: 0.22445862881235173, training acc: 91.24799999999999
Validation loss: 0.2969231972327599, validation acc: 87.36
--------------------------------------------------
Epoch 33 of 50
Training........


100%|█████████████████████████████████████████████████████████████| 37/37 [00:04<00:00,  7.78it/s]

Validation



100%|█████████████████████████████████████████████████████████████| 13/13 [00:01<00:00, 11.09it/s]


Training loss: 0.21973957725473353, training acc: 91.45066666666666
Validation loss: 0.29499539274435777, validation acc: 87.424
--------------------------------------------------
Epoch 34 of 50
Training........


100%|█████████████████████████████████████████████████████████████| 37/37 [00:04<00:00,  7.69it/s]

Validation



100%|█████████████████████████████████████████████████████████████| 13/13 [00:01<00:00, 11.31it/s]


Training loss: 0.2142953345099011, training acc: 91.632
Validation loss: 0.2935543633424319, validation acc: 87.6
--------------------------------------------------
Epoch 35 of 50
Training........


100%|█████████████████████████████████████████████████████████████| 37/37 [00:04<00:00,  7.79it/s]

Validation



100%|█████████████████████████████████████████████████████████████| 13/13 [00:01<00:00, 11.55it/s]


Training loss: 0.21124208940042033, training acc: 91.92
Validation loss: 0.29225873717894923, validation acc: 87.68
--------------------------------------------------
Epoch 36 of 50
Training........


100%|█████████████████████████████████████████████████████████████| 37/37 [00:04<00:00,  7.46it/s]

Validation



100%|█████████████████████████████████████████████████████████████| 13/13 [00:01<00:00, 11.19it/s]


Training loss: 0.20243493447432648, training acc: 92.4
Validation loss: 0.29101309868005604, validation acc: 87.712
--------------------------------------------------
Epoch 37 of 50
Training........


100%|█████████████████████████████████████████████████████████████| 37/37 [00:04<00:00,  7.70it/s]

Validation



100%|█████████████████████████████████████████████████████████████| 13/13 [00:01<00:00, 10.93it/s]


Training loss: 0.20210852051103437, training acc: 92.31466666666667
Validation loss: 0.2902596569978274, validation acc: 87.872
--------------------------------------------------
Epoch 38 of 50
Training........


100%|█████████████████████████████████████████████████████████████| 37/37 [00:04<00:00,  7.78it/s]

Validation



100%|█████████████████████████████████████████████████████████████| 13/13 [00:01<00:00, 11.63it/s]


Training loss: 0.19813975893162392, training acc: 92.58133333333333
Validation loss: 0.28896288917614865, validation acc: 87.872
--------------------------------------------------
Epoch 39 of 50
Training........


100%|█████████████████████████████████████████████████████████████| 37/37 [00:04<00:00,  7.63it/s]

Validation



100%|█████████████████████████████████████████████████████████████| 13/13 [00:01<00:00, 11.36it/s]


Training loss: 0.1939967528388307, training acc: 92.66133333333333
Validation loss: 0.28811487784752476, validation acc: 87.968
--------------------------------------------------
Epoch 40 of 50
Training........


100%|█████████████████████████████████████████████████████████████| 37/37 [00:04<00:00,  7.81it/s]

Validation



100%|█████████████████████████████████████████████████████████████| 13/13 [00:01<00:00, 11.62it/s]


Training loss: 0.19200391543878093, training acc: 92.69333333333334
Validation loss: 0.28725635088407075, validation acc: 87.856
--------------------------------------------------
Epoch 41 of 50
Training........


100%|█████████████████████████████████████████████████████████████| 37/37 [00:04<00:00,  7.45it/s]

Validation



100%|█████████████████████████████████████████████████████████████| 13/13 [00:01<00:00, 11.24it/s]


Training loss: 0.18313315832937085, training acc: 93.12533333333334
Validation loss: 0.2866937563969539, validation acc: 87.936
--------------------------------------------------
Epoch 42 of 50
Training........


100%|█████████████████████████████████████████████████████████████| 37/37 [00:04<00:00,  7.47it/s]

Validation



100%|█████████████████████████████████████████████████████████████| 13/13 [00:01<00:00, 10.22it/s]


Training loss: 0.18261135631316416, training acc: 93.22666666666667
Validation loss: 0.2861699370237497, validation acc: 88.0
--------------------------------------------------
Epoch 43 of 50
Training........


100%|█████████████████████████████████████████████████████████████| 37/37 [00:04<00:00,  7.82it/s]

Validation



100%|█████████████████████████████████████████████████████████████| 13/13 [00:01<00:00, 11.25it/s]


Training loss: 0.18148557398770307, training acc: 93.09866666666666
Validation loss: 0.28575536150198716, validation acc: 88.096
--------------------------------------------------
Epoch 44 of 50
Training........


100%|█████████████████████████████████████████████████████████████| 37/37 [00:04<00:00,  7.64it/s]

Validation



100%|█████████████████████████████████████████████████████████████| 13/13 [00:01<00:00, 11.48it/s]


Training loss: 0.1800202670935038, training acc: 93.27466666666666
Validation loss: 0.2852632082425631, validation acc: 88.128
--------------------------------------------------
Epoch 45 of 50
Training........


100%|█████████████████████████████████████████████████████████████| 37/37 [00:04<00:00,  7.83it/s]

Validation



100%|█████████████████████████████████████████████████████████████| 13/13 [00:01<00:00, 10.81it/s]


Training loss: 0.17295608770202947, training acc: 93.73333333333333
Validation loss: 0.28552706425006574, validation acc: 87.936
--------------------------------------------------
Epoch 46 of 50
Training........


100%|█████████████████████████████████████████████████████████████| 37/37 [00:04<00:00,  7.85it/s]

Validation



100%|█████████████████████████████████████████████████████████████| 13/13 [00:01<00:00, 11.31it/s]


Training loss: 0.16790141406896952, training acc: 93.78666666666666
Validation loss: 0.28470727572074306, validation acc: 88.096
--------------------------------------------------
Epoch 47 of 50
Training........


100%|█████████████████████████████████████████████████████████████| 37/37 [00:04<00:00,  7.88it/s]

Validation



100%|█████████████████████████████████████████████████████████████| 13/13 [00:01<00:00, 11.04it/s]


Training loss: 0.16515989883525953, training acc: 93.96266666666668
Validation loss: 0.284686994094115, validation acc: 88.064
--------------------------------------------------
Epoch 48 of 50
Training........


100%|█████████████████████████████████████████████████████████████| 37/37 [00:04<00:00,  7.64it/s]

Validation



100%|█████████████████████████████████████████████████████████████| 13/13 [00:01<00:00, 10.31it/s]


Training loss: 0.16597838498450615, training acc: 93.70666666666668
Validation loss: 0.28439922286913943, validation acc: 88.048
--------------------------------------------------
Epoch 49 of 50
Training........


100%|█████████████████████████████████████████████████████████████| 37/37 [00:04<00:00,  7.87it/s]

Validation



100%|█████████████████████████████████████████████████████████████| 13/13 [00:01<00:00, 11.15it/s]


Training loss: 0.16295965237391963, training acc: 94.128
Validation loss: 0.2843192689693891, validation acc: 88.128
--------------------------------------------------
Epoch 50 of 50
Training........


100%|█████████████████████████████████████████████████████████████| 37/37 [00:04<00:00,  7.80it/s]

Validation



100%|█████████████████████████████████████████████████████████████| 13/13 [00:01<00:00, 11.37it/s]


Training loss: 0.15676193382288958, training acc: 94.17066666666668
Validation loss: 0.28474500660712904, validation acc: 88.144
--------------------------------------------------


In [None]:
def save_plots(train_acc, valid_acc, train_loss, valid_loss):
    """
    Function to save the loss and accuracy plots to disk.
    """
    plt.figure(figsize=(10, 7))
    plt.plot(
        train_acc, color='blue', linestyle='-', 
        label='train accuracy'
    )
    plt.plot(
        valid_acc, color='red', linestyle='-', 
        label='validataion accuracy'
    )
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.savefig(os.path.join(OUTPUTS_DIR, 'accuracy.png'))
    plt.show()
    # Loss plots.
    plt.figure(figsize=(10, 7))
    plt.plot(
        train_loss, color='blue', linestyle='-', 
        label='train loss'
    )
    plt.plot(
        valid_loss, color='red', linestyle='-', 
        label='validataion loss'
    )
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.savefig(os.path.join(OUTPUTS_DIR, 'loss.png'))
    plt.show() 
save_plots(train_acc, valid_acc, train_loss, valid_loss)