In [None]:
import logging
import math,random
import time

import torch
import torch.nn as nn
import torch.nn.functional as F

from models.Attention import Attention
from models.Embeddings.BERT import BERT_Embedding
from models.Embeddings.GLOVE import GLOVE_Embedding
from models.Encoders.CNN import CNN_Encoder
from models.Encoders.RNN import RNN_User_Encoder
from models.Interactors.CNN import CNN_Interactor
from models.Interactors.FIM import FIM_Interactor
from models.Interactors.KNRM import KNRM_Interactor
from models.Interactors.BERT_Overlook import BERT_Interactor
from models.Interactors.BERT_Onepass import BERT_Interactor

from models.ESM import ESM

from data.configs.demo import config
from utils.utils import prepare
from utils.Manager import Manager

from models.Modules.DRM import DRM_Matching
# from models.Modules.TFM import TFM

In [None]:
config.his_size = 10
config.embedding = 'bert'
config.hidden_dim = 768
config.device = 1
config.learning_rate = 1e-6

# config.path = "C:/"

manager = Manager(config)

vocab, loaders = prepare(manager)
record = next(iter(loaders[0]))

In [None]:
embedding = BERT_Embedding(manager)
encoderN = CNN_Encoder(manager)
encoderU = RNN_User_Encoder(manager)
docReducer = DRM_Matching(manager)
# termFuser = TFM(manager.his_size, manager.k)
# interactor = CNN_Interactor(manager)

interactor = BERT_Interactor(manager)

esm = ESM(manager, embedding, encoderN, encoderU, docReducer, None, interactor).to(manager.device)

In [20]:
name_p = list(esm.named_parameters())
p = list(esm.parameters())

In [22]:
p[3].shape

torch.Size([1, 768])

In [24]:
name_p[3][1].shape

torch.Size([1, 768])

In [33]:
from itertools import chain
base_params = chain(*[v.named_parameters() for k,v in esm.named_children() if k not in ['embedding','interactor']])

In [26]:
[n[0] for n in name_p]

['embedding.embedding.weight',
 'embedding.layerNorm.weight',
 'embedding.layerNorm.bias',
 'encoderN.query_words',
 'encoderN.wordQueryProject.weight',
 'encoderN.wordQueryProject.bias',
 'encoderN.CNN.weight',
 'encoderN.CNN.bias',
 'encoderN.layerNorm.weight',
 'encoderN.layerNorm.bias',
 'encoderU.lstm.weight_ih_l0',
 'encoderU.lstm.weight_hh_l0',
 'encoderU.lstm.bias_ih_l0',
 'encoderU.lstm.bias_hh_l0',
 'interactor.order_embedding',
 'interactor.cdd_pos_embedding',
 'interactor.pst_pos_embedding',
 'interactor.sep_embedding',
 'interactor.cls_embedding',
 'interactor.bert.layer.0.attention.self.query.weight',
 'interactor.bert.layer.0.attention.self.query.bias',
 'interactor.bert.layer.0.attention.self.key.weight',
 'interactor.bert.layer.0.attention.self.key.bias',
 'interactor.bert.layer.0.attention.self.value.weight',
 'interactor.bert.layer.0.attention.self.value.bias',
 'interactor.bert.layer.0.attention.output.dense.weight',
 'interactor.bert.layer.0.attention.output.dense.

In [34]:
my_name = [x[0] for x in list(base_params)]

In [36]:
next(esm.parameters())

Parameter containing:
tensor([[-0.0102, -0.0615, -0.0265,  ..., -0.0199, -0.0372, -0.0098],
        [-0.0117, -0.0600, -0.0323,  ..., -0.0168, -0.0401, -0.0107],
        [-0.0198, -0.0627, -0.0326,  ..., -0.0165, -0.0420, -0.0032],
        ...,
        [-0.0218, -0.0556, -0.0135,  ..., -0.0043, -0.0151, -0.0249],
        [-0.0462, -0.0565, -0.0019,  ...,  0.0157, -0.0139, -0.0095],
        [ 0.0015, -0.0821, -0.0160,  ..., -0.0081, -0.0475,  0.0753]],
       device='cuda:1', requires_grad=True)

In [37]:
from itertools import cycle
c = cycle(base_params)

In [40]:
list(c)

[]