In [227]:
# %load_ext autoreload
%autoreload 2
import torch
from datasets import TripletAudio

K, MAX_CLOSE_NEG, MAX_FAR_NEG = 5, 15, 15
BATCH_SIZE = 1

triplet_train_dataset = TripletAudio(True, K, MAX_CLOSE_NEG, MAX_FAR_NEG)
triplet_test_dataset = TripletAudio(False, K, MAX_CLOSE_NEG, MAX_FAR_NEG)
triplet_train_loader = torch.utils.data.DataLoader(triplet_train_dataset, batch_size=BATCH_SIZE, shuffle=True)
triplet_test_loader = torch.utils.data.DataLoader(triplet_test_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [261]:
%autoreload 2
# Set up the network and training parameters
from networks import AnchorNet, EmbeddingNet, TripletNet
from losses import TripletLoss
import torch.optim as optim
from torch.optim import lr_scheduler

INPUT_D, OUTPUT_D = 192, 128
MARGIN, LEARNING_RATE, N_EPOCHS, LOG_INT = 0.5, 1e-3, 5, 1000

#define model
anchor_net = AnchorNet(triplet_train_dataset.get_dataset(), INPUT_D, OUTPUT_D)
embedding_net = EmbeddingNet(anchor_net)
model = TripletNet(embedding_net)

loss_fn = TripletLoss(MARGIN)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = lr_scheduler.StepLR(optimizer, 8, gamma=0.1, last_epoch=-1)

In [263]:
%autoreload 2
from trainer import fit
fit(triplet_train_loader, triplet_test_loader, model, loss_fn, optimizer, scheduler, N_EPOCHS, {}, LOG_INT)



KeyboardInterrupt: 

### Online Selection

In [358]:
# %load_ext autoreload
%autoreload 2
import torch
from datasets import AudioTrainDataset, AudioTestDataset
from datasets import BalancedBatchSampler

K = 5
train_dataset = AudioTrainDataset(K)
test_dataset = AudioTestDataset(K)

train_batch_sampler = BalancedBatchSampler(train_dataset)
test_batch_sampler = BalancedBatchSampler(test_dataset)

online_train_loader = torch.utils.data.DataLoader(train_dataset, batch_sampler=train_batch_sampler)
online_test_loader = torch.utils.data.DataLoader(test_dataset, batch_sampler=test_batch_sampler)

In [355]:
# Set up the network and training parameters
from networks import EmbeddingNet, AnchorNet
from losses import OnlineTripletLoss
from utils import SemihardNegativeTripletSelector # Strategies for selecting triplets within a minibatch
from metrics import AverageNonzeroTripletsMetric

N_EPOCHS, LOG_INT = 20, 25

model = embedding_net
loss_fn = OnlineTripletLoss(MARGIN, SemihardNegativeTripletSelector(MARGIN, train_dataset.KNN))
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)
scheduler = lr_scheduler.StepLR(optimizer, 8, gamma=0.1, last_epoch=-1)

In [362]:
fit(online_train_loader, online_test_loader, model, loss_fn, optimizer, scheduler, N_EPOCHS, {}, LOG_INT, metrics=[AverageNonzeroTripletsMetric()])

len of triplets 425
len of triplets 543
len of triplets 525
len of triplets 394
len of triplets 443
len of triplets 430
len of triplets 467
len of triplets 389


KeyboardInterrupt: 