In [1]:
from typing import Tuple

import numpy as np
import torch
import torch.nn as nn

In [2]:
torch.cuda.is_available()

True

In [3]:
class RankNet(nn.Module):
    def __init__(self, n_input: int, n_output: int=1, n_hidden: int=10) -> None:
        super(RankNet, self).__init__()
        self.linear1 = nn.Linear(n_input, n_hidden)
        self.linear2 = nn.Linear(n_hidden, n_output)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = torch.sigmoid(self.linear1(x))
        x = self.linear2(x)
        return x

In [4]:
def pairwise_loss(s_i: torch.Tensor, s_j: torch.Tensor, S_ij: int, sigma: float=1.0) -> float:
    C = torch.log1p(torch.exp(-sigma * (s_i - s_j)))

    if S_ij == -1:
        C += sigma * (s_i - s_j)
    elif S_ij == 0:
        C += 0.5 * sigma * (s_i - s_j)
    elif S_ij == 1:
        pass
    else:
        raise ValueError
    
    return C


In [5]:
def make_dataset(N_train: torch.Tensor, N_valid: torch.Tensor, D: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    ws = torch.randn(D, 1)

    X_train = torch.randn(N_train, D, requires_grad=True)
    X_valid = torch.randn(N_valid, D, requires_grad=True)

    ys_train_score = torch.mm(X_train, ws)
    ys_valid_score = torch.mm(X_valid, ws)

    bins = np.array([-2, -1, 0, 1])
    ys_train_rel = torch.Tensor(
        np.digitize(ys_train_score.clone().detach().numpy(), bins=bins)
    )
    ys_valid_rel = torch.Tensor(
        np.digitize(ys_valid_score.clone().detach().numpy(), bins=bins)
    )

    return X_train.cuda(), X_valid.cuda(), ys_train_rel.cuda(), ys_valid_rel.cuda()

In [16]:
def swapped_pairs(ys_pred: torch.Tensor, ys_target: torch.Tensor) -> torch.Tensor:
    N = ys_target.shape[0]
    swapped = 0

    for i in range(N-1):
        for j in range(i+1, N):
            if ys_target[i] < ys_target[j]:
                if ys_pred[i] > ys_pred[j]:
                    swapped += 1
            elif ys_target[i] > ys_target[j]:
                if ys_pred[i] < ys_pred[j]:
                    swapped += 1
    
    return swapped

In [17]:
N_train = 500
N_valid = 100

D = 50
epochs = 10
batch_size = 16

n_sampling_combs = 50

In [18]:
X_train, X_valid, ys_train, ys_valid = make_dataset(N_train, N_valid, D)

In [19]:
print(X_train.shape)
print(ys_train.shape)
print(X_valid.shape)
print(ys_valid.shape)

torch.Size([500, 50])
torch.Size([500, 1])
torch.Size([100, 50])
torch.Size([100, 1])


In [20]:
model = RankNet(D).cuda()
optimizer = torch.optim.Adam(model.parameters())

In [21]:
for epoch in range(epochs):
    indices = torch.randperm(N_train)

    X_train = X_train[indices]
    ys_train = ys_train[indices]

    cur_batch = 0
    for _ in range(N_train // batch_size):
        X_batch = X_train[cur_batch: cur_batch+batch_size]
        ys_batch = ys_train[cur_batch: cur_batch+batch_size]
        cur_batch += batch_size

        if len(X_batch) == 0:
            continue

        optimizer.zero_grad()
        batch_loss = torch.zeros(1).cuda()

        y_pred = model(X_batch)

        for _ in range(n_sampling_combs):
            i, j = np.random.choice(range(batch_size), 2)

            s_i = y_pred[i]
            s_j = y_pred[j]

            if ys_batch[i] > ys_batch[j]:
                S_ij = 1
            elif ys_batch[i] == ys_batch[j]:
                S_ij = 0
            else:
                S_ij = -1
            
            loss = pairwise_loss(s_i, s_j, S_ij)
            batch_loss += loss

        batch_loss.backward(retain_graph=True)
        optimizer.step()

    # validation
    with torch.no_grad():
        y_pred = model(X_valid)
        valid_swapped_pairs = swapped_pairs(y_pred, ys_valid)
        print(f"epoch {epoch+1:03d}: {valid_swapped_pairs}/{N_valid * (N_valid-1)//2}")

epoch 001: 1199/4950
epoch 002: 738/4950
epoch 003: 508/4950
epoch 004: 335/4950
epoch 005: 257/4950
epoch 006: 200/4950
epoch 007: 173/4950
epoch 008: 153/4950
epoch 009: 136/4950
epoch 010: 119/4950
