In [None]:
import os
import json
import random
import collections

import numpy as np
import pandas as pd
import cv2
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
# Things to try:
# GeM pooling
# soften labels during mixup
# change stride of first conv layer
# increasing no. of channels with a pre BB convolution layer

# Import libraries

In [None]:
!pip install efficientnet_pytorch -qq

!pip install -q nnAudio -qq
import torch
from nnAudio.Spectrogram import CQT1992v2, CQT2010v2

import time

import torch
from torch import nn
from torch.utils import data as torch_data
from sklearn import model_selection as sk_model_selection
from torch.nn import functional as torch_functional
from torch.autograd import Variable
import efficientnet_pytorch
from tqdm.auto import tqdm
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, CosineAnnealingLR, ReduceLROnPlateau
from torchaudio.functional import lfilter
from torch.fft import fft, rfft, ifft
import numpy as np



from sklearn.metrics import roc_auc_score

from sklearn.model_selection import StratifiedKFold

import sys
sys.path.append('../input/timm-pytorch-image-models/pytorch-image-models-master')
import timm

In [None]:
#sys.path.append('../input/pytorch-swa')
#import swa

# Load data

In [None]:
submission = pd.read_csv("../input/g2net-gravitational-wave-detection/sample_submission.csv")
train_df = pd.read_csv("../input/g2net-gravitational-wave-detection/training_labels.csv")
train_df_pred = pd.read_csv("../input/train-pred-cqt-v10/train_preds_CQT_V10.csv")

In [None]:
train_df['preds'] = train_df_pred['target']
weight = 0.5
train_df['soft_target'] = train_df['preds']*weight + train_df['target']*(1-weight)
train_df

In [None]:
train_df_pred

# Define config

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class CFG:
    
    TRAIN = True
    
    WARM_START = True
    
    EPOCHS = 7
    
    lr = 1e-2
    
    n_fold = 5
    fold = 0
    
    # scheduler_params
    scheduler='CosineAnnealingLR'
    T_max=3 # CosineAnnealingLR
    T_0=3 # CosineAnnealingWarmRestarts
    min_lr=1e-6
    schedulerStepFreq = 400
    
    valStepFreq= 1600
    
    
    # Parameters CWT
    cwt_params = {'fs':2048, 'lower_freq':10, 'upper_freq': 500, 
                  'n_scales':81, 'wavelet_width':1, 'stride':12, 'border_crop':0, 'train_width':True}
    
    cqt_params = {'sr':2048, 'fmin':20, 'fmax':512, 'hop_length':32, 'n_bins':69}
    # cqt_params = {'sr':2048, 'fmin':20, 'fmax':512, 'hop_length':32, 'bins_per_octave': 25, 'norm':1}
    
    BPfilter = True
    
    
    
    
    # Post Proc Option
    PREPROC = 'Q_transform'
    
    # batch size
    BATCH = 64
    BATCH_VAL = 128*4
    
    # scale:linear or log
    SCALE = 'linear'
    
    DEBUG = False
    
    SMALL_TRAIN_SET = True
    
    VISUALIZE = True
    
    seed = 42
    
    model_name = 'tf_efficientnet_b0' #'tf_efficientnet_b4' #'efficientnet-b7'
    pretrained = False
    unfreezeStep = 100 # set 0 for no freezing
    
    useSoftLabels = True
    useTestLabels = True
    

if not CFG.pretrained:
    CFG.unfreezeStep = 10
    
if CFG.DEBUG:
    CFG.EPOCHS = 2
    train_df = train_df.sample(n=1000, random_state=CFG.seed).reset_index(drop=True)
elif CFG.SMALL_TRAIN_SET:
    CFG.EPOCHS = 4
    train_df = train_df.sample(n=CFG.BATCH*500, random_state=CFG.seed).reset_index(drop=True)
    CFG.valStepFreq = 100
elif CFG.WARM_START:
    CFG.EPOCHS = 3
    
    

In [None]:
model = timm.create_model(CFG.model_name, pretrained=CFG.pretrained)
model.global_pool


In [None]:
import random

def set_seed(seed):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True


set_seed(42)

# Data retrieving and related functions

In [None]:
from scipy import signal 

from scipy import signal

bHP, aHP = signal.butter(8, (25, 500), btype='bandpass', fs= 2048)
def filterSig(waves, a=aHP, b=bHP, axis = 1):
    '''Apply a 20Hz high pass filter to the three events'''
    if not CFG.BPfilter:
        return waves
    return signal.filtfilt(b, a, waves, axis = axis) #lfilter introduces a larger spike around 20hz

class DataRetriever(torch_data.Dataset):
    def __init__(self, paths, targets):
        self.paths = paths
        self.targets = targets
          
    def __len__(self):
        return len(self.paths)
    
    def __get_qtransform(self, x):
        image = x / np.max(x,axis=1,keepdims = True)
        if CFG.DEBUG:
            image = filterSig(image).copy()
        # image = image / np.max(np.abs(image),axis=1,keepdims = True)
        # image is [chan x time]
        image = torch.tensor(image).float()
        return image

    
    def __getitem__(self, index):
        #file_path = convert_image_id_2_path(self.paths[index])
        file_path = self.paths[index]
        x = np.load(file_path)
        image = self.__get_qtransform(x)
        
        y = torch.tensor(self.targets[index], dtype=torch.float)
            
        return {"X": image, "y": y}
    
    
class TestDataRetriever(torch_data.Dataset):
    def __init__(self, paths):
        self.paths = paths
        
        self.q_transform = CQT1992v2(
            sr=2048, fmin=20, fmax=1024, hop_length=32
        ) if CFG.PREPROC == 'Q_transform' else None
        
          
          
    def __len__(self):
        return len(self.paths)
    
    def __get_qtransform(self, x):
        image = x / np.max(x,axis=1,keepdims = True)
        if CFG.DEBUG:
            image = filterSig(image).copy()
        # image = image / np.max(np.abs(image),axis=1,keepdims = True)
        # image is [chan x time]
        image = torch.tensor(image).float()
        return image
    
    def __getitem__(self, index):
        # file_path = convert_image_id_2_path(self.paths[index], is_train=False)
        file_path = self.paths[index]
        x = np.load(file_path)
        image = self.__get_qtransform(x)
            
        return {"X": image, "id": self.paths[index]}

In [None]:
if True:
    Fold = StratifiedKFold(n_splits=CFG.n_fold, shuffle=True, random_state=CFG.seed)
    for n, (train_index, val_index) in enumerate(Fold.split(train_df, train_df['target'])):
        train_df.loc[val_index, 'fold'] = int(n)
    train_df['fold'] = train_df['fold'].astype(int)
    display(train_df.groupby(['fold', 'target']).size())

    df_train = train_df.loc[train_df['fold'] != CFG.fold,:]
    df_valid= train_df.loc[train_df['fold'] == CFG.fold,:]
else:
    df_train, df_valid = sk_model_selection.train_test_split(
    train_df, 
    test_size=0.2, 
    random_state=42, 
    stratify=train_df["target"],
    )




In [None]:
def get_train_file_path(image_id):
    return "../input/g2net-gravitational-wave-detection/train/{}/{}/{}/{}.npy".format(
        image_id[0], image_id[1], image_id[2], image_id)

def get_test_file_path(image_id):
    return "../input/g2net-gravitational-wave-detection/test/{}/{}/{}/{}.npy".format(
        image_id[0], image_id[1], image_id[2], image_id)

df_train['file_path'] = df_train['id'].apply(get_train_file_path)
df_valid['file_path'] = df_valid['id'].apply(get_train_file_path)

submission['file_path'] = submission['id'].apply(get_test_file_path)
submission_869 = pd.read_csv('../input/grnet869/model_submission_B0_no_pretrain.csv')
submission_869['file_path'] = submission_869['id'].apply(get_test_file_path)

if CFG.useTestLabels:
    submission_869['soft_target'] = 0.5*(submission_869['target']+np.round(submission_869['target']))
    tmp_df_0 = submission_869.loc[submission_869['target']<0.2,:]
    tmp_df_1 = submission_869.loc[submission_869['target']>0.9,:]
    print('No. test negatives selected: '+str(len(tmp_df_0)))
    print('No. test positives selected: '+str(len(tmp_df_1)))
    
    print('Train size: '+str(len(df_train)))
    if CFG.SMALL_TRAIN_SET:
        tmp_df_0 = tmp_df_0.head(CFG.BATCH*50)
        tmp_df_1 = tmp_df_1.head(CFG.BATCH*50)
    else:
        tmp_df_0 = tmp_df_0.head(38400) # 128*300
        tmp_df_1 = tmp_df_1.head(38400)
    print(tmp_df_0.head())
    df_train = df_train.append(tmp_df_0)
    df_train = df_train.append(tmp_df_1).sample(frac=1)
    df_train['target'] = df_train['target'].clip(lower = 0., upper = 1.)
    print('Train size with test: '+str(len(df_train)))


train_data_retriever = DataRetriever(
    df_train['file_path'].values, 
    df_train["target"].values, 
)

train_data_retriever_soft = DataRetriever(
    df_train['file_path'].values, 
    df_train["soft_target"].values, 
)

valid_data_retriever = DataRetriever(
    df_valid['file_path'].values, 
    df_valid["target"].values,
)

test_data_retriever = TestDataRetriever(
    submission["file_path"].values, 
)

test_data_retriever_withlabels = DataRetriever(
    submission_869['file_path'].values, 
    submission_869["target"].values,
)

In [None]:
train_loader = torch_data.DataLoader(
    train_data_retriever,
    batch_size=CFG.BATCH,
    shuffle=True,
    num_workers=12,
)

train_loader_soft = torch_data.DataLoader(
    train_data_retriever_soft,
    batch_size=CFG.BATCH,
    shuffle=True,
    num_workers=12,
)

valid_loader = torch_data.DataLoader(
    valid_data_retriever, 
    batch_size=CFG.BATCH_VAL,
    shuffle=False,
    num_workers=8,
)

test_loader = torch_data.DataLoader(
    test_data_retriever,
    batch_size=CFG.BATCH_VAL,
    shuffle=False,
    num_workers=8,
)

test_loader_withlabels = torch_data.DataLoader(
    test_data_retriever_withlabels,
    batch_size=CFG.BATCH,
    shuffle=True,
    num_workers=12,
)



# Model

In [None]:
# Different heads

class BasicHead(nn.Module):   
    def __init__(self,n_features):
        super().__init__()
        self.classifier = nn.Sequential(
          nn.Dropout(0.5),
          nn.Linear(in_features=n_features, out_features=256, bias=True),
          nn.ReLU(),
          # nn.Dropout(0.5), # p is probability of zeroing
          nn.Linear(in_features=256, out_features=1, bias=True),
        )
        
    def forward(self,x):
        return self.classifier(x)
    
class MultiDropoutHead(nn.Module):
    def __init__(self,n_features):
        super().__init__()
        self.classifier = nn.Sequential(
          nn.Linear(in_features=n_features, out_features=256, bias=True),
          nn.ReLU(),
          # nn.Dropout(0.5), # p is probability of zeroing
          nn.Linear(in_features=256, out_features=1, bias=True),
        )
        self.dropout = lambda p: nn.Dropout(p)
        
    def forward(self,x):
        return torch.mean(torch.stack([
            self.classifier(self.dropout(p)(x))
            for p in np.linspace(0.3,0.7, 5)
        ], dim=0), dim=0)
    
class GeM(nn.Module):
    def __init__(self, p=3, eps=1e-6):
        super(GeM,self).__init__()
        self.p = nn.Parameter(torch.ones(1)*p)
        self.eps = eps

    def forward(self, x):
        x = self.gem(x, p=self.p, eps=self.eps)
        x = torch.flatten(x,start_dim=1,end_dim=-1)
        return x
        
    def gem(self, x, p=3, eps=1e-6):
        return torch.nn.functional.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1./p)
        
    def __repr__(self):
        return self.__class__.__name__ + '(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + ', ' + 'eps=' + str(self.eps) + ')'

In [None]:

def whiten(signal):
    # From here: https://www.kaggle.com/kevinmcisaac/g2net-spectral-whitening
    length = signal.size(2)
    hann = torch.hann_window(length, periodic=True, dtype=float).view(1,1,-1)
    spec = fft(signal* hann, dim = 2)
    mag = torch.sqrt(torch.real(spec*torch.conj(spec))) 

    return torch.real(ifft(spec/mag)) * np.sqrt(length/2)

def batch_preprocessing(X):
    # X = whiten(X)
    X = X.numpy()
    if CFG.BPfilter:        
        X = filterSig(X,axis=2).copy()
    X = torch.tensor(X).float()
    return X      
        

Head = MultiDropoutHead#BasicHead

model_no = 1

def Backbone():
    if 'tf_efficientnet' in CFG.model_name:
        model = timm.create_model(CFG.model_name, pretrained=CFG.pretrained)
        n_features = model.classifier.in_features
        model.classifier = nn.Identity()
        if True:
            model.global_pool = GeM()
    elif 'efficientnet-' in CFG.model_name:
        model = efficientnet_pytorch.EfficientNet.from_pretrained(CFG.model_name)
        n_features = model._fc.in_features
        model._fc = nn.Identity()
    elif 'rexnet_' in CFG.model_name:
        model = timm.create_model(CFG.model_name, pretrained=CFG.pretrained)
        n_features = model.head.fc.in_features
        model.head.fc = nn.Identity()
        
    return model, n_features

def batch_postCQTprocessing(x):
    if CFG.SCALE == 'log':
        x = (torch.log10(x) + 1.)/1.5
    else:
        x = (torch.clamp(x,max=2.5)-1)
        #x = torch.divide(x,torch.mean(x,dim=2,keepdims = True))
    return x

def batch_postCQTprocessing_experimental(x):
    x = torch.unsqueeze(x,1)
    row_mean = torch.mean(x,dim=3,keepdim = True)
    x = x/row_mean
    # x = torch.nn.functional.conv2d(x,torch.ones(1,1,5,1).to(device),padding = (2,1))
    x_std,x_mean = torch.std_mean(x,dim=(2,3),keepdim = True)
    x = (x-x_mean)/x_std
    # x = (torch.clamp(x, max=2.5)-1)
    x = torch.squeeze(x,1)
    return x


if model_no == 1:
    print('Selecting multi channel model ... ')
    class Model(nn.Module):
        def __init__(self, get_spectrogram = False):
            self.get_spectrogram = get_spectrogram
            super().__init__()
            self.q_transform = CQT1992v2(
                **CFG.cqt_params
            )
            if not self.get_spectrogram:
                
                self.model, n_features = Backbone()
                self.head = Head(n_features)

        def freezeModel(self):
            for param in self.model.parameters():
                param.requires_grad = False

        def unfreezeModel(self):
            for param in self.model.parameters():
                param.requires_grad = True

        def forward(self, x):
            # reshape from [batch by chan by time] [(batch x chan) by time]
            batch_size = x.size(0)

            x = torch.divide(x,torch.max(torch.abs(x),dim=2,keepdims = True)[0])
            x = torch.reshape(x,(batch_size*3,-1))
            x = self.q_transform(x)
            x = x[:,0:-1,0:-1]
            if self.get_spectrogram:
                x_expt = batch_postCQTprocessing_experimental(x)
                size = list(x_expt.size())
                x_expt = torch.reshape(x_expt,(batch_size,3,size[1],size[2]))
                x = batch_postCQTprocessing(x)
                size = list(x.size())
                x = torch.reshape(x,(batch_size,3,size[1],size[2]))
                return x, x_expt
            x = batch_postCQTprocessing(x)
            size = list(x.size())
            x = torch.reshape(x,(batch_size,3,size[1],size[2]))
            x = self.model(x)
            out = self.head(x)
            return out
elif model_no == 2:        
    print('Selecting single channel model ... ')
    
    class Model(nn.Module):
        def __init__(self, get_spectrogram = False):
            self.get_spectrogram = get_spectrogram
            super().__init__()
            self.q_transform = CQT1992v2(
                **CFG.cqt_params
            )
            if not self.get_spectrogram:
                self.model, n_features = Backbone()
                self.head = Head(n_features)

        def freezeModel(self):
            for param in self.model.parameters():
                param.requires_grad = False

        def unfreezeModel(self):
            for param in self.model.parameters():
                param.requires_grad = True

        def forward(self, x):
            # reshape from [batch by chan by time] [(batch x chan) by time]
            batch_size = x.size(0)

            x = torch.divide(x,torch.max(torch.abs(x),dim=2,keepdims = True)[0])
            x = torch.reshape(x,(batch_size*3,-1))
            x = self.q_transform(x)
            x = x[:,0:-1,0:-1]
            x = batch_postCQTprocessing(x)
            if self.get_spectrogram:
                size = list(x.size())
                x = torch.reshape(x,(batch_size,3,size[1],size[2]))
                return x
            
            x = torch.unsqueeze(x,1)
            
            # x_mean = torch.mean(x,dim=1,keepdims = True)
            # x = torch.stack([x,x_mean],dim=1)
            
            x = self.model(x)
            size = list(x.size())
            x = torch.reshape(x,(batch_size,-1,size[1]))
            x = torch.max(x,dim=1,keepdims = False)[0]
            out = self.head(x)
            out = out
            return out
        
elif model_no == 3:
    print('Selecting experimental model ... ')
    class Model(nn.Module):
        def __init__(self, get_spectrogram = False):
            self.get_spectrogram = get_spectrogram
            super().__init__()
            self.q_transform = CQT1992v2(
                **CFG.cqt_params
            )
            if not self.get_spectrogram:
                self.model, n_features = Backbone()
                self.head = Head(2*n_features)

        def freezeModel(self):
            for param in self.model.parameters():
                param.requires_grad = False

        def unfreezeModel(self):
            for param in self.model.parameters():
                param.requires_grad = True

        def forward(self, x):
            # reshape from [batch by chan by time] [(batch x chan) by time]
            batch_size = x.size(0)

            x = torch.divide(x,torch.max(torch.abs(x),dim=2,keepdims = True)[0])
            x = torch.reshape(x,(batch_size*3,-1))
            x = self.q_transform(x)
            x = x[:,0:-1,0:-1]
            x = batch_postCQTprocessing(x)
            if self.get_spectrogram:
                size = list(x.size())
                x = torch.reshape(x,(batch_size,3,size[1],size[2]))
                return x
            
            x = torch.unsqueeze(x,1)
            
            size = list(x.size())
            x = torch.reshape(x,(batch_size,3,size[2],size[3]))
            x_mean = torch.mean(x,dim=1,keepdims = True)
            x = torch.cat([x,x_mean],dim=1)
            x = torch.reshape(x,(batch_size*4,1,size[2],size[3]))
            
            x = self.model(x)
            size = list(x.size())
            x = torch.reshape(x,(batch_size,-1,size[1]))
            x_mean = x[:,3,:].squeeze(1)
            x = torch.max(x,dim=1,keepdims = False)[0]
            # this will be [batch by (2 x n_features)]
            x = torch.cat([x,x_mean],dim=1)
            
            out = self.head(x)
            out = out
            return out
    




In [None]:

if CFG.VISUALIZE:
    import matplotlib.pyplot as plt
    modelTmp = Model(get_spectrogram = True)
    for step, batch in enumerate(train_loader_soft,1):
        X = batch["X"]
        X = batch_preprocessing(X)
        modelTmp.to(device)
        X = X.to(device)
        targets = batch["y"].to(device)
        outputs, outputs_expt = modelTmp(X)
        n = np.random.randint(32)
        tmp = outputs[n].cpu().numpy()
        tmp_expt = outputs_expt[n].cpu().numpy()
        for i in range(3):
            fig, (ax1, ax2) = plt.subplots(1, 2, figsize = (13,4))
            fig.suptitle('Target = '+ str(targets[n]))
            plt1 = ax1.imshow((tmp[i,:,:]).squeeze())
            #plt.colorbar(plt1)
            plt2 = ax2.imshow((tmp_expt[i,:,:]).squeeze())
            #plt.colorbar(plt2)
        plt.figure()
        plt.title('Target = '+ str(targets[n]))
        plt.imshow(np.mean(tmp[:,:,:],axis=0).squeeze())
        plt.colorbar()
        print(tmp.shape)
        break

In [None]:
if CFG.VISUALIZE:
    for step, batch in enumerate(test_loader_withlabels,1):
        X = batch["X"]
        X = batch_preprocessing(X)
        modelTmp.to(device)
        X = X.to(device)
        targets = batch["y"].to(device)
        outputs, outputs_expt = modelTmp(X)
        n = np.random.randint(32)
        tmp = outputs[n].cpu().numpy()
        tmp_expt = outputs_expt[n].cpu().numpy()
        for i in range(3):
            fig, (ax1, ax2) = plt.subplots(1, 2, figsize = (13,4))
            fig.suptitle('Target = '+ str(targets[n]))
            plt1 = ax1.imshow((tmp[i,:,:]).squeeze())
            #plt.colorbar(plt1)
            plt2 = ax2.imshow((tmp_expt[i,:,:]).squeeze())
            #plt.colorbar(plt2)
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize = (13,4))
        fig.suptitle('Target = '+ str(targets[n]))
        plt1 = ax1.imshow(np.transpose(tmp_expt[:,:,:]+1,(1,2,0)).squeeze())
        #plt.colorbar(plt1)
        plt2 = ax2.imshow(np.transpose(tmp[:,:,:],(1,2,0)).squeeze())
        plt.title('Target = '+ str(targets[n]))
        print(tmp.shape)
        break

# Loss related functions

In [None]:
class LossMeter:
    def __init__(self):
        self.avg = 0
        self.n = 0

    def update(self, val):
        self.n += 1
        # incremental update
        self.avg = val / self.n + (self.n - 1) / self.n * self.avg

        
class AccMeter:
    def __init__(self):
        self.avg = 0
        self.n = 0
        
    def update(self, y_true, y_pred):
        y_true = y_true.cpu().round().numpy().astype(int)
        y_pred = y_pred.cpu().numpy() >= 0
        last_n = self.n
        self.n += len(y_true)
        true_count = np.sum(y_true == y_pred)
        # incremental update
        self.avg = true_count / self.n + last_n / self.n * self.avg
        


# Trainer related functions

In [None]:
def get_scheduler(optimizer):
    if CFG.scheduler=='ReduceLROnPlateau':
        scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=CFG.factor, patience=CFG.patience, verbose=True, eps=CFG.eps)
    elif CFG.scheduler=='CosineAnnealingLR':
        scheduler = CosineAnnealingLR(optimizer, T_max=CFG.T_max, eta_min=CFG.min_lr, last_epoch=-1)
    elif CFG.scheduler=='CosineAnnealingWarmRestarts':
        scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=CFG.T_0, T_mult=1, eta_min=CFG.min_lr, last_epoch=-1)
    return scheduler

class Trainer:
    def __init__(
        self, 
        model, 
        device, 
        optimizer, 
        criterion, 
        loss_meter, 
        score_meter,
        use_swa = False
    ):
        self.model = model
        # freeze model by default
        self.model.freezeModel()
        
        self.device = device
        self.use_swa = use_swa
        self.optimizer = swa.SWA(optimizer) if self.use_swa else optimizer
        self.criterion = criterion
        self.loss_meter = loss_meter
        self.score_meter = score_meter
        self.scheduler = get_scheduler(optimizer)
        self.learning_rate = self.scheduler.get_lr()
        
        self.best_valid_score = -np.inf
        self.n_patience = 0
        
        self.messages = {
            "epoch": "[Epoch {}: {}] loss: {:.5f}, score: {:.5f}, auc_score: {:.5f}, time: {} s",
            "checkpoint": "The score improved from {:.5f} to {:.5f}. Save model to '{}'",
            "patience": "\nValid score didn't improve last {} epochs."
        }
        self.training_step = 0
        self.prevbatch = []
        self.epoch = -1 
    
    def fit(self, epochs, train_loader, valid_loader, save_path, patience,train_loader_soft = False):        
        for n_epoch in range(1, epochs + 1):
            
            self.epoch = n_epoch
            
            self.save_path = save_path
            
            self.info_message("EPOCH: {}", n_epoch)
            
            if self.epoch==1 or (not train_loader_soft):
                train_loss, train_score, train_time = self.train_epoch(train_loader)
            else:
                print('Using soft labels for this epoch... ')
                train_loss, train_score, train_time = self.train_epoch(train_loader_soft)
                
            valid_loss, valid_score, valid_time, valid_rocauc = self.valid_epoch(valid_loader)
            
            self.info_message(
                self.messages["epoch"], "Train", n_epoch, train_loss, train_score, 0, train_time
            )
            
            self.info_message(
                self.messages["epoch"], "Valid", n_epoch, valid_loss, valid_score, valid_rocauc, valid_time
            )
            

            if self.best_valid_score < valid_score:
                self.info_message(
                    self.messages["checkpoint"], self.best_valid_score, valid_score, save_path
                )
                self.best_valid_score = valid_score
                self.save_model(n_epoch, save_path)
                self.n_patience = 0
            else:
                self.n_patience += 1
            
            if self.n_patience >= patience:
                self.info_message(self.messages["patience"], patience)
                break
        if self.use_swa:
            self.optimizer.bn_update(train_loader, self.model)
            self.optimizer.swap_swa_sgd()
        
    def train_epoch(self, train_loader):
        self.model.train()
        t = time.time()
        train_loss = self.loss_meter()
        train_score = self.score_meter()
        
        for step, batch in enumerate(tqdm(train_loader),1):
            
            if self.training_step == CFG.unfreezeStep:
                self.model.unfreezeModel()
            
            X = batch["X"]
            targets = batch["y"]
            if self.prevbatch:
                prevX = self.prevbatch["X"]
                prevY = self.prevbatch['y']
                rndNum = np.random.rand()
                if rndNum<0.5:
                    # only keep prevX where there is no wave
                    prevX = torch.where(prevY.view(-1,1,1)>0.5,X,prevX)
                    prevY = torch.where(prevY>0.5,targets,prevY)
                    
                    # weight for prevX is at most 0.5, and not replaced
                    # when there is a wave
                    X = (1-rndNum)*X + rndNum*prevX 
                    # /2 to prevent target 0 from exceeding 0.5
                    # targets = torch.clamp(((1-rndNum)*targets + rndNum*prevY/2)*1.5,max = 1.)

            self.prevbatch = batch.copy()  
            X = batch_preprocessing(X)
            X = X.to(self.device)
            targets = targets.to(self.device)
            
            self.optimizer.zero_grad()
            outputs = self.model(X).squeeze(1)
            
            loss = self.criterion(outputs, targets)
            loss.backward()

            train_loss.update(loss.detach().item())
            train_score.update(targets, outputs.detach())

            self.optimizer.step()
            
            _loss, _score = train_loss.avg, train_score.avg
            
            message = 'Train Step {}/{}, train_loss: {:.5f}, train_score: {:.5f}, learning_rate: {:.5f}/{:.5f}'
            self.info_message(message, step, len(train_loader), _loss, _score, self.learning_rate[0], self.learning_rate[1],end="\r")
            self.training_step += 1
            
            if self.training_step%CFG.schedulerStepFreq==0:
                if isinstance(self.scheduler, CosineAnnealingLR):
                    self.scheduler.step()
                elif isinstance(self.scheduler, CosineAnnealingWarmRestarts):
                    self.scheduler.step()
                self.learning_rate = self.scheduler.get_lr()
        # print('\n Updated learning rate: '+ str(self.scheduler.get_lr()))
            if self.training_step%CFG.valStepFreq==0 and self.epoch>2:
                valid_loss, valid_score, valid_time, valid_rocauc = self.valid_epoch(valid_loader)
                self.info_message(
                    self.messages["epoch"], "Valid", self.epoch, valid_loss, valid_score, valid_rocauc, valid_time
                )

                if self.best_valid_score < valid_score:
                    self.info_message(
                        self.messages["checkpoint"], self.best_valid_score, valid_score, self.save_path
                    )
                    self.best_valid_score = valid_score
                    self.save_model(n_epoch, self.save_path)
                    self.n_patience = 0
                else:
                    self.n_patience += 1
        
        return train_loss.avg, train_score.avg, int(time.time() - t)
    
    def valid_epoch(self, valid_loader,returnPred = False):
        self.model.eval()
        t = time.time()
        valid_loss = self.loss_meter()
        valid_score = self.score_meter()
        
        for step, batch in enumerate(valid_loader, 1):
            y_pred = []
            tgts = []
            with torch.no_grad():
                X = batch["X"]  
                X = batch_preprocessing(X)
                X = X.to(self.device)
                targets = batch["y"].to(self.device)

                outputs = self.model(X).squeeze(1)
                loss = self.criterion(outputs, targets)

                valid_loss.update(loss.detach().item())
                valid_score.update(targets, outputs)
                outputs = outputs
                y_pred.extend(torch.sigmoid(outputs).cpu().numpy().squeeze())
                tgts.extend(batch["y"].numpy())
                    
            rocauc = roc_auc_score(tgts,y_pred)
            _loss, _score = valid_loss.avg, valid_score.avg
            message = 'Valid Step {}/{}, valid_loss: {:.5f}, valid_score: {:.5f},valid_roc_auc: {:.5f}'
            self.info_message(message, step, len(valid_loader), _loss, _score, rocauc, end="\r")
        if not returnPred:
            return valid_loss.avg, valid_score.avg, int(time.time() - t), rocauc
        else:
            return y_pred, tgts
    
    def test_eval(self,test_loader):
        y_pred = []
        ids = []
        for e, batch in enumerate(test_loader):
            print(f"{e}/{len(test_loader)}", end="\r")
            with torch.no_grad():
                X = batch["X"]
                X = batch_preprocessing(X)
                X = X.to(self.device)
                outputs = self.model(X)
                y_pred.extend(torch.sigmoid(outputs).cpu().numpy().squeeze())
                ids.extend(batch["id"])
        return y_pred, ids
    
    def save_model(self, n_epoch, save_path):
        torch.save(
            {
                "model_state_dict": self.model.state_dict(),
                "optimizer_state_dict": self.optimizer.state_dict(),
                "best_valid_score": self.best_valid_score,
                "n_epoch": n_epoch,
            },
            save_path,
        )
    
    @staticmethod
    def info_message(message, *args, end="\n"):
        print(message.format(*args), end=end)

In [None]:

model = Model()
model.to(device)
if (not CFG.TRAIN) or CFG.WARM_START:
    startCheckpoint = torch.load("../input/exptv2v4-best-model-fold0b0/best-model.pth")
    checkpoint = torch.load("../input/exptv2v4-best-model-fold0b0/best-model.pth")
    model.load_state_dict(startCheckpoint["model_state_dict"])

optimizer = torch.optim.Adam([{"params": model.model.parameters(), "lr": CFG.lr},
                              {"params": model.head.parameters(), "lr": CFG.lr/10}], 
                             lr=CFG.lr)
criterion = torch_functional.binary_cross_entropy_with_logits

trainer = Trainer(
    model, 
    device, 
    optimizer, 
    criterion, 
    LossMeter, 
    AccMeter
)

if CFG.TRAIN:
    history = trainer.fit(
        CFG.EPOCHS, 
        train_loader, 
        valid_loader, 
        "best-model.pth", 
        400,
        train_loader_soft = train_loader_soft if CFG.useSoftLabels else False
    )
    
    y_pred_val,tgts = trainer.valid_epoch(valid_loader,returnPred = True)

In [None]:
plt.scatter(tgts,y_pred_val,1)
plt.xlabel('targets')
plt.ylabel('predictions')

In [None]:
if CFG.TRAIN:
    checkpoint = torch.load("best-model.pth")

model.load_state_dict(checkpoint["model_state_dict"])
model.eval();

In [None]:
import gc

gc.collect()



y_pred, ids = trainer.test_eval(test_loader)

In [None]:
submission = pd.read_csv("../input/g2net-gravitational-wave-detection/sample_submission.csv")
submission = pd.DataFrame({"id": submission['id'].values, "target": y_pred})
submission.to_csv("model_submission.csv", index=False)

In [None]:
submission