In [9]:
from typing import List

import torch
import torch.nn as nn
from typing import Type, Any, Callable, Union, List, Optional

import faiss
import numpy as np

In [43]:
# !wget https://huggingface.co/stanfordnlp/glove/resolve/main/glove.6B.zip?download=true

In [3]:
# !unzip glove.6B.zip

/bin/bash: line 1: unzip: command not found


In [4]:
# import zipfile
# with zipfile.ZipFile('glove.6B.zip', 'r') as zip_ref:
#     zip_ref.extractall('.')

In [5]:
# Read the glove file
def read_glove_file(glove_file):
    counter = 0
    with open(glove_file, 'r') as f:
        word_to_vec = {}
        for line in f:
            line = line.strip().split()
            word = line[0]
            vec = list(map(float, line[1:]))
            if word.isalpha():
                word_to_vec[word] = vec
            counter += 1
            if counter == 100:
                break
    return word_to_vec


file = 'glove.6B.50d.txt'
word_to_vec = read_glove_file(file)

In [6]:
word_to_vec.keys()

dict_keys(['the', 'of', 'to', 'and', 'in', 'a', 'for', 'that', 'on', 'is', 'was', 'said', 'with', 'he', 'as', 'it', 'by', 'at', 'from', 'his', 'an', 'be', 'has', 'are', 'have', 'but', 'were', 'not', 'this', 'who', 'they', 'had', 'i', 'which', 'will', 'their', 'or', 'its', 'one', 'after', 'new', 'been', 'also', 'we', 'would', 'two', 'more', 'first', 'about', 'up', 'when', 'year', 'there', 'all', 'out', 'she', 'other', 'people', 'her', 'percent', 'than', 'over', 'into', 'last', 'some', 'government', 'time', 'you', 'years', 'if', 'no', 'world', 'can', 'three', 'do', 'president', 'only', 'state', 'million', 'could', 'us', 'most', 'against'])

In [7]:
embeddings_list = [word_to_vec[key] for key in word_to_vec.keys()]
word_list = list(word_to_vec.keys())

# Map word to a integer
word_to_int = {word: i for i, word in enumerate(word_list)}
int_to_word = {i: word for i, word in enumerate(word_list)}

In [10]:
# Convert the embeddings list to tensor
embeddings_tensor = torch.tensor(embeddings_list)
# Convert the word list to tensor
word_tensor = torch.tensor([word_to_int[word] for word in word_list])

In [14]:
# Assuming embeddings_list is a list of embeddings
embeddings_np = np.array(embeddings_list)

# Now you can convert its dtype to float32 if needed
embeddings_np = embeddings_np.astype("float32")

In [16]:
# embeddings = embeddings_tensor.numpy()

# create a random matrix of 50x50 with float32
embeddings = embeddings_np

# check for float32
assert embeddings.dtype == 'float32'

dimension = embeddings.shape[1]
index = faiss.IndexFlatL2(dimension)
index.add(embeddings)

k = 5  # Number of nearest neighbors to retrieve

In [17]:
query_embedding = embeddings[0]
query_embedding = query_embedding.astype('float32')
query_embedding = query_embedding.reshape(1, -1)
query_embedding.shape

(1, 50)

In [18]:
k = 5  # Number of nearest neighbors to retrieve
distances, indices = index.search(query_embedding, k)

In [19]:
distances, indices

(array([[0.       , 3.7542362, 4.7945795, 4.809632 , 4.9880333]],
       dtype=float32),
 array([[ 0, 33,  1,  4,  8]]))

In [20]:
nearest_words = [int_to_word[idx] for idx in indices[0]]
nearest_words

['the', 'which', 'of', 'in', 'on']

These are the nearest words to the word 'the' in the glove embeddings. So when we finish the training, we need these words to also be the same as the words in the embeddings

----

In [21]:
# -*- coding: utf-8 -*-
"""
Directly borrowed from https://github.com/RAIVNLab/MRL/blob/main/MRL.py
"""

'''
Loss function for Matryoshka Representation Learning 
'''

class Matryoshka_CE_Loss(nn.Module):
	def __init__(self, relative_importance: List[float]=None, **kwargs):
		super(Matryoshka_CE_Loss, self).__init__()
		self.criterion = nn.CrossEntropyLoss(**kwargs)
		# relative importance shape: [G]/
		self.relative_importance = relative_importance

	def forward(self, output, target):
		# output shape: [G granularities, N batch size, C number of classes]
		# target shape: [N batch size]

		# Calculate losses for each output and stack them. This is still O(N)
		losses = torch.stack([self.criterion(output_i, target) for output_i in output])
		
		# Set relative_importance to 1 if not specified
		rel_importance = torch.ones_like(losses) if self.relative_importance is None else torch.tensor(self.relative_importance)
		
		# Apply relative importance weights
		weighted_losses = rel_importance * losses
		return weighted_losses.sum()

class MRL_Linear_Layer(nn.Module):
	def __init__(self, nesting_list: List, num_classes=1000, efficient=False, **kwargs):
		super(MRL_Linear_Layer, self).__init__()
		self.nesting_list = nesting_list
		self.num_classes = num_classes # Number of classes for classification
		self.efficient = efficient
		if self.efficient:
			setattr(self, f"nesting_classifier_{0}", nn.Linear(nesting_list[-1], self.num_classes, **kwargs))		
		else:	
			for i, num_feat in enumerate(self.nesting_list):
				setattr(self, f"nesting_classifier_{i}", nn.Linear(num_feat, self.num_classes, **kwargs))	

	def reset_parameters(self):
		if self.efficient:
			self.nesting_classifier_0.reset_parameters()
		else:
			for i in range(len(self.nesting_list)):
				getattr(self, f"nesting_classifier_{i}").reset_parameters()


	def forward(self, x):
		nesting_logits = ()
		for i, num_feat in enumerate(self.nesting_list):
			if self.efficient:
				if self.nesting_classifier_0.bias is None:
					nesting_logits += (torch.matmul(x[:, :num_feat], (self.nesting_classifier_0.weight[:, :num_feat]).t()), )
				else:
					nesting_logits += (torch.matmul(x[:, :num_feat], (self.nesting_classifier_0.weight[:, :num_feat]).t()) + self.nesting_classifier_0.bias, )
			else:
				nesting_logits +=  (getattr(self, f"nesting_classifier_{i}")(x[:, :num_feat]),)

		return nesting_logits


class FixedFeatureLayer(nn.Linear):
    '''
    For our fixed feature baseline, we just replace the classification layer with the following. 
    It effectively just look at the first "in_features" for the classification. 
    '''

    def __init__(self, in_features, out_features, **kwargs):
        super(FixedFeatureLayer, self).__init__(in_features, out_features, **kwargs)

    def forward(self, x):
        if not (self.bias is None):
            out = torch.matmul(x[:, :self.in_features], self.weight.t()) + self.bias
        else:
            out = torch.matmul(x[:, :self.in_features], self.weight.t())
        return out

In [22]:
from torch.utils.data import TensorDataset, DataLoader

# Create a dataset and dataloader
dataset = TensorDataset(embeddings_tensor, word_tensor)
data_loader = DataLoader(dataset, batch_size=32, shuffle=True)

In [39]:
model = nn.Sequential(
    nn.Linear(50, 50),  # Reduce dimension from 50 to 25
    MRL_Linear_Layer(nesting_list=[25], num_classes=1000, efficient=True),
)

# Define the loss function
criterion = Matryoshka_CE_Loss()

# Define the optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Train the model
n_epochs = 1000
for epoch in range(n_epochs):
    for embeddings, word in data_loader:
        # Forward pass
        outputs = model(embeddings)
        loss = criterion(outputs, word)

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    if (epoch+1) % 100 == 0:
        print(f'Epoch [{epoch+1}/{n_epochs}], Loss: {loss.item():.4f}')

Epoch [100/1000], Loss: 2.7660
Epoch [200/1000], Loss: 0.8505
Epoch [300/1000], Loss: 0.3562
Epoch [400/1000], Loss: 0.1535
Epoch [500/1000], Loss: 0.0628
Epoch [600/1000], Loss: 0.0391
Epoch [700/1000], Loss: 0.0299
Epoch [800/1000], Loss: 0.0146
Epoch [900/1000], Loss: 0.0152
Epoch [1000/1000], Loss: 0.0102


In [40]:
# Build the index, again
embeddings = embeddings_tensor.numpy()
dimension = embeddings.shape[1]
index = faiss.IndexFlatL2(dimension)
index.add(embeddings)

# Get the nearest neighbors
query_embedding = embeddings[0]
query_embedding = query_embedding.astype('float32')
query_embedding = query_embedding.reshape(1, -1)
distances, indices = index.search(query_embedding, k)

# Get the words for the indices
nearest_words = [int_to_word[idx] for idx in indices[0]]
nearest_words

['the', 'which', 'of', 'in', 'on']

In [42]:
# Now, we will only use the first 25 dimensions of the embeddings
embeddings = embeddings_tensor[:, :25].numpy()
print(embeddings.shape)
dimension = embeddings.shape[1]
index = faiss.IndexFlatL2(dimension)
index.add(embeddings)

# Get the nearest neighbors
query_embedding = embeddings[0]
query_embedding = query_embedding.astype('float32')
query_embedding = query_embedding.reshape(1, -1)
distances, indices = index.search(query_embedding, k)

# Get the words for the indices
nearest_words = [int_to_word[idx] for idx in indices[0]]
nearest_words

(83, 25)


['the', 'also', 'on', 'one', 'as']