In [121]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import obspy
from obspy.clients.fdsn import Client
from obspy import UTCDateTime
import matplotlib.pyplot as plt
import h5py
import pandas as pd

In [None]:
class SteadDataset(Dataset):

    def __init__(self, chunk_files, channel_first):
        self.files = []
        self.event_lists = []
        self.stopping_indices = None
        for chunk in chunk_files:
            file = h5py.File(chunk, 'r')
            metadata = pd.read_csv(chunk.replace('hdf5', 'csv'))
            ev_list = metadata['trace_name'].astype('str').to_list()
            self.files.append(file)
            self.event_lists.append(ev_list)
            if self.stopping_indices:
                self.stopping_indices.append(self.stopping_indices[-1] + len(ev_list))
            else:
                self.stopping_indices = [len(ev_list)]
        self.stopping_indices = np.array(self.stopping_indices)
        self.channel_first = channel_first
    def __len__(self):
        return sum([len(ev_list) for ev_list in self.event_lists])
    

    def __getitem__(self, idx):
        # find which chunk
        chunk_idx = 0
        while idx >= self.stopping_indices[chunk_idx]:
            chunk_idx += 1
        relative_idx = idx - self.stopping_indices[chunk_idx - 1] if chunk_idx > 0 else idx
        event_name = self.event_lists[chunk_idx][relative_idx]
        file = self.files[chunk_idx].get('data/' + event_name)
        trace = np.array(file)
        p_arrival = file.attrs['p_arrival_sample']
        s_arrival = file.attrs['s_arrival_sample']
        coda_end = file.attrs['coda_end_sample']
        if(p_arrival == ''):
            p_arrival = np.nan
        if(s_arrival == ''):
            s_arrival = np.nan
        if(coda_end == ''):
            coda_end = np.nan
        if self.channel_first:
            trace = trace.transpose(1, 0)
        return trace, p_arrival.item(), s_arrival.item(), coda_end.item(), event_name

# OLD

In [None]:
#time is the last dimension
class ConvFeatureEncoder(nn.Module):
    def __init__(self, in_ch, dim, kernel_sizes, strides, paddings):
        super().__init__()
        self.net = nn.Sequential()
        for i, (k, s, p) in enumerate(zip(kernel_sizes, strides, paddings)):
            conv = nn.Conv1d(in_ch if i == 0 else dim, dim, kernel_size=k, stride=s, padding=p)
            self.net.add_module(f"conv_{i}", conv)
            self.net.add_module(f"gelu_{i}", nn.GELU())
    def forward(self, x):  
        return self.net(x)
#time is the last dimension
class ConvPositionalEncoding(nn.Module):
    def __init__(self, channels, kernel_size=3):
        super().__init__()
        self.conv = nn.Conv1d(
            channels, channels,
            kernel_size=kernel_size,
            groups=channels,
            padding=kernel_size // 2,
            bias=True
        )

    def forward(self, x):
        return x + self.conv(x)
# channel is the last dimension
class ContextEncoder(nn.Module):
    def __init__(self, dim, n_layers, n_heads, ffn_dim, dropout):
        super().__init__()
        encoder_layer = nn.TransformerEncoderLayer(d_model=dim, 
        nhead=n_heads, 
        dim_feedforward=ffn_dim, 
        dropout=dropout, 
        activation='gelu', 
        norm_first=True, 
        batch_first=True)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
    def forward(self, x):
        return self.transformer(x)
    
class MaskedEncoderModel(nn.Module):
    def __init__(
        self,
        in_ch,
        dim,
        mask_prob,
        feature_enc_kernel_sizes,
        feature_enc_strides,
        feature_enc_paddings,
        context_n_layers,
        context_n_heads,
        context_ffn_dim,
        context_dropout=0.1,
        use_cpe=True
    ):
        super().__init__()
        self.mask_prob = mask_prob
        self.feature_encoder = ConvFeatureEncoder(
            in_ch=in_ch,
            dim=dim,
            kernel_sizes=feature_enc_kernel_sizes,
            strides=feature_enc_strides,
            paddings=feature_enc_paddings,
        )
        self.use_cpe = use_cpe
        if use_cpe:
            self.cpe = ConvPositionalEncoding(dim, kernel_size=3)

        self.context_encoder = ContextEncoder(
            dim=dim,
            n_layers=context_n_layers,
            n_heads=context_n_heads,
            ffn_dim=context_ffn_dim,
            dropout=context_dropout
        )

        self.mask_embedding = nn.Parameter(torch.zeros(dim))
        nn.init.normal_(self.mask_embedding, mean=0.0, std=0.02)

        self.pred_head = nn.Sequential(
            nn.Linear(dim, dim),
            nn.GELU(),
            nn.Linear(dim, dim),
        )

    @staticmethod
    def random_mask(x, mask_p):
        B, T, D = x.shape
        mask = torch.zeros(B, T, dtype=torch.bool, device=x.device)
        for b in range(B):
            where_to_mask = torch.bernoulli(torch.ones(T)*mask_p)
            if(where_to_mask.sum() == 0):
                mask[b, torch.rand(0,T)] = True
            else:
                mask[b] = where_to_mask == 1
        return mask
    
    def forward(self, x, run_with_mask):
        # 1) Conv feature encoder: (B, in_ch, T_raw) -> (B, dim, T_enc)
        feats = self.feature_encoder(x)              # encoded targets

        # 2) Optional convolutional positional encoding
        if self.use_cpe:
            feats = self.cpe(feats)                 # (B, dim, T_enc)


        # 3) Prepare for transformer: (B, dim, T_enc) -> (B, T_enc, dim)
        feats_t = feats.transpose(1, 2)             # original encoded features (targets)

        if(run_with_mask):
            # 4) Create masked input sequence
            #masked_input = feats_t.clone()
            #mask_bool = torch.zeros_like(feats_t[:, :, 0], dtype=torch.bool)  # (B, T_enc)
            #mask_bool[torch.arange(0, feats_t.size(0)), torch.randint(0, feats_t.size(1), (feats_t.size(0),))] = True  # randomly mask 1 time step per sample
            mask_bool = MaskedEncoderModel.random_mask(feats_t, self.mask_prob)
            masked_input = feats_t.clone()
            masked_input[mask_bool] = self.mask_embedding  # apply mask
            ctx = self.context_encoder(masked_input)  # (B, T_enc, dim)
            preds = self.pred_head(ctx)  # (B, T_enc, dim)
            return preds, ctx, mask_bool
        else:
            ctx = self.context_encoder(feats_t)  # (B, T_enc, dim)
            return ctx

class LatentDirichletRegression(nn.Module):
    def __init__(self,  input_dim, output_dim, kernel_sizes, channel_sizes, strides, paddings):
        super().__init__()
        self.net = nn.Sequential()
        for i, (k, s, p) in enumerate(zip(kernel_sizes, strides, paddings)):
            conv = nn.Conv1d(
                in_channels=input_dim if i == 0 else channel_sizes[i-1],
                out_channels=channel_sizes[i],
                kernel_size=k,
                stride=s,
                padding=p,
            )
            self.net.add_module(f"conv_{i}", conv)
            self.net.add_module(f"norm_{i}", nn.GroupNorm(num_groups=channel_sizes[i], num_channels=channel_sizes[i]))
            self.net.add_module(f"gelu_{i}", nn.GELU())

        self.output_conv = nn.Sequential(nn.Conv1d(
            in_channels=channel_sizes[-1],
            out_channels=channel_sizes[-1],
            kernel_size=1,
            stride=1,
            padding=0,
        ),
            nn.GELU())

        self.output_alphas = nn.Linear(channel_sizes[-1], output_dim)
        self.output_alpha0 = nn.Linear(channel_sizes[-1], 1)
        
    def forward(self, x):
        x = self.net(x)
        x = self.output_conv(x)
        x = x.mean(dim=2)  # global average pooling over time dimension
        alpha_scores = self.output_alphas(x)
        alpha0 = F.softplus(self.output_alpha0(x))
        return F.softmax(alpha_scores, dim=-1) * alpha0

In [None]:
root = 'STEAD/'
train_chunks = [ root + f'chunk{chunk}/chunk{chunk}.hdf5'  for chunk in range(2, 4) ]
val_chunks = [ root + f'chunk{chunk}/chunk{chunk}.hdf5'  for chunk in range(4, 5) ]
train_dataset = SteadDataset(train_chunks , channel_first=True)
val_dataset = SteadDataset(val_chunks , channel_first=True)

In [None]:
input_dim = 3
feature_enc_kernel_sizes=[10,8,4]
feature_enc_strides=[5,4,2]
feature_enc_paddings=[5,4,2]
p_mask = 0.15
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = MaskedEncoderModel(
    in_ch=input_dim,
    dim=256,
    mask_prob=p_mask,
    feature_enc_kernel_sizes=feature_enc_kernel_sizes,
    feature_enc_strides=feature_enc_strides,
    feature_enc_paddings=feature_enc_paddings,
    context_n_layers=6,
    context_n_heads=4,
    context_ffn_dim=1024,
    context_dropout=0.1,
    use_cpe=True
).to(device)

In [None]:
# model.load_state_dict(torch.load('STEAD/maskedencoder_epoch10.pth', map_location=device ))

In [None]:
num_epochs = 10
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0) # must be 0 for hdf5
steps_per_epoch = len(train_loader)
from torch.optim.lr_scheduler import SequentialLR, LinearLR, CosineAnnealingLR
learning_rate = 3e-4
warmup_steps = int(0.1 * num_epochs * steps_per_epoch)

optimizer = optim.AdamW(
    model.parameters(),
    lr=learning_rate,
    weight_decay=1e-5
)

scheduler = SequentialLR(
    optimizer,
    schedulers=[
        LinearLR(optimizer, start_factor=1e-3, total_iters=warmup_steps),
        CosineAnnealingLR(
            optimizer,
            T_max=num_epochs * steps_per_epoch - warmup_steps,
            eta_min=1e-6
        ),
    ],
    milestones=[warmup_steps]
)

In [None]:
lambda_var = 1e-3
def vicreg_reg(z, mask, eps=1e-4, gamma=1.0):
    # z: [B,T,C], mask: [B,T] True=masked
    u = z[~mask]                           # [N,C]
    if u.shape[0] < 2:
        return z.sum() * 0.0

    u = u - u.mean(dim=0, keepdim=True)

    std = torch.sqrt(u.var(dim=0, unbiased=False) + eps)
    var_loss = 0.5*torch.mean((std-gamma)**2)

    # covariance term (decorrelate dims)
    N, C = u.shape
    cov = (u.T @ u) / (N - 1)              # [C,C]
    offdiag = cov - torch.diag(torch.diag(cov))
    cov_loss = (offdiag**2).mean()

    return var_loss + 0.01 * cov_loss, std      # 0.01 is a decent start

for epoch in range(num_epochs):
    for i, (traces, p_arrivals, s_arrivals, coda_ends, event_names) in enumerate(train_loader):
        traces_mean = traces.mean(dim=2, keepdim=True)
        traces_std = traces.std(dim=2, keepdim=True) + 1e-9
        normalized_traces = (traces - traces_mean) / traces_std  # normalize input traces
        optimizer.zero_grad()
        ctx_preds, ctx, mask_bool = model(normalized_traces.to(device), run_with_mask=True)
        masked_preds = ctx_preds[mask_bool]
        masked_targets = ctx.detach()[mask_bool]
        loss = F.mse_loss(masked_preds, masked_targets) 
        reg_loss, unmasked_ctx_std = vicreg_reg(ctx, mask_bool)
        loss += lambda_var * reg_loss
        loss.backward()
        optimizer.step()
        scheduler.step()

        if(i % 250 == 0):
            print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(dataloader)}], Loss: {loss.item():.1e}")
            with torch.no_grad():
                print(unmasked_ctx_std.cpu().mean().item())

    torch.save(model.state_dict(), f'STEAD/maskedencoder_epoch{epoch+1}.pth')

In [None]:
val_loader = DataLoader(val_dataset, batch_size=1000, shuffle=False, num_workers=0) # must be 0 for hdf5
with torch.no_grad():
    for i, (traces, p_arrivals, s_arrivals, coda_ends, event_names) in enumerate(val_loader):
        traces_mean = traces.mean(dim=2, keepdim=True)
        traces_std = traces.std(dim=2, keepdim=True) + 1e-9
        normalized_traces = (traces - traces_mean) / traces_std
        ctx = model(normalized_traces.to(device), run_with_mask=False)
        print(ctx)
        print(ctx.var(dim=(0,1)).mean())
        break

In [None]:
with torch.no_grad():
    ctx = model(normalized_traces.to(device), run_with_mask=False)

In [None]:
for idx in range(ctx.size(0)):
    plt.plot(ctx[idx].cpu().numpy().mean(axis=0))
    plt.show()

In [None]:
num_epochs = 10
batch_size = 32
learning_rate = 1e-3
interval_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0) # must be 0 for hdf5
interval_model = LatentDirichletRegression(
    input_dim=256,
    output_dim=4,
    kernel_sizes=[3,3,3],
    channel_sizes=[128,64,32],
    strides=[2,2,2],
    paddings=[1,1,1],).to(device)
optimizer_interval = optim.Adam(interval_model.parameters(), lr=learning_rate, weight_decay=1e-5)
steps_per_epoch = len(interval_dataloader)
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
scheduler = CosineAnnealingWarmRestarts(
    optimizer_interval,
    T_0=2 * steps_per_epoch,
    T_mult=2,
    eta_min=1e-6
)

In [None]:
interval_model.load_state_dict(torch.load('STEAD/latdirichlet_epoch10.pth', map_location=device ))

In [None]:
for epoch in range(num_epochs):
    for i, (traces, p_arrivals, s_arrivals, coda_ends, event_names) in enumerate(interval_dataloader):
        traces_mean = traces.mean(dim=2, keepdim=True)
        traces_std = traces.std(dim=2, keepdim=True) + 1e-9
        normalized_traces = (traces - traces_mean) / traces_std  # normalize input traces
        num_timesteps = traces.size(-1)
        with torch.no_grad():
            ctx = model(normalized_traces.to(device), run_with_mask=False)  # (B, T_enc, dim)
            ctx_t = ctx.transpose(1, 2)  # (B, dim, T_enc)
        optimizer_interval.zero_grad()
        alphas = interval_model(ctx_t)  # (B, T_out)
        dist = torch.distributions.Dirichlet(alphas + 1e-9)
        s1 = p_arrivals/num_timesteps
        s2 = s_arrivals/num_timesteps - s1
        s3 = coda_ends/num_timesteps - s1 - s2
        s4 = 1.0 - (s1 + s2 + s3)
        target = torch.stack([s1, s2, s3, s4], dim=-1)
        loss = -dist.log_prob(target.to(device)).mean()
        loss.backward()
        optimizer_interval.step()
        scheduler.step()
        if(i % 100 == 0):
            print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(interval_dataloader)}], Loss: {loss.item():.1e}")
    torch.save(interval_model.state_dict(), f'STEAD/latdirichlet_epoch{epoch+1}.pth')

In [None]:
with torch.no_grad():
    trues = []
    preds = []
    pred_alphas = []
    ntraces = []
    for i, (traces, p_arrivals, s_arrivals, coda_ends, event_names) in enumerate(interval_dataloader):
        traces_mean = traces.mean(dim=2, keepdim=True)
        traces_std = traces.std(dim=2, keepdim=True) + 1e-9
        normalized_traces = (traces - traces_mean) / traces_std  # normalize input traces
        num_timesteps = traces.size(-1)
        ctx = model(normalized_traces.to(device), run_with_mask=False)  # (B, T_enc, dim)
        ctx_t = ctx.transpose(1, 2)  # (B, dim, T_enc)
        alphas = interval_model(ctx_t)  # (B, T_out)
        dist = torch.distributions.Dirichlet(alphas + 1e-9)
        s1 = p_arrivals/num_timesteps
        s2 = s_arrivals/num_timesteps - s1
        s3 = coda_ends/num_timesteps - s1 - s2
        s4 = 1.0 - (s1 + s2 + s3)
        target = torch.stack([s1, s2, s3, s4], dim=-1)
        trues.append(target.cpu().numpy())
        preds.append(dist.mean.cpu().numpy())
        ntraces.append(normalized_traces)
        pred_alphas.append(alphas.cpu().numpy())
        if(i>=5):
            break
    trues = np.concatenate(trues, axis=0)
    preds = np.concatenate(preds, axis=0)
    ntraces = np.concatenate(ntraces, axis=0)
    pred_alphas = np.concatenate(pred_alphas, axis=0)

In [None]:
import scipy.stats as stats
for i in range(trues.shape[0]):
    #print(stats.linregress(trues[:, i], preds[:, i]))
    true_times = np.cumsum(trues[i])
    pred_times = np.cumsum(preds[i])
    distr = torch.distributions.Dirichlet(torch.from_numpy(pred_alphas[i]))
    pred_samples = np.cumsum(distr.sample((10000,)).numpy(), axis=-1)

    plt.plot(np.linspace(0,1, ntraces.shape[-1]), ntraces[i,0,:], color='gray')
    y_min, y_max = plt.gca().get_ylim()
    plt.vlines(true_times[:-1], y_min, y_max, color='red')
    ax2 = plt.gca().twinx()
    ax2.hist(pred_samples[:,0], bins=100, density=True, alpha=0.5)
    ax2.hist(pred_samples[:,1], bins=100, density=True, alpha=0.5)
    ax2.hist(pred_samples[:,2], bins=100, density=True, alpha=0.5)
    plt.show()
    if(i>=20):
        break
    #plt.vlines(pred_times[:-1], y_min, y_max, color='red')
    #plt.show()