In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import obspy
from obspy.clients.fdsn import Client
from obspy import UTCDateTime
import matplotlib.pyplot as plt
import h5py
import pandas as pd
from torchvision.ops.misc import ConvNormActivation
from torchvision.ops import StochasticDepth
from typing import List, Optional, Union, Callable

### Models

In [None]:
# ===== MAE-1D FIXES (canonical ids_keep/ids_restore, no ragged padding, correct LayerScale, patch regression) =====
# Drop-in replacements for: TransformerLayerRoPE, MAEEncoder1d, MAEDecoder1d
# + helper functions for MAE masking and patchify/unpatchify targets.

from einops import rearrange
from rotary_embedding_torch import RotaryEmbedding, apply_rotary_emb
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple


# ----------------------------
# Helpers: patchify + MAE mask
# ----------------------------
def patchify_1d(x: torch.Tensor, patch_size: int, pad_to_patch: bool = True) -> Tuple[torch.Tensor, int]:
    """
    x: (B, C, L)
    returns:
      patches: (B, N, C*P)
      pad: int padding applied on the right
    """
    B, C, L = x.shape
    P = patch_size
    pad = 0
    if pad_to_patch:
        pad = (-L) % P
        if pad:
            x = F.pad(x, (0, pad))
            L = L + pad
    else:
        assert L % P == 0, f"L={L} not divisible by patch_size={P}"

    N = L // P
    patches = x.unfold(2, P, P)                    # (B, C, N, P)
    patches = patches.permute(0, 2, 1, 3)          # (B, N, C, P)
    patches = patches.contiguous().view(B, N, C * P)
    return patches, pad


def mae_random_masking(B: int, N: int, mask_ratio: float, device=None) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Canonical MAE masking.
    Returns:
      ids_keep:    (B, N_keep) indices of visible tokens in shuffled order
      ids_restore: (B, N) inverse permutation to restore original order
      mask:        (B, N) float mask in original order: 1=masked, 0=visible
    """
    assert 0.0 <= mask_ratio < 1.0
    N_keep = int(round(N * (1.0 - mask_ratio)))

    noise = torch.rand(B, N, device=device)
    ids_shuffle = noise.argsort(dim=1)           # (B, N)
    ids_restore = ids_shuffle.argsort(dim=1)     # (B, N)

    ids_keep = ids_shuffle[:, :N_keep]           # (B, N_keep)

    mask = torch.ones(B, N, device=device)
    mask[:, :N_keep] = 0
    mask = torch.gather(mask, dim=1, index=ids_restore)  # unshuffle to original order
    return ids_keep, ids_restore, mask


# ----------------------------
# Patch encoder (unchanged)
# ----------------------------
class ConvPatchEncoder1d(nn.Module):
    def __init__(
        self,
        in_channels: int,
        emb_dim: int,
        num_conv_layers: int,
        patch_size: int,
        kernel_size: int = 3,
        padding: int = 1,
        stride: int = 1,
        pad_to_patch: bool = True
    ):
        super().__init__()
        self.patch_size = patch_size
        self.pad_to_patch = pad_to_patch

        layers = []
        c_in = in_channels
        for _ in range(num_conv_layers):
            layers.append(nn.Conv1d(c_in, emb_dim, kernel_size=kernel_size, stride=stride, padding=padding))
            layers.append(nn.GELU())
            c_in = emb_dim
        self.net = nn.Sequential(*layers)

        self.norm = nn.LayerNorm(emb_dim)
        self.pool = nn.AdaptiveAvgPool1d(1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (B, C, L)
        B, C, L = x.shape
        P = self.patch_size

        if self.pad_to_patch:
            pad = (-L) % P
            if pad:
                x = F.pad(x, (0, pad))
                L = L + pad
        else:
            assert L % P == 0, f"L={L} not divisible by patch_size={P}"

        N = L // P

        x = x.unfold(dimension=2, size=P, step=P)           # (B, C, N, P)
        x = rearrange(x, 'b c n p -> (b n) c p')            # (B*N, C, P)

        x = self.net(x)                                     # (B*N, emb_dim, p')
        x = self.pool(x).squeeze(-1)                        # (B*N, emb_dim)
        x = self.norm(x)                                    # (B*N, emb_dim)

        x = rearrange(x, '(b n) d -> b n d', b=B, n=N)       # (B, N, emb_dim)
        return x


# --------------------------------
# FIXED Transformer layer (LayerScale once, proper RoPE positions, optional attn_mask)
# --------------------------------
class TransformerLayerRoPE(nn.Module):
    def __init__(
        self,
        emb_dim: int,
        nheads: int,
        dim_feedforward: int = 2048,
        dropout: float = 0.1,
        layer_scale: float = 1e-2,
        bias: bool = True
    ):
        super().__init__()
        assert emb_dim % nheads == 0, "emb_dim must be divisible by nheads"
        self.emb_dim = emb_dim
        self.nheads = nheads
        self.head_dim = emb_dim // nheads

        self.pos_emb = RotaryEmbedding(dim=self.head_dim)

        self.q_proj = nn.Linear(emb_dim, emb_dim, bias=bias)
        self.k_proj = nn.Linear(emb_dim, emb_dim, bias=bias)
        self.v_proj = nn.Linear(emb_dim, emb_dim, bias=bias)
        self.out_proj = nn.Linear(emb_dim, emb_dim, bias=bias)

        self.norm1 = nn.LayerNorm(emb_dim)
        self.norm2 = nn.LayerNorm(emb_dim)

        self.ffn = nn.Sequential(
            nn.Linear(emb_dim, dim_feedforward, bias=bias),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(dim_feedforward, emb_dim, bias=bias),
        )

        self.ls1 = nn.Parameter(layer_scale * torch.ones(emb_dim)) if layer_scale > 0 else None
        self.ls2 = nn.Parameter(layer_scale * torch.ones(emb_dim)) if layer_scale > 0 else None
        self.drop = nn.Dropout(dropout)

    def _shape_heads(self, x: torch.Tensor) -> torch.Tensor:
        # (B, N, D) -> (B, H, N, Dh)
        B, N, D = x.shape
        return x.view(B, N, self.nheads, self.head_dim).transpose(1, 2)

    def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
        # (B, H, N, Dh) -> (B, N, D)
        B, H, N, Dh = x.shape
        return x.transpose(1, 2).contiguous().view(B, N, H * Dh)

    def _apply_rope(self, q: torch.Tensor, k: torch.Tensor, positions: Optional[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        q,k: (B,H,N,Dh)
        positions:
          - None -> sequential [0..N-1]
          - (B,N) long -> absolute positions per token
        """
        if positions is None:
            q = self.pos_emb.rotate_queries_or_keys(q)
            k = self.pos_emb.rotate_queries_or_keys(k)
            return q, k

        # positions are per-batch absolute indices
        # rotary_embedding_torch: pos_emb.forward(t) gives freqs for those positions
        # Need positions as int/long; build freqs for all needed positions then index
        max_pos = int(positions.max().item())
        pos = torch.arange(max_pos + 1, device=positions.device)        # (max_pos+1,)
        freqs = self.pos_emb.forward(pos)                               # (max_pos+1, Dh)
        freqs_batch = freqs[positions]                                  # (B, N, Dh)
        freqs_batch = freqs_batch.unsqueeze(1)                          # (B, 1, N, Dh)

        q = apply_rotary_emb(freqs_batch, q)
        k = apply_rotary_emb(freqs_batch, k)
        return q, k

    def forward(
        self,
        src: torch.Tensor,
        positions: Optional[torch.Tensor] = None,
        attn_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """
        src: (B, N, D)
        positions: None or (B,N) long absolute indices for RoPE
        attn_mask: passed to scaled_dot_product_attention; bool mask uses True=allowed
                  typical key padding mask (True=keep keys): (B,1,1,N) broadcastable
        """
        # ---- attention (pre-norm) ----
        x = self.norm1(src)
        q = self._shape_heads(self.q_proj(x))
        k = self._shape_heads(self.k_proj(x))
        v = self._shape_heads(self.v_proj(x))

        q, k = self._apply_rope(q, k, positions)

        attn_out = F.scaled_dot_product_attention(
            q, k, v,
            attn_mask=attn_mask,
            dropout_p=self.drop.p if self.training else 0.0,
            is_causal=False,
        )
        attn_out = self._merge_heads(attn_out)
        attn_out = self.out_proj(attn_out)

        if self.ls1 is not None:
            src = src + self.drop(attn_out) * self.ls1[None, None, :]
        else:
            src = src + self.drop(attn_out)

        # ---- ffn (pre-norm) ----
        y = self.ffn(self.norm2(src))

        # FIX: apply LayerScale ONLY ONCE (not twice)
        if self.ls2 is not None:
            src = src + self.drop(y) * self.ls2[None, None, :]
        else:
            src = src + self.drop(y)

        return src


# ----------------------------
# FIXED MAE encoder: canonical ids_keep, no ragged padding, correct RoPE positions
# ----------------------------
class MAEEncoder1d(nn.Module):
    def __init__(
        self,
        emb_dim: int,
        nheads: int,
        num_layers: int,
        dim_feedforward: int = 2048,
        dropout: float = 0.1,
        layer_scale: float = 1e-2
    ):
        super().__init__()
        self.layers = nn.ModuleList([
            TransformerLayerRoPE(
                emb_dim=emb_dim,
                nheads=nheads,
                dim_feedforward=dim_feedforward,
                dropout=dropout,
                layer_scale=layer_scale,
            )
            for _ in range(num_layers)
        ])

    def forward(
        self,
        x: torch.Tensor,
        mask_ratio: Optional[float] = None,
        ids_keep: Optional[torch.Tensor] = None,
        ids_restore: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
        """
        x: (B, N, D)
        Either provide:
          - mask_ratio (and it will sample ids_keep/ids_restore), or
          - ids_keep + ids_restore (precomputed)

        Returns:
          z_vis:      (B, N_keep, D) encoded visible tokens
          ids_keep:   (B, N_keep)
          ids_restore:(B, N)
          mask:       (B, N) float, 1=masked, 0=visible
        """
        B, N, D = x.shape
        if ids_keep is None:
            assert mask_ratio is not None, "Provide mask_ratio or ids_keep/ids_restore"
            ids_keep, ids_restore, mask = mae_random_masking(B, N, mask_ratio, device=x.device)
        else:
            assert ids_restore is not None, "If ids_keep is provided, ids_restore must be provided"
            mask = None

        # gather visible tokens
        x_vis = x.gather(1, ids_keep[..., None].expand(-1, -1, D))  # (B, N_keep, D)
        pos_vis = ids_keep                                           # (B, N_keep) absolute positions

        # no padding needed; all have same N_keep
        for layer in self.layers:
            x_vis = layer(x_vis, positions=pos_vis, attn_mask=None)

        return x_vis, ids_keep, ids_restore, mask


# ----------------------------
# FIXED MAE decoder: append mask tokens, unshuffle with ids_restore, predict patch values
# ----------------------------
class MAEDecoder1d(nn.Module):
    def __init__(
        self,
        enc_dim: int,
        dec_dim: int,
        patch_dim: int,          # = in_channels * patch_size
        nheads: int,
        num_layers: int,
        dim_feedforward: int = 2048,
        dropout: float = 0.1,
        layer_scale: float = 1e-2,
    ):
        super().__init__()
        self.dec_dim = dec_dim

        self.dec_embed = nn.Linear(enc_dim, dec_dim)
        self.mask_token = nn.Parameter(torch.zeros(1, 1, dec_dim))
        nn.init.normal_(self.mask_token, std=0.02)

        self.layers = nn.ModuleList([
            TransformerLayerRoPE(
                emb_dim=dec_dim,
                nheads=nheads,
                dim_feedforward=dim_feedforward,
                dropout=dropout,
                layer_scale=layer_scale,
            )
            for _ in range(num_layers)
        ])

        self.norm = nn.LayerNorm(dec_dim)
        self.pred = nn.Linear(dec_dim, patch_dim)

    def forward(self, z_vis: torch.Tensor, ids_keep: torch.Tensor, ids_restore: torch.Tensor) -> torch.Tensor:
        """
        z_vis: (B, N_keep, enc_dim)
        ids_keep: (B, N_keep)
        ids_restore: (B, N)
        returns:
          pred_patches: (B, N, patch_dim)
        """
        B, N_keep, _ = z_vis.shape
        N = ids_restore.shape[1]

        z_vis = self.dec_embed(z_vis)                    # (B, N_keep, dec_dim)

        # append mask tokens to reach length N (in shuffled order: [keep | mask])
        n_mask = N - N_keep
        z_mask = self.mask_token.expand(B, n_mask, self.dec_dim)
        z_ = torch.cat([z_vis, z_mask], dim=1)          # (B, N, dec_dim) in shuffled order

        # unshuffle to original order
        z_full = z_.gather(1, ids_restore[..., None].expand(-1, -1, self.dec_dim))  # (B, N, dec_dim)

        # decode with sequential positions (positions=None -> 0..N-1)
        for layer in self.layers:
            z_full = layer(z_full, positions=None, attn_mask=None)

        z_full = self.norm(z_full)
        pred = self.pred(z_full)                         # (B, N, patch_dim)
        return pred


# ----------------------------
# Full MAE wrapper (optional convenience)
# ----------------------------
class MAE1d(nn.Module):
    def __init__(
        self,
        in_channels: int,
        patch_size: int,
        enc_dim: int,
        dec_dim: int,
        enc_layers: int,
        dec_layers: int,
        nheads_enc: int,
        nheads_dec: int,
        enc_dim_ff: int,
        dec_dim_ff: int,
        dropout: float = 0.1,
        layer_scale: float = 1e-2,
        num_conv_layers: int = 2,
    ):
        super().__init__()
        self.in_channels = in_channels
        self.patch_size = patch_size
        self.patch_dim = in_channels * patch_size

        self.patch_encoder = ConvPatchEncoder1d(
            in_channels=in_channels,
            emb_dim=enc_dim,
            num_conv_layers=num_conv_layers,
            patch_size=patch_size,
        )

        self.encoder = MAEEncoder1d(
            emb_dim=enc_dim,
            nheads=nheads_enc,
            num_layers=enc_layers,
            dim_feedforward=enc_dim_ff,
            dropout=dropout,
            layer_scale=layer_scale,
        )

        self.decoder = MAEDecoder1d(
            enc_dim=enc_dim,
            dec_dim=dec_dim,
            patch_dim=self.patch_dim,
            nheads=nheads_dec,
            num_layers=dec_layers,
            dim_feedforward=dec_dim_ff,
            dropout=dropout,
            layer_scale=layer_scale,
        )

    def forward(self, x: torch.Tensor, mask_ratio: float) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        x: (B, C, L) normalized traces
        returns:
          loss, pred_patches, mask (B,N) with 1=masked
        """
        # targets in input space
        target_patches, _ = patchify_1d(x, self.patch_size, pad_to_patch=True)   # (B,N,patch_dim)

        tokens = self.patch_encoder(x)                                           # (B,N,enc_dim)
        z_vis, ids_keep, ids_restore, mask = self.encoder(tokens, mask_ratio=mask_ratio)

        pred = self.decoder(z_vis, ids_keep, ids_restore)                        # (B,N,patch_dim)

        # MAE loss: masked-only
        # mask: (B,N) float 1=masked 0=visible
        loss_map = (pred - target_patches).pow(2).mean(dim=-1)                   # (B,N)
        loss = (loss_map * mask).sum() / mask.sum().clamp_min(1.0)

        return loss, pred, mask


In [None]:
class SteadDataset(Dataset):

    def __init__(self, chunk_files, channel_first):
        self.files = []
        self.event_lists = []
        self.stopping_indices = None
        for chunk in chunk_files:
            file = h5py.File(chunk, 'r')
            metadata = pd.read_csv(chunk.replace('hdf5', 'csv'))
            ev_list = metadata['trace_name'].astype('str').to_list()
            self.files.append(file)
            self.event_lists.append(ev_list)
            if self.stopping_indices:
                self.stopping_indices.append(self.stopping_indices[-1] + len(ev_list))
            else:
                self.stopping_indices = [len(ev_list)]
        self.stopping_indices = np.array(self.stopping_indices)
        self.channel_first = channel_first
    def __len__(self):
        return sum([len(ev_list) for ev_list in self.event_lists])
    

    def __getitem__(self, idx):
        # find which chunk
        chunk_idx = 0
        while idx >= self.stopping_indices[chunk_idx]:
            chunk_idx += 1
        relative_idx = idx - self.stopping_indices[chunk_idx - 1] if chunk_idx > 0 else idx
        event_name = self.event_lists[chunk_idx][relative_idx]
        file = self.files[chunk_idx].get('data/' + event_name)
        trace = np.array(file)
        p_arrival = file.attrs['p_arrival_sample']
        s_arrival = file.attrs['s_arrival_sample']
        coda_end = file.attrs['coda_end_sample']
        if(p_arrival == ''):
            p_arrival = np.nan
        if(s_arrival == ''):
            s_arrival = np.nan
        if(coda_end == ''):
            coda_end = np.nan
        if self.channel_first:
            trace = trace.transpose(1, 0)
        return trace, p_arrival.item(), s_arrival.item(), coda_end.item(), event_name

### Train

In [None]:
root = 'STEAD/'
train_chunks = [ root + f'chunk{chunk}/chunk{chunk}.hdf5'  for chunk in range(2, 4) ]
val_chunks = [ root + f'chunk{chunk}/chunk{chunk}.hdf5'  for chunk in range(4, 5) ]
train_dataset = SteadDataset(train_chunks , channel_first=True)
val_dataset = SteadDataset(val_chunks , channel_first=True)

In [None]:
batch_size = 64
num_epochs = 10
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0) # must be 0 for hdf5

In [None]:
# --- hyperparams ---
patch_size = 80
in_channels = 3
mask_ratio = 0.4

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# ---- build model (using the FIXED classes) ----
mae = MAE1d(
    in_channels=in_channels,
    patch_size=patch_size,
    enc_dim=128,
    dec_dim=64,
    enc_layers=4,
    dec_layers=2,
    nheads_enc=8,
    nheads_dec=4,
    enc_dim_ff=512,
    dec_dim_ff=256,
    dropout=0.1,
    layer_scale=1e-2,
    num_conv_layers=2,
).to(device)

optimizer = optim.AdamW(
    mae.parameters(),
    lr=1e-4,
    weight_decay=1e-2
)

# If you want cosine over *all steps*
total_steps = num_epochs * len(train_loader)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=total_steps)

# Optional but recommended for stability
grad_clip = 1.0

mae.train()
global_step = 0
for epoch in range(num_epochs):
    for i, (traces, p_arrivals, s_arrivals, coda_ends, event_names) in enumerate(train_loader):
        traces = traces.to(device)  # (B,3,L)

        # normalize per trace (your choice; keep for now)
        traces_mean = traces.mean(dim=2, keepdim=True)
        traces_std  = traces.std(dim=2, keepdim=True).clamp_min(1e-9)
        normalized_traces = (traces - traces_mean) / traces_std

        optimizer.zero_grad(set_to_none=True)

        # ---- MAE forward: returns masked-only loss already ----
        loss, pred_patches, mask = mae(normalized_traces, mask_ratio=mask_ratio)

        loss.backward()

        if grad_clip is not None:
            torch.nn.utils.clip_grad_norm_(mae.parameters(), grad_clip)

        optimizer.step()
        scheduler.step()
        global_step += 1

        if i % 100 == 0:
            # mask is 1=masked; print masked fraction
            print(f"Epoch {epoch} Step {i}  Loss {loss.item():.3e}  masked_frac {mask.mean().item():.2f}")


In [None]:
class LatentDirichletRegression(nn.Module):
    def __init__(self,  input_dim, output_dim, kernel_sizes, channel_sizes, strides, paddings):
        super().__init__()
        self.net = nn.Sequential()
        for i, (k, s, p) in enumerate(zip(kernel_sizes, strides, paddings)):
            conv = nn.Conv1d(
                in_channels=input_dim if i == 0 else channel_sizes[i-1],
                out_channels=channel_sizes[i],
                kernel_size=k,
                stride=s,
                padding=p,
            )
            self.net.add_module(f"conv_{i}", conv)
            self.net.add_module(f"norm_{i}", nn.GroupNorm(num_groups=channel_sizes[i], num_channels=channel_sizes[i]))
            self.net.add_module(f"gelu_{i}", nn.GELU())

        self.output_conv = nn.Conv1d(
            in_channels=channel_sizes[-1],
            out_channels=channel_sizes[-1],
            kernel_size=1,
            stride=1,
            padding=0)

        self.output_alphas = nn.Linear(channel_sizes[-1], output_dim)
        self.output_alpha0 = nn.Linear(channel_sizes[-1], 1)
        
    def forward(self, x):
        x = self.net(x)
        x = self.output_conv(x)
        x = x.mean(dim=2)  # global average pooling over time dimension
        alpha_scores = self.output_alphas(x)
        alpha0 = F.softplus(self.output_alpha0(x))
        return F.softmax(alpha_scores, dim=-1) * alpha0

In [None]:
downstream_batch_size = 64
downstream_num_epochs = 10
downstream_train_loader = DataLoader(train_dataset, batch_size=downstream_batch_size, shuffle=True, num_workers=0) # must be 0 for hdf5
interval_model = LatentDirichletRegression(
    input_dim=128,
    output_dim=4,
    kernel_sizes=[3,3,3],
    channel_sizes=[128,64,32],
    strides=[2,2,2],
    paddings=[1,1,1],).to(device)
downstream_optimizer = optim.AdamW(
    interval_model.parameters(),
    lr=1e-4,
    weight_decay=1e-2
)
downstream_scheduler = optim.lr_scheduler.CosineAnnealingLR(downstream_optimizer, T_max=downstream_num_epochs * len(downstream_train_loader))

In [None]:
mae.eval()
interval_model.train()
for epoch in range(downstream_num_epochs):
    for i, (traces, p_arrivals, s_arrivals, coda_ends, event_names) in enumerate(downstream_train_loader):
        downstream_optimizer.zero_grad()
        traces = traces.to(device)
        traces_mean = traces.mean(dim=2, keepdim=True)
        traces_std = traces.std(dim=2, keepdim=True) + 1e-9
        normalized_traces = (traces - traces_mean) / traces_std  # normalize input traces
        num_timesteps = traces.size(-1)
        with torch.no_grad():
            tokens = mae.patch_encoder(normalized_traces)
            embeddings = mae.encoder(tokens)[0] # (B, Np, D)
        alphas = interval_model(embeddings.transpose(1, 2))  # (B, T_out)
        dist = torch.distributions.Dirichlet(alphas + 1e-9)
        s1 = p_arrivals/num_timesteps
        s2 = s_arrivals/num_timesteps - s1
        s3 = coda_ends/num_timesteps - s1 - s2
        s4 = 1.0 - (s1 + s2 + s3)
        target = torch.stack([s1, s2, s3, s4], dim=-1)
        loss = -dist.log_prob(target.to(device)).mean()
        loss.backward()
        downstream_optimizer.step()
        downstream_scheduler.step()
        if i % 100 == 0:
            print(f"Downstream Epoch {epoch}, Step {i}, Loss: {loss.item():.2e}")

# OLD

In [None]:
class Conv1dNormActivation(ConvNormActivation):
    """
    Configurable block used for Convolution2d-Normalization-Activation blocks.

    Args:
        in_channels (int): Number of channels in the input image
        out_channels (int): Number of channels produced by the Convolution-Normalization-Activation block
        kernel_size: (int, optional): Size of the convolving kernel. Default: 3
        stride (int, optional): Stride of the convolution. Default: 1
        padding (int, tuple or str, optional): Padding added to all four sides of the input. Default: None, in which case it will be calculated as ``padding = (kernel_size - 1) // 2 * dilation``
        groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
        norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the convolution layer. If ``None`` this layer won't be used. Default: ``torch.nn.BatchNorm2d``
        activation_layer (Callable[..., torch.nn.Module], optional): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer won't be used. Default: ``torch.nn.ReLU``
        dilation (int): Spacing between kernel elements. Default: 1
        inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True``
        bias (bool, optional): Whether to use bias in the convolution layer. By default, biases are included if ``norm_layer is None``.

    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int = 3,
        stride: int = 1,
        padding: Optional[Union[int, tuple[int, int], str]] = None,
        groups: int = 1,
        norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm1d,
        activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
        dilation: int = 1,
        inplace: Optional[bool] = True,
        bias: Optional[bool] = None,
    ) -> None:

        super().__init__(
            in_channels,
            out_channels,
            kernel_size,
            stride,
            padding,
            groups,
            norm_layer,
            activation_layer,
            dilation,
            inplace,
            bias,
            torch.nn.Conv1d,
        )
class GlobalResponseNorm1d(nn.Module):
    def __init__(self, dim, eps=1e-6):
        super(GlobalResponseNorm1d, self).__init__()
        self.eps = eps
        self.gamma = nn.Parameter(torch.zeros(1,1,dim))
        self.beta = nn.Parameter(torch.zeros(1,1,dim))

    def forward(self, x):
        # Compute the global response norm
        gx = torch.norm(x, p=2, dim=1, keepdim=True)
        nx = gx/(torch.mean(gx, dim=-1, keepdim=True) + self.eps)
        return self.gamma * (x * nx) + self.beta + x
class LayerNorm1d(nn.LayerNorm):
    def forward(self, x):
        x = x.permute(0, 2, 1)  # NCL -> NLC
        x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
        x = x.permute(0, 2, 1)  # NLC -> NCL
        return x
class CN2Upsample1d(nn.Module):
    def __init__(self, in_dim : int, out_dim : int):
        super(CN2Upsample1d, self).__init__()
        self.norm = LayerNorm1d(in_dim)
        self.conv = nn.ConvTranspose1d(in_dim, out_dim, kernel_size=2, stride=2, bias=True)
    def forward(self, input):
        x = self.norm(input)
        return self.conv(x)
class CN2Downsample1d(nn.Module):
    def __init__(self, in_dim : int, out_dim : int):
        super(CN2Downsample1d, self).__init__()
        self.norm = LayerNorm1d(in_dim)
        self.conv = nn.Conv1d(in_dim, out_dim, kernel_size=2, stride=2, bias=True)
    def forward(self, input):
        x = self.norm(input)
        return self.conv(x)
class CN2Block1d(nn.Module):
    def __init__(self, dim: int, stochastic_depth_prob: float):
        super().__init__()
        self.conv1 = nn.Conv1d(dim, dim, kernel_size=7, stride=1, padding=3, groups=dim, bias=True)
        self.norm1 = nn.LayerNorm(dim, eps=1e-6)
        self.linear1 = nn.Linear(dim, 4 * dim, bias=True)
        self.grn = GlobalResponseNorm1d(4 * dim)
        self.linear2 = nn.Linear(4 * dim, dim, bias=True)
        self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row")

    def _block(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        # x: (N,C,L), mask: (N,1,L) where True=masked
        if mask is not None:
            if mask.dim() == 1:
                mask = mask[None, None, :]
            elif mask.dim() == 2:
                mask = mask[:, None, :]
            x = x.masked_fill(mask, 0.0)
        
        x = self.conv1(x)

        if mask is not None:
            x = x.masked_fill(mask, 0.0)

        x = x.permute(0, 2, 1)  # NCL -> NLC
        x = self.norm1(x)
        x = self.linear1(x)
        x = F.gelu(x)
        x = self.grn(x)
        x = self.linear2(x)
        x = x.permute(0, 2, 1)  # NLC -> NCL

        x = self.stochastic_depth(x)
        return x

    def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        if mask is not None:
            if mask.dim() == 1:
                mask = mask[None, None, :]
            elif mask.dim() == 2:
                mask = mask[:, None, :]
            x = x.masked_fill(mask, 0.0)
        return x + self._block(x, mask)
class CNBlock1d(nn.Module):
    def __init__(self, dim : int):
        super(CNBlock1d, self).__init__()
        # this does not change the input size and dim
        self.conv1 = nn.Conv1d(dim, dim, kernel_size=7, stride=1, padding=3, groups=dim, bias=True)
        self.norm1 = nn.LayerNorm(dim, eps=1e-6)
        self.linear1 =nn.Linear(dim, 4*dim, bias=True)
        self.linear2 =nn.Linear(4*dim, dim, bias=True)
        self.layer_scale = nn.Parameter(1e-6 * torch.ones((dim)), requires_grad=True)
    def _block(self, input: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        x = self.conv1(input)
        x = x.permute(0, 2, 1)  # NCL -> NLC
        x = self.norm1(x)
        x = self.linear1(x)
        x = F.gelu(x)
        x = self.linear2(x)
        x = x * self.layer_scale
        x = x.permute(0, 2, 1)  # NLC -> NCL
        return x
    def forward(self, input: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        return input + self._block(input, mask)
    

class CN2Encoder1d(nn.Module):
    def __init__(self, input_channels : int,
                embed_dim : int,
                num_blocks : List[int],
                stochastic_depth_prob : float):
        super(CN2Encoder1d, self).__init__()
        self.stem = Conv1dNormActivation(
            in_channels=input_channels,
            out_channels=embed_dim,
            kernel_size=4,
            padding=0,
            stride=4,
            norm_layer=LayerNorm1d,
            activation_layer=None,
            bias=True
        )
        self.stem_pooler = nn.AvgPool1d(
            kernel_size=4,
            stride=4
        )
        self.down_pooler = nn.AvgPool1d(
            kernel_size=2,
            stride=2
        )
        num_stages = len(num_blocks)
        self.num_stages = num_stages
        self.stages = nn.ModuleList()
        self.downsamples = nn.ModuleList()
        for i in range(num_stages):
            stage = nn.ModuleList()
            for j in range(num_blocks[i]):
                block = CN2Block1d(
                    dim=embed_dim,
                    stochastic_depth_prob=stochastic_depth_prob
                )
                stage.append(block)
            
            self.stages.append(stage)
            if i < num_stages-1:
                downsample = CN2Downsample1d(
                    in_dim=embed_dim,
                    out_dim=embed_dim
                )
                self.downsamples.append(downsample)

    def get_downscale_factor(self) -> int:
        return 4 * (2 ** (len(self.downsamples)))  # stem + downsamples        
    def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        # x: (N,C,L)
        if mask is not None:
            # accetta mask (L,) oppure (N,L) oppure (N,1,L)
            if mask.dim() == 1:
                mask = mask[None, :]          # -> (1,L)
            elif mask.dim() == 2:
                mask = mask[:, None]             # -> (N,1,L)
            mask = mask.to(device=x.device, dtype=torch.bool)

            # applica su input (broadcast su C)
            x = x.masked_fill(mask, 0.0)

        # stem: stride=4
        x = self.stem(x)

        if mask is not None:
            # pool mask in modo coerente con lo stem (stride=4)
            x_mask = self.stem_pooler(mask.float()).to(dtype=torch.bool)
            x = x.masked_fill(x_mask, 0.0)  # broadcasting su C
            #print("after stem", x_mask.shape, x.shape)

        for i, stage in enumerate(self.stages):
            for block in stage:
                x = block(x)
                if mask is not None:
                    x = x.masked_fill(x_mask, 0.0)
                #print(f"after stage {i}", x_mask.shape, x.shape)
            if i < len(self.downsamples):
                x = self.downsamples[i](x)  # stride=2

                if mask is not None:
                    # pool mask con stride=2 (coerente col downsample)
                    x_mask = self.down_pooler(x_mask.float()).to(dtype=torch.bool)
                    x = x.masked_fill(x_mask, 0.0)
                    #print(f"after down {i}", x_mask.shape, x.shape)
        return x
    
class CN2Decoder1d(nn.Module):
    def __init__(self,embed_dim : int, out_channels : int, num_upsamples : int):
        super(CN2Decoder1d, self).__init__()
        self.upsampler = nn.Sequential(*[
             nn.Sequential(CN2Upsample1d(embed_dim, embed_dim), LayerNorm1d(embed_dim), nn.GELU())
        for _ in range(num_upsamples) ])
        self.block = CNBlock1d(embed_dim)
        self.final_conv = nn.Conv1d(embed_dim, out_channels, kernel_size=1, stride=1, padding=0, bias=True)
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (N,C,L)
        x = self.upsampler(x)
        x = self.block(x)
        x = self.final_conv(x)
        return x

In [None]:
#time is the last dimension
class ConvFeatureEncoder(nn.Module):
    def __init__(self, in_ch, dim, kernel_sizes, strides, paddings):
        super().__init__()
        self.net = nn.Sequential()
        for i, (k, s, p) in enumerate(zip(kernel_sizes, strides, paddings)):
            conv = nn.Conv1d(in_ch if i == 0 else dim, dim, kernel_size=k, stride=s, padding=p)
            self.net.add_module(f"conv_{i}", conv)
            self.net.add_module(f"gelu_{i}", nn.GELU())
    def forward(self, x):  
        return self.net(x)
#time is the last dimension
class ConvPositionalEncoding(nn.Module):
    def __init__(self, channels, kernel_size=3):
        super().__init__()
        self.conv = nn.Conv1d(
            channels, channels,
            kernel_size=kernel_size,
            groups=channels,
            padding=kernel_size // 2,
            bias=True
        )

    def forward(self, x):
        return x + self.conv(x)
# channel is the last dimension
class ContextEncoder(nn.Module):
    def __init__(self, dim, n_layers, n_heads, ffn_dim, dropout):
        super().__init__()
        encoder_layer = nn.TransformerEncoderLayer(d_model=dim, 
        nhead=n_heads, 
        dim_feedforward=ffn_dim, 
        dropout=dropout, 
        activation='gelu', 
        norm_first=True, 
        batch_first=True)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
    def forward(self, x):
        return self.transformer(x)
    
class MaskedEncoderModel(nn.Module):
    def __init__(
        self,
        in_ch,
        dim,
        mask_prob,
        feature_enc_kernel_sizes,
        feature_enc_strides,
        feature_enc_paddings,
        context_n_layers,
        context_n_heads,
        context_ffn_dim,
        context_dropout=0.1,
        use_cpe=True
    ):
        super().__init__()
        self.mask_prob = mask_prob
        self.feature_encoder = ConvFeatureEncoder(
            in_ch=in_ch,
            dim=dim,
            kernel_sizes=feature_enc_kernel_sizes,
            strides=feature_enc_strides,
            paddings=feature_enc_paddings,
        )
        self.use_cpe = use_cpe
        if use_cpe:
            self.cpe = ConvPositionalEncoding(dim, kernel_size=3)

        self.context_encoder = ContextEncoder(
            dim=dim,
            n_layers=context_n_layers,
            n_heads=context_n_heads,
            ffn_dim=context_ffn_dim,
            dropout=context_dropout
        )

        self.mask_embedding = nn.Parameter(torch.zeros(dim))
        nn.init.normal_(self.mask_embedding, mean=0.0, std=0.02)

        self.pred_head = nn.Sequential(
            nn.Linear(dim, dim),
            nn.GELU(),
            nn.Linear(dim, dim),
        )

    @staticmethod
    def random_mask(x, mask_p):
        B, T, D = x.shape
        mask = torch.zeros(B, T, dtype=torch.bool, device=x.device)
        for b in range(B):
            where_to_mask = torch.bernoulli(torch.ones(T)*mask_p)
            if(where_to_mask.sum() == 0):
                mask[b, torch.rand(0,T)] = True
            else:
                mask[b] = where_to_mask == 1
        return mask
    
    def forward(self, x, run_with_mask):
        # 1) Conv feature encoder: (B, in_ch, T_raw) -> (B, dim, T_enc)
        feats = self.feature_encoder(x)              # encoded targets

        # 2) Optional convolutional positional encoding
        if self.use_cpe:
            feats = self.cpe(feats)                 # (B, dim, T_enc)


        # 3) Prepare for transformer: (B, dim, T_enc) -> (B, T_enc, dim)
        feats_t = feats.transpose(1, 2)             # original encoded features (targets)

        if(run_with_mask):
            # 4) Create masked input sequence
            #masked_input = feats_t.clone()
            #mask_bool = torch.zeros_like(feats_t[:, :, 0], dtype=torch.bool)  # (B, T_enc)
            #mask_bool[torch.arange(0, feats_t.size(0)), torch.randint(0, feats_t.size(1), (feats_t.size(0),))] = True  # randomly mask 1 time step per sample
            mask_bool = MaskedEncoderModel.random_mask(feats_t, self.mask_prob)
            masked_input = feats_t.clone()
            masked_input[mask_bool] = self.mask_embedding  # apply mask
            ctx = self.context_encoder(masked_input)  # (B, T_enc, dim)
            preds = self.pred_head(ctx)  # (B, T_enc, dim)
            return preds, ctx, mask_bool
        else:
            ctx = self.context_encoder(feats_t)  # (B, T_enc, dim)
            return ctx

class LatentDirichletRegression(nn.Module):
    def __init__(self,  input_dim, output_dim, kernel_sizes, channel_sizes, strides, paddings):
        super().__init__()
        self.net = nn.Sequential()
        for i, (k, s, p) in enumerate(zip(kernel_sizes, strides, paddings)):
            conv = nn.Conv1d(
                in_channels=input_dim if i == 0 else channel_sizes[i-1],
                out_channels=channel_sizes[i],
                kernel_size=k,
                stride=s,
                padding=p,
            )
            self.net.add_module(f"conv_{i}", conv)
            self.net.add_module(f"norm_{i}", nn.GroupNorm(num_groups=channel_sizes[i], num_channels=channel_sizes[i]))
            self.net.add_module(f"gelu_{i}", nn.GELU())

        self.output_conv = nn.Sequential(nn.Conv1d(
            in_channels=channel_sizes[-1],
            out_channels=channel_sizes[-1],
            kernel_size=1,
            stride=1,
            padding=0,
        ),
            nn.GELU())

        self.output_alphas = nn.Linear(channel_sizes[-1], output_dim)
        self.output_alpha0 = nn.Linear(channel_sizes[-1], 1)
        
    def forward(self, x):
        x = self.net(x)
        x = self.output_conv(x)
        x = x.mean(dim=2)  # global average pooling over time dimension
        alpha_scores = self.output_alphas(x)
        alpha0 = F.softplus(self.output_alpha0(x))
        return F.softmax(alpha_scores, dim=-1) * alpha0

In [None]:
root = 'STEAD/'
train_chunks = [ root + f'chunk{chunk}/chunk{chunk}.hdf5'  for chunk in range(2, 4) ]
val_chunks = [ root + f'chunk{chunk}/chunk{chunk}.hdf5'  for chunk in range(4, 5) ]
train_dataset = SteadDataset(train_chunks , channel_first=True)
val_dataset = SteadDataset(val_chunks , channel_first=True)

In [None]:
input_dim = 3
feature_enc_kernel_sizes=[10,8,4]
feature_enc_strides=[5,4,2]
feature_enc_paddings=[5,4,2]
p_mask = 0.15
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = MaskedEncoderModel(
    in_ch=input_dim,
    dim=256,
    mask_prob=p_mask,
    feature_enc_kernel_sizes=feature_enc_kernel_sizes,
    feature_enc_strides=feature_enc_strides,
    feature_enc_paddings=feature_enc_paddings,
    context_n_layers=6,
    context_n_heads=4,
    context_ffn_dim=1024,
    context_dropout=0.1,
    use_cpe=True
).to(device)

In [None]:
# model.load_state_dict(torch.load('STEAD/maskedencoder_epoch10.pth', map_location=device ))

In [None]:
num_epochs = 10
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0) # must be 0 for hdf5
steps_per_epoch = len(train_loader)
from torch.optim.lr_scheduler import SequentialLR, LinearLR, CosineAnnealingLR
learning_rate = 3e-4
warmup_steps = int(0.1 * num_epochs * steps_per_epoch)

optimizer = optim.AdamW(
    model.parameters(),
    lr=learning_rate,
    weight_decay=1e-5
)

scheduler = SequentialLR(
    optimizer,
    schedulers=[
        LinearLR(optimizer, start_factor=1e-3, total_iters=warmup_steps),
        CosineAnnealingLR(
            optimizer,
            T_max=num_epochs * steps_per_epoch - warmup_steps,
            eta_min=1e-6
        ),
    ],
    milestones=[warmup_steps]
)

In [None]:
lambda_var = 1e-3
def vicreg_reg(z, mask, eps=1e-4, gamma=1.0):
    # z: [B,T,C], mask: [B,T] True=masked
    u = z[~mask]                           # [N,C]
    if u.shape[0] < 2:
        return z.sum() * 0.0

    u = u - u.mean(dim=0, keepdim=True)

    std = torch.sqrt(u.var(dim=0, unbiased=False) + eps)
    var_loss = 0.5*torch.mean((std-gamma)**2)

    # covariance term (decorrelate dims)
    N, C = u.shape
    cov = (u.T @ u) / (N - 1)              # [C,C]
    offdiag = cov - torch.diag(torch.diag(cov))
    cov_loss = (offdiag**2).mean()

    return var_loss + 0.01 * cov_loss, std      # 0.01 is a decent start

for epoch in range(num_epochs):
    for i, (traces, p_arrivals, s_arrivals, coda_ends, event_names) in enumerate(train_loader):
        traces_mean = traces.mean(dim=2, keepdim=True)
        traces_std = traces.std(dim=2, keepdim=True) + 1e-9
        normalized_traces = (traces - traces_mean) / traces_std  # normalize input traces
        optimizer.zero_grad()
        ctx_preds, ctx, mask_bool = model(normalized_traces.to(device), run_with_mask=True)
        masked_preds = ctx_preds[mask_bool]
        masked_targets = ctx.detach()[mask_bool]
        loss = F.mse_loss(masked_preds, masked_targets) 
        reg_loss, unmasked_ctx_std = vicreg_reg(ctx, mask_bool)
        loss += lambda_var * reg_loss
        loss.backward()
        optimizer.step()
        scheduler.step()

        if(i % 250 == 0):
            print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(dataloader)}], Loss: {loss.item():.1e}")
            with torch.no_grad():
                print(unmasked_ctx_std.cpu().mean().item())

    torch.save(model.state_dict(), f'STEAD/maskedencoder_epoch{epoch+1}.pth')

In [None]:
val_loader = DataLoader(val_dataset, batch_size=1000, shuffle=False, num_workers=0) # must be 0 for hdf5
with torch.no_grad():
    for i, (traces, p_arrivals, s_arrivals, coda_ends, event_names) in enumerate(val_loader):
        traces_mean = traces.mean(dim=2, keepdim=True)
        traces_std = traces.std(dim=2, keepdim=True) + 1e-9
        normalized_traces = (traces - traces_mean) / traces_std
        ctx = model(normalized_traces.to(device), run_with_mask=False)
        print(ctx)
        print(ctx.var(dim=(0,1)).mean())
        break

In [None]:
with torch.no_grad():
    ctx = model(normalized_traces.to(device), run_with_mask=False)

In [None]:
for idx in range(ctx.size(0)):
    plt.plot(ctx[idx].cpu().numpy().mean(axis=0))
    plt.show()

In [None]:
num_epochs = 10
batch_size = 32
learning_rate = 1e-3
interval_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0) # must be 0 for hdf5
interval_model = LatentDirichletRegression(
    input_dim=256,
    output_dim=4,
    kernel_sizes=[3,3,3],
    channel_sizes=[128,64,32],
    strides=[2,2,2],
    paddings=[1,1,1],).to(device)
optimizer_interval = optim.Adam(interval_model.parameters(), lr=learning_rate, weight_decay=1e-5)
steps_per_epoch = len(interval_dataloader)
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
scheduler = CosineAnnealingWarmRestarts(
    optimizer_interval,
    T_0=2 * steps_per_epoch,
    T_mult=2,
    eta_min=1e-6
)

In [None]:
interval_model.load_state_dict(torch.load('STEAD/latdirichlet_epoch10.pth', map_location=device ))

In [None]:
for epoch in range(num_epochs):
    for i, (traces, p_arrivals, s_arrivals, coda_ends, event_names) in enumerate(interval_dataloader):
        traces_mean = traces.mean(dim=2, keepdim=True)
        traces_std = traces.std(dim=2, keepdim=True) + 1e-9
        normalized_traces = (traces - traces_mean) / traces_std  # normalize input traces
        num_timesteps = traces.size(-1)
        with torch.no_grad():
            ctx = model(normalized_traces.to(device), run_with_mask=False)  # (B, T_enc, dim)
            ctx_t = ctx.transpose(1, 2)  # (B, dim, T_enc)
        optimizer_interval.zero_grad()
        alphas = interval_model(ctx_t)  # (B, T_out)
        dist = torch.distributions.Dirichlet(alphas + 1e-9)
        s1 = p_arrivals/num_timesteps
        s2 = s_arrivals/num_timesteps - s1
        s3 = coda_ends/num_timesteps - s1 - s2
        s4 = 1.0 - (s1 + s2 + s3)
        target = torch.stack([s1, s2, s3, s4], dim=-1)
        loss = -dist.log_prob(target.to(device)).mean()
        loss.backward()
        optimizer_interval.step()
        scheduler.step()
        if(i % 100 == 0):
            print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(interval_dataloader)}], Loss: {loss.item():.1e}")
    torch.save(interval_model.state_dict(), f'STEAD/latdirichlet_epoch{epoch+1}.pth')

In [None]:
with torch.no_grad():
    trues = []
    preds = []
    pred_alphas = []
    ntraces = []
    for i, (traces, p_arrivals, s_arrivals, coda_ends, event_names) in enumerate(interval_dataloader):
        traces_mean = traces.mean(dim=2, keepdim=True)
        traces_std = traces.std(dim=2, keepdim=True) + 1e-9
        normalized_traces = (traces - traces_mean) / traces_std  # normalize input traces
        num_timesteps = traces.size(-1)
        ctx = model(normalized_traces.to(device), run_with_mask=False)  # (B, T_enc, dim)
        ctx_t = ctx.transpose(1, 2)  # (B, dim, T_enc)
        alphas = interval_model(ctx_t)  # (B, T_out)
        dist = torch.distributions.Dirichlet(alphas + 1e-9)
        s1 = p_arrivals/num_timesteps
        s2 = s_arrivals/num_timesteps - s1
        s3 = coda_ends/num_timesteps - s1 - s2
        s4 = 1.0 - (s1 + s2 + s3)
        target = torch.stack([s1, s2, s3, s4], dim=-1)
        trues.append(target.cpu().numpy())
        preds.append(dist.mean.cpu().numpy())
        ntraces.append(normalized_traces)
        pred_alphas.append(alphas.cpu().numpy())
        if(i>=5):
            break
    trues = np.concatenate(trues, axis=0)
    preds = np.concatenate(preds, axis=0)
    ntraces = np.concatenate(ntraces, axis=0)
    pred_alphas = np.concatenate(pred_alphas, axis=0)

In [None]:
import scipy.stats as stats
for i in range(trues.shape[0]):
    #print(stats.linregress(trues[:, i], preds[:, i]))
    true_times = np.cumsum(trues[i])
    pred_times = np.cumsum(preds[i])
    distr = torch.distributions.Dirichlet(torch.from_numpy(pred_alphas[i]))
    pred_samples = np.cumsum(distr.sample((10000,)).numpy(), axis=-1)

    plt.plot(np.linspace(0,1, ntraces.shape[-1]), ntraces[i,0,:], color='gray')
    y_min, y_max = plt.gca().get_ylim()
    plt.vlines(true_times[:-1], y_min, y_max, color='red')
    ax2 = plt.gca().twinx()
    ax2.hist(pred_samples[:,0], bins=100, density=True, alpha=0.5)
    ax2.hist(pred_samples[:,1], bins=100, density=True, alpha=0.5)
    ax2.hist(pred_samples[:,2], bins=100, density=True, alpha=0.5)
    plt.show()
    if(i>=20):
        break
    #plt.vlines(pred_times[:-1], y_min, y_max, color='red')
    #plt.show()