In [None]:
# Install required libraries and their versions

!pip install torch==1.13.1+cu116 
!pip install torchvision==0.14.1+cu116
!pip install faiss-cpu==1.7.3
!pip install pytorch-lightning==1.9.4
!pip install pytorch-metric-learning==2.0.1
!pip install opencv-python==4.7.0.72
!pip install scikit-image==0.19.3
!conda install -y gdown

In [None]:
# Libraries

import os
import faiss
import torch
import numpy as np
import torch.nn as nn
import torchvision.models
import pytorch_lightning as pl
from typing import Tuple
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from pytorch_metric_learning import losses
from torchvision import transforms as tfm
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning import loggers as pl_loggers
from glob import glob
from sklearn.neighbors import NearestNeighbors
from collections import defaultdict

# Google libraries

import requests
import zipfile
import gdown

In [None]:
# Download, extract and delete for google drive

# sheewa dataset

# zip_url_tokyo = 'https://drive.google.com/file/d/1FUSHfFUMcPyXMbzt9hKUoYoLZ_uEHULP/view?usp=drive_link' # tokyo
# zip_url_sf = 'https://drive.google.com/file/d/10wCmksu4w1uMRnmDkTPvuVWq_u5nqHEQ/view?usp=drive_link' # sf
# zip_url_gsv = 'https://drive.google.com/file/d/10QHcLRefihtVIFuLMxcnxDSi_6E4MBJs/view?usp=drive_link' # gsv

# prof dataset

zip_url_tokyo = 'https://drive.google.com/file/d/15QB3VNKj93027UAQWv7pzFQO1JDCdZj2/view?usp=drive_link' # tokyo
zip_url_sf = 'https://drive.google.com/file/d/1tQqEyt3go3vMh4fj_LZrRcahoTbzzH-y/view?usp=drive_link' # sf
zip_url_gsv = 'https://drive.google.com/file/d/1q7usSe9_5xV5zTfN-1In4DlmF5ReyU_A/view?usp=drive_link' # gsv


destination_folder = './dataset'

if not os.path.exists(destination_folder):
    os.makedirs(destination_folder)
    
def download_and_extract(zip_url, destination_folder):
    # Extract the file ID from the URL
    file_id = zip_url.split('/')[-2]
    
    # Create the direct download URL
    download_url = f'https://drive.google.com/uc?id={file_id}'
    
    # Download the zip file
    output_file = os.path.join(destination_folder, f'{file_id}.zip')
    gdown.download(download_url, output_file, quiet=False, fuzzy=True)
    
    # Extract the zip file
    with zipfile.ZipFile(output_file, 'r') as zip_ref:
        zip_ref.extractall(destination_folder)
    
    # Remove the downloaded zip file
    os.remove(output_file)
    
    print(f"Extraction completed for: {file_id}.zip")
    

download_and_extract(zip_url_tokyo, destination_folder)
download_and_extract(zip_url_sf, destination_folder)
download_and_extract(zip_url_gsv, destination_folder)

In [None]:
# Transformation + Converts to RGB Format

def open_image(path):
    return Image.open(path).convert("RGB")

transform = tfm.Compose([
    tfm.ToTensor(),
    tfm.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [None]:
# Dataset for Training

class TrainDataset(Dataset):
    """
    Custom dataset for training a deep learning model on images from a folder structure.

    This class ensures each "place" (identified in the path) has at least a minimum number of images.
    It loads, preprocesses, and returns batches of images from different places.
    """
    
    def __init__(self, dataset_folder, transform, img_per_place=4, min_img_per_place=4):
        super().__init__()
        self.dataset_folder = dataset_folder
        self.images_paths = sorted(glob(f"{dataset_folder}/**/*.jpg", recursive=True))
        self.dict_place_paths = defaultdict(list)
        
        for image_path in self.images_paths:
            place_id = image_path.split("@")[-2]
            self.dict_place_paths[place_id].append(image_path)

        assert img_per_place <= min_img_per_place, \
            f"img_per_place should be less than {min_img_per_place}"
        self.img_per_place = img_per_place
        self.transform = transform

        # keep only places depicted by at least min_img_per_place images
        for place_id in list(self.dict_place_paths.keys()):
            all_paths_from_place_id = self.dict_place_paths[place_id]
            if len(all_paths_from_place_id) < min_img_per_place:
                del self.dict_place_paths[place_id]
        self.places_ids = sorted(list(self.dict_place_paths.keys()))
        self.total_num_images = sum([len(paths) for paths in self.dict_place_paths.values()])

    # extract placeId....
    def __getitem__(self, index):
        place_id = self.places_ids[index]
        all_paths_from_place_id = self.dict_place_paths[place_id]
        chosen_paths = np.random.choice(all_paths_from_place_id, self.img_per_place)
        images = [Image.open(path).convert('RGB') for path in chosen_paths]
        images = [self.transform(img) for img in images]
        return torch.stack(images), torch.tensor(index).repeat(self.img_per_place)

    def __len__(self):
        return len(self.places_ids)

In [None]:
# DataLoader

train_dataset = TrainDataset('/kaggle/working/dataset/gsv_xs/train', transform)
train_loader = DataLoader(dataset=train_dataset, batch_size=64, num_workers=2, shuffle=True)
print(f"Train dataset: {len(train_dataset)}")
# print(train_dataset.total_num_images)

In [None]:
# Dataset for Testing

class TestDataset(Dataset):
    def __init__(self, dataset_folder, database_folder="database", queries_folder="queries", positive_dist_threshold=25):
        super().__init__()
        self.dataset_folder = dataset_folder
        self.database_folder = os.path.join(dataset_folder, database_folder)
        self.queries_folder = os.path.join(dataset_folder, queries_folder)
        self.dataset_name = os.path.basename(dataset_folder)
        self.database_paths = sorted(glob(os.path.join(self.database_folder, "**", "*.jpg"), recursive=True))
        self.queries_paths = sorted(glob(os.path.join(self.queries_folder, "**", "*.jpg"),  recursive=True))

        # exteract UTM
        self.database_utms = np.array \
            ([(path.split("@")[1], path.split("@")[2]) for path in self.database_paths]).astype(float)
        self.queries_utms = np.array([(path.split("@")[1], path.split("@")[2]) for path in self.queries_paths]).astype \
            (float)

        # Find positives_per_query, which are within positive_dist_threshold
        knn = NearestNeighbors(n_jobs=-1)
        knn.fit(self.database_utms)
        self.positives_per_query = knn.radius_neighbors(self.queries_utms, radius=positive_dist_threshold, return_distance=False)

        self.images_paths = [p for p in self.database_paths]
        self.images_paths += [p for p in self.queries_paths]

        self.database_num = len(self.database_paths)
        self.queries_num = len(self.queries_paths)

    def __getitem__(self, index):
        image_path = self.images_paths[index]
        pil_img = open_image(image_path)
        normalized_img = transform(pil_img)
        return normalized_img, index

    def __len__(self):
        return len(self.images_paths)

    def __repr__(self):
        return f"< {self.dataset_name} - #q: {self.queries_num}; #db: {self.database_num} >"

    def get_positives(self):
        return self.positives_per_query

In [None]:
# DataLoader Test SF and val SF + Test Tokyo

sf_val_dataset = TestDataset('/kaggle/working/dataset/sf_xs/val')
sf_test_dataset = TestDataset('/kaggle/working/dataset/sf_xs/test')
tokyo_test_dataset = TestDataset('/kaggle/working/dataset/tokyo_xs/test')

val_loader = DataLoader(dataset=sf_val_dataset, batch_size=64, num_workers=2, shuffle=False)
test_loader = DataLoader(dataset=sf_test_dataset, batch_size=64, num_workers=2, shuffle=False)
tokyo_loader = DataLoader(dataset=tokyo_test_dataset, batch_size=64, num_workers=2, shuffle=False)

print(f"val sf: {len(sf_val_dataset)}")
print(f"test sf: {len(sf_test_dataset)}")
print(f"test tokyo: {len(tokyo_test_dataset)}")

In [None]:
# GeM layer

class GeM(nn.Module):
    def __init__(self, p=3, eps=1e-6):
        super().__init__()
        self.p = nn.Parameter(torch.ones(1) * p)
        self.eps = eps

    def forward(self, x):
        return self.gem(x, p=self.p, eps=self.eps)

    def gem(self, x, p=3, eps=1e-6):
        return nn.functional.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1.0 / p)

    def __repr__(self):
        return self.__class__.__name__ + f"(p={self.p.data.tolist()[0]:.4f}, eps={self.eps})"

In [None]:
# Compute Recall

def compute_recalls(eval_ds: Dataset, queries_descriptors : np.ndarray, database_descriptors : np.ndarray):

    # Use a kNN to find predictions
    faiss_index = faiss.IndexFlatL2(queries_descriptors.shape[1])
    faiss_index.add(database_descriptors)
    del database_descriptors

    print("Calculating")
    RECALL_VALUES = [1, 5]
    _, predictions = faiss_index.search(queries_descriptors, max(RECALL_VALUES))

    positives_per_query = eval_ds.get_positives()
    recalls = np.zeros(len(RECALL_VALUES))
    for query_index, preds in enumerate(predictions):
        for i, n in enumerate(RECALL_VALUES):
            if np.any(np.in1d(preds[:n], positives_per_query[query_index])):
                recalls[i:] += 1
                break
    # Divide by queries_num and multiply by 100, so the recalls are in percentages
    recalls = recalls / eval_ds.queries_num * 100
    recalls_str = ", ".join([f"R@{val}: {rec:.1f}" for val, rec in zip(RECALL_VALUES, recalls)])
    
    return recalls, recalls_str

In [None]:
# Step 1 - AVG pooling with Adam with ContrastiveLoss

class LightningModel(pl.LightningModule):
    def __init__(self, val_dataset, test_dataset, descriptors_dim=512):
        super().__init__()
        self.val_dataset = val_dataset
        self.test_dataset = test_dataset
        
        # Load pretrained ResNet-18
        resnet18 = torchvision.models.resnet18(pretrained=True)
        
        # Truncate the model at conv3 (end of layer2)
        self.features = torch.nn.Sequential(
            resnet18.conv1,
            resnet18.bn1,
            resnet18.relu,
            resnet18.maxpool,
            resnet18.layer1,
            resnet18.layer2,
        )
        
        # Add average pooling layer
        self.avgpool = torch.nn.AdaptiveAvgPool2d((1, 1))
        
        # Calculate the output dimension
        with torch.no_grad():
            dummy_input = torch.randn(1, 3, 224, 224)
            output = self.avgpool(self.features(dummy_input))
            self.output_dim = output.view(-1).shape[0]
        
        # Add a linear layer to match the desired descriptors dimension
        self.fc = torch.nn.Linear(self.output_dim, descriptors_dim)
        
        # Set the loss function
        self.loss_fn = losses.ContrastiveLoss(pos_margin=0, neg_margin=1)

    def forward(self, images):
        x = self.features(images)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        descriptors = self.fc(x)
        return descriptors
        
        
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-5)
        return optimizer

    def loss_function(self, descriptors, labels):
        loss = self.loss_fn(descriptors, labels)
        return loss

    def training_step(self, batch, batch_idx):
        images, labels = batch
        num_places, num_images_per_place, C, H, W = images.shape
        images = images.view(num_places * num_images_per_place, C, H, W)
        labels = labels.view(num_places * num_images_per_place)

        descriptors = self(images)
        loss = self.loss_function(descriptors, labels)
        
        self.log('loss', loss.item(), logger=True)
        return {'loss': loss}

    def inference_step(self, batch):
        images, _ = batch
        descriptors = self(images)
        return descriptors.cpu().numpy().astype(np.float32)

    def validation_step(self, batch, batch_idx):
        return self.inference_step(batch)

    def test_step(self, batch, batch_idx):
        return self.inference_step(batch)

    def validation_epoch_end(self, all_descriptors):
        return self.inference_epoch_end(all_descriptors, self.val_dataset, 'val')

    def test_epoch_end(self, all_descriptors):
        return self.inference_epoch_end(all_descriptors, self.test_dataset, 'test')

    def inference_epoch_end(self, all_descriptors, inference_dataset, split):
        all_descriptors = np.concatenate(all_descriptors)
        queries_descriptors = all_descriptors[inference_dataset.database_num:]
        database_descriptors = all_descriptors[:inference_dataset.database_num]

        recalls, recalls_str = compute_recalls(inference_dataset, queries_descriptors, database_descriptors)

        print(f"Epoch[{self.current_epoch:02d}]): " +
                      f"recalls: {recalls_str}")

        self.log(f'{split}/R@1', recalls[0], prog_bar=False, logger=True)
        self.log(f'{split}/R@5', recalls[1], prog_bar=False, logger=True)

In [None]:
# Step 2 - GeM with Adam with ContrastiveLoss

class LightningModel(pl.LightningModule):
    def __init__(self, val_dataset, test_dataset, descriptors_dim=512):
        super().__init__()
        self.val_dataset = val_dataset
        self.test_dataset = test_dataset
        
        # Load pretrained ResNet-18
        resnet18 = torchvision.models.resnet18(pretrained=True)
        
        # Truncate the model at conv3 (end of layer2)
        self.features = torch.nn.Sequential(
            resnet18.conv1,
            resnet18.bn1,
            resnet18.relu,
            resnet18.maxpool,
            resnet18.layer1,
            resnet18.layer2,
        )
        
        # Replace average pooling with GeM pooling
        self.gem_pool = GeM()
    
        # Calculate the output dimension
        with torch.no_grad():
            dummy_input = torch.randn(1, 3, 224, 224)
            output = self.gem_pool(self.features(dummy_input))
            self.output_dim = output.view(-1).shape[0]
    
        # Add a linear layer to match the desired descriptors dimension
        self.fc = torch.nn.Linear(self.output_dim, descriptors_dim)
    
        # Set the loss function
        self.loss_fn = losses.ContrastiveLoss(pos_margin=0, neg_margin=1)

    def forward(self, images):
        x = self.features(images)
        x = self.gem_pool(x)
        x = torch.flatten(x, 1)
        descriptors = self.fc(x)
        return descriptors
        

        
        
        
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-5)
        return optimizer

    def loss_function(self, descriptors, labels):
        loss = self.loss_fn(descriptors, labels)
        return loss

    def training_step(self, batch, batch_idx):
        images, labels = batch
        num_places, num_images_per_place, C, H, W = images.shape
        images = images.view(num_places * num_images_per_place, C, H, W)
        labels = labels.view(num_places * num_images_per_place)

        descriptors = self(images)
        loss = self.loss_function(descriptors, labels)
        
        self.log('loss', loss.item(), logger=True)
        return {'loss': loss}

    def inference_step(self, batch):
        images, _ = batch
        descriptors = self(images)
        return descriptors.cpu().numpy().astype(np.float32)

    def validation_step(self, batch, batch_idx):
        return self.inference_step(batch)

    def test_step(self, batch, batch_idx):
        return self.inference_step(batch)

    def validation_epoch_end(self, all_descriptors):
        return self.inference_epoch_end(all_descriptors, self.val_dataset, 'val')

    def test_epoch_end(self, all_descriptors):
        return self.inference_epoch_end(all_descriptors, self.test_dataset, 'test')

    def inference_epoch_end(self, all_descriptors, inference_dataset, split):
        all_descriptors = np.concatenate(all_descriptors)
        queries_descriptors = all_descriptors[inference_dataset.database_num:]
        database_descriptors = all_descriptors[:inference_dataset.database_num]

        recalls, recalls_str = compute_recalls(inference_dataset, queries_descriptors, database_descriptors)

        print(f"Epoch[{self.current_epoch:02d}]): " +
                      f"recalls: {recalls_str}")

        self.log(f'{split}/R@1', recalls[0], prog_bar=False, logger=True)
        self.log(f'{split}/R@5', recalls[1], prog_bar=False, logger=True)

In [None]:
# step 3 - GeM with Adam with TripletLoss

import torch.nn.functional as F

class TripletLoss(nn.Module):
    def __init__(self, margin=1.0):
        super(TripletLoss, self).__init__()
        self.margin = margin

    def forward(self, anchor, positive, negative):
        distance_positive = (anchor - positive).pow(2).sum(1)
        distance_negative = (anchor - negative).pow(2).sum(1)
        losses = F.relu(distance_positive - distance_negative + self.margin)
        return losses.mean()

class LightningModel(pl.LightningModule):
    def __init__(self, val_dataset, test_dataset, descriptors_dim=512):
        super().__init__()
        self.val_dataset = val_dataset
        self.test_dataset = test_dataset
        
        # Load pretrained ResNet-18
        resnet18 = torchvision.models.resnet18(pretrained=True)
        
        # Truncate the model at conv3 (end of layer2)
        self.features = torch.nn.Sequential(
            resnet18.conv1,
            resnet18.bn1,
            resnet18.relu,
            resnet18.maxpool,
            resnet18.layer1,
            resnet18.layer2,
        )
        
        # Replace average pooling with GeM pooling
        self.gem_pool = GeM()
    
        # Calculate the output dimension
        with torch.no_grad():
            dummy_input = torch.randn(1, 3, 224, 224)
            output = self.gem_pool(self.features(dummy_input))
            self.output_dim = output.view(-1).shape[0]
    
        # Add a linear layer to match the desired descriptors dimension
        self.fc = torch.nn.Linear(self.output_dim, descriptors_dim)
    
        # Set the loss function
        self.loss_fn = TripletLoss(margin=1.0)

    def forward(self, images):
        x = self.features(images)
        x = self.gem_pool(x)
        x = torch.flatten(x, 1)
        descriptors = self.fc(x)
        return descriptors

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-5)
        return optimizer

    def loss_function(self, descriptors, labels):
        # Reshape descriptors and labels
        num_places, num_images_per_place = labels.shape
        descriptors = descriptors.view(num_places, num_images_per_place, -1)
        
        # For each place, select an anchor, a positive, and a negative
        anchors = descriptors[:, 0]  # First image of each place as anchor
        positives = descriptors[:, 1]  # Second image of each place as positive
        
        # Randomly select negatives from other places
        negatives_idx = torch.randint(0, num_places, (num_places,))
        negatives = descriptors[negatives_idx, 0]
        
        loss = self.loss_fn(anchors, positives, negatives)
        return loss

    def training_step(self, batch, batch_idx):
        images, labels = batch
        num_places, num_images_per_place, C, H, W = images.shape
        images = images.view(num_places * num_images_per_place, C, H, W)

        descriptors = self(images)
        loss = self.loss_function(descriptors, labels)
        
        self.log('loss', loss.item(), logger=True)
        return {'loss': loss}


    def inference_step(self, batch):
        images, _ = batch
        descriptors = self(images)
        return descriptors.cpu().numpy().astype(np.float32)

    def validation_step(self, batch, batch_idx):
        return self.inference_step(batch)

    def test_step(self, batch, batch_idx):
        return self.inference_step(batch)

    def validation_epoch_end(self, all_descriptors):
        return self.inference_epoch_end(all_descriptors, self.val_dataset, 'val')

    def test_epoch_end(self, all_descriptors):
        return self.inference_epoch_end(all_descriptors, self.test_dataset, 'test')

    def inference_epoch_end(self, all_descriptors, inference_dataset, split):
        all_descriptors = np.concatenate(all_descriptors)
        queries_descriptors = all_descriptors[inference_dataset.database_num:]
        database_descriptors = all_descriptors[:inference_dataset.database_num]

        recalls, recalls_str = compute_recalls(inference_dataset, queries_descriptors, database_descriptors)

        print(f"Epoch[{self.current_epoch:02d}]): " +
                      f"recalls: {recalls_str}")

        self.log(f'{split}/R@1', recalls[0], prog_bar=False, logger=True)
        self.log(f'{split}/R@5', recalls[1], prog_bar=False, logger=True)

In [None]:
# Step 3 - GeM with Adam with MultiSimilarityLoss

from pytorch_metric_learning import losses

class LightningModel(pl.LightningModule):
    def __init__(self, val_dataset, test_dataset, descriptors_dim=512):
        super().__init__()
        self.val_dataset = val_dataset
        self.test_dataset = test_dataset
        
        # Load pretrained ResNet-18
        resnet18 = torchvision.models.resnet18(pretrained=True)
        
        # Truncate the model at conv3 (end of layer2)
        self.features = torch.nn.Sequential(
            resnet18.conv1,
            resnet18.bn1,
            resnet18.relu,
            resnet18.maxpool,
            resnet18.layer1,
            resnet18.layer2,
        )
        
        # Replace average pooling with GeM pooling
        self.gem_pool = GeM()
    
        # Calculate the output dimension
        with torch.no_grad():
            dummy_input = torch.randn(1, 3, 224, 224)
            output = self.gem_pool(self.features(dummy_input))
            self.output_dim = output.view(-1).shape[0]
    
        # Add a linear layer to match the desired descriptors dimension
        self.fc = torch.nn.Linear(self.output_dim, descriptors_dim)
    
        # Set the loss function to MultiSimilarityLoss
        self.loss_fn = losses.MultiSimilarityLoss()

    def forward(self, images):
        x = self.features(images)
        x = self.gem_pool(x)
        x = torch.flatten(x, 1)
        descriptors = self.fc(x)
        return descriptors

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-5)
        return optimizer

    def loss_function(self, descriptors, labels):
        loss = self.loss_fn(descriptors, labels)
        return loss

    def training_step(self, batch, batch_idx):
        images, labels = batch
        num_places, num_images_per_place, C, H, W = images.shape
        images = images.view(num_places * num_images_per_place, C, H, W)
        labels = labels.view(num_places * num_images_per_place)

        descriptors = self(images)
        loss = self.loss_function(descriptors, labels)
        
        self.log('loss', loss.item(), logger=True)
        return {'loss': loss}


    def inference_step(self, batch):
        images, _ = batch
        descriptors = self(images)
        return descriptors.cpu().numpy().astype(np.float32)

    def validation_step(self, batch, batch_idx):
        return self.inference_step(batch)

    def test_step(self, batch, batch_idx):
        return self.inference_step(batch)

    def validation_epoch_end(self, all_descriptors):
        return self.inference_epoch_end(all_descriptors, self.val_dataset, 'val')

    def test_epoch_end(self, all_descriptors):
        return self.inference_epoch_end(all_descriptors, self.test_dataset, 'test')

    def inference_epoch_end(self, all_descriptors, inference_dataset, split):
        all_descriptors = np.concatenate(all_descriptors)
        queries_descriptors = all_descriptors[inference_dataset.database_num:]
        database_descriptors = all_descriptors[:inference_dataset.database_num]

        recalls, recalls_str = compute_recalls(inference_dataset, queries_descriptors, database_descriptors)

        print(f"Epoch[{self.current_epoch:02d}]): " +
                      f"recalls: {recalls_str}")

        self.log(f'{split}/R@1', recalls[0], prog_bar=False, logger=True)
        self.log(f'{split}/R@5', recalls[1], prog_bar=False, logger=True)

In [None]:
# Step 4 - GeM with MultiSimilarityLoss with ADAMW + SGD

from pytorch_metric_learning import losses

class LightningModel(pl.LightningModule):
    def __init__(self, val_dataset, test_dataset, descriptors_dim=512):
        super().__init__()
        self.val_dataset = val_dataset
        self.test_dataset = test_dataset
        
        # Load pretrained ResNet-18
        resnet18 = torchvision.models.resnet18(pretrained=True)
        
        # Truncate the model at conv3 (end of layer2)
        self.features = torch.nn.Sequential(
            resnet18.conv1,
            resnet18.bn1,
            resnet18.relu,
            resnet18.maxpool,
            resnet18.layer1,
            resnet18.layer2,
        )
        
        # Replace average pooling with GeM pooling
        self.gem_pool = GeM()
    
        # Calculate the output dimension
        with torch.no_grad():
            dummy_input = torch.randn(1, 3, 224, 224)
            output = self.gem_pool(self.features(dummy_input))
            self.output_dim = output.view(-1).shape[0]
    
        # Add a linear layer to match the desired descriptors dimension
        self.fc = torch.nn.Linear(self.output_dim, descriptors_dim)
    
        # Set the loss function to MultiSimilarityLoss
        self.loss_fn = losses.MultiSimilarityLoss()

    def forward(self, images):
        x = self.features(images)
        x = self.gem_pool(x)
        x = torch.flatten(x, 1)
        descriptors = self.fc(x)
        return descriptors

    
    
#     def configure_optimizers(self):
#         optimizer = torch.optim.AdamW(self.parameters(), lr=1e-5, weight_decay=0.01)
#         return optimizer

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.parameters(), lr=0.01, momentum=0.9, weight_decay=0.0005)
        return optimizer

    def loss_function(self, descriptors, labels):
        loss = self.loss_fn(descriptors, labels)
        return loss

    def training_step(self, batch, batch_idx):
        images, labels = batch
        num_places, num_images_per_place, C, H, W = images.shape
        images = images.view(num_places * num_images_per_place, C, H, W)
        labels = labels.view(num_places * num_images_per_place)

        descriptors = self(images)
        loss = self.loss_function(descriptors, labels)
        
        self.log('loss', loss.item(), logger=True)
        return {'loss': loss}


    def inference_step(self, batch):
        images, _ = batch
        descriptors = self(images)
        return descriptors.cpu().numpy().astype(np.float32)

    def validation_step(self, batch, batch_idx):
        return self.inference_step(batch)

    def test_step(self, batch, batch_idx):
        return self.inference_step(batch)

    def validation_epoch_end(self, all_descriptors):
        return self.inference_epoch_end(all_descriptors, self.val_dataset, 'val')

    def test_epoch_end(self, all_descriptors):
        return self.inference_epoch_end(all_descriptors, self.test_dataset, 'test')

    def inference_epoch_end(self, all_descriptors, inference_dataset, split):
        all_descriptors = np.concatenate(all_descriptors)
        queries_descriptors = all_descriptors[inference_dataset.database_num:]
        database_descriptors = all_descriptors[:inference_dataset.database_num]

        recalls, recalls_str = compute_recalls(inference_dataset, queries_descriptors, database_descriptors)

        print(f"Epoch[{self.current_epoch:02d}]): " +
                      f"recalls: {recalls_str}")

        self.log(f'{split}/R@1', recalls[0], prog_bar=False, logger=True)
        self.log(f'{split}/R@5', recalls[1], prog_bar=False, logger=True)

In [None]:
model = LightningModel(sf_val_dataset, sf_test_dataset)

checkpoint_cb = ModelCheckpoint(
    monitor='val/R@1',
    filename='_epoch({epoch:02d})_R@1[{val/R@1:.4f}]_R@5[{val/R@5:.4f}]',
    auto_insert_metric_name=False,
    save_weights_only=False,
    save_top_k=1,
    save_last=True,
    mode='max'
)

tb_logger = pl_loggers.TensorBoardLogger(save_dir="logs/", version="default")

trainer = pl.Trainer(
    accelerator='gpu',
    devices=[0],
    default_root_dir='./logs',
    num_sanity_val_steps=0,
    precision=16,
    max_epochs=10,
    check_val_every_n_epoch=1,
    logger=tb_logger,
    callbacks=[checkpoint_cb],
    reload_dataloaders_every_n_epochs=1,
    log_every_n_steps=20,
)

In [None]:
print("Traning Process\n")
trainer.fit(model=model, ckpt_path=None, train_dataloaders=train_loader, val_dataloaders=val_loader)
print("VALIDATING ON SF")

In [None]:
print("-_-_-_-_-_-_-_-_-_TESTING ON SF-_-_-_-_-_-_-_-_-_-_-_\n")
trainer.test(model=model, dataloaders=test_loader, ckpt_path='best')

In [None]:
model = LightningModel(sf_val_dataset, tokyo_test_dataset)

print("-_-_-_-_-_-_-_-_-_TESTING ON TOKYO-_-_-_-_-_-_-_-_-_-_-_")
trainer.test(model=model, dataloaders=tokyo_loader, ckpt_path= None)