### InfoNCE Contrastive Loss
https://arxiv.org/pdf/1807.03748.pdf

InfoNCE loss implementation:
- Anchor with single positive and multiple negatives.
- Positive is most similar 
- Negatives are least similar

In [1]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append('../')

In [2]:
from utils.returns_data_class import ReturnsData
PERIOD = 20
data = ReturnsData(
    daily_returns_path="../Data/returns_df_611.csv",
    extras_path="../Data/historical_stocks.csv",
)
data.change_returns_period(PERIOD)
X = data.returns_df.values.T

In [3]:
X.shape

(611, 234)

In [19]:
num_TS = X.shape[0]
num_pos_samples = 1
period = 10
stride = 3
num_neg_samples = 10

print(f"Context Size: {num_pos_samples}, Period: {period}, Stride: {stride}")
print(f"Number of Negative Samples: {num_neg_samples}")

from utils.context import get_tgt_context_euclidean_multiprocess
positive_tgt_context_sets = get_tgt_context_euclidean_multiprocess(ts_array=X, m=period, k=num_pos_samples, stride=stride, z_normalize=False, verbose=False)
negative_tgt_context_sets = get_tgt_context_euclidean_multiprocess(ts_array=X, m=period, k=num_neg_samples, stride=stride, z_normalize=False, top_k=False, verbose=False)
print(f"Number (anchor, positive, negative) samples: {len(positive_tgt_context_sets)}")

Context Size: 1, Period: 10, Stride: 3
Number of Negative Samples: 10
Number (anchor, positive, negative) samples: 45825


In [20]:
index_samples = []
for pos, neg in zip(positive_tgt_context_sets,negative_tgt_context_sets):
    index_samples.append((pos[0], pos[1][0], neg[1]))

In [21]:
import torch
from torch.utils.data import Dataset, DataLoader


class NPairDataset(Dataset):
    def __init__(self, index_samples):
        self.index_samples = index_samples

    def __len__(self):
        return len(self.index_samples)

    def __getitem__(self, idx):
        anchor_idx, positive_idx, negative_indices = self.index_samples[idx]
        # Convert negative indices list to tensor
        negative_indices_tensor = torch.tensor(negative_indices)
        return anchor_idx, positive_idx, negative_indices_tensor


def normalize_embeddings(embeddings):
    with torch.no_grad():  # Ensure no gradients are computed in this block
        norms = embeddings.weight.norm(dim=1, keepdim=True)
        embeddings.weight.data = embeddings.weight.data / norms
    return embeddings


def info_nce_loss(
    anchor_embeddings: torch.Tensor,
    positive_embeddings: torch.Tensor,
    negative_embeddings: torch.Tensor,
) -> torch.Tensor:
    """See Equation 6.20 in Bishop Deep Learning (called NCE)

    Args:
        anchor_embeddings (torch.Tensor): shape (batch_size, embedding_dim)
        positive_embeddings (torch.Tensor): (batch_size, embedding_dim)
        negative_embeddings (torch.Tensor): (batch_size, num_negative_samples, embedding_dim)

    Returns:
        torch.Tensor: _description_
    """
    numerator = torch.exp(
        torch.einsum("bd,bd->b", anchor_embeddings, positive_embeddings)
    )
    denominator = numerator + torch.sum(
        torch.exp(torch.einsum("bd,bnd->bn", anchor_embeddings, negative_embeddings)),
        dim=1,
    )
    loss = torch.sum(-torch.log(numerator / denominator))
    return loss

In [23]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

# Parameters
num_items = num_TS  # Number of items in your dataset
embedding_size = 16  # Size of each embedding
learning_rate = 0.01
epochs = 100
batch_size = 2048

# Initialize the embedding matrix
embeddings = nn.Embedding(num_embeddings=num_items, embedding_dim=embedding_size)

# Optimizer
optimizer = optim.Adam(embeddings.parameters(), lr=learning_rate)

# Prepare your index_samples as a tensor
# index_samples = torch.tensor([(35, 12, 98), (47, 12, 4), ...])  # Your index tuples as a tensor
# Create the dataset and data loader
npair_dataset = NPairDataset(index_samples)
data_loader = DataLoader(npair_dataset, batch_size=batch_size, shuffle=True)

# Training loop with batching
for epoch in range(epochs):
    total_loss = 0
    # normalise the embeddings to prevent degenerate solution
    embeddings = normalize_embeddings(embeddings)

    for anchor_idx, positive_idx, negative_indices in data_loader:

        # Get the embeddings for anchor, positive, and negative
        anchor_embeddings = embeddings(anchor_idx) # shape: (batch_size, embedding_dim)
        positive_embeddings = embeddings(positive_idx) # shape: (batch_size, embedding_dim)
        negative_embeddings = embeddings(negative_indices) # shape: (batch_size, num_neg_samples, embedding_dim)

        # Compute the loss
        loss = info_nce_loss(anchor_embeddings, positive_embeddings, negative_embeddings)
        total_loss += loss.item()

        # Backward pass and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    if (epoch % 10 == 0) | (epoch==epochs-1):
        print(f'Epoch [{epoch+1}/{epochs}], Loss: {total_loss/len(data_loader)}')

embeddings = normalize_embeddings(embeddings)

Epoch [1/100], Loss: 4740.7499203889265
Epoch [11/100], Loss: 2416.5103759765625
Epoch [21/100], Loss: 2421.149169921875
Epoch [31/100], Loss: 2415.4263332201085
Epoch [41/100], Loss: 2418.4303509256115
Epoch [51/100], Loss: 2412.4856142790422
Epoch [61/100], Loss: 2418.1203002929688
Epoch [71/100], Loss: 2415.030082370924


KeyboardInterrupt: 

### Get training data

In [24]:
from utils.sector_classification import get_sector_score

get_sector_score(embeddings.weight.detach().numpy(), sectors=data.sectors, top_k_accuracy=True)


Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.


Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.



Precision Score: 0.41
Recall Score: 0.28
F1 Score: 0.26
Accuracy Score: 0.28
Accuracy Score Top-3: 0.6



Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.



In [25]:
from utils.visualisation_functions import pca_plot_from_embeddings

pca_plot_from_embeddings(
    embedding_matrix=embeddings.weight.detach().numpy(),
    sectors=data.sectors,
    tickers=data.tickers,
    industries=data.industries,
    names=data.names,
    dimensions=2,
    reduced=True,
    method="PCA",
    return_df=False,
    rand_state=None,
)