In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
import numpy as np
import random

# Data transformation
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Load CIFAR-10 dataset
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

# Custom dataset to load triplets
class TripletDataset(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset
        self.train_labels = np.array(dataset.targets)
        self.train_data = dataset.data
        self.labels_set = set(self.train_labels)
        self.label_to_indices = {label: np.where(self.train_labels == label)[0]
                                 for label in self.labels_set}

    def __getitem__(self, index):
        img1, label1 = self.dataset[index]
        positive_index = index
        while positive_index == index:
            positive_index = np.random.choice(self.label_to_indices[label1])
        negative_label = np.random.choice(list(self.labels_set - {label1}))
        negative_index = np.random.choice(self.label_to_indices[negative_label])
        img2 = self.dataset[positive_index][0]
        img3 = self.dataset[negative_index][0]
        return (img1, img2, img3), []

    def __len__(self):
        return len(self.dataset)

# DataLoader for training
train_loader = DataLoader(TripletDataset(train_dataset), batch_size=64, shuffle=True)


In [None]:
class DNNH(nn.Module):
    def __init__(self, hash_bits):
        super(DNNH, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=5, stride=1, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.divide_and_encode = nn.Sequential(
            nn.Linear(64 * 8 * 8, 50 * hash_bits),  # Assuming output from CNN is (64, 8, 8)
            nn.Tanh(),
            nn.Linear(50 * hash_bits, hash_bits),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.divide_and_encode(x)
        return x

# Define the model with a specific number of hash bits
hash_bits = 24
model = DNNH(hash_bits=hash_bits)




## Shared Sub Network

In [None]:
class SharedSubNet(nn.Module):
    def __init__(self):
        super(SharedSubNet, self).__init__()
        # Starting with a larger filter to capture initial features; reduced stride to fit CIFAR-10 image size
        self.conv1 = nn.Conv2d(3, 96, kernel_size=5, stride=1, padding=2)  # Adapted for RGB images
        self.conv2 = nn.Conv2d(96, 96, kernel_size=1, stride=1, padding=0)
        self.pool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        # Second block
        self.conv3 = nn.Conv2d(96, 256, kernel_size=5, stride=1, padding=2)
        self.conv4 = nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0)
        self.pool2 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        # Third block
        self.conv5 = nn.Conv2d(256, 384, kernel_size=3, stride=1, padding=1)
        self.conv6 = nn.Conv2d(384, 384, kernel_size=1, stride=1, padding=0)
        self.pool3 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        # Fourth block (Adapted to prevent output size from becoming too small)
        self.conv7 = nn.Conv2d(384, 1024, kernel_size=3, stride=1, padding=1)
        self.conv8 = nn.Conv2d(1024, 50, kernel_size=1, stride=1, padding=0)  # Using 50 as an arbitrary number of filters

        # Global average pooling to produce a 1x1 output
        self.global_pool = nn.AdaptiveAvgPool2d((1, 1))

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.pool1(x)
        
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = self.pool2(x)
        
        x = F.relu(self.conv5(x))
        x = F.relu(self.conv6(x))
        x = self.pool3(x)
        
        x = F.relu(self.conv7(x))
        x = F.relu(self.conv8(x))
        
        x = self.global_pool(x)
        x = x.view(x.size(0), -1) 
        return x

## Divide and Encode module

In [None]:
class DivideAndEncode(nn.Module):
    def __init__(self, num_slices=50, bits_per_slice=1, num_bits=24):
        super(DivideAndEncode, self).__init__()
        self.num_slices = num_slices
        self.bits_per_slice = bits_per_slice
        self.fc_layers = nn.ModuleList([nn.Linear(num_slices, bits_per_slice) for _ in range(num_bits)])

    def forward(self, x):
        outputs = []
        for fc in self.fc_layers:
            out = F.sigmoid(fc(x))  # Applying sigmoid to restrict output to [0, 1]
            out = self.piecewise_threshold(out)
            outputs.append(out)
        return torch.cat(outputs, dim=1)

    def piecewise_threshold(self, s, epsilon=0.05):
        return torch.where(s < 0.5 - epsilon, torch.zeros_like(s),
                           torch.where(s > 0.5 + epsilon, torch.ones_like(s), s))


## Triplet Ranking Loss

In [None]:
class TripletRankingLoss(nn.Module):
    def __init__(self, margin=1.0):
        super(TripletRankingLoss, self).__init__()
        self.margin = margin

    def forward(self, anchor, positive, negative):
        distance_positive = (anchor - positive).pow(2).sum(1)  # L2 squared
        distance_negative = (anchor - negative).pow(2).sum(1)  # L2 squared
        losses = F.relu(distance_positive - distance_negative + self.margin)
        return losses.mean()


In [None]:
class IntegratedModel(nn.Module):
    def __init__(self, num_bits):
        super(IntegratedModel, self).__init__()
        self.shared_subnet = SharedSubNet()
        self.divide_and_encode = DivideAndEncode(num_slices=50, bits_per_slice=1, num_bits=num_bits)
        self.fc = nn.Linear(24, 512)  # Adjusted input size
        self.relu = nn.ReLU()
        self.final_fc = nn.Linear(512, num_bits)  # Output bits as the number of output features

    def forward(self, x):
        x = self.shared_subnet(x)
        print("Shape after SharedSubNet:", x.shape)
        x = self.divide_and_encode(x)
        print("Shape after DivideAndEncode:", x.shape)
        x = self.fc(x)
        x = self.relu(x)
        x = self.final_fc(x)
        x = torch.tanh(x)
        x = F.normalize(x, p=2, dim=1)
        return x


## Training

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Model, optimizer, and loss
model = IntegratedModel(num_bits=24).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_func = TripletRankingLoss()

for epoch in range(10):
    model.train()
    total_loss = 0.0
    for anchor, positive, negative, _, _ in train_loader:
        anchor, positive, negative = anchor.to(device), positive.to(device), negative.to(device)
        optimizer.zero_grad()
        anchor_output = model(anchor)
        positive_output = model(positive)
        negative_output = model(negative)
        loss = loss_func(anchor_output, positive_output, negative_output)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    print(f'Epoch {epoch+1}, Loss: {total_loss / len(train_loader)}')

## KNN on hashed embeddings

In [None]:
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import classification_report
from sklearn.metrics import average_precision_score
from sklearn.preprocessing import label_binarize
import numpy as np

def evaluate_model(model, test_loader, device):
    model.eval()
    embeddings = []
    labels = []
    with torch.no_grad():
        for anchor, _, _, label_a, _ in test_loader:  # Correctly unpack all elements
            anchor = anchor.to(device)
            output = model(anchor)
            embeddings.append(output.cpu())
            labels.append(label_a)
    embeddings = torch.cat(embeddings)
    labels = torch.cat(labels)
    return embeddings, labels



# Extract hash codes
train_codes, train_labels = evaluate_model(model, train_loader, device)
test_codes, test_labels = evaluate_model(model, test_loader, device)

# Classification with KNN
knn = KNeighborsClassifier(n_neighbors=10, metric='hamming')
knn.fit(train_codes, train_labels)
predictions = knn.predict(test_codes)
y_pred_proba = knn.predict_proba(test_codes)

print(classification_report(test_labels, predictions))

# Binarize the labels for a one-vs-rest computation
y_test_binarized = label_binarize(test_labels, classes=np.unique(train_labels))  # Updated to use `test_labels`

# Calculate the average precision for each class
average_precisions = []
for i in range(y_test_binarized.shape[1]):  # iterate over classes
    average_precisions.append(average_precision_score(y_test_binarized[:, i], y_pred_proba[:, i]))

# Compute the mean of the average precisions
map_score = np.mean(average_precisions)
print(f'Mean Average Precision (MAP): {map_score}')
