In [1]:
import logging
import math,random
import time

import torch
import torch.nn as nn
import torch.nn.functional as F

from models.Attention import Attention
from models.Embeddings.BERT import BERT_Embedding
from models.Embeddings.GLOVE import GLOVE_Embedding
from models.Encoders.CNN import CNN_Encoder
from models.Encoders.RNN import RNN_User_Encoder
from models.Interactors.CNN import CNN_Interactor
from models.Interactors.FIM import FIM_Interactor
from models.Interactors.KNRM import KNRM_Interactor
from models.Interactors.BERT import BERT_Interactor

from models.ESM import ESM

from data.configs.demo import config
from utils.utils import prepare
from utils.Manager import Manager

from models.Modules.DRM import DRM_Matching
# from models.Modules.TFM import TFM

ModuleNotFoundError: No module named 'transformers'

In [3]:
config.his_size = 10
config.embedding = 'bert'
config.hidden_dim = 768
config.device = 0
config.learning_rate = 1e-6

config.path = "C:/"

manager = Manager(config)

vocab, loaders = prepare(manager)
record = next(iter(loaders[0]))

[2021-08-09 13:01:48,774] INFO (root) Hyper Parameters are scale:demo, mode:tune, epochs:8, batch_size:10, k:3, threshold:0, title_length:20, abs_length:40, signal_length:50, npratio:4, his_size:10, dropout_p:0.2, device:0, learning_rate:1e-06, metrics:auc,mean_mrr,ndcg@5,ndcg@10, embedding:bert, embedding_dim:300, hidden_dim:768, rank:0, world_size:0, step:[0], seeds:42, interval:10, val_freq:2, schedule:None, path:C:/, tb:False, bert:bert-base-uncased, cdd_size:5
[2021-08-09 13:01:48,776] INFO (root) preparing dataset...


In [3]:
embedding = BERT_Embedding(manager)
encoderN = CNN_Encoder(manager)
encoderU = RNN_User_Encoder(manager)
docReducer = DRM_Matching(manager)
# termFuser = TFM(manager.his_size, manager.k)
# interactor = CNN_Interactor(manager)
interactor = BERT_Interactor(manager)

esm = ESM(manager, embedding, encoderN, encoderU, docReducer, None, interactor).to(manager.device)

In [4]:
esm(record)

tensor([[-1.6428, -1.6586, -1.5995, -1.5979, -1.5519],
        [-1.6173, -1.6472, -1.5759, -1.6027, -1.6053],
        [-1.5719, -1.6049, -1.6151, -1.6163, -1.6402],
        [-1.6038, -1.6090, -1.6464, -1.6010, -1.5880],
        [-1.5379, -1.6163, -1.6820, -1.6166, -1.5996],
        [-1.6240, -1.6177, -1.6030, -1.6219, -1.5813],
        [-1.6487, -1.6017, -1.6328, -1.5972, -1.5687],
        [-1.6289, -1.6082, -1.5650, -1.6197, -1.6268],
        [-1.6331, -1.6486, -1.5963, -1.5826, -1.5883],
        [-1.6860, -1.5988, -1.6386, -1.6056, -1.5251]], device='cuda:0',
       grad_fn=<LogSoftmaxBackward>)

In [5]:
manager.tune(esm, loaders)

[2021-08-08 02:23:37,875] INFO (utils.Manager) training...
epoch 1 , step 140 , loss: 1.5820:  49%|████▉     | 146/295 [00:46<00:47,  3.12it/s][2021-08-08 02:24:24,940] INFO (utils.Manager) evaluating...






100%|██████████| 1812/1812 [00:55<00:00, 32.52it/s]
[2021-08-08 02:25:21,153] INFO (utils.Manager) current result of esm_cnn_rnn-user-encoder_matching-based_bert is {'auc': 0.4966, 'mean_mrr': 0.2273, 'ndcg@5': 0.2398, 'ndcg@10': 0.3023, 'epoch': 1, 'step': 146}
epoch 1 , step 290 , loss: 1.5815:  99%|█████████▉| 292/295 [02:30<00:01,  1.94it/s][2021-08-08 02:26:08,356] INFO (utils.Manager) evaluating...






100%|██████████| 1812/1812 [01:00<00:00, 29.90it/s]
[2021-08-08 02:27:09,454] INFO (utils.Manager) current result of esm_cnn_rnn-user-encoder_matching-based_bert is {'auc': 0.4976, 'mean_mrr': 0.2288, 'ndcg@5': 0.2282, 'ndcg@10': 0.2929, 'epoch': 1, 'step': 292}
epoch 1 , step 290 , loss: 1.5815: 100%|██████████| 295/295 [03:32<00:00,  1.39it/s]
epoch 2 , step 140 , loss: 1.6237:  49%|████▉     | 146/295 [00:46<00:47,  3.11it/s][2021-08-08 02:27:57,494] INFO (utils.Manager) evaluating...






100%|██████████| 1812/1812 [01:00<00:00, 30.08it/s]
[2021-08-08 02:28:58,234] INFO (utils.Manager) current result of esm_cnn_rnn-user-encoder_matching-based_bert is {'auc': 0.5059, 'mean_mrr': 0.2279, 'ndcg@5': 0.2332, 'ndcg@10': 0.2922, 'epoch': 2, 'step': 146}
epoch 2 , step 290 , loss: 1.6153:  99%|█████████▉| 292/295 [02:34<00:01,  1.89it/s][2021-08-08 02:29:45,378] INFO (utils.Manager) evaluating...






100%|██████████| 1812/1812 [00:59<00:00, 30.23it/s]
[2021-08-08 02:30:45,788] INFO (utils.Manager) current result of esm_cnn_rnn-user-encoder_matching-based_bert is {'auc': 0.5086, 'mean_mrr': 0.2311, 'ndcg@5': 0.233, 'ndcg@10': 0.2967, 'epoch': 2, 'step': 292}
epoch 2 , step 290 , loss: 1.6153: 100%|██████████| 295/295 [03:36<00:00,  1.36it/s]
epoch 3 , step 140 , loss: 1.6180:  49%|████▉     | 146/295 [00:47<00:48,  3.04it/s][2021-08-08 02:31:34,898] INFO (utils.Manager) evaluating...






100%|██████████| 1812/1812 [00:58<00:00, 30.77it/s]
[2021-08-08 02:32:34,268] INFO (utils.Manager) current result of esm_cnn_rnn-user-encoder_matching-based_bert is {'auc': 0.4918, 'mean_mrr': 0.2242, 'ndcg@5': 0.2221, 'ndcg@10': 0.2813, 'epoch': 3, 'step': 146}
epoch 3 , step 290 , loss: 1.6101:  99%|█████████▉| 292/295 [02:34<00:01,  1.90it/s][2021-08-08 02:33:20,925] INFO (utils.Manager) evaluating...






100%|██████████| 1812/1812 [00:59<00:00, 30.28it/s]
[2021-08-08 02:34:21,232] INFO (utils.Manager) current result of esm_cnn_rnn-user-encoder_matching-based_bert is {'auc': 0.5133, 'mean_mrr': 0.2376, 'ndcg@5': 0.2347, 'ndcg@10': 0.2989, 'epoch': 3, 'step': 292}
epoch 3 , step 290 , loss: 1.6101: 100%|██████████| 295/295 [03:35<00:00,  1.37it/s]
epoch 4 , step 140 , loss: 1.6173:  49%|████▉     | 146/295 [00:46<00:47,  3.12it/s][2021-08-08 02:35:09,128] INFO (utils.Manager) evaluating...






100%|██████████| 1812/1812 [00:59<00:00, 30.50it/s]
[2021-08-08 02:36:09,024] INFO (utils.Manager) current result of esm_cnn_rnn-user-encoder_matching-based_bert is {'auc': 0.4922, 'mean_mrr': 0.2182, 'ndcg@5': 0.2153, 'ndcg@10': 0.2786, 'epoch': 4, 'step': 146}
epoch 4 , step 270 , loss: 1.6103:  92%|█████████▏| 271/295 [02:26<00:12,  1.85it/s]