In [None]:
class AnimeMainPicModel(tf.keras.Model):
    def __init__(self, 
                unique_anime_ids, 
                anime_main_pic_url,
                img_height = 192, 
                img_width = 128,
                base_model_name = "efficientnet",
                anime_embedding_size = 32):
        super().__init__()
        
        self.anime_id_lookup_layer = tf.keras.layers.StringLookup(
            vocabulary = unique_anime_ids, 
            num_oov_indices = 0,
            name = 'anime_pic_model_id_lookup'
        )

        anime_img_tf_ds = self.__class__.get_image_dataset(anime_main_pic_url, img_height, img_width)
        anime_img_model = self.__class__.get_image_embedding_model(base_model_name, img_height, img_width)

        anime_image_embeddings = anime_img_model.predict(anime_img_tf_ds)
        num_animes = anime_image_embeddings.shape[0]
        img_emb_dim = anime_image_embeddings.shape[1]
        self.image_embedding_layer = tf.keras.layers.Embedding(
            num_animes,
            img_emb_dim,
            embeddings_initializer = tf.keras.initializers.Constant(anime_image_embeddings),
            trainable = False,
            name = 'anime_img_base_model_embedding'
        )

        self.final_layer = tf.keras.layers.Dense(anime_embedding_size, activation = 'relu')
    
    def call(self, anime_id):
        anime_idx = self.anime_id_lookup_layer(anime_id)
        anime_image_embedding = self.image_embedding_layer(anime_idx)
        anime_embedding = self.final_layer(anime_image_embedding)
        return anime_embedding

    @staticmethod
    def download_image(img_url, img_height = 192, img_width = 128):
        for _ in range(10):
            try:
                with urlopen(img_url) as request:
                    img_array = np.asarray(bytearray(request.read()), dtype=np.uint8)
                img = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
                img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
                img = cv2.resize(img, (img_width, img_height), interpolation = cv2.INTER_AREA)
                return img
            except Exception as e:
                print(e, img_url)
                continue
        return np.zeros((img_height, img_width, 3))
    
    @staticmethod
    def get_image_dataset(anime_main_pic_df, img_height = 192, img_width = 128):
        anime_image_ds = tf.data.Dataset.from_tensor_slices(anime_main_pic_df)
        anime_image_ds = anime_image_ds.map(
            lambda img_url : tf.py_function(func = lambda x : AnimeMainPicModel.download_image(x.numpy().decode('utf-8'), img_height, img_width), 
                            inp=[img_url], 
                            Tout=tf.uint8)
        )
        anime_image_ds = anime_image_ds.batch(128)
        return anime_image_ds

    @staticmethod
    def get_base_model(base_model_name, img_height = 192, img_width = 128):

        if base_model_name.lower() == "densenet":
            from tensorflow.keras.applications.densenet import DenseNet121, preprocess_input
            return  preprocess_input,\
                    DenseNet121(weights='imagenet', input_shape = (img_height, img_width, 3), include_top=False, pooling = 'max')

        if base_model_name.lower() == "efficientnet":
            from tensorflow.keras.applications.efficientnet import EfficientNetB0, preprocess_input
            return  preprocess_input,\
                    EfficientNetB0(weights='imagenet', input_shape = (img_height, img_width, 3), include_top=False, pooling = 'max')

        if base_model_name.lower() == "mobilenet":
            from tensorflow.keras.applications import MobileNetV3Small
            from tensorflow.keras.applications.mobilenet_v3 import preprocess_input
            return  preprocess_input,\
                    MobileNetV3Small(weights='imagenet', input_shape = (img_height, img_width, 3), include_top=False, pooling = 'max')
        
        if base_model_name.lower() == "nasnet":
            from tensorflow.keras.applications.nasnet import NASNetMobile, preprocess_input
            return  preprocess_input,\
                    NASNetMobile(weights='imagenet', input_shape = (img_height, img_width, 3), include_top=False, pooling = 'max')
        
        if base_model_name.lower() == "resnet":
            from tensorflow.keras.applications.resnet_v2 import ResNet50V2, preprocess_input
            return  preprocess_input,\
                    ResNet50V2(weights='imagenet', input_shape = (img_height, img_width, 3), include_top=False, pooling = 'max')
        
        from tensorflow.keras.applications.resnet_v2 import ResNet50V2, preprocess_input
        return  preprocess_input,\
                ResNet50V2(weights='imagenet', input_shape = (img_height, img_width, 3), include_top=False, pooling = 'max')

    @staticmethod
    def get_image_embedding_model(base_model_name, img_height = 192, img_width = 128):
        preprocessing_function, base_model = AnimeMainPicModel.get_base_model(base_model_name, img_height, img_width)
        raw_image = tf.keras.layers.Input(shape=(img_height, img_width, 3), dtype = tf.uint8, name = 'raw_image')
        preprocessed_image = tf.cast(raw_image, dtype = tf.float32)
        preprocessed_image = preprocessing_function(preprocessed_image)
        image_embedding = base_model(preprocessed_image)
        return tf.keras.Model(raw_image, image_embedding)

class AnimeTextModel(tf.keras.Model):
    def __init__(self, 
                unique_anime_ids, 
                anime_text_feature,
                base_model_name = "bert",
                anime_embedding_size = 32):

        super().__init__()
        self.anime_id_lookup_layer = tf.keras.layers.StringLookup(
            vocabulary=unique_anime_ids, 
            num_oov_indices=0,
            name = 'anime_text_model_id_lookup'
        )
        text_tf_ds = tf.data.Dataset.from_tensor_slices(anime_text_feature).batch(128)
        text_embedding_model = self.__class__.get_text_embedding_model(base_model_name)

        anime_text_embeddings = text_embedding_model.predict(text_tf_ds)

        num_animes = anime_text_embeddings.shape[0]
        text_emb_dim = anime_text_embeddings.shape[1]

        self.text_embedding_layer = tf.keras.layers.Embedding(
            num_animes,
            text_emb_dim,
            embeddings_initializer=tf.keras.initializers.Constant(anime_text_embeddings),
            trainable=False,
            name = 'text_embedding_layer'
        )

        self.final_layer = tf.keras.layers.Dense(anime_embedding_size, activation = 'relu')

    def call(self, anime_id):
        anime_idx = self.anime_id_lookup_layer(anime_id)
        anime_text_embedding = self.text_embedding_layer(anime_idx)
        anime_embedding = self.final_layer(anime_text_embedding)
        return anime_embedding

    @staticmethod
    def get_base_model(base_model_name):
        if base_model_name.lower() == "bert":
            preprocess_model = hub.KerasLayer("https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3")
            model = hub.KerasLayer("https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-2_H-128_A-2/2", 
                                    trainable=False)
            return preprocess_model, model
        
        preprocess_model = hub.KerasLayer("https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3")
        model = hub.KerasLayer("https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-2_H-128_A-2/2", 
                                    trainable=False)
        return preprocess_model, model
    
    @staticmethod
    def get_text_embedding_model(base_model_name):
        preprocessing_function, base_model = AnimeTextModel.get_base_model(base_model_name)
        text = tf.keras.layers.Input(shape = (), dtype = tf.string, name = 'text')
        preprocessed_text = preprocessing_function(text)
        text_embedding = base_model(preprocessed_text)
        text_embedding = text_embedding["pooled_output"]
        return tf.keras.Model(text, text_embedding)

class AnimeOneHotModel(tf.keras.Model):
    def __init__(self,
                unique_anime_ids,
                one_hot_feature,
                vocabulary):
        super().__init__()

        self.anime_id_lookup_layer = tf.keras.layers.StringLookup(
            vocabulary = unique_anime_ids, 
            num_oov_indices = 0,
            name = 'anime_onehot_model_id_lookup'
        )

        one_hot_feature_ds = tf.data.Dataset.from_tensor_slices(one_hot_feature).batch(128)
        one_hot_layer = tf.keras.layers.StringLookup(vocabulary = vocabulary,
                                                    output_mode = "one_hot",
                                                    num_oov_indices = 0
                                                    )
        one_hot_encodings = one_hot_layer(one_hot_feature_ds)
        
        num_animes = one_hot_encodings.shape[0]
        num_one_hot_dims = one_hot_encodings.shape[1]

        self.one_hot_encoding_layer = tf.keras.layers.Embedding(
            num_animes,
            num_one_hot_dims,
            embeddings_initializer = tf.keras.initializers.Constant(one_hot_encodings),
            trainable = False,
            name = 'one_hot_enconding_layer'
        )
    
    def call(self, anime_id):
        anime_idx = self.anime_id_lookup_layer(anime_id)
        anime_onehot_encoding = self.one_hot_encoding_layer(anime_idx)
        return anime_onehot_encoding

class AnimeMultiHotModel(tf.keras.Model):
    def __init__(self,
                unique_anime_ids,
                multi_hot_feature,
                vocabulary):
        super().__init__()

        self.anime_id_lookup_layer = tf.keras.layers.StringLookup(
            vocabulary = unique_anime_ids, 
            num_oov_indices = 0,
            name = 'anime_multihot_model_id_lookup'
        )

        multi_hot_feature = multi_hot_feature.apply(lambda x : self.__class__.multi_hot_same_shape(x, max_len=len(vocabulary)))
        multi_hot_feature = list(multi_hot_feature)

        multi_hot_layer = tf.keras.layers.StringLookup(vocabulary = vocabulary,
                                                    output_mode = "multi_hot",
                                                    num_oov_indices=1
                                                    )
        multi_hot_encodings = multi_hot_layer(multi_hot_feature)
        multi_hot_encodings = multi_hot_encodings[:, 1:]
        
        num_animes = multi_hot_encodings.shape[0]
        num_multi_hot_dims = multi_hot_encodings.shape[1]

        self.multi_hot_encoding_layer = tf.keras.layers.Embedding(
            num_animes,
            num_multi_hot_dims,
            embeddings_initializer=tf.keras.initializers.Constant(multi_hot_encodings),
            trainable = False,
            name = 'multi_hot_enconding_layer'
        )
    
    def call(self, anime_id):
        anime_idx = self.anime_id_lookup_layer(anime_id)
        anime_multihot_encoding = self.multi_hot_encoding_layer(anime_idx)
        return anime_multihot_encoding
    
    @staticmethod
    def multi_hot_same_shape(list_entities, max_len = 30):
        list_entities = list_entities[:max_len]
        num_add = max_len - list_entities.shape[0]
        return np.concatenate([list_entities , num_add * ["[UNK]"]])

class CombinedAnimeModel(tf.keras.Model):
    def __init__(self, sub_models = [], anime_embedding_size = 32):
        super().__init__()

        assert(len(sub_models) > 1)
        self.sub_models = sub_models
        self.final_layer = tf.keras.layers.Dense(anime_embedding_size, activation = 'relu', name = 'combined_anime_final_layer')
    
    def call(self, anime_id):
        sub_embeddings = [
            sub_model(anime_id) 
            for sub_model in self.sub_models
        ]
        concat_embedding = tf.concat(sub_embeddings, axis=-1)
        anime_embedding = self.final_layer(concat_embedding)
        return anime_embedding

###########################

class CombinedUserModel(tf.keras.Model):
    def __init__(self, sub_models = [], user_embedding_size = 32):
        super().__init__()

        assert(len(sub_models) > 1)
        self.sub_models = sub_models
        self.final_layer = tf.keras.layers.Dense(user_embedding_size, activation = 'relu', name = 'combined_user_final_layer')
    
    def call(self, user_id):
        sub_embeddings = [
            sub_model(user_id) 
            for sub_model in self.sub_models
        ]
        concat_embedding = tf.concat(sub_embeddings, axis=-1)
        user_embedding = self.final_layer(concat_embedding)
        return user_embedding
##############################
    
class RandomUserAnimeListRankingModel(BaseUserAnimeListRankingModel):
    def __init__(self, topn = 5, positive_threshold = 8.0):
        super().__init__(topn, positive_threshold)

    def call(self, features):
        anime_id = features['anime_id']
        pred_ratings = tf.random.uniform(tf.shape(anime_id), minval = 1.0, maxval = 10.0)
        return pred_ratings

class PerfectUserAnimeListRankingModel(BaseUserAnimeListRankingModel):
    def __init__(self, topn = 5, positive_threshold = 8.0):
        super().__init__(topn, positive_threshold)
          
    def call(self, features):
        pred_ratings = features['score']
        return pred_ratings

class AverageUserAnimeRatingListRankingModel(BaseUserAnimeListRankingModel):
    def __init__(self, anime_ids, anime_scores, topn = 5, positive_threshold = 8.0):
        super().__init__(topn, positive_threshold)
        self.ratings = tf.lookup.StaticHashTable(
                tf.lookup.KeyValueTensorInitializer(anime_ids, anime_scores),
                default_value=-1
        )
    def call(self, features):
        anime_id = features['anime_id']
        pred_ratings = self.ratings.lookup(anime_id)
        return pred_ratings
