In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sequences import *

In [None]:
def get_nll_loss(log_prob, log_surv, seq_lengths):
    mask = torch.arange(log_prob.size(1))[None, :] < seq_lengths[:, None]
    nll_loss = - (log_prob * mask).sum(-1)
    log_surv_last = torch.gather(log_surv, dim=-1, index=torch.unsqueeze(seq_lengths,-1))
    return nll_loss - log_surv_last.squeeze(-1)

### What happens
$$ z_1 = (\tau_1, x_1), \dots, z_n = (\tau_n, x_n)$$
$$ h_i = \textrm{emb}(z_i) $$
$$ h_0 = 0 $$
$$ \tau_c = t_{\rm start} - t_n$$
$$ \theta_i = \textrm{net}(h_0, \dots, h_{i-1}) $$
$$ L = -\sum_{i=1}^n \ln f(\tau_i|\theta_i) - \ln S(\tau_c|\theta_{n+1}) $$

In [211]:
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.nn import TransformerEncoderLayer, TransformerDecoderLayer

class WaitingTimeModel(nn.Module):
    def __init__(self, input_dim, num_params):
        super(WaitingTimeModel, self).__init__()
        self.input_dim = input_dim
        self.num_params = num_params

    def forward(self, x, tau):
        return None
        
class WeibullWaitingTimeModel(WaitingTimeModel):
    def __init__(self, input_dim):
        super(WeibullWaitingTimeModel, self).__init__(input_dim, num_params=2)
        self.linear = nn.Linear(input_dim, 2)

    def forward(self, x, tau, tau_cens, sequence_lengths):
        params = self.linear(x)
        raw_scale = F.softplus(params[..., 0])
        raw_shape = F.softplus(params[..., 1])
        end_idx = torch.unsqueeze(sequence_lengths,-1)
        scales_surv = torch.gather(raw_scale, dim=1, index=end_idx)[:,0]
        shapes_surv = torch.gather(raw_shape, dim=1, index=end_idx)[:,0]
        scales_events = raw_scale[:, :-1]
        shapes_events = raw_shape[:, :-1]
        tau_clamped = torch.clamp(tau, min=1e-6)
        tau_cens_clamped = torch.clamp(tau_cens, min=1e-6)
        log_prob = torch.log(shapes_events) - torch.log(scales_events) + (shapes_events - 1)*torch.log(tau_clamped) - (tau_clamped/scales_events)**shapes_events
        log_surv = - (tau_cens_clamped/scales_surv)**shapes_surv
        return log_prob, log_surv
    

class SequenceEncoder(torch.nn.Module):
    def __init__(self, starting_token_feature_dim : int, feature_dim: int, emb_dim: int, nhead: int, dim_feedforward: int, num_layers: int):
        super(SequenceEncoder, self).__init__()
        self.s_token_linear = torch.nn.Linear(starting_token_feature_dim, emb_dim)
        self.in_linear = torch.nn.Linear(feature_dim, emb_dim)
        self.out_layer = torch.nn.Linear(emb_dim, emb_dim)
        encoder_layer = TransformerEncoderLayer(d_model=emb_dim, nhead=nhead, dim_feedforward=dim_feedforward, batch_first=True)
        self.transformer_encoder = torch.nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

    def forward(self, start_conditioner : torch.Tensor, x: torch.Tensor, sequence_lengths : torch.Tensor):
        #starting_token = self.s_token_linear(start_conditioner).unsqueeze(1)
        starting_token = torch.zeros((x.size(0), 1, self.s_token_linear.out_features), device=x.device)
        output = self.in_linear(x)
        output = torch.cat([starting_token, output], dim=1)
        casual_mask = torch.nn.Transformer.generate_square_subsequent_mask(output.size(1), device=output.device)
        padding_mask = (torch.arange(output.size(1), device=output.device).unsqueeze(0) >= (sequence_lengths+1).unsqueeze(1)).to(torch.float32)
        output = self.transformer_encoder(output, is_causal=True, mask=casual_mask, src_key_padding_mask=padding_mask)
        output = self.out_layer(output)
        return output

In [212]:
def gam_filter(values, min_val, max_val):
    rescaled_values = 2*(values - min_val) / (max_val - min_val) - 1.0
    phi_values = np.arccos(rescaled_values) 
    return np.sin(phi_values[:,None] - phi_values[None,:])

In [213]:
time_origin = pd.to_datetime("1997-01-01T00:00:00").to_numpy()
catalog = pd.read_csv("japanese-cat.csv", sep=" ", parse_dates=["Origin_Time(UT)"])
times = catalog["Origin_Time(UT)"].to_numpy()
times_days = (times - time_origin) / np.timedelta64(1, 'D')
times = times_days.astype(np.float32)
magnitudes = catalog["JMA_Magnitude(Mj)"].to_numpy()
latitudes = catalog["Latitude(deg)"].to_numpy()
longitudes = catalog["Longitude(deg)"].to_numpy()
features = np.vstack([magnitudes, latitudes, longitudes]).T

In [214]:
encoder = SequenceEncoder(starting_token_feature_dim=1, feature_dim=4, emb_dim=16, nhead=2, dim_feedforward=64, num_layers=1)
wtmodel = WeibullWaitingTimeModel(input_dim=16)
sequence = Sequence(arrival_times=times, features=features)
np.random.seed(42)

In [217]:
subseq_inter_times, subseq_cens_inter_times, subseq_features, t_start, t_end = sequence.sample_sequences(
    max_num_sequences=100,
    duration_scale=30.0,
    return_inter_times=True)
subseq_features = [ np.vstack([taus, feats.T]).T for taus, feats in zip(subseq_inter_times, subseq_features) ]
packed_features, sequence_lengths = Sequence.pack_sequences(
    subseq_features)

packed_features = torch.tensor(packed_features, dtype=torch.float32)
sequence_lengths = torch.tensor(sequence_lengths, dtype=torch.long)
t_start = torch.tensor(t_start, dtype=torch.float32)
subseq_cens_inter_times = torch.tensor(subseq_cens_inter_times, dtype=torch.float32)
output = encoder(
    t_start[..., None],   
    packed_features, 
    sequence_lengths)

log_prob, log_surv = wtmodel(output, packed_features[...,0],
                              subseq_cens_inter_times, sequence_lengths)

log_prob_mask = (torch.arange(log_prob.size(1))[None, :] < sequence_lengths[:, None]).to(log_prob.device)

loss = - (log_prob * log_prob_mask).sum(-1) - log_surv
loss = loss.mean()