In [4]:
import tensorflow as tf
import tensorflow_recommenders as tfrs
import numpy as np

# Create synthetic movie data
movies = tf.data.Dataset.from_tensor_slices({
    "title": ["Movie " + str(i) for i in range(5)],
    "genre": ["Action", "Comedy", "Drama", "Comedy", "Action"]
})

# Convert movies dataset to just titles for the candidates
movie_titles = movies.map(lambda x: x["title"])

# Create synthetic user data
users = tf.data.Dataset.from_tensor_slices({
    "user_id": [str(i) for i in range(3)],
    "age": [20, 30, 40]
})

# Define a simple recommendation model
class SimpleModel(tfrs.Model):
    def __init__(self):
        super().__init__()
        
        # Compute embeddings for users
        self.user_model = tf.keras.Sequential([
            # Add +1 to vocabulary size for padding
            tf.keras.layers.StringLookup(vocabulary=list(["0", "1", "2"])),
            tf.keras.layers.Embedding(4, 32),  # Changed from 3 to 4
        ])

        # Compute embeddings for movies
        self.movie_model = tf.keras.Sequential([
            # Add +1 to vocabulary size for padding
            tf.keras.layers.StringLookup(vocabulary=list(["Movie 0", "Movie 1", "Movie 2", "Movie 3", "Movie 4"])),
            tf.keras.layers.Embedding(6, 32),  # Changed from 5 to 6
        ])

        # Task
        self.task = tfrs.tasks.Retrieval(metrics=tfrs.metrics.FactorizedTopK(
            candidates=movie_titles.batch(128).map(self.movie_model)
        ))

    def compute_loss(self, features, training=False):
        user_embeddings = self.user_model(features["user_id"])
        movie_embeddings = self.movie_model(features["title"])
        
        return self.task(user_embeddings, movie_embeddings)

# Create synthetic interactions
interactions = tf.data.Dataset.from_tensor_slices({
    "user_id": ["0", "1", "2", "1", "2"],
    "title": ["Movie 0", "Movie 1", "Movie 2", "Movie 3", "Movie 4"]
}).batch(32)

# Create and train the model
model = SimpleModel()
model.compile(optimizer=tf.keras.optimizers.Adagrad(0.1))

# Train for 3 epochs
model.fit(interactions, epochs=3)

Epoch 1/3
Epoch 2/3
Epoch 3/3


<keras.callbacks.History at 0x349347400>