In [1]:
import sys
import os
import json
import tensorflow as tf
import tqdm
from models.model import *
from ltv_utils import *
from losses.custom_loss import *
pd.set_option('display.float_format', '{:.4f}'.format)  # 保留10位小数，可调整
import warnings
warnings.simplefilter(action='ignore', category=UserWarning)




def parse_function(serialized_example):
    feature_description = {
        'deviceid': tf.io.FixedLenFeature([], tf.string),
        'install_date': tf.io.FixedLenFeature([], tf.string),
        'dim_os_name1': tf.io.FixedLenFeature([], tf.string),
        'creative_classify1': tf.io.FixedLenFeature([], tf.string),
        'total_pay_amount1':  tf.io.FixedLenFeature([], tf.float32),
         'channel1': tf.io.FixedLenFeature([], tf.string),
        'b2_sale_amt_bias':  tf.io.FixedLenFeature([], tf.int64),
         'b2_sale_amt_7d': tf.io.FixedLenFeature([], tf.int64),
         'install_time': tf.io.FixedLenFeature([], tf.string),
        'install_order_diff':  tf.io.FixedLenFeature([], tf.int64),
        'all_install_order_7d_diff':  tf.io.FixedLenFeature([], tf.int64),
        'is_a1x_a33':  tf.io.FixedLenFeature([], tf.int64),
        'platform_label':  tf.io.FixedLenFeature([], tf.string),
        'user_dense_price_features': tf.io.FixedLenFeature([len(group_2_features['user_dense_price_features'])], tf.float32),
        'user_dense_duration_features': tf.io.FixedLenFeature([len(group_2_features['user_dense_duration_features'])], tf.float32),
        'user_dense_features': tf.io.FixedLenFeature([len(group_2_features['user_dense_features'])], tf.float32),
        'user_sparse_features': tf.io.FixedLenFeature([len(group_2_features['user_sparse_features'])], tf.float32)
    }
    example = tf.io.parse_single_example(serialized_example, feature_description)
    return example


# load tf records
group_2_features = read_feature_json_config('features/feature_list.json')
file_name = 'data/loca_test_tf.tfrecords'
data_path = file_name

dataset = tf.data.TFRecordDataset(data_path)
dataset = dataset.map(parse_function)
dataset = dataset.prefetch(buffer_size=10000)
dataset = dataset.batch(2048)


user_dense_price_features = group_2_features['user_dense_price_features']
user_dense_duration_features = group_2_features['user_dense_duration_features']
user_dense_features = group_2_features['user_dense_features']
user_sparse_features = group_2_features['user_sparse_features']


In [2]:
def create_tf_dataset(dataset):
    sample_batch = next(iter(dataset))
    sample_data = {k: v for k, v in sample_batch.items() if k not in ['b2_sale_amt_7d', 'total_pay_amount1']}

    def generator():
        for batch in dataset:
            hour = tf.cast(tf.gather(batch['user_sparse_features'], indices=0, axis=1) - 1, tf.int64)
            b2_7d_raw = tf.reshape(batch.pop('b2_sale_amt_7d'), (-1, 1))
            total_amt_1h = tf.reshape(batch.pop('total_pay_amount1'), (-1, 1))

            # ✅ hour == 4 的数据模拟变换 随机10% 
            hour_mask = tf.equal(hour, 4) 
            
            rand = tf.random.uniform(shape=tf.shape(hour_mask), minval=0.0, maxval=1.0)
            hour_mask = tf.logical_and(hour_mask, rand < 0.1)
            
            hour_mask = tf.reshape(hour_mask, (-1, 1))
            b2_7d_raw = tf.where(hour_mask, tf.cast(b2_7d_raw, tf.float32) + 1.0, tf.cast(b2_7d_raw, tf.float32))

            # ✅ 转为二分类标签：b2_7d > 0 → 1，否则 0
            binary_label = tf.cast(b2_7d_raw > 0.0, tf.float32)

            # 仍保留 total_amt_1h 以支持其他 loss 或多任务用
            y_true_packed = tf.concat([binary_label, total_amt_1h], axis=1)

            yield batch, y_true_packed

    output_signature = (
        {
            name: tf.TensorSpec(shape=(None,) + v.shape[1:], dtype=v.dtype)
            for name, v in sample_data.items()
        },
        tf.TensorSpec(shape=(None, 2), dtype=tf.float32)  # 第一列是分类标签
    )

    return tf.data.Dataset.from_generator(generator, output_signature=output_signature)


In [7]:

emb_features = [
'creative_classify','dim_device_manufacture', 'car_add_type_most','show_order_is_2arrival_latest', 'selecttirecount_most', 'show_order_is_2arrival_most','selecttirecount_latest',
 'new_sitename','advsite','car_add_type_latest','platform_level', 'tire_list_click_avg_index','tire_list_click_most_pid_level','tire_order_page_most_pid_level',
]


model = MULTI_HEAD_LTV_MODEL(5, [128], [200,128,128], 'user_dense_features', 'user_dense_price_features', 'user_dense_duration_features',
                            'user_sparse_features',user_sparse_features, emb_features)

sample = next(iter(dataset))
input_shape = {k: v.shape for k, v in sample.items()}
# early_stopping = tf.keras.callbacks.EarlyStopping(
#     monitor='val_auc',  # 监控验证集上的 loss
#     patience=3,          # 如果连续 3 轮没有改善，就停止训练
#     restore_best_weights=True  # 训练结束后恢复到最优模型
# )
loss_fn = UnifiedLTVLoss('binary')
model.compile(loss=loss_fn, optimizer = 'adam')

model.fit(
    create_tf_dataset(dataset),
    epochs=10,
)        


Epoch 1/10
Tensor("IteratorGetNext:15", shape=(None, 2), dtype=float32)
Tensor("IteratorGetNext:15", shape=(None, 2), dtype=float32)
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<keras.callbacks.History at 0x16cd9ae1450>

In [11]:
model.hour2headnn

ListWrapper([<models.model.HEAD_DNN object at 0x0000016CD9AC1F90>, <models.model.HEAD_DNN object at 0x0000016CD9AC27A0>, <models.model.HEAD_DNN object at 0x0000016CD9AD1E10>, <models.model.HEAD_DNN object at 0x0000016CD9AD2680>, <models.model.HEAD_DNN object at 0x0000016CD9AD32B0>])

In [9]:
res = model.evaluate(create_tf_dataset(dataset))
res = pd.DataFrame(res)
display(res)

Unnamed: 0,pred sum:,true sum:
0,"tf.Tensor(4237.78, shape=(), dtype=float32)","tf.Tensor(4427.0, shape=(), dtype=float32)"
1,"tf.Tensor(4930.1714, shape=(), dtype=float32)","tf.Tensor(4463.0, shape=(), dtype=float32)"
2,"tf.Tensor(5433.9263, shape=(), dtype=float32)","tf.Tensor(4520.0, shape=(), dtype=float32)"
3,"tf.Tensor(4873.4976, shape=(), dtype=float32)","tf.Tensor(4556.0, shape=(), dtype=float32)"
4,"tf.Tensor(8623.543, shape=(), dtype=float32)","tf.Tensor(7712.0, shape=(), dtype=float32)"


In [None]:
# step 1 训练全部层。第二部训练 head层 。
# step 2: # 训练第一阶段，全部训练

model.sharebottom.trainable = True

model.process_dense_layer.trainable = True

model.process_emb_layer.trainable = True

for head in model.hour2headnn:

  head.trainable = True

model.compile(...)  # 编译

model.fit(...)

# 训练第二阶段，只训练head层

model.sharebottom.trainable = False

model.process_dense_layer.trainable = False

model.process_emb_layer.trainable = False

for head in model.hour2headnn:

  head.trainable = True

# 重新compile（推荐）或者直接训练

model.compile(...)

model.fit(...)