<a href="https://colab.research.google.com/github/t0matoOtk/ML2022-Spring/blob/main/ML2022Spring_HW2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Homework 2 Phoneme Classification**

* Slides: https://docs.google.com/presentation/d/1v6HkBWiJb8WNDcJ9_-2kwVstxUWml87b9CnA16Gdoio/edit?usp=sharing
* Kaggle: https://www.kaggle.com/c/ml2022spring-hw2
* Video: TBA


In [None]:
!nvidia-smi

Tue Feb 11 23:15:41 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   52C    P8             12W /   70W |       0MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

## Download Data
Download data from google drive, then unzip it.

You should have
- `libriphone/train_split.txt`
- `libriphone/train_labels`
- `libriphone/test_split.txt`
- `libriphone/feat/train/*.pt`: training feature<br>
- `libriphone/feat/test/*.pt`:  testing feature<br>

after running the following block.

> **Notes: if the links are dead, you can download the data directly from [Kaggle](https://www.kaggle.com/c/ml2022spring-hw2/data) and upload it to the workspace, or you can use [the Kaggle API](https://www.kaggle.com/general/74235) to directly download the data into colab.**


### Download train/test metadata

In [None]:
# Main link
!wget -O libriphone.zip "https://github.com/xraychen/shiny-robot/releases/download/v1.0/libriphone.zip"

# Backup Link 0
# !pip install --upgrade gdown
# !gdown --id '1o6Ag-G3qItSmYhTheX6DYiuyNzWyHyTc' --output libriphone.zip

# Backup link 1
# !pip install --upgrade gdown
# !gdown --id '1R1uQYi4QpX0tBfUWt2mbZcncdBsJkxeW' --output libriphone.zip

# Backup link 2
# !wget -O libriphone.zip "https://www.dropbox.com/s/wqww8c5dbrl2ka9/libriphone.zip?dl=1"

# Backup link 3
# !wget -O libriphone.zip "https://www.dropbox.com/s/p2ljbtb2bam13in/libriphone.zip?dl=1"

!unzip -q libriphone.zip
!ls libriphone

--2025-02-11 23:15:41--  https://github.com/xraychen/shiny-robot/releases/download/v1.0/libriphone.zip
Resolving github.com (github.com)... 20.205.243.166
Connecting to github.com (github.com)|20.205.243.166|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://objects.githubusercontent.com/github-production-release-asset-2e65be/463868124/343908dd-b2e4-4b8e-b7d6-7f0f040179ce?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=releaseassetproduction%2F20250211%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20250211T231542Z&X-Amz-Expires=300&X-Amz-Signature=b80f216024e1b749033467191184094f2385c53bb56b7857b697d11839c69f99&X-Amz-SignedHeaders=host&response-content-disposition=attachment%3B%20filename%3Dlibriphone.zip&response-content-type=application%2Foctet-stream [following]
--2025-02-11 23:15:42--  https://objects.githubusercontent.com/github-production-release-asset-2e65be/463868124/343908dd-b2e4-4b8e-b7d6-7f0f040179ce?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-C

### Preparing Data

**Helper functions to pre-process the training data from raw MFCC features of each utterance.**

A phoneme may span several frames and is dependent to past and future frames. \
Hence we concatenate neighboring phonemes for training to achieve higher accuracy. The **concat_feat** function concatenates past and future k frames (total 2k+1 = n frames), and we predict the center frame.

Feel free to modify the data preprocess functions, but **do not drop any frame** (if you modify the functions, remember to check that the number of frames are the same as mentioned in the slides)

In [None]:
import os
import random
import pandas as pd
import torch
from tqdm import tqdm

def load_feat(path):
    feat = torch.load(path)
    return feat

def shift(x, n):
    if n < 0:
        left = x[0].repeat(-n, 1)
        right = x[:n]

    elif n > 0:
        right = x[-1].repeat(n, 1)
        left = x[n:]
    else:
        return x

    return torch.cat((left, right), dim=0)

def concat_feat(x, concat_n):
    assert concat_n % 2 == 1 # n must be odd
    if concat_n < 2:
        return x
    seq_len, feature_dim = x.size(0), x.size(1)
    x = x.repeat(1, concat_n)
    x = x.view(seq_len, concat_n, feature_dim).permute(1, 0, 2) # concat_n, seq_len, feature_dim
    mid = (concat_n // 2)
    for r_idx in range(1, mid+1):
        x[mid + r_idx, :] = shift(x[mid + r_idx], r_idx)
        x[mid - r_idx, :] = shift(x[mid - r_idx], -r_idx)

    return x.permute(1, 0, 2).view(seq_len, concat_n * feature_dim)

def preprocess_data(split, feat_dir, phone_path, concat_nframes, train_ratio=0.8, train_val_seed=1337):
    class_num = 41 # NOTE: pre-computed, should not need change
    mode = 'train' if (split == 'train' or split == 'val') else 'test'

    label_dict = {}
    if mode != 'test':
      phone_file = open(os.path.join(phone_path, f'{mode}_labels.txt')).readlines()

      for line in phone_file:
          line = line.strip('\n').split(' ')
          label_dict[line[0]] = [int(p) for p in line[1:]]

    if split == 'train' or split == 'val':
        # split training and validation data
        usage_list = open(os.path.join(phone_path, 'train_split.txt')).readlines()
        random.seed(train_val_seed)
        random.shuffle(usage_list)
        percent = int(len(usage_list) * train_ratio)
        usage_list = usage_list[:percent] if split == 'train' else usage_list[percent:]
    elif split == 'test':
        usage_list = open(os.path.join(phone_path, 'test_split.txt')).readlines()
    else:
        raise ValueError('Invalid \'split\' argument for dataset: PhoneDataset!')

    usage_list = [line.strip('\n') for line in usage_list]
    print('[Dataset] - # phone classes: ' + str(class_num) + ', number of utterances for ' + split + ': ' + str(len(usage_list)))

    max_len = 3000000
    X = torch.empty(max_len, 39 * concat_nframes)
    if mode != 'test':
      y = torch.empty(max_len, dtype=torch.long)

    idx = 0
    for i, fname in tqdm(enumerate(usage_list)):
        feat = load_feat(os.path.join(feat_dir, mode, f'{fname}.pt'))
        cur_len = len(feat)
        feat = concat_feat(feat, concat_nframes)
        if mode != 'test':
          label = torch.LongTensor(label_dict[fname])

        X[idx: idx + cur_len, :] = feat
        if mode != 'test':
          y[idx: idx + cur_len] = label

        idx += cur_len

    X = X[:idx, :]
    if mode != 'test':
      y = y[:idx]

    print(f'[INFO] {split} set')
    print(X.shape)
    if mode != 'test':
      print(y.shape)
      return X, y
    else:
      return X


## Define Dataset

In [None]:
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

class LibriDataset(Dataset):
    def __init__(self, X, y=None):
        self.data = X
        if y is not None:
            self.label = torch.LongTensor(y)
        else:
            self.label = None

    def __getitem__(self, idx):
        if self.label is not None:
            return self.data[idx], self.label[idx]
        else:
            return self.data[idx]

    def __len__(self):
        return len(self.data)


## Define Model

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class BasicBlock(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(BasicBlock, self).__init__()

        self.block = nn.Sequential(
            nn.Linear(input_dim, output_dim),
            nn.ReLU(),
            nn.BatchNorm1d(output_dim),  # 添加 BatchNorm
            nn.Dropout(dropout_prob)    # 添加 Dropout
        )

    def forward(self, x):
        x = self.block(x)
        return x


class Classifier(nn.Module):
    def __init__(self, input_dim, output_dim=41, hidden_layers=1, hidden_dim=256):
        super(Classifier, self).__init__()

        self.fc = nn.Sequential(
            BasicBlock(input_dim, hidden_dim),
            *[BasicBlock(hidden_dim, hidden_dim) for _ in range(hidden_layers)],
            nn.Linear(hidden_dim, output_dim)
        )

    def forward(self, x):
        x = self.fc(x)
        return x

## Hyper-parameters

In [None]:
# data prarameters
concat_nframes = 21             # the number of frames to concat with, n must be odd (total 2k+1 = n frames)
train_ratio = 0.9               # the ratio of data used for training, the rest will be used for validation

# training parameters
seed = 0                        # random seed
batch_size = 512                # batch size
num_epoch = 125                   # the number of training epoch
learning_rate = 0.0001          # learning rate
model_path = './model.ckpt'     # the path where the checkpoint will be saved

# model parameters
input_dim = 39 * concat_nframes # the input dim of the model, you should not change the value
hidden_layers = 2               # the number of hidden layers
hidden_dim = 1024              # the hidden dim
dropout_prob = 0.5            # dropout probability

## Prepare dataset and model

In [None]:
import gc

# preprocess data
train_X, train_y = preprocess_data(split='train', feat_dir='./libriphone/feat', phone_path='./libriphone', concat_nframes=concat_nframes, train_ratio=train_ratio)
val_X, val_y = preprocess_data(split='val', feat_dir='./libriphone/feat', phone_path='./libriphone', concat_nframes=concat_nframes, train_ratio=train_ratio)

# get dataset
train_set = LibriDataset(train_X, train_y)
val_set = LibriDataset(val_X, val_y)

# remove raw feature to save memory
del train_X, train_y, val_X, val_y
gc.collect()

# get dataloader
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False)

[Dataset] - # phone classes: 41, number of utterances for train: 3857


  feat = torch.load(path)
3857it [00:12, 305.22it/s]


[INFO] train set
torch.Size([2379588, 819])
torch.Size([2379588])
[Dataset] - # phone classes: 41, number of utterances for val: 429


429it [00:02, 192.99it/s]


[INFO] val set
torch.Size([264570, 819])
torch.Size([264570])


In [None]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print(f'DEVICE: {device}')

DEVICE: cuda:0


In [None]:
import numpy as np

#fix seed
def same_seeds(seed):
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

In [None]:
# fix random seed
same_seeds(seed)

# create model, define a loss function, and optimizer
model = Classifier(input_dim=input_dim, hidden_layers=hidden_layers, hidden_dim=hidden_dim).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

## Training

In [None]:
best_acc = 0.0
for epoch in range(num_epoch):
    train_acc = 0.0
    train_loss = 0.0
    val_acc = 0.0
    val_loss = 0.0

    # training
    model.train() # set the model to training mode
    for i, batch in enumerate(tqdm(train_loader)):
        features, labels = batch
        features = features.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(features)

        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        _, train_pred = torch.max(outputs, 1) # get the index of the class with the highest probability
        train_acc += (train_pred.detach() == labels.detach()).sum().item()
        train_loss += loss.item()

    # validation
    if len(val_set) > 0:
        model.eval() # set the model to evaluation mode
        with torch.no_grad():
            for i, batch in enumerate(tqdm(val_loader)):
                features, labels = batch
                features = features.to(device)
                labels = labels.to(device)
                outputs = model(features)

                loss = criterion(outputs, labels)

                _, val_pred = torch.max(outputs, 1)
                val_acc += (val_pred.cpu() == labels.cpu()).sum().item() # get the index of the class with the highest probability
                val_loss += loss.item()

            print('[{:03d}/{:03d}] Train Acc: {:3.6f} Loss: {:3.6f} | Val Acc: {:3.6f} loss: {:3.6f}'.format(
                epoch + 1, num_epoch, train_acc/len(train_set), train_loss/len(train_loader), val_acc/len(val_set), val_loss/len(val_loader)
            ))

            # if the model improves, save a checkpoint at this epoch
            if val_acc > best_acc:
                best_acc = val_acc
                torch.save(model.state_dict(), model_path)
                print('saving model with acc {:.3f}'.format(best_acc/len(val_set)))
    else:
        print('[{:03d}/{:03d}] Train Acc: {:3.6f} Loss: {:3.6f}'.format(
            epoch + 1, num_epoch, train_acc/len(train_set), train_loss/len(train_loader)
        ))

# if not validating, save the last epoch
if len(val_set) == 0:
    torch.save(model.state_dict(), model_path)
    print('saving model at last epoch')


100%|██████████| 4648/4648 [00:37<00:00, 123.09it/s]
100%|██████████| 517/517 [00:02<00:00, 207.83it/s]


[001/125] Train Acc: 0.536236 Loss: 1.565348 | Val Acc: 0.630003 loss: 1.201458
saving model with acc 0.630


100%|██████████| 4648/4648 [00:39<00:00, 118.99it/s]
100%|██████████| 517/517 [00:02<00:00, 210.99it/s]


[002/125] Train Acc: 0.602674 Loss: 1.297175 | Val Acc: 0.658982 loss: 1.096029
saving model with acc 0.659


100%|██████████| 4648/4648 [00:37<00:00, 123.98it/s]
100%|██████████| 517/517 [00:02<00:00, 204.92it/s]


[003/125] Train Acc: 0.623939 Loss: 1.218500 | Val Acc: 0.674483 loss: 1.042299
saving model with acc 0.674


100%|██████████| 4648/4648 [00:37<00:00, 124.63it/s]
100%|██████████| 517/517 [00:02<00:00, 172.87it/s]


[004/125] Train Acc: 0.636521 Loss: 1.171854 | Val Acc: 0.684042 loss: 1.004234
saving model with acc 0.684


100%|██████████| 4648/4648 [00:37<00:00, 123.52it/s]
100%|██████████| 517/517 [00:02<00:00, 207.91it/s]


[005/125] Train Acc: 0.645573 Loss: 1.138883 | Val Acc: 0.692902 loss: 0.975356
saving model with acc 0.693


100%|██████████| 4648/4648 [00:37<00:00, 124.25it/s]
100%|██████████| 517/517 [00:02<00:00, 207.27it/s]


[006/125] Train Acc: 0.652975 Loss: 1.113094 | Val Acc: 0.697218 loss: 0.957714
saving model with acc 0.697


100%|██████████| 4648/4648 [00:42<00:00, 109.46it/s]
100%|██████████| 517/517 [00:02<00:00, 194.35it/s]


[007/125] Train Acc: 0.658553 Loss: 1.093017 | Val Acc: 0.702177 loss: 0.939802
saving model with acc 0.702


100%|██████████| 4648/4648 [00:38<00:00, 121.36it/s]
100%|██████████| 517/517 [00:02<00:00, 200.00it/s]


[008/125] Train Acc: 0.663350 Loss: 1.075930 | Val Acc: 0.706735 loss: 0.925219
saving model with acc 0.707


100%|██████████| 4648/4648 [00:37<00:00, 123.38it/s]
100%|██████████| 517/517 [00:03<00:00, 166.14it/s]


[009/125] Train Acc: 0.667334 Loss: 1.062174 | Val Acc: 0.710470 loss: 0.913580
saving model with acc 0.710


100%|██████████| 4648/4648 [00:38<00:00, 121.49it/s]
100%|██████████| 517/517 [00:02<00:00, 207.26it/s]


[010/125] Train Acc: 0.670710 Loss: 1.050026 | Val Acc: 0.713339 loss: 0.901223
saving model with acc 0.713


100%|██████████| 4648/4648 [00:37<00:00, 122.88it/s]
100%|██████████| 517/517 [00:02<00:00, 196.17it/s]


[011/125] Train Acc: 0.673875 Loss: 1.039763 | Val Acc: 0.714926 loss: 0.895110
saving model with acc 0.715


100%|██████████| 4648/4648 [00:38<00:00, 121.55it/s]
100%|██████████| 517/517 [00:02<00:00, 198.13it/s]


[012/125] Train Acc: 0.676452 Loss: 1.030195 | Val Acc: 0.717867 loss: 0.887680
saving model with acc 0.718


100%|██████████| 4648/4648 [00:37<00:00, 122.66it/s]
100%|██████████| 517/517 [00:02<00:00, 198.57it/s]


[013/125] Train Acc: 0.678406 Loss: 1.022725 | Val Acc: 0.720611 loss: 0.879606
saving model with acc 0.721


100%|██████████| 4648/4648 [00:37<00:00, 124.03it/s]
100%|██████████| 517/517 [00:03<00:00, 154.34it/s]


[014/125] Train Acc: 0.680644 Loss: 1.014867 | Val Acc: 0.722338 loss: 0.874252
saving model with acc 0.722


100%|██████████| 4648/4648 [00:38<00:00, 120.12it/s]
100%|██████████| 517/517 [00:02<00:00, 197.72it/s]


[015/125] Train Acc: 0.682550 Loss: 1.008200 | Val Acc: 0.723752 loss: 0.867771
saving model with acc 0.724


100%|██████████| 4648/4648 [00:37<00:00, 124.55it/s]
100%|██████████| 517/517 [00:02<00:00, 213.21it/s]


[016/125] Train Acc: 0.684223 Loss: 1.001963 | Val Acc: 0.725063 loss: 0.862582
saving model with acc 0.725


100%|██████████| 4648/4648 [00:37<00:00, 124.72it/s]
100%|██████████| 517/517 [00:03<00:00, 167.12it/s]


[017/125] Train Acc: 0.685665 Loss: 0.997080 | Val Acc: 0.726900 loss: 0.858874
saving model with acc 0.727


100%|██████████| 4648/4648 [00:37<00:00, 122.50it/s]
100%|██████████| 517/517 [00:02<00:00, 193.82it/s]


[018/125] Train Acc: 0.686846 Loss: 0.991590 | Val Acc: 0.727176 loss: 0.855032
saving model with acc 0.727


100%|██████████| 4648/4648 [00:37<00:00, 124.22it/s]
100%|██████████| 517/517 [00:02<00:00, 197.19it/s]


[019/125] Train Acc: 0.688109 Loss: 0.987916 | Val Acc: 0.728866 loss: 0.849518
saving model with acc 0.729


100%|██████████| 4648/4648 [00:37<00:00, 124.00it/s]
100%|██████████| 517/517 [00:03<00:00, 171.26it/s]


[020/125] Train Acc: 0.689798 Loss: 0.982627 | Val Acc: 0.729727 loss: 0.847239
saving model with acc 0.730


100%|██████████| 4648/4648 [00:37<00:00, 123.88it/s]
100%|██████████| 517/517 [00:02<00:00, 210.25it/s]


[021/125] Train Acc: 0.691110 Loss: 0.978997 | Val Acc: 0.730234 loss: 0.843453
saving model with acc 0.730


100%|██████████| 4648/4648 [00:37<00:00, 124.77it/s]
100%|██████████| 517/517 [00:02<00:00, 178.77it/s]


[022/125] Train Acc: 0.692050 Loss: 0.974883 | Val Acc: 0.731655 loss: 0.840293
saving model with acc 0.732


100%|██████████| 4648/4648 [00:38<00:00, 121.55it/s]
100%|██████████| 517/517 [00:02<00:00, 178.92it/s]


[023/125] Train Acc: 0.693184 Loss: 0.971927 | Val Acc: 0.731753 loss: 0.838232
saving model with acc 0.732


100%|██████████| 4648/4648 [00:37<00:00, 124.79it/s]
100%|██████████| 517/517 [00:02<00:00, 211.35it/s]


[024/125] Train Acc: 0.693772 Loss: 0.967799 | Val Acc: 0.732834 loss: 0.836324
saving model with acc 0.733


100%|██████████| 4648/4648 [00:37<00:00, 124.58it/s]
100%|██████████| 517/517 [00:02<00:00, 212.38it/s]


[025/125] Train Acc: 0.695034 Loss: 0.964391 | Val Acc: 0.734263 loss: 0.833194
saving model with acc 0.734


100%|██████████| 4648/4648 [00:37<00:00, 123.01it/s]
100%|██████████| 517/517 [00:02<00:00, 184.95it/s]


[026/125] Train Acc: 0.696095 Loss: 0.962005 | Val Acc: 0.734116 loss: 0.831696


100%|██████████| 4648/4648 [00:38<00:00, 122.10it/s]
100%|██████████| 517/517 [00:02<00:00, 204.21it/s]


[027/125] Train Acc: 0.696476 Loss: 0.959457 | Val Acc: 0.734569 loss: 0.827758
saving model with acc 0.735


100%|██████████| 4648/4648 [00:38<00:00, 121.48it/s]
100%|██████████| 517/517 [00:03<00:00, 164.90it/s]


[028/125] Train Acc: 0.697052 Loss: 0.956448 | Val Acc: 0.736043 loss: 0.825463
saving model with acc 0.736


100%|██████████| 4648/4648 [00:37<00:00, 123.73it/s]
100%|██████████| 517/517 [00:02<00:00, 209.28it/s]


[029/125] Train Acc: 0.698222 Loss: 0.954208 | Val Acc: 0.736565 loss: 0.824060
saving model with acc 0.737


100%|██████████| 4648/4648 [00:37<00:00, 123.68it/s]
100%|██████████| 517/517 [00:02<00:00, 173.97it/s]


[030/125] Train Acc: 0.698809 Loss: 0.951784 | Val Acc: 0.736410 loss: 0.823657


100%|██████████| 4648/4648 [00:38<00:00, 121.78it/s]
100%|██████████| 517/517 [00:02<00:00, 186.75it/s]


[031/125] Train Acc: 0.699770 Loss: 0.949241 | Val Acc: 0.737344 loss: 0.820522
saving model with acc 0.737


100%|██████████| 4648/4648 [00:37<00:00, 124.50it/s]
100%|██████████| 517/517 [00:02<00:00, 211.58it/s]


[032/125] Train Acc: 0.699893 Loss: 0.947685 | Val Acc: 0.738039 loss: 0.817505
saving model with acc 0.738


100%|██████████| 4648/4648 [00:37<00:00, 124.39it/s]
100%|██████████| 517/517 [00:02<00:00, 202.16it/s]


[033/125] Train Acc: 0.700233 Loss: 0.945017 | Val Acc: 0.738538 loss: 0.817969
saving model with acc 0.739


100%|██████████| 4648/4648 [00:37<00:00, 123.15it/s]
100%|██████████| 517/517 [00:02<00:00, 193.14it/s]


[034/125] Train Acc: 0.701252 Loss: 0.942983 | Val Acc: 0.738954 loss: 0.815912
saving model with acc 0.739


100%|██████████| 4648/4648 [00:37<00:00, 122.55it/s]
100%|██████████| 517/517 [00:02<00:00, 211.32it/s]


[035/125] Train Acc: 0.701669 Loss: 0.940871 | Val Acc: 0.739585 loss: 0.815099
saving model with acc 0.740


100%|██████████| 4648/4648 [00:37<00:00, 124.83it/s]
100%|██████████| 517/517 [00:02<00:00, 184.79it/s]


[036/125] Train Acc: 0.702094 Loss: 0.939644 | Val Acc: 0.740496 loss: 0.811725
saving model with acc 0.740


100%|██████████| 4648/4648 [00:37<00:00, 124.07it/s]
100%|██████████| 517/517 [00:02<00:00, 199.29it/s]


[037/125] Train Acc: 0.702692 Loss: 0.937650 | Val Acc: 0.741029 loss: 0.811189
saving model with acc 0.741


100%|██████████| 4648/4648 [00:37<00:00, 124.87it/s]
100%|██████████| 517/517 [00:02<00:00, 198.25it/s]


[038/125] Train Acc: 0.703342 Loss: 0.934969 | Val Acc: 0.741059 loss: 0.810001
saving model with acc 0.741


100%|██████████| 4648/4648 [00:38<00:00, 122.08it/s]
100%|██████████| 517/517 [00:02<00:00, 190.06it/s]


[039/125] Train Acc: 0.703283 Loss: 0.934708 | Val Acc: 0.741569 loss: 0.808575
saving model with acc 0.742


100%|██████████| 4648/4648 [00:37<00:00, 123.95it/s]
100%|██████████| 517/517 [00:02<00:00, 210.88it/s]


[040/125] Train Acc: 0.704204 Loss: 0.932304 | Val Acc: 0.741656 loss: 0.807397
saving model with acc 0.742


100%|██████████| 4648/4648 [00:37<00:00, 125.10it/s]
100%|██████████| 517/517 [00:02<00:00, 201.56it/s]


[041/125] Train Acc: 0.704292 Loss: 0.931311 | Val Acc: 0.741925 loss: 0.806518
saving model with acc 0.742


100%|██████████| 4648/4648 [00:37<00:00, 124.93it/s]
100%|██████████| 517/517 [00:02<00:00, 191.98it/s]


[042/125] Train Acc: 0.705030 Loss: 0.929504 | Val Acc: 0.742658 loss: 0.803987
saving model with acc 0.743


100%|██████████| 4648/4648 [00:38<00:00, 121.69it/s]
100%|██████████| 517/517 [00:02<00:00, 210.33it/s]


[043/125] Train Acc: 0.705459 Loss: 0.928299 | Val Acc: 0.742851 loss: 0.804078
saving model with acc 0.743


100%|██████████| 4648/4648 [00:37<00:00, 124.92it/s]
100%|██████████| 517/517 [00:02<00:00, 209.35it/s]


[044/125] Train Acc: 0.705656 Loss: 0.927327 | Val Acc: 0.743346 loss: 0.802945
saving model with acc 0.743


100%|██████████| 4648/4648 [00:37<00:00, 124.95it/s]
100%|██████████| 517/517 [00:03<00:00, 169.24it/s]


[045/125] Train Acc: 0.706192 Loss: 0.925421 | Val Acc: 0.743380 loss: 0.802203
saving model with acc 0.743


100%|██████████| 4648/4648 [00:37<00:00, 124.18it/s]
100%|██████████| 517/517 [00:02<00:00, 208.90it/s]


[046/125] Train Acc: 0.706505 Loss: 0.923757 | Val Acc: 0.744215 loss: 0.801452
saving model with acc 0.744


100%|██████████| 4648/4648 [00:38<00:00, 121.82it/s]
100%|██████████| 517/517 [00:02<00:00, 210.88it/s]


[047/125] Train Acc: 0.706706 Loss: 0.922901 | Val Acc: 0.744196 loss: 0.799321


100%|██████████| 4648/4648 [00:37<00:00, 123.96it/s]
100%|██████████| 517/517 [00:02<00:00, 187.25it/s]


[048/125] Train Acc: 0.707021 Loss: 0.922119 | Val Acc: 0.743762 loss: 0.798886


100%|██████████| 4648/4648 [00:37<00:00, 123.91it/s]
100%|██████████| 517/517 [00:02<00:00, 199.26it/s]


[049/125] Train Acc: 0.707667 Loss: 0.920499 | Val Acc: 0.744230 loss: 0.798076
saving model with acc 0.744


100%|██████████| 4648/4648 [00:37<00:00, 122.51it/s]
100%|██████████| 517/517 [00:02<00:00, 204.25it/s]


[050/125] Train Acc: 0.707981 Loss: 0.918806 | Val Acc: 0.744567 loss: 0.798235
saving model with acc 0.745


100%|██████████| 4648/4648 [00:38<00:00, 121.60it/s]
100%|██████████| 517/517 [00:03<00:00, 165.10it/s]


[051/125] Train Acc: 0.708582 Loss: 0.917083 | Val Acc: 0.744877 loss: 0.795632
saving model with acc 0.745


100%|██████████| 4648/4648 [00:37<00:00, 124.38it/s]
100%|██████████| 517/517 [00:02<00:00, 199.92it/s]


[052/125] Train Acc: 0.708274 Loss: 0.917247 | Val Acc: 0.744790 loss: 0.796523


100%|██████████| 4648/4648 [00:37<00:00, 124.64it/s]
100%|██████████| 517/517 [00:02<00:00, 208.36it/s]


[053/125] Train Acc: 0.708963 Loss: 0.915988 | Val Acc: 0.746181 loss: 0.794254
saving model with acc 0.746


100%|██████████| 4648/4648 [00:37<00:00, 123.60it/s]
100%|██████████| 517/517 [00:03<00:00, 164.93it/s]


[054/125] Train Acc: 0.709088 Loss: 0.914648 | Val Acc: 0.745606 loss: 0.793664


100%|██████████| 4648/4648 [00:37<00:00, 122.77it/s]
100%|██████████| 517/517 [00:02<00:00, 207.38it/s]


[055/125] Train Acc: 0.709480 Loss: 0.913479 | Val Acc: 0.746045 loss: 0.792615


100%|██████████| 4648/4648 [00:37<00:00, 123.29it/s]
100%|██████████| 517/517 [00:02<00:00, 197.43it/s]


[056/125] Train Acc: 0.710065 Loss: 0.912454 | Val Acc: 0.745973 loss: 0.792843


100%|██████████| 4648/4648 [00:38<00:00, 121.75it/s]
100%|██████████| 517/517 [00:02<00:00, 201.20it/s]


[057/125] Train Acc: 0.710089 Loss: 0.911280 | Val Acc: 0.746593 loss: 0.791360
saving model with acc 0.747


100%|██████████| 4648/4648 [00:37<00:00, 123.77it/s]
100%|██████████| 517/517 [00:02<00:00, 199.63it/s]


[058/125] Train Acc: 0.710093 Loss: 0.911027 | Val Acc: 0.746714 loss: 0.790187
saving model with acc 0.747


100%|██████████| 4648/4648 [00:38<00:00, 121.33it/s]
100%|██████████| 517/517 [00:03<00:00, 163.94it/s]


[059/125] Train Acc: 0.710395 Loss: 0.909737 | Val Acc: 0.746547 loss: 0.790453


100%|██████████| 4648/4648 [00:37<00:00, 123.18it/s]
100%|██████████| 517/517 [00:02<00:00, 198.54it/s]


[060/125] Train Acc: 0.710762 Loss: 0.909123 | Val Acc: 0.746600 loss: 0.791052


100%|██████████| 4648/4648 [00:37<00:00, 123.65it/s]
100%|██████████| 517/517 [00:02<00:00, 209.00it/s]


[061/125] Train Acc: 0.711213 Loss: 0.907672 | Val Acc: 0.747080 loss: 0.789816
saving model with acc 0.747


100%|██████████| 4648/4648 [00:37<00:00, 122.72it/s]
100%|██████████| 517/517 [00:03<00:00, 165.17it/s]


[062/125] Train Acc: 0.711510 Loss: 0.907378 | Val Acc: 0.747194 loss: 0.788405
saving model with acc 0.747


100%|██████████| 4648/4648 [00:38<00:00, 121.71it/s]
100%|██████████| 517/517 [00:02<00:00, 197.04it/s]


[063/125] Train Acc: 0.711805 Loss: 0.906017 | Val Acc: 0.747530 loss: 0.787869
saving model with acc 0.748


100%|██████████| 4648/4648 [00:37<00:00, 123.50it/s]
100%|██████████| 517/517 [00:02<00:00, 182.93it/s]


[064/125] Train Acc: 0.711916 Loss: 0.904993 | Val Acc: 0.748074 loss: 0.787303
saving model with acc 0.748


100%|██████████| 4648/4648 [00:38<00:00, 121.89it/s]
100%|██████████| 517/517 [00:02<00:00, 201.85it/s]


[065/125] Train Acc: 0.711963 Loss: 0.904755 | Val Acc: 0.747519 loss: 0.786946


100%|██████████| 4648/4648 [00:37<00:00, 123.39it/s]
100%|██████████| 517/517 [00:02<00:00, 206.66it/s]


[066/125] Train Acc: 0.712782 Loss: 0.903205 | Val Acc: 0.748312 loss: 0.786008
saving model with acc 0.748


100%|██████████| 4648/4648 [00:38<00:00, 121.38it/s]
100%|██████████| 517/517 [00:03<00:00, 156.17it/s]


[067/125] Train Acc: 0.712493 Loss: 0.903464 | Val Acc: 0.748218 loss: 0.786170


100%|██████████| 4648/4648 [00:37<00:00, 123.28it/s]
100%|██████████| 517/517 [00:02<00:00, 209.80it/s]


[068/125] Train Acc: 0.712675 Loss: 0.902231 | Val Acc: 0.748471 loss: 0.784182
saving model with acc 0.748


100%|██████████| 4648/4648 [00:37<00:00, 123.18it/s]
100%|██████████| 517/517 [00:02<00:00, 208.72it/s]


[069/125] Train Acc: 0.713126 Loss: 0.901718 | Val Acc: 0.748528 loss: 0.785082
saving model with acc 0.749


100%|██████████| 4648/4648 [00:37<00:00, 122.92it/s]
100%|██████████| 517/517 [00:02<00:00, 172.62it/s]


[070/125] Train Acc: 0.713283 Loss: 0.900312 | Val Acc: 0.749442 loss: 0.783716
saving model with acc 0.749


100%|██████████| 4648/4648 [00:38<00:00, 121.23it/s]
100%|██████████| 517/517 [00:02<00:00, 201.11it/s]


[071/125] Train Acc: 0.713143 Loss: 0.899812 | Val Acc: 0.748944 loss: 0.783078


100%|██████████| 4648/4648 [00:37<00:00, 123.81it/s]
100%|██████████| 517/517 [00:02<00:00, 192.41it/s]


[072/125] Train Acc: 0.713266 Loss: 0.899726 | Val Acc: 0.749174 loss: 0.783356


100%|██████████| 4648/4648 [00:37<00:00, 123.12it/s]
100%|██████████| 517/517 [00:02<00:00, 206.89it/s]


[073/125] Train Acc: 0.713821 Loss: 0.898377 | Val Acc: 0.749726 loss: 0.781408
saving model with acc 0.750


100%|██████████| 4648/4648 [00:37<00:00, 123.72it/s]
100%|██████████| 517/517 [00:02<00:00, 206.97it/s]


[074/125] Train Acc: 0.713961 Loss: 0.897829 | Val Acc: 0.748826 loss: 0.781940


100%|██████████| 4648/4648 [00:38<00:00, 122.10it/s]
100%|██████████| 517/517 [00:03<00:00, 160.28it/s]


[075/125] Train Acc: 0.714050 Loss: 0.897442 | Val Acc: 0.749235 loss: 0.782944


100%|██████████| 4648/4648 [00:37<00:00, 123.64it/s]
100%|██████████| 517/517 [00:02<00:00, 210.01it/s]


[076/125] Train Acc: 0.714582 Loss: 0.896515 | Val Acc: 0.749877 loss: 0.780105
saving model with acc 0.750


100%|██████████| 4648/4648 [00:37<00:00, 123.08it/s]
100%|██████████| 517/517 [00:02<00:00, 207.96it/s]


[077/125] Train Acc: 0.714814 Loss: 0.895863 | Val Acc: 0.750327 loss: 0.780154
saving model with acc 0.750


100%|██████████| 4648/4648 [00:37<00:00, 123.56it/s]
100%|██████████| 517/517 [00:03<00:00, 154.21it/s]


[078/125] Train Acc: 0.714619 Loss: 0.895944 | Val Acc: 0.750225 loss: 0.779623


100%|██████████| 4648/4648 [00:38<00:00, 120.93it/s]
100%|██████████| 517/517 [00:02<00:00, 193.35it/s]


[079/125] Train Acc: 0.714614 Loss: 0.894854 | Val Acc: 0.750270 loss: 0.779515


100%|██████████| 4648/4648 [00:37<00:00, 122.59it/s]
100%|██████████| 517/517 [00:02<00:00, 208.14it/s]


[080/125] Train Acc: 0.715240 Loss: 0.894218 | Val Acc: 0.749575 loss: 0.779937


100%|██████████| 4648/4648 [00:38<00:00, 121.97it/s]
100%|██████████| 517/517 [00:02<00:00, 201.73it/s]


[081/125] Train Acc: 0.715133 Loss: 0.893461 | Val Acc: 0.749802 loss: 0.779406


100%|██████████| 4648/4648 [00:37<00:00, 124.34it/s]
100%|██████████| 517/517 [00:02<00:00, 197.49it/s]


[082/125] Train Acc: 0.715242 Loss: 0.893039 | Val Acc: 0.750690 loss: 0.779006
saving model with acc 0.751


100%|██████████| 4648/4648 [00:38<00:00, 122.29it/s]
100%|██████████| 517/517 [00:02<00:00, 184.42it/s]


[083/125] Train Acc: 0.715592 Loss: 0.892190 | Val Acc: 0.751098 loss: 0.777778
saving model with acc 0.751


100%|██████████| 4648/4648 [00:37<00:00, 122.57it/s]
100%|██████████| 517/517 [00:02<00:00, 204.66it/s]


[084/125] Train Acc: 0.715829 Loss: 0.892176 | Val Acc: 0.750569 loss: 0.778600


100%|██████████| 4648/4648 [00:37<00:00, 123.48it/s]
100%|██████████| 517/517 [00:02<00:00, 204.71it/s]


[085/125] Train Acc: 0.716130 Loss: 0.890922 | Val Acc: 0.751204 loss: 0.777053
saving model with acc 0.751


100%|██████████| 4648/4648 [00:37<00:00, 123.75it/s]
100%|██████████| 517/517 [00:03<00:00, 158.55it/s]


[086/125] Train Acc: 0.716110 Loss: 0.890586 | Val Acc: 0.750652 loss: 0.777406


100%|██████████| 4648/4648 [00:38<00:00, 121.50it/s]
100%|██████████| 517/517 [00:02<00:00, 207.76it/s]


[087/125] Train Acc: 0.716026 Loss: 0.890052 | Val Acc: 0.751427 loss: 0.775782
saving model with acc 0.751


100%|██████████| 4648/4648 [00:37<00:00, 124.11it/s]
100%|██████████| 517/517 [00:02<00:00, 211.23it/s]


[088/125] Train Acc: 0.716239 Loss: 0.889917 | Val Acc: 0.751083 loss: 0.775932


100%|██████████| 4648/4648 [00:37<00:00, 123.46it/s]
100%|██████████| 517/517 [00:03<00:00, 159.59it/s]


[089/125] Train Acc: 0.716508 Loss: 0.889480 | Val Acc: 0.750531 loss: 0.775702


100%|██████████| 4648/4648 [00:38<00:00, 121.69it/s]
100%|██████████| 517/517 [00:02<00:00, 191.85it/s]


[090/125] Train Acc: 0.716457 Loss: 0.889019 | Val Acc: 0.751177 loss: 0.775911


100%|██████████| 4648/4648 [00:38<00:00, 121.68it/s]
100%|██████████| 517/517 [00:02<00:00, 196.47it/s]


[091/125] Train Acc: 0.717084 Loss: 0.888015 | Val Acc: 0.751457 loss: 0.774652
saving model with acc 0.751


100%|██████████| 4648/4648 [00:37<00:00, 122.52it/s]
100%|██████████| 517/517 [00:02<00:00, 207.45it/s]


[092/125] Train Acc: 0.717177 Loss: 0.887614 | Val Acc: 0.751922 loss: 0.775109
saving model with acc 0.752


100%|██████████| 4648/4648 [00:37<00:00, 124.01it/s]
100%|██████████| 517/517 [00:02<00:00, 195.11it/s]


[093/125] Train Acc: 0.717351 Loss: 0.886969 | Val Acc: 0.751529 loss: 0.774440


100%|██████████| 4648/4648 [00:37<00:00, 122.32it/s]
100%|██████████| 517/517 [00:03<00:00, 163.81it/s]


[094/125] Train Acc: 0.717280 Loss: 0.886713 | Val Acc: 0.752164 loss: 0.773264
saving model with acc 0.752


100%|██████████| 4648/4648 [00:38<00:00, 119.29it/s]
100%|██████████| 517/517 [00:02<00:00, 213.58it/s]


[095/125] Train Acc: 0.717028 Loss: 0.886445 | Val Acc: 0.752678 loss: 0.773201
saving model with acc 0.753


100%|██████████| 4648/4648 [00:37<00:00, 123.74it/s]
100%|██████████| 517/517 [00:02<00:00, 207.86it/s]


[096/125] Train Acc: 0.717388 Loss: 0.885721 | Val Acc: 0.751673 loss: 0.774785


100%|██████████| 4648/4648 [00:37<00:00, 122.76it/s]
100%|██████████| 517/517 [00:03<00:00, 153.79it/s]


[097/125] Train Acc: 0.717639 Loss: 0.885497 | Val Acc: 0.752077 loss: 0.772389


100%|██████████| 4648/4648 [00:38<00:00, 122.07it/s]
100%|██████████| 517/517 [00:02<00:00, 206.40it/s]


[098/125] Train Acc: 0.717441 Loss: 0.884702 | Val Acc: 0.751544 loss: 0.773373


100%|██████████| 4648/4648 [00:39<00:00, 118.51it/s]
100%|██████████| 517/517 [00:03<00:00, 156.22it/s]


[099/125] Train Acc: 0.717953 Loss: 0.884107 | Val Acc: 0.752512 loss: 0.772170


100%|██████████| 4648/4648 [00:39<00:00, 118.61it/s]
100%|██████████| 517/517 [00:02<00:00, 200.42it/s]


[100/125] Train Acc: 0.718162 Loss: 0.883792 | Val Acc: 0.752175 loss: 0.771498


100%|██████████| 4648/4648 [00:38<00:00, 121.08it/s]
100%|██████████| 517/517 [00:02<00:00, 190.53it/s]


[101/125] Train Acc: 0.718087 Loss: 0.883242 | Val Acc: 0.752848 loss: 0.771483
saving model with acc 0.753


100%|██████████| 4648/4648 [00:38<00:00, 120.13it/s]
100%|██████████| 517/517 [00:02<00:00, 206.06it/s]


[102/125] Train Acc: 0.718551 Loss: 0.882577 | Val Acc: 0.752205 loss: 0.771988


100%|██████████| 4648/4648 [00:38<00:00, 120.16it/s]
100%|██████████| 517/517 [00:02<00:00, 204.29it/s]


[103/125] Train Acc: 0.718326 Loss: 0.883337 | Val Acc: 0.753419 loss: 0.771652
saving model with acc 0.753


100%|██████████| 4648/4648 [00:38<00:00, 122.00it/s]
100%|██████████| 517/517 [00:03<00:00, 153.26it/s]


[104/125] Train Acc: 0.718378 Loss: 0.882071 | Val Acc: 0.752780 loss: 0.771144


100%|██████████| 4648/4648 [00:37<00:00, 122.35it/s]
100%|██████████| 517/517 [00:02<00:00, 195.11it/s]


[105/125] Train Acc: 0.718429 Loss: 0.881710 | Val Acc: 0.752803 loss: 0.770550


100%|██████████| 4648/4648 [00:37<00:00, 123.02it/s]
100%|██████████| 517/517 [00:02<00:00, 199.66it/s]


[106/125] Train Acc: 0.718793 Loss: 0.881326 | Val Acc: 0.753302 loss: 0.769728


100%|██████████| 4648/4648 [00:39<00:00, 119.06it/s]
100%|██████████| 517/517 [00:02<00:00, 199.46it/s]


[107/125] Train Acc: 0.719113 Loss: 0.880083 | Val Acc: 0.752319 loss: 0.770421


100%|██████████| 4648/4648 [00:37<00:00, 123.64it/s]
100%|██████████| 517/517 [00:02<00:00, 197.23it/s]


[108/125] Train Acc: 0.719190 Loss: 0.880127 | Val Acc: 0.752591 loss: 0.770391


100%|██████████| 4648/4648 [00:37<00:00, 122.89it/s]
100%|██████████| 517/517 [00:03<00:00, 169.99it/s]


[109/125] Train Acc: 0.719198 Loss: 0.879770 | Val Acc: 0.753449 loss: 0.769133
saving model with acc 0.753


100%|██████████| 4648/4648 [00:38<00:00, 119.96it/s]
100%|██████████| 517/517 [00:02<00:00, 200.34it/s]


[110/125] Train Acc: 0.719317 Loss: 0.879428 | Val Acc: 0.752746 loss: 0.769864


100%|██████████| 4648/4648 [00:38<00:00, 120.36it/s]
100%|██████████| 517/517 [00:02<00:00, 204.70it/s]


[111/125] Train Acc: 0.719533 Loss: 0.879503 | Val Acc: 0.752727 loss: 0.769711


100%|██████████| 4648/4648 [00:38<00:00, 121.64it/s]
100%|██████████| 517/517 [00:02<00:00, 179.42it/s]


[112/125] Train Acc: 0.719587 Loss: 0.878495 | Val Acc: 0.753286 loss: 0.769010


100%|██████████| 4648/4648 [00:37<00:00, 123.07it/s]
100%|██████████| 517/517 [00:02<00:00, 207.80it/s]


[113/125] Train Acc: 0.719375 Loss: 0.879037 | Val Acc: 0.753464 loss: 0.769688
saving model with acc 0.753


100%|██████████| 4648/4648 [00:37<00:00, 122.97it/s]
100%|██████████| 517/517 [00:02<00:00, 172.94it/s]


[114/125] Train Acc: 0.719304 Loss: 0.878729 | Val Acc: 0.753264 loss: 0.769182


100%|██████████| 4648/4648 [00:39<00:00, 118.54it/s]
100%|██████████| 517/517 [00:02<00:00, 205.50it/s]


[115/125] Train Acc: 0.720031 Loss: 0.877500 | Val Acc: 0.754039 loss: 0.767620
saving model with acc 0.754


100%|██████████| 4648/4648 [00:37<00:00, 122.59it/s]
100%|██████████| 517/517 [00:02<00:00, 196.98it/s]


[116/125] Train Acc: 0.719991 Loss: 0.876854 | Val Acc: 0.753986 loss: 0.767915


100%|██████████| 4648/4648 [00:37<00:00, 122.56it/s]
100%|██████████| 517/517 [00:03<00:00, 161.54it/s]


[117/125] Train Acc: 0.719599 Loss: 0.877218 | Val Acc: 0.754526 loss: 0.768112
saving model with acc 0.755


100%|██████████| 4648/4648 [00:37<00:00, 122.99it/s]
100%|██████████| 517/517 [00:02<00:00, 207.87it/s]


[118/125] Train Acc: 0.719934 Loss: 0.876613 | Val Acc: 0.754073 loss: 0.766987


100%|██████████| 4648/4648 [00:37<00:00, 123.58it/s]
100%|██████████| 517/517 [00:02<00:00, 185.19it/s]


[119/125] Train Acc: 0.720071 Loss: 0.876500 | Val Acc: 0.754016 loss: 0.766540


100%|██████████| 4648/4648 [00:39<00:00, 118.66it/s]
100%|██████████| 517/517 [00:02<00:00, 180.76it/s]


[120/125] Train Acc: 0.720288 Loss: 0.875790 | Val Acc: 0.753162 loss: 0.767447


100%|██████████| 4648/4648 [00:38<00:00, 122.27it/s]
100%|██████████| 517/517 [00:02<00:00, 206.14it/s]


[121/125] Train Acc: 0.720874 Loss: 0.875049 | Val Acc: 0.754299 loss: 0.766008


100%|██████████| 4648/4648 [00:37<00:00, 122.89it/s]
100%|██████████| 517/517 [00:02<00:00, 174.46it/s]


[122/125] Train Acc: 0.720804 Loss: 0.875379 | Val Acc: 0.753971 loss: 0.767497


100%|██████████| 4648/4648 [00:37<00:00, 122.51it/s]
100%|██████████| 517/517 [00:02<00:00, 194.08it/s]


[123/125] Train Acc: 0.720650 Loss: 0.874646 | Val Acc: 0.753453 loss: 0.767000


100%|██████████| 4648/4648 [00:38<00:00, 120.79it/s]
100%|██████████| 517/517 [00:02<00:00, 199.40it/s]


[124/125] Train Acc: 0.720581 Loss: 0.874749 | Val Acc: 0.754341 loss: 0.767414


100%|██████████| 4648/4648 [00:37<00:00, 123.57it/s]
100%|██████████| 517/517 [00:03<00:00, 168.36it/s]

[125/125] Train Acc: 0.720630 Loss: 0.874394 | Val Acc: 0.754349 loss: 0.766653





In [None]:
del train_loader, val_loader
gc.collect()

0

## Testing
Create a testing dataset, and load model from the saved checkpoint.

In [None]:
# load data
test_X = preprocess_data(split='test', feat_dir='./libriphone/feat', phone_path='./libriphone', concat_nframes=concat_nframes)
test_set = LibriDataset(test_X, None)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)

[Dataset] - # phone classes: 41, number of utterances for test: 1078


  feat = torch.load(path)
1078it [00:05, 180.06it/s]

[INFO] test set
torch.Size([646268, 819])





In [None]:
# load model
model = Classifier(input_dim=input_dim, hidden_layers=hidden_layers, hidden_dim=hidden_dim).to(device)
model.load_state_dict(torch.load(model_path))

  model.load_state_dict(torch.load(model_path))


<All keys matched successfully>

Make prediction.

In [None]:
test_acc = 0.0
test_lengths = 0
pred = np.array([], dtype=np.int32)

model.eval()
with torch.no_grad():
    for i, batch in enumerate(tqdm(test_loader)):
        features = batch
        features = features.to(device)

        outputs = model(features)

        _, test_pred = torch.max(outputs, 1) # get the index of the class with the highest probability
        pred = np.concatenate((pred, test_pred.cpu().numpy()), axis=0)


100%|██████████| 1263/1263 [00:05<00:00, 235.19it/s]


Write prediction to a CSV file.

After finish running this block, download the file `prediction.csv` from the files section on the left-hand side and submit it to Kaggle.

In [None]:
with open('prediction.csv', 'w') as f:
    f.write('Id,Class\n')
    for i, y in enumerate(pred):
        f.write('{},{}\n'.format(i, y))