In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from iris import irisRequests
import datetime
from numba import njit
from numba.typed import List

In [15]:
torch.randn((10,20)).shape[0]

10

In [2]:
regions = {}
regions['greece'] = (30, 45,18, 44)
regions['california'] = (30, 41, -125, -113)
regions['japan'] = (20, 50, 120, 150)
regions['italy'] = (35,46,6, 19)

In [11]:
start_time = datetime.datetime(2003, 1, 1, 0, 0, 0)
end_time =  datetime.datetime(2024, 1, 1, 0, 0, 0)
catalogs = {}
for name, region in regions.items():
    df  =irisRequests.retrieve_events_box(start_time, end_time, region[0], region[1], region[2], region[3], minmag=3, magtype="MW")
    df.Time = df.Time.dt.tz_localize(None)
    catalogs[name] = df

  df = pd.read_csv(download_url, sep="|", comment="#")
  df = pd.read_csv(download_url, sep="|", comment="#")


In [150]:
@njit(nogil=True)
def split_times(times, origin_shift, t_end ):
    # find first index to match origin_shift
    origin_i = 0
    for i in range(0, len(times)):
        if(times[i] >= origin_shift):
            origin_i = i
            break
    begin_indices = []
    end_indices = []
    while(origin_i < len(times)):
        begin_indices.append(origin_i)
        end_i = origin_i + 1
        while(times[end_i] - times[origin_i] <= t_end):
            end_i += 1
            if(end_i >= len(times)):
                break
        end_indices.append(end_i)
        
        origin_i = end_i
    return begin_indices, end_indices

In [151]:
times, features = catalogs['japan'].Time.values, catalogs['japan'][['Latitude', 'Longitude', 'Magnitude']]
times = times - times[0]
times = times.astype('timedelta64[s]').astype('float')

In [146]:
times

array([0.00000000e+00, 5.62800000e+03, 7.97100000e+03, ...,
       6.55432316e+08, 6.55440602e+08, 6.55633038e+08])

In [32]:
class EmbeddingNPP(nn.Module):

    def __init__(self, in_features : int, emb_dim : int, dropout : float, layer_norm : bool):
        super().__init__()
        self.linear = nn.Linear(in_features, emb_dim)
        self.dropout = nn.Dropout(dropout)
        self.norm = nn.LayerNorm(emb_dim) if layer_norm else None
    def forward(self, x):
        x = self.dropout(self.linear(x))
        if(self.norm is not None):
            return self.norm(x)

In [38]:
class Weibull:
    def __init__(self, N, k, eps=1e-8):
        # b and k are strictly positive tensors of the same shape
        self.b = b
        self.k = k
        self.eps = eps
    
    def log_prob(self, x):
        """Logarithm of the probability density function log(f(x))."""
        # x must have the same shape as self.b and self.k
        x = x.clamp_min(self.eps)  # pow is unstable for inputs close to 0
        return (self.b.log() + self.k.log() + (self.k - 1) * x.log() 
                + self.b.neg() * torch.pow(x, self.k))
    
    def log_survival(self, x):
        """Logarithm of the survival function log(S(x))."""
        x = x.clamp_min(self.eps)
        return self.b.neg() * torch.pow(x, self.k)

In [106]:
class Exponential:
    def __init__(self, b):
        self.b = b

    def log_prob(self, x):
        return self.b.log() + self.b.neg()*x

In [107]:
class AttentiveNPP(nn.Module):
    def __init__(self, in_features : int, emb_dim : int, num_heads : int, dropout : float, batch_first : bool, hidden_dim : int):
        super().__init__()
        assert emb_dim % num_heads == 0
        self.batch_first = batch_first
        self.emb =EmbeddingNPP(in_features, emb_dim, dropout, True) #  apply everywhere the embedding
        self.mha = nn.MultiheadAttention(emb_dim, num_heads,
                                           dropout,
                                           batch_first=batch_first,
                                           kdim=emb_dim,
                                           vdim=emb_dim)
        
        self.rnn = nn.GRU(emb_dim,hidden_dim,
                          num_layers=1,
                          batch_first=batch_first)
        self.gr_mod = nn.Sequential(
            nn.Linear(hidden_dim, 1), 
            nn.Softplus()) # GR parameter
        self.weibull_mod = nn.Sequential(
            nn.Linear(hidden_dim, 2), nn.Softplus()) # Waiting time parameters

    def forward(self,z,inter_times): # z contains all the other info (mag, lat, lon, depth) for example
        # inter_times (N, L)
        # z (N, L, K) 
        x = torch.cat([z, inter_times.unsqueeze(-1)], dim=-1)  
        x = self.emb(x)
        attn_output, _ = self.mha(x, x, x) # no need the weights, atm
        rnn_output, _ = self.rnn(attn_output)
        # shift forward along the time dimension and pad
        context = F.pad(rnn_output[:, :-1, :], (0,0, 1,0))
        return context

    def get_time_nll(self, inter_times, context, seq_lengths):
        weibull_params = self.weibull_mod(context)
        weibull = Weibull(weibull_params[...,0], weibull_params[...,1])
        log_pdf = weibull.log_prob(inter_times)
        arange = torch.arange(inter_times.shape[1], device=seq_lengths.device)
        mask = (arange[None, :] < seq_lengths[:, None]).float()  # (N, L)
        log_like = (log_pdf * mask).sum(-1)  # (N,)
        log_surv = weibull.log_survival(inter_times)  # (N, L)
        end_idx = seq_lengths.unsqueeze(-1)  # (N, 1)
        log_surv_last = torch.gather(log_surv, dim=-1, index=end_idx)  # (N, 1)
        log_like += log_surv_last.squeeze(-1)  # (N,)
        return -log_like

In [112]:
model = AttentiveNPP(4,64,4, 0.2, True, 128)

In [109]:
inter_times = torch.zeros((10,100)).exponential_()
context = model(torch.randn((10, 100, 3 )), inter_times)

In [111]:
model.get_prob_part(inter_times,context)

tensor([[-1.2121, -0.5218, -0.6996, -1.1791, -2.0170, -0.8205, -0.5913, -0.7628,
         -1.0624, -0.4477, -1.1397, -1.1023, -1.3869, -0.8901, -1.4265, -0.9131,
         -0.5053, -0.7892, -0.4738, -0.5507, -0.5950, -0.4919, -1.7662, -0.6275,
         -2.3188, -0.6468, -1.4172, -0.5552, -0.8088, -0.6932, -1.6195, -1.0932,
         -0.6750, -0.7479, -0.4915, -0.9567, -0.6205, -0.6935, -0.5089, -1.4511,
         -0.8642, -0.5096, -1.0778, -1.2879, -0.7289, -0.5897, -1.3612, -0.6941,
         -1.4936, -1.3143, -1.0286, -0.7985, -1.7222, -3.8011, -0.5063, -0.6236,
         -1.3824, -0.9613, -0.5173, -1.2240, -0.6452, -1.7097, -0.6350, -0.8988,
         -0.5568, -0.6035, -0.6933, -1.0979, -1.9476, -0.7162, -1.4590, -0.8381,
         -0.6385, -1.0851, -0.8592, -1.4093, -1.3474, -2.6896, -1.8480, -0.4968,
         -0.5954, -0.7744, -0.5878, -0.7064, -1.6996, -0.7946, -2.7677, -1.1159,
         -1.2443, -0.5470, -1.4339, -0.7899, -0.9116, -0.8763, -0.7571, -1.0739,
         -0.6575, -1.9952, -