In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from src.evaluation import evaluate_anmrr
from typing import List, Dict, Any

In [None]:
class TripletRetriever(pl.LightningModule):
    def __init__(self, model_name: str, last_layer_size=100):
        super().__init__()
        self.model = torch.hub.load('pytorch/vision', model_name, pretrained=True)
        self.model.fc = torch.nn.Linear(512, last_layer_size, bias=True)
        self.set_training_model_layers(False, 8)
        self.criterion = torch.nn.TripletMarginLoss()

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        anchors = batch['a']
        positives = batch['p']
        negatives = batch['n']
        a = self.model(anchors)
        p = self.model(positives)
        n = self.model(negatives)
        loss = self.criterion(a, p, n)
        self.log('train_loss', loss)
        return loss
    
    def validation_step(self, batch, batch_idx):
        anchors = batch['a']
        positives = batch['p']
        negatives = batch['n']
        a = self.model(anchors)
        p = self.model(positives)
        n = self.model(negatives)
        loss = self.criterion(a, p, n)
        self.log('valid_loss', loss)
        return loss
    
    def training_epoch_end(self, outputs: List[Any]) -> None:
        print(outputs)
        return super().training_epoch_end(outputs)

    def configure_optimizers(self):
        optim = torch.optim.Adam(self.model.parameters(), weight_decay=1e-5)
        return optim
    
    def set_training_model_layers(self, training: bool, up_to_index: int):
        i = 0
        for child in self.model.children():
            if i > up_to_index:
                break
            for param in child.parameters():
                param.requires_grad = training
            i += 1

In [None]:
from src.data.ucmerced_dataset import UcMercedDataset
from src.settings import TRAIN_DATA_DIRECTORY, TEST_DATA_DIRECTORY
image_size = 224
train_dataset = UcMercedDataset(TRAIN_DATA_DIRECTORY, image_size, train=True)
test_dataset = UcMercedDataset(TEST_DATA_DIRECTORY, image_size, train=False)

train_dataloader = DataLoader(train_dataset, batch_size=80, shuffle=True, num_workers=10)
test_dataloader = DataLoader(test_dataset, shuffle=False, batch_size=100, num_workers=10)

triplet_retriever = TripletRetriever("resnet18")
wandb_logger = WandbLogger('uc_merced_100_1', project='triplet_retrieval')
trainer = pl.Trainer(max_epochs=2, gpus=1, logger=wandb_logger)

In [None]:

trainer.fit(triplet_retriever, train_dataloader, test_dataloader)

In [None]:
wandb_logger.close()
wandb_logger.save()