In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sequences import *

### What happens
$$ z_1 = (\tau_1, x_1), \dots, z_n = (\tau_n, x_n)$$
$$ h_i = \textrm{emb}(z_i) $$
$$ h_0 = 0 $$
$$ \tau_c = t_{\rm start} - t_n$$
$$ \theta_i = \textrm{net}(h_0, \dots, h_{i-1}) $$
$$ L = -\sum_{i=1}^n \ln f(\tau_i|\theta_i) - \ln S(\tau_c|\theta_{n+1}) $$

In [193]:
from networkx import sigma
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.nn import TransformerEncoderLayer, TransformerDecoderLayer
import torch.optim as optim
class WaitingTimeModel(nn.Module):
    def __init__(self, input_dim, num_params):
        super(WaitingTimeModel, self).__init__()
        self.input_dim = input_dim
        self.num_params = num_params

    def forward(self, x, tau):
        return None
        
class WeibullWaitingTimeModel(WaitingTimeModel):
    def __init__(self, input_dim):
        super(WeibullWaitingTimeModel, self).__init__(input_dim, num_params=2)
        self.linear = nn.Linear(input_dim, 2)

    def forward(self, x, tau, tau_cens, sequence_lengths):
        params = self.linear(x)
        raw_scale = F.softplus(params[..., 0])
        raw_shape = F.softplus(params[..., 1])
        end_idx = torch.unsqueeze(sequence_lengths,-1)
        scales_surv = torch.gather(raw_scale, dim=1, index=end_idx)[:,0]
        shapes_surv = torch.gather(raw_shape, dim=1, index=end_idx)[:,0]
        scales_events = raw_scale[:, :-1]
        shapes_events = raw_shape[:, :-1]
        tau_clamped = torch.clamp(tau, min=1e-6)
        tau_cens_clamped = torch.clamp(tau_cens, min=1e-6)
        log_prob = torch.log(shapes_events) - torch.log(scales_events) + (shapes_events - 1)*torch.log(tau_clamped) - (tau_clamped/scales_events)**shapes_events
        log_surv = - (tau_cens_clamped/scales_surv)**shapes_surv
        return log_prob, log_surv, raw_scale, raw_shape

class NormalModel(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(NormalModel, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.mean = nn.Linear(input_dim, output_dim)
        self.vars = nn.Sequential(
            nn.Linear(input_dim, output_dim),
            nn.Softplus()
        )

    def forward(self, x, target):
        means = self.mean(x)
        vars = self.vars(x)
        log_prob = - 0.5*torch.log(vars[...,:-1, :])  - ((target - means[...,:-1, :]) ** 2) / (2 * vars[...,:-1, :])
        return log_prob, means, vars

class LogNormalModel(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(LogNormalModel, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.lognormal_mean = nn.Linear(input_dim, output_dim)
        self.lognormal_vars = nn.Sequential(
            nn.Linear(input_dim, output_dim),
            nn.Softplus()
        )

    def forward(self, x, target):
        means = self.lognormal_mean(x)
        vars = self.lognormal_vars(x)
        target_clamped = torch.clamp(target, min=1e-6)
        log_prob = - 0.5*torch.log(vars[:,:-1, :])  - ((torch.log(target_clamped) - means[:,:-1, :]) ** 2) / (2 * vars[:,:-1, :]) - torch.log(target_clamped)
        return log_prob, means, vars

class SequenceEncoder(torch.nn.Module):
    def __init__(self, starting_token_feature_dim : int, feature_dim: int, emb_dim: int, nhead: int, dim_feedforward: int, num_layers: int):
        super(SequenceEncoder, self).__init__()
        self.s_token_linear = torch.nn.Linear(starting_token_feature_dim, emb_dim)
        self.in_linear = torch.nn.Linear(feature_dim, emb_dim)
        self.out_layer = torch.nn.Sequential(
            torch.nn.Linear(emb_dim, emb_dim),
            torch.nn.LayerNorm(emb_dim),
        )
        encoder_layer = TransformerEncoderLayer(d_model=emb_dim, nhead=nhead, dim_feedforward=dim_feedforward, batch_first=True)
        self.transformer_encoder = torch.nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
    def forward(self, start_conditioner : torch.Tensor, x: torch.Tensor, sequence_lengths : torch.Tensor):
        #starting_token = self.s_token_linear(start_conditioner).unsqueeze(1)
        starting_token = torch.zeros((x.size(0), 1, self.in_linear.out_features), device=x.device)
        output = self.in_linear(x)
        output = torch.cat([starting_token, output], dim=1)
        causal_mask = torch.nn.Transformer.generate_square_subsequent_mask(output.size(1), device=output.device, dtype=torch.bool)
        padding_mask = (torch.arange(output.size(1), device=output.device).unsqueeze(0) >= (sequence_lengths+1).unsqueeze(1))
        output = self.transformer_encoder(output, is_causal=True, mask=causal_mask, src_key_padding_mask=padding_mask)
        output = self.out_layer(output)
        return output

In [194]:
def gam_filter(values, min_val, max_val):
    rescaled_values = 2*(values - min_val) / (max_val - min_val) - 1.0
    phi_values = np.arccos(rescaled_values) 
    return np.sin(phi_values[:,None] - phi_values[None,:])

In [195]:
time_origin = pd.to_datetime("1997-01-01T00:00:00").to_numpy()
train_stop = pd.to_datetime("2020-01-01T00:00:00").to_numpy()
catalog = pd.read_csv("japanese-cat.csv", sep=" ", parse_dates=["Origin_Time(UT)"])
times = catalog["Origin_Time(UT)"].to_numpy()
times_days = (times - time_origin) / np.timedelta64(1, 'D')
times = times_days.astype(np.float32)
index_stop = np.argwhere(times >= (train_stop - time_origin) / np.timedelta64(1, 'D'))[0,0]
magnitudes = catalog["JMA_Magnitude(Mj)"].to_numpy()
latitudes = catalog["Latitude(deg)"].to_numpy()
longitudes = catalog["Longitude(deg)"].to_numpy()
features = np.vstack([magnitudes, latitudes, longitudes]).T
features[...,1:] = (features[...,1:] - features[:index_stop, 1:].mean(axis=0)) / features[:index_stop, 1:].std(axis=0)
train_sequence = Sequence(arrival_times=times[:index_stop], features=features[:index_stop])
test_sequence = Sequence(arrival_times=times[index_stop:], features=features[index_stop:])

In [None]:
encoder = SequenceEncoder(starting_token_feature_dim=1, feature_dim=4, emb_dim=16, nhead=2, dim_feedforward=64, num_layers=1)
wtmodel = WeibullWaitingTimeModel(input_dim=16)
posmodel = NormalModel(input_dim=16, output_dim=2)
magmodel = LogNormalModel(input_dim=16, output_dim=1)
optimizer = optim.Adam(list(encoder.parameters()) + list(wtmodel.parameters()) + list(posmodel.parameters()) + list(magmodel.parameters()), lr=1e-4)
torch.manual_seed(42)
np.random.seed(42)

In [None]:
loss_history = []   
for loop in range(5000):
    optimizer.zero_grad()
    subseq_inter_times, subseq_cens_inter_times, subseq_features, t_start, t_end = train_sequence.sample_sequences(
        max_num_sequences=10,
        duration_scale=15.0,
        return_inter_times=True)
    subseq_features = [ np.vstack([taus, feats.T]).T for taus, feats in zip(subseq_inter_times, subseq_features) ]
    packed_features, sequence_lengths = Sequence.pack_sequences(
        subseq_features)

    packed_features = torch.tensor(packed_features, dtype=torch.float32)
    sequence_lengths = torch.tensor(sequence_lengths, dtype=torch.long)
    t_start = torch.tensor(t_start, dtype=torch.float32)
    subseq_cens_inter_times = torch.tensor(subseq_cens_inter_times, dtype=torch.float32)
    output = encoder(
        torch.log(1+t_start[..., None]),   
        packed_features, 
        sequence_lengths)

    log_prob, log_surv, raw_scale, raw_shape = wtmodel(output, packed_features[...,0],
                                  subseq_cens_inter_times, sequence_lengths)
    log_prob_mask = (torch.arange(log_prob.size(1))[None, :] < sequence_lengths[:, None]).to(log_prob.device)

    
    pos_log_prob, pos_means, pos_vars = posmodel(output, packed_features[...,2:])
    mag_log_prob, mag_means, mag_vars = magmodel(output, packed_features[...,1:2])

    
    loss = - (log_prob * log_prob_mask).sum(-1) - log_surv
    loss += - (pos_log_prob * log_prob_mask[:,:,None]).sum([-2,-1])
    loss += - (mag_log_prob * log_prob_mask[:,:,None]).sum([-2,-1])
    if(torch.isnan(loss).any()):
        print("NaN encountered in loss computation at step ", loop)
        break
    loss = loss.mean()
    loss.backward()
    optimizer.step()
    loss_history.append(loss.item())
    if loop % 100 == 0:
        print(f"Loop {loop}, loss = {loss.item():.4f}")

In [None]:
num_input_samples = 100
num_steps = 20
sampling_origin = index_stop + 100
input_t_start = times[sampling_origin - 1]
input_feature = features[sampling_origin]
input_tau = times[sampling_origin] - input_t_start
input_tensor = torch.tensor(np.hstack([input_tau, input_feature])[None, None, :], dtype=torch.float32)
input_tensor = torch.repeat_interleave(input_tensor, repeats=num_input_samples, dim=0)
input_t_start = torch.tensor(np.array([input_t_start]*num_input_samples), dtype=torch.float32)
input_sequence_length = torch.tensor([1]*num_input_samples, dtype=torch.long)

with torch.no_grad():
    for step in range(num_steps):
        output = encoder(
                torch.log(1+input_t_start[..., None]),   
                input_tensor, 
                input_sequence_length)

        log_prob, log_surv, raw_scale, raw_shape = wtmodel(output, input_tensor[...,0],
                                    torch.tensor([np.inf]*num_input_samples, dtype=torch.float32),
                                    input_sequence_length)

        pos_log_prob, pos_means, pos_vars = posmodel(output, input_tensor[...,2:])
        mag_log_prob, mag_means, mag_vars = magmodel(output, input_tensor[...,1:2])

        new_event_waiting_times = torch.zeros((num_input_samples,))
        new_event_waiting_times = (new_event_waiting_times.exponential_()/raw_scale[:,-1])**(raw_shape[:,-1])

        new_positions = torch.normal(pos_means[:,-1,:], torch.sqrt(pos_vars[:,-1,:]))
        new_magnitudes = torch.exp(torch.normal(mag_means[:,-1,:], torch.sqrt(mag_vars[:,-1,:])))

        new_features = torch.hstack([new_event_waiting_times.unsqueeze(-1), new_magnitudes, new_positions])[:, None, :]
        input_tensor = torch.cat([input_tensor, new_features], dim=1)
        input_sequence_length += 1

In [None]:
from scipy.stats import binned_statistic
for sample_idx in range(num_input_samples):
    plt.plot(times[sampling_origin:sampling_origin+10*num_steps] - times[sampling_origin-1], features[sampling_origin:sampling_origin+10*num_steps,0])
    plt.plot(torch.cumsum(input_tensor[sample_idx,:, 0], dim=0).numpy(), input_tensor[sample_idx,:, 1].numpy(), color='red', alpha=0.5)
    plt.show()