In [64]:
# %load_ext autoreload
%autoreload 2
import torch
from datasets import TripletAudio

K, MAX_CLOSE_NEG, MAX_FAR_NEG = 5, 15, 15
BATCH_SIZE = 1

triplet_train_dataset = TripletAudio(True, K, MAX_CLOSE_NEG, MAX_FAR_NEG)
triplet_test_dataset = TripletAudio(False, K, MAX_CLOSE_NEG, MAX_FAR_NEG)
triplet_train_loader = torch.utils.data.DataLoader(triplet_train_dataset, batch_size=BATCH_SIZE, shuffle=True)
triplet_test_loader = torch.utils.data.DataLoader(triplet_test_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [169]:
%autoreload 2
# Set up the network and training parameters
from networks import AnchorNet, EmbeddingNet, TripletNet
from losses import TripletLoss
import torch.optim as optim
from torch.optim import lr_scheduler

INPUT_D, OUTPUT_D = 192, 128
MARGIN, LEARNING_RATE, N_EPOCHS, LOG_INT = 0.5, 1e-3, 5, 1000

#define model
anchor_net = AnchorNet(triplet_train_dataset.get_dataset(), INPUT_D, OUTPUT_D)
embedding_net = EmbeddingNet(anchor_net)
model = TripletNet(embedding_net)

loss_fn = TripletLoss(MARGIN)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = lr_scheduler.StepLR(optimizer, 8, gamma=0.1, last_epoch=-1)

In [170]:
%autoreload 2
from trainer import fit
fit(triplet_train_loader, triplet_test_loader, model, loss_fn, optimizer, scheduler, N_EPOCHS, {}, LOG_INT)

Epoch: 1/5. Train set: Average loss: 1.4918
Epoch: 1/5. Validation set: Average loss: 0.7237
Epoch: 2/5. Train set: Average loss: 0.4837
Epoch: 2/5. Validation set: Average loss: 0.3618
Epoch: 3/5. Train set: Average loss: 0.3329
Epoch: 3/5. Validation set: Average loss: 0.2764
Epoch: 4/5. Train set: Average loss: 0.2873
Epoch: 4/5. Validation set: Average loss: 0.2419
Epoch: 5/5. Train set: Average loss: 0.2637
Epoch: 5/5. Validation set: Average loss: 0.2237


### Online Selection

In [None]:
# %load_ext autoreload
%autoreload 2
import torch
from datasets import TripletAudio

train_dataset = AudioTrainDataset()
test_dataset = AudioTestDataset()

train_batch_sampler = BalancedBatchSampler(train_dataset.train_labels, n_classes=25, n_samples=K)
test_batch_sampler = BalancedBatchSampler(test_dataset.test_labels, n_classes=25, n_samples=K)

online_train_loader = torch.utils.data.DataLoader(train_dataset, batch_sampler=train_batch_sampler)
online_test_loader = torch.utils.data.DataLoader(test_dataset, batch_sampler=test_batch_sampler)