In [None]:
# import all necessary libraries
import numpy as np
import pandas as pd
import datetime
import torch
import torch.utils
from datetime import timezone
import time

In [None]:
# import loader, model and utils
from Synthetic_with_attribute_data_loader import *
from DyRep import *
from utils import *

In [None]:
# 1.first to change: data direction
train_set = SyntheticAttributeDataset('train', data_dir='/simulated_data/final_hawkes_with_features.csv')
test_set = SyntheticAttributeDataset('test',  data_dir='/simulated_data/final_hawkes_with_features.csv')
initial_embeddings = np.random.randn (train_set.N_nodes, 32)
A_initial = train_set.get_Adjacency()

In [None]:
print("Train set preview (first 5 events):")
for event in train_set.all_events[:5]:
    print(event)

print("\nTest set preview (first 5 events):")
for event in test_set.all_events[:5]:
    print(event)

In [None]:
N_nodes = A_initial.shape[0]
if A_initial.ndim == 1 and A_initial.size == N_nodes * N_nodes:
    A_initial = A_initial.reshape(N_nodes, N_nodes)[:, :, None]  # Reshape and add relationship type dimension
elif A_initial.ndim == 1:  # If it's just a vector that doesn't match the expected size
    # Initialize A_initial as a zero matrix with an extra dimension for types
    A_initial = np.zeros((N_nodes, N_nodes, 1))
n_assoc_types,n_event_types = 1, 1
n_relations = n_assoc_types + n_event_types

Adj_all = train_set.get_Adjacency()[0]

if not isinstance(Adj_all, list):
    Adj_all = [Adj_all]

node_degree_global = []
for rel, A in enumerate(Adj_all):
    node_degree_global.append(np.zeros(A.shape[0]))
    for u in range(A.shape[0]):
        node_degree_global[rel][u] = np.sum(A[u])

Adj_all = Adj_all[0]
print("A_initial dimensions:", A_initial.shape)

In [None]:
# Instantiate the model
# Baseline model
# model = DyRep_update(
#     node_embeddings=initial_embeddings,
#     A_initial=A_initial,
#     N_surv_samples=5,
#     n_hidden=32,
#     node_degree_global= node_degree_global,
#     N_hops=1,
# )    

In [None]:
model = DyRep_update(
    node_embeddings=initial_embeddings,
    A_initial=A_initial,
    N_surv_samples=5,
    n_hidden=32,
    node_degree_global= node_degree_global,
    N_hops=2,
    with_attributes=True,
    gamma=0.5,
)    

In [None]:
from torch.utils.data import DataLoader
train_loader = DataLoader(train_set, batch_size=200, shuffle=False)
test_loader = DataLoader(test_set, batch_size=200, shuffle=False)

In [None]:
def test(model, n_test_batches=10, epoch=0):
    model.eval()
    loss = 0
    losses =[ [np.Inf, 0], [np.Inf, 0] ]
    n_samples = 0
    # Time slots with 10 days intervals as in the DyRep paper
    timeslots = [t.toordinal() for t in test_loader.dataset.TEST_TIMESLOTS]
    event_types = list(test_loader.dataset.event_types_num.keys()) #['comm', 'assoc']
    # sort it by k
    for event_t in test_loader.dataset.event_types_num:
        event_types[test_loader.dataset.event_types_num[event_t]] = event_t

    event_types += ['Com']

    mar, hits_10 = {}, {}
    for event_t in event_types:
        mar[event_t] = []
        hits_10[event_t] = []
        for c, slot in enumerate(timeslots):
            mar[event_t].append([])
            hits_10[event_t].append([])


    start = time.time()
    with torch.no_grad():
        
        import datetime
        #from datetime import datetime, timezone 
        for batch_idx, data in enumerate(test_loader):
            data[2] = data[2].float()
            data[4] = data[4].double()
            data[5] = data[5].double()
            output = model(data)
            loss += (-torch.sum(torch.log(output[0]) + 1e-10) + torch.sum(output[1])).item()
            for i in range(len(losses)):
                m1 = output[i].min()
                m2 = output[i].max()
                if m1 < losses[i][0]:
                    losses[i][0] = m1
                if m2 > losses[i][1]:
                    losses[i][1] = m2
            n_samples += 1
            A_pred, Survival_term = output[2]
            u, v, k = data[0], data[1], data[3]

            time_cur = data[5]
            m, h = MAR(A_pred, u, v, k, Survival_term=Survival_term)
            assert len(time_cur) == len(m) == len(h) == len(k)
            for t, m, h, k_ in zip(time_cur, m, h, k):
                d = datetime.datetime.fromtimestamp(t.item()).toordinal()
                event_t = event_types[k_.item()]
                for c, slot in enumerate(timeslots):
                    if d <= slot:
                        mar[event_t][c].append(m)
                        hits_10[event_t][c].append(h)
                        if k_ > 0:
                            mar['Com'][c].append(m)
                            hits_10['Com'][c].append(h)
                        if c > 0:
                            assert slot > timeslots[c-1] and d > timeslots[c-1], (d, slot, timeslots[c-1])
                        break

            if batch_idx % 10 == 0:
                print('test', batch_idx)

            if n_test_batches is not None and batch_idx >= n_test_batches - 1:
                break

    time_iter = time.time() - start

    print('\nTEST batch={}/{}, loss={:.3f}, psi={}, loss1 min/max={:.4f}/{:.4f}, '
          'loss2 min/max={:.4f}/{:.4f}, integral time stamps={}, sec/iter={:.4f}'.
          format(batch_idx + 1, len(test_loader), (loss / n_samples),
                 [model.psi[c].item() for c in range(len(model.psi))],
                 losses[0][0], losses[0][1], losses[1][0], losses[1][1],
                 len(model.Lambda_dict), time_iter / (batch_idx + 1)))

    # Report results for different time slots in the test set
    for c, slot in enumerate(timeslots):
        s = 'Slot {}: '.format(c)
        for event_t in event_types:
            sfx = '' if event_t == event_types[-1] else ', '
            if len(mar[event_t][c]) > 0:
                s += '{} ({} events): MAR={:.2f}+-{:.2f}, HITS_10={:.3f}+-{:.3f}'.\
                    format(event_t, len(mar[event_t][c]), np.mean(mar[event_t][c]), np.std(mar[event_t][c]),
                            np.mean(hits_10[event_t][c]), np.std(hits_10[event_t][c]))
            else:
                s += '{} (no events)'.format(event_t)
            s += sfx
        print(s)

    mar_all, hits_10_all = {}, {}
    for event_t in event_types:
        mar_all[event_t] = []
        hits_10_all[event_t] = []
        for c, slot in enumerate(timeslots):
            mar_all[event_t].extend(mar[event_t][c])
            hits_10_all[event_t].extend(hits_10[event_t][c])

    s = 'Epoch {}: results per event type for all test time slots: \n'.format(epoch)
    print(''.join(['-']*100))
    for event_t in event_types:
        if len(mar_all[event_t]) > 0:
            s += '====== {:10s}\t ({:7s} events): \tMAR={:.2f}+-{:.2f}\t HITS_10={:.3f}+-{:.3f}'.format(
                str(event_t),  # Ensure event_t is a string
                str(len(mar_all[event_t])),  # Ensure this is also a string if it isn't already
                np.mean(mar_all[event_t]),
                np.std(mar_all[event_t]),
                np.mean(hits_10_all[event_t]),
                np.std(hits_10_all[event_t])
            )
        else:
            s += '====== {:10s}\t (no events)'.format(str(event_t))  # Ensure event_t is a string
        if event_t != event_types[-1]:
            s += '\n'
    print(s)
    print(''.join(['-'] * 100))

    return mar_all, hits_10_all, loss / n_samples

In [None]:
params_main = []
for name, param in model.named_parameters():
    if param.requires_grad:
        params_main.append(param)

In [None]:
optimizer = torch.optim.Adam([{"params": params_main, "weight_decay":0}], lr=0.0002, betas=(0.5, 0.999))
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, '10', gamma=0.5)

In [None]:
time_bar = np.zeros((N_nodes, 1)) + train_set.FIRST_DATE.timestamp()

In [None]:
import copy
def get_temporal_variables():
    variables = {}
    variables['time_bar'] = copy.deepcopy(time_bar)
    variables['node_degree_global'] = copy.deepcopy(node_degree_global)
    variables['time_keys'] = copy.deepcopy(model.time_keys)
    variables['z'] = model.z.clone()
    variables['S'] = model.S.clone()
    variables['A'] = model.A.clone()
    variables['Lambda_dict'] = model.Lambda_dict.clone()
    return variables

In [None]:
def set_temporal_variables(variables, model, train_loader, test_loader):
    time_bar = copy.deepcopy(variables['time_bar'])
    train_loader.dataset.time_bar = time_bar
    test_loader.dataset.time_bar = time_bar
    model.node_degree_global = copy.deepcopy(variables['node_degree_global'])
    model.time_keys = copy.deepcopy(variables['time_keys'])
    model.z = variables['z'].clone()
    model.S = variables['S'].clone()
    model.A = variables['A'].clone()
    model.Lambda_dict = variables['Lambda_dict'].clone()
    return time_bar

In [None]:
epoch_start = 1
epochs = 100
batch_start = 0
batch_size = 200
weight = 1
log_interval = 20
loss_l,losses_events, losses_nonevents, losses_sum = [], [], [], []
test_MAR, test_HITS10, test_loss = [], [], []

In [None]:
class EarlyStopping:
    def __init__(self, patience=7, verbose=False, delta=0):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement. 
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta

    def __call__(self, val_loss, model):
        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.verbose:
                print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), 'checkpoint.pt')
        self.val_loss_min = val_loss

In [None]:
early_stopping = EarlyStopping(patience=10, verbose=True)

In [None]:
# Training loop
for epoch in range(epoch_start, epochs + 1):
    if epoch > epoch_start:
        time_bar, node_degree_global = initialize_state(
            dataset=train_loader.dataset,
            model=model,
            node_embeddings=initial_embeddings,
            keepS=epoch > 1  # Only keep states after the first epoch if needed
        )
        model.node_degree_global = node_degree_global
        
    # Setting the global time_bar for the datasets
    train_loader.dataset.time_bar = time_bar
    test_loader.dataset.time_bar = time_bar

    start = time.time()

    for batch_idx, data_batch in enumerate(train_loader):
        model.train()
        optimizer.zero_grad()

        # Ensure the data is in the correct format
        data_batch[2] = data_batch[2].float()
        data_batch[4] = data_batch[4].double()
        data_batch[5] = data_batch[5].double()

        output = model(data_batch)
        losses = [-torch.sum(torch.log(output[0]) + 1e-10), weight * torch.sum(output[1])]

        # KL losses (if there are additional items in output to process as losses)
        if len(output) > 3 and output[-1] is not None:
            losses.extend(output[-1])

        loss = torch.sum(torch.stack(losses)) / batch_size
        loss.backward()
        nn.utils.clip_grad_value_(model.parameters(), 100)
        optimizer.step()
        losses_events.append(losses[0].item())
        losses_nonevents.append(losses[1].item())
        losses_sum.append(loss.item())

        # Clamping psi to prevent numerical overflow
        model.psi.data = torch.clamp(model.psi.data, 1e-1, 1e+3)

        time_iter = time.time() - start

        # Detach computational graph to prevent unwanted backprop
        model.z = model.z.detach()
        model.S = model.S.detach()

        if (batch_idx + 1) % log_interval == 0 or batch_idx == len(train_loader) - 1:
            print(f'\nTRAIN epoch={epoch}/{epochs}, batch={batch_idx + 1}/{len(train_loader)}, '
                  f'sec/iter: {time_iter / (batch_idx + 1):.4f}, loss={loss.item():.3f}, '
                  f'loss components: {[l.item() for l in losses]}')
            loss_l.append(loss.item())

            # Save state before testing
            variables = get_temporal_variables()
            print('time', datetime.datetime.fromtimestamp(np.max(time_bar)))

            # Testing and collecting results
            result = test(model, n_test_batches=None if batch_idx == len(train_loader) - 1 else 10, epoch=epoch)
            test_MAR.append(np.mean(result[0]['Com']))
            test_HITS10.append(np.mean(result[1]['Com']))
            test_loss.append(result[2])
            early_stopping(result[2], model)
            if early_stopping.early_stop:
                print("Early stopping")
                break

            # Restore state after testing
            time_bar = set_temporal_variables(variables, model, train_loader, test_loader)

    scheduler.step()

In [None]:
def grid_search_bg(data_dir, gamma, epochs):
    """
    Performs grid search to find the best decay_rate for the model based on the lowest average test loss.
    
    Args:
        model (torch.nn.Module): The PyTorch model.
        train_loader (DataLoader): DataLoader for the training data.
        test_loader (DataLoader): DataLoader for the testing data.
        decay_rates (list of float): List of decay rates to evaluate.
        epochs (int): Number of epochs to train the model for each decay rate.
        initial_embeddings (Tensor): Initial node embeddings to reset model for each decay rate test.
        device (str): Device on which to perform computations ('cpu' or 'cuda').

    Returns:
        best_decay_rate (float): The decay rate yielding the lowest test loss.
        best_loss (float): The lowest test loss achieved.
    """
    best_loss = float('inf')

    train_set = SyntheticAttributeDataset('train', data_dir=data_dir)
    test_set = SyntheticAttributeDataset('test',  data_dir=data_dir)
    initial_embeddings = np.random.randn (train_set.N_nodes, 32)
    A_initial = train_set.get_Adjacency()
    
    N_nodes = A_initial.shape[0]
    if A_initial.ndim == 1 and A_initial.size == N_nodes * N_nodes:
        A_initial = A_initial.reshape(N_nodes, N_nodes)[:, :, None]  # Reshape and add relationship type dimension
    elif A_initial.ndim == 1:  # If it's just a vector that doesn't match the expected size
        # Initialize A_initial as a zero matrix with an extra dimension for types
        A_initial = np.zeros((N_nodes, N_nodes, 1))


    Adj_all = train_set.get_Adjacency()[0]

    if not isinstance(Adj_all, list):
        Adj_all = [Adj_all]

    node_degree_global = []
    for rel, A in enumerate(Adj_all):
        node_degree_global.append(np.zeros(A.shape[0]))
        for u in range(A.shape[0]):
            node_degree_global[rel][u] = np.sum(A[u])

    Adj_all = Adj_all[0]
    
    
    for g in gamma:
        print(f"Testing gamma: {g}")
        
        # Reset model for each decay rate test
    
        test_loss = []
        
        model = DyRep_update(
        node_embeddings=initial_embeddings,
        A_initial=A_initial,
        n_hidden=32,
        node_degree_global=node_degree_global,
        N_hops=2,
        with_attributes=True,
        gamma=g
        )
        params_main = []
        for name, param in model.named_parameters():
            if param.requires_grad:
                params_main.append(param)
        train_loader = DataLoader(train_set, batch_size=200, shuffle=False)
        test_loader = DataLoader(test_set, batch_size=200, shuffle=False)
        optimizer = torch.optim.Adam([{"params": params_main, "weight_decay":0}], lr=0.0002, betas=(0.5, 0.999))
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, '10', gamma=0.5)
        time_bar = np.zeros((N_nodes, 1)) + train_set.FIRST_DATE.timestamp()
        import copy
        def get_temporal_variables():
            variables = {}
            variables['time_bar'] = copy.deepcopy(time_bar)
            variables['node_degree_global'] = copy.deepcopy(node_degree_global)
            variables['time_keys'] = copy.deepcopy(model.time_keys)
            variables['z'] = model.z.clone()
            variables['S'] = model.S.clone()
            variables['A'] = model.A.clone()
            variables['Lambda_dict'] = model.Lambda_dict.clone()
            return variables
        def set_temporal_variables(variables, model, train_loader, test_loader):
            time_bar = copy.deepcopy(variables['time_bar'])
            train_loader.dataset.time_bar = time_bar
            test_loader.dataset.time_bar = time_bar
            model.node_degree_global = copy.deepcopy(variables['node_degree_global'])
            model.time_keys = copy.deepcopy(variables['time_keys'])
            model.z = variables['z'].clone()
            model.S = variables['S'].clone()
            model.A = variables['A'].clone()
            model.Lambda_dict = variables['Lambda_dict'].clone()
            return time_bar
        import time
        def test(model, n_test_batches=10, epoch=0):
            model.eval()
            loss = 0
            losses =[ [np.Inf, 0], [np.Inf, 0] ]
            n_samples = 0
            # Time slots with 10 days intervals as in the DyRep paper
            timeslots = [t.toordinal() for t in test_loader.dataset.TEST_TIMESLOTS]
            event_types = list(test_loader.dataset.event_types_num.keys()) #['comm', 'assoc']
            # sort it by k
            for event_t in test_loader.dataset.event_types_num:
                event_types[test_loader.dataset.event_types_num[event_t]] = event_t

            event_types += ['Com']

            mar, hits_10 = {}, {}
            for event_t in event_types:
                mar[event_t] = []
                hits_10[event_t] = []
                for c, slot in enumerate(timeslots):
                    mar[event_t].append([])
                    hits_10[event_t].append([])


            start = time.time()
            with torch.no_grad():
                
                import datetime
                #from datetime import datetime, timezone 
                for batch_idx, data in enumerate(test_loader):
                    data[2] = data[2].float()
                    data[4] = data[4].double()
                    data[5] = data[5].double()
                    data[6] = data[6].float()
                    output = model(data)
                    loss += (-torch.sum(torch.log(output[0]) + 1e-10) + torch.sum(output[1])).item()
                    for i in range(len(losses)):
                        m1 = output[i].min()
                        m2 = output[i].max()
                        if m1 < losses[i][0]:
                            losses[i][0] = m1
                        if m2 > losses[i][1]:
                            losses[i][1] = m2
                    n_samples += 1
                    A_pred, Survival_term = output[2]
                    u, v, k = data[0], data[1], data[3]

                    time_cur = data[5]
                    m, h = MAR(A_pred, u, v, k, Survival_term=Survival_term)
                    assert len(time_cur) == len(m) == len(h) == len(k)
                    for t, m, h, k_ in zip(time_cur, m, h, k):
                        d = datetime.datetime.fromtimestamp(t.item()).toordinal()
                        event_t = event_types[k_.item()]
                        for c, slot in enumerate(timeslots):
                            if d <= slot:
                                mar[event_t][c].append(m)
                                hits_10[event_t][c].append(h)
                                if k_ > 0:
                                    mar['Com'][c].append(m)
                                    hits_10['Com'][c].append(h)
                                if c > 0:
                                    assert slot > timeslots[c-1] and d > timeslots[c-1], (d, slot, timeslots[c-1])
                                break

                    if batch_idx % 10 == 0:
                        print('test', batch_idx)

                    if n_test_batches is not None and batch_idx >= n_test_batches - 1:
                        break

            time_iter = time.time() - start

            print('\nTEST batch={}/{}, loss={:.3f}, psi={}, loss1 min/max={:.4f}/{:.4f}, '
                'loss2 min/max={:.4f}/{:.4f}, integral time stamps={}, sec/iter={:.4f}'.
                format(batch_idx + 1, len(test_loader), (loss / n_samples),
                        [model.psi[c].item() for c in range(len(model.psi))],
                        losses[0][0], losses[0][1], losses[1][0], losses[1][1],
                        len(model.Lambda_dict), time_iter / (batch_idx + 1)))

            # Report results for different time slots in the test set
            for c, slot in enumerate(timeslots):
                s = 'Slot {}: '.format(c)
                for event_t in event_types:
                    sfx = '' if event_t == event_types[-1] else ', '
                    if len(mar[event_t][c]) > 0:
                        s += '{} ({} events): MAR={:.2f}+-{:.2f}, HITS_10={:.3f}+-{:.3f}'.\
                            format(event_t, len(mar[event_t][c]), np.mean(mar[event_t][c]), np.std(mar[event_t][c]),
                                    np.mean(hits_10[event_t][c]), np.std(hits_10[event_t][c]))
                    else:
                        s += '{} (no events)'.format(event_t)
                    s += sfx
                print(s)

            mar_all, hits_10_all = {}, {}
            for event_t in event_types:
                mar_all[event_t] = []
                hits_10_all[event_t] = []
                for c, slot in enumerate(timeslots):
                    mar_all[event_t].extend(mar[event_t][c])
                    hits_10_all[event_t].extend(hits_10[event_t][c])

            s = 'Epoch {}: results per event type for all test time slots: \n'.format(epoch)
            print(''.join(['-']*100))
            for event_t in event_types:
                if len(mar_all[event_t]) > 0:
                    s += '====== {:10s}\t ({:7s} events): \tMAR={:.2f}+-{:.2f}\t HITS_10={:.3f}+-{:.3f}'.format(
                        str(event_t),  # Ensure event_t is a string
                        str(len(mar_all[event_t])),  # Ensure this is also a string if it isn't already
                        np.mean(mar_all[event_t]),
                        np.std(mar_all[event_t]),
                        np.mean(hits_10_all[event_t]),
                        np.std(hits_10_all[event_t])
                    )
                else:
                    s += '====== {:10s}\t (no events)'.format(str(event_t))  # Ensure event_t is a string
                if event_t != event_types[-1]:
                    s += '\n'
            print(s)
            print(''.join(['-'] * 100))

            return mar_all, hits_10_all, loss / n_samples
        epoch_start = 1
        batch_size = 200
        weight = 1
        log_interval = 20
        # Train and test the model for a specified number of epochs
        for epoch in range(epoch_start, epochs + 1):
            if epoch > epoch_start:
                time_bar, node_degree_global = initialize_state(
                    dataset=train_loader.dataset,
                    model=model,
                    node_embeddings=initial_embeddings,
                    keepS=epoch > 1  # Only keep states after the first epoch if needed
                )
                model.node_degree_global = node_degree_global
                
            # Setting the global time_bar for the datasets
            train_loader.dataset.time_bar = time_bar
            test_loader.dataset.time_bar = time_bar

            start = time.time()

            for batch_idx, data_batch in enumerate(train_loader):
                # if batch_idx <= batch_start:
                #   continue

                model.train()
                optimizer.zero_grad()

                # Ensure the data is in the correct format
                data_batch[2] = data_batch[2].float()
                data_batch[4] = data_batch[4].double()
                data_batch[5] = data_batch[5].double()
                data_batch[6] = data_batch[6].float()

                output = model(data_batch)
                losses = [-torch.sum(torch.log(output[0]) + 1e-10), weight * torch.sum(output[1])]

                # KL losses (if there are additional items in output to process as losses)
                if len(output) > 3 and output[-1] is not None:
                    losses.extend(output[-1])

                loss = torch.sum(torch.stack(losses)) / batch_size
                loss.backward()
                nn.utils.clip_grad_value_(model.parameters(), 100)
                optimizer.step()

                # Clamping psi to prevent numerical overflow
                model.psi.data = torch.clamp(model.psi.data, 1e-1, 1e+3)

                time_iter = time.time() - start

                # Detach computational graph to prevent unwanted backprop
                model.z = model.z.detach()
                model.S = model.S.detach()

                if (batch_idx + 1) % log_interval == 0 or batch_idx == len(train_loader) - 1:

                    # Save state before testing
                    variables = get_temporal_variables()
                    #print('time', datetime.datetime.fromtimestamp(np.max(time_bar)))

                    # Testing and collecting results
                    result = test(model, n_test_batches=None if batch_idx == len(train_loader) - 1 else 10, epoch=epoch)
                    test_loss.append(result[2])

                    # Restore state after testing
                    time_bar = set_temporal_variables(variables, model, train_loader, test_loader)
            avg_test_loss = sum(test_loss) / len(test_loss)
            print(f"Epoch {epoch}, Gamma_value:{g}, Test Loss: {avg_test_loss}")
            scheduler.step()
    
        if avg_test_loss < best_loss:
            best_loss = avg_test_loss
            best_gamma = g

    return best_gamma, best_loss

In [None]:
# Define the hyperparameters for the grid search
gamma = [0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9, 0.95]
epochs = 100
data_dir = '/simulated_data/final_hawkes_with_features.csv'
best_gamma, best_loss = grid_search_bg(data_dir, gamma, epochs)