In [2]:
import os
import sys
import time
import copy

from typing import Dict, List, Tuple, Optional, Set
import torch
from torch import nn  
from torch.utils import data  
from torch.utils.data import DataLoader, Dataset

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt

In [7]:
class MLP(nn.Module):
    def __init__(self, input_dim, output_dim=1, hidden_dim=[512,512], act_func=nn.ReLU()):
        super(MLP,self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        self.act_func = act_func
        
        layers = []
        for i in range(len(hidden_dim)):
            if i==0:
                layer = nn.Linear(input_dim, hidden_dim[i])
            else:
                layer = nn.Linear(hidden_dim[i-1], hidden_dim[i])
            nn.init.xavier_uniform_(layer.weight)
            nn.init.zeros_(layer.bias)
            layers.append(layer)
            #layers.append(nn.ReLU(True))
            layers.append(act_func)
        if len(hidden_dim):                #if there is more than one hidden layer
            layer = nn.Linear(hidden_dim[-1], output_dim)
        else:
            layer = nn.Linear(input_dim, output_dim)
        nn.init.xavier_uniform_(layer.weight)
        nn.init.zeros_(layer.bias)
        layers.append(layer)
        
        self._main = nn.Sequential(*layers)
        
    def forward(self, x):
        out = x.view(x.shape[0], self.input_dim)
        out = self._main(out)
        return out

class TwoInputMLPWrapper(nn.Module):
    def __init__(self, func):
        super(TwoInputMLPWrapper,self).__init__()
        self.func = func
        
    def forward(self, x, y):
        xy = torch.cat([x,y],dim=1)
        
        return self.func(xy)
    
def l2_normalizer(x, dim=1):
    
    norm = torch.sqrt(torch.square(x).sum(dim=dim,keepdim=True))
    x_norm = x / norm
    
    return x_norm

In [None]:
class BilinearCritic(nn.Module):
    
    '''
    encoder_x : dx -> feature_dim
    encoder_y : dy -> feature_dim
    u_func : 2*feature_dim -> 1
    '''
    
    def __init__(self,
                 encoder_x: nn.Module,
                 encoder_y: nn.Module,
                 u_func: nn.Module,
                 tau: Optional[float] = 1.):
        
        super(BilinearCritic,self).__init__()
        self.encoder_x = encoder_x
        self.encoder_y = encoder_y
        self.u_func = u_func
        self.tau = torch.nn.Parameter(torch.Tensor([tau]))
        
    
    def forward(self, x, y, tau=None):
        if tau is None:
            tau = self.tau
        tau = torch.sqrt(tau)
        hx = self.norm(self.encoder_x(x))
        hy = self.norm(self.encoder_y(y))
        u = self.u_func(hx,hy)
        
        return hx/tau, hy/tau, u  
    
    def norm(self,z):
        return torch.nn.functional.normalize(z,dim=1)

In [None]:
class BilinearFLO(nn.Module):
    
    def __init__(self,
         critic: nn.Module, 
         u_func: Optional[nn.Module] = None,
         K: Optional[int] = None,
         args: Optional[Dict] = None,
         cuda: Optional[int] = None) -> None:
        
        super(BilinearFLO,self).__init__()
        self.critic = critic
        self.u_func = u_func
        self.K = K
    
    def forward(self, x, y, y0,K=None):
        
        '''
        x:    n x p
        y:    n x d true
        y0:   n x d fake 
        returns negative MI (i.e., can be directly used as loss for minimization)
        '''
        if K is None:
            K = self.K 
        output  = self.PMI(x,y,y0,K)
        return output.mean()
    
    def MI(self, x, y, K=10):
        mi = 0
        for k in range(K):
            y0 = y[torch.randperm(y.size()[0])]
            mi += self.forward(x,y,y0)
            
        return -mi/K    
    
    def PMI(self, x, y, y0=None, K=None):
        '''
        x:    n x p
        y:    n x d true
        y0:   n x d fake 
        '''

        if self.u_func is not None:
            u  = self.u_func(x, y)
            if K is not None:
            
                for k in range(K-1):

                    if k==0:
                        y0 = y0
                        g0 = self.critic(x, y0)
                    else:
                        y0 = y[torch.randperm(y.size()[0])]
                        g0 = torch.cat((g0,self.critic(x, y0)),1)

                g0_logsumexp = torch.logsumexp(g0,1).view(-1,1)
                output = u + torch.exp(-u+g0_logsumexp-g)/(K-1) - 1
            else:               
                
                g = self.critic(x, y)
                g0 = self.critic(x, y0)
               
                output = u + torch.exp(-u+g0-g) - 1
        else:
            # one func mode
            gu = self.critic(x,y)
            if isinstance(gu, tuple):
                hx,hy,u = gu
                similarity_matrix = hx @ hy.t()
                pos_mask = torch.eye(hx.size(0),dtype=torch.bool)
                g = similarity_matrix[pos_mask].view(hx.size(0),-1)
                g0 = similarity_matrix[~pos_mask].view(hx.size(0),-1)
                g0_logsumexp = torch.logsumexp(g0,1).view(-1,1)
                output = u + torch.exp(-u+g0_logsumexp-g)/(hx.size(0)-1) - 1

            else:      
                g, u = torch.chunk(self.critic(x,y),2,dim=1)
                if K is not None:

                    for k in range(K-1):

                        if k==0:
                            y0 = y0
                            g0,_ = torch.chunk(self.critic(x,y0),2,dim=1)
                        else:
                            y0 = y[torch.randperm(y.size()[0])]
                            g00,_ = torch.chunk(self.critic(x,y0),2,dim=1)
                            g0 = torch.cat((g0,g00),1)

                    g0_logsumexp = torch.logsumexp(g0,1).view(-1,1)
                    output = u + torch.exp(-u+g0_logsumexp-g)/(K-1) - 1
                else:    

                    g0, _ = torch.chunk(self.critic(x,y0),2,dim=1)
                    output = u + torch.exp(-u+g0-g) - 1
        
        return output

In [None]:
args = {}
args['lr'] = 1e-3
args['latent_dim'] = 100
args['num_epochs'] = int(50*np.sqrt(k/16))
args["input_dim"] = 2*p
args['output_dim'] = 2
args['batch_size'] = k
args['feature_dim'] = 512

encoder_x = MLP(p, output_dim=args['feature_dim'])
encoder_y = MLP(p, output_dim=args['feature_dim'])
u_func = TwoInputMLPWrapper(MLP(2*args['feature_dim'],hidden_dim=[128]))
critic = BilinearCritic(encoder_x, encoder_y, u_func)
model = BilinearFLO(critic)

In [15]:
def LossFLO(feat1, feat2, u_func, inv_temp=1., feat2_transposed=False, normalizer=None):
    
    '''
    feat1   bs1 x dim
    feat2   bs2 x dim (dim x bs2 if feat2_tranposed is True)
    normalizer(x, dim=1)
    
    Official implementation of
    Qing Guo, et al. Tight Mutual Information Estimation With Contrastive Fenchel-Legendre Optimization
    NeurIPS 2022
    https://arxiv.org/abs/2107.01131
    
    bs1<=bs2, and (input1[i], input2[i]) are positive pairs
    all (input1[i], input2[j]) i!=j are negative pairs
    when bs1<bs2, the negative samples are augmented (e.g., momentum contrastive (MoCo))
    '''
    
    assert len(feat1.size())==2, 'input1 dimension should be batch_size x feature_dim'
    assert len(feat2.size())==2, 'input2 dimension should be batch_size x feature_dim (or transpose)'
    
    if feat2_transposed is False:
        feat2 = feat2.t()
        
    assert feat1.size(dim=1)==feat2.size(dim=0), 'The feature dimension should match for input1 and input2'
    
    n1 = feat1.size(dim=0)
    n2 = feat2.size(dim=1)
    assert n1<=n2, 'Size of input2 should not be less than input1'
    
    # Normlize feature if normalizer is specified
    if normalizer is not None:
        feat1 = normalizer(feat1, dim=1)
        feat2 = normalizer(feat2, dim=0)
    
    similarity = feat1 @ feat2
    
    mask = torch.eye(n1, dtype=torch.bool)
    if n1<n2:
        mask = torch.cat([mask,torch.zeros([n1,n2-n1], dtype=torch.bool)], dim=1)
    
    positives = similarity[mask].view(n1,-1)
    negatives = similarity[~mask].view(n1,-1)
    
    g = positives * inv_temp
    g0 = negatives * inv_temp
    
    u = u_func(feat1, feat2[:,:n1].t())
    
    g0_logsumexp = torch.logsumexp(g0, dim=1, keepdim=True)
    
    loss_vec = u + torch.exp(-u+g0_logsumexp-g)/(n2-1) - 1
    loss = loss_vec.mean()
    
    res = dict()
    res['loss_vec'] = loss_vec
    res['similarity'] = similarity
    res['u'] = u
    
    return loss, res

In [16]:
bs1 = 10
bs2 = 15
dim = 20

u_func = TwoInputMLPWrapper(MLP(2*dim,hidden_dim=[128]))

x1 = torch.Tensor(np.random.randn(bs1,dim))
x2 = torch.Tensor(np.random.randn(bs2,dim))

In [17]:
loss_flo, res = LossFLO(x1, x2, u_func, normalizer=l2_normalizer)