GCL
===

**Generalized Contrastive Learning for Multi-Modal Retrieval and Ranking**
 * Paper: https://arxiv.org/abs/2404.08535

![GCL Overview](../assets/gcl_overview.png)

In [1]:
import torch
import torch.nn.functional as F
from torch import Tensor
from transformers import AutoTokenizer, AutoModel

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def average_pool(
        last_hidden_states: Tensor,
        attention_mask: Tensor
    ) -> Tensor:
    last_hidden = last_hidden_states.masked_fill(
        ~attention_mask[..., None].bool(), 0.0
    )
    return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]


tokenizer = AutoTokenizer.from_pretrained(
    'Marqo/marqo-gcl-e5-large-v2-130'
)
model_new = AutoModel.from_pretrained(
    'Marqo/marqo-gcl-e5-large-v2-130'
).eval().to(device);

In [2]:
# Each input text should start with "query: " or "passage: ".
# For tasks other than retrieval, you can simply use the "query: " prefix.
queries = [
    'query: Espresso Pitcher with Handle',
    'query: Women’s designer handbag sale'
]

passeges = [
    "passage: Dianoo Espresso Steaming Pitcher, Espresso Milk Frothing Pitcher Stainless Steel",
    "passage: Coach Outlet Eliza Shoulder Bag - Black - One Size"
]

# Tokenize the input texts
batch_dict = tokenizer(
    queries + passeges,
    max_length=77,
    padding=True,
    truncation=True,
    return_tensors='pt'
).to(device)

outputs = model_new(**batch_dict)
outputs.keys()

odict_keys(['last_hidden_state', 'pooler_output'])

In [3]:
print(outputs.last_hidden_state.shape)

torch.Size([4, 23, 1024])


In [4]:
embeddings = average_pool(
    outputs.last_hidden_state,
    batch_dict['attention_mask']
)
print(embeddings.shape)

torch.Size([4, 1024])


In [5]:
# normalize embeddings
embeddings = F.normalize(embeddings, p=2, dim=1)

query_embeddings = embeddings[:len(queries)]
passage_embeddings = embeddings[len(queries):]

# calculate cosine similarity
scores = (query_embeddings @ passage_embeddings.T)
print([[round(s, 4) for s in row] for row in scores.tolist()])

[[0.6319, 0.1453], [0.0052, 0.7808]]
