# 7.2.3-Subclassing API

In [17]:
import tensorflow as tf, numpy as np
from tensorflow import keras

In [7]:
vocabulary_size = 10000
num_tags = 100
num_departments = 4

In [9]:
num_samples = 1280

In [10]:
# Dummy input data
title_data = np.random.randint(0, 2, size=(num_samples, vocabulary_size))
text_body_data = np.random.randint(0, 2, size=(num_samples, vocabulary_size))
tags_data = np.random.randint(0, 2, size=(num_samples, num_tags))

In [24]:
# Dummy target data
priority_data = np.random.random(size=(num_samples, 1))
department_data = np.random.randint(0, 2, size=(num_samples, num_departments))

In [25]:
class CustomerTicketModel(keras.Model):

    def __init__(self, num_departments):
        super().__init__()
        self.concat_layer = keras.layers.Concatenate()
        self.mixing_layer = keras.layers.Dense(64, activation='relu')
        self.priority_scorer = keras.layers.Dense(1, activation='sigmoid')
        self.department_classifier = keras.layers.Dense(num_departments, activation='softmax')

    def call(self, inputs):
        title=inputs['title']
        text_body = inputs['text_body']
        tags = inputs['tags']

        features = self.concat_layer([title, text_body, tags])
        features = self.mixing_layer(features)
        priority = self.priority_scorer(features)
        department = self.department_classifier(features)
        return priority, department

In [26]:
model = CustomerTicketModel(num_departments=4)

In [27]:
priority, department = model(
    {'title': title_data, 'text_body': text_body_data, 'tags': tags_data}
)

In [28]:
model.compile(optimizer='rmsprop',
             loss=['mean_squared_error', 'categorical_crossentropy'],
             metrics=[['mean_absolute_error'], ['accuracy']])

In [29]:
model.fit(
{
    'title': title_data,
    'text_body': text_body_data,
    'tags': tags_data,
},
[
    priority_data, department_data,
],
epochs=1       )

[1m40/40[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 8ms/step - accuracy: 0.2299 - loss: 47.5817 - mean_absolute_error: 0.4846


<keras.src.callbacks.history.History at 0x225d3313830>

In [30]:
model.evaluate({
    'title': title_data,
    'text_body': text_body_data,
    'tags': tags_data,
},
[
    priority_data,
    department_data,
]
              )

[1m40/40[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - accuracy: 0.2441 - loss: 12.3791 - mean_absolute_error: 0.5052 


[12.388519287109375, 0.5086504817008972, 0.2515625059604645]

In [31]:
priority_preds, department_preds = model.predict({'title': title_data,
                                                 'text_body': text_body_data,
                                                 'tags': tags_data})

[1m40/40[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step
