Skip to content

ppriyank/-Online-Soft-Mining-and-Class-Aware-Attention-Pytorch

Repository files navigation

Online-Soft-Mining-and-Class-Aware-Attention

Implementation of Weighted Contrastive Loss from

Deep Metric Learning by Online Soft Mining and Class-Aware Attention (https://arxiv.org/pdf/1811.01459v2.pdf)
Xinshao Wang, Yang Hua1, Elyor Kodirov, Guosheng Hu, Neil M. Robertson

Use (person re-id) Pytorch:


criterion_osm_caa = OSM_CAA_Loss()
if use_gpu:
   imgs, pids = imgs.cuda(), pids.cuda()
imgs, pids = Variable(imgs), Variable(pids)
outputs, features = model(imgs)
if use_gpu:
  loss = criterion_osm_caa(features, pids , model.module.classifier.weight.t())         
else:
  loss = criterion_osm_caa(features, pids , model.classifier.weight.t())         

Tensorflow:

sess =tf.Session()
x = tf.random.uniform([32,200]) #(batch size= 32, embedding dim= 200)
embd  =tf.random.uniform([200, 10]) #(embedding dim= 200 , num of classes = 10)

loss = OSM_CAA_Loss()
osm_loss = loss.forward

loss_val = osm_loss(x , labels , embd)
sess.run(loss_val)

If you find any deviation from the paper, please let me know (raise issue) I will make the necessary changes.

Comments

dist refers to the pairwise distance between normalized feature vectors, of the shape n x n, dij = dist[i][j]

A refers to the pairwise attention score Aij = min(ai , aj)

About

(Pytorch and Tensorflow) Implementation of Weighted Contrastive Loss (Deep Metric Learning by Online Soft Mining and Class-Aware Attention)

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages