In [2]:
import h5py
import numpy as np
import matplotlib.pyplot as plt
from scipy.signal import butter, filtfilt

In [3]:
import matplotlib.colors as mcolors

# Create a ListedColormap: 0 -> transparent, 1 -> orange
cmap_orange_transparent = mcolors.ListedColormap([(1,1,1,0), (1,0.5,0,1)])  # 0: transparent, 1: orange

In [4]:
### Parameters for DAS 
sample_rate = 25
dchan = 9.5714
ch_max = 4500  # max channel of each cable (4500 or 6000)
ch_itv=2  # channels are downsampled for faster picking

### Directories and files
raw_dir = '/fd1/QibinShi_data/akdas/qibin_data/'
out_dir = raw_dir + 'largerEQ_plots_test_picking_dec_ch' + str(ch_max) + '/'
record_time_file = 'recording_times_larger.csv'
qml = raw_dir + 'ak_Dec1_31_a120b065.xml'

In [5]:
### Read phase picks from the previous session
with h5py.File(out_dir + 'phase_picks.hdf5', 'r') as f: #filepaths should be 
    raw_picks = f["raw_alldata_picks"][:]
    one_picks = f["one_denoise_picks"][:]
    mul_picks = f["mul_denoise_picks"][:]
    pred_picks = f["predicted_picks"][:]
    array_dist = f["array_dist"][:]
    
### Read raw and denoised DAS
with h5py.File(raw_dir + 'KKFLStill2024_02_24.hdf5', 'r') as f:
    raw_quake_kkfls = f["raw_quake"][:, :4500, :] # original could be 500:5000, check 
    fk_quake_kkfls = f["fk_quake"][:, :4500, :]

with h5py.File(raw_dir + 'TERRAtill2024_02_24.hdf5', 'r') as f:
    raw_quake_terra = f["raw_quake"][:, :4500, :]
    fk_quake_terra = f["fk_quake"][:, :4500, :]


In [6]:
raw_quake_kkfls.shape
raw_quake_terra.shape

#flip the terra data on the channel axis, the number of ea
quakes = np.concatenate((raw_quake_kkfls[:, ::-1, :], raw_quake_terra), axis=1)


In [7]:
### Bandpass filter
b, a = butter(4, (0.5, 12), fs=sample_rate, btype='bandpass')
filt = filtfilt(b, a, quakes, axis=2)
rawdata = filt / np.std(filt, axis=(1,2), keepdims=True)  ## Rawdata w.r.t. Denoised 

In [8]:
#import the masks generated by previous notebook
p_s_quake_masks = np.load("/home/arose17/FM_Segmentation_DAS/src/data/pick_masking/cleaned_picks_with_data/alex_verified_06042025/masks/p_s_quake_masks_06042025.npy")
p_wave_masks = np.load("/home/arose17/FM_Segmentation_DAS/src/data/pick_masking/cleaned_picks_with_data/alex_verified_06042025/masks/mask_p_waves_06042025.npy")
s_wave_masks = np.load("/home/arose17/FM_Segmentation_DAS/src/data/pick_masking/cleaned_picks_with_data/alex_verified_06042025/masks/mask_s_waves_06042025.npy")
#load the indices files
p_wave_indices = np.load("/home/arose17/FM_Segmentation_DAS/src/data/pick_masking/cleaned_picks_with_data/alex_verified_06042025/masks/p_wave_indices_06042025.npy")
s_wave_indices = np.load("/home/arose17/FM_Segmentation_DAS/src/data/pick_masking/cleaned_picks_with_data/alex_verified_06042025/masks/s_wave_indices_06042025.npy")
both_p_s_indices = np.load("/home/arose17/FM_Segmentation_DAS/src/data/pick_masking/cleaned_picks_with_data/alex_verified_06042025/masks/both_p_s_indices_06042025.npy")

#combined the indices into one array
combined_indices = np.concatenate((p_wave_indices, s_wave_indices, both_p_s_indices), axis=0)

#print(combined_indices.shape)
#print(combined_indices)

#combined the p_wave masks, s_wave masks and p_s_quake_masks into one array
combined_masks = np.concatenate((p_wave_masks, s_wave_masks, p_s_quake_masks), axis=0)

#create a for loop that takes only the i values from combined_indices and makes a new rawdata array with the earthquake data indexed in the specific order
new_rawdata = rawdata[combined_indices, :, :]

#duplicate every odd row and create an even row with the same data
new_combined = np.repeat(combined_masks, 2, axis=1)

print(new_rawdata.shape)
print(new_combined.shape)


(40, 9000, 1500)
(40, 9000, 1500)


In [9]:
#MODEL FROM (Shi et al., 2025), utilizing datalabel class for denodas_train
import torch
from torch import nn
import torch.nn.functional as F

import numpy as np
import gc


import torch
import torch.nn as nn

class unet(nn.Module):
    def __init__(self, ch_in, ch0, ch_max, factors=None, kernel=(3, 3), pad=(1, 1), use_att=False):
        super(unet, self).__init__()
        self.level = len(factors)
        self.factor = factors
        self.relu = nn.ReLU()
        self.kernel = kernel
        self.pad = pad
        self.use_att = use_att
        self.layer = nn.ModuleList([])

        if self.use_att:
            self.attgates = nn.ModuleList([])
            for i in range(self.level):
                nch = min(ch0 * 2 ** i, ch_max)
                self.attgates.append(AttentionGate(nch))

        for i in range(self.level + 1):
            if i == 0:
                nch_input = ch_in
            else:
                nch_input = nch_output
            nch_output = min(ch0 * 2 ** i, ch_max)
            self.layer.append(nn.Conv2d(nch_input, nch_output, self.kernel, padding=self.pad))
            self.layer.append(nn.Conv2d(nch_output, nch_output, self.kernel, padding=self.pad))
            if i > self.level - 2:
                self.layer.append(nn.Dropout(p=0.2))
            if i < self.level:
                self.layer.append(MaxBlurPool2d(nch_output, kernel_size=(self.factor[i], self.factor[i])))

        for i in range(self.level):
            nch_input = min(ch0 * 2 ** (self.level - i), ch_max)
            nch_output = min(ch0 * 2 ** (self.level - i - 1), ch_max)
            scale_factor = (self.factor[self.level - 1 - i], self.factor[self.level - 1 - i])
            self.layer.append(nn.Upsample(scale_factor=scale_factor, mode='nearest'))
            self.layer.append(nn.Conv2d(nch_input, nch_output, self.kernel, padding=self.pad))
            self.layer.append(nn.Conv2d(nch_input, nch_output, self.kernel, padding=self.pad))
            self.layer.append(nn.Conv2d(nch_output, nch_output, self.kernel, padding=self.pad))

        self.layer.append(nn.Conv2d(nch_output, ch_in, self.kernel, padding=self.pad))

        # Track decoder layer indices
        self.decoder_layer_indices = list(range(3 * self.level + 4, len(self.layer)))

        self.initialize_weights()

    def forward(self, x):
        cat_content = []
        if len(x.shape) == 2:
            x = x.unsqueeze(0)
        if len(x.shape) == 3:
            x = x.unsqueeze(1)

        for i in range(self.level - 1):
            x = self.layer[3 * i + 0](x)
            x = self.relu(x)
            x = self.layer[3 * i + 1](x)
            x = self.relu(x)
            cat_content.append(x)
            x = self.layer[3 * i + 2](x)

        x = self.layer[3 * (self.level - 1) + 0](x)
        x = self.relu(x)
        x = self.layer[3 * (self.level - 1) + 1](x)
        x = self.relu(x)
        x = self.layer[3 * (self.level - 1) + 2](x)
        cat_content.append(x)
        x = self.layer[3 * (self.level - 1) + 3](x)

        x = self.layer[3 * self.level + 1](x)
        x = self.relu(x)
        x = self.layer[3 * self.level + 2](x)
        x = self.relu(x)
        x = self.layer[3 * self.level + 3](x)

        st_lvl = 3 * self.level + 4
        for i in range(self.level):
            x = self.layer[st_lvl + i * 4 + 0](x)
            x = self.layer[st_lvl + i * 4 + 1](x)
            x = self.relu(x)

            if self.use_att:
                cat = self.attgates[-1 * (i + 1)](cat_content[-1 * (i + 1)], x)
            else:
                cat = cat_content[-1 * (i + 1)]
            x = torch.cat((cat, x), dim=1)

            x = self.layer[st_lvl + i * 4 + 2](x)
            x = self.relu(x)
            x = self.layer[st_lvl + i * 4 + 3](x)
            x = self.relu(x)

        x = self.layer[7 * self.level + 4](x)
        x = x.squeeze(1)
        x = x.squeeze(1)
        return x

    def initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

    def reinit_decoder(self):
        for idx in self.decoder_layer_indices:
            self.layer[idx].apply(self._reinit_single_layer)

    def freeze_encoder(self):
        encoder_indices = list(range(0, 3 * self.level + 4))
        for idx in encoder_indices:
            for param in self.layer[idx].parameters():
                param.requires_grad = False

    def _reinit_single_layer(self, m):
        if isinstance(m, nn.Conv2d):
            nn.init.kaiming_normal_(m.weight)
            if m.bias is not None:
                nn.init.zeros_(m.bias)

def load_weights_and_reset_decoder(model, checkpoint_path):
    state_dict = torch.load(checkpoint_path, map_location='cpu')
    model.load_state_dict(state_dict)
    model.reinit_decoder()
    model.freeze_encoder()



class MaxBlurPool2d(nn.Module):
    def __init__(self, nch, kernel_size=(2, 2)):
        """ must specify:
            Max pool
        """
        super(MaxBlurPool2d, self).__init__()
        self.kernel_size = kernel_size
        a = self.gaussion_filter(self.kernel_size[0])
        b = self.gaussion_filter(self.kernel_size[1])
        f = torch.matmul(a[:, None], b[None, :])

        f = f / torch.sum(f)
        f = f[None, None, :, :]
        f = f.repeat(nch, nch, 1, 1)

        pad1 = (kernel_size[0] - 1) // 2
        pad2 = kernel_size[0] - 1 - pad1
        pad3 = (kernel_size[1] - 1) // 2
        pad4 = kernel_size[1] - 1 - pad3
        pads = np.array([pad3, pad4, pad1, pad2])
        pads = torch.from_numpy(pads)
        filter = f.to(dtype=torch.float32)

        self.register_buffer('pads', pads)
        self.register_buffer('filter', filter)

    def forward(self, x):
        x = nn.MaxPool2d(kernel_size=self.kernel_size)(x)
        x = F.pad(x, self.pads.tolist(), 'constant', 0)
        x = F.conv2d(x, self.filter, stride=(1, 1), padding='valid')
        return x

    def gaussion_filter(self, kernel_size):

        if kernel_size == 1:
            f = torch.tensor([1., ])
        elif kernel_size == 2:
            f = torch.tensor([1., 1.])
        elif kernel_size == 3:
            f = torch.tensor([1., 2., 1.])
        elif kernel_size == 4:
            f = torch.tensor([1., 3., 3., 1.])
        elif kernel_size == 5:
            f = torch.tensor([1., 4., 6., 4., 1.])
        elif kernel_size == 6:
            f = torch.tensor([1., 5., 10., 10., 5., 1.])
        elif kernel_size == 7:
            f = torch.tensor([1., 6., 15., 20., 15., 6., 1.])
        return f


class AttentionGate(nn.Module):
    def __init__(self, nch):
        super(AttentionGate, self).__init__()
        self.conv1 = nn.Conv2d(nch, nch, (1, 1), padding=0)
        self.conv2 = nn.Conv2d(nch, nch, (1, 1), padding=0)
        self.conv3 = nn.Conv2d(nch, nch, (1, 1), padding=0)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, enc, dec):
        x = self.conv1(enc)
        y = self.conv2(dec)
        z = self.relu(x+y)
        z = self.sigmoid(self.conv3(z))

        return enc * z

class datalabel(nn.Module):

    def __init__(self, X, Y, Nx_sub=1500, stride=750, mask_ratio=0.1, n_masks=10):
        """ This code assumes input size to be Ni, Nx=1500*n, Nt=1500；
            extract 1500^2 square samples and do masking in a bootstrap manner"""

        self.X = X  # DAS matrix
        self.Y = Y  # DAS matrix
        self.Ni = X.shape[0]
        self.Nx = X.shape[1]
        self.Nt = X.shape[2]
        self.Nx_sub = Nx_sub  # Number of channels per sample
        self.stride = stride
        self.n_masks = n_masks  # number of times repeating the mask
        self.mask_traces = int(mask_ratio * Nx_sub)  # the number traces to mask
        self.__data_generation()

    def __len__(self):
        """ Number of samples """
        return int(self.n_masks * self.Ni * ((self.Nx - self.Nx_sub) / self.stride + 1))

    def __getitem__(self, idx):
        return (self.samples[idx], self.labels[idx])

    def __data_generation(self):
        X = self.X
        Y = self.Y
        Ni = self.Ni
        Nt = self.Nt
        Nx = self.Nx
        Nx_sub = self.Nx_sub
        stride = self.stride
        n_masks = self.n_masks
        mask_traces = self.mask_traces

        n_total = self.__len__()  # total number of samples
        samples = np.zeros((n_total, Nx_sub, Nt), dtype=np.float32)
        labels = np.zeros((n_total, Nx_sub, Nt), dtype=np.float32)
        masks = np.ones_like(samples, dtype=np.float32)
        print(samples.shape)

        # Loop over samples
        for k in range(n_masks):
            for i in range(Ni):
                for j, st_ch in enumerate(np.arange(0, Nx-Nx_sub+1, stride)):
                    # %% slice each big image along channels
                    s = (k * Ni + i) * int((Nx-Nx_sub)//stride+1) + j
                    samples[s, :, :] = X[i, st_ch:st_ch + Nx_sub, :]
                    labels[s, :, :] = Y[i, st_ch:st_ch + Nx_sub, :]

                    # rng = np.random.default_rng(s + 11)
                    # trace_masked = rng.choice(Nx_sub, size=mask_traces, replace=False)
                    # masks[s, trace_masked, :] = masks[s, trace_masked, :] * 0
                    
                    # del trace_masked, rng
                    gc.collect()

        self.samples = samples
        # self.masks = masks
        self.labels = labels
        
        del X, Y
        gc.collect()
        
        pass

# %%


In [10]:
#model from (Shi et al., 2023), loss_fn function changed from MSEloss to cross entropy loss, also utilized datalabel class
#intead of the dataflow as I am utiliing the labeled data.

import h5py
import time
import torch
import configparser
import numpy as np
import torch.nn as nn
from matplotlib import pyplot as plt
from numpy.random import default_rng
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split


def train(filtered_data, picks,configure_file='config.ini'):
    # Data
    print("filtered_data shape:", filtered_data.shape, "picks shape:", picks.shape)

    x = filtered_data
    y = picks

    #normalize the x data
    
    for i in range(x.shape[0]):
        x_min = np.min(x[i])
        x_max = np.max(x[i])
        x[i] = (x[i] - x_min) / (x_max - x_min)

    x = np.repeat(x, 2, axis=0)
    y = np.repeat(y, 2, axis=0)
    #normalize the x data

    print("x shape:", x.shape)
    print("y shape:", y.shape)
    
    print(x.shape, y.shape, x.dtype, y.dtype) #batch 1st dimension, then 1500 by 1500

    eighty_percent_length = int(x.shape[0] * 0.8)
    print("before datalabel")
    training_data = datalabel(x[:eighty_percent_length], y[:eighty_percent_length])  # Use datalabel for training data
    validation_data = datalabel(x[eighty_percent_length:], y[eighty_percent_length:])  # Use datalabel for validation data

    print("Attempting to initialize the U-net model...")
    # Initialize the U-net model
    model = unet(1, 16, 1024, factors=(5, 3, 2, 2))
    print("U-net model initialized successfully.")
    devc = try_gpu(i=1)
    model.to(devc)


    print("Post device utilization")

    # Hyper-parameters for training
    batch_size = 2
    lr = 1e-2
    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=lr)
    train_iter = DataLoader(training_data, batch_size=batch_size, shuffle=True)
    validate_iter = DataLoader(validation_data, batch_size=batch_size, shuffle=True)

    lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='min',        # reduce LR when the monitored quantity has stopped decreasing
        factor=0.5,        # reduce by half
        patience=3,        # wait 3 epochs before reducing
    )

    print("Start training...")
    # Train the model
    model, avg_train_losses, avg_valid_losses = train_augmentation(train_iter,
                                                                   validate_iter,
                                                                   model,
                                                                   loss_fn,
                                                                   optimizer,
                                                                   lr_scheduler,
                                                                   epochs=50,
                                                                   patience=6,
                                                                   device=devc,
                                                                   minimum_epochs=5)


def try_gpu(i=0):  # @save
    """Return gpu(i) if exists, otherwise return cpu()."""
    if torch.cuda.device_count() >= i + 1:
        return torch.device(f'cuda:{i}')
    return torch.device('cpu')


class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""

    def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', trace_func=print):
        """Args:
            patience (int): How long to wait after last time validation loss improved.
            verbose (bool): If True, prints a message for each validation loss improvement.
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
            path (str): Path for the checkpoint to be saved to.
            trace_func (function): trace print function.
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.inf
        self.delta = delta
        self.path = path
        self.trace_func = trace_func

    def __call__(self, val_loss, model):
        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        """Saves model when validation loss decrease."""
        if self.verbose:
            self.trace_func(
                f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss


def train_augmentation(train_dataloader, validate_dataloader, model, loss_fn, optimizer, lr_schedule, epochs,
                        patience, device, minimum_epochs=None):
    # get early_stopping ready
    early_stopping = EarlyStopping(patience=patience, verbose=True)

    # save history of losses every epoch
    avg_train_losses = []
    avg_valid_losses = []

    print(f'Training on {device} ...')

    for epoch in range(1, epochs + 1):
        starttime = time.time()  # record time for each epoch
        train_losses = []  # save loss for every batch
        valid_losses = []
        print(f'Epoch {epoch}/{epochs}')
        # ======================= training =======================
      #  print("training")
        model.train()  # train mode on
        for batch, (X, y) in enumerate(train_dataloader):
            X, y = X.to(device), y.to(device)

            # #change x and y to numpy arrays
            # x_np = X.cpu().numpy()
            # y_np = y.cpu().numpy()

            # fig, ax = plt.subplots(1, 2, figsize=(12, 6))
            # ax[0].imshow(x_np[1], aspect='auto', cmap = "RdBu", vmin = -0.5, vmax = 0.5, interpolation = "none")
            # ax[0].imshow(y_np[1], aspect='auto', cmap = cmap_orange_transparent, alpha=0.5, interpolation = "none")

            # ax[1].imshow(x_np[1], aspect='auto', cmap = "RdBu", vmin = -0.5, vmax = 0.5, interpolation = "none")
            # ax[1].imshow(y_np[1], aspect='auto', cmap = cmap_orange_transparent, alpha=0.5, interpolation = "none")

            # plt.show()
            # plt.close()

            # predict and loss
            pred = model(X)
            #print()
            loss = loss_fn(pred, y)
            train_losses.append(loss.item())

            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # ======================= validating =======================
      #  print("validating")
        model.eval()  # evaluation model on
        with torch.no_grad():
            for (X, y) in validate_dataloader:
                X, y = X.to(device), y.to(device)
               # print(X, y)

                pred = model(X)
                loss = loss_fn(pred, y)
                valid_losses.append(loss.item())
        
        # average loss over an epoch
        train_loss = np.average(train_losses)
        valid_loss = np.average(valid_losses)
        avg_train_losses.append(train_loss)
        avg_valid_losses.append(valid_loss)

        lr_schedule.step(valid_loss)

        # ==================== history monitoring ====================
        
      #  print("history monitoring")
        # print training/validation statistics
        epoch_len = len(str(epochs))
        print_msg = (f'[{epoch:>{epoch_len}}/{epochs:>{epoch_len}}] ' +
                     f'train_loss: {train_loss:.5f} ' +
                     f'valid_loss: {valid_loss:.5f} ' +
                     f'time per epoch: {(time.time() - starttime):.3f} s')
        print(print_msg)

        if (minimum_epochs is None) or ((minimum_epochs is not None) and (epoch > minimum_epochs)):
            # if the current valid loss is lowest, save a checkpoint of model weights
            early_stopping(valid_loss, model)

        if early_stopping.early_stop:
            print("Early stopping")
            break

    # load the last checkpoint as the best model
    model.load_state_dict(torch.load('checkpoint.pt'))

    return model, avg_train_losses, avg_valid_losses

#if __name__ == '__main__':
#    train()

In [12]:
model = train(new_rawdata[:2], new_combined[:2])

filtered_data shape: (2, 9000, 1500) picks shape: (2, 9000, 1500)
x shape: (4, 9000, 1500)
y shape: (4, 9000, 1500)
(4, 9000, 1500) (4, 9000, 1500) float64 uint8
before datalabel
(330, 1500, 1500)
(110, 1500, 1500)
Attempting to initialize the U-net model...
U-net model initialized successfully.
Post device utilization
Start training...
Training on cuda:1 ...
Epoch 1/50
[ 1/50] train_loss: 16424663.55168 valid_loss: 1938.34199 time per epoch: 24.641 s
Epoch 2/50
[ 2/50] train_loss: 2474.42902 valid_loss: 1938.30096 time per epoch: 24.227 s
Epoch 3/50
[ 3/50] train_loss: 2474.41041 valid_loss: 1938.29070 time per epoch: 24.307 s
Epoch 4/50
[ 4/50] train_loss: 2474.40465 valid_loss: 1938.28241 time per epoch: 24.213 s
Epoch 5/50
[ 5/50] train_loss: 2474.40169 valid_loss: 1938.28088 time per epoch: 24.300 s
Epoch 6/50
[ 6/50] train_loss: 2474.40018 valid_loss: 1938.28048 time per epoch: 24.297 s
Validation loss decreased (inf --> 1938.280482).  Saving model ...
Epoch 7/50


KeyboardInterrupt: 

## Debugging Below:

In [None]:
# test = datalabel(new_rawdata[0:2], new_combined[0:2])
# test2 = datalabel(new_rawdata[2:4], new_combined[2:4])

In [None]:
new_rawdata.shape

(40, 9000, 1500)

In [None]:
x = new_rawdata[:2]
y = new_combined[:2]

x = (x - np.mean(x, axis=(1, 2), keepdims=True)) / np.std(x, axis=(1, 2), keepdims=True)

for i in range(x.shape[0]):
    x_min = np.min(x[i])
    x_max = np.max(x[i])
    x[i] = (x[i] - x_min) / (x_max - x_min)
    print(np.max(x[i,:,:]))



# x = np.repeat(x, 2, axis=0)
# y = np.repeat(y, 2, axis=0)
# #normalize the x data

# print("x shape:", x.shape)
# print("y shape:", y.shape)

# print(x.shape, y.shape, x.dtype, y.dtype) #batch 1st dimension, then 1500 by 1500

# #take 80% of the data total length of x[0] for training and 20% for validation
# eighty_percent_length = int(x.shape[0] * 0.8)


1.0
1.0


In [None]:
print(eighty_percent_length)

6


In [None]:
training_data = datalabel(x[:eighty_percent_length], y[:eighty_percent_length])  # Use datalabel for training data
validation_data = datalabel(x[eighty_percent_length:], y[eighty_percent_length:])

print(x.shape, y.shape, x.dtype, y.dtype) #batch 1st dimension, then 1500 by 1500

(660, 1500, 1500)
(220, 1500, 1500)
(8, 9000, 1500) (8, 9000, 1500) float64 uint8


In [None]:
print(training_data.samples.shape, training_data.labels.shape)
# Initialize the U-net od
print(validation_data.samples.shape, validation_data.labels.shape)

(660, 1500, 1500) (660, 1500, 1500)
(220, 1500, 1500) (220, 1500, 1500)


In [None]:
batch_size = 2
train_iter = DataLoader(training_data, batch_size=batch_size, shuffle=True)
validate_iter = DataLoader(validation_data, batch_size=batch_size, shuffle=True)

print(train_iter, validate_iter)

<torch.utils.data.dataloader.DataLoader object at 0x7fe0cc6c0a10> <torch.utils.data.dataloader.DataLoader object at 0x7fe0cc6c2350>


In [None]:
for batch, (X, y) in enumerate(train_iter):
    print(f"Batch {batch}: X shape: {X.shape}, y shape: {y.shape}")     

                # pred = model(X)
                # loss = loss_fn(pred, y)
                # valid_losses.append(loss.item())

Batch 0: X shape: torch.Size([2, 1500, 1500]), y shape: torch.Size([2, 1500, 1500])
Batch 1: X shape: torch.Size([2, 1500, 1500]), y shape: torch.Size([2, 1500, 1500])
Batch 2: X shape: torch.Size([2, 1500, 1500]), y shape: torch.Size([2, 1500, 1500])
Batch 3: X shape: torch.Size([2, 1500, 1500]), y shape: torch.Size([2, 1500, 1500])
Batch 4: X shape: torch.Size([2, 1500, 1500]), y shape: torch.Size([2, 1500, 1500])
Batch 5: X shape: torch.Size([2, 1500, 1500]), y shape: torch.Size([2, 1500, 1500])
Batch 6: X shape: torch.Size([2, 1500, 1500]), y shape: torch.Size([2, 1500, 1500])
Batch 7: X shape: torch.Size([2, 1500, 1500]), y shape: torch.Size([2, 1500, 1500])
Batch 8: X shape: torch.Size([2, 1500, 1500]), y shape: torch.Size([2, 1500, 1500])
Batch 9: X shape: torch.Size([2, 1500, 1500]), y shape: torch.Size([2, 1500, 1500])
Batch 10: X shape: torch.Size([2, 1500, 1500]), y shape: torch.Size([2, 1500, 1500])
Batch 11: X shape: torch.Size([2, 1500, 1500]), y shape: torch.Size([2, 150

In [None]:
model = train(new_rawdata[:4], new_combined[:4])

filtered_data shape: (4, 9000, 1500) picks shape: (4, 9000, 1500)
x shape: (8, 9000, 1500)
y shape: (8, 9000, 1500)
(8, 9000, 1500) (8, 9000, 1500) float64 uint8
before datalabel
(660, 1500, 1500)
(220, 1500, 1500)
Attempting to initialize the U-net model...
1
2
U-net model initialized successfully.
Post device utilization


TypeError: ReduceLROnPlateau.__init__() got an unexpected keyword argument 'verbose'