In [None]:
import os
import json
import random
import collections

import numpy as np
import pandas as pd
import cv2
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
## things to try:
# 1) smaller img
# 2) save both best auc and best acc
# 3) change to fold 3
# 4) add warm up to learning rate; try new optimizer
# 5) swapping channels
# 6) shorten cqt convolution kernel
# 7) change pretrain to True (with learning rate warm up)
# 8) remove mixup; soft labels
# 9) try stochastic depth
# 10) try noisy student training

# Import libraries

In [None]:
!pip install efficientnet_pytorch -qq

!pip install -q nnAudio -qq
import torch
from nnAudio.Spectrogram import CQT1992v2, CQT2010v2

import time
import gc
import torch
from torch import nn
from torch.utils import data as torch_data
from sklearn import model_selection as sk_model_selection
from torch.nn import functional as torch_functional
from torch.autograd import Variable
import efficientnet_pytorch
from tqdm.auto import tqdm
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, CosineAnnealingLR, ReduceLROnPlateau, ExponentialLR
from torchaudio.functional import lfilter
from torch.fft import fft, rfft, ifft
import numpy as np
import torchvision



from sklearn.metrics import roc_auc_score

from sklearn.model_selection import StratifiedKFold

import sys
sys.path.append('../input/timm-pytorch-image-models/pytorch-image-models-master')
import timm

In [None]:
#sys.path.append('../input/pytorch-swa')
#import swa

# Load data

In [None]:
submission = pd.read_csv("../input/g2net-gravitational-wave-detection/sample_submission.csv")
train_df = pd.read_csv("../input/g2net-gravitational-wave-detection/training_labels.csv")
train_df_pred = pd.read_csv("../input/train-pred-cqt-v10/train_preds_CQT_V10.csv")

In [None]:
train_df['preds'] = train_df_pred['target']
weight = 0.5
train_df['soft_target'] = train_df['preds']*weight + train_df['target']*(1-weight)
train_df

In [None]:
train_df_pred

# Define config

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class CFG:
    
    TRAIN = True
    
    WARM_START = False
    
    EPOCHS = 10
    
    # batch size
    BATCH = 192
    
    
    
    n_fold = 5
    fold = 3
    
    # scheduler_params
    scheduler='CosineAnnealingLR'
    if scheduler == 'CosineAnnealingLR':
        lr = 1.5*1e-2
        schedulerStepFreq = 10
        T_max=EPOCHS*3500*128/BATCH/schedulerStepFreq # CosineAnnealingLR
        T_0=3 # CosineAnnealingWarmRestarts
        min_lr=1e-6
    elif scheduler == 'ExponentialLR':
        lr = 3e-2
        schedulerStepFreq = 500
        T_max=EPOCHS*3500*128/BATCH/schedulerStepFreq # CosineAnnealingLR
        min_lr = 1e-4
        gamma = (min_lr/lr)**(1/T_max)
    
    
    
    # Parameters CWT
    cwt_params = {'fs':2048, 'lower_freq':10, 'upper_freq': 500, 
                  'n_scales':81, 'wavelet_width':1, 'stride':12, 'border_crop':0, 'train_width':True}
    
    useCWT = False
    #cqt_params = {'sr':2048, 'fmin':25, 'fmax':512, 'hop_length':32, 'n_bins':69, 'bins_per_octave': 8}
    cqt_params = {'sr':2048, 'fmin':20, 'fmax':512, 'hop_length':32, 'bins_per_octave': 12, 'filter_scale': 0.6}
    # cqt_params = {'sr':2048, 'fmin':20, 'fmax':512, 'hop_length':32, 'bins_per_octave': 25, 'norm':1}
    
    BPfilter = True
    
    
    
    
    # Post Proc Option
    PREPROC = 'Q_transform'
    
    # scale:linear or log
    SCALE = 'linear'
    
    DEBUG = False
    
    SMALL_TRAIN_SET = False
    
    seed = 42
    
    model_name = 'tf_efficientnet_b0' #'tf_efficientnet_b4' #'efficientnet-b7'
    pretrained = True
    unfreezeStep = 5 # set 0 for no freezing
    
    useSoftLabels = True
    useTestLabels = False
    
    NORM_MODE = 'global_channel' # 'global' or 'local' or 'local_std' or 'global_channel'
    
    DENOISE = False
    
    image_size = (101,201)
    
    
if CFG.DEBUG:
    CFG.EPOCHS = 2
    train_df = train_df.sample(n=1000, random_state=CFG.seed).reset_index(drop=True)
elif CFG.SMALL_TRAIN_SET:
    CFG.EPOCHS = 4
    train_df = train_df.sample(n=CFG.BATCH*500, random_state=CFG.seed).reset_index(drop=True)
elif CFG.WARM_START:
    CFG.EPOCHS = 3
    CFG.lr = 1e-3
    CFG.useTestLabels = True
    if CFG.scheduler == 'CosineAnnealingLR':
        # scheduler_params
        CFG.schedulerStepFreq = 10
        CFG.T_max=CFG.EPOCHS*4200*128/CFG.BATCH/CFG.schedulerStepFreq # CosineAnnealingLR
        CFG.min_lr=1e-7
    

if not CFG.pretrained:
    unfreezeStep = 5
    
    

In [None]:
import random

def set_seed(seed):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True


set_seed(42)

# CWT

https://www.kaggle.com/anjum48/continuous-wavelet-transform-cwt-in-pytorch/notebook

In [None]:
#!pip install torch --upgrade --quiet

In [None]:
import numpy as np
import pandas as pd
import pywt
import torch
import torch.nn as nn
from scipy import signal
from scipy import optimize
import matplotlib.pyplot as plt
from pathlib import Path
from IPython.display import Image

INPUT_PATH = Path("../input/g2net-gravitational-wave-detection/")

In [None]:
# From https://github.com/tomrunia/PyTorchWavelets/blob/master/wavelets_pytorch/wavelets.py
class Morlet(object):
    def __init__(self, w0=6):
        """w0 is the nondimensional frequency constant. If this is
        set too low then the wavelet does not sample very well: a
        value over 5 should be ok; Terrence and Compo set it to 6.
        """
        self.w0 = w0
        if w0 == 6:
            # value of C_d from TC98
            self.C_d = 0.776

    def __call__(self, *args, **kwargs):
        return self.time(*args, **kwargs)

    def time(self, t, s=1.0, complete=True):
        """
        Complex Morlet wavelet, centred at zero.
        Parameters
        ----------
        t : float
            Time. If s is not specified, this can be used as the
            non-dimensional time t/s.
        s : float
            Scaling factor. Default is 1.
        complete : bool
            Whether to use the complete or the standard version.
        Returns
        -------
        out : complex
            Value of the Morlet wavelet at the given time
        See Also
        --------
        scipy.signal.gausspulse
        Notes
        -----
        The standard version::
            pi**-0.25 * exp(1j*w*x) * exp(-0.5*(x**2))
        This commonly used wavelet is often referred to simply as the
        Morlet wavelet.  Note that this simplified version can cause
        admissibility problems at low values of `w`.
        The complete version::
            pi**-0.25 * (exp(1j*w*x) - exp(-0.5*(w**2))) * exp(-0.5*(x**2))
        The complete version of the Morlet wavelet, with a correction
        term to improve admissibility. For `w` greater than 5, the
        correction term is negligible.
        Note that the energy of the return wavelet is not normalised
        according to `s`.
        The fundamental frequency of this wavelet in Hz is given
        by ``f = 2*s*w*r / M`` where r is the sampling rate.
        """
        w = self.w0

        x = t / s

        output = np.exp(1j * w * x)

        if complete:
            output -= np.exp(-0.5 * (w ** 2))

        output *= np.exp(-0.5 * (x ** 2)) * np.pi ** (-0.25)

        return output

    # Fourier wavelengths
    def fourier_period(self, s):
        """Equivalent Fourier period of Morlet"""
        return 4 * np.pi * s / (self.w0 + (2 + self.w0 ** 2) ** 0.5)

    def scale_from_period(self, period):
        """
        Compute the scale from the fourier period.
        Returns the scale
        """
        # Solve 4 * np.pi * scale / (w0 + (2 + w0 ** 2) ** .5)
        #  for s to obtain this formula
        coeff = np.sqrt(self.w0 * self.w0 + 2)
        return (period * (coeff + self.w0)) / (4.0 * np.pi)

    # Frequency representation
    def frequency(self, w, s=1.0):
        """Frequency representation of Morlet.
        Parameters
        ----------
        w : float
            Angular frequency. If `s` is not specified, i.e. set to 1,
            this can be used as the non-dimensional angular
            frequency w * s.
        s : float
            Scaling factor. Default is 1.
        Returns
        -------
        out : complex
            Value of the Morlet wavelet at the given frequency
        """
        x = w * s
        # Heaviside mock
        Hw = np.array(w)
        Hw[w <= 0] = 0
        Hw[w > 0] = 1
        return np.pi ** -0.25 * Hw * np.exp((-((x - self.w0) ** 2)) / 2)

    def coi(self, s):
        """The e folding time for the autocorrelation of wavelet
        power at each scale, i.e. the timescale over which an edge
        effect decays by a factor of 1/e^2.
        This can be worked out analytically by solving
            |Y_0(T)|^2 / |Y_0(0)|^2 = 1 / e^2
        """
        return 2 ** 0.5 * s


class CWT(nn.Module):
    def __init__(
        self,
        dj=0.0625,
        dt=1 / 2048,
        fmin: int = 20,
        fmax: int = 500,
        output_format="Magnitude",
        trainable=False,
        padding = 0,
        stride = (1,1)
    ):
        super().__init__()
        self.wavelet = Morlet()

        self.dt = dt
        self.dj = dj
        self.fmin = fmin
        self.fmax = fmax
        self.output_format = output_format
        self.trainable = trainable  # TODO make kernel a trainable parameter
        self.stride = stride  # Strides > 1 not yet supported
        self.padding = padding

        self._scale_minimum = self.compute_minimum_scale()

        self.signal_length = None
        self._channels = None

        self._scales = None
        self._kernel = None
        self._kernel_real = None
        self._kernel_imag = None

    def compute_optimal_scales(self):
        """
        Determines the optimal scale distribution (see. Torrence & Combo, Eq. 9-10).
        :return: np.ndarray, collection of scales
        """
        if self.signal_length is None:
            raise ValueError(
                "Please specify signal_length before computing optimal scales."
            )
        J = int(
            (1 / self.dj) * np.log2(self.signal_length * self.dt / self._scale_minimum)
        )
        scales = self._scale_minimum * 2 ** (self.dj * np.arange(0, J + 1))

        # Remove high and low frequencies
        frequencies = np.array([1 / self.wavelet.fourier_period(s) for s in scales])
        if self.fmin:
            frequencies = frequencies[frequencies >= self.fmin]
            scales = scales[0 : len(frequencies)]
        if self.fmax:
            frequencies = frequencies[frequencies <= self.fmax]
            scales = scales[len(scales) - len(frequencies) : len(scales)]

        return scales

    def compute_minimum_scale(self):
        """
        Choose s0 so that the equivalent Fourier period is 2 * dt.
        See Torrence & Combo Sections 3f and 3h.
        :return: float, minimum scale level
        """
        dt = self.dt

        def func_to_solve(s):
            return self.wavelet.fourier_period(s) - 2 * dt

        return optimize.fsolve(func_to_solve, 1)[0]

    def _build_filters(self):
        self._filters = []
        for scale_idx, scale in enumerate(self._scales):
            # Number of points needed to capture wavelet
            M = 10 * scale / self.dt
            # Times to use, centred at zero
            t = torch.arange((-M + 1) / 2.0, (M + 1) / 2.0) * self.dt
            if len(t) % 2 == 0:
                t = t[0:-1]  # requires odd filter size
            # Sample wavelet and normalise
            norm = (self.dt / scale) ** 0.5
            filter_ = norm * self.wavelet(t, scale)
            self._filters.append(torch.conj(torch.flip(filter_, [-1])))

        self._pad_filters()

    def _pad_filters(self):
        filter_len = self._filters[-1].shape[0]
        padded_filters = []

        for f in self._filters:
            pad = (filter_len - f.shape[0]) // 2
            padded_filters.append(nn.functional.pad(f, (pad, pad)))

        self._filters = padded_filters

    def _build_wavelet_bank(self):
        """This function builds a 2D wavelet filter using wavelets at different scales

        Returns:
            tensor: Tensor of shape (num_widths, 1, channels, filter_len)
        """
        self._build_filters()
        wavelet_bank = torch.stack(self._filters)
        wavelet_bank = wavelet_bank.view(
            wavelet_bank.shape[0], 1, 1, wavelet_bank.shape[1]
        )
        return wavelet_bank

    def forward(self, x):
        """Compute CWT arrays from a batch of multi-channel inputs

        Args:
            x (torch.tensor): Tensor of shape (batch_size, channels, time)

        Returns:
            torch.tensor: Tensor of shape (batch_size, channels, widths, time)
        """
        if self.signal_length is None:
            self.signal_length = x.shape[-1]
            self.channels = x.shape[-2]
            self._scales = self.compute_optimal_scales()
            self._kernel = self._build_wavelet_bank()

            if self._kernel.is_complex():
                self._kernel_real = self._kernel.real
                self._kernel_imag = self._kernel.imag

        x = x.unsqueeze(1)

        if self._kernel.is_complex():
            if (
                x.dtype != self._kernel_real.dtype
                or x.device != self._kernel_real.device
            ):
                self._kernel_real = self._kernel_real.to(device=x.device, dtype=x.dtype)
                self._kernel_imag = self._kernel_imag.to(device=x.device, dtype=x.dtype)

            output_real = nn.functional.conv2d(
                x, self._kernel_real, padding=self.padding, stride=self.stride
            )
            output_imag = nn.functional.conv2d(
                x, self._kernel_imag, padding=self.padding, stride=self.stride
            )
            output_real = torch.transpose(output_real, 1, 2)
            output_imag = torch.transpose(output_imag, 1, 2)

            if self.output_format == "Magnitude":
                return torch.sqrt(output_real ** 2 + output_imag ** 2)
            else:
                return torch.stack([output_real, output_imag], -1)

        else:
            if x.device != self._kernel.device:
                self._kernel = self._kernel.to(device=x.device, dtype=x.dtype)

            output = nn.functional.conv2d(
                x, self._kernel, padding=self.padding, stride=self.stride
            )
            return torch.transpose(output, 1, 2)

# Data retrieving and related functions

In [None]:
from scipy import signal 

from scipy import signal

bHP, aHP = signal.butter(8, (20, 500), btype='bandpass', fs= 2048)
def filterSig(waves, a=aHP, b=bHP, axis = 1):
    '''Apply a 20Hz high pass filter to the three events'''
    if not CFG.BPfilter:
        return waves
    return signal.filtfilt(b, a, waves, axis = axis) #lfilter introduces a larger spike around 20hz

class DataRetriever(torch_data.Dataset):
    def __init__(self, paths, targets):
        self.paths = paths
        self.targets = targets
          
    def __len__(self):
        return len(self.paths)
    
    def __get_qtransform(self, x):
        # image = x / np.max(x,axis=1,keepdims = True)
        image = x/(2.8679e-20)
        if CFG.DEBUG:
            image = filterSig(image).copy()
        # image = image / np.max(np.abs(image),axis=1,keepdims = True)
        # image is [chan x time]
        image = torch.tensor(image).float()
        return image

    
    def __getitem__(self, index):
        #file_path = convert_image_id_2_path(self.paths[index])
        file_path = self.paths[index]
        x = np.load(file_path)
        image = self.__get_qtransform(x)
        
        y = torch.tensor(self.targets[index], dtype=torch.float)
            
        return {"X": image, "y": y}
    
    
class TestDataRetriever(torch_data.Dataset):
    def __init__(self, paths):
        self.paths = paths
        
        self.q_transform = CQT1992v2(
            sr=2048, fmin=20, fmax=1024, hop_length=32
        ) if CFG.PREPROC == 'Q_transform' else None
        
          
          
    def __len__(self):
        return len(self.paths)
    
    def __get_qtransform(self, x):
        image = x/(2.8679e-20)
        # image = x / np.max(x,axis=1,keepdims = True)
        if CFG.DEBUG:
            image = filterSig(image).copy()
        # image = image / np.max(np.abs(image),axis=1,keepdims = True)
        # image is [chan x time]
        image = torch.tensor(image).float()
        return image
    
    def __getitem__(self, index):
        # file_path = convert_image_id_2_path(self.paths[index], is_train=False)
        file_path = self.paths[index]
        x = np.load(file_path)
        image = self.__get_qtransform(x)
            
        return {"X": image, "id": self.paths[index]}

In [None]:
if True:
    Fold = StratifiedKFold(n_splits=CFG.n_fold, shuffle=True, random_state=CFG.seed)
    for n, (train_index, val_index) in enumerate(Fold.split(train_df, train_df['target'])):
        train_df.loc[val_index, 'fold'] = int(n)
    train_df['fold'] = train_df['fold'].astype(int)
    display(train_df.groupby(['fold', 'target']).size())

    df_train = train_df.loc[train_df['fold'] != CFG.fold,:]
    df_valid= train_df.loc[train_df['fold'] == CFG.fold,:]
else:
    df_train, df_valid = sk_model_selection.train_test_split(
    train_df, 
    test_size=0.2, 
    random_state=42, 
    stratify=train_df["target"],
    )




In [None]:
train_df_true = train_df.loc[train_df.target == 1]
train_df_false = train_df.loc[train_df.target == 0]
display(train_df_true.groupby('fold').mean())
display(train_df_false.groupby('fold').mean())

In [None]:
def get_train_file_path(image_id):
    return "../input/g2net-gravitational-wave-detection/train/{}/{}/{}/{}.npy".format(
        image_id[0], image_id[1], image_id[2], image_id)

def get_test_file_path(image_id):
    return "../input/g2net-gravitational-wave-detection/test/{}/{}/{}/{}.npy".format(
        image_id[0], image_id[1], image_id[2], image_id)

df_train['file_path'] = df_train['id'].apply(get_train_file_path)
df_valid['file_path'] = df_valid['id'].apply(get_train_file_path)

submission['file_path'] = submission['id'].apply(get_test_file_path)

submission_8748 = pd.read_csv('../input/submission-08748/submission.csv')
submission_8748['file_path'] = submission_8748['id'].apply(get_test_file_path)

if CFG.useTestLabels:
    submission_8748['soft_target'] = submission_8748['target']#0.5*(submission_8748['target']+np.round(submission_8748['target']))
    tmp_df_0 = submission_8748.loc[submission_8748['target']<0.25,:]
    tmp_df_1 = submission_8748.loc[submission_8748['target']>0.9,:]
    print('No. test negatives selected: '+str(len(tmp_df_0)))
    print('No. test positives selected: '+str(len(tmp_df_1)))
    
    print('Train size: '+str(len(df_train)))
    if CFG.SMALL_TRAIN_SET:
        tmp_df_0 = tmp_df_0.head(CFG.BATCH*50)
        tmp_df_1 = tmp_df_1.head(CFG.BATCH*50)
    else:
        tmp_df_0 = tmp_df_0.head(44800) # 128*300
        tmp_df_1 = tmp_df_1.head(44800)
    print(tmp_df_0.head())
    df_train = df_train.append(tmp_df_0)
    df_train = df_train.append(tmp_df_1).sample(frac=1)
    df_train['target'] = df_train['target'].clip(lower = 0., upper = 1.)
    print('Train size with test: '+str(len(df_train)))


train_data_retriever = DataRetriever(
    df_train['file_path'].values, 
    df_train["target"].values, 
)

train_data_retriever_soft = DataRetriever(
    df_train['file_path'].values, 
    df_train["soft_target"].values, 
)

valid_data_retriever = DataRetriever(
    df_valid['file_path'].values, 
    df_valid["target"].values,
)

test_data_retriever = TestDataRetriever(
    submission["file_path"].values, 
)

In [None]:
train_loader = torch_data.DataLoader(
    train_data_retriever,
    batch_size=CFG.BATCH,
    shuffle=True,
    num_workers=12,
    drop_last = True
)

train_loader_soft = torch_data.DataLoader(
    train_data_retriever_soft,
    batch_size=CFG.BATCH,
    shuffle=True,
    num_workers=12,
    drop_last = True
)

valid_loader = torch_data.DataLoader(
    valid_data_retriever, 
    batch_size=512,
    shuffle=False,
    num_workers=8,
)

test_loader = torch_data.DataLoader(
    test_data_retriever,
    batch_size=512,
    shuffle=False,
    num_workers=8,
)



# Model

In [None]:
# Different heads

class BasicHead(nn.Module):   
    def __init__(self,n_features):
        super().__init__()
        self.classifier = nn.Sequential(
          nn.Dropout(0.5),
          nn.Linear(in_features=n_features, out_features=256, bias=True),
          nn.ReLU(),
          # nn.Dropout(0.5), # p is probability of zeroing
          nn.Linear(in_features=256, out_features=1, bias=True),
        )
        
    def forward(self,x):
        return self.classifier(x)
    
class MultiDropoutHead(nn.Module):
    def __init__(self,n_features):
        super().__init__()
        if False:
            self.classifier = nn.Sequential(
              nn.Linear(in_features=n_features, out_features=256, bias=True),
              nn.ReLU(),
              # nn.Dropout(0.5), # p is probability of zeroing
              nn.Linear(in_features=256, out_features=1, bias=True),
            )
        else:
            self.classifier = nn.Linear(in_features=n_features, out_features=1, bias=True)
        self.dropout = lambda p: nn.Dropout(p)
        
    def forward(self,x):
        return torch.mean(torch.stack([
            self.classifier(self.dropout(p)(x))
            for p in np.linspace(0.3, 0.7, 5)
        ], dim=0), dim=0)

In [None]:

def whiten(signal):
    # From here: https://www.kaggle.com/kevinmcisaac/g2net-spectral-whitening
    length = signal.size(2)
    hann = torch.hann_window(length, periodic=True, dtype=float).view(1,1,-1)
    spec = fft(signal* hann, dim = 2)
    mag = torch.sqrt(torch.real(spec*torch.conj(spec))) 

    return torch.real(ifft(spec/mag)) * np.sqrt(length/2)

def batch_preprocessing(X):
    # X = whiten(X)
    X = X.numpy()
    if CFG.BPfilter:        
        X = filterSig(X,axis=2).copy()
    X = torch.tensor(X).float()
    return X      
        

Head = MultiDropoutHead#BasicHead

model_no = 1

def Backbone():
    if 'tf_efficientnet' in CFG.model_name:
        model = timm.create_model(CFG.model_name, pretrained=CFG.pretrained, 
                                  drop_path_rate = 0.4 if CFG.useTestLabels else 0.,
                                  in_chans = 6 if CFG.useCWT else 3)
        n_features = model.classifier.in_features
        model.classifier = nn.Identity()
    elif 'efficientnet-' in CFG.model_name:
        model = efficientnet_pytorch.EfficientNet.from_pretrained(CFG.model_name)
        n_features = model._fc.in_features
        model._fc = nn.Identity()
    elif 'rexnet_' in CFG.model_name:
        model = timm.create_model(CFG.model_name, pretrained=CFG.pretrained)
        n_features = model.head.fc.in_features
        model.head.fc = nn.Identity()
        
    return model, n_features

class Transform:
    
    def __init__(self,useCWT = CFG.useCWT):
        self.q_transform = CQT1992v2(
                **CFG.cqt_params
            ).to(device)
        self.useCWT =  useCWT
        if self.useCWT:
            self.pycwt = CWT(fmin=20, fmax=512, stride = (1,26)).to(device)
        self.nH = 75
        self.resize = torchvision.transforms.Resize(CFG.image_size)
        
        if CFG.DENOISE:
            masks = np.load('../input/g2netmasks/masks.npy')
            self.masks = torch.tensor(masks,dtype = torch.float32).to(device)
            self.masks = (self.masks*20)+1
            self.normalize = torchvision.transforms.Normalize((0.0782925,0.07827569,0.07719089),
                                                              (0.03546559,0.03547181,0.0346077))
        else:
            #mean = (-0.49382125+0.5,-0.49382148+0.5,-0.49036058+0.5)
            #std = (0.00531079,0.00531216,0.00816139)
            mean = (0.00776412,0.00775545,0.01362684)
            std = (0.00917761,0.00914989,0.01805867)
            self.normalize = torchvision.transforms.Normalize(mean,
                                                          std)
            self.normalize_cwt = torchvision.transforms.Normalize((0.00469631,0.00469434,0.00796551),
                                                                  (0.00338371,0.00338833,0.00668454))
            
    def transform(self,x):
        if CFG.NORM_MODE == 'local':
            x = x/torch.amax(x, dim = (1,2), keepdim = True)
        if self.useCWT:
            x_cwt = self.pycwt(x)
        batch_size = x.size(0)
        x = torch.reshape(x,(batch_size*3,-1))
        x = self.q_transform(x)
        size = list(x.size())
        x = torch.reshape(x,(batch_size,3,size[1],size[2]))
        x = self.resize(x)
        x = x[:,:,np.arange(CFG.image_size[0]-1,-1,-1),:]
        if CFG.DENOISE:
            x =  x/(self.masks)
        if CFG.SCALE == 'log':
            x = (torch.log10(x) + 1.)/1.5
        else:
            if CFG.NORM_MODE == 'global':
                x = (x + 0.4926)/np.sqrt(5.33e-5)
            elif CFG.NORM_MODE == 'local':
                xmin = torch.amin(x, dim = (1,2,3), keepdim = True)
                xmax = torch.amax(x, dim = (1,2,3), keepdim = True)
                x = (x-xmin)/(xmax-xmin)
            elif CFG.NORM_MODE == 'local_std':
                xmean = torch.mean(x, dim = (1,2,3), keepdim = True)
                xstd = torch.std(x, dim = (1,2,3), keepdim = True)
                x = (x-xmean)/(xstd)
            elif CFG.NORM_MODE == 'global_channel':
                x = self.normalize(x)
            elif CFG.NORM_MODE == 'global_channel_local':
                x = self.normalize(x)
                xmean = torch.mean(x, dim = (1,2,3), keepdim = True)
                xstd = torch.std(x, dim = (1,2,3), keepdim = True)
                x = (x-xmean)/(xstd)
                
        if self.useCWT:
            if CFG.NORM_MODE == 'global':
                x_cwt = (x_cwt - 0.00579486)/np.sqrt(2.57e-5)
            elif CFG.NORM_MODE == 'local':
                xmin = torch.amin(x_cwt, dim = (1,2,3), keepdim = True)
                xmax = torch.amax(x_cwt, dim = (1,2,3), keepdim = True)
                x_cwt = (x_cwt-xmin)/(xmax-xmin)
            elif CFG.NORM_MODE == 'local_std':
                xmean = torch.mean(x_cwt, dim = (1,2,3), keepdim = True)
                xstd = torch.std(x_cwt, dim = (1,2,3), keepdim = True)
                x_cwt = (x_cwt-xmean)/(xstd)
            elif CFG.NORM_MODE == 'global_channel':
                x_cwt = self.normalize_cwt(x_cwt)
            x = torch.cat((x,x_cwt),dim=1)
        return x
        

if model_no == 1:
    print('Selecting multi channel model ... ')
    class Model(nn.Module):
        def __init__(self, get_spectrogram = False):
            self.get_spectrogram = get_spectrogram
            super().__init__()
            self.q_transform = CQT1992v2(
                **CFG.cqt_params
            )
            self.transform = Transform()
            if not self.get_spectrogram:
                
                self.model, n_features = Backbone()
                self.head = Head(n_features)

        def freezeModel(self):
            for param in self.model.parameters():
                param.requires_grad = False

        def unfreezeModel(self):
            for param in self.model.parameters():
                param.requires_grad = True

        def forward(self, x):
            # reshape from [batch by chan by time] [(batch x chan) by time]
            
            # x = torch.divide(x,torch.max(torch.abs(x),dim=2,keepdims = True)[0])
            x = self.transform.transform(x)
                #x = torch.divide(x,torch.mean(x,dim=2,keepdims = True))
            if self.get_spectrogram:
                return x
            x = self.model(x)
            out = self.head(x)
            return out
elif model_no == 2:        
    print('Selecting single channel model ... ')
    
    class Model(nn.Module):
        def __init__(self, get_spectrogram = False):
            self.get_spectrogram = get_spectrogram
            super().__init__()
            self.q_transform = CQT1992v2(
                **CFG.cqt_params
            )
            if not self.get_spectrogram:
                self.model, n_features = Backbone()
                self.head = Head(n_features)

        def freezeModel(self):
            for param in self.model.parameters():
                param.requires_grad = False

        def unfreezeModel(self):
            for param in self.model.parameters():
                param.requires_grad = True

        def forward(self, x):
            # reshape from [batch by chan by time] [(batch x chan) by time]
            batch_size = x.size(0)

            x = torch.divide(x,torch.max(torch.abs(x),dim=2,keepdims = True)[0])
            x = torch.reshape(x,(batch_size*3,-1))
            x = self.q_transform(x)
            x = x[:,0:-1,0:-1]
            if CFG.SCALE == 'log':
                x = (torch.log10(x) + 1.)/1.5
            else:
                x = torch.clamp(x,max=2.5)-1
            if self.get_spectrogram:
                size = list(x.size())
                x = torch.reshape(x,(batch_size,3,size[1],size[2]))
                return x
            
            x = torch.unsqueeze(x,1)
            
            # x_mean = torch.mean(x,dim=1,keepdims = True)
            # x = torch.stack([x,x_mean],dim=1)
            
            x = self.model(x)
            size = list(x.size())
            x = torch.reshape(x,(batch_size,-1,size[1]))
            x = torch.max(x,dim=1,keepdims = False)[0]
            out = self.head(x)
            out = out
            return out
        
elif model_no == 3:
    print('Selecting experimental model ... ')
    class Model(nn.Module):
        def __init__(self, get_spectrogram = False):
            self.get_spectrogram = get_spectrogram
            super().__init__()
            self.q_transform = CQT1992v2(
                **CFG.cqt_params
            )
            if not self.get_spectrogram:
                self.model, n_features = Backbone()
                self.head = Head(2*n_features)

        def freezeModel(self):
            for param in self.model.parameters():
                param.requires_grad = False

        def unfreezeModel(self):
            for param in self.model.parameters():
                param.requires_grad = True

        def forward(self, x):
            # reshape from [batch by chan by time] [(batch x chan) by time]
            batch_size = x.size(0)

            x = torch.divide(x,torch.max(torch.abs(x),dim=2,keepdims = True)[0])
            x = torch.reshape(x,(batch_size*3,-1))
            x = self.q_transform(x)
            x = x[:,0:-1,0:-1]
            if CFG.SCALE == 'log':
                x = (torch.log10(x) + 1.)/1.5
            else:
                x = torch.clamp(x,max=2.5)-1
            if self.get_spectrogram:
                size = list(x.size())
                x = torch.reshape(x,(batch_size,3,size[1],size[2]))
                return x
            
            x = torch.unsqueeze(x,1)
            
            size = list(x.size())
            x = torch.reshape(x,(batch_size,3,size[2],size[3]))
            x_mean = torch.mean(x,dim=1,keepdims = True)
            x = torch.cat([x,x_mean],dim=1)
            x = torch.reshape(x,(batch_size*4,1,size[2],size[3]))
            
            x = self.model(x)
            size = list(x.size())
            x = torch.reshape(x,(batch_size,-1,size[1]))
            x_mean = x[:,3,:].squeeze(1)
            x = torch.max(x,dim=1,keepdims = False)[0]
            # this will be [batch by (2 x n_features)]
            x = torch.cat([x,x_mean],dim=1)
            
            out = self.head(x)
            out = out
            return out
    




In [None]:
if False:
    masks = np.load('../input/g2netmasks/masks.npy')
    masks = masks*0.9+0.1
    for i in range(3):
        plt.figure()
        plt.title('Channel ' + str(i) + '; Target = '+ str(targets[n]))
        plt.imshow(np.mean(masks[:,i,:,:],axis=0).squeeze())
        plt.colorbar()
        #plt.clim([0.1,1])

In [None]:
if True:
    modelTmp = Model(get_spectrogram = True)
    tmp = np.zeros([0,3 if not CFG.useCWT else 6,CFG.image_size[0],CFG.image_size[1]])
    for step, batch in enumerate(train_loader,1):
        X = batch["X"]
        y = batch["y"]
        print(torch.max(X))
        X = batch_preprocessing(X)
        modelTmp.to(device)
        X = X.to(device)
        targets = batch["y"].to(device)
        outputs = modelTmp(X)
        print(outputs.size())
        n = np.random.randint(32)
        import matplotlib.pyplot as plt
        # tmp = outputs[n].cpu().numpy()
        tmp = np.concatenate((tmp,outputs.cpu().numpy()[y.numpy()==0,:,:,:]),axis = 0)
        if step == 20:
            break
    count = np.cumsum(np.size(tmp))
    print('Mean = ' + str(np.sum(tmp)/count))
    print('std = ' + str(np.sqrt(np.sum(tmp**2)/count - (np.sum(tmp)/count)**2)))
    print('Max = ' + str(np.max(tmp)))
    
    print('Mean = ' + str(np.mean(tmp,axis = (0,2,3))))
    print('std = ' + str(np.std(tmp,axis = (0,2,3))))
    print('Max = ' + str(np.max(tmp,axis = (0,2,3))))
    
    for i in range(3):
        plt.figure()
        plt.title('Channel ' + str(i) + '; Target = '+ str(targets[n]))
        plt.imshow(np.mean(tmp[:,i,:,:],axis=0).squeeze())
        plt.colorbar()
    masks = np.mean(tmp[:,:,:,:],axis=0,keepdims = True)
    

In [None]:

if True:
    
    modelTmp = Model(get_spectrogram = True)
    for step, batch in enumerate(train_loader_soft,1):
        X = batch["X"]
        print(torch.max(X))
        X = batch_preprocessing(X)
        modelTmp.to(device)
        X = X.to(device)
        targets = batch["y"].to(device)
        outputs = modelTmp(X)
        print(outputs.size())
        n = np.random.randint(32)
        import matplotlib.pyplot as plt
        tmp = outputs[n].cpu().numpy()
        count = np.cumsum(np.size(tmp))
        print('Mean = ' + str(np.sum(tmp)/count))
        print('std = ' + str(np.sum(tmp**2)/count - (np.sum(tmp)/count)**2))
        for i in range(3):
            plt.figure()
            plt.title('Target = '+ str(targets[n]))
            plt.imshow((tmp[i,:,:]).squeeze())
            plt.colorbar()
            if CFG.useCWT:
                plt.figure()
                plt.title('CWT: Target = '+ str(targets[n]))
                plt.imshow((tmp[i+3,:,:]).squeeze())
                plt.colorbar()
        plt.figure()
        plt.title('Target = '+ str(targets[n]))
        plt.imshow(np.mean(tmp[:,:,:],axis=0).squeeze())
        plt.colorbar()
        print(tmp.shape)
        break

# Loss related functions

In [None]:
class LossMeter:
    def __init__(self):
        self.avg = 0
        self.n = 0

    def update(self, val):
        self.n += 1
        # incremental update
        self.avg = val / self.n + (self.n - 1) / self.n * self.avg

        
class AccMeter:
    def __init__(self):
        self.avg = 0
        self.n = 0
        
    def update(self, y_true, y_pred):
        y_true = y_true.cpu().round().numpy().astype(int)
        y_pred = y_pred.cpu().numpy() >= 0
        last_n = self.n
        self.n += len(y_true)
        true_count = np.sum(y_true == y_pred)
        # incremental update
        self.avg = true_count / self.n + last_n / self.n * self.avg
        


# Trainer related functions

In [None]:
def get_scheduler(optimizer):
    if CFG.scheduler=='ReduceLROnPlateau':
        scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=CFG.factor, patience=CFG.patience, verbose=True, eps=CFG.eps)
    elif CFG.scheduler=='CosineAnnealingLR':
        scheduler = CosineAnnealingLR(optimizer, T_max=CFG.T_max, eta_min=CFG.min_lr, last_epoch=-1)
    elif CFG.scheduler=='CosineAnnealingWarmRestarts':
        scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=CFG.T_0, T_mult=1, eta_min=CFG.min_lr, last_epoch=-1)
    elif CFG.scheduler == 'ExponentialLR':
        scheduler = ExponentialLR(optimizer, gamma = CFG.gamma)
    return scheduler

class Trainer:
    def __init__(
        self, 
        model, 
        device, 
        optimizer, 
        criterion, 
        loss_meter, 
        score_meter,
        use_swa = False
    ):
        self.model = model
        # freeze model by default
        self.model.freezeModel()
        
        self.device = device
        self.use_swa = use_swa
        self.optimizer = swa.SWA(optimizer) if self.use_swa else optimizer
        self.criterion = criterion
        self.loss_meter = loss_meter
        self.score_meter = score_meter
        self.scheduler = get_scheduler(optimizer)
        self.learning_rate = self.scheduler.get_lr()
        
        self.best_valid_score = -np.inf
        self.best_valid_rocauc = -np.inf
        self.n_patience = 0
        
        self.messages = {
            "epoch": "[Epoch {}: {}] loss: {:.5f}, score: {:.5f}, auc_score: {:.5f}, time: {} s",
            "checkpoint": "The score improved from {:.5f} to {:.5f}. Save model to '{}'",
            "checkpoint_auc": "The rocauc improved from {:.5f} to {:.5f}. Save model to '{}'",
            "patience": "\nValid score didn't improve last {} epochs."
        }
        self.training_step = 0
        self.prevbatch = []
        self.epoch = -1 
    
    def fit(self, epochs, train_loader, valid_loader, save_path, patience,train_loader_soft = False):        
        for n_epoch in range(1, epochs + 1):
            gc.collect()
            self.epoch = n_epoch
            
            self.info_message("EPOCH: {}", n_epoch)
            
            if self.epoch==1 or (not train_loader_soft):
                train_loss, train_score, train_time = self.train_epoch(train_loader)
            else:
                train_loss, train_score, train_time = self.train_epoch(train_loader_soft)
                
            valid_loss, valid_score, valid_time, valid_rocauc = self.valid_epoch(valid_loader)
            
            self.info_message(
                self.messages["epoch"], "Train", n_epoch, train_loss, train_score, 0, train_time
            )
            
            self.info_message(
                self.messages["epoch"], "Valid", n_epoch, valid_loss, valid_score, valid_rocauc, valid_time
            )
            
            self.n_patience += 1
            if self.best_valid_score < valid_score:
                self.info_message(
                    self.messages["checkpoint"], self.best_valid_score, valid_score, save_path
                )
                self.best_valid_score = valid_score
                self.save_model(n_epoch, save_path)
                self.n_patience = 0
                
            if self.best_valid_rocauc < valid_rocauc:
                self.info_message(
                    self.messages["checkpoint_auc"], self.best_valid_rocauc, valid_rocauc, 'rocauc'+save_path
                )
                self.best_valid_rocauc = valid_rocauc
                self.save_model(n_epoch, 'rocauc_'+ save_path)
                self.n_patience = 0
                
            
            if self.n_patience >= patience:
                self.info_message(self.messages["patience"], patience)
                break
        if self.use_swa:
            self.optimizer.bn_update(train_loader, self.model)
            self.optimizer.swap_swa_sgd()
        
    def train_epoch(self, train_loader):
        self.model.train()
        t = time.time()
        train_loss = self.loss_meter()
        train_score = self.score_meter()
        
        for step, batch in enumerate(tqdm(train_loader),1):
            
            if self.training_step == CFG.unfreezeStep:
                self.model.unfreezeModel()
            
            X = batch["X"]
            if self.prevbatch:
                prevX = self.prevbatch["X"]
                prevY = self.prevbatch['y']
                rndNum = np.random.rand()
                if rndNum<0.5:
                    # only keep prevX where there is no wave
                    prevX = torch.where(prevY.view(-1,1,1)>0.5,X,prevX)
                    # weight for prevX is at most 0.5, and not replaced
                    # when there is a wave
                    X = (1-rndNum)*X + rndNum*prevX 

            self.prevbatch = batch.copy()  
            X = batch_preprocessing(X)
            X = X.to(self.device)
            targets = batch["y"].to(self.device)
            self.optimizer.zero_grad()
            outputs = self.model(X).squeeze(1)
            
            loss = self.criterion(outputs, targets)
            loss.backward()

            train_loss.update(loss.detach().item())
            train_score.update(targets, outputs.detach())

            self.optimizer.step()
            
            _loss, _score = train_loss.avg, train_score.avg
            
            message = 'Train Step {}/{}, train_loss: {:.5f}, train_score: {:.5f}, learning_rate: {:.7e}/{:.7e}'
            self.info_message(message, step, len(train_loader), _loss, _score, self.learning_rate[0], self.learning_rate[1],end="\r")
            self.training_step += 1
            
            if self.training_step%CFG.schedulerStepFreq==0:
                if isinstance(self.scheduler, CosineAnnealingLR):
                    self.scheduler.step()
                elif isinstance(self.scheduler, CosineAnnealingWarmRestarts):
                    self.scheduler.step()
                elif isinstance(self.scheduler, ExponentialLR):
                    self.scheduler.step()
                self.learning_rate = self.scheduler.get_lr()
        # print('\n Updated learning rate: '+ str(self.scheduler.get_lr()))
        
        return train_loss.avg, train_score.avg, int(time.time() - t)
    
    def valid_epoch(self, valid_loader,returnPred = False):
        self.model.eval()
        t = time.time()
        valid_loss = self.loss_meter()
        valid_score = self.score_meter()
        
        for step, batch in enumerate(valid_loader, 1):
            y_pred = []
            tgts = []
            with torch.no_grad():
                X = batch["X"]  
                X = batch_preprocessing(X)
                X = X.to(self.device)
                targets = batch["y"].to(self.device)

                outputs = self.model(X).squeeze(1)
                loss = self.criterion(outputs, targets)

                valid_loss.update(loss.detach().item())
                valid_score.update(targets, outputs)
                outputs = outputs
                y_pred.extend(torch.sigmoid(outputs).cpu().numpy().squeeze())
                tgts.extend(batch["y"].numpy())
                    
            rocauc = roc_auc_score(tgts,y_pred)
            _loss, _score = valid_loss.avg, valid_score.avg
            message = 'Valid Step {}/{}, valid_loss: {:.5f}, valid_score: {:.5f},valid_roc_auc: {:.5f}'
            self.info_message(message, step, len(valid_loader), _loss, _score, rocauc, end="\r")
        if not returnPred:
            return valid_loss.avg, valid_score.avg, int(time.time() - t), rocauc
        else:
            return y_pred, tgts
    
    def test_eval(self,test_loader):
        y_pred = []
        ids = []
        for e, batch in enumerate(test_loader):
            print(f"{e}/{len(test_loader)}", end="\r")
            with torch.no_grad():
                X = batch["X"]
                X = batch_preprocessing(X)
                X = X.to(self.device)
                outputs = self.model(X)
                y_pred.extend(torch.sigmoid(outputs).cpu().numpy().squeeze())
                ids.extend(batch["id"])
        return y_pred, ids
    
    def save_model(self, n_epoch, save_path):
        torch.save(
            {
                "model_state_dict": self.model.state_dict(),
                "optimizer_state_dict": self.optimizer.state_dict(),
                "best_valid_score": self.best_valid_score,
                "n_epoch": n_epoch,
            },
            save_path,
        )
    
    @staticmethod
    def info_message(message, *args, end="\n"):
        print(message.format(*args), end=end)

In [None]:

model = Model()
model.to(device)
if (not CFG.TRAIN) or CFG.WARM_START:
    startCheckpoint = torch.load("../input/cqtpt2v10/best-model.pth")
    checkpoint = torch.load("../input/cqtpt2v10/best-model.pth")
    model.load_state_dict(startCheckpoint["model_state_dict"])

optimizer = torch.optim.Adam([{"params": model.model.parameters(), "lr": CFG.lr},
                              {"params": model.head.parameters(), "lr": CFG.lr/5}], 
                             lr=CFG.lr)
criterion = torch_functional.binary_cross_entropy_with_logits

trainer = Trainer(
    model, 
    device, 
    optimizer, 
    criterion, 
    LossMeter, 
    AccMeter
)

if CFG.TRAIN:
    history = trainer.fit(
        CFG.EPOCHS, 
        train_loader, 
        valid_loader, 
        "best-model.pth", 
        400,
        train_loader_soft = train_loader_soft if CFG.useSoftLabels else False
    )
    
    y_pred_val,tgts = trainer.valid_epoch(valid_loader,returnPred = True)

In [None]:
plt.scatter(tgts,y_pred_val,1)
plt.xlabel('targets')
plt.ylabel('predictions')

In [None]:
if CFG.TRAIN:
    checkpoint = torch.load("best-model.pth")

model.load_state_dict(checkpoint["model_state_dict"])
model.eval();
gc.collect()
y_pred, ids = trainer.test_eval(test_loader)
submission = pd.read_csv("../input/g2net-gravitational-wave-detection/sample_submission.csv")
submission = pd.DataFrame({"id": submission['id'].values, "target": y_pred})
submission.to_csv("model_submission.csv", index=False)
submission

In [None]:
if CFG.TRAIN:
    checkpoint = torch.load("rocauc_best-model.pth")

model.load_state_dict(checkpoint["model_state_dict"])
model.eval();
gc.collect()
y_pred, ids = trainer.test_eval(test_loader)
submission = pd.read_csv("../input/g2net-gravitational-wave-detection/sample_submission.csv")
submission = pd.DataFrame({"id": submission['id'].values, "target": y_pred})
submission.to_csv("model_submission_rocauc.csv", index=False)
submission