# Setup

In [1]:
import numpy as np, scipy as sp
from scipy.stats import norm
import torch, time, os, math, glob
from tqdm import tqdm

from torch import nn, optim, autograd
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset, random_split
# from torchvision import datasets, transforms, models
# import segmentation_models_pytorch as smp

# from sklearn.model_selection import train_test_split

from matplotlib import pyplot as plt

REQUIRE_CUDA = True
is_cuda_available = torch.cuda.is_available()
if REQUIRE_CUDA and (not is_cuda_available):
    raise Exception('cuda is unavailable and requested')
device = torch.device("cuda" if is_cuda_available else "cpu")
print(f"using {device} device" + (f" · count is {torch.cuda.device_count()}" if is_cuda_available else ''))

if is_cuda_available:
    torch.backends.cudnn.benchmark = True
    torch._dynamo.reset()
    #torch.set_default_device('cuda')

#

CLEAR_WEIGHT_FOLDER = False

if CLEAR_WEIGHT_FOLDER:
    import shutil
    shutil.rmtree("./weight")

#

if not os.path.exists('./weight'):
    os.makedirs('./weight')

using cuda device · count is 1


# U-Net Architecture

In [2]:
class CoordConv2d_onlyY_fixedDims(nn.Module):
    def __init__(
        self,
        dimensions,
        in_channels,
        out_channels,
        kernel_size,
        stride=1,
        padding=0,
        dilation=1,
        groups=1,
        bias=True,
    ):
        super(CoordConv2d_onlyY_fixedDims, self).__init__()

        self.conv = nn.Conv2d(
            in_channels + 2,
            out_channels,
            kernel_size,
            stride,
            padding,
            dilation,
            groups,
            bias,
        )
        self.y_coords = (
            ((torch.arange(1, 1 + dimensions[1]) + 1.0 / 2.0) / (1.0 + dimensions[1]))
            .repeat((1, 1, dimensions[0], 1))
            .to(device)
        )
        self.y_coords_log = ((11.0 + torch.log2(self.y_coords)) / 11.0).to(device)

    def forward(self, x):
        batch_size, channels, width, height = x.size()

        # Concatenates coordinates channels to input
        x = torch.cat(
            [
                x,
                self.y_coords.expand(x.size(0), 1, width, height),
                self.y_coords_log.expand(x.size(0), 1, width, height),
            ],
            dim=1,
        )

        # Performs convolution
        x = self.conv(x)

        return x


class DeeplySupervizedUnet(nn.Module):

    def convblock_enc(
        self, dims, n_in_channels, n_out_channels, permits_slimming=False
    ):
        if not permits_slimming:
            assert n_in_channels <= n_out_channels
        # https://debuggercafe.com/unet-from-scratch-using-pytorch
        """
        In the original paper implementation, the convolution operations were
        not padded but we are padding them here. This is because, we need the
        output result size to be same as input size.
        """
        biggest_n_channels = max(n_in_channels, n_out_channels)
        conv_op = nn.Sequential(
            CoordConv2d_onlyY_fixedDims(
                dims,
                n_in_channels,
                biggest_n_channels,
                kernel_size=3,
                padding=1,
                bias=False,
            ),
            nn.BatchNorm2d(biggest_n_channels),
            nn.Mish(inplace=True),
            CoordConv2d_onlyY_fixedDims(
                dims,
                biggest_n_channels,
                n_out_channels,
                kernel_size=3,
                padding=1,
                bias=False,
            ),
            nn.BatchNorm2d(n_out_channels),
            nn.LeakyReLU(inplace=True),
        )
        return conv_op

    def convblock_dec(
        self, n_in_channels, n_out_channels, permits_thicken=False, dropout_p=0.0
    ):
        if not permits_thicken:
            assert n_out_channels <= n_in_channels
        # https://debuggercafe.com/unet-from-scratch-using-pytorch
        """
        In the original paper implementation, the convolution operations were
        not padded but we are padding them here. This is because, we need the
        output result size to be same as input size.
        """
        mean_n_channels = (n_in_channels + n_out_channels) // 2
        conv_op = nn.Sequential(
            (nn.Dropout2d(p=dropout_p) if dropout_p > 0 else nn.Identity()),
            nn.Conv2d(
                n_in_channels, mean_n_channels, kernel_size=3, padding=1, bias=True
            ),
            nn.Mish(inplace=True),
            nn.Conv2d(
                mean_n_channels, n_out_channels, kernel_size=3, padding=1, bias=True
            ),
            nn.LeakyReLU(inplace=True),
        )
        return conv_op

    def downsampler(self, n_channels, ratios):
        r_x, r_y = ratios
        assert r_x == int(r_x) and r_y == int(r_y)

        # depthwise-seperable
        conv_op = nn.Sequential(
            (
                nn.Conv2d(
                    n_channels,
                    n_channels,
                    kernel_size=(1, 1 + 2 * (r_y - 1)),
                    padding=(0, r_y - 1),
                    bias=True,
                    groups=n_channels,
                )
                if r_x != 1
                else nn.Identity()
            ),
            (
                nn.Conv2d(
                    n_channels,
                    n_channels,
                    kernel_size=(1 + 2 * (r_x - 1), 1),
                    padding=(r_x - 1, 0),
                    bias=True,
                    groups=n_channels,
                )
                if r_y != 1
                else nn.Identity()
            ),
            nn.Conv2d(
                n_channels,
                n_channels,
                kernel_size=1,
                bias=True,
            ),
            nn.MaxPool2d(kernel_size=(r_x, r_y), stride=(r_x, r_y)),
            # nn.LeakyReLU(inplace=True),
        )
        return conv_op

    def upsampler(self, n_channels, ratios):
        r_x, r_y = ratios
        assert r_x == int(r_x) and r_y == int(r_y)
        conv_op = nn.Sequential(
            nn.Upsample(
                scale_factor=ratios,
                mode="bilinear",
                align_corners=False,
            ),
            (
                nn.Conv2d(
                    n_channels,
                    n_channels,
                    kernel_size=(1 + 2 * (r_x - 1), 1),
                    padding=(r_x - 1, 0),
                    bias=True,
                    groups=n_channels,
                )
                if r_x != 1
                else nn.Identity()
            ),
            (
                nn.Conv2d(
                    n_channels,
                    n_channels,
                    kernel_size=(1, 1 + 2 * (r_y - 1)),
                    padding=(0, r_y - 1),
                    bias=True,
                    groups=n_channels,
                )
                if r_y != 1
                else nn.Identity()
            ),
            nn.Conv2d(
                n_channels,
                n_channels,
                kernel_size=1,
                bias=True,
            ),
            # nn.LeakyReLU(inplace=True),
        )
        return conv_op

    def exporter(self, n_in_channels, n_intermediate_channels):
        return nn.Sequential(
            nn.Conv2d(
                n_in_channels,
                n_intermediate_channels,
                kernel_size=1,
                stride=1,
            ),
            nn.ReLU(inplace=True),
            nn.Conv2d(
                n_intermediate_channels,
                1,
                kernel_size=1,
                stride=1,
            ),
        )

    def __init__(self):
        INPUTS_PATHS_DIMS = [
            [(64, 256), (32, 256)],
            [(32, 512), (16, 512)],
            [(16, 1024), (16, 512)],
        ]
        N_INPUT_PATHS = len(INPUTS_PATHS_DIMS)
        DIMS_AT_JOIN = (16, 256)
        DIMS_OUTS = [DIMS_AT_JOIN, (32, 512), (32, 1024)]
        DIMS_SMALLEST = (2, 4)
        N_TIER_1_FEATURE_MAPS = 32
        N_TIER_2_FEATURE_MAPS = 64
        N_TIER_3_FEATURE_MAPS = 96
        N_TIER_4_FEATURE_MAPS = 128
        N_TIER_5_FEATURE_MAPS = 256
        N_TIER_6_FEATURE_MAPS = 384
        N_FC_HIDDEN_LAYER_NODES = 512
        N_FC_OUTPUTS = 24
        N_TIER_1_PRE_DIMENSIONALITY_REDUC_FEATURE_MAPS = 16

        super(DeeplySupervizedUnet, self).__init__()

        self.input_1_path = nn.ModuleList(
            [
                self.convblock_enc(INPUTS_PATHS_DIMS[0][0], 1, N_TIER_1_FEATURE_MAPS),
                self.downsampler(N_TIER_1_FEATURE_MAPS, (2, 1)),
                self.convblock_enc(
                    INPUTS_PATHS_DIMS[0][1],
                    N_TIER_1_FEATURE_MAPS,
                    N_TIER_2_FEATURE_MAPS,
                ),
                self.downsampler(N_TIER_2_FEATURE_MAPS, (2, 1)),
            ]
        )
        self.input_2_path = nn.ModuleList(
            [
                self.convblock_enc(INPUTS_PATHS_DIMS[1][0], 1, N_TIER_1_FEATURE_MAPS),
                self.downsampler(N_TIER_1_FEATURE_MAPS, (2, 1)),
                self.convblock_enc(
                    INPUTS_PATHS_DIMS[1][1],
                    N_TIER_1_FEATURE_MAPS,
                    N_TIER_2_FEATURE_MAPS,
                ),
                self.downsampler(N_TIER_2_FEATURE_MAPS, (1, 2)),
            ]
        )
        self.input_3_path = nn.ModuleList(
            [
                self.convblock_enc(INPUTS_PATHS_DIMS[2][0], 1, N_TIER_1_FEATURE_MAPS),
                self.downsampler(N_TIER_1_FEATURE_MAPS, (1, 2)),
                self.convblock_enc(
                    INPUTS_PATHS_DIMS[2][1],
                    N_TIER_1_FEATURE_MAPS,
                    N_TIER_2_FEATURE_MAPS,
                ),
                self.downsampler(N_TIER_2_FEATURE_MAPS, (1, 2)),
            ]
        )
        self.joined_path_encode = nn.ModuleList(
            [
                self.convblock_enc(
                    DIMS_AT_JOIN,
                    N_INPUT_PATHS * N_TIER_2_FEATURE_MAPS,
                    N_TIER_3_FEATURE_MAPS,
                    permits_slimming=True,
                ),
                self.downsampler(N_TIER_3_FEATURE_MAPS, (2, 4)),  # 8 x 64
                self.convblock_enc(
                    (8, 64),
                    N_TIER_3_FEATURE_MAPS,
                    N_TIER_4_FEATURE_MAPS,
                ),
                self.downsampler(N_TIER_4_FEATURE_MAPS, (2, 4)),  # 4 x 16
                self.convblock_enc(
                    (4, 16),
                    N_TIER_4_FEATURE_MAPS,
                    N_TIER_5_FEATURE_MAPS,
                ),
                self.downsampler(N_TIER_5_FEATURE_MAPS, (2, 4)),  # 2 x 4
                self.convblock_enc(
                    DIMS_SMALLEST,
                    N_TIER_5_FEATURE_MAPS,
                    N_TIER_6_FEATURE_MAPS,
                ),
            ]
        )
        self.skip_inp_1_tier_0_scaler = nn.Sequential(
            nn.AvgPool2d(kernel_size=(2, 1), stride=(2, 1)),
            self.upsampler(1, (1, 4)),
        )
        self.skip_inp_2_tier_0_scaler = self.upsampler(1, (1, 2))
        self.skip_inp_3_tier_0_scaler = self.upsampler(1, (2, 1))
        self.skip_inp_1_tier_1_scaler = nn.Sequential(
            nn.AvgPool2d(kernel_size=(2, 1), stride=(2, 1)),
            self.upsampler(N_TIER_1_FEATURE_MAPS, (1, 4)),
        )
        self.skip_inp_2_tier_1_scaler = self.upsampler(N_TIER_1_FEATURE_MAPS, (1, 2))
        self.skip_inp_3_tier_1_scaler = self.upsampler(N_TIER_1_FEATURE_MAPS, (2, 1))
        self.skip_inp_1_tier_2_scaler = self.upsampler(N_TIER_2_FEATURE_MAPS, (1, 2))
        self.skip_inp_2_tier_2_scaler = self.upsampler(N_TIER_2_FEATURE_MAPS, (2, 1))
        self.skip_inp_3_tier_2_scaler = self.upsampler(N_TIER_2_FEATURE_MAPS, (2, 1))
        self.fc = nn.Sequential(
            nn.Flatten(),  # default : start_dim=1
            nn.Dropout(p=0.1),
            nn.Linear(
                N_TIER_6_FEATURE_MAPS * DIMS_SMALLEST[0] * DIMS_SMALLEST[1],
                N_FC_HIDDEN_LAYER_NODES,
            ),
            nn.Mish(True),
            nn.Dropout(p=0.05),
            nn.Linear(
                N_FC_HIDDEN_LAYER_NODES,
                N_FC_OUTPUTS,
            ),
            nn.ReLU(True),
            nn.Dropout(p=0.01),
        )  # so the expanding path can factor in hollistic information about the image as a whole and not just local patterns
        self.path_decode = nn.ModuleList(
            [
                self.upsampler(N_TIER_6_FEATURE_MAPS, (2, 4)),  # 4 x 16
                self.convblock_dec(
                    N_FC_OUTPUTS + N_TIER_6_FEATURE_MAPS + N_TIER_5_FEATURE_MAPS,
                    N_TIER_5_FEATURE_MAPS,
                    dropout_p=0.1,
                ),
                self.upsampler(N_TIER_5_FEATURE_MAPS, (2, 4)),  # 8 x 64
                self.convblock_dec(
                    N_FC_OUTPUTS + N_TIER_5_FEATURE_MAPS + N_TIER_4_FEATURE_MAPS,
                    N_TIER_4_FEATURE_MAPS,
                    dropout_p=0.1,
                ),
                self.upsampler(N_TIER_4_FEATURE_MAPS, (2, 4)),  # 16 x 256
                self.convblock_dec(
                    N_FC_OUTPUTS + N_TIER_4_FEATURE_MAPS + N_TIER_3_FEATURE_MAPS,
                    N_TIER_3_FEATURE_MAPS,
                    dropout_p=0.1,
                ),
                self.upsampler(N_TIER_3_FEATURE_MAPS, (2, 2)),  # 32 x 512
                self.convblock_dec(
                    N_FC_OUTPUTS
                    + N_TIER_3_FEATURE_MAPS
                    + N_INPUT_PATHS * N_TIER_2_FEATURE_MAPS,
                    N_TIER_2_FEATURE_MAPS,
                    dropout_p=0.05,
                ),
                self.upsampler(N_TIER_2_FEATURE_MAPS, (1, 2)),  # 32 x 1024
                self.convblock_dec(
                    N_FC_OUTPUTS
                    + N_TIER_2_FEATURE_MAPS
                    + N_INPUT_PATHS * N_TIER_1_FEATURE_MAPS,
                    N_TIER_1_FEATURE_MAPS,
                    dropout_p=0.05,
                ),
                self.exporter(
                    N_FC_OUTPUTS + N_TIER_1_FEATURE_MAPS + N_INPUT_PATHS,
                    N_TIER_1_PRE_DIMENSIONALITY_REDUC_FEATURE_MAPS,
                ),
            ]
        )
        self.deep_superv_out_t4 = self.exporter(
            N_TIER_4_FEATURE_MAPS,
            N_TIER_1_PRE_DIMENSIONALITY_REDUC_FEATURE_MAPS,
        )
        self.deep_superv_out_t3 = self.exporter(
            N_TIER_3_FEATURE_MAPS,
            N_TIER_1_PRE_DIMENSIONALITY_REDUC_FEATURE_MAPS,
        )
        self.deep_superv_out_t2 = self.exporter(
            N_TIER_2_FEATURE_MAPS,
            N_TIER_1_PRE_DIMENSIONALITY_REDUC_FEATURE_MAPS,
        )

    def forward(self, x1, x2, x3, do_deep_supervision=False):
        skip_t_0 = torch.cat(
            [
                self.skip_inp_1_tier_0_scaler(x1),
                self.skip_inp_2_tier_0_scaler(x2),
                self.skip_inp_3_tier_0_scaler(x3),
            ],
            dim=1,
        )

        in_path_1 = self.input_1_path[0](x1)
        skip_1_1 = self.skip_inp_1_tier_1_scaler(in_path_1)
        in_path_1 = self.input_1_path[1](in_path_1)
        in_path_1 = self.input_1_path[2](in_path_1)
        skip_1_2 = self.skip_inp_1_tier_2_scaler(in_path_1)
        in_path_1 = self.input_1_path[3](in_path_1)

        in_path_2 = self.input_2_path[0](x2)
        skip_2_1 = self.skip_inp_2_tier_1_scaler(in_path_2)
        in_path_2 = self.input_2_path[1](in_path_2)
        in_path_2 = self.input_2_path[2](in_path_2)
        skip_2_2 = self.skip_inp_2_tier_2_scaler(in_path_2)
        in_path_2 = self.input_2_path[3](in_path_2)

        in_path_3 = self.input_3_path[0](x3)
        skip_3_1 = self.skip_inp_3_tier_1_scaler(in_path_3)
        in_path_3 = self.input_3_path[1](in_path_3)
        in_path_3 = self.input_3_path[2](in_path_3)
        skip_3_2 = self.skip_inp_3_tier_2_scaler(in_path_3)
        in_path_3 = self.input_3_path[3](in_path_3)

        enc = self.joined_path_encode[0](
            torch.cat([in_path_1, in_path_2, in_path_3], dim=1)
        )
        skip_t_3 = enc
        enc = self.joined_path_encode[1](enc)
        enc = self.joined_path_encode[2](enc)
        skip_t_4 = enc
        enc = self.joined_path_encode[3](enc)
        enc = self.joined_path_encode[4](enc)
        skip_t_5 = enc
        enc = self.joined_path_encode[5](enc)
        enc = self.joined_path_encode[6](enc)

        fc = self.fc(enc)
        fc_delinearized = torch.unsqueeze(torch.unsqueeze(fc, 2), 3)

        dec = self.path_decode[0](enc)
        dec = self.path_decode[1](
            torch.cat([fc_delinearized.repeat((1, 1, 4, 16)), skip_t_5, dec], 1)
        )
        dec = self.path_decode[2](dec)
        dec = self.path_decode[3](
            torch.cat([fc_delinearized.repeat((1, 1, 8, 64)), skip_t_4, dec], 1)
        )
        deep_superv_out_t4 = (
            self.deep_superv_out_t4(dec) if do_deep_supervision else None
        )
        dec = self.path_decode[4](dec)
        dec = self.path_decode[5](
            torch.cat([fc_delinearized.repeat((1, 1, 16, 256)), skip_t_3, dec], 1)
        )
        deep_superv_out_t3 = (
            self.deep_superv_out_t3(dec) if do_deep_supervision else None
        )
        dec = self.path_decode[6](dec)
        dec = self.path_decode[7](
            torch.cat(
                [
                    fc_delinearized.repeat((1, 1, 32, 512)),
                    skip_1_2,
                    skip_2_2,
                    skip_3_2,
                    dec,
                ],
                1,
            )
        )

        deep_superv_out_t2 = (
            self.deep_superv_out_t2(dec) if do_deep_supervision else None
        )
        dec = self.path_decode[8](dec)
        dec = self.path_decode[9](
            torch.cat(
                [
                    fc_delinearized.repeat((1, 1, 32, 1024)),
                    skip_1_1,
                    skip_2_1,
                    skip_3_1,
                    dec,
                ],
                1,
            )
        )
        dec = self.path_decode[10](
            torch.cat([fc_delinearized.repeat((1, 1, 32, 1024)), skip_t_0, dec], 1)
        )

        return (
            (deep_superv_out_t4, deep_superv_out_t3, deep_superv_out_t2, dec)
            if do_deep_supervision
            else dec
        )

# Loss Function

~~Combination of MSE in decibels and MSE in linear amplitude~~

Volume-Weighted Mean of Squared ( Error Function that Penalizes Same-Sign Undershooting Less than It Penalizes Overshooting and Opposite-Sign Predictions )

In [3]:
class CustomMSELoss(nn.Module):
    def __init__(self):
        super(CustomMSELoss, self).__init__()
        self.softplus_c = 50
        self.vol_ref_dB = 100
        self.epsilon = 0.000001
        self.fade_to_unabridge_error_mag_thresh_dB = 0.01

    def softplus(self, x, c):
        # https://stackoverflow.com/a/60908241/4356188
        return torch.where(x < 50, c * torch.log1p(torch.exp(x / c)), x)

    def sign_continuous_differentiable(self, x):
        return torch.tanh(1000000 * x)

    def modified_absolute_err(self, predictions, targets):
        d = lambda a: torch.clamp(torch.tanh(-a), min=0) * (
            (1 + torch.tanh(2 * (1 / 2 + a))) / 2
        )
        f = lambda r: torch.abs(r) - d(r) / 2

        abs_targets = torch.abs(targets)
        # print(torch.min(abs_targets[abs_targets > 0]))
        # f_val = torch.where(abs_targets > 0, f((predictions - targets) / targets), 0)

        mod_err = (self.epsilon + abs_targets) * f(
            self.sign_continuous_differentiable(targets)
            * (predictions - targets)
            / (self.epsilon + abs_targets)
        )
        pure_err = torch.abs(predictions - targets)
        tinyness = torch.pow(self.fade_to_unabridge_error_mag_thresh_dB, abs_targets)

        return tinyness * pure_err + (1 - tinyness) * mod_err
        """
        torch.where(
            torch.isfinite(f_val),
            abs_targets * f_val,
            torch.abs(predictions),
        )
        """

    def weighted_mse(self, predictions, targets, weights):
        # different weights within each train item
        # all train items in batch have same weight
        return torch.mean(
            torch.sum(
                weights * (self.modified_absolute_err(predictions, targets) ** 2),
                dim=(1, 2, 3),
            )
            / torch.sum(weights, dim=(1, 2, 3))
        )

    def calc_weights(self, base, residuals_predicted, residuals_targets):
        base_predicted = self.softplus(base + residuals_predicted, self.softplus_c)
        base_targets = self.softplus(base + residuals_targets, self.softplus_c)
        max_vol = torch.maximum(base_predicted, base_targets) / self.vol_ref_dB
        return max_vol

    def forward(self, inputs_2048, predictions, targets):
        return self.weighted_mse(
            predictions, targets, self.calc_weights(inputs_2048, predictions, targets)
        )

# Trainer + Profiler

In [4]:
MODEL_STORE_PATH = "./weight/"
NEWEST_MODEL_SAVE_FILE_NAME = "last.pth"
BEST_MODEL_SAVE_FILE_NAME = "best.pth"
CHECKPOINT_FILE_PATH = MODEL_STORE_PATH + NEWEST_MODEL_SAVE_FILE_NAME

MAX_SESSION_HOURS = 12
def train_and_val(epochs, model, train_loader, len_train, val_loader, len_val, criterion, optimizer, scheduler,
                  intermediates_contributions_decays_per_epoch,
                  device, time_limit_hrs=None, early_stop_thresh_proportion_epoch=0.25, use_amp=True):
    
    assert ( not (time_limit_hrs is None) ) and time_limit_hrs < MAX_SESSION_HOURS
    time_limit_s = time_limit_hrs * 60 ** 2
    
    min_epochs_train_until_consider_early_stop = math.ceil(early_stop_thresh_proportion_epoch * epochs / 2)

    n_train_batches = len(train_loader)
    n_val_batches = len(val_loader)
    
    torch.cuda.empty_cache()
    scaler = torch.cuda.amp.GradScaler(enabled=use_amp)
    
    per_epoch_metrics = {
        'lr':[],
        'train_loss':[],
        'val_loss':[],
        'val_means_of_items_peak_errors':[],
        'val_means_of_batches_peak_errors':[],
        'val_epochs_peak_errors':[],
        'deep_weights':[
            [],[],[],
        ],
    }
    
    lowest_val_loss = None
    recent_i_epoch_at_lowest_val_loss = 0

    #model = model.to(device)
    fit_time = time.time()
    time_to_halt = fit_time + time_limit_s
    dur_last_epoch_s = 0
    for i_e in range(epochs):
        
        print('┅┅┅┅┅┅┅┅┅┅')
        
        n_e = 1 + i_e
        since = time.time()
        
        running_loss = 0
        running_lr = 0
        running_deep_loss_weights = torch.zeros(len(intermediates_contributions_decays_per_epoch)).to(device)
        
        """
        with torch.profiler.profile(
            schedule=torch.profiler.schedule(
                wait=1,
                warmup=1,
                active=5,
                repeat=1),
            on_trace_ready=torch.profiler.tensorboard_trace_handler('/kaggle/working/profiler'),
            profile_memory=True,
            #with_stack=True,
        ) as profiler:
        """
        with tqdm(total=n_train_batches,disable=IS_KERNEL_OFFLINE_RUN) as pbar:
            for (resonant1,resonant2,resonant3), (deresonated_deep4,deresonated_deep3,deresonated_deep2,deresonated_the) in train_loader:
                
                scheduler_epochs_elapsed = scheduler.last_epoch / n_train_batches
                intermediates_contributions_per_epoch = intermediates_contributions_decays_per_epoch ** scheduler_epochs_elapsed
                running_deep_loss_weights += torch.from_numpy(intermediates_contributions_per_epoch).to(device)
                
                resonant1,resonant2,resonant3,deresonated_deep4,deresonated_deep3,deresonated_deep2,deresonated_the = (resonant1.to(device),
                                                                                                     resonant2.to(device),
                                                                                                     resonant3.to(device),
                                                                                                     deresonated_deep4.to(device),
                                                                                                     deresonated_deep3.to(device),
                                                                                                     deresonated_deep2.to(device),
                                                                                                     deresonated_the.to(device)
                                                                                                     )
                
                optimizer.zero_grad()

                model.train()
                
                with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=use_amp):

                    # forward
                    output_hidden_tier4,output_hidden_tier3,output_hidden_tier2,output_proper = model(resonant1,resonant2,resonant3,do_deep_supervision=True)
                    
                    # loss
                    # resonant_middle_channel = resonant2 #resonant[:,1:2]
                    loss_out_layer = criterion(output_proper,deresonated_the)
                    loss_intermediates = (
                        intermediates_contributions_per_epoch[0] * criterion(output_hidden_tier4,deresonated_deep4) +
                        intermediates_contributions_per_epoch[1] * criterion(output_hidden_tier3,deresonated_deep3) +
                        intermediates_contributions_per_epoch[2] * criterion(output_hidden_tier2,deresonated_deep2)
                    )
                    loss = loss_out_layer + loss_intermediates #criterion(resonant_middle_channel, output, deresonated)
                
                '''
                # backward
                loss.backward()
                #torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
                optimizer.step()  # update weight
                '''
                # Scales loss. Calls ``backward()`` on scaled loss to create scaled gradients.
                scaler.scale(loss).backward()

                # ``scaler.step()`` first unscales the gradients of the optimizer's assigned parameters.
                # If these gradients do not contain ``inf``s or ``NaN``s, optimizer.step() is then called,
                # otherwise, optimizer.step() is skipped.
                scaler.step(optimizer)

                # Updates the scale for next iteration.
                scaler.update()
                
                running_lr += np.mean(scheduler.get_last_lr())
                scheduler.step()

                running_loss += loss_out_layer.item()
                
                pbar.update(1)
                #profiler.step()
                
        train_loader.dataset.epoch() # requires non-persistent DataLoader-workers

        model.eval()
        val_losses = 0
        val_peaks = 0
        val_batchwide_peaks = 0
        val_epochwide_peak = 0
        # validation loop
        with torch.no_grad():
            with tqdm(total=n_val_batches,disable=IS_KERNEL_OFFLINE_RUN) as pb:
                for (resonant1,resonant2,resonant3), (deresonated_deep4,deresonated_deep3,deresonated_deep2,deresonated_the) in val_loader:
                    
                    resonant1,resonant2,resonant3,deresonated_the = resonant1.to(device),resonant2.to(device),resonant3.to(device),deresonated_the.to(device)
                    
                    with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=use_amp):

                        # forward
                        output_proper = model(resonant1,resonant2,resonant3,do_deep_supervision=False)

                        # loss
                        # resonant_middle_channel = resonant2 #resonant[:,1:2]
                        loss_out_layer = criterion(output_proper,deresonated_the)

                    val_losses += loss_out_layer.item()
                    
                    # metric
                    peak_per_item = torch.amax(torch.abs(output_proper - deresonated_the),dim=(1,2,3)) # shape should be 1d = batch size
                    val_peaks += torch.mean(peak_per_item).item()
                    peak_this_batch = torch.max(peak_per_item).item()
                    val_batchwide_peaks += peak_this_batch
                    if peak_this_batch > val_epochwide_peak:
                        val_epochwide_peak = peak_this_batch
                    
                    pb.update(1)

            # calculates mean of batches
            current_train_loss = running_loss / len_train
            current_val_loss = val_losses / len_val
            current_epoch_mean_lr = running_lr / len_train
            if math.isnan(current_train_loss) or math.isnan(current_val_loss):
                raise Exception("💥 NaN loss 💥")
            per_epoch_metrics['train_loss'].append(current_train_loss)
            per_epoch_metrics['val_loss'].append(current_val_loss)
            per_epoch_metrics['lr'].append(current_epoch_mean_lr)
            for i in range(len(intermediates_contributions_decays_per_epoch)):
                per_epoch_metrics['deep_weights'][i].append(running_deep_loss_weights[i].item() / n_train_batches)
            
            val_itemwide_peaks_mean = val_peaks / n_val_batches
            val_batchwide_peaks_mean = val_batchwide_peaks / n_val_batches
            per_epoch_metrics['val_means_of_items_peak_errors'].append(val_itemwide_peaks_mean)
            per_epoch_metrics['val_means_of_batches_peak_errors'].append(val_batchwide_peaks_mean)
            per_epoch_metrics['val_epochs_peak_errors'].append(val_epochwide_peak)

            # saves progress
            info_to_save = {
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
            }
            torch.save(info_to_save, CHECKPOINT_FILE_PATH)
            if lowest_val_loss is None or current_val_loss < lowest_val_loss:
                recent_i_epoch_at_lowest_val_loss = i_e
                lowest_val_loss = current_val_loss
                torch.save(info_to_save, MODEL_STORE_PATH + BEST_MODEL_SAVE_FILE_NAME)
            
            allotted_seconds_left = time_to_halt - time.time()
            dur_this_epoch_s = time.time() - since
            print("Epoch:{}/{}..".format(n_e, epochs),
                  "Train Loss: {:.5f}..".format(current_train_loss),
                  "Val Loss: {:.5f}..".format(current_val_loss),
                  "LR: {:.9f}..".format(current_epoch_mean_lr),
                  "Epoch’s Duration: {:.3f} s..".format(dur_this_epoch_s),
                  "Time Left: {:.3f} m".format(allotted_seconds_left / 60),
            )
            print(
                "Val Epoch-Mean Batch-Mean Max Deviation in Decibels: {:.4f}..".format(val_itemwide_peaks_mean),
                "Val Epoch-Mean Batch-Max Max Deviation in Decibels: {:.4f}..".format(val_batchwide_peaks_mean),
                "Val Epoch-Max Batch-Max Max Deviation in Decibels: {:.4f}".format(val_epochwide_peak),
            )
            
        if n_e < epochs:
            # quota management
            if allotted_seconds_left <= dur_last_epoch_s:
                print("🛑 TIME 🛑")
                break
            # early stop
            if recent_i_epoch_at_lowest_val_loss >= (min_epochs_train_until_consider_early_stop-1) and i_e - recent_i_epoch_at_lowest_val_loss >= min_epochs_train_until_consider_early_stop and n_e / (1+recent_i_epoch_at_lowest_val_loss) >= 2:
                print("🛑 EARLY STOP = validation loss didn’t decrease for too long 🛑")
                break
                
        dur_last_epoch_s = dur_this_epoch_s

    history = { 'n_epochs': n_e, 'learning_rate': per_epoch_metrics['lr'], 'train_loss': per_epoch_metrics['train_loss'], 'val_loss': per_epoch_metrics['val_loss'], 'deep_weights':per_epoch_metrics['deep_weights'],
                'error_peak_item': per_epoch_metrics['val_means_of_items_peak_errors'],'error_peak_batch': per_epoch_metrics['val_means_of_batches_peak_errors'],'error_peak_epoch': per_epoch_metrics['val_epochs_peak_errors'] }
    print('▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬')
    print('Total time: {:.3f} m'.format((time.time() - fit_time) / 60))

    return history

# Plots Loss + Accuracy Metrics Throughout Epochs

In [5]:
def plot_metrics(history):
    x = 1 + np.arange(0,history['n_epochs'])
    
    fig,(ax1,ax2,ax3) = plt.subplots(3, 1, sharex=True, figsize=(12, 12))
    #fig.tight_layout()
    
    def plot_loss():
    
        ax1.plot(x, history['val_loss'], label='val', marker='o')
        ax1.plot(x, history['train_loss'], label='train', marker='o')
        
        i_least_val_loss = np.argmin(history['val_loss'])
        x_least_val_loss = x[i_least_val_loss]
        ax1.plot(x_least_val_loss,history['val_loss'][i_least_val_loss],'ro')

        ax1.set_title('Loss per Epoch')
        ax1.set_ylabel('MSE') # Modified Weighted MSE Loss
        ax1.set_xlabel('Epoch')
        ax1.legend(), ax1.grid()
    
    def plot_inaccuracy():
        
        ax2.plot(x, history['error_peak_item'], label='item-max|batch-mean', marker='o')
        ax2.plot(x, history['error_peak_batch'], label='batch-max|epoch-mean', marker='o')
        ax2.plot(x, history['error_peak_epoch'], label='epoch-max', marker='o')

        ax2.set_title('Peak Inaccuracy (dB) per Epoch')
        ax2.set_ylabel('dB Off Target')
        ax2.set_xlabel('Epoch')
        ax2.legend(), ax2.grid()
        
    def plot_learnrate():
        
        ax3.plot(x, history['learning_rate'], label='learn rate', marker='o')
        
        ax3.set_title('Mean Learn Rate per Epoch')
        ax3.set_ylabel('Learn Rate')
        ax3.set_xlabel('Epoch')
        ax3.grid()
        
        ax4 = ax3.twinx()  # instantiates a second Axes that shares the same x-axis
        #color = 'tab:blue'
        ax4.set_ylabel('Intermediate Outputs’ Contributions to Total Loss')  # we already handled the x-label with ax1
        ax4.plot(x, history['deep_weights'][0], label='4th tier = 8x64')
        ax4.plot(x, history['deep_weights'][1], label='3rd tier = 16x256')
        ax4.plot(x, history['deep_weights'][2], label='2nd tier = 32x512')
        ax4.legend(loc='upper right')
        
        ax3.legend(loc='upper left')
        
    plot_loss()
    plot_inaccuracy()
    plot_learnrate()

    plt.savefig('./weight/charts.png',dpi=300)

# Datasets + Runtime Augmentations

In [None]:
class ArraysChainer():
    def __init__(self,arrays):
        self.arrays = arrays
        self.len_total = int(np.sum([len(a) for a in self.arrays],dtype=np.int64))
        
    def __len__(self):
        return self.len_total
        
    def __getitem__(self,idx):
        which_array = 0
        subtracted_idx = idx
        while subtracted_idx >= (len_here := len(self.arrays[which_array])):
            subtracted_idx -= len_here
            which_array += 1
        return self.arrays[which_array][subtracted_idx]

class NumpyMemMapPairsChainsDataset(Dataset):
    def __init__(self, x_chainers:tuple, y_chainer:ArraysChainer, dtype=np.float32): # , _device=device
        x1_chainer,x2_chainer,x3_chainer = x_chainers
        self.x1 = x1_chainer
        self.x2 = x2_chainer
        self.x3 = x3_chainer
        self.y = y_chainer
        assert len(self.x1) == len(self.y)
        assert len(self.x2) == len(self.y)
        assert len(self.x3) == len(self.y)
        self.dtype = dtype
        # self._device = device

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        return (
            (
                torch.from_numpy(self.x1[idx].astype(self.dtype)),
                torch.from_numpy(self.x2[idx].astype(self.dtype)),
                torch.from_numpy(self.x3[idx].astype(self.dtype)),
            ),
            torch.from_numpy(self.y[idx].astype(self.dtype)),
        )
    
class AugmentationDatasetWrapper(Dataset):
    # https://stackoverflow.com/a/74921042
    def __init__(self, _train_subset, view_augmentation_multiplier=2, flip_augmentation_multiplier=0.25,layers_augmentation_multiplier=0.2):
        assert view_augmentation_multiplier == int(view_augmentation_multiplier)
        
        self.base_set = _train_subset
        
        self.n_samples_per_base_item = 1 + view_augmentation_multiplier
        
        self._l = len(self.base_set)
        self.w_crop_ratio = 1/2
        self.w_bay_per_x = np.array([2*64,2*32,2*16])
        assert np.all(self.w_bay_per_x * self.w_crop_ratio == np.floor(self.w_bay_per_x * self.w_crop_ratio))
        self.w_bay_y_xmatch_index = 1
        coarsest_res = min(self.w_bay_per_x)
        shiftmax_coarsest_res = coarsest_res * (1-self.w_crop_ratio)
        shifts_coarsest_res = np.arange(shiftmax_coarsest_res)
        self.views_positions = [np.array([s * res / coarsest_res for res in self.w_bay_per_x]) for s in shifts_coarsest_res]
        self.views_positions = np.array([poses.astype(int) for poses in self.views_positions if np.all((poses - poses.astype(int)) == 0)])
        wholenum_shifts_coarsest_res = self.views_positions[:,-1]
        self.probability_per_pos_per_remainder = [
            (
                _p := (norm.pdf(
                    wholenum_shifts_coarsest_res / shiftmax_coarsest_res,
                    loc=_i / (self.n_samples_per_base_item - 1),
                    scale=1 / (self.n_samples_per_base_item - 1) / 2,
                ) if self.n_samples_per_base_item > 1 else np.ones((len(self.views_positions),),dtype=np.float64))
            )
            / np.sum(_p)
            for _i in np.arange(self.n_samples_per_base_item)
        ]

        
        self.augmented_main_len = self._l * self.n_samples_per_base_item
        self.augmentation_extra_len = int(self.augmented_main_len * flip_augmentation_multiplier)
        self.augmented_nolayers_len = self.augmented_main_len + self.augmentation_extra_len
        self.augmented_layers_len = int(self.augmented_nolayers_len * layers_augmentation_multiplier)
        
        octaves_below_nyquist = torch.log2((torch.arange(1, 1 + 1024) + 1.0 / 2.0) / 1024).repeat((32, 1))
        self.spect_cascade_masks = (
            torch.clamp(octaves_below_nyquist,min=-2,max=-1)+2,
            (1-(torch.clamp(octaves_below_nyquist,min=-2,max=-1)+2)) * (torch.clamp(octaves_below_nyquist,min=-4,max=-3)+4),
            1-(torch.clamp(octaves_below_nyquist,min=-4,max=-3)+4),
        )
        
        self.seed = 0
        
        self.epoch()
        
    def epoch(self):
        self.seed = int(time.time() * np.random.random())
        
        base_indices = np.arange(self._l)
        if self.augmentation_extra_len <= self._l:
            self.augmentation_map = np.random.choice(base_indices,size=self.augmentation_extra_len,replace=False)
            return
        augmentation_map_unique = base_indices
        np.random.shuffle(augmentation_map_unique)
        self.augmentation_map = np.resize(augmentation_map_unique,self.augmentation_extra_len)
        
    def rancrop(self,x,y,idx_getter):
        x1,x2,x3 = x
        x1 = torch.unsqueeze(x1,dim=0)
        x2 = torch.unsqueeze(x2,dim=0)
        x3 = torch.unsqueeze(x3,dim=0)
        _y = torch.unsqueeze(y,dim=0)
        
        if self.n_samples_per_base_item <= 1:
            which_shiftsset = (np.random.default_rng(self.seed + idx_getter + 0)).integers(0,len(self.views_positions)-1)
        
        else:
            modulo = idx_getter % self.n_samples_per_base_item
            #print(self.probability_per_pos_per_remainder[0].shape)
            which_shiftsset = (np.random.default_rng(self.seed + idx_getter + 1)).choice(np.arange(len(self.views_positions)),p=self.probability_per_pos_per_remainder[modulo])
        
        shifts = self.views_positions[which_shiftsset]
        
        y_fullsize = _y[:,shifts[self.w_bay_y_xmatch_index]:shifts[self.w_bay_y_xmatch_index]+int(self.w_bay_per_x[self.w_bay_y_xmatch_index]*self.w_crop_ratio),:]
        #print(y_fullsize.shape)
        return ((x1[:,shifts[0]:shifts[0]+int(self.w_bay_per_x[0]*self.w_crop_ratio),:],
                x2[:,shifts[1]:shifts[1]+int(self.w_bay_per_x[1]*self.w_crop_ratio),:],
                x3[:,shifts[2]:shifts[2]+int(self.w_bay_per_x[2]*self.w_crop_ratio),:]),
                y_fullsize
               )

    def __len__(self):
        return self.augmented_nolayers_len + self.augmented_layers_len

    def __getitem__(self, idx,is_intermediate=False):
        layers = idx >= self.augmented_nolayers_len
        flip = (not layers) and idx >= self.augmented_main_len
        if layers:
            x_cropped,y_cropped = self.dual_layer_getitem(idx)
        else:
            base_idx = self.augmentation_map[idx-self.augmented_main_len] if flip else idx // self.n_samples_per_base_item
            x_cropped,y_cropped = self.rancrop(*self.base_set[base_idx],idx)
        if is_intermediate:
            return x_cropped,y_cropped
        if flip:
            x_cropped = tuple([torch.flip(x_item,(1,)) for x_item in x_cropped])
            y_cropped = torch.flip(y_cropped,(1,))
        y_cropped = (
                    F.interpolate(torch.unsqueeze(y_cropped,dim=0), size=(8,64),
                                          mode='bilinear', align_corners=False)[0],
                    F.interpolate(torch.unsqueeze(y_cropped,dim=0), size=(16,256),
                                          mode='bilinear', align_corners=False)[0],
                    F.interpolate(torch.unsqueeze(y_cropped,dim=0), size=(32,512),
                                          mode='bilinear', align_corners=False)[0],
                    y_cropped,
                )
        return x_cropped,y_cropped
    
    def spect_cascade(self,Y1024,Y2048,Y4096):
        return Y1024 * self.spect_cascade_masks[0] + Y2048 * self.spect_cascade_masks[1] + Y4096 * self.spect_cascade_masks[2]
    
    def dual_layer_getitem(self,idx_getter):
        to_amp = lambda dB : 10.0 ** (dB / 20.0)
        from_amp = lambda a : 20.0 * torch.log10(a)
        
        pointee_idx_1 = (np.random.default_rng(self.seed + idx_getter + 2)).integers(0,self.augmented_nolayers_len)
        pointee_idx_2 = (np.random.default_rng(self.seed + idx_getter + 3)).integers(0,self.augmented_nolayers_len)
        pointee_1_x,pointee_1_y = self.__getitem__(pointee_idx_1,True)
        pointee_2_x,pointee_2_y = self.__getitem__(pointee_idx_2,True)
        Xs = []
        Ys = []
        mix_fac = (np.random.default_rng(self.seed + idx_getter + 4)).uniform(0,1)
        for p1x,p2x in zip(pointee_1_x,pointee_2_x):
            amp_1 = to_amp(p1x)
            amp_2 = to_amp(p2x)
            amp_total = (1-mix_fac) * amp_1 + mix_fac * amp_2
            Xs.append(from_amp(amp_total))
            
            #print(p1x.shape)
            
            p1x_scaled = F.interpolate(torch.unsqueeze(p1x,dim=0), size=pointee_1_y.shape[1:],
                                          mode='bilinear', align_corners=False)[0]
            p2x_scaled = F.interpolate(torch.unsqueeze(p2x,dim=0), size=pointee_2_y.shape[1:],
                                          mode='bilinear', align_corners=False)[0]
            amp_1 = to_amp(p1x_scaled)
            amp_2 = to_amp(p2x_scaled)
            amp_total = (1-mix_fac) * amp_1 + mix_fac * amp_2
            
            contrib_1 = (1-mix_fac) * amp_1 / amp_total
            contrib_2 = mix_fac * amp_2 / amp_total
            Ys.append(from_amp(contrib_1 * to_amp(pointee_1_y) + contrib_2 * to_amp(pointee_2_y)))
            
        #print(pointee_1_x[0])
        #print(pointee_2_x[0])
        #print(pointee_1_y)
        #print(pointee_2_y)
        #print(self.spect_cascade(*Ys))
            
        return tuple(Xs),self.spect_cascade(*Ys)
    
def load_train_data(filter=''):
    FOLDER = '/kaggle/input/'
    FILENAME_STEM = 'train_'
    specific_filter = (filter + '_') if len(filter) > 0 else filter
    X1_files_outoforder = glob.glob(FOLDER+FILENAME_STEM+f'x1_{specific_filter}*.npy')
    X2_files_outoforder = glob.glob(FOLDER+FILENAME_STEM+f'x2_{specific_filter}*.npy')
    X3_files_outoforder = glob.glob(FOLDER+FILENAME_STEM+f'x3_{specific_filter}*.npy')
    Y_files_outoforder = glob.glob(FOLDER+FILENAME_STEM+f'y_{specific_filter}*.npy')
    
    files_pairs = [] # to list to reuse the iterator
    for yf in Y_files_outoforder:
        correspond_x1f = yf.replace(FILENAME_STEM+'y',FILENAME_STEM+'x1')
        correspond_x2f = yf.replace(FILENAME_STEM+'y',FILENAME_STEM+'x2')
        correspond_x3f = yf.replace(FILENAME_STEM+'y',FILENAME_STEM+'x3')
        assert correspond_x1f in X1_files_outoforder and correspond_x2f in X2_files_outoforder and correspond_x3f in X3_files_outoforder
        files_pairs.append(((correspond_x1f,correspond_x2f,correspond_x3f),yf))
    files_pairs = sorted(files_pairs, key=lambda a : os.path.getmtime(a[1]))
    [print(_x1,_x2,_x3,_y) for (_x1,_x2,_x3),_y in files_pairs]
    
    IN_MEM_LIMIT_GiB = 15
    IN_MEM_LIMIT_B = IN_MEM_LIMIT_GiB * 1024 ** 3
    
    mem_allocated_total_B = 0
    
    def load_to_mem_if_room(_fpath):
        nonlocal mem_allocated_total_B
        if mem_allocated_total_B >= IN_MEM_LIMIT_B:
            return np.load(_fpath, mmap_mode='r')
        loaded_to_ram = np.load(_fpath)
        mem_allocated_total_B += loaded_to_ram.nbytes
        return loaded_to_ram
    
    X1_arrays = []
    X2_arrays = []
    X3_arrays = []
    Y_arrays = []
    for (x1f,x2f,x3f),yf in files_pairs:
        Y_arrays.append(load_to_mem_if_room(yf))
        X1_arrays.append(load_to_mem_if_room(x1f))
        X2_arrays.append(load_to_mem_if_room(x2f))
        X3_arrays.append(load_to_mem_if_room(x3f))
    
    X1_all = ArraysChainer(X1_arrays)
    X2_all = ArraysChainer(X2_arrays)
    X3_all = ArraysChainer(X3_arrays)
    Y_all = ArraysChainer(Y_arrays)
    
    return NumpyMemMapPairsChainsDataset((X1_all,X2_all,X3_all),Y_all) # dtype=np.float16

# Loads Data

In [None]:
with torch.device('cpu'):

    dataset = load_train_data() # filter='130'
    n_pairs = len(dataset)

    generator1 = torch.Generator(device='cpu').manual_seed(42) # prevents data leakage artificially lowering val loss when resumes
    VAL_SIZE_PROPORTION = np.tanh(np.sqrt(10**3/(10**3+np.e*n_pairs)))/2
    print(VAL_SIZE_PROPORTION)
    TRAIN_SET_COUNT_INFLATION_RATIO = 3
    VAL_SIZE_PROPORTION_BEFORE_INFLATE_TRAIN = -(TRAIN_SET_COUNT_INFLATION_RATIO*VAL_SIZE_PROPORTION)/(VAL_SIZE_PROPORTION-TRAIN_SET_COUNT_INFLATION_RATIO*VAL_SIZE_PROPORTION-1)
    # x_train, x_val, y_train, y_val = train_test_split(loaded_all_X,loaded_all_Y,test_size=VAL_SIZE_PROPORTION)
    train_dataset, val_dataset = random_split(dataset,[1-VAL_SIZE_PROPORTION_BEFORE_INFLATE_TRAIN,VAL_SIZE_PROPORTION_BEFORE_INFLATE_TRAIN],generator=generator1)
    train_dataset = AugmentationDatasetWrapper(train_dataset)
    val_dataset = AugmentationDatasetWrapper(val_dataset,view_augmentation_multiplier=0)

    BATCH_SIZE_TRAIN = 32
    """8 if n_pairs < 250 else (
        16 if n_pairs < 1000 else (
            32 if n_pairs < 75000 else (
                64 if n_pairs < 1000000 else 128)))"""
    BATCH_SIZE_VAL = 32 # recommendation = as big as hardware can handle | but smaller performed faster in tests with P100

    N_LOADERWORKERS = 4
    # train_set = np.array([x_train,y_train],dtype=np.float32) # np.swapaxes([x_train,y_train],0,1)
    # train_dataset = TensorDataset(torch.from_numpy(x_train), torch.from_numpy(y_train))
    train_loader = DataLoader(
        train_dataset, batch_size=BATCH_SIZE_TRAIN, shuffle=True, drop_last=True,
        num_workers=N_LOADERWORKERS, pin_memory=True, persistent_workers=False, # https://discuss.pytorch.org/t/selective-augmentation-modifying-dataloader/108304/2
        #device='cpu',
    )
    # val_set = np.array([x_val,y_val],dtype=np.float32) # np.swapaxes([x_val,y_val],0,1)
    # val_dataset = TensorDataset(torch.from_numpy(x_val), torch.from_numpy(y_val))
    valid_loader = DataLoader(
        val_dataset, batch_size=BATCH_SIZE_VAL, shuffle=True, drop_last=True,
        num_workers=N_LOADERWORKERS, pin_memory=True, persistent_workers=False, # https://discuss.pytorch.org/t/selective-augmentation-modifying-dataloader/108304/2
        #device='cpu',
    )

    len_train = len(train_dataset)
    len_val = len(val_dataset)
    print(len_train)
    print(len_val)

# Run

Resumes Training If Checkpoint Exists

In [None]:
net = DeeplySupervizedUnet() # smp.Unet("resnet34", encoder_weights=None, activation=None)
net = net.to(device)
# net.half()
# net = torch.compile(net)
# net = nn.DataParallel(net)

loss_function = nn.MSELoss() #CustomMSELoss()
# loss_function = torch.compile(loss_function)
optimizer = optim.AdamW(net.parameters(), lr=1e-4, weight_decay=1e-3) # , amsgrad=True

# https://discuss.pytorch.org/t/cant-use-cycliclr-with-adam/74820
# https://medium.com/analytics-vidhya/cyclical-learning-rates-a922a60e8c04
n_iters_per_epoch = len(train_loader)
#print(n_iters_per_epoch)
scheduler = optim.lr_scheduler.CyclicLR(
    optimizer,
    base_lr=5e-5,max_lr=5e-4,
    step_size_up=int(2*n_iters_per_epoch),step_size_down=int(10*n_iters_per_epoch), # https://medium.com/analytics-vidhya/cyclical-learning-rates-a922a60e8c04
    mode="triangular2",
    cycle_momentum=False
)

# 🟢
IS_KERNEL_OFFLINE_RUN = True # 🟢
# 🟢

PRETRAINED_MODEL_PATH = '/kaggle/input/'
DO_RESUME_FROM_CHECKPOINT = True
MUST_RESUME_FROM_CHECKPOINT = True
RESET_OPTIM_LR_SCHED = False
# Resumes Training If Checkpoint Exists
real_chkpnt_pth = PRETRAINED_MODEL_PATH if IS_KERNEL_OFFLINE_RUN else CHECKPOINT_FILE_PATH
if (True if MUST_RESUME_FROM_CHECKPOINT else DO_RESUME_FROM_CHECKPOINT) and os.path.isfile(real_chkpnt_pth):
    checkpoint = torch.load(real_chkpnt_pth)
    net.load_state_dict(checkpoint['model_state_dict'])
    if not RESET_OPTIM_LR_SCHED:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    print("⏯️ RESUMES from CHECKPOINT ⏯️")
else:
    print("▶️ TRAINS from SCRATCH ▶️")
    if MUST_RESUME_FROM_CHECKPOINT:
        raise Exception("💥 checkpoint file not found 💥")

HOURS_TO_TRAIN = 5
# KNOWN_DUR_SECONDS_EACH_EPOCH = 49
# n_epochs = int(HOURS_TO_TRAIN * 60 ** 2 / KNOWN_DUR_SECONDS_EACH_EPOCH)
n_epochs = 100

#with autograd.detect_anomaly():
history = train_and_val(n_epochs, net, train_loader, len_train, valid_loader, len_val, loss_function, optimizer, scheduler,
                        intermediates_contributions_decays_per_epoch = np.array([0.5**(1/10),0.5**(1/20),0.5**(1/30)],dtype=np.float32),
                        device=device,
                        time_limit_hrs=HOURS_TO_TRAIN,
                        use_amp=False,
                       )

plot_metrics(history)