In [1]:
import tensorflow as tf
import numpy as np
import os
import logging
import datetime
import random

from config import PATH, DATA, TRAINING
from model.mukara import Mukara
from model.dataloader import load_gt
import model.utils as utils


class MukaraTrainer:
    def __init__(self):
        """
        Initializes the MukaraTrainer with model, optimizer, and logging configuration.
        """
        # Set random seeds for reproducibility
        np.random.seed(TRAINING['seed'])
        tf.random.set_seed(TRAINING['seed'])
        random.seed(TRAINING['seed'])

        # Disable GPUs if required
        if TRAINING['use_gpu'] == False:
            tf.config.set_visible_devices([], 'GPU')

        # Configure logging
        logging.basicConfig(
            filename=os.path.join(PATH["evaluate"], 'training_log.log'),
            level=logging.INFO,
            format='%(asctime)s:%(levelname)s:%(message)s',
            filemode='w'
        )

        # Initialize model and optimizer
        self.model = Mukara()
        self.optimizer = tf.keras.optimizers.Adam(learning_rate=TRAINING['lr'])

        self.edge_to_gt, self.scaler = load_gt()
        self.train_ids, self.test_ids = utils.train_test_sampler(
            list(self.edge_to_gt.keys()), TRAINING['train_prop']
        )

    def compute_loss(self, gt, pred, traffic_loss_function):
        """
        Computes the total loss with multiple components.
        """
        traffic_loss = getattr(utils, traffic_loss_function)(gt, pred, self.scaler)
        
        return traffic_loss

    def train_model(self):
        """
        Training loop for the Mukara model.
        Evaluates the model every `TRAINING['eval_interval']` edges trained.
        """
        for epoch in range(TRAINING['epoch']):
            random.shuffle(self.train_ids)

            for i, edge_id in enumerate(self.train_ids, start=1):
                with tf.GradientTape() as tape:
                    pred = self.model(edge_id)
                    if pred:
                        gt = tf.convert_to_tensor(self.edge_to_gt[edge_id], dtype=tf.float32)
                        loss = self.compute_loss(gt, pred, TRAINING['loss_function'])
                    else:
                        continue

                # Compute and apply gradients
                grads = tape.gradient(loss, self.model.trainable_variables)
                grads = [tf.clip_by_value(grad, -TRAINING['clip_gradient'], TRAINING['clip_gradient']) for grad in grads]
                self.optimizer.apply_gradients(zip(grads, self.model.trainable_variables))

                print(f"Training complete for epoch {epoch}, step {i}/{len(self.train_ids)}.")

                # Evaluate model periodically
                if i % TRAINING['eval_interval'] == 0 or i == len(self.train_ids):
                    train_loss = self.evaluate_model(self.train_ids)
                    test_loss = self.evaluate_model(self.test_ids)

                    logging.info(f"Epoch {epoch}, step {i}: Train Loss: {self.format_loss(train_loss)}")
                    logging.info(f"Epoch {epoch}, step {i}: Valid Loss: {self.format_loss(test_loss)}")

            # Save model weights
            self.save_model()

    def evaluate_model(self, ids):
        """
        Evaluates the model on a subset of edge IDs and logs all loss components.
        """
        sampled_ids = random.sample(ids, TRAINING['eval_samples'])
        loss = {}

        for metric in TRAINING['eval_metrics']:
            loss[metric] = 0.0

        for edge_id in sampled_ids:
            pred = self.model(edge_id)
            gt = tf.convert_to_tensor(self.edge_to_gt[edge_id], dtype=tf.float32)

            # Compute traffic loss using different evaluation metrics
            for metric in TRAINING['eval_metrics']:
                traffic_loss = self.compute_loss(gt, pred, metric)
                loss[metric] += traffic_loss.numpy()

        # Normalize by the number of evaluated samples
        num_samples = len(sampled_ids)
        for key in loss.keys():
            loss[key] /= num_samples

        return loss

    @staticmethod
    def format_loss(loss_dict):
        """
        Formats the loss dictionary into a readable string with .6f precision.
        """
        return ", ".join([f"{key}: {value:.6f}" for key, value in loss_dict.items()])

    def save_model(self):
        """
        Saves the model weights.
        """
        formatted_time = datetime.datetime.now().strftime('%Y%m%d-%H%M%S')
        model_filename = os.path.join(PATH["param"], f"{formatted_time}.h5")
        self.model.save_weights(model_filename)
        print("Model saved successfully.")

    def cleanup_and_log_config(self):
        """
        Clears the cache if enabled and logs the configuration.
        """
        if DATA['clear_cache']:
            for filename in os.listdir(PATH['cache']):
                os.remove(os.path.join(PATH['cache'], filename))

        # Close logging handlers
        for handler in logging.root.handlers[:]:
            handler.close()
            logging.root.removeHandler(handler)

        # Append model configuration to log file
        with open('config.py', 'r') as config_file:
            config_content = config_file.read()

        with open(os.path.join(PATH["evaluate"], "training_log.log"), 'a') as log_file:
            log_file.write('\n\n# Contents of config.py\n')
            log_file.write(config_content)


In [2]:

print("Initiating model...")
trainer = MukaraTrainer()


Initiating model...
Loading nodes...
Loading edges...
Grid features shape (653, 573, 14)
Preloading subgraphs to memory...


Loading subgraphs: 100%|██████████| 4510/4510 [14:36<00:00,  5.14file/s]  

Mean: 20259.4824 Standard Deviation: 17446.0566





In [19]:
trainer.train_ids

[114118,
 186995,
 36829,
 137243,
 48013,
 78668,
 151326,
 124434,
 186832,
 20209,
 116583,
 145817,
 129981,
 36750,
 129985,
 44785,
 89794,
 149493,
 59529,
 149685,
 33954,
 78998,
 144676,
 132979,
 122015,
 59387,
 126230,
 151282,
 146939,
 175641,
 110895,
 49335,
 178621,
 168486,
 112827,
 166782,
 112916,
 37582,
 15148,
 86235,
 106103,
 37269,
 168453,
 115527,
 115512,
 132482,
 49164,
 36916,
 183580,
 151337,
 151233,
 49166,
 89870,
 36937,
 178602,
 15122,
 33420,
 39845,
 132978,
 84156,
 22731,
 33860,
 43829,
 146484,
 188656,
 43816,
 44786,
 78722,
 155861,
 75339,
 199325,
 36765,
 77987,
 120906,
 165109,
 123082,
 119079,
 137806,
 122627,
 105750,
 180180,
 36952,
 49476,
 92892,
 92775,
 118434,
 36901,
 78999,
 137822,
 135311,
 84162,
 114484,
 190672,
 81065,
 137182,
 49318,
 53175,
 95300,
 141215,
 106582,
 122657,
 59513,
 43060,
 149817,
 122624,
 41379,
 173317,
 89865,
 194793,
 118500,
 168510,
 137277,
 179883,
 168576,
 36800,
 188932,
 17895

In [20]:
print("Training started...")
trainer.train_model()
trainer.save_model()
print(trainer.model.summary())
trainer.cleanup_and_log_config()

Training started...
{'0': {'node2edge': [{'node': 6799995982, 'edge': '997'}, {'node': 1974389728, 'edge': '824'}, {'node': 30810725, 'edge': '825'}], 'edge2node': [{'node': 304253236, 'edge': ['997']}, {'node': 30810729, 'edge': ['824', '825']}]}, '1': {'node2edge': [{'node': 1986463312, 'edge': '1695'}, {'node': 593893057, 'edge': '1487'}, {'node': 249595810, 'edge': '823'}], 'edge2node': [{'node': 6799995982, 'edge': ['1695']}, {'node': 1974389728, 'edge': ['1487']}, {'node': 30810725, 'edge': ['823']}]}, '2': {'node2edge': [{'node': 304253239, 'edge': '1496'}, {'node': 5442752941, 'edge': '1096'}, {'node': 267844788, 'edge': '1097'}, {'node': 249595811, 'edge': '866'}, {'node': 282563780, 'edge': '867'}], 'edge2node': [{'node': 1986463312, 'edge': ['1496']}, {'node': 593893057, 'edge': ['1096', '1097']}, {'node': 249595810, 'edge': ['866', '867']}]}, '3': {'node2edge': [{'node': 304253638, 'edge': '999'}, {'node': 304253251, 'edge': '1000'}, {'node': 267844800, 'edge': '1668'}, {'n

TypeError: Exception encountered when calling Mukara.call().

[1mTensor is unhashable. Instead, use tensor.ref() as the key.[0m

Arguments received by Mukara.call():
  • edge_id=tf.Tensor(shape=(), dtype=int64)