In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
from fastai.data.all import *
from fastai.vision.all import *
import matplotlib.pyplot as plt
%matplotlib inline
from PIL import Image

import cv2
from tqdm.notebook import tqdm
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import glob
import os
#for dirname, _, filenames in os.walk('/kaggle/input'):
#    for filename in filenames:
#        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
!pip install timm

In [None]:
train_files = glob.glob('/kaggle/input/seti-breakthrough-listen/train/*/*.*',recursive=True)
len(train_files)

In [None]:
lbl = pd.read_csv('/kaggle/input/seti-breakthrough-listen/train_labels.csv')
lbl.head()

In [None]:
def get_file_train(_id):
    return f"../input/seti-breakthrough-listen/train/{_id[0]}/{_id}.npy"
lbl['img_path']=lbl['id'].apply(get_file_train)

In [None]:
class Signal:
    def __init__(self,df):
        self.df=df
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self,index):
        label = self.df.iloc[index].target
        f = self.df.iloc[index].img_path
        img = np.load(f).astype(np.float32)
        data = np.vstack(img).transpose((1, 0))
        data = cv2.resize(data, dsize=(256,256))     
        data = torch.tensor(data).float().unsqueeze(0)
        
        return data ,torch.tensor(label)
        

In [None]:
train_df, valid_df = train_test_split(lbl, test_size=0.2)

In [None]:
train_ds = Signal(train_df)
valid_ds = Signal(valid_df)

bs = 128
train_dl = torch.utils.data.DataLoader(train_ds, batch_size=bs)
valid_dl = torch.utils.data.DataLoader(valid_ds, batch_size=bs)

In [None]:
dls = DataLoaders(train_dl, valid_dl)

In [None]:
from timm import create_model
from fastai.vision.learner import _update_first_layer

def create_timm_body(arch:str, pretrained=True, cut=None, n_in=3):
    "Creates a body from any model in the `timm` library."
    model = create_model(arch, pretrained=pretrained, num_classes=0, global_pool='')
    _update_first_layer(model, n_in, pretrained)
    if cut is None:
        ll = list(enumerate(model.children()))
        cut = next(i for i,o in reversed(ll) if has_pool_type(o))
    if isinstance(cut, int): return nn.Sequential(*list(model.children())[:cut])
    elif callable(cut): return cut(model)
    else: raise NamedError("cut must be either integer or function")
        
def create_timm_model(arch:str, n_out, cut=None, pretrained=True, n_in=3, init=nn.init.kaiming_normal_, custom_head=None,
                     concat_pool=True, **kwargs):
    "Create custom architecture using `arch`, `n_in` and `n_out` from the `timm` library"
    body = create_timm_body(arch, pretrained, None, n_in)
    if custom_head is None:
        nf = num_features_model(nn.Sequential(*body.children()))
        head = create_head(nf, n_out, concat_pool=concat_pool, **kwargs)
    else: head = custom_head
    model = nn.Sequential(body, head)
    if init is not None: apply_init(model[1], init)
    return model

def timm_learner(dls, arch:str, loss_func=None, pretrained=True, cut=None, splitter=None,
                y_range=None, config=None, n_in=3, n_out=None, normalize=True, **kwargs):
    "Build a convnet style learner from `dls` and `arch` using the `timm` library"
    if config is None: config = {}
    if n_out is None: n_out = get_c(dls)
    assert n_out, "`n_out` is not defined, and could not be inferred from data, set `dls.c` or pass `n_out`"
    if y_range is None and 'y_range' in config: y_range = config.pop('y_range')
    model = create_timm_model(arch, n_out, default_split, pretrained, n_in=n_in, y_range=y_range, **config)
    learn = Learner(dls, model, loss_func=loss_func, splitter=default_split, **kwargs)
    if pretrained: learn.freeze()
    return learn

In [None]:
def roc_auc(preds,targ):
    try: return roc_auc_score(targ.cpu(),preds.squeeze().cpu())
    except: return 0.5

In [None]:
import timm
timm.list_models(pretrained=True)

In [None]:
learn = timm_learner(dls,'wide_resnet50_2',
                     pretrained=True,n_in=1,n_out=1,metrics=[roc_auc], opt_func=ranger, 
                     loss_func=BCEWithLogitsLossFlat()).to_fp16()

In [None]:
learn.lr_find()

In [None]:
learn.fit_one_cycle(1, 0.019,cbs=[ReduceLROnPlateau()])

In [None]:
import torch

In [None]:
torch.save(learn.state_dict(), 'model.hdf5')

In [None]:
#from torch_imports import load_model
model=torch.load('model.hdf5')

In [None]:
def save_model(m, p): torch.save(m.state_dict(), p)
def load_model(m, p): m.load_state_dict(torch.load(p, map_location=lambda storage, loc: storage))

In [None]:
torch.save(learn.state_dict(), 'model.pth')

In [None]:
model=learn.load_state_dict(torch.load('model.pth', map_location=lambda storage, loc: storage))

In [None]:
learn.recorder.plot_loss()

In [None]:
test_files = glob.glob('/kaggle/input/seti-breakthrough-listen/test/*/*.*',recursive=True)
test_files[:10]

In [None]:
tdl = dls.test_dl(test_files)

In [None]:
preds,_ = learn.get_preds(dl=tdl)

In [None]:
image_ids = np.array([fname.split('/')[-1].split('.npy')[0] for fname in test_files])
len(image_ids)

In [None]:
preds.numpy().shape

In [None]:
submission = pd.DataFrame({'id':image_ids, 'target':preds.numpy()[:,1]})
submission = submission.sort_values('id') 
submission.to_csv('submission.csv', index=False)
submission.head()