In [None]:
import pickle
import itertools
import numpy as np

import torch
from torch import nn, optim
import torch.nn.functional as F

In [None]:
device = "cuda:2"

batch_size = 1024
num_validation_batches = 32
num_agents = 4
prob = 0.2 

np.random.seed(42)

In [None]:
def generate_permutation_array(N, num_agents):
    P = np.zeros((N, num_agents))
    for i in range(N):
        P[i] = np.random.permutation(num_agents)
    return P


class Data(object):
    """
    A class for generating data for the matching problem
    """
    def __init__(self, num_agents, prob, corr):
        self.num_agents = num_agents
        self.prob = prob
        self.corr = corr


    def sample_ranking(self, N, prob):
        """
        Samples ranked lists
        Arguments
            N: Number of samples
            prob: Probability of truncation
        Returns:
            Ranked List of shape [N, Num_agents]
        """
        N_trunc = int(N * prob)
        P = generate_permutation_array(N, self.num_agents) + 1

        if N_trunc > 0:
            # Choose indices to truncate
            idx = np.random.choice(N, N_trunc, replace = False)

            # Choose a position to truncate
            trunc = np.random.randint(self.num_agents, size = N_trunc)

            # Normalize so preference to remain single has 0 payoff
            swap_vals = P[idx, trunc]
            P[idx, trunc] = 0
            P[idx] = P[idx] - swap_vals[:, np.newaxis]

        return P/self.num_agents


    def generate_all_ranking(self, include_truncation = True):
        """
        Generates all possible rankings
        Arguments
            include_truncation: Whether to include truncations or only generate complete rankings
        Returns:
            Ranked of list of shape: [m, num_agents]
                where m = N! if complete, (N+1)! if truncations are included
        """
        if include_truncation is False:
            M = np.array(list(itertools.permutations(np.arange(self.num_agents)))) + 1.0
        else:
            M = np.array(list(itertools.permutations(np.arange(self.num_agents + 1))))
            M = (M - M[:, -1:])[:, :-1]

        return M/self.num_agents


    def generate_batch(self, batch_size, prob = None, corr = None):
        """
        Samples a batch of data from training
        Arguments
            batch_size: number of samples
            prob: probability of truncation
        Returns
            P: Men's preferences,
                P_{ij}: How much Man-i prefers to be Women-j
            Q: Women's preferences,
                Q_{ij}: How much Woman-j prefers to be with Man-i
        """
        if corr is None: corr = self.corr
        if prob is None: prob = self.prob

        N = batch_size * self.num_agents

        P = self.sample_ranking(N, prob)
        Q = self.sample_ranking(N, prob)

        P = P.reshape(-1, self.num_agents, self.num_agents)
        Q = Q.reshape(-1, self.num_agents, self.num_agents)

        if corr > 0.00:
            P_common = self.sample_ranking(batch_size, prob).reshape(batch_size, 1, self.num_agents)
            Q_common = self.sample_ranking(batch_size, prob).reshape(batch_size, 1, self.num_agents)

            P_idx = np.random.binomial(1, corr, [batch_size, self.num_agents, 1])
            Q_idx = np.random.binomial(1, corr, [batch_size, self.num_agents, 1])

            P = P * (1 - P_idx) + P_common * P_idx
            Q = Q * (1 - Q_idx) + Q_common * Q_idx

        Q = Q.transpose(0, 2, 1)

        return P, Q


    def compose_misreport(self, P, Q, M, agent_idx, is_P = True):
        """ Composes mis-report
        Arguments:
            P: Men's preference, [Batch_size, num_agents, num_agents]
            Q: Women's preference [Batch_size, num_agents, num_agents]
            M: Ranked List of mis_reports
                    either [num_misreports, num_agents]
                    or [batch_size, num_misreports, num_agents]
            agent_idx: Agent-idx that is mis-reporting
            is_P: if True, Men[agent-idx] misreporting
                    else, Women[agent-idx] misreporting

        Returns:
            P_mis, Q_mis: [batch-size, num_misreports, num_agents, num_agents]

        """
        num_misreports = M.shape[-2]
        P_mis = np.tile(P[:, None, :, :], [1, num_misreports, 1, 1])
        Q_mis = np.tile(Q[:, None, :, :], [1, num_misreports, 1, 1])

        if is_P: P_mis[:, :, agent_idx, :] = M
        else: Q_mis[:, :, :, agent_idx] = M

        return P_mis, Q_mis


    def generate_all_misreports(self, P, Q, agent_idx, is_P, include_truncation = True):
        """ Generates all mis-reports
        Arguments:
            P: Men's preference, [Batch_size, num_agents, num_agents]
            Q: Women's preference [Batch_size, num_agents, num_agents]
            agent_idx: Agent-idx that is mis-reporting
            is_P: if True, Men[agent-idx] misreporting
                    else, Women[agent-idx] misreporting
            include_truncation: Whether to truncate preference or submit complete preferences

        Returns:
            P_mis, Q_mis: [batch-size, M, num_agents, num_agents]
                where M = (num_agents + 1)! if truncations are includes
                      M = (num_agents)! if preferences are complete
        """
        M = self.generate_all_ranking(include_truncation = include_truncation)
        P_mis, Q_mis = self.compose_misreport(P, Q, M, agent_idx, is_P)

        return P_mis, Q_mis


    def sample_misreports(self, P, Q, num_misreports_per_sample, agent_idx, is_P, prob = None):
        """ Samples misreports
        Arguments:
            P: Men's preference, [Batch_size, num_agents, num_agents]
            Q: Women's preference [Batch_size, num_agents, num_agents]
            num_misreports_per_sample: Number of misreports per sample
            agent_idx: Agent-idx that is mis-reporting
            is_P: if True, Men[agent-idx] misreporting
                    else, Women[agent-idx] misreporting
            prob: probability of truncation

        Returns:
            P_mis, Q_mis: [batch-size, num_misreports_per_sample, num_agents, num_agents]
        """
        if prob is None: prob = self.prob

        N = P.shape[0]
        M = self.sample_ranking(N * num_misreports_per_sample, prob).reshape(N, num_misreports_per_sample, -1)
        P_mis, Q_mis = self.compose_misreport(P, Q, M, agent_idx, is_P)

        return P_mis, Q_mis

In [None]:
class Net(nn.Module):
    """ Neural Network Module for Matching """
    def __init__(self, num_agents=4, num_hidden_nodes=256):
        super(Net, self).__init__()
        self.num_agents = num_agents
        self.num_hidden_nodes = num_hidden_nodes

        self.input_block = nn.Sequential(
            # Input Layer
            nn.Linear(2 * (num_agents**2), num_hidden_nodes),
            nn.PReLU(),

            # Layer 1
            nn.Linear(num_hidden_nodes, num_hidden_nodes),
            nn.PReLU(),

            # Layer 2
            nn.Linear(num_hidden_nodes, num_hidden_nodes),
            nn.PReLU(),

            # Layer 3
            nn.Linear(num_hidden_nodes, num_hidden_nodes),
            nn.PReLU(),

            # Layer 4
            nn.Linear(num_hidden_nodes, num_hidden_nodes),
            nn.PReLU()
        )


        # Output Layer
        self.layer_out_r = nn.Linear(num_hidden_nodes, (num_agents + 1) * num_agents)
        self.layer_out_c = nn.Linear(num_hidden_nodes, num_agents * (num_agents + 1))


    def forward(self, p, q):

        p = F.relu(p)
        q = F.relu(q)

        x = torch.stack([p, q], axis = -1)
        x = x.view(-1, self.num_agents * self.num_agents * 2)
        x = self.input_block(x)

        mask_p = (p > 0).to(p.dtype)
        mask_p = F.pad(mask_p, (0, 0, 0, 1, 0, 0), mode='constant', value=1)

        mask_q = (q > 0).to(q.dtype)
        mask_q = F.pad(mask_q, (0, 1, 0, 0, 0, 0), mode='constant', value=1)

        x_r = self.layer_out_r(x)
        x_r = x_r.view(-1, self.num_agents + 1, self.num_agents)
        x_c = self.layer_out_c(x)
        x_c = x_c.view(-1, self.num_agents, self.num_agents + 1)

        x_r = F.softplus(x_r) * mask_p
        x_c = F.softplus(x_c) * mask_q

        x_r = F.normalize(x_r, p = 1, dim = 1, eps=1e-8)
        x_c = F.normalize(x_c, p = 1, dim = 2, eps=1e-8)

        return torch.min(x_r[:, :-1, :], x_c[:, :, :-1])

In [None]:
# Stability Violation
def compute_st(r, p, q):
    wp = F.relu(p[:, :, None, :] - p[:, :, :, None])
    wq = F.relu(q[:, :, None, :] - q[:, None, :, :], 0)
    t = (1 - torch.sum(r, dim = 1, keepdim = True))
    s = (1 - torch.sum(r, dim = 2, keepdim = True))
    rgt_1 = torch.einsum('bjc,bijc->bic', r, wq) + t * F.relu(q)
    rgt_2 = torch.einsum('bia,biac->bic', r, wp) + s * F.relu(p)
    regret =  rgt_1 * rgt_2
    return regret.sum(-1).sum(-1).mean()/num_agents

# IR Violation
def compute_ir(r, p, q):
    ir_1 = r * F.relu(-q)
    ir_2 = r * F.relu(-p)
    ir = ir_1 + ir_2
    return ir.sum(-1).sum(-1).mean()/num_agents

# FOSD Violation
def compute_ic_FOSD(r, p, q, P, Q, data_generator, r_mult = 1):

    IC_viol_P = torch.zeros(num_agents).to(device)
    IC_viol_Q = torch.zeros(num_agents).to(device)

    discount = torch_var((r_mult) ** np.arange(num_agents))

    for agent_idx in range(num_agents):
        P_mis, Q_mis = data_generator.generate_all_misreports(P, Q, agent_idx = agent_idx, is_P = True, include_truncation = True)
        p_mis, q_mis = torch_var(P_mis), torch_var(Q_mis)
        r_mis = model(p_mis.view(-1, num_agents, num_agents), q_mis.view(-1, num_agents, num_agents))
        r_mis = r_mis.view(*P_mis.shape)

        r_diff = (r_mis[:, :, agent_idx, :] - r[:, None, agent_idx, :])*(p[:, None, agent_idx, :] > 0).float()
        _, idx = torch.sort(-p[:, agent_idx, :])
        idx = idx[:, None, :].repeat(1, r_mis.size(1), 1)

        fosd_viol = torch.cumsum(torch.gather(r_diff, -1, idx) * discount, -1)
        IC_viol_P[agent_idx] = F.relu(fosd_viol).max(-1)[0].max(-1)[0].mean(-1)

        P_mis, Q_mis = data_generator.generate_all_misreports(P, Q, agent_idx = agent_idx, is_P = False, include_truncation = True)
        p_mis, q_mis = torch_var(P_mis), torch_var(Q_mis)
        r_mis = model(p_mis.view(-1, num_agents, num_agents), q_mis.view(-1, num_agents, num_agents))
        r_mis = r_mis.view(*Q_mis.shape)

        r_diff = (r_mis[:, :, :, agent_idx] - r[:, None, :, agent_idx])*(q[:, None, :, agent_idx] > 0).float()
        _, idx = torch.sort(-q[:, :, agent_idx])
        idx = idx[:, None, :].repeat(1, r_mis.size(1), 1)

        fosd_viol = torch.cumsum(torch.gather(r_diff, -1, idx) * discount, -1)
        IC_viol_Q[agent_idx] = F.relu(fosd_viol).max(-1)[0].max(-1)[0].mean(-1)

    IC_viol = (IC_viol_P.mean() + IC_viol_Q.mean())*0.5
    return IC_viol

In [None]:
def torch_var(x):
      return torch.Tensor(x).to(device)

In [None]:
def validate(model, data_generator):
    model.eval()
    with torch.no_grad():
        val_st_loss = 0.0
        val_ic_loss = 0.0
        welfare = 0.0
        
        for j in range(num_validation_batches):
            P, Q = data_generator.generate_batch(batch_size)
            p, q = torch_var(P), torch_var(Q)
            r = model(p, q)
            
            R = r.detach().cpu().numpy()
            welfare += (P * R + Q * R).sum()/(batch_size * num_agents * 2)
            
            st_loss = compute_st(r, p, q)
            ic_loss = compute_ic_FOSD(r, p, q, P, Q, data_generator)
            
            val_st_loss += st_loss.item() 
            val_ic_loss += ic_loss.item()

            print("[Batch]", j, "[ST_LOSS]", st_loss.item(), "[IC_LOSS]", ic_loss.item())
            
        val_st_loss = val_st_loss/num_validation_batches
        val_ic_loss = val_ic_loss/num_validation_batches
        welfare = welfare/num_validation_batches

        print("[ST_LOSS]", val_st_loss, "[IC_LOSS]", val_ic_loss, "[Welfare]", welfare)
        
        return val_st_loss, val_ic_loss, welfare

In [None]:
model = Net()
model.to(device)

model_path = ["./models_10k/mixed/", "./models_10k/c00/", "./models_10k/c25/", "./models_10k/c50/", "./models_10k/c75/"]
model_name = ["0.0.pt", "0.25.pt", "0.5.pt", "0.75.pt", "1.0.pt"]

In [None]:
def get_all_results(m):
    corr_values = [0.00, 0.25, 0.50, 0.75]
    lambdas = [0.00, 0.25, 0.50, 0.75, 1.00]
    
    ST_LOSS = []
    IC_LOSS = []
    WF = []
    for corr in corr_values:
        data_generator = Data(num_agents = num_agents, prob = prob, corr = corr)

        ST_loss = []
        IC_loss = []
        wel = []

        for j, lambd in enumerate(lambdas):
            model.load_state_dict(torch.load(model_path[m] + model_name[j], map_location = device))
            st, ic, wf = validate(model, data_generator)
            ST_loss.append(st)
            IC_loss.append(ic)
            wel.append(wf)

        ST_LOSS.append(ST_loss)
        IC_LOSS.append(IC_loss)
        WF.append(wel)
        
    return ST_LOSS, IC_LOSS, WF

In [None]:
result_name = ["mixed", "c00", "c25", "c50", "c75"]
for i in range(5):
    ST_LOSS, IC_LOSS, WF = get_all_results(i)
    result = {"ST_LOSS" : ST_LOSS, "IC_LOSS" : IC_LOSS, "WF" : WF}
    
    print(result_name[i], "DONE.")
    print("ST_LOSS :", ST_LOSS)
    print("IC_LOSS :", IC_LOSS)
    print("WF :", WF)
    
    with open("./models_10k/"+result_name[i]+'_results.pkl', 'wb') as f:
        pickle.dump(result, f)

In [None]:
result_name = ["mixed", "c00", "c25", "c50", "c75"]
for i in range(5):
    with open("./models_10k/"+result_name[i]+'_results.pkl', 'rb') as f:
        loaded_data = pickle.load(f)
    
    print(result_name[i]+'_results.pkl')
    print(loaded_data["ST_LOSS"])
    print(loaded_data["IC_LOSS"])
    print("\n")