In [1]:
pip install pytorch_metric_learning

Collecting pytorch_metric_learning
  Downloading pytorch_metric_learning-2.5.0-py3-none-any.whl.metadata (17 kB)
Downloading pytorch_metric_learning-2.5.0-py3-none-any.whl (119 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m119.1/119.1 kB[0m [31m4.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pytorch_metric_learning
Successfully installed pytorch_metric_learning-2.5.0
Note: you may need to restart the kernel to use updated packages.


In [3]:
import os
import math
import timm
import random
import torch
import wandb
import cv2 as cv
import pandas as pd
import numpy as np
import pickle
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.nn import Module
from torchvision.models import vgg16, VGG16_Weights, ResNet50_Weights
from torchvision import transforms
import torchvision
from PIL import Image
from enum import Enum
from pytorch_metric_learning import miners, losses
import albumentations as A
#from torchgeo import datasets

###To work with

In [4]:
wandb.login()

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

  ········································


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [21]:
wandb.init(project='Test1', config={
    'learning_rate': 1e-3,
    'margin': 0.2,
    'net': 'TimmMobileNet large',
    'optimizer': 'AdamW',
    'train dataset': 'DenseUAV',
    'test dataset': 'DenseUAV',
    'miner': 'semi-hard samples'
})

[34m[1mwandb[0m: Currently logged in as: [33mmc1granec2003[0m ([33mdiploma_work[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [23]:
GEOMETRIC_DOUBLE = 'GEOMETRIC_DOUBLE'
GEOMETRIC_SINGLE = 'GEOMETRIC_SINGLE'
FINE_SINGLE = 'FINE_SINGLE'
COLOR_DOUBLE = 'COLOR_DOUBLE'
COLOR_SINGLE = 'COLOR_SINGLE'
RANDOM_CROP_SINGLE = 'RANDOM_CROP_SINGLE'
RANDOM_CROP_DOUBLE = 'RANDOM_CROP_DOUBLE'

def make_train_aug(size=(512, 512)):
    h, w = size
    geometric_aug = [
        A.Flip(p=0.75),
        A.Transpose(p=0.5),
        A.RandomRotate90(p=0.75),
        # A.ShiftScaleRotate(scale_limit=(-0.5, 0.5), shift_limit=0, rotate_limit=45, p=0.9),
        A.Perspective(p=0.25),
        A.PadIfNeeded(min_height=h, min_width=w, always_apply=True, border_mode=0),
        ]

    geometric_double= A.Compose(geometric_aug, additional_targets={'positive': 'image'})

    color_aug = [
        A.Sharpen (alpha=(0.05, 0.1), lightness=(0.1, 0.5), p=0.5),
        A.RandomBrightnessContrast(brightness_limit=0.05, contrast_limit=0.05, p=1.0),
        A.RGBShift(),
        A.CLAHE(p=0.2),
        A.RandomGamma(p=1),
        A.HueSaturationValue(p=1),
        A.ChannelShuffle(p=0.2),

        A.OneOf([
            A.GaussNoise(p=1),
            A.Emboss(p=1),
            A.Sharpen(p=1),
            A.ImageCompression(p=1),
        ], p=0.75),

        A.OneOf([
            A.Blur(blur_limit=3, p=1),
            A.GaussianBlur(blur_limit=3, p=1),
            A.MedianBlur(blur_limit=3, p=1),
            A.MotionBlur(blur_limit=3, p=1),
        ], p=0.75),
        ]

    color_double = A.Compose(color_aug, additional_targets={'negative': 'image'})

    return {
        GEOMETRIC_DOUBLE: geometric_double,
        COLOR_DOUBLE: color_double,
        }

In [5]:
from geopy.distance import geodesic

def parse_coordinates(line):
    parts = line.split()
    easting = float(parts[1][1:])
    northing = float(parts[2][1:])
    return easting, northing

def calculate_distance_between_coordinates(file_path, indexes_true, indexes_pred):
    distance = 0
    
    with open(file_path, 'r') as file:
        lines = file.readlines()
        for index_true, index_pred in zip(indexes_true, indexes_pred):
            
            line_true = lines[index_true].strip()
            line_pred = lines[index_pred].strip()

            lon_true, lat_true = parse_coordinates(line_true)
            lon_pred, lat_pred = parse_coordinates(line_pred)

            distance += geodesic((lat_true, lon_true), (lat_pred, lon_pred)).meters
            
        return distance / indexes_true.shape[0]

geo_data_path = "/kaggle/input/dataset/DenseUAV/Dense_GPS_test.txt"

In [6]:
class Mode(Enum):
  TRAIN = 1
  VAL = 2

class DenseUAVDataset(Dataset):
  def __init__(self, root, mode):
    self.root = root
    self.mode = mode
    self.is_debug = False

    if self.mode == Mode.VAL:
      self.root += '/test/'
      self.query_path = 'query_drone/'
      self.ref_path = 'gallery_satellite/'
    else:
      self.root += '/train/'
      self.query_path = 'drone/'
      self.ref_path = 'satellite/'

    sorted(os.listdir(self.root + self.query_path))
    self.query_folders = sorted(os.listdir(self.root + self.query_path))
    self.ref_folders = sorted(os.listdir(self.root + self.ref_path))

    self.query_img_path = []
    self.ref_img_path = []

    if self.mode == Mode.TRAIN:
        
        for folder_name in self.query_folders:
            folder_path = os.path.join(self.root + self.query_path, folder_name)
            self.query_img_path.extend([
                    os.path.join(folder_path, f"H100.JPG"),
                    os.path.join(folder_path, f"H90.JPG"),
                    os.path.join(folder_path, f"H80.JPG")
                ])

        for folder_name in self.ref_folders:
            folder_path = os.path.join(self.root + self.ref_path, folder_name)
            self.ref_img_path.extend([
                    os.path.join(folder_path, f"H100_old.tif") if random.randint(0, 1) == 0 else os.path.join(folder_path, f"H100_old.tif"),
                    os.path.join(folder_path, f"H90_old.tif") if random.randint(0, 1) == 0 else os.path.join(folder_path, f"H90_old.tif"),
                    os.path.join(folder_path, f"H80.tif") if random.randint(0, 1) == 0 else os.path.join(folder_path, f"H80_old.tif")
                ])
    
    else:
        
        self.ref_folders = sorted(os.listdir(self.root + self.ref_path))[int(self.query_folders[0]):]
        
        for folder_name in self.query_folders:
          folder_path = os.path.join(self.root + self.query_path, folder_name)
          self.query_img_path.append(os.path.join(folder_path, f"H100.JPG"))

        for folder_name in self.ref_folders:
          folder_path = os.path.join(self.root + self.ref_path, folder_name)
          self.ref_img_path.append(os.path.join(folder_path, f"H100.tif"))


  def get_image(self, path):
    image = Image.open(path)

    return image

  def generate_ref_image(self, index):
    positive_folder = self.ref_img_path[index]
    rand_idx = random.randint(0, len(self.ref_img_path) - 1) # 0, 1, 2 - 3, 4 ... +  3, 4, 5 - 0, 1, 2, 6, 7 , 8
    valid_range_start = (index // 3) * 3
    valid_range_end = valid_range_start + 2
    while rand_idx > valid_range_start and rand_idx < valid_range_end: #3826 1289
      rand_idx = random.randint(0, len(self.ref_img_path) - 1)
    negative_folder = self.ref_img_path[rand_idx]
    return positive_folder, negative_folder, torch.as_tensor(rand_idx // 3, dtype=torch.int)


  def apply_color_transfer(self, image_target, image_source):
    mean_target, mean_source = np.mean(image_target, axis=(0, 1)), np.mean(image_source, axis=(0, 1))
    std_target, std_source = np.std(image_target, axis=(0, 1)), np.std(image_source, axis=(0, 1))

    colored_image = (image_source - mean_source) * (std_target / std_source) + mean_target
    colored_image = np.clip(colored_image, 0, 255).astype(np.uint8)
    return colored_image

  def custom_transform(self, image, is_ref=False):

    image = transforms.ToTensor()(image)
    image = transforms.Resize((224, 224))(image)
    image = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(image)

    return image


  def __getitem__(self, index):
    anchor_label, positive_label = torch.as_tensor(index // 3, dtype=torch.int), torch.as_tensor(index // 3, dtype=torch.int)
    anchor = self.get_image(self.query_img_path[index])
    positive_path, negative_path, negative_label = self.generate_ref_image(index)

    positive = self.get_image(positive_path)
    negative = self.get_image(negative_path)

    if self.mode == Mode.TRAIN:
      sample = make_train_aug()[GEOMETRIC_DOUBLE](image=cv.resize(np.array(anchor), (512, 512)), positive=np.array(positive))
      anchor, positive = sample['image'], sample['positive']

      sample = make_train_aug()[COLOR_DOUBLE](image=np.array(positive), negative=np.array(negative))
      positive, negative = sample['image'], sample['negative']

    if not self.is_debug:
        anchor = self.custom_transform(anchor)
        positive = self.custom_transform(positive, is_ref=True)
        negative = self.custom_transform(negative, is_ref=True)

    return anchor, positive, negative, anchor_label, positive_label, negative_label


  def __len__(self):
    return len(self.query_img_path)


In [7]:
train_data = DenseUAVDataset('/kaggle/input/dataset/DenseUAV', mode=Mode.TRAIN)
val_data = DenseUAVDataset('/kaggle/input/dataset/DenseUAV', mode=Mode.VAL)

In [32]:
train_dataloader1 = DataLoader(train_data, batch_size=32, shuffle=False, num_workers=0, pin_memory=True)
train_dataloader2 = DataLoader(train_data, batch_size=32, shuffle=True, num_workers=0, pin_memory=True)
val_dataloader = DataLoader(val_data, batch_size=777, shuffle=False, num_workers=0, pin_memory=True)

In [9]:
class TimmMobilenet(torch.nn.Module):
    def __init__(self, timm_name='mobilenetv3_large_100'):
        super(TimmMobilenet, self).__init__()
        self.source_model = timm.create_model(timm_name, pretrained=True)

    def forward(self, x):
        x = self.source_model.forward_features(x)
        x = self.source_model.global_pool(x)
        x = self.source_model.conv_head(x)
        x = self.source_model.act2(x)

        x = torch.flatten(x, start_dim=1)
        x = F.normalize(x, p=2, dim=1)

        return x

In [10]:
device = 'cuda'

model = TimmMobilenet().to(device)

criterion = torch.nn.TripletMarginLoss(margin=0.2)

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-5)

# with wandb.restore('epoch_33.pth', run_path='diploma_work/Test1/uz7rwknw') as io:
#   name = io.name

# model.load_state_dict(torch.load(name, map_location=device))

model.safetensors:   0%|          | 0.00/22.1M [00:00<?, ?B/s]

In [11]:
from typing import List

def accuracy(dists: torch.Tensor, labels: List[int], top_k=(1,), samples_per_class=3) -> List[torch.FloatTensor]:
    maxk = max(top_k)
    batch_size = dists.size(0)

    #y_pred = torch.argsort(dists, dim=1) // samples_per_class
    _, y_pred = dists.topk(k=maxk, dim=1, largest=False)# [B, n_classes] -> [B, maxk]
    y_pred = y_pred.t() #// samples_per_class # [B, maxk] -> [maxk, B]

    labels_reshaped = labels.view(1, -1).expand_as(y_pred) # B -> [1, B] -> [maxk, B]
    correct = (y_pred == labels_reshaped)
    list_topk_accs = []
    for k in top_k:
        ind_which_topk_matched_truth = correct[:k]  # [maxk, B] -> [k, B]
        flattened_indicator_which_topk_matched_truth = ind_which_topk_matched_truth.reshape(-1).float()  # [k, B] -> [k * B]
        #top_correct_topk = torch.any(ind_which_topk_matched_truth, dim=0).float().sum(dim=0, keepdim=True)
        top_correct_topk = flattened_indicator_which_topk_matched_truth.float().sum(dim=0, keepdim=True)  # [k * B] -> [1]
        topk_acc = top_correct_topk / batch_size
        list_topk_accs.append(topk_acc)

    return torch.stack(list_topk_accs).reshape(-1)

In [12]:
from pytorch_metric_learning.losses import BaseMetricLossFunction
from pytorch_metric_learning.reducers import AvgNonZeroReducer
from pytorch_metric_learning.utils import common_functions as c_f
from pytorch_metric_learning.utils import loss_and_miner_utils as lmu
# taken from 
# https://github.com/KevinMusgrave/pytorch-metric-learning/blob/master/src/pytorch_metric_learning/losses/triplet_margin_loss.py
# to slightly modify smooth_loss
class TripletMarginLoss(BaseMetricLossFunction):

    def __init__(
        self,
        margin=0.2,
        swap=False,
        smooth_loss=True,
        triplets_per_anchor="all",
        **kwargs
    ):
        super().__init__(**kwargs)
        self.margin = margin
        self.swap = swap
        self.smooth_loss = smooth_loss
        self.triplets_per_anchor = triplets_per_anchor
        self.add_to_recordable_attributes(list_of_names=["margin"], is_stat=False)

    def compute_loss(self, embeddings, labels, indices_tuple, ref_emb, ref_labels):
        c_f.labels_or_indices_tuple_required(labels, indices_tuple)
        indices_tuple = lmu.convert_to_triplets(
            indices_tuple, labels, ref_labels, t_per_anchor=self.triplets_per_anchor
        )
        anchor_idx, positive_idx, negative_idx = indices_tuple
        if len(anchor_idx) == 0:
            return self.zero_losses()
        mat = self.distance(embeddings, ref_emb)
        ap_dists = mat[anchor_idx, positive_idx]
        an_dists = mat[anchor_idx, negative_idx]
        if self.swap:
            pn_dists = mat[positive_idx, negative_idx]
            an_dists = self.distance.smallest_dist(an_dists, pn_dists)

        current_margins = self.distance.margin(ap_dists, an_dists)
        violation = current_margins + self.margin
        if self.smooth_loss:
            loss = torch.nn.functional.softplus(violation,beta=3)
        else:
            loss = torch.nn.functional.relu(violation)

        return {
            "loss": {
                "losses": loss,
                "indices": indices_tuple,
                "reduction_type": "triplet",
            }
        }

    def get_default_reducer(self):
        return AvgNonZeroReducer()

In [30]:
def run(train_dataloader, val_dataloader, model, criterion, optimizer, save_path, 
        num_epochs=50, samples_per_class=3, log_freq=26):
    
   def proccess_batch(batch, train):
    anchor, positive, negative, anchor_label, positive_label, negative_label = batch
    anchor, positive, negative, anchor_label, positive_label, negative_label = anchor.to(device), positive.to(device), negative.to(device), anchor_label.to(device), positive_label.to(device), negative_label.to(device)

    if train:
      optimizer.zero_grad()

    anchor = model.forward(anchor)      # b * output_size

    positive = model.forward(positive)  # b * output_size
    negative = model.forward(negative)  # b * output_size

    ref_embeddings = positive

    ref_labels = positive_label

    anchor_to_positive = F.pairwise_distance(anchor, positive) #b
    anchor_to_negative = F.pairwise_distance(anchor, negative) #b

    valid_distances = torch.sum(anchor_to_positive < anchor_to_negative)
    acc = valid_distances / anchor_to_positive.shape[0]
    default_loss = criterion(anchor, positive, negative)

    miner_func = miners.TripletMarginMiner(margin=0.2, type_of_triplets="hard")

    miner_output = miner_func(anchor, anchor_label, ref_embeddings, ref_labels)
    loss_func = TripletMarginLoss(margin=0.2)
    
    loss = loss_func(anchor, anchor_label, miner_output, ref_embeddings, ref_labels)
        
    if train:
      loss.backward()
      optimizer.step()

    return loss, default_loss, acc, miner_func

   wandb.watch(model, log_freq=log_freq)

   for epoch in range(num_epochs):
        print(f'EPOCH #{epoch + 1} ---------------')
        train_miner_loss, train_default_loss, train_acc, train_mined_samples, loss0_cnt = 0, 0, 0, 0, 0
        model.train()
        for idx, (batch1, batch2) in enumerate(zip(train_dataloader1, train_dataloader2)):
          concatenated_batch = []
          for tensor1, tensor2 in zip(batch1, batch2):
                concatenated_tensor = torch.cat((tensor1, tensor2), dim=0)
                concatenated_batch.append(concatenated_tensor)
          miner_loss, default_loss, acc, miner_func = proccess_batch(concatenated_batch, train=True)
          train_miner_loss += float(miner_loss)
          if float(miner_loss) == 0.0:
            loss0_cnt += 1
          train_default_loss += float(default_loss)
          train_mined_samples += miner_func.num_triplets
          train_acc += acc
          if idx % log_freq == 0:
            print("Train\nStep #{}, miner_loss: {}, default_loss: {}, acc: {}\n".format((idx + 1) // log_freq, miner_loss.item(), default_loss.item(), acc))
            print("Triplets mined: {}".format(miner_func.num_triplets))
        
        torch.save(model.state_dict(), f'epoch_{epoch + 1}.pth')

        val_miner_loss, val_default_loss, val_acc, val_mined_samples = 0, 0, 0, 0
        model.eval()
        for idx, batch in enumerate(val_dataloader):
            with torch.no_grad():
                miner_loss, default_loss, acc, miner_func = proccess_batch(batch, train=False)
                val_miner_loss += float(miner_loss)
                val_default_loss += float(default_loss)
                val_mined_samples += miner_func.num_triplets
                val_acc += acc
            if idx % log_freq == 0:
                print("Test\nStep #{}, miner_loss: {}, default_loss: {}, acc: {}\n".format((idx + 1) // log_freq, miner_loss.item(), default_loss.item(), acc))
                print("Triplets mined: {}".format(miner_func.num_triplets))
                
        print('Avg train miner loss: ', train_miner_loss / (len(train_dataloader) - loss0_cnt))
        print('Avg val miner loss: ', val_miner_loss / len(val_dataloader))

        top_k = (1, 5, 10)
        topk_accuracy = 0
        distance = 0
        
        for idx, batch in enumerate(val_dataloader):

            query, ref, _, _, _, _ = batch

            query, ref = query.to(device), ref.to(device)

            with torch.no_grad():
                query_output = model(query)
                ref_output = model(ref)
                dist = torch.cdist(query_output, ref_output) # b
                
            indexes_true = torch.arange(query.shape[0], device=device)
            indexes_pred = torch.argsort(dist, dim=1)[:, 0] 
            
            topk_accuracy += accuracy(dist, indexes_true, top_k=top_k)
            print(accuracy(dist, indexes_true, top_k=top_k))
            distance += calculate_distance_between_coordinates(geo_data_path, indexes_true // samples_per_class, indexes_pred // samples_per_class)
            
        print('topk: ', topk_accuracy / len(val_dataloader))
        print('distance error:', distance / len(val_dataloader))
        
        wandb.save(f'epoch_{epoch + 1}.pth')
        wandb.log({"train_loss": train_default_loss / len(train_dataloader),
                   "val_loss":val_default_loss / len(val_dataloader),
                   "train_miner_loss": train_miner_loss / len(train_dataloader),
                   "val_miner_loss": val_miner_loss / len(val_dataloader),
                   "train_acc": train_acc / len(train_dataloader),
                   "val_acc":val_acc / len(val_dataloader),
                   "top1": topk_accuracy[0] / len(val_dataloader),
                   "top5": topk_accuracy[1] / len(val_dataloader),
                   "top10": topk_accuracy[2] / len(val_dataloader),
                   "distance error": distance / len(val_dataloader),
                   "avg_train_mined_samples": train_mined_samples / len(train_dataloader),
                   "avg_val_mined_samples": val_mined_samples / len(val_dataloader)})




In [None]:
run(train_dataloader1, val_dataloader, model, criterion, optimizer, 'Models/timm_model')