In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from torch.optim import Adam
from torch import nn, tensor
import torch.nn.functional as F
import torchaudio
from torchaudio.transforms import MelSpectrogram
from torch.utils.data import Dataset, DataLoader
import os
from tqdm.auto import tqdm
from sklearn.metrics import confusion_matrix
from voice_feedback.audio import AudioFile
import numpy as np
from pickle_cache import PickleCache
from voice_feedback.utils import two_buttons, iter_sequence, new_item_path
from IPython.display import display, clear_output


pcache = PickleCache()

In [28]:
def ls(path):
    return [os.path.join(path, d) for d in os.listdir(path)]

class AudioDataset(Dataset):
    def __init__(self, dir_path):
        good = ls(f'{dir_path}/good')
        bad = ls(f'{dir_path}/bad')
        self.files = good + bad
        self.labels = torch.cat([
            torch.zeros((len(good),), dtype=torch.long),
            torch.ones((len(bad),), dtype=torch.long)
        ])
        _, rate = torchaudio.load(self.files[0])
        self.spectrogram = MelSpectrogram(rate)
        bad_proportion = len(bad) / len(self.files)
        self.weights = tensor([bad_proportion, 1.-bad_proportion])
        self.samples = [
            self.spectrogram(torchaudio.load(f)[0][0])
            for f in self.files
        ]

    def __len__(self):
        return len(self.files)
    
    def __getitem__(self, k):
        return {
            'samples': self.samples[k],
            'labels': self.labels[k]
        }

In [29]:
dataset = AudioDataset('../data/stutter')
loader = DataLoader(dataset, batch_size=128, shuffle=True)

In [90]:
class AudioClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.layers = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=3),
            nn.Dropout(p=0.1),
            
            nn.Conv2d(32, 64, kernel_size=3),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=4, stride=4),
            nn.Dropout(p=0.1),
            
            nn.Conv2d(64, 64, kernel_size=3),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=4, stride=4),
            nn.Dropout(p=0.1)            
        )
        
        self.fc = nn.Sequential(
            nn.Linear(768, 2),
            nn.LogSoftmax()
        )
        
    def forward(self, x):
        conv_output = self.layers(x.unsqueeze(dim=1)).squeeze(dim=1)
        return self.fc(torch.flatten(conv_output, 1)).squeeze(dim=1)
    
    def predict(self, x):
        return torch.argmax(self(x), dim=1)
    
    def confusion(self, loader):
        self.eval()
        
        all_y_pred = []
        all_y_true = []
        for batch in loader:
            all_y_pred.append(model.predict(batch['samples'].to('cuda')).to('cpu'))
            all_y_true.append(batch['labels'])        
            
        cmat = confusion_matrix(
            torch.cat(all_y_true).detach().numpy(), 
            torch.cat(all_y_pred).detach().numpy())
        
        self.train()
        return torch.from_numpy(cmat).double()

In [91]:
model = AudioClassifier()
#model = torch.load('../data/checkpoints/second.pt')
#model = model.to('cpu')
model = model.to('cuda')

In [92]:
model.train()
adam = Adam(model.parameters(), lr=0.001)
loss_fn = nn.NLLLoss(weight=dataset.weights)

for i in tqdm(range(100)):
    for batch in loader:
        y_pred = model(batch['samples'].to('cuda'))        
        loss = loss_fn(y_pred.to('cpu'), batch['labels'])
        adam.zero_grad()
        loss.backward()
        adam.step()   
        
    confusion = model.confusion(loader)
    confusion = F.normalize(confusion, p=1, dim=1)
    if torch.all(torch.diagonal(confusion) > 0.9):
        print(i, confusion)
        break
    
    if i % 5 == 0:
        print(i, confusion)
    
#     y_bad = [1,:]
#     y_bad = y_bad / np.sum(y_bad)
#     if y_bad[1] >= .8:
#         break

HBox(children=(FloatProgress(value=0.0), HTML(value='')))

0 tensor([[0.0045, 0.9955],
        [0.0000, 1.0000]], dtype=torch.float64)
5 tensor([[0.6964, 0.3036],
        [0.0400, 0.9600]], dtype=torch.float64)
10 tensor([[0.8929, 0.1071],
        [0.1200, 0.8800]], dtype=torch.float64)
12 tensor([[0.9152, 0.0848],
        [0.0800, 0.9200]], dtype=torch.float64)



In [None]:
torch.save(model, '../data/checkpoints/second.pt')

In [93]:
zoom = AudioFile('../data/zoom.wav')

In [100]:
total_duration = len(zoom.wav) / zoom.rate
spec = MelSpectrogram(zoom.rate)
intvls, clips = [], []
clip_duration = 2.
clip_samples = int(zoom.rate * clip_duration)
me_segs = pcache.get('me_segs')

for seg in me_segs:
    for start in np.arange(seg.start, seg.end, 1.):
        end = min(start + clip_duration, total_duration)
        clip = zoom.interval(start, end)
        intvls.append((start, end))
        t = tensor(clip, dtype=torch.float32)
        clips.append(F.pad(t, (0, clip_samples - len(t))))
        
samples = torch.stack(clips)

In [109]:
model = model.to('cpu')
#model.eval()
x = spec(samples) #.to('cuda')
y_pred = []
for i in range(0, len(samples), 100):
    x2 = x[i:i+100]
    y_pred.append(model(x2).to('cpu'))
y_pred = torch.cat(y_pred)
idxs = torch.argsort(y_pred[:, 1], descending=True)
#_ = model.train()

In [113]:
for i in idxs[:20]:
    print(torch.exp(y_pred[i]))
    display(zoom.display(zoom.interval(*intvls[i])))

tensor([5.4699e-04, 9.9945e-01], grad_fn=<ExpBackward>)


tensor([6.7783e-04, 9.9932e-01], grad_fn=<ExpBackward>)


tensor([0.0010, 0.9990], grad_fn=<ExpBackward>)


tensor([0.0016, 0.9984], grad_fn=<ExpBackward>)


tensor([0.0050, 0.9950], grad_fn=<ExpBackward>)


tensor([0.0051, 0.9949], grad_fn=<ExpBackward>)


tensor([0.0057, 0.9943], grad_fn=<ExpBackward>)


tensor([0.0057, 0.9943], grad_fn=<ExpBackward>)


tensor([0.0061, 0.9939], grad_fn=<ExpBackward>)


tensor([0.0070, 0.9930], grad_fn=<ExpBackward>)


tensor([0.0075, 0.9925], grad_fn=<ExpBackward>)


tensor([0.0085, 0.9915], grad_fn=<ExpBackward>)


tensor([0.0094, 0.9906], grad_fn=<ExpBackward>)


tensor([0.0106, 0.9894], grad_fn=<ExpBackward>)


tensor([0.0117, 0.9883], grad_fn=<ExpBackward>)


tensor([0.0146, 0.9854], grad_fn=<ExpBackward>)


tensor([0.0159, 0.9841], grad_fn=<ExpBackward>)


tensor([0.0182, 0.9818], grad_fn=<ExpBackward>)


tensor([0.0199, 0.9801], grad_fn=<ExpBackward>)


tensor([0.0212, 0.9788], grad_fn=<ExpBackward>)


In [None]:
def prompt(i, next_panel):
    intvl = zoom.interval(*intvls[i])
    display(zoom.display(intvl, autoplay=True))
    
    def callback(is_bad):
        zoom.write(new_item_path('../data/stutter', 'bad' if is_bad else 'good'), intvl)
        next_panel()
    
    print(y_pred[i])
    two_buttons(yes='Bad', no='Good', callback=callback)

iter_sequence(idxs[20:40], prompt)