In [1]:
from models.siamese import SiameseNetwork
from dataset.triplets import generate_dataset, TripletFaceDataset

import torch
from torch.optim import SGD
from torch.nn import BCELoss
from torch.utils.data import DataLoader
from torchvision import transforms

In [2]:
data_folder_path = r'F:\ML Data\105_classes_pins_dataset'
data_df = generate_dataset(data_folder_path)

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((160, 160))
])
dataset = TripletFaceDataset(
    triplets_dataframe=data_df,
    weights_path="models/TrainedWeights",
    transform=transform
)
dataloader = DataLoader(
    dataset=dataset,
    batch_size=1,
    shuffle=True
)

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = SiameseNetwork().to(device)
optimizer = SGD(model.parameters(), lr=0.01)
criterion = BCELoss()

In [4]:
def train_step(batch):
    model.train()
    model.encoder.eval()
    optimizer.zero_grad()
    anchor, positive, negative = batch[0].to(device), batch[1].to(device), batch[2].to(device)
    y_pos = model(anchor, positive)
    y_pos_true = torch.tensor([1], dtype=torch.float)
    loss_p = criterion(y_pos, y_pos_true)
    loss_p.backward()
    optimizer.step()
    y_neg = model(anchor, negative)
    y_neg_true = torch.tensor([0], dtype=torch.float)
    loss_n = criterion(y_neg, y_neg_true)
    loss_n.backward()
    optimizer.step()
    return (loss_p.item() + loss_n.item())/2

def validation_step(batch):
    model.eval()
    with torch.no_grad():
        anchor, positive, negative = batch[0].to(device), batch[1].to(device), batch[2].to(device)
        y_p = model(anchor, positive)
        y_n = model(anchor, negative)
        return y_p.item(), y_n.item()

In [5]:
running_loss = 0.0
for i, batch in enumerate(dataloader):
    current_loss = train_step(batch)
    running_loss += current_loss
    if i%10 == 9:
        print(f"Loss: {running_loss/10}")
        running_loss = 0

Loss: 0.695321723818779
Loss: 0.6964830219745636
Loss: 0.6980624437332154
Loss: 0.7000533699989319


KeyboardInterrupt: 