In [1]:
import numpy as np

In [2]:
import argparse
import torch
from tqdm import tqdm
import model.metric as module_metric
import model.loss as module_loss
from parse_config import ConfigParser
from utils.util import create_model, create_dataloader


In [3]:
args = argparse.ArgumentParser(description='PyTorch Template')
args.add_argument('-c', '--config', default="test.json", type=str,
                  help='config file path (default: None)')
args.add_argument('-r', '--resume', default=None, type=str,
                  help='path to latest checkpoint (default: None)')
args.add_argument('-d', '--device', default=None, type=str,
                  help='indices of GPUs to enable (default: all)')

_StoreAction(option_strings=['-d', '--device'], dest='device', nargs=None, const=None, default=None, type=<class 'str'>, choices=None, help='indices of GPUs to enable (default: all)', metavar=None)

In [4]:
config = ConfigParser.from_args(args)

In [5]:
#config._config

In [6]:
model = create_model(config)

model [AMER] was created


In [7]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.initialize(config, device)

In [8]:
# model.load_state_dict(torch.load("data/pretrained/train_test_9.pth"))

In [9]:
checkpoint = torch.load("data/pretrained/train_test_9.pth")
state_dict = checkpoint['state_dict']

In [10]:
model.load_state_dict(state_dict)

<All keys matched successfully>

In [11]:
# prepare model for testing
model = model.to(device)
model.eval()

AMER(
  (attn): ScaledDotProductAttention(
    (dropout): Dropout(p=0, inplace=False)
  )
  (enc_v): Sequential(
    (0): Linear(in_features=4302, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=384, bias=True)
    (3): ReLU()
    (4): Linear(in_features=384, out_features=256, bias=True)
  )
  (enc_a): Sequential(
    (0): Linear(in_features=6373, out_features=1024, bias=True)
    (1): ReLU()
    (2): Linear(in_features=1024, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=256, bias=True)
  )
  (enc_t): Sequential(
    (0): Linear(in_features=1024, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=256, bias=True)
  )
  (enc_p): Sequential(
    (0): Linear(in_features=118, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=256, bias=True)
  )
  (out_layer): Sequential(
    (0): Linear(in_features=512, out_features=256, bia

In [12]:
metric_fns = [getattr(module_metric, met) for met in config['metrics']]
loss_fn = getattr(module_loss, config['loss'])
total_loss = 0.0
total_metrics = torch.zeros(len(metric_fns))

In [13]:
data_loader = create_dataloader(config)

Initializing VisualFeatureExtractor...
Initializing AudioFeatureExtracor...
Initializing TextFeatureExtractor...


  1%|▍                                                  | 14/1707 [00:00<00:12, 139.55it/s]

Initializing PersonalityFeatureExtractor...


100%|█████████████████████████████████████████████████| 1707/1707 [00:12<00:00, 139.68it/s]

dataset [MEmoRDataLoader] was created





In [14]:
import warnings
warnings.filterwarnings('ignore')

In [15]:
with torch.no_grad():
    for i, data in enumerate(tqdm(data_loader)):
        target, U_v, U_a, U_t, U_p, M_v, M_a, M_t, target_loc, umask, seg_len, n_c = [d.to(device) for d in data]
        seq_lengths = [(umask[j] == 1).nonzero().tolist()[-1][0] + 1 for j in range(len(umask))]
        output = model(U_v, U_a, U_t, U_p, M_v, M_a, M_t, seq_lengths, target_loc, seg_len, n_c)
        target = target.squeeze(1)
        loss = loss_fn(output, target)
        batch_size = U_v.shape[0]
        total_loss += loss.item() * batch_size
        for i, metric in enumerate(metric_fns):
            total_metrics[i] += metric(output, target) * batch_size

100%|████████████████████████████████████████████████████| 214/214 [00:12<00:00, 17.76it/s]


In [16]:
logger = config.get_logger('test')

In [17]:
n_samples = len(data_loader.sampler)
log = {'loss': total_loss / n_samples}
log.update({
    met.__name__: total_metrics[i].item() / n_samples for i, met in enumerate(metric_fns)
})
logger.info(log)

{'loss': 4.352610321860601, 'accuracy': 0.4774458113649678, 'macro_f1': 0.3351415400625183, 'weighted_f1': 0.46724096207710897}


In [18]:
config._config['anno_file']

'data/anno.json'

In [19]:
# all data
# {'loss': 0.9204926241213913, 'accuracy': 0.8795688847235239, 'macro_f1': 0.8151952713141402, 'weighted_f1': 0.8721875205929299}