In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sequences import *

In [3]:
def get_nll_loss(log_prob, log_surv, seq_lengths):
    mask = torch.arange(log_prob.size(1))[None, :] < seq_lengths[:, None]
    nll_loss = - (log_prob * mask).sum(-1)
    log_surv_last = torch.gather(log_surv, dim=-1, index=torch.unsqueeze(seq_lengths,-1))
    return nll_loss - log_surv_last.squeeze(-1)

In [4]:
def gam_filter(values, min_val, max_val):
    rescaled_values = 2*(values - min_val) / (max_val - min_val) - 1.0
    phi_values = np.arccos(rescaled_values) 
    return np.sin(phi_values[:,None] - phi_values[None,:])

In [5]:
time_origin = pd.to_datetime("1997-01-01T00:00:00").to_numpy()
catalog = pd.read_csv("japanese-cat.csv", sep=" ", parse_dates=["Origin_Time(UT)"])
times = catalog["Origin_Time(UT)"].to_numpy()
times_days = (times - time_origin) / np.timedelta64(1, 'D')
times = times_days.astype(np.float32)
magnitudes = catalog["JMA_Magnitude(Mj)"].to_numpy()
latitudes = catalog["Latitude(deg)"].to_numpy()
longitudes = catalog["Longitude(deg)"].to_numpy()
features = np.vstack([magnitudes, latitudes, longitudes]).T

In [6]:
subsequence_arrival_times, subsequence_features, t_start, t_end = extract_subsequences(
    times,
    features,
    max_num_sequences = 100,
    duration_scale = 30
)
subsequence_waiting_times = [ arrival_to_inter_times(subsequence_arrival_times[i], t_start[i], t_end[i])  for i in range(len(subsequence_arrival_times))]
subsequence_list = [ np.vstack((subsequence_waiting_times[i],np.vstack([subsequence_features[i], np.zeros(subsequence_features[i].shape[1])]).T)).T for i in range(len(subsequence_features))]

packed_features, seq_lengths, masks = pack_sequences(
    subsequence_list, batch_first=True, valid_mask_is_true=False) 

In [None]:
import torch
import torch.nn as nn
from torch.nn import TransformerEncoderLayer, TransformerDecoderLayer

In [None]:
class WaitingTimeModel(nn.Module):
    def __init__(self, input_dim, num_params):
        super(WaitingTimeModel, self).__init__()
        self.input_dim = input_dim
        self.num_params = num_params

    def forward(self, x, tau):
        return None
        
class WeibullWaitingTimeModel(WaitingTimeModel):
    def __init__(self, input_dim):
        super(WeibullWaitingTimeModel, self).__init__(input_dim, num_params=2)
        self.linear = nn.Linear(input_dim, 2)

    def forward(self, x, tau):
        params = self.linear(x)
        scale = torch.log(1+torch.exp(params[..., 0]))
        shape = torch.log(1+torch.exp(params[..., 1]))
        scaled_taus = tau/scale
        log_prob = torch.log(shape) - torch.log(scale) + (shape - 1)*torch.log(scaled_taus) - scaled_taus**shape
        log_surv = - scaled_taus**shape
        return log_prob, log_surv

In [None]:
class SequenceEncoder(torch.nn.Module):
    def __init__(self, feature_dim: int, emb_dim: int, nhead: int, dim_feedforward: int, num_layers: int):
        super(SequenceEncoder, self).__init__()
        self.emb_layer = torch.nn.Linear(feature_dim, emb_dim)
        encoder_layer = TransformerEncoderLayer(d_model=emb_dim, nhead=nhead, dim_feedforward=dim_feedforward, batch_first=True)
        self.transformer_encoder = torch.nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

    def forward(self, src: torch.Tensor, src_key_padding_mask: torch.Tensor = None):
        src = self.emb_layer(src)
        casual_mask = torch.nn.Transformer.generate_square_subsequent_mask(src.shape[1], device=src.device)
        output = self.transformer_encoder(src, is_causal=True, mask=casual_mask, src_key_padding_mask=src_key_padding_mask)
        return output

In [None]:
encoder = SequenceEncoder(feature_dim=4, emb_dim=16, nhead=2, dim_feedforward=64, num_layers=1)

In [None]:
output = encoder(torch.tensor(packed_features, dtype=torch.float32), 
                 torch.tensor(masks))

In [None]:
wtmodel = WeibullWaitingTimeModel(input_dim=16)

In [None]:
log_prob, log_surv = wtmodel(output, torch.tensor(packed_features, dtype=torch.float32)[...,0])

In [None]:
get_nll_loss(log_prob, log_surv, torch.from_numpy(seq_lengths) )