# 1. Importing

In [1]:
# Standard libraries
import logging
import os
import sys
import time
import re

# External libraries
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

from CONSTANTS import DEVICE, LOG_ROOT, PROJECT_ROOT, SESSION
from models.gru import AttGRUModel
from models.decoder import MLPDecoder
from module.Common import data_iter, generate_tinsts_binary_label
from module.Optimizer import Optimizer
from preprocessing.datacutter.SimpleCutting import cut_by
from preprocessing.Preprocess import Preprocessor
from representations.sequences.statistics import Sequential_TF
from representations.templates.statistics import (
    Template_TF_IDF_without_clean,
)
from utils.Vocab import Vocab

2025-04-01 19:28:24,915 - AttGRU - SESSION_7924cf0409ef6591c1a6984a1406ad92 - INFO: Construct logger for Attention-Based GRU succeeded, current working directory: /Users/wind/Projects/AI/MTALog, logs will be written in /Users/wind/Projects/AI/MTALog/logs
2025-04-01 19:28:24,945 - Preprocessor - SESSION_7924cf0409ef6591c1a6984a1406ad92 - INFO: Construct logger for MTALog succeeded, current working directory: /Users/wind/Projects/AI/MTALog, logs will be written in /Users/wind/Projects/AI/MTALog/logs
2025-04-01 19:28:24,948 - StatisticsRepresentation. - SESSION_7924cf0409ef6591c1a6984a1406ad92 - INFO: Construct logger for Statistics Representation succeeded, current working directory: /Users/wind/Projects/AI/MTALog, logs will be written in /Users/wind/Projects/AI/MTALog/logs
2025-04-01 19:28:24,951 - Statistics_Template_Encoder - SESSION_7924cf0409ef6591c1a6984a1406ad92 - INFO: Construct logger for Statistics Template Encoder succeeded, current working directory: /Users/wind/Projects/AI/M

# 2. Custom default params

## 2.1. Hyper-params

In [2]:
# ========== Embedding Configuration ==========
word2vec_file = "glove.6B.300d.txt"
word2vec_dim = 300

# ========== Meta-Learning Hyperparameters ==========
alpha = 8e-3         # Inner loop learning rate (meta-train)
beta = 1.0           # Outer loop scaling factor (meta-test loss weight)
gamma = 8e-3         # Learning rate for optimizer
lambda_recon = 1.0   # Weight for reconstruction loss in total objective


## 2.2. Network model params

In [3]:
# ========== Model Architecture ==========
lstm_hidden_units = 64   # Hidden size of each GRU direction
num_layers = 4           # Number of GRU layers
dropout_rate = 0.5       # Dropout rate applied to input embeddings

# ========== Training Configuration ==========
training_batch_size = 100    # Mini-batch size for training
num_epochs = 5               # Number of training epochs
prediction_threshold = 0.5   # Threshold for binary anomaly prediction


## 2.3. Dataset params

In [4]:
# ========== Experiment Settings ==========
parser = "IBM"     # Log parser to use (e.g., Drain, Spell, IBM)
mode = "train"     # Mode can be 'train' or 'eval'

# 3. Saving the model

In [5]:
def get_model_and_result_paths(parser: str, project_root: str) -> tuple[str, str]:
    """
    Generate absolute paths for:
        - Trained model checkpoint,
        - Prediction results.

    Args:
        parser (str): Parser name (e.g., "IBM").
        project_root (str): Root directory of the project.

    Returns:
        tuple[str, str]: 
            - output_model_dir: Directory for trained model checkpoints.
            - output_res_dir: Directory for model prediction results.
    """
    output_model_dir = os.path.join(project_root, "outputs", "models", "MTALog", parser, "model")
    output_res_dir = os.path.join(project_root, "outputs", "results", "MTALog", parser, "detect_res")
    
    return output_model_dir, output_res_dir


# Instantiate shared paths
output_model_dir, output_res_dir = get_model_and_result_paths(parser, PROJECT_ROOT)


# 4. Function for updating model

In [6]:
def get_updated_network(old: nn.Module, new: nn.Module, lr: float, load: bool = False) -> nn.Module:
    """
    Apply one manual gradient-based parameter update to a model.
    Typically used in meta-learning inner loops.

    Args:
        old (nn.Module): The original model with gradients.
        new (nn.Module): The new model to receive updated parameters.
        lr (float): Inner-loop learning rate (alpha).
        load (bool): If True, load the updated state directly. Otherwise, assign recursively via put_theta.

    Returns:
        nn.Module: The updated model.
    """
    updated_theta = {}
    current_weights = old.state_dict()
    grad_params = dict(old.named_parameters())

    for key, value in current_weights.items():
        if key in grad_params and grad_params[key].grad is not None:
            updated_theta[key] = grad_params[key] - lr * grad_params[key].grad
        else:
            updated_theta[key] = value

    return new.load_state_dict(updated_theta) if load else put_theta(new, updated_theta)


def put_theta(model: nn.Module, theta: dict) -> nn.Module:
    """
    Recursively assign updated weights to a model.

    Args:
        model (nn.Module): Model to update.
        theta (dict): Dictionary of parameter names to new values.

    Returns:
        nn.Module: Updated model.
    """
    def recursive_assign(module: nn.Module, prefix: str = ""):
        for name, child in module._modules.items():
            new_prefix = f"{prefix}.{name}" if prefix else name
            recursive_assign(child, new_prefix)

        for name, param in module._parameters.items():
            if param is not None:
                key = f"{prefix}.{name}" if prefix else name
                if key in theta:
                    module._parameters[name] = theta[key]

    recursive_assign(model)
    return model


# 5. Logging

## 5.1. Logging config

In [7]:
def setup_logger(name="MTALog", log_file="MTALog.log", level=logging.DEBUG):
    """
    Set up a logger with console and file handlers.

    Args:
        name (str): Name of the logger.
        log_file (str): Log file name (inside LOG_ROOT).
        level (int): Logging level (default: DEBUG).

    Returns:
        logging.Logger: Configured logger.
    """
    logger = logging.getLogger(name)
    logger.setLevel(level)

    formatter = logging.Formatter(
        f"%(asctime)s - %(name)s - {SESSION} - %(levelname)s: %(message)s"
    )

    # Avoid adding handlers multiple times
    if not logger.handlers:
        # Console handler
        console_handler = logging.StreamHandler(sys.stderr)
        console_handler.setLevel(level)
        console_handler.setFormatter(formatter)
        logger.addHandler(console_handler)

        # File handler
        file_handler = logging.FileHandler(os.path.join(LOG_ROOT, log_file))
        file_handler.setLevel(logging.INFO)
        file_handler.setFormatter(formatter)
        logger.addHandler(file_handler)

    logger.info(f"Logger for {name} constructed successfully. Current working directory: {os.getcwd()}. Logs will be written in {LOG_ROOT}.")
    return logger

# Initialize logger
logger = setup_logger()


2025-04-01 19:28:33,300 - MTALog - SESSION_7924cf0409ef6591c1a6984a1406ad92 - INFO: Logger for MTALog constructed successfully. Current working directory: /Users/wind/Projects/AI/MTALog. Logs will be written in /Users/wind/Projects/AI/MTALog/logs.


## 5.2. Log custom params

In [8]:
# Log architecture parameters
logger.info("=== Model Architecture ===")
logger.info(f"LSTM hidden units         : {lstm_hidden_units}")
logger.info(f"Number of GRU layers      : {num_layers}")
logger.info(f"Dropout rate              : {dropout_rate}")
logger.info(f"Latent representation dim : {2 * lstm_hidden_units}")

# Log training hyperparameters
logger.info("=== Training Hyperparameters ===")
logger.info(f"Meta-train step size (alpha)     : {alpha}")
logger.info(f"Meta-test loss weight (beta)     : {beta}")
logger.info(f"Learning rate (gamma)            : {gamma}")
logger.info(f"Reconstruction loss weight       : {lambda_recon}")
logger.info(f"Word2Vec file used               : {word2vec_file}")


2025-04-01 19:28:35,181 - MTALog - SESSION_7924cf0409ef6591c1a6984a1406ad92 - INFO: === Model Architecture ===
2025-04-01 19:28:35,183 - MTALog - SESSION_7924cf0409ef6591c1a6984a1406ad92 - INFO: LSTM hidden units         : 64
2025-04-01 19:28:35,184 - MTALog - SESSION_7924cf0409ef6591c1a6984a1406ad92 - INFO: Number of GRU layers      : 4
2025-04-01 19:28:35,185 - MTALog - SESSION_7924cf0409ef6591c1a6984a1406ad92 - INFO: Dropout rate              : 0.5
2025-04-01 19:28:35,186 - MTALog - SESSION_7924cf0409ef6591c1a6984a1406ad92 - INFO: Latent representation dim : 128
2025-04-01 19:28:35,187 - MTALog - SESSION_7924cf0409ef6591c1a6984a1406ad92 - INFO: === Training Hyperparameters ===
2025-04-01 19:28:35,188 - MTALog - SESSION_7924cf0409ef6591c1a6984a1406ad92 - INFO: Meta-train step size (alpha)     : 0.008
2025-04-01 19:28:35,189 - MTALog - SESSION_7924cf0409ef6591c1a6984a1406ad92 - INFO: Meta-test loss weight (beta)     : 1.0
2025-04-01 19:28:35,189 - MTALog - SESSION_7924cf0409ef6591c1a6

# 6. Import dataset

In [9]:
template_encoder = (
    Template_TF_IDF_without_clean(word2vec_file)
)

2025-04-01 19:28:36,878 - Statistics_Template_Encoder - SESSION_7924cf0409ef6591c1a6984a1406ad92 - INFO: Loading word2vec dict from glove.6B.300d.txt.
2025-04-01 19:28:36,879 - Statistics_Template_Encoder - SESSION_7924cf0409ef6591c1a6984a1406ad92 - INFO: Loading word2vec dict.
100%|██████████| 400000/400000 [00:21<00:00, 18562.52it/s]


In [None]:
def preprocess_data(dataset, parser, cut_func, template_encoder):
    """
    Load and parse log data, segment into train/dev/test sets, and encode templates.

    Args:
        dataset (str): Dataset name (e.g., "HDFS", "BGL").
        parser (str): Parsing method (e.g., "IBM" → Drain parser).
        cut_func (callable): Data splitting strategy (e.g., cut_by).
        template_encoder (object): Encoder with `.present()` method for embedding templates.

    Returns:
        tuple: (train_data, valid_data, test_data, processor)
    """
    processor = Preprocessor()
    train_data, valid_data, test_data = processor.process(
        dataset=dataset,
        parsing=parser,
        cut_func=cut_func,
        template_encoding=template_encoder.present,
    )
    return train_data, valid_data, test_data, processor


def encode_log_sequences(processor, train_data, test_data=None):
    """
    Encode log sequences using template-based sequential TF encoder.

    Args:
        processor (Preprocessor): Contains template embeddings.
        train_data (list[Instance]): Training instances.
        test_data (list[Instance], optional): Optional test set.

    Returns:
        tuple: Updated (train_data, test_data) with `.repr` as semantic vector.
    """
    sequential_encoder = Sequential_TF(processor.embedding)

    train_reprs = sequential_encoder.present(train_data)
    for i, inst in enumerate(train_data):
        inst.repr = train_reprs[i]

    if test_data is not None:
        test_reprs = sequential_encoder.present(test_data)
        for i, inst in enumerate(test_data):
            inst.repr = test_reprs[i]
        return train_data, test_data

    return train_data, None


def encode_log_sequences_with_gru(model, vocab, instances):
    """
    Encode log sequences into latent vectors using AttGRUModel.

    Args:
        model (AttGRUModel): Encoder with attention GRU.
        vocab (Vocab): Vocabulary used for token indexing.
        instances (list[Instance]): List of log instances.

    Returns:
        list[Instance]: Same list with `.repr` updated from latent space.
    """
    model.eval()
    with torch.no_grad():
        for batch in data_iter(instances, batch_size=128, shuffle=False):
            tinst = generate_tinsts_binary_label(batch, vocab)
            tinst.to(DEVICE)

            _, _, latent = model(tinst.inputs)
            for i, inst in enumerate(batch):
                inst.repr = latent[i].detach().cpu().numpy()

    return instances

## 6.1. Import Target dataset

In [None]:
# === Few-shot Setup for TARGET System ===

TARGET_SYSTEM = "BGL"
FEWSHOT_NORMAL_RATIO = 0.01  # 1% of normal logs for prototype

logger.info(f"Preparing target system: {TARGET_SYSTEM} with few-shot ratio {FEWSHOT_NORMAL_RATIO}")

# Step 1: Preprocess BGL with only normal logs in support set
cut_func = cut_by(train=FEWSHOT_NORMAL_RATIO, dev=0.0, anomalous_rate=0.0)
support_set, _, query_set, processor_target = preprocess_data(
    dataset=TARGET_SYSTEM,
    parser=parser,
    cut_func=cut_func,
    template_encoder=template_encoder
)

# Step 2: Load target system vocabulary
vocab_target = Vocab()
vocab_target.load_from_dict(processor_target.embedding)

# Step 3: Initialize separate encoder for the target system
encoder_target = AttGRUModel(
    vocab=vocab_target,
    lstm_layers=num_layers,
    lstm_hiddens=lstm_hidden_units,
    dropout=dropout_rate,
)
encoder_target = encoder_target.to(DEVICE)

# Step 4: Encode support set (used to compute prototype) and query set (to classify)
encoded_support_set = encode_log_sequences_with_gru(encoder_target, vocab_target, support_set)
encoded_query_set = encode_log_sequences_with_gru(encoder_target, vocab_target, query_set)

# Final output
prototype_support_set = encoded_support_set
evaluation_query_set = encoded_query_set

logger.info(f"Target system '{TARGET_SYSTEM}' prepared — Support: {len(prototype_support_set)}, Query: {len(evaluation_query_set)}")

## 6.2. Import Source dataset

In [None]:
from collections import OrderedDict
import random

source_systems = ["HDFS", "OpenStack"]

source_processors = OrderedDict()
source_vocabularies = OrderedDict()
source_encoders = OrderedDict()
source_support_sets = OrderedDict()
source_query_sets = OrderedDict()

for system in source_systems:
    print(f"=== Processing source system: {system} ===")

    # Step 1: Preprocess data (normal + abnormal)
    cut_func = cut_by(train=1.0, dev=0.0, anomalous_rate=1.0)
    train_data, _, _, processor = preprocess_data(system, parser, cut_func, template_encoder)
    source_processors[system] = processor

    # Step 2: Load vocab
    vocab = Vocab()
    vocab.load_from_dict(processor.embedding)
    source_vocabularies[system] = vocab

    # Step 3: Init encoder
    encoder = AttGRUModel(
        vocab=vocab,
        lstm_layers=num_layers,
        lstm_hiddens=lstm_hidden_units,
        dropout=dropout_rate,
    ).to(DEVICE)
    source_encoders[system] = encoder

    # Step 4: Encode all training logs
    encoded_data = encode_log_sequences_with_gru(encoder, vocab, train_data)

    # Step 5: Split encoded data support/query for meta-task
    split_index = int(0.5 * len(encoded_data))  # 50/50
    random.shuffle(encoded_data)
    support_set = encoded_data[:split_index]
    query_set = encoded_data[split_index:]

    source_support_sets[system] = support_set
    source_query_sets[system] = query_set


# 7. Training

## 7.1. MetaLog class

In [None]:
class MTALog:
    """
    Meta-learning model for log-based anomaly detection using prototype learning and autoencoding.
    This version decouples encoder from MTALog, allowing encoder/vocab to be system-specific and passed in dynamically.
    """

    def __init__(self, recon_weight=1.0, proto_weight=1.0):
        """
        Args:
            recon_weight (float): Weight for reconstruction loss.
            proto_weight (float): Weight for prototype loss.
        """
        self.recon_weight = recon_weight
        self.proto_weight = proto_weight
        self.test_batch_size = 1024
        self.cls_loss_fn = nn.BCELoss()
        self.recon_loss_fn = nn.MSELoss()

    def compute_prototype(self, instances):
        """
        Compute prototype vector (mean of latent vectors) from support set.

        Args:
            instances (list): List of support instances (must have `.repr`).

        Returns:
            Tensor: Prototype vector [latent_dim].
        """
        vecs = [torch.tensor(inst.repr, dtype=torch.float32).to(inst.device) for inst in instances]
        return torch.stack(vecs).mean(dim=0)

    def forward_with_proto(self, encoder, inputs, prototype):
        """
        Unsupervised forward pass using prototype and reconstruction loss.

        Args:
            encoder (nn.Module): Encoder model.
            inputs (tuple): Model input (token_ids, masks, lengths).
            prototype (Tensor): Prototype vector.

        Returns:
            tuple: total_loss, latent, recon
        """
        _, recon, latent = encoder(inputs)
        recon_loss = self.recon_loss_fn(recon, latent)
        proto_loss = F.mse_loss(latent, prototype.expand_as(latent))
        total_loss = self.proto_weight * proto_loss + self.recon_weight * recon_loss
        return total_loss, latent, recon

    def forward(self, encoder, inputs, targets):
        """
        Supervised forward pass using classification + reconstruction loss.

        Args:
            encoder (nn.Module): Encoder model.
            inputs (tuple): Input tokens, masks, lengths.
            targets (Tensor): Target class labels.

        Returns:
            tuple: total_loss, tag_logits, latent, recon
        """
        tag_logits, recon, latent = encoder(inputs)
        tag_probs = F.softmax(tag_logits, dim=1)
        cls_loss = self.cls_loss_fn(tag_probs, targets)
        recon_loss = self.recon_loss_fn(recon, latent)
        total_loss = cls_loss + self.recon_weight * recon_loss
        return total_loss, tag_logits, latent, recon

    def predict(self, encoder, inputs, prototype, threshold=None):
        """
        Predict anomaly using distance to prototype.

        Args:
            encoder (nn.Module): Trained encoder.
            inputs (tuple): Input tokens, masks, lengths.
            prototype (Tensor): Prototype vector.
            threshold (float): Threshold for distance-based decision.

        Returns:
            tuple: pred_labels (Tensor), distances (Tensor)
        """
        with torch.no_grad():
            _, _, latent = encoder(inputs)
            distances = torch.norm(latent - prototype.expand_as(latent), dim=1)
            pred_tags = (distances > threshold).long() if threshold is not None else None
        return pred_tags, distances

    def evaluate(self, encoder, vocab, dataset_name, instances, prototype, threshold=0.5):
        """
        Evaluate performance using prototype distance.

        Args:
            encoder (nn.Module): Encoder model.
            vocab (Vocab): Vocab object for encoding.
            dataset_name (str): For logging.
            instances (list): Input instances.
            prototype (Tensor): Support-set prototype.
            threshold (float): Anomaly threshold.

        Returns:
            tuple: precision, recall, f1_score (in %).
        """
        TP = TN = FP = FN = 0
        encoder.eval()

        with torch.no_grad():
            for batch in data_iter(instances, self.test_batch_size, shuffle=False):
                tinst = generate_tinsts_binary_label(batch, vocab, is_train=False)
                tinst.to(prototype.device)
                pred_tags, _ = self.predict(encoder, tinst.inputs, prototype, threshold)
                gold = [1 if inst.label == "Anomalous" else 0 for inst in batch]

                for p, g in zip(pred_tags, gold):
                    if p == 1 and g == 1: TP += 1
                    elif p == 0 and g == 0: TN += 1
                    elif p == 1 and g == 0: FP += 1
                    elif p == 0 and g == 1: FN += 1

        precision = 100 * TP / (TP + FP) if TP + FP > 0 else 0.0
        recall = 100 * TP / (TP + FN) if TP + FN > 0 else 0.0
        f1 = 2 * precision * recall / (precision + recall) if precision + recall > 0 else 0.0
        logger.info(f"[{dataset_name}] Precision: {precision:.2f}, Recall: {recall:.2f}, F1: {f1:.2f}")
        return precision, recall, f1


In [None]:
# Instantiate the MTALog class with loss weights only
metalog = MTALog(
    recon_weight=1.0,    # Weight for reconstruction loss
    proto_weight=1.0     # Weight for prototype loss (important in few-shot meta-test)
)

## 7.2. Model saving

In [None]:
# Create the model output directory if it doesn't exist
os.makedirs(output_model_dir, exist_ok=True)

# Construct filenames for best and last model checkpoints
info = f"layer={num_layers}_hidden={lstm_hidden_units}_dropout={dropout_rate}_epoch={num_epochs}"
best_model_file = os.path.join(output_model_dir, f"{info}_best.pt")
last_model_file = os.path.join(output_model_dir, f"{info}_last.pt")


## 7.3. Training

In [None]:
import random

def split_support_query(instances, ratio=0.5):
    """
    Split a list of instances into support and query sets.

    Args:
        instances (list): List of instances (already encoded).
        ratio (float): Proportion for support set (e.g., 0.5 means 50% support, 50% query).

    Returns:
        tuple: (support_set, query_set)
    """
    total = len(instances)
    indices = list(range(total))
    random.shuffle(indices)
    
    split_point = int(total * ratio)
    support_indices = indices[:split_point]
    query_indices = indices[split_point:]

    support_set = [instances[i] for i in support_indices]
    query_set = [instances[i] for i in query_indices]

    return support_set, query_set


In [None]:
if mode == "train":
    # Optimizer for encoder (not decoder)
    optimizer = Optimizer(
        filter(lambda p: p.requires_grad, metalog.model.parameters()), lr=gamma
    )

    global_step = 0
    best_f1_score = 0

    for epoch in range(1, num_epochs + 1):
        metalog.model.train()
        metalog.bk_model.train()
        logger.info(f"Epoch {epoch} | Start time: {time.strftime('%H:%M:%S')} | Alpha={alpha}, Beta={beta}, Gamma={gamma}")

        # Loaders for each source system
        meta_train_loaders = {
            sys_name: data_iter(source_query_sets[sys_name], training_batch_size, shuffle=True)
            for sys_name in source_systems
        }
        meta_test_loader = data_iter(evaluation_query_set, training_batch_size, shuffle=True)

        total_batches = max(
            len(evaluation_query_set) // training_batch_size,
            max(len(source_query_sets[s]) // training_batch_size for s in source_systems)
        )

        for _ in range(total_batches):
            optimizer.zero_grad()

            # === Meta-Train ===
            src = np.random.choice(source_systems)
            try:
                meta_train_batch = next(meta_train_loaders[src])
            except StopIteration:
                meta_train_loaders[src] = data_iter(source_query_sets[src], training_batch_size, True)
                meta_train_batch = next(meta_train_loaders[src])

            support_set, query_set = split_support_query(meta_train_batch, ratio=0.5)
            support_normal = [inst for inst in support_set if inst.label == "Normal"]
            if len(support_normal) == 0:
                continue

            prototype = metalog.compute_prototype(support_normal)

            tinst_query = generate_tinsts_binary_label(query_set, source_vocabularies[src], is_train=True)
            tinst_query.to(DEVICE)
            train_loss, _, _, _ = metalog.forward_with_proto(
                inputs=tinst_query.inputs,
                prototype=prototype,
                encoder=source_encoders[src]
            )
            train_loss_value = train_loss.item()
            train_loss.backward(retain_graph=True)

            # === Update backup model ===
            metalog.bk_model = (
                get_updated_network(metalog.model, metalog.bk_model, alpha)
                .train()
                .to(DEVICE)
            )

            # === Meta-Test ===
            try:
                meta_test_batch = next(meta_test_loader)
            except StopIteration:
                meta_test_loader = data_iter(evaluation_query_set, training_batch_size, True)
                meta_test_batch = next(meta_test_loader)

            test_support, test_query = split_support_query(meta_test_batch, ratio=0.5)
            support_normal_test = [inst for inst in test_support if inst.label == "Normal"]
            if len(support_normal_test) == 0:
                continue

            prototype_test = metalog.compute_prototype(support_normal_test)

            tinst_test_query = generate_tinsts_binary_label(test_query, vocab_target, is_train=True)
            tinst_test_query.to(DEVICE)

            test_loss, _, _, _ = metalog.bk_forward_with_proto(
                inputs=tinst_test_query.inputs,
                prototype=prototype_test,
                encoder=encoder_target
            )
            test_loss = beta * test_loss
            test_loss_value = test_loss.item() / beta
            test_loss.backward()

            optimizer.step()
            global_step += 1

            if global_step % 10 == 0:
                logger.info(
                    f"Step {global_step} | Epoch {epoch} | Src: {src} | Train loss: {train_loss_value:.4f} | Test loss: {test_loss_value:.4f}"
                )

        # === Evaluate at the end of epoch ===
        if evaluation_query_set:
            _, _, f1_score = metalog.evaluate(
                dataset_name="Test",
                instances=evaluation_query_set,
                prototype=prototype_test,
                encoder=encoder_target,
                vocab=vocab_target
            )
            if f1_score > best_f1_score:
                logger.info(f"New best F1: {f1_score:.2f} (prev {best_f1_score:.2f})")
                torch.save(metalog.model.state_dict(), best_model_file)
                best_f1_score = f1_score

        logger.info(f"Epoch {epoch} finished.")
        torch.save(metalog.model.state_dict(), last_model_file)


# 8. Evaluate

In [None]:
# === Evaluate last model ===
if os.path.exists(last_model_file):
    logger.info("=== Evaluating Final (Last) Model ===")
    state_dict = torch.load(last_model_file, map_location=DEVICE)
    metalog.model.load_state_dict(state_dict)
    metalog.model.to(DEVICE)
    metalog.model.eval()
    metalog.evaluate("Final Model on Test BGL", evaluation_query_set)

# === Evaluate best model ===
if os.path.exists(best_model_file):
    logger.info("=== Evaluating Best Model ===")
    state_dict = torch.load(best_model_file, map_location=DEVICE)
    metalog.model.load_state_dict(state_dict)
    metalog.model.to(DEVICE)
    metalog.model.eval()
    metalog.evaluate("Best Model on Test BGL", evaluation_query_set)

logger.info("All evaluations finished!")


# 9. Export to graph

## 9.1. Constanst

In [None]:
STATISTICS_TEMPLATE_LOG_PATH = "logs/Statistics_Template.log"
MTALOG_LOG_PATH = "logs/MTALog.log"

## 9.2. Extracting

In [None]:
def extract_word2vec_file(log_path, session):
    """Extract the word2vec file path from the statistics template log."""
    with open(log_path, "r") as file:
        for line in file:
            match = re.search(rf"^.+ - Statistics_Template_Encoder - {session} - INFO: Loading word2vec dict from (.+)\.$", line)
            if match:
                return match.group(1)

def extract_f1_scores(log_path, session):
    """Extract train and test F1 scores from the MTALog."""
    train_f1_scores, test_f1_scores = [], []
    with open(log_path, "r") as file:
        lines = file.readlines()
        for line in lines:
            train_match = re.search(rf"^.+ - MTALog - {session} - INFO: Train: F1 score = (.+) \| Precision = (.+) \| Recall = (.+)$", line)
            test_match = re.search(rf"^.+ - MTALog - {session} - INFO: Test: F1 score = (.+) \| Precision = (.+) \| Recall = (.+)$", line)

            if train_match:
                train_f1_scores.append(float(train_match.group(1)))
            if test_match:
                test_f1_scores.append(float(test_match.group(1)))
    return train_f1_scores, test_f1_scores

def extract_meta_losses(log_path, session):
    """Extract meta-train and meta-test losses from the MTALog."""
    meta_train_losses, meta_test_losses = [], []
    with open(log_path, "r") as file:
        lines = file.readlines()
        for line in lines:
            match = re.search(rf"^.* - MTALog - {session} - INFO: Step: (.+) \| Epoch: (.+) \| Meta-train loss: (.+) \| Meta-test loss: (.+)\.$", line)
            if match:
                meta_train_losses.append(float(match.group(3)))
                meta_test_losses.append(float(match.group(4)))

    return meta_train_losses, meta_test_losses

# Extract the word2vec file path
word2vec_file = extract_word2vec_file(STATISTICS_TEMPLATE_LOG_PATH, SESSION)
title = f"BILATERAL GENERALIZATION TRANSFERRING HDFS TO BGL USING {word2vec_file}"

# Extract F1 scores and losses
train_f1_scores, test_f1_scores = extract_f1_scores(MTALOG_LOG_PATH, SESSION)
meta_train_losses, meta_test_losses = extract_meta_losses(MTALOG_LOG_PATH, SESSION)

print(train_f1_scores, test_f1_scores)
print(meta_train_losses, meta_test_losses)

## 9.3. Plotting

In [None]:
def plot_f1_scores(ax, num_epochs, train_f1, test_f1):
    """Plot train and test F1 scores on the provided axis."""
    ax.set_ylim(0, 110)
    ax.plot(num_epochs, train_f1, color="tab:blue")
    ax.plot(num_epochs, test_f1, color="tab:orange")
    ax.legend(["Train", "Test"])
    ax.set_xlabel("Epoch")
    ax.set_ylabel("F1 Score")

    for i in range(len(num_epochs)):
        ax.plot(num_epochs[i], train_f1[i], "o", color="tab:blue", zorder=10)
        ax.text(num_epochs[i], train_f1[i] + 5, round(train_f1[i], 2), ha="center")

        ax.plot(num_epochs[i], test_f1[i], "o", color="tab:orange", zorder=10)
        ax.text(num_epochs[i], test_f1[i] - 10, round(test_f1[i], 2), ha="center")

def plot_meta_losses(ax, num_steps, meta_train_losses, meta_test_losses):
    """Plot meta-train and meta-test losses on the provided axis."""
    ax.plot(num_steps, meta_train_losses, color="tab:blue")
    ax.plot(num_steps, meta_test_losses, color="tab:orange")
    ax.legend(["Meta-train loss", "Meta-test loss"])
    ax.set_xlabel("Step")
    ax.set_ylabel("Loss")
    
# Plot F1 scores and losses
fig, axs = plt.subplots(2, 1, figsize=(16, 8))
num_epochs = list(range(len(train_f1_scores)))
num_steps = [i * 10 for i in range(len(meta_train_losses))]

#plot_f1_scores(axs[0], num_epochs, train_f1_scores, test_f1_scores)
plot_meta_losses(axs[1], num_steps, meta_train_losses, meta_test_losses)

## 9.4. Saving and exporting

In [None]:
# Set the title for the plot
best_test_f1_score = max(test_f1_scores)
fig_title = f"{title}\nBest model F1 Score = {best_test_f1_score}\nLSTM hidden units = {lstm_hidden_units} | Layers = {num_layers} | Drop out = {dropout_rate} | Alpha = {alpha} | Beta = {beta} | Gamma = {gamma}"
fig.suptitle(fig_title)

# Define the path to save the plot
plot_dir = os.path.join("visualization", "graphs")
plot_filename = f"{alpha}-{beta}-{gamma}-{word2vec_file}-{SESSION}.png"
plot_path = os.path.join(plot_dir, plot_filename)

# Ensure the directory exists
os.makedirs(plot_dir, exist_ok=True)

# Save the plot
fig.savefig(plot_path)