In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
from itertools import product
from random import shuffle
from tqdm import tqdm

from KNN_Embeddings import *

In [None]:
# Feature normalization
from sklearn.preprocessing import StandardScaler
from torch.utils.data import DataLoader, TensorDataset, Dataset
import random

class TripletDataset(Dataset):
    def __init__(self, embeddings, labels):
        self.embeddings = embeddings
        self.labels = labels
        self.labels_set = set(labels.numpy())
        self.label_to_indices = {label: np.where(labels.numpy() == label)[0]
                                 for label in self.labels_set}

    def __getitem__(self, index):
        anchor = self.embeddings[index]
        anchor_label = self.labels[index].item()
        positive_index = index
        while positive_index == index:
            positive_index = random.choice(self.label_to_indices[anchor_label])
        negative_label = random.choice(list(self.labels_set - {anchor_label}))
        negative_index = random.choice(self.label_to_indices[negative_label])
        positive = self.embeddings[positive_index]
        negative = self.embeddings[negative_index]
        return anchor, positive, negative, anchor_label, negative_label

    def __len__(self):
        return len(self.embeddings)


# Create a scaler object
scaler = StandardScaler()

# Fit on training data and transform both training and test data
X_train_normalized = scaler.fit_transform(X_train)
X_test_normalized = scaler.transform(X_test)

X_train_tensor = torch.tensor(X_train_normalized, dtype=torch.float)
y_train_tensor = torch.tensor(y_train, dtype=torch.long)
X_test_tensor = torch.tensor(X_test_normalized, dtype=torch.float)
y_test_tensor = torch.tensor(y_test, dtype=torch.long)

train_dataset = TripletDataset(X_train_tensor, y_train_tensor)
test_dataset = TripletDataset(X_test_tensor, y_test_tensor)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=4)



## Shared Sub Network

In [None]:
class SharedSubNet(nn.Module):
    def __init__(self):
        super(SharedSubNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=5, stride=1, padding=2)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2)
        self.conv3 = nn.Conv2d(64, 50, kernel_size=1, stride=1)  # Using 1x1 convolutions
        self.global_pool = nn.AdaptiveAvgPool2d((1, 1))

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = self.global_pool(x)
        x = x.view(x.size(0), -1)
        return x

## Divide and Encode module

In [None]:
class DivideAndEncode(nn.Module):
    def __init__(self, num_slices=50, bits_per_slice=1, num_bits=10):
        super(DivideAndEncode, self).__init__()
        self.num_slices = num_slices
        self.bits_per_slice = bits_per_slice
        self.fc_layers = nn.ModuleList([nn.Linear(num_slices, bits_per_slice) for _ in range(num_bits)])

    def forward(self, x):
        outputs = []
        for fc in self.fc_layers:
            out = F.sigmoid(fc(x))  # Applying sigmoid to restrict output to [0, 1]
            out = self.piecewise_threshold(out)
            outputs.append(out)
        return torch.cat(outputs, dim=1)

    def piecewise_threshold(self, s, epsilon=0.05):
        return torch.where(s < 0.5 - epsilon, torch.zeros_like(s),
                           torch.where(s > 0.5 + epsilon, torch.ones_like(s), s))


## Triplet Ranking Loss

In [None]:
class TripletRankingLoss(nn.Module):
    def __init__(self, margin=1.0):
        super(TripletRankingLoss, self).__init__()
        self.margin = margin

    def forward(self, anchor, positive, negative):
        distance_positive = (anchor - positive).pow(2).sum(1)  # L2 squared
        distance_negative = (anchor - negative).pow(2).sum(1)  # L2 squared
        losses = F.relu(distance_positive - distance_negative + self.margin)
        return losses.mean()


## Training

In [None]:
# class SFHC(nn.Module):
#     def __init__(self, input_dim, num_bits):
#         super(SFHC, self).__init__()
#         self.fc1 = nn.Linear(input_dim, 512)
#         self.relu = nn.ReLU()
#         self.fc2 = nn.Linear(512, num_bits)
    
#     def forward(self, x):
#         x = self.fc1(x)
#         x = self.relu(x)
#         x = self.fc2(x)
#         x = torch.tanh(x)  # Ensuring outputs are bounded between -1 and 1
#         x = F.normalize(x, p=2, dim=1)  # L2 normalization
#         return x
    
class IntegratedModel(nn.Module):
    def __init__(self, num_bits):
        super(IntegratedModel, self).__init__()
        self.shared_subnet = SharedSubNet()
        self.divide_and_encode = DivideAndEncode(num_slices=50, bits_per_slice=1, num_bits=num_bits)
        self.fc = nn.Linear(24, 512)  # Adjusted input size
        self.relu = nn.ReLU()
        self.final_fc = nn.Linear(512, num_bits)  # Output bits as the number of output features

    def forward(self, x):
        x = self.shared_subnet(x)
        print("Shape after SharedSubNet:", x.shape)
        x = self.divide_and_encode(x)
        print("Shape after DivideAndEncode:", x.shape)
        x = self.fc(x)
        x = self.relu(x)
        x = self.final_fc(x)
        x = torch.tanh(x)
        x = F.normalize(x, p=2, dim=1)
        return x

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Model, optimizer, and loss
model = IntegratedModel(input_dim=X_train_tensor.shape[1], num_bits=12).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_func = TripletRankingLoss()

for epoch in range(10):
    model.train()
    total_loss = 0.0
    for anchor, positive, negative, _, _ in train_loader:
        anchor, positive, negative = anchor.to(device), positive.to(device), negative.to(device)
        optimizer.zero_grad()
        anchor_output = model(anchor)
        positive_output = model(positive)
        negative_output = model(negative)
        loss = loss_func(anchor_output, positive_output, negative_output)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f'Epoch {epoch+1}, Loss: {total_loss / len(train_loader)}')

## KNN on hashed embeddings

In [None]:
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import classification_report
from sklearn.metrics import average_precision_score
from sklearn.preprocessing import label_binarize
import numpy as np

def evaluate_model(model, test_loader, device):
    model.eval()
    embeddings = []
    labels = []
    with torch.no_grad():
        for anchor, _, _, label_a, _ in test_loader:  # Correctly unpack all elements
            anchor = anchor.to(device)
            output = model(anchor)
            embeddings.append(output.cpu())
            labels.append(label_a)
    embeddings = torch.cat(embeddings)
    labels = torch.cat(labels)
    return embeddings, labels

# Ensure model and device are defined and properly initialized
# Example: model = DPSH(input_dim=X_train.shape[1], num_bits=48).to(device)
# and device is defined like device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Extract hash codes
train_codes, train_labels = evaluate_model(model, train_loader, device)
test_codes, test_labels = evaluate_model(model, test_loader, device)

# Classification with KNN
knn = KNeighborsClassifier(n_neighbors=5, metric='euclidean')
knn.fit(train_codes, train_labels)
predictions = knn.predict(test_codes)
y_pred_proba = knn.predict_proba(test_codes)

print(classification_report(test_labels, predictions))

# Binarize the labels for a one-vs-rest computation
y_test_binarized = label_binarize(test_labels, classes=np.unique(train_labels))  # Updated to use `test_labels`

# Calculate the average precision for each class
average_precisions = []
for i in range(y_test_binarized.shape[1]):  # iterate over classes
    average_precisions.append(average_precision_score(y_test_binarized[:, i], y_pred_proba[:, i]))

# Compute the mean of the average precisions
map_score = np.mean(average_precisions)
print(f'Mean Average Precision (MAP): {map_score}')
