## References
* Jaegle, A., Gimeno, F., Brock, A., Zisserman, A., Vinyals, O., & Carreira, J. (2021). Perceiver: General perception with iterative attention. arXiv preprint arXiv:2103.03206 (https://arxiv.org/abs/2103.03206).

* Jaegle, A., Borgeaud, S., Alayrac, J. B., Doersch, C., Ionescu, C., Ding, D., ... & Carreira, J. (2021). Perceiver io: A general architecture for structured inputs & outputs. arXiv preprint arXiv:2107.14795 (https://arxiv.org/abs/2107.14795).

In [None]:
import numpy as np
import pandas as pd
from scipy import signal
import os
import matplotlib.pyplot as plt
from glob import glob
import torch
from sklearn.metrics import roc_auc_score
import time
import math

In [None]:
train_df = pd.read_csv("../input/g2net-gravitational-wave-detection/training_labels.csv")
train_df

In [None]:
paths = ["../input/g2net-gravitational-wave-detection/train/" + "/".join(id[:3]) + "/" + id + ".npy" for id in train_df.id.values]
train_df['path'] = paths
train_df

In [None]:
def whiten(x):
    for i in range(3):
        spec = np.fft.rfft(x[i])
        mag = np.sqrt(np.real(spec*np.conj(spec)))
        norm = np.sqrt(np.array([4096/2]))
        x[i] = np.fft.irfft(spec/mag) * norm
    return x

In [None]:
def apply_bandpass(x, lf=30, hf=500, order=8, sr=2048):
    sos = signal.butter(order, [lf, hf], btype="bandpass", output="sos", fs=sr)
    normalization = np.sqrt((hf - lf) / (sr / 2))
    for i in range(3):
        x[i] = signal.sosfiltfilt(sos, x[i]) / normalization
    return x

In [None]:
def preprocess(x):
    x = x / np.max(np.abs(x), axis=-1, keepdims=True)
    #scale = np.array([[1.5e-20], [1.5e-20], [0.5e-20]])
    #x = x / scale
    x *= signal.tukey(4096, 0.1)
    #x = whiten(x)
    x = apply_bandpass(x)
    return x

In [None]:
class DataSet:
    def __init__(self, paths, target, index):
        self.paths = [paths[i] for i in index]
        self.target = [target[i] for i in index]
        
    def __len__(self):
        return len(self.paths)

    def __getitem__(self, index):
        x = np.load(self.paths[index])
        x = preprocess(x)
        return x.astype(np.float32), self.target[index].astype(np.float32)

In [None]:
class SelfAttention(torch.nn.Module):
    def __init__(self, dim, out_dim):
        super(SelfAttention,self).__init__()
        self.dim = out_dim
        self.qkv_weight = torch.nn.Linear(dim, out_dim*3, bias=False)
    
    def forward(self, x):
        q, k, v = self.qkv_weight(x).chunk(3, dim=-1)
        att_logit = torch.bmm(q, k.transpose(1,2)) * (self.dim ** -0.5) # q * k
        att_weight = torch.softmax(att_logit, dim=-1)
        weighted_v = torch.bmm(att_weight, v) # q*k@k*dim == q * dim
        return weighted_v

In [None]:
class MultiheadSelfAttention(torch.nn.Module):
    def __init__(self, in_dim, out_dim, heads):
        super(MultiheadSelfAttention,self).__init__()
        head_out = out_dim // heads
        self.heads = torch.nn.ModuleList([SelfAttention(in_dim, head_out) for _ in range(heads)])
        self.output = torch.nn.Linear(out_dim, out_dim)
    def forward(self, x):
        outs = []
        for head in self.heads:
            out = head(x)
            outs.append(out)
        outs = torch.cat(outs, dim=-1)
        outs = self.output(outs)
        return outs

In [None]:
class CrossAttention(torch.nn.Module):
    def __init__(self, q_dim, kv_dim, out_dim):
        super(CrossAttention,self).__init__()
        self.dim = out_dim
        self.q_weight = torch.nn.Linear(q_dim, out_dim, bias=False)
        self.kv_weight = torch.nn.Linear(kv_dim, out_dim*2, bias=False)
    
    def forward(self, q, kv):
        q = self.q_weight(q)
        k, v = self.kv_weight(kv).chunk(2, dim=-1)
        att_logit = torch.bmm(q, k.transpose(1,2)) * (self.dim ** -0.5) # q * k
        att_weight = torch.softmax(att_logit, dim=-1)
        weighted_v = torch.bmm(att_weight, v) # q*k@k*dim == q * dim
        return weighted_v

In [None]:
class MultiheadCrossAttention(torch.nn.Module):
    def __init__(self, q_dim, kv_dim, out_dim, heads):
        super(MultiheadCrossAttention,self).__init__()
        head_out = out_dim // heads
        self.heads = torch.nn.ModuleList([CrossAttention(q_dim, kv_dim, head_out) for _ in range(heads)])
        self.output = torch.nn.Linear(out_dim, out_dim)
    def forward(self, q, kv):
        outs = []
        for head in self.heads:
            out = head(q, kv)
            outs.append(out)
        outs = torch.cat(outs, dim=-1)
        outs = self.output(outs)
        return outs

In [None]:
class Residual(torch.nn.Module):
    def __init__(self, fn):
        super(Residual,self).__init__()
        self.fn = fn
        
    def forward(self, x):
        y = x + self.fn(x)
        return y

In [None]:
class Feedforward(torch.nn.Module):
    def __init__(self, embed_dim, hidden_dim):
        super(Feedforward,self).__init__()
        self.hidden = torch.nn.Sequential(torch.nn.Linear(embed_dim, hidden_dim),
                                        torch.nn.GELU(),
                                         torch.nn.Linear(hidden_dim, embed_dim))
    def forward(self, x):
        return self.hidden(x)

In [None]:
class PerceiverEncoder(torch.nn.Module):
    def __init__(self, in_dim, embed_dim, hidden_dim, heads, length):
        super(PerceiverEncoder,self).__init__()
        self.attention = MultiheadCrossAttention(embed_dim, in_dim, embed_dim, heads)
        self.feedforward = Residual(Feedforward(embed_dim, hidden_dim))
        self.initial_latent = torch.nn.Parameter(torch.randn(length, embed_dim))
    
    def forward(self, x):
        latent = self.initial_latent.repeat((x.shape[0],1,1))
        x = self.attention(latent, x)
        x = x + latent
        x = self.feedforward(x)
        return x

In [None]:
class Encoder(torch.nn.Module):
    def __init__(self, in_dim, embed_dim, hidden_dim, heads, depth, encode_length):
        super(Encoder,self).__init__()
        self.encode = PerceiverEncoder(in_dim, embed_dim, hidden_dim, heads, encode_length)
        self.feedforwards = torch.nn.ModuleList([Residual(Feedforward(embed_dim, hidden_dim)) for _ in range(depth)])
        self.attentions = torch.nn.ModuleList([Residual(MultiheadSelfAttention(embed_dim, embed_dim, heads)) for _ in range(depth)])
    
    def forward(self, x):
        x = self.encode(x)
        for ffn, attn in zip(self.feedforwards, self.attentions):
            x = attn(x)
            x = ffn(x)
        return x

In [None]:
class Recognizer(torch.nn.Module):
    def __init__(self, embed_dim):
        super(Recognizer,self).__init__()
        self.logit = torch.nn.Linear(embed_dim*2, 1)
        self.out = torch.nn.Sigmoid()
    def forward(self, x):
        mean = torch.mean(x, dim=1)
        max = torch.max(x, dim=1)[0]
        x = torch.cat([mean, max], dim=1)
        x = self.logit(x)
        out = self.out(x).squeeze()
        return out, x.squeeze()

In [None]:
class CNN1d(torch.nn.Module):
    def __init__(self, embed_dim):
        super(CNN1d, self).__init__()
        self.CE = torch.nn.Sequential(torch.nn.BatchNorm1d(1),
                                    torch.nn.Conv1d(1, embed_dim//4, 16, stride=1, padding=0),
                                    torch.nn.GELU(),
                                    torch.nn.MaxPool1d(4, stride=4, padding=0),
                                    torch.nn.BatchNorm1d(embed_dim//4),
                                    torch.nn.Conv1d(embed_dim//4, embed_dim//2, 8, stride=1, padding=0),
                                    torch.nn.GELU(),
                                    torch.nn.MaxPool1d(4, stride=4, padding=0),
                                    torch.nn.BatchNorm1d(embed_dim//2),
                                    torch.nn.Conv1d(embed_dim//2, embed_dim, 8, stride=1, padding=0),
                                    torch.nn.GELU(),
                                    torch.nn.MaxPool1d(4, stride=4, padding=0))
        self.length = 61
    def forward(self, x):
        return self.CE(x)

In [None]:
class Embed(torch.nn.Module):
    def __init__(self, enc_dim, pe_dim, embed_dim, pe='fix'):
        super(Embed, self).__init__()
        dim = enc_dim
        self.CNN = CNN1d(dim)
        max_len = self.CNN.length
        if pe == 'fix':
            position = torch.arange(max_len).unsqueeze(1)
            div_term = torch.exp(torch.arange(0, pe_dim, 2) * (-math.log(10000.0) / pe_dim))
            pe = torch.zeros(max_len, pe_dim)
            pe[:, 0::2] = torch.sin(position * div_term)
            pe[:, 1::2] = torch.cos(position * div_term)
            pe = pe.repeat((3,1))
            detector = torch.zeros(max_len*3, 3)
            detector[:max_len, 0] = 1
            detector[max_len:max_len*2, 1] = 1
            detector[max_len*2:max_len*3, 2] = 1
            self.pe = torch.cat([pe, detector], dim=-1)
            out_dim = dim + pe_dim + 3
        elif pe == 'trainable':
            self.tpe = torch.nn.Parameter(torch.randn(max_len, pe_dim))
            detector = torch.zeros(max_len*3, 3)
            detector[:max_len, 0] = 1
            detector[max_len:max_len*2, 1] = 1
            detector[max_len*2:max_len*3, 2] = 1
            self.pe = torch.cat([self.tpe.repeat((3,1)), detector], dim=-1)
            out_dim = dim + pe_dim + 3
        else:
            self.pe = None
            out_dim = dim
        
        self.embed = torch.nn.Linear(out_dim, embed_dim, bias=False)

    def forward(self, x):
        ss = []
        for i in range(x.shape[1]):
            s = x[:,i,:].unsqueeze(1)
            fts = self.CNN(s).transpose(1, 2)
            std, mean = torch.std_mean(fts, dim=(1,2), unbiased=False, keepdim=True)
            fts = torch.div(fts-mean, std)
            ss.append(fts)
        x = torch.cat(ss, dim=1)
        if self.pe is not None:
            x = torch.cat([x, self.pe.to(x.device).repeat((x.shape[0],1,1))], dim=-1)
        x = self.embed(x)
        return x

In [None]:
class Model(torch.nn.Module):
    def __init__(self, enc_dim, pe_dim, embed_dim, hidden_dim, heads, depth, encode_length, pe='fix'):
        super(Model,self).__init__()
        self.embed = Embed(enc_dim, pe_dim, embed_dim, pe=pe)
        self.encoder = Encoder(embed_dim, embed_dim, hidden_dim, heads, depth, encode_length)
        self.recognizer = Recognizer(embed_dim)
    
    def forward(self, x):
        x = self.embed(x)
        x = self.encoder(x)
        out, x = self.recognizer(x)
        return out, x

In [None]:
seed = 2434
pe_dim = 128
enc_dim = 128
embed_dim = 256
hidden_dim = 256*4
heads = 4
depth = 3
encode_length = 64
pe = 'fix'
lr = 1e-4
weight_decay = 1e-5
train_batch_size = 256
test_batch_size = 512
optimizer = 'Adam'
grad_clip = 1000

In [None]:
np.random.seed(seed)
torch.manual_seed(seed)
train_index = np.random.rand(train_df.shape[0]) < 0.9
val_index = ~train_index
train_index = np.nonzero(train_index)[0]
val_index = np.nonzero(val_index)[0]
train_dataset = DataSet(train_df.path.values, train_df.target.values, train_index)
val_dataset = DataSet(train_df.path.values, train_df.target.values, val_index)

In [None]:
epochs = 10
device = 'cuda'
num_worker = os.cpu_count()
model = Model(enc_dim, pe_dim, embed_dim, hidden_dim, heads, depth, encode_length, pe).to(device)
best_model = Model(enc_dim, pe_dim, embed_dim, hidden_dim, heads, depth, encode_length, pe).to(device)
best_model.load_state_dict(model.state_dict())
criterion = torch.nn.BCEWithLogitsLoss(reduction='sum')
optim = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=train_batch_size,
                                         shuffle=True, drop_last=True, num_workers=num_worker, pin_memory=True)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=test_batch_size,
                                         shuffle=False, drop_last=False, num_workers=num_worker, pin_memory=True)
start_time = time.time()
train_loss_list = []
train_auc_list = []
val_loss_list = []
val_auc_list = []
best_auc = 0
for epoch in range(epochs):
    model.train()
    train_loss = 0
    train_auc = 0
    count = 0
    batch_count = 0
    preds = []
    targets = []
    for data, target in train_dataloader:
        optim.zero_grad()
        data = data.to(device)
        target = target.to(device)
        pred, x = model(data)
        loss = criterion(x, target)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
        optim.step()
        train_loss += loss.item()
        preds.append(pred.detach().cpu().numpy())
        targets.append(target.cpu().numpy())
        count += data.shape[0]
        batch_count += 1
    train_loss = train_loss / count
    preds = np.concatenate(preds, axis=0)
    targets = np.concatenate(targets, axis=0)
    train_auc = roc_auc_score(targets, preds)
    train_loss_list.append(train_loss)
    train_auc_list.append(train_auc)
    
    model.eval()
    with torch.no_grad():
        val_loss = 0
        val_auc = 0
        count = 0
        batch_count = 0
        preds = []
        targets = []
        for data, target in val_dataloader:
            data = data.to(device)
            target = target.to(device)
            pred, x = model(data)
            loss = criterion(x, target)
            val_loss += loss.item()
            preds.append(pred.detach().cpu().numpy())
            targets.append(target.cpu().numpy())
            count += data.shape[0]
            batch_count += 1
        val_loss = val_loss / count
        preds = np.concatenate(preds, axis=0)
        targets = np.concatenate(targets, axis=0)
        val_auc = roc_auc_score(targets, preds)
        val_loss_list.append(val_loss)
        val_auc_list.append(val_auc)
    spent_time = time.time() - start_time
    print(f'epoch: {epoch} train loss: {train_loss} train auc: {train_auc} val loss: {val_loss} val auc: {val_auc} time: {spent_time/60} min')
    if val_auc >= best_auc:
        best_auc = val_auc
        best_model.load_state_dict(model.state_dict())
    if spent_time >= 25000:
        print('time over')
        break

In [None]:
plt.plot(train_loss_list)
plt.plot(val_loss_list)

In [None]:
plt.plot(train_auc_list)
plt.plot(val_auc_list)

In [None]:
torch.save(model.state_dict(), "model")

In [None]:
paths = glob("../input/g2net-gravitational-wave-detection/test/*/*/*/*")
ids = [path.split("/")[-1].split(".")[0] for path in paths]
test_df = pd.DataFrame({"path":paths,"id":ids})
test_df['target'] = 0.0
test_df = test_df.set_index('id')
test_df = test_df.sort_index()
test_df

In [None]:
class TestDataSet:
    def __init__(self, paths, ids):
        self.paths = paths
        self.ids = ids

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, index):
        x = np.load(self.paths[index])
        x = preprocess(x).astype(np.float32)
        return x, self.ids[index]

In [None]:
test_dataset = TestDataSet(test_df.path.values, test_df.index.values)

In [None]:
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=512,
                                         shuffle=False, drop_last=False, num_workers=num_worker, pin_memory=True)

In [None]:
best_model.eval()
with torch.no_grad():
    for data, ids in test_dataloader:
        data = data.to(device)
        pred, x = best_model(data)
        test_df.loc[list(ids),'target'] = pred.cpu().numpy()

In [None]:
test_df

In [None]:
test_df.to_csv('submission.csv', columns=['target'])