In [None]:
"""
General Approach:
    Model: Pretrained resnet34d
    Data: Stack of the 3 signal-channels
    Albumentations: Resize(512,256), HorizontalFlip, VerticalFlip, Rotate(lim=(-10,10))
    Training: MixUp (alpha=0.4), CosineAnnealing, Adam (weight_decay: 1e-5)
    Testing: with TTA
"""

In [None]:
import os
import glob
import random

import numpy as np 
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from torch.utils.data import Dataset, random_split, DataLoader
from torchvision import transforms

import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, EarlyStopping

from sklearn.metrics import roc_auc_score
import cv2
import albumentations as A
from albumentations.pytorch import ToTensorV2

# vision-models
!pip install timm
import timm

# config
!pip install omegaconf
from omegaconf import OmegaConf

# logging
import wandb
from pytorch_lightning.loggers import WandbLogger
from tqdm import tqdm

In [None]:
class SetiDataset(Dataset):
    def __init__(self, df, cfg, train=True, transform=None):
        super(SetiDataset, self).__init__()
        self.df = df
        self.cfg = cfg
        self.train = train
        self.transform = transform
        
        self.file_names = df.file_path.values
        
        if train:
            self.targets = df.target.values
            
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        file_path = self.file_names[idx]
        image = np.load(file_path)
        
        image = image.astype(np.float32)
        
        if self.cfg.mode=='3ch_stacked' or self.cfg.mode=='3ch':
            image = image[::2] # [6//2,274,256]
        
        if self.cfg.mode=='3ch_stacked' or self.cfg.mode=='6ch_stacked':
            image = np.expand_dims(np.concatenate(image, axis=0), axis=0)
            
        image = np.moveaxis(image, 0, -1)

        if self.transform:
            image = self.transform(image=image)['image']

        image = torch.from_numpy(image)
        image = image.permute(2,0,1)
        
        if self.train:
            target = torch.tensor(self.targets[idx]).float()

            return image, target
        else:
            return image
    
class SetiDataModule(pl.LightningDataModule):
    def __init__(self, cfg, transform=None, test_transform=None):
        super(SetiDataModule, self).__init__()

        self.cfg = cfg
        self.transform = transform
        self.test_transform = test_transform
        
    def prepare_data(self):
        self.train_df = pd.read_csv('../input/seti-breakthrough-listen/train_labels.csv')
        self.test_df = pd.read_csv('../input/seti-breakthrough-listen/sample_submission.csv')

        def get_train_file_path(image_id):
            return "../input/seti-breakthrough-listen/train/{}/{}.npy".format(image_id[0], image_id)

        def get_test_file_path(image_id):
            return "../input/seti-breakthrough-listen/test/{}/{}.npy".format(image_id[0], image_id)

        self.train_df['file_path'] = self.train_df['id'].apply(get_train_file_path)
        self.test_df['file_path'] = self.test_df['id'].apply(get_test_file_path)


    def setup(self, stage=None):
        dataset = SetiDataset(self.train_df, self.cfg, train=True, transform=self.transform)
        
        n_train = int(len(dataset)*self.cfg.train_split+0.5)
        n_val = int(len(dataset)*(1-self.cfg.train_split)+0.5)
        
        self.train_dataset, self.val_dataset = random_split(dataset, [n_train, n_val])
        self.test_dataset = SetiDataset(self.test_df, self.cfg, train=False, transform=self.test_transform)
    
    def train_dataloader(self):
        train_loader = DataLoader(self.train_dataset, batch_size=self.cfg.batch_size, num_workers=8, shuffle=True)
        return train_loader

    def val_dataloader(self):
        val_loader = DataLoader(self.val_dataset, batch_size=self.cfg.batch_size, num_workers=8, shuffle=False)
        return val_loader
    
    def test_dataloader(self):
        test_loader = DataLoader(self.test_dataset, batch_size=self.cfg.test_batch_size, num_workers=8, shuffle=False)
        return test_loader

In [None]:
class SetiModel(nn.Module):
    def __init__(self, cfg):
        super(SetiModel, self).__init__()
        
        self.cfg = cfg
        
        if cfg.mode=='3ch_stacked' or cfg.mode=='6ch_stacked':
            in_chans = 1
        elif cfg.mode=='3ch':
            in_chans = 3
        elif cfg.mode=='6ch':
            in_chans = 6
            
        self.net = timm.create_model(self.cfg.model, pretrained=True, num_classes=1, in_chans=in_chans)
        #modules = list(net.children())        
        #self.net = nn.Sequential(*modules)

    def forward(self, input):
        return self.net(input)

class SetiLightningModule(pl.LightningModule): 
    def __init__(self, model, cfg): 
        super(SetiLightningModule, self).__init__() 
        self.cfg = cfg
        self.lr = cfg.lr
        
        self.model = model
        
        self.loss = nn.BCEWithLogitsLoss()
        self.val_loss = nn.BCEWithLogitsLoss() 

    def forward(self, X):
        return self.model(X)
    
    def predict(self, dataloader):
        preds = []
        for i, batch in tqdm(enumerate(dataloader)):
            pred, output = self.forward(batch)
            preds.append(pred.cpu())
  
    def configure_optimizers(self): 
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.cfg.weight_decay)
        #scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.3, patience=1, min_lr=1e-6, verbose=True)
        #scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=6, T_mult=1, eta_min=1e-6, last_epoch=-1)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=3, eta_min=1e-6, last_epoch=-1)
        
        return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "val_loss"}            
        
    def mixup_data(self, X, y, alpha=1.0):
        # Thanks go to Salmon (https://www.kaggle.com/micheomaano/efficientnet-b4-mixup-cv-0-98-lb-0-97)
        # for the code for mixup
        
        if alpha > 0:
            lam = np.random.beta(alpha, alpha)
        else:
            lam = 1

        batch_size = X.size()[0]
        index = torch.randperm(batch_size).cuda()

        mixed_X = lam * X + (1 - lam) * X[index, :]
        y_a, y_b = y, y[index]
        return mixed_X, y_a, y_b, lam      
        
    def mixup_criterion(self, criterion, pred, y_a, y_b, lam):
        return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

    def training_step(self, train_batch, batch_idx): 
        X, y = train_batch 
        
        if self.cfg.mixup:
            if np.random.choice(2, p=[1-self.cfg.p_mixup, self.cfg.p_mixup]):
                X, y_a, y_b, lam = self.mixup_data(X, y.unsqueeze(1), alpha=self.cfg.alpha)                                     
                y_pred = self.forward(X)
                loss = self.mixup_criterion(self.loss, y_pred, y_a, y_b, lam)

            else:
                y_pred = self.forward(X)
                loss = self.loss(y_pred, y.unsqueeze(1))
        else:
            y_pred = self.forward(X)
            loss = self.loss(y_pred, y.unsqueeze(1))
        
        try:
            roc_auc = roc_auc_score(y.unsqueeze(1).detach().cpu().numpy(), y_pred.sigmoid().detach().cpu().numpy())
        except:
            roc_auc = 1
        
        self.log('train_loss', loss, on_step=True, on_epoch=False, prog_bar=False)
        self.log('roc_auc', roc_auc, on_step=True, on_epoch=False, prog_bar=True)

        return loss 
  
    def validation_step(self, valid_batch, batch_idx): 
        X, y = valid_batch 
        y_pred = self.forward(X)
        val_loss = self.val_loss(y_pred.sigmoid(), y.unsqueeze(1).to(torch.float32))
        
        try:
            roc_auc = roc_auc_score(y.unsqueeze(1).detach().cpu().numpy(), y_pred.sigmoid().detach().cpu().numpy())
        except:
            roc_auc = 1
            
        self.log('val_loss', val_loss, on_step=True, on_epoch=True, prog_bar=False)
        self.log('val_roc_auc', roc_auc, on_step=True, on_epoch=True, prog_bar=True)

    def test_step(self, test_batch, batch_idx):
        y_pred = self.forward(test_batch).sigmoid()
        return y_pred[:,0] #torch.zeros(len(test_batch), 1)
    
    def test_epoch_end(self, outputs):
        y_preds = torch.cat(outputs).detach().cpu().numpy()
        df = pd.DataFrame({'target':y_preds})
        N = len(glob.glob('submission*.csv'))
        df.target.to_csv(f'submission{N}.csv')

In [None]:
def mergeTTAs(submissions):
    preds = []
    for submission in submissions:
        df = pd.read_csv(submission)
        preds.append(df.target)
    preds = np.array(preds).T
    submission = np.mean(preds, axis=1)
    return submission

In [None]:
def set_seed(seed = 0):
    #REPRODUCIBILITY
    np.random.seed(seed)
    random_state = np.random.RandomState(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)
    return random_state

random_state = set_seed(42)

In [None]:
cfg = dict(epochs=4,
           train_split=0.9,
           batch_size=24,
           test_batch_size=128,
           lr=5e-4,
           weight_decay=1e-5,
           
           mixup=True,
           alpha=0.4,
           p_mixup=0.99,
           
           #auto_find_lr=False,
           mode='3ch_stacked', # '6ch_stacked', 3ch_stacked', 6ch', 3ch'
           model='resnet34d', #'efficientnet_b2',
        
           N_TTAs=4,
           debug=False) 

cfg = OmegaConf.create(cfg)

In [None]:
transform = A.Compose([
    A.Resize(height=512, width=256),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.Rotate(limit=(-10, 10))
])

test_transform = A.Compose([
    A.Resize(height=512, width=256),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.Rotate(limit=(-10, 10))
])

In [None]:
model = SetiLightningModule(SetiModel(cfg), cfg)
datamodule = SetiDataModule(cfg, transform=transform, test_transform=test_transform)

In [None]:
#datamodule.prepare_data()
#datamodule.setup()

In [None]:
wandb_logger = WandbLogger(name='ResNet18_1224',
                           project='Seti', 
                           offline=False, 
                           log_model=False)

early_stop = EarlyStopping(monitor='val_loss', 
                           min_delta=0.00, 
                           patience=3, 
                           verbose=True,
                           mode='min')

lr_monitor = LearningRateMonitor(logging_interval='step')


trainer = pl.Trainer(gpus=1, 
                     max_epochs=cfg.epochs, 
                     progress_bar_refresh_rate=1,
                     callbacks=[lr_monitor],
                     logger=wandb_logger,
                     auto_lr_find=cfg.auto_find_lr)

if cfg.auto_find_lr:
    trainer.tune(model, datamodule=datamodule)
trainer.fit(model, datamodule)

In [None]:
for _ in range(cfg.N_TTAs):
    trainer.test(model, datamodule=datamodule)

In [None]:
submissions = glob.glob('submission*.csv')
final_submission = mergeTTAs(submissions)

In [None]:
submission = datamodule.test_df[['id', 'target']]
submission['target'] = final_submission

In [None]:
sub0 = pd.read_csv('submission0.csv').target.values
sub1 = pd.read_csv('submission1.csv').target.values
sub2 = pd.read_csv('submission2.csv').target.values
sub3 = pd.read_csv('submission3.csv').target.values
sub = submission.target.values

In [None]:
np.sum(sub==sub0), np.sum(sub==sub1), np.sum(sub==sub2), np.sum(sub==sub3)

In [None]:
submission.to_csv('submission.csv', index=False)

In [None]:
"""
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("GPU is available")
else:
    device = torch.device("cpu")
    print("GPU not available, CPU used")


def predict(model, dataloader):
    preds = []
    model = model.to(device)
    for batch in tqdm(iter(dataloader)):        
        with torch.no_grad():
            pred = model(batch.to(device))
        preds.append(pred.detach().cpu().numpy())
        
    return np.concatenate(preds)[:,0]
"""

In [None]:
#test_loader = DataLoader(datamodule.test_dataset, batch_size=128, num_workers=8, shuffle=False)

In [None]:
#preds = predict(model, test_loader)

#submission = datamodule.test_df[['id', 'target']]
#submission.target = preds

In [None]:
submission.head(128)