In [None]:
# !git clone https://github.com/toyboy12345/deep-matching.git
# from google.colab import drive
# drive.mount('/content/drive')

## (Strategy Proofness - First Order Stochastic Dominance)
$$
\forall i\in W\cup F \ \forall \succ_i \forall \succ_{-i} \forall \succ'_i \forall j \\
\sum_{j'\succeq j}(g_{ij'}(\succ'_i,\succ_{-i})-g_{ij'}(\succ_i,\succ_{-i})) \leq 0
$$

## (Ex-ante Stability)
$\nexists (w,f)\in W\times F$ s.t. $\exist f'\ [g_{wf'}(\succ)>0\land f\succ_w f']\ \exist w'\ [g_{w'f}(\succ)>0\land w\succ_f w']$

## (Stability of Deterministic Matching)
$$
\forall (w,f)\in W\times F \ g_{wf}+\sum_{f'\succ_w f}g_{wf'}+\sum_{w'\succ_f w}g_{w'f}\geq 1
$$

## (Ex-post Stability)
A randomized matching is **ex-post stable** iff it can be decomposed into deterministic stable matchings.

## (Fractionally Stable)
$$
\forall (w,f)\in W\times F \ g_{wf}+\sum_{f'\succ_w f}g_{wf'}+\sum_{w'\succ_f w}g_{w'f}\geq 1
$$

### (Violation of Fractionally Stability)
$$
\sum_\succ\sum_w\sum_f\max\left\{0,1-g_{wf}(\succ)-\sum_{w'\succ_f w}g_{w'f}(\succ)-\sum_{f'\succ_w f}g_{wf'}(\succ)\right\}
$$

## (Primal)
$$
\begin{align*}
    \min & \sum_\succ\sum_w\sum_f t_{wf}(\succ)\\
    \text{s.t.} & \sum_f g_{wf}(\succ)\leq 1 & \forall\succ\forall w \\
    & \sum_w g_{wf}(\succ)\leq 1 & \forall \succ\forall f\\
    & t_{wf}(\succ)\geq 1-g_{wf}(\succ)-\sum_{w'\succeq_f w}g_{w'f}(\succ)-\sum_{f'\succeq_w f}g_{wf'}(\succ) & \forall\succ\forall w\forall f\\
    & \sum_{f'\succeq_wf}(g_{wf'}(\succ_w',\succ_{-w})-g_{wf'}(\succ))\leq 0 & \forall\succ\forall w\forall\succ_{w}'\forall f\\
    & \sum_{w'\succeq_fw}(g_{w'f}(\succ_f',\succ_{-f})-g_{w'f}(\succ))\leq 0 & \forall\succ\forall f\forall\succ_{f}'\forall w\\
    & g_{wf}(\succ)\geq 0,\ t_{wf}(\succ)\geq 0 & \forall\succ\forall w \forall y
\end{align*}
$$



## (Dual)
$$
\begin{align*}
    \min & \sum_\succ\left(\sum_wx_w(\succ)+\sum_fy_f(\succ)-\sum_w\sum_fz_{wf}(\succ)\right)\\
    \text{s.t.}  
    & x_w(\succ)+y_f(\succ)-z_{wf}(\succ)-\sum_{f'\prec_wf}z_{wf'}(\succ)-\sum_{w'\prec_fw}z_{w'f}(\succ)+\sum_{\succ_w'}\\sum_{f'\prec_w'f}\left(u_{wf'}(\succ_w',\succ_w,\succ_{-w})-u_{wf'}(\succ_w,\succ_w',\succ_{-w})\right)+\sum_{\succ_f'}\sum_{w'\prec_f'w}\left(v_{w'f}(\succ_f',\succ_f,\succ_{-f})-v_{w'f}(\succ_f,\succ_f',\succ_{-f})\right)\geq 0 & \forall\succ\forall w\forall f\\
    & x_w(\succ)\geq 0,\ y_f(\succ)\geq 0,\ 0\leq z_{wf}(\succ)\leq 1 & \forall\succ\forall w\forall f\\
    & u_{wf}(\succ'_w,\succ_w,\succ_{-w})\geq 0 & \forall\succ\forall w\forall\succ_w'\forall f\\
    & v_{wf}(\succ'_f,\succ_f,\succ_{-f})\geq 0 & \forall\succ\forall f\forall\succ_f'\forall w
\end{align*}
$$

In [2]:
import os
import sys
import time
import logging
import argparse
import numpy as np
from random import random
import itertools
from pathlib import Path

sys.path.append(str(Path("primal_dual_matching.ipynb").resolve().parent.parent))

import torch
import torch.nn
from torch import optim
import torch.nn.functional as F

from data import Data
from dual_net import DualNet
from dual_loss import *
from dual_train import *

In [3]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
lambd = np.ones((3,3))*0.001
# lambd = cfg.lambd

cfg = HParams(num_agents = 3,
              device = device,
              lambd = lambd,
              rho = 0.1,
              lagr_iter = 1000,
              batch_size = 512)

cfg.lr = 1e-4

np.random.seed(cfg.seed)

G = Data(cfg)

In [None]:
model = DualNet(cfg)
model.to(device)

In [None]:
train_dual(cfg,G,model)

In [None]:
P,Q = G.generate_batch(1)
p,q = torch.Tensor(P).to(device),torch.Tensor(P).to(device)
model(p,q)

In [None]:
x,y,z,u,v = model(p,q)

In [None]:
compute_uloss(cfg,model,u,p,q)

In [None]:
compute_loss(cfg,model,x,y,z,u,v,p,q,cfg.lambd,cfg.rho)

In [None]:
import os
import sys
import time
import logging
import argparse
import numpy as np
from random import random
import itertools

import torch
from torch import nn, optim
import torch.nn.functional as F

sys.path.append("/content/deep-matching")

from data import Data
from primal_net import PrimalNet
from primal_loss import *
from primal_train import *

from baselines import RSD,DA

In [None]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
# lambd = np.ones((3,3))*1
lambd = cfg.lambd

cfg = HParams(num_agents = 3,
              device = device,
              lambd = lambd,
              rho = 100,
              lagr_iter = 1000,
              batch_size = 512)

cfg.lr = 1e-5

np.random.seed(cfg.seed)

G = Data(cfg)

In [None]:
model = PrimalNet(cfg)
model.to(device)

In [None]:
train_primal(cfg,G,model)

In [None]:
a = torch.randn((2,3,3))
a

In [None]:
torch.save(model,"/content/deep-matching/models/primal/model_1.pth")

In [None]:
if device == "cpu":
    model2 = torch.load("/content/deep-matching/models/model_sp_verystable.pth",map_location=torch.device('cpu'))
else:
    model2 = torch.load("/content/deep-matching/models/model_sp_verystable.pth")

In [None]:
P,Q = G.generate_batch(100)
p,q = torch.Tensor(P).to(device),torch.Tensor(Q).to(device)
r = model(p,q)
r2 = model2(p,q)

# print(r)
# print(r2)
# print(DA(p,q))
# print(RSD(p,q))

print(compute_spv_w(cfg,model,r,p,q).sum(),compute_spv_f(cfg,model,r,p,q).sum())
print(compute_spv_w(cfg,model2,r2,p,q).sum(),compute_spv_f(cfg,model2,r2,p,q).sum())
print(compute_spv_w(cfg,DA,DA(p,q),p,q).sum(),compute_spv_f(cfg,DA,DA(p,q),p,q).sum())
print(compute_spv_w(cfg,RSD,RSD(p,q),p,q).sum(),compute_spv_f(cfg,RSD,RSD(p,q),p,q).sum())

In [None]:
print(compute_loss(cfg,model,model(p,q),p,q,cfg.lambd,cfg.rho))
print(compute_loss(cfg,model2,model2(p,q),p,q,cfg.lambd,cfg.rho))
print(compute_loss(cfg,DA,DA(p,q),p,q,cfg.lambd,cfg.rho))
print(compute_loss(cfg,RSD,RSD(p,q),p,q,cfg.lambd,cfg.rho))

In [None]:
from baselines import RSD,DA

In [None]:
compute_loss(cfg,RSD,RSD(p,q),p,q,torch.Tensor(cfg.lambd).to(device),cfg.rho)

In [None]:
compute_t(r,p,q)

In [None]:
%%writefile deep-matching/primal_net.py
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np


class PrimalNet(nn.Module):
    """ Neural Network Module for Matching """
    def __init__(self, cfg):
        super(PrimalNet, self).__init__()
        self.cfg = cfg
        num_agents = self.cfg.num_agents
        num_hidden_nodes = self.cfg.num_hidden_nodes

        self.input_block = nn.Sequential(
            # Input Layer
            nn.Linear(2 * num_agents*num_agents, num_hidden_nodes),
            nn.LeakyReLU(),

            # Layer 1
            nn.Linear(num_hidden_nodes, num_hidden_nodes),
            nn.LeakyReLU(),

            # Layer 2
            nn.Linear(num_hidden_nodes, num_hidden_nodes),
            nn.LeakyReLU(),

            # Layer 3
            nn.Linear(num_hidden_nodes, num_hidden_nodes),
            nn.LeakyReLU(),

            # Layer 4
            nn.Linear(num_hidden_nodes, num_hidden_nodes),
            nn.LeakyReLU())


        # Output Layer
        self.layer_out = nn.Linear(num_hidden_nodes, num_agents * num_agents)


    def forward(self, p, q):
        x = torch.stack([p, q], axis = -1)
        x = x.view(-1, self.cfg.num_agents * self.cfg.num_agents * 2)
        x = self.input_block(x)

        r = self.layer_out(x)
        r = r.view(-1, self.cfg.num_agents, self.cfg.num_agents)
        r = F.softplus(r)
        r = F.normalize(r, p = 1, dim = 1, eps=1e-8)

        return r

In [None]:
%%writefile deep-matching/primal_loss.py
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
from data import Data

def compute_t(r, p, q):
    wp = torch.where(p[:, :, None, :] - p[:, :, :, None]>0,1,0).to(torch.float)
    wq = torch.where(q[:, :, None, :] - q[:, None, :, :]>0,1,0).to(torch.float)
    t =  r + torch.einsum('bjc,bijc->bic', r, wq) + torch.einsum('bia,biac->bic', r, wp)  - 1
    return t

def compute_spv_w(cfg, model, r, p, q):
    num_agents = cfg.num_agents
    device = cfg.device
    G = Data(cfg)

    P,Q = p.to('cpu').detach().numpy().copy(),q.to('cpu').detach().numpy().copy()
    spv_w = torch.zeros((num_agents,num_agents)).to(device)
    for agent_idx in range(num_agents):
        P_mis, Q_mis = G.generate_all_misreports(P, Q, agent_idx = agent_idx, is_P = True, include_truncation = False)
        p_mis, q_mis = torch.Tensor(P_mis).to(device), torch.Tensor(Q_mis).to(device)
        r_mis = model(p_mis.view(-1, num_agents, num_agents), q_mis.view(-1, num_agents, num_agents))
        r_mis = r_mis.view(p.shape[0],-1,num_agents,num_agents)

        r_mis_agent = r_mis[:,:,agent_idx,:]

        r_agent = r[:,agent_idx,:]
        r_agent = r_agent.repeat(1,r_mis_agent.shape[1]).view(r_mis_agent.shape[0],r_mis_agent.shape[1],r_mis_agent.shape[2])

        for f in range(num_agents):
            mask = torch.where(p[:,agent_idx,:]>=p[:,agent_idx,f].view(-1,1),1,0)
            mask = mask.repeat(1,r_mis_agent.shape[1]).view(r_mis_agent.shape[0],r_mis_agent.shape[1],r_mis_agent.shape[2])
            spv_w[agent_idx,f] = ((r_mis_agent - r_agent)*mask).sum(-1).relu().sum(-1).mean()
    return spv_w

def compute_spv_f(cfg, model, r, p, q):
    num_agents = cfg.num_agents
    device = cfg.device
    G = Data(cfg)

    P,Q = p.to('cpu').detach().numpy().copy(),q.to('cpu').detach().numpy().copy()
    spv_f = torch.zeros((num_agents,num_agents)).to(device)
    for agent_idx in range(num_agents):
        P_mis, Q_mis = G.generate_all_misreports(P, Q, agent_idx = agent_idx, is_P = True, include_truncation = False)
        p_mis, q_mis = torch.Tensor(P_mis).to(device), torch.Tensor(Q_mis).to(device)
        r_mis = model(p_mis.view(-1, num_agents, num_agents), q_mis.view(-1, num_agents, num_agents))
        r_mis = r_mis.view(p.shape[0],-1,num_agents,num_agents)

        r_mis_agent = r_mis[:,:,:,agent_idx]

        r_agent = r[:,:,agent_idx]
        r_agent = r_agent.repeat(1,r_mis_agent.shape[1]).view(r_mis_agent.shape[0],r_mis_agent.shape[1],r_mis_agent.shape[2])

        for w in range(num_agents):
            mask = torch.where(q[:,:,agent_idx]>=q[:,w,agent_idx].view(-1,1),1,0)
            mask = mask.repeat(1,r_mis_agent.shape[1]).view(r_mis_agent.shape[0],r_mis_agent.shape[1],r_mis_agent.shape[2])
            spv_f[w,agent_idx] = ((r_mis_agent - r_agent)*mask).sum(-1).relu().sum(-1).mean()
    return spv_f

def compute_loss(cfg, model, r, p, q, lambd, rho):
    t = compute_t(r,p,q)
    spv_w = compute_spv_w(cfg,model,r,p,q)
    spv_f = compute_spv_f(cfg,model,r,p,q)

    constr_vio = spv_w+spv_f

    loss = torch.sum(t) - 2*torch.sum(r) + (constr_vio*lambd).sum() + 0.5*rho*constr_vio.square().sum()

    return loss,constr_vio.sum()

In [None]:
%%writefile deep-matching/primal_train.py
import os
import sys
import time
import logging
import numpy as np

import torch
from torch import nn, optim
import torch.nn.functional as F

from data import Data
from primal_net import PrimalNet
from primal_loss import *

class HParams:
    def __init__(self, num_agents = 3,
                 batch_size = 1024, num_hidden_layers = 4, num_hidden_nodes = 256, lr = 5e-3, epochs = 50000,
                 print_iter = 100, val_iter = 1000, num_val_batches = 200,
                 prob = 0, corr = 0, seed = 0, device = "cuda:0",
                 lambd = torch.ones((3,3)), rho = 1, lagr_iter = 100):
        self.num_agents = num_agents
        self.batch_size = batch_size
        self.num_hidden_layers = 4
        self.num_hidden_nodes = 256
        self.lr = 5e-3
        self.epochs = 50000
        self.print_iter = 100
        self.val_iter = 1000
        self.save_iter = self.epochs - 1
        self.num_val_batches = 200

        # Higher probability => More truncation
        self.prob = prob
        # Correlation of rankings
        self.corr = corr
        # Run seed
        self.seed = seed

        self.device = device

        self.lambd = lambd
        self.rho = rho
        self.lagr_iter = lagr_iter

def train_primal(cfg, G, model, include_truncation = False):
    # # File names
    # root_dir = os.path.join("experiments", "agents_%d"%(cfg.num_agents), "corr_%.2f"%(cfg.corr))
    # log_fname = os.path.join(root_dir, "LOG_%d_lambd_%f_prob_%.2f_corr_%.2f.txt"%(cfg.seed, cfg.lambd, cfg.prob, cfg.corr))
    # model_path = os.path.join(root_dir, "MODEL_%d_lambd_%f_prob_%.2f_corr_%.2f"%(cfg.seed, cfg.lambd, cfg.prob, cfg.corr))
    # os.makedirs(root_dir, exist_ok=True)

    # # Logger
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)

    num_agents = cfg.num_agents

    if not logger.hasHandlers():
        handler = logging.StreamHandler()
        handler.setLevel(logging.INFO)
        formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
        handler.setFormatter(formatter)
        logger.addHandler(handler)

        handler = logging.FileHandler(log_fname, 'w')
        handler.setLevel(logging.INFO)
        formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
        handler.setFormatter(formatter)
        logger.addHandler(handler)

    # Optimizer
    opt = torch.optim.Adam(model.parameters(), lr = cfg.lr)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(opt, milestones=[10000,25000], gamma=0.5)

    # Trainer
    tic = time.time()
    i = 0

    lambd = cfg.lambd

    while i < cfg.epochs:

        # Reset opt
        opt.zero_grad()
        model.train()

        # Inference
        P, Q = G.generate_batch(cfg.batch_size)
        p, q = torch.Tensor(P).to(cfg.device), torch.Tensor(Q).to(cfg.device)
        r = model(p, q)

        loss,constr_vio = compute_loss(cfg,model,r,p,q,torch.Tensor(lambd).to(cfg.device),cfg.rho)
        if (i>0) and (i%cfg.lagr_iter == 0):
            lambd += cfg.rho*constr_vio.item()
            print(lambd)

        loss.backward(retain_graph=True)

        opt.step()
        scheduler.step()
        t_elapsed = time.time() - tic


        # Validation
        if i% cfg.print_iter == 0 or i == cfg.epochs - 1:
            logger.info("[TRAIN-ITER]: %d, [Time-Elapsed]: %f, [Total-Loss]: %f"%(i, t_elapsed, loss.item()))
            logger.info("[CONSTR-Vio]: %f"%(constr_vio.item()))

        if (i>0) and (i % cfg.save_iter == 0) or i == cfg.epochs - 1:
            torch.save(model, "deep-matching/models/primal/model_tmp.pth")

        if ((i>0) and (i% cfg.val_iter == 0)) or i == cfg.epochs - 1:
            model.eval()
            with torch.no_grad():
                val_loss = 0
                val_constr_vio = 0
                for j in range(cfg.num_val_batches):
                    P, Q = G.generate_batch(cfg.batch_size)
                    p, q = torch.Tensor(P).to(cfg.device), torch.Tensor(Q).to(cfg.device)
                    r = model(p, q)
                    loss,constr_vio = compute_loss(cfg,model,r,p,q,torch.Tensor(lambd).to(cfg.device),cfg.rho)
                    val_loss += loss.item()
                    val_constr_vio += constr_vio.item()
                logger.info("\t[VAL-ITER]: %d, [LOSS]: %f, [Constr-vio]: %f"%(i, val_loss/cfg.num_val_batches, val_constr_vio/cfg.num_val_batches))

        i += 1