In [1]:
import time
import logging
from IPython.display import Audio

from IDRnD.utils import *
from IDRnD.augmentations import *
from IDRnD.dataset import *
from IDRnD.resnet import *
from IDRnD.nasnet_mobile import NASNetAMobile
from IDRnD.focalloss import FocalLoss
from IDRnD.callbacks import *
from IDRnD.pipeline import *

import numpy as np
import torch
from torch.optim import Adam, SGD
from torch.optim.lr_scheduler import CosineAnnealingLR, CyclicLR, ReduceLROnPlateau
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import transforms

from sklearn.model_selection import StratifiedKFold, train_test_split

%reload_ext autoreload
%autoreload 2
%matplotlib inline

seed_everything(0)
logging.basicConfig(level=logging.DEBUG, filename="logs/logs.log",
                    filemode="w+")

In [2]:
X, y = get_train_data()
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)

mask = np.load("IDRnD/data/mask.npy")
X_good, X_bad, y_good, y_bad = X_val[np.invert(mask)], X_val[mask], y_val[np.invert(mask)], y_val[mask]
X_train_new, y_train_new = np.concatenate((X_train, X_bad)), np.concatenate((y_train, y_bad))
X_val_new, y_val_new = X_good, y_good

In [3]:
post_transform = transforms.Compose([
    librosa.power_to_db,
    PadOrClip(320),
    Normalize_predef(-29.6179, 16.6342),
    ToTensor(),
    #transforms.ToTensor(),
])

In [4]:
##train
batch_size = 125

train_dataset = SimpleMelDataset(X_train_new, y_train_new, "../data/files/raw_mels/", post_transform)
valid_dataset = SimpleMelDataset(X_val_new, y_val_new, "../data/files/raw_mels/", post_transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=16, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size*3, num_workers=16, shuffle=False)

model = NASNetAMobile(num_classes=1).cuda()
#model.load_state_dict(torch.load('models/resnet_34.pt7'))
model_dst = torch.nn.DataParallel(model, device_ids=[0, 1]).cuda()

#criterion = nn.BCEWithLogitsLoss().cuda()
criterion = FocalLoss(gamma=2, reduce='mean').cuda()
optimizer = Adam(params=model.parameters(), lr=1e-4)

In [5]:
tb_logger = TensorBoardCallback(compute_eer)
saver = SaveEveryEpoch("models/resnet_34_better_val.pt")
acumulator = AccumulateGradient([30, 50, 70, 100])
best = SaveBestEpoch("models/nasnet_mobile.pt", compute_eer)
scheduler = ReduceLROnPlateau(optimizer, patience=1, verbose=True)
scheduler_call = EpochScheduler(scheduler)
hm = Train(callbacks=[tb_logger, acumulator, best, scheduler_call])

hm.fit(train_loader, valid_loader, model_dst, criterion, optimizer, epoches=150)

torch.Size([63, 7392])
torch.Size([62, 7392])




torch.Size([63, 7392])
torch.Size([62, 7392])
torch.Size([62, 7392])
torch.Size([63, 7392])
torch.Size([63, 7392])
torch.Size([62, 7392])
torch.Size([62, 7392])
torch.Size([63, 7392])
torch.Size([63, 7392])
torch.Size([62, 7392])
torch.Size([62, 7392])
torch.Size([63, 7392])
torch.Size([63, 7392])
torch.Size([62, 7392])
torch.Size([62, 7392])
torch.Size([63, 7392])


Traceback (most recent call last):
  File "/opt/conda/lib/python3.6/multiprocessing/queues.py", line 240, in _feed
    send_bytes(obj)
  File "/opt/conda/lib/python3.6/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
Traceback (most recent call last):
  File "/opt/conda/lib/python3.6/multiprocessing/queues.py", line 240, in _feed
    send_bytes(obj)
  File "/opt/conda/lib/python3.6/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/opt/conda/lib/python3.6/multiprocessing/connection.py", line 404, in _send_bytes
    self._send(header + buf)
  File "/opt/conda/lib/python3.6/multiprocessing/connection.py", line 368, in _send
    n = write(self._handle, buf)
BrokenPipeError: [Errno 32] Broken pipe
Traceback (most recent call last):
  File "/opt/conda/lib/python3.6/multiprocessing/queues.py", line 240, in _feed
    send_bytes(obj)
  File "/opt/conda/lib/python3.6/multiprocessing/

KeyboardInterrupt: 