In [1]:
import os
IS_COLAB_ENV = True

try:
    from google.colab import drive
    IS_COLAB_ENV = True
except:
    IS_COLAB_ENV = False

if IS_COLAB_ENV:
    drive.mount('/content/drive')
    !pip install einops
    !pip install tf_sentence_transformers
    ROOT_DIR = "/content/drive/MyDrive/work/AI-RecommenderSystem/"
    os.chdir('/content/drive/MyDrive/work/' + 'AI-RecommenderSystem/Recall/YoutubeDNN/my_implementation')
else:
    ROOT_DIR = "/mnt/g/My Drive/work/AI-RecommenderSystem/"

try:
    os.chdir(ROOT_DIR + 'Recall/YoutubeDNN/my_implementation')
except FileNotFoundError:
    ROOT_DIR = "/Users/hanshen/Library/CloudStorage/GoogleDrive-shawnhan1029@gmail.com/My Drive/work/AI-RecommenderSystem/"
    os.chdir(ROOT_DIR + 'Recall/YoutubeDNN/my_implementation')
DATA_DIR = ROOT_DIR + "Dataset/news_data_bigger/"


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


# LOAD USER AND DOC INFO

In [2]:
from collections import namedtuple
import pandas as pd
from typing import Dict
import numpy as np
import tensorflow as tf
from tensorflow import keras
import os, pickle
from tf_sentence_transformers import SentenceTransformer


user_info_cols = ["userid", "device", "operating_system", "province", "city", "age", "gender"]
doc_info_cols = ["docid", "title", "create_time", "image_num", "cate1", "cate2", "keywords"]
show_info_cols = ["userid", "docid", "exp_time", "network", "rt", "rit", "click", "reading_time"]

In [3]:
import re
device_name = tf.test.gpu_device_name()
if len(device_name) > 0:
    USE_GPU = True
    print("Found GPU at: {}".format(device_name))
else:
    USE_GPU = False
    device_name = tf.config.list_physical_devices('CPU')[0].name
    device_name = re.sub("physical_device:", "", device_name)
    print("No GPU, using {}.".format(device_name))

Found GPU at: /device:GPU:0


# BUILD converter_layers

In [4]:
feat_names = [
    "userid",
    "device",
    "operating_system",
    "province",
    "city",
    "age",
    "gender",
    # "docid",
    "network",
    "rt",
    "rit",
    "cate1",
    "cate2",
    "title",
]


CATE_FEAT = set([
    "userid",
    "device",
    "operating_system",
    "province",
    "city",
    "age",
    "gender",
    "docid",
    "network",
    "rt",
    "rit",
    "cate1",
    "cate2",
])

TEXT_FEAT = set([
    "title"
])


def clean_create_time(value):
    if len(value) == 0:
        return np.uint32(0)
    return np.uint32(value)

def clean_image_num(value):
    if len(value) == 0:
        return np.uint8(0)
    return np.uint8(value)

def reduce_proba(row: str):
    '''Suppose row has the following format: key1:[float],key2:[float]
    '''
    if not isinstance(row, str):
        return "UNK"
    if len(row) == 0:
        return "UNK"
    classes = row.split(",")
    assert len(classes) >= 1, "unkown format: [{}]".format(row)
    max_proba = 0
    max_class = "UNK"
    for cls_pair in classes:
        cls, proba = cls_pair.split(":")
        if float(proba) > max_proba:
            max_class = cls
            max_proba = float(proba)
    return max_class


def get_vocab(feat: str):
    assert feat in CATE_FEAT, "not support feat!"
    global user_info, doc_info, show_info
    if feat in user_info_cols:
        if user_info is None:
            user_info = pd.read_csv(
                DATA_DIR + "user_info.txt",
                sep="\t", header=None, names=user_info_cols,
                dtype=str,
                keep_default_na=False
            )
            user_info["age"] = user_info["age"].apply(reduce_proba)
            user_info["gender"] = user_info["gender"].apply(reduce_proba)
        return user_info[feat].unique()
    if feat in doc_info:
        if doc_info is None:
            doc_info = pd.read_csv(
                DATA_DIR + "doc_info.txt",
                sep="\t", header=None, names=doc_info_cols,
                # converters={
                #     "create_time": clean_create_time,
                #     "image_num": clean_image_num,
                # },
                dtype=str,
                keep_default_na=False
            )
        return doc_info[feat].unique()
    if feat in show_info_cols:
        if show_info_cols is None:
            show_info = pd.read_csv(
                DATA_DIR + "sorted_train_data.txt",
                sep="\t", names=show_info_cols,
                dtype=str,
                keep_default_na=False,
                nrows=1000000,
            )
        return show_info[feat].unique()
    raise Exception("Dont know where to find {}".format(feat))



In [5]:
user_info: pd.DataFrame = pd.read_csv(
    DATA_DIR + "user_info.txt",
    sep="\t", header=None, names=user_info_cols,
    dtype=str,
    keep_default_na=False
)
user_info["age"] = user_info["age"].apply(reduce_proba)
user_info["gender"] = user_info["gender"].apply(reduce_proba)
doc_info: pd.DataFrame = pd.read_csv(
    DATA_DIR + "doc_info.txt",
    sep="\t", header=None, names=doc_info_cols,
    # converters={
    #     "create_time": clean_create_time,
    #     "image_num": clean_image_num,
    # },
    dtype=str,
    keep_default_na=False,
    # nrows=633388,
)
show_info: pd.DataFrame = None

In [6]:
converter_layers: Dict[str, keras.layers.StringLookup] = {}

for key in feat_names:
    if key in CATE_FEAT:
        if os.path.exists("./{}.pkl".format(key)):
            print("trying to load {} StringLookup layer...".format(key))
            from_disk = pickle.load(open("./{}.pkl".format(key), "rb"))
            new_layer = keras.layers.StringLookup().from_config(from_disk["config"])
            new_layer.adapt(tf.data.Dataset.from_tensor_slices(["xyz"]))
            new_layer.set_weights(from_disk['weights'])
        else:
            vocab = get_vocab(key)
            print("trying to create {} StringLookup layer...".format(key))
            new_layer = keras.layers.StringLookup(num_oov_indices=1)
            new_layer.adapt(data=vocab)
            pickle.dump({
                "config": new_layer.get_config(),
                "weights": new_layer.get_weights(),
            }, open("./{}.pkl".format(key), "wb"))
    elif key == "title":
        print("trying to download bert model...")
        with tf.device(device_name):
            new_layer = SentenceTransformer.from_pretrained('uer/sbert-base-chinese-nli', from_pt=True)
    else:
        raise ValueError("unsupported feat [" + key + "]!")
    converter_layers[key] = new_layer

trying to load userid StringLookup layer...




trying to load device StringLookup layer...
trying to load operating_system StringLookup layer...
trying to load province StringLookup layer...
trying to load city StringLookup layer...




trying to load age StringLookup layer...
trying to load gender StringLookup layer...
trying to load network StringLookup layer...
trying to load rt StringLookup layer...
trying to load rit StringLookup layer...
trying to load cate1 StringLookup layer...
trying to load cate2 StringLookup layer...
trying to download bert model...


Some weights of the PyTorch model were not used when initializing the TF 2.0 model TFBertModel: ['embeddings.position_ids']
- This IS expected if you are initializing TFBertModel from a PyTorch model trained on another task or with another architecture (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFBertModel from a PyTorch model that you expect to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a BertForSequenceClassification model).
All the weights of TFBertModel were initialized from the PyTorch model.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFBertModel for predictions without further training.
Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


## click sequence (one time)

### make seq file

In [7]:
# from collections import deque

# SEQ_LENGTH = 20

# ShowLog = namedtuple("Userinfo", ["userid", "docid", "exp_time", "network", "rt", "rit", "click", "reading_time"])

# user_seq_buffer: Dict[str, deque] = {}

# with open(DATA_DIR + "sorted_train_data.txt", "r") as show_log_f, open(DATA_DIR + "clk_seq_from_sorted_train_data.txt", 'w') as seq_f:
#     for line in show_log_f:
#         parts = line[:-1].split("\t")
#         show_log = ShowLog(*parts)
#         if show_log.userid not in user_seq_buffer:
#             user_seq_buffer[show_log.userid] = deque(maxlen=SEQ_LENGTH)
#         dq_of_this_user = user_seq_buffer[show_log.userid]
#         seq_str = " ".join(dq_of_this_user)
#         seq_f.write(",".join(show_log[:3] + (seq_str,)) + "\n")
#         if show_log.click == "1":
#             dq_of_this_user.append(show_log.docid)

### make dataset

In [8]:
def get_seq_feat(ele):
    ele = tf.strings.split(ele, ",").to_tensor()
    # tf.print("1", ele, tf.shape(ele))
    ele = ele[:, 3]
    # tf.print("2", ele, tf.shape(ele))
    ele = tf.strings.split(ele, " ").to_tensor()
    # tf.print("3", ele, tf.shape(ele))
    res = converter_layers["docid"](ele)
    # tf.print(res, type(res), tf.shape(res))
    return res


if "docid" in feat_names:
    seq_dataset = tf.data.TextLineDataset([DATA_DIR + "/clk_seq_from_sorted_train_data.txt"])\
        .batch(1024, drop_remainder=True)\
        .map(get_seq_feat)
        # .unbatch()
else:
    seq_dataset = None

## make title embedding file

In [24]:
doc_info.loc[510178, 'title'] = '"最强鸿蒙概念股”润和软件再收关注函!20多天5倍涨幅,散户抱团到尽头?'

In [25]:
from tqdm import tqdm

# for member in tqdm(members):
#     # current contents of your for loop


def write_block(input_blk, mode, idx, tmp_ids):
    title_as_tensor = tf.constant(input_blk)
    with tf.device(device_name):
        res_as_tensor = converter_layers["title"](title_as_tensor)
    res_as_list = res_as_tensor.numpy().tolist()
    lines_blk = []
    for id, one_res in zip(tmp_ids, res_as_list):
        one_res = [str(round(e, 5)) for e in one_res]
        lines_blk.append(id + "," + " ".join(one_res) + "\n")
    with open(DATA_DIR + "doc_title_embedding.txt", mode) as title_embedding_file, open(DATA_DIR + "doc_title_embedding_info.txt", "w") as info_file:
        title_embedding_file.writelines(lines_blk)
        info_file.write(str(idx + 1))


DOC_NUMBER = doc_info.shape[0]
# DOC_NUMBER = 30


if "title" in feat_names:
    if os.path.exists(DATA_DIR + "doc_title_embedding_info.txt"):
        with open(DATA_DIR + "doc_title_embedding_info.txt" ,"r") as info_file:
            number_str = info_file.readline()
            appended_number = int(number_str)
    else:
        appended_number = 0
    print("wrote number: " + str(appended_number))
    if appended_number >= DOC_NUMBER:
        print("No need to generate.")
        pass
    else:
        print("Still need to process {} docs.".format(str(DOC_NUMBER - appended_number)))
        if appended_number == 0:
            mode = "w"
        else:
            mode = "a"
        print("mode: ", mode)
        input_blk = []
        tmp_ids = []
        doc_info__ = doc_info.loc[appended_number:DOC_NUMBER-1, :]
        for idx, row in tqdm(doc_info__.iterrows()):
            input_blk.append([row["title"]])
            tmp_ids.append(row["docid"])
            if len(tmp_ids) >= 500:
                write_block(input_blk, mode, idx, tmp_ids)
                tmp_ids.clear()
                input_blk.clear()
                mode = "a"
        if len(tmp_ids) > 0:
            write_block(input_blk, mode, idx, tmp_ids)
            pass
        print("Done generate.")

wrote number: 510178
Still need to process 82568 docs.
mode:  a


82568it [02:46, 496.02it/s]


Done generate.


In [26]:
title_embd_table = tf.lookup.experimental.MutableHashTable(
    key_dtype=tf.int64,
    value_dtype=tf.float32,
    default_value=[0.0] * 768,
)

def build_title_embd_hash_table(ele):
    ele = tf.strings.split(ele, ",").to_tensor()
    # tf.print(ele[:, 0], type(ele), ele.shape)
    key = tf.strings.to_number([ele[:, 0]], out_type=tf.int64)
    value_as_a_whole = tf.strings.split(ele[:, 1], " ").to_tensor()
    value = tf.strings.to_number([value_as_a_whole], out_type=tf.float32)
    # tf.print(key, type(key), tf.shape(key))
    # tf.print(value, type(value), tf.shape(value))
    title_embd_table.insert(key, value)
    return 1

tf.data.TFRecordDataset
title_embd_ = tf.data.TextLineDataset([DATA_DIR + "doc_title_embedding.txt"])\
    .batch(1000)\
    .map(build_title_embd_hash_table)

# Must run one pass
for ele in title_embd_:
    pass

In [27]:
title_embd_table.lookup(tf.constant([349635709], tf.int64))

<tf.Tensor: shape=(1, 768), dtype=float32, numpy=
array([[ 0.06875,  0.02232, -0.02387, -0.02996,  0.10184,  0.02768,
        -0.04228,  0.01384, -0.02292,  0.0283 , -0.00192, -0.01466,
         0.00466, -0.0935 ,  0.02055, -0.03486, -0.00847, -0.00972,
        -0.03999,  0.03547,  0.00991,  0.02884,  0.08565, -0.01345,
         0.00349, -0.0626 , -0.01813,  0.03217, -0.00993, -0.04048,
        -0.01647, -0.06433,  0.02927, -0.00067, -0.00317,  0.04936,
        -0.06789,  0.02903, -0.03834,  0.00525, -0.03111, -0.0302 ,
        -0.04978, -0.01178, -0.03246,  0.01466,  0.01287, -0.02434,
         0.0625 ,  0.02376, -0.11231, -0.06029, -0.06636,  0.03758,
        -0.0168 , -0.02797,  0.06975, -0.0041 ,  0.02385,  0.0421 ,
        -0.02723, -0.00443, -0.06479,  0.00531,  0.03648,  0.03837,
        -0.04302,  0.04137, -0.03322, -0.04695, -0.07088, -0.00611,
         0.02463,  0.04293,  0.00055,  0.03174,  0.02429, -0.04159,
         0.05676, -0.02921, -0.02157, -0.00119, -0.01214, -0.02329

## other feature (normal)

### build user MutableHashTable, to build feature later

In [28]:
# 用户id、设备名称、操作系统、所在省、所在市、年龄、性别；

user_info_table = tf.lookup.experimental.MutableHashTable(
        key_dtype=tf.int64,
        value_dtype=tf.string,
        default_value=["1", "1", "1", "1", "1", "1", "1"],
    )


def build_hash_table(ele):
    # tf.print(ele[:, 0], type(ele), ele.shape)
    key = tf.strings.to_number([ele[:, 0]], out_type=tf.int64)
    # tf.print(key, type(key), key.shape)
    # tf.print(key)
    user_info_table.insert(key, tf.expand_dims(ele, axis=0))
    return 1


user_info_ = tf.data.Dataset.from_tensor_slices(user_info)\
    .batch(1000)\
    .map(build_hash_table)

# Must run one pass
for ele in user_info_:
    pass

In [29]:
user_info_table.lookup(tf.constant([1001384888], dtype=tf.int64))

<tf.Tensor: shape=(1, 7), dtype=string, numpy=
array([[b'1001384888', b'M2007J22C', b'Android',
        b'\xe6\xb2\xb3\xe5\x8c\x97',
        b'\xe7\x9f\xb3\xe5\xae\xb6\xe5\xba\x84', b'A_40+', b'male']],
      dtype=object)>

### build doc MutableHashTable

In [30]:
# 文章id、标题、发文时间、图片数量、一级分类、二级分类、关键词；
doc_info_table = tf.lookup.experimental.MutableHashTable(
        key_dtype=tf.int64,
        value_dtype=tf.string,
        default_value=["1", "1", "1", "1", "1", "1", "1"],
    )


def build_doc_hash_table(ele):
    # tf.print(ele[:, 0], type(ele), ele.shape)
    key = tf.strings.to_number([ele[:, 0]], out_type=tf.int64)
    # tf.print(key, type(key), key.shape)
    # tf.print(key)
    doc_info_table.insert(key, tf.expand_dims(ele, axis=0))
    return 1


doc_info_ = tf.data.Dataset.from_tensor_slices(doc_info)\
    .batch(1000)\
    .map(build_doc_hash_table)

# Must run one pass
for ele in doc_info_:
    pass

### make feat by tf method

In [40]:
# 用户id、文章id、展现时间、网路环境、刷新次数、展现位置、是否点击、消费时长（秒）；
user_feat_location = {name: idx for idx, name in enumerate(user_info_cols)}
doc_feat_location = {name: idx for idx, name in enumerate(doc_info_cols)}
show_feat_location = {name: idx for idx, name in enumerate(show_info_cols)}


def tf_get_feat_from_table(ele: tf.Tensor):
    ele: tf.Tensor = tf.strings.split(ele, "\t").to_tensor()
    uids = tf.strings.to_number([ele[:, 0]], out_type=tf.int64)
    docids = tf.strings.to_number([ele[:, 1]], out_type=tf.int64)
    label = tf.strings.to_number([ele[:, 6]], out_type=tf.int64)
    label = tf.reshape(label, shape=[-1, 1])  # [1, batch] to [batch, 1]
    # tf.print(label, type(label), tf.shape(label))
    user_feat_values = user_info_table.lookup(uids)
    user_feat_values = tf.squeeze(user_feat_values, axis=0)

    doc_feat_values = doc_info_table.lookup(docids)
    doc_feat_values = tf.squeeze(doc_feat_values, axis=0)
    # tf.print(doc_feat_values, type(doc_feat_values), tf.shape(doc_feat_values))

    title_embd_values = title_embd_table.lookup(docids)
    title_embd_values = tf.squeeze(title_embd_values, axis=0)

    # tf.print(title_embd_values, type(title_embd_values), tf.shape(title_embd_values))

    feat_dict = {}
    for feat in feat_names:
        if feat == "title":
            tmp = title_embd_values
        elif feat in user_feat_location:
            tmp = user_feat_values[:, user_feat_location[feat]]
        elif feat in doc_feat_location:
            tmp = doc_feat_values[:, doc_feat_location[feat]]
        else:
            tmp = ele[:, show_feat_location[feat]]

        if feat == "title":
            tmp = tf.ensure_shape(tmp, [None, 768])
        else:
            tmp = converter_layers[feat](tmp)
            tmp = tf.ensure_shape(tmp, [None,])

        feat_dict[feat] = tmp
    return (feat_dict, label)



train_show_log = tf.data.TextLineDataset([DATA_DIR + "sorted_train_data.txt"])\
    .batch(1024, drop_remainder=True)\
    .map(tf_get_feat_from_table)
    # .unbatch()

### concat show log and req feat

In [41]:
def merge_sparse_with_seq_into_one_dict(a, seq):
    sparse, label = a
    for _, value in sparse.items():
        value.set_shape([1024])
    sparse["docid_seq"] = seq
    return (sparse, label)

if "docid" in feat_names:
    dataset = tf.data.Dataset.zip((train_show_log, seq_dataset)).map(merge_sparse_with_seq_into_one_dict).prefetch(tf.data.AUTOTUNE)
else:
    dataset = train_show_log.prefetch(tf.data.AUTOTUNE)

In [42]:
# train_dataset, test_dataset = tf.keras.utils.split_dataset(dataset.take(1_000_000), left_size=0.9)
# dataset has approximately 185319 batches.

if IS_COLAB_ENV:
    train_dataset = dataset.take(100_000)
    test_dataset = dataset.skip(100_000).take(30_000)
else:
    train_dataset = dataset.take(100)
    test_dataset = dataset.skip(100).take(30)

### make negative sample

In [34]:
# 用户id、文章id、展现时间、网路环境、刷新次数、展现位置、是否点击、消费时长（秒）；

# doc_click_freq = tf.lookup.experimental.MutableHashTable(
#     key_dtype=tf.int64,
#     value_dtype=tf.int64,
#     default_value=0,
# )


# def split_ele(ele):
#     return tf.strings.split(ele, "\t")


# def is_click(ele):
#     res = (ele[6] == tf.constant(["1"]))
#     # tf.print(res[0], type(res[0]))
#     return res[0]


# click_log_dataset = tf.data.TextLineDataset(DATA_DIR + "/sorted_train_data.txt")\
#     .map(split_ele)
#     .filter(is_click)

# for e in click_log_dataset.take(10):
#     print(e)
#     pass

# TRAIN MODEL

In [35]:
%load_ext tensorboard

In [44]:
import importlib
import model
import datetime
importlib.reload(model)

sparse_configs = []
doc_embedding = None
dense_configs = []

# feat_names = [
#     "userid",
#     "device",
#     "operating_system",
#     "province",
#     "city",
#     "age",
#     "gender",
#     # "docid",
#     "network",
#     "rt",
#     "rit",
#     "cate1",
#     "cate2",
#     "title",
# ]

for feat in feat_names:
    layer = converter_layers[feat]
    if feat == "cate2":
        doc_embedding = model.EmbeddingConfig(feat, 64, layer.vocabulary_size())
    elif feat == "title":
        dense_configs.append(
            model.DenseConfig("title", 768)
        )
        continue
    elif feat in ["userid", "device", "rit"]:
        sparse_configs.append(
            model.EmbeddingConfig(feat, 16, layer.vocabulary_size())
        )
    else:
        sparse_configs.append(
            model.EmbeddingConfig(feat, 64, layer.vocabulary_size())
        )
    print("{} embedding space size is {}".format(feat, layer.vocabulary_size()))

with tf.device(device_name):
    optimizer = keras.optimizers.Adam()
    youtubednn = model.YouTubeDNN(
        sparse_configs, doc_embedding, use_seq_feat=False, dense_feat=dense_configs, dnn_dims=[1024, 512],
        # text_feat=["title"], text_transformer=converter_layers["title"]
    )
    youtubednn.compile(
        optimizer=optimizer,
        loss=keras.losses.BinaryCrossentropy(from_logits=False),
        metrics=keras.metrics.AUC(),
    )

userid embedding space size is 1538385
device embedding space size is 3097
operating_system embedding space size is 4
province embedding space size is 329
city embedding space size is 769
age embedding space size is 6
gender embedding space size is 4
network embedding space size is 5
rt embedding space size is 215
rit embedding space size is 1185
cate1 embedding space size is 40
cate2 embedding space size is 202


In [45]:
checkpoint_dir = "./ckpt_1014_bert/"
checkpoint_filepath = checkpoint_dir + "{epoch:02d}-{val_auc:.3f}.hdf5"

# epoch = tf.Variable(0)
EPOCHS = 2

# other_ckpt = tf.train.Checkpoint(
#     optimizer=optimizer,
#     epoch = epoch,
# )
# other_chkpt_manager = tf.train.CheckpointManager(
#     other_ckpt,
#     checkpoint_dir,
#     max_to_keep=3,
# )

class CheckPointCallback(keras.callbacks.ModelCheckpoint):
    def on_epoch_end(self, epoch, logs=None):
        keys = list(logs.keys())
        for key in keys:
            if key.startswith("val_auc"):
                logs["val_auc"] = logs[key]
                break
        keras.callbacks.ModelCheckpoint.on_epoch_end(self, epoch, logs)
        # ckpt_save_path = other_chkpt_manager.save()
        # print(f'Saving checkpoint for epoch {epoch} at {ckpt_save_path}')
        # other_ckpt.epoch.assign_add(1)

model_checkpoint_callback = CheckPointCallback(
    filepath=checkpoint_filepath,
    save_weights_only=True,
    monitor="val_auc",
    mode="max",
    save_best_only=True,
    verbose=1,
)

# backup_and_restore = keras.callbacks.BackupAndRestore(
#     backup_dir=checkpoint_dir[:-1] + "_backup",
#     save_freq=100000)

log_dir = checkpoint_dir + "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)

def get_latext_ckpt():
    init_epoch = 0
    max_auc = 0
    ckpt_name = None
    try:
        file_names = os.listdir(checkpoint_dir)
    except:
        return init_epoch, max_auc, ckpt_name
    if len(file_names) > 0:
        for file_name in file_names:
            if file_name in ["logs"]:
                continue
            file_name_without_ext = file_name[:-5]
            epoch_str, auc_str = file_name_without_ext.split("-")
            epoch = int(epoch_str)
            auc = float(auc_str)
            if max_auc <= auc:
                max_auc = auc
                init_epoch = epoch
                ckpt_name = checkpoint_dir + file_name
    return init_epoch, max_auc, ckpt_name

init_epoch, _, ckpt_name = get_latext_ckpt()

if ckpt_name is not None:
    youtubednn.evaluate(train_dataset.take(1))
    youtubednn.load_weights(ckpt_name)
    print("Load file "+ckpt_name)


In [None]:
# if other_chkpt_manager.latest_checkpoint:
#     youtubednn.load_weights(checkpoint_dir)
#     other_ckpt.restore(other_chkpt_manager.latest_checkpoint)
#     print(f'Latest checkpoint restored!!. Last epoch is {int(epoch)}')
# else:
#     print("Didn't find checkpoint. Train from scratch.")


youtubednn.fit(
    x = train_dataset,
    validation_data=test_dataset,
    callbacks=[
        model_checkpoint_callback,
        # backup_and_restore,
        tensorboard_callback,
    ],
    epochs=EPOCHS,
    initial_epoch=init_epoch,
    steps_per_epoch=100_000,
)

Epoch 1/2