In [None]:
import sys

from models.DeepDense import DeepDense
from models.TextLSTM import TextLSTM
from models.DNNAttr import DNNAttr
from models.Wide import Wide

from models.WideDeep import WideDeep
from optim.Initializer import KaimingNormal, XavierNormal
from optim.radam import RAdam
from pandas import DataFrame
from preprocessing.Preprocessor import WidePreprocessor, DeepPreprocessor, DeepTextPreprocessor, MultiDeepTextPreprocessor

import numpy as np
import pandas as pd
import torch, os
from torch.utils.data import DataLoader
import time
import torch


wide = Wide(wide_dim=17, output_dim=1)

# init deep_dense model
deep_column_idx = dict()
deep_column_idx['family_pred_gender'] = 0
deep_column_idx['family_pred_age_level'] = 1
deep_column_idx['category'] = 2
deep_column_idx['ott_uv_norm'] = 3
deep_column_idx['category_prefer'] = 4
emb_col_val_dim_tuple = []
emb_col_val_dim_tuple.append(('family_pred_gender', 4, 8))
emb_col_val_dim_tuple.append(('family_pred_age_level', 1024, 12))
emb_col_val_dim_tuple.append(('category', 28, 8))
deepdense = DeepDense(hidden_layers=[32], dropout=[0.2], deep_column_idx=deep_column_idx, embed_input=emb_col_val_dim_tuple, continuous_cols=['ott_uv_norm', 'category_prefer'])

# init transformer model
transformer = DNNAttr()

wide_deep_model = WideDeep(wide=wide, deepdense=deepdense, deeptext=transformer, head_layers=[128])
wide_deep_model.load_state_dict(torch.load('log/05-06_16.29/sug_saved_model.pt'))

for name, param in wide_deep_model.named_parameters():
    print(name, '        ', param.size())

In [None]:
wide_linear_weight=''
wide_linear_bias=''
emb_layer_category=''
emb_layer_family_pred_age_level=''
emb_layer_family_pred_gender=''
dense_layer_0_w=''
dense_layer_0_bias=''
dense_layer_1_w=''
dense_layer_1_bias=''
query_embedding=''
query_embedding_fc_w=''
query_embedding_fc_bias=''
prefix_embedding=''
head_layer_0_w=''
head_layer_0_bias=''
head_layer_1_w=''
head_layer_1_bias=''
head_out_w=''
head_out_bias=''
for name, param in wide_deep_model.named_parameters():
    
    if 'wide.wide_linear.weight' in name:
        wide_linear_weight = np.round(param.detach().numpy(), 3)
    
    if 'wide.wide_linear.bias' in name:
        wide_linear_bias = np.round(param.detach().numpy(), 3)
    
    # deepdense
    if 'deepdense.embed_layers_dic.emb_layer_category.weight' in name:
        emb_layer_category = np.round(param.detach().numpy(), 3)
    
    if 'deepdense.embed_layers_dic.emb_layer_family_pred_age_level.weight' in name:
        emb_layer_family_pred_age_level = np.round(param.detach().numpy(), 3)
        
    if 'deepdense.embed_layers_dic.emb_layer_family_pred_gender.weight' in name:
        emb_layer_family_pred_gender = np.round(param.detach().numpy(), 3)
    
    if 'deepdense.dense_sequential.dense_layer_0.0.weight' in name:
        dense_layer_0_w = np.round(param.detach().numpy(), 3)
    if 'deepdense.dense_sequential.dense_layer_0.0.bias' in name:
        dense_layer_0_bias = np.round(param.detach().numpy(), 3)
    if 'deepdense.dense_sequential.dense_layer_1.0.weight' in name:
        dense_layer_1_w = np.round(param.detach().numpy(), 3)
    if 'deepdense.dense_sequential.dense_layer_1.0.bias' in name:
        dense_layer_1_bias = np.round(param.detach().numpy(), 3)
    
    
    # text
    if 'deeptext.embedding.weight' in name:
        query_embedding = np.round(param.detach().numpy(), 3)
    if 'deeptext.fc.weight' in name:
        query_embedding_fc_w = np.round(param.detach().numpy(), 3)
    if 'deeptext.fc.bias' in name:
        query_embedding_fc_bias = np.round(param.detach().numpy(), 3)
    
    # prefix
    if 'prefix_embedding.weight' in name:
        prefix_embedding = np.round(param.detach().numpy(), 3)
    if 'deephead.head_layer_0.0.weight' in name:
        head_layer_0_w = np.round(param.detach().numpy(), 3)
    if 'deephead.head_layer_0.0.bias' in name:
        head_layer_0_bias = np.round(param.detach().numpy(), 3)
    
    if 'deephead.head_layer_1.0.weight' in name:
        head_layer_1_w = np.round(param.detach().numpy(), 3)
    if 'deephead.head_layer_1.0.bias' in name:
        head_layer_1_bias = np.round(param.detach().numpy(), 3)
    
    if 'deephead.head_out.weight' in name:
        head_out_w = np.round(param.detach().numpy(), 3)
    
    if 'deephead.head_out.bias' in name:
        head_out_bias = np.round(param.detach().numpy(), 3)
    
np.savez("param.npz", wide_linear_weight=wide_linear_weight,
                        wide_linear_bias=wide_linear_bias,
                        emb_layer_category=emb_layer_category,
                        emb_layer_family_pred_age_level=emb_layer_family_pred_age_level,
                        emb_layer_family_pred_gender=emb_layer_family_pred_gender,
                        dense_layer_0_w=dense_layer_0_w,
                        dense_layer_0_bias=dense_layer_0_bias,
                        dense_layer_1_w=dense_layer_1_w,
                        dense_layer_1_bias=dense_layer_1_bias,
                        query_embedding=query_embedding,
                        query_embedding_fc_w=query_embedding_fc_w,
                        query_embedding_fc_bias=query_embedding_fc_bias,
                        prefix_embedding=prefix_embedding,
                        head_layer_0_w=head_layer_0_w,
                        head_layer_0_bias=head_layer_0_bias,
                        head_layer_1_w=head_layer_1_w,
                        head_layer_1_bias=head_layer_1_bias,
                        head_out_w=head_out_w,
                        head_out_bias=head_out_bias)
    
    
#     if 'emb_layer_family_pred_age_level' in name:
#         emb_layer_family_pred_age_level = np.round(param.detach().numpy(), 3)
        
#         print(name)
#         print(emb_layer_family_pred_age_level)

#     if 'wide.wide_linear.weight' in name:
#         print(wide_linear_weight)
        
        
np_file = np.load("param.npz")
print(np_file['query_embedding'].shape)
print(np_file['query_embedding'])


[rows, cols] = np_file['query_embedding'].shape

for i in range(rows - 1):
    if i >= 1:
        break
    
    row_data = np_file['query_embedding'][i, :].tolist()
    print(np_file['query_embedding'][i, :].tostring())
    print('\n')
    row_data_str = [str(round(item, 3)) for item in row_data]
    print(','.join(row_data_str))
