In [13]:
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
import torch
import torch.nn.functional as F
from tqdm import tqdm
import copy
import collections

In [14]:
def source_target_dist(eps):
    return eps.abs().sum(dim = 0).max() / 2

In [15]:
def cdf(mu, sigma, x):
    return 0.5 * (1 + torch.erf((x - mu) / (sigma * np.sqrt(2))))

In [16]:
ps = .5
ms = 1.
ss = 1.0

mus = torch.tensor([-ms, ms])
sigmas = torch.tensor([ss, ss])

pt = .8
mt = 2.
st = 1.0

mut = torch.tensor([-mt, mt])
sigmat = torch.tensor([st, st])

In [17]:
def main1(ps, ms, ss, pt, mt, st):
    '''
    x|y=0 is N(-m, s)
    x|y=1 is N(m, s)
    y=1 with probability p
    '''
    mus = torch.tensor([-ms, ms])
    sigmas = torch.tensor([ss, ss])
    mut = torch.tensor([-mt, mt])
    sigmat = torch.tensor([st, st])

    ccps = torch.zeros(2, 2)
    ccps[0, 0] = cdf(mus[0], sigmas[0], 0.0)
    ccps[0, 1] = cdf(mus[1], sigmas[1], 0.0)
    ccps[1, 0] = 1 - ccps[0, 0]
    ccps[1, 1] = 1 - ccps[0, 1]

    cs = torch.zeros(2, 2)
    cs[:, 0] = ccps[:, 0] * (1 - ps)
    cs[:, 1] = ccps[:, 1] * ps


    ccpt = torch.zeros(2, 2)
    ccpt[0, 0] = cdf(mut[0], sigmat[0], 0.0)
    ccpt[0, 1] = cdf(mut[1], sigmat[1], 0.0)
    ccpt[1, 0] = 1 - ccpt[0, 0]
    ccpt[1, 1] = 1 - ccpt[0, 1]

    ct = torch.zeros(2, 2)
    ct[:, 0] = ccpt[:, 0] * (1 - pt)
    ct[:, 1] = ccpt[:, 1] * pt

    eps = ccpt - ccps

    ct = torch.zeros(2, 2)
    ct[:, 0] = ccpt[:, 0] * (1 - pt)
    ct[:, 1] = ccpt[:, 1] * pt
    
    muPred = torch.zeros(2)
    muPred[0] = cdf(mut[0], sigmat[0], 0.0) * (1-pt) + cdf(mut[1], sigmat[1], 0.0) * pt
    muPred[1] = 1 - muPred[0]

    muTrue = torch.zeros(2)
    muTrue[0] = 1 - pt
    muTrue[1] = pt

    lambdaMin = torch.linalg.eig(cs)[0].abs().min()
    cInv = torch.inverse(cs)
    wPred = cInv @ muPred

    wTrue = torch.zeros(2)
    wTrue[0] = (1-pt) / (1-ps)
    wTrue[1] = pt / ps
    
    eps = ccpt - ccps
    
    ourBound21 = 2 * np.sqrt(2.)/lambdaMin * source_target_dist(eps) * torch.norm(muTrue)

    trueDist = torch.norm(wTrue - wPred)

    return trueDist, ourBound21

In [18]:
ps = .5
ms = 1.
ss = 1.0

# pt = .8
mt = 1.
st = 1.0

for pt in [.5000001, .50001, .501, .51, .6, .7, .8, .9, .95, .99, .999, .9999, .9999999]:
    trueDist, ourBound = main1(ps, ms, ss, pt, mt, st)
    print(pt, mt, st, trueDist, ourBound, ourBound/trueDist)

0.5000001 2.0 1.0 tensor(5.9605e-08) tensor(0.7963) tensor(13359573.)
0.50001 2.0 1.0 tensor(1.1127e-05) tensor(0.7963) tensor(71565.4922)
0.501 2.0 1.0 tensor(0.0011) tensor(0.7963) tensor(707.0856)
0.51 2.0 1.0 tensor(0.0113) tensor(0.7965) tensor(70.7257)
0.6 2.0 1.0 tensor(0.1126) tensor(0.8121) tensor(7.2111)
0.7 2.0 1.0 tensor(0.2252) tensor(0.8576) tensor(3.8079)
0.8 2.0 1.0 tensor(0.3378) tensor(0.9286) tensor(2.7487)
0.9 2.0 1.0 tensor(0.4505) tensor(1.0198) tensor(2.2638)
0.95 2.0 1.0 tensor(0.5068) tensor(1.0713) tensor(2.1140)
0.99 2.0 1.0 tensor(0.5518) tensor(1.1149) tensor(2.0205)
0.999 2.0 1.0 tensor(0.5619) tensor(1.1250) tensor(2.0020)
0.9999 2.0 1.0 tensor(0.5630) tensor(1.1260) tensor(2.0002)
0.9999999 2.0 1.0 tensor(0.5631) tensor(1.1261) tensor(2.)


In [8]:
def main2(ps, ms, ss, pt, mt, st):
    '''
    y|x=0 is .75 * N(-m, s) + .25 * N(m, s)
    y|x=1 is .25 * N(-m, s) + .75 * N(m, s)
    y=1 with probability p
    '''
    mus = torch.tensor([-ms, ms])
    sigmas = torch.tensor([ss, ss])
    mut = torch.tensor([-mt, mt])
    sigmat = torch.tensor([st, st])

    ccps = torch.zeros(2, 2)
    ccps[0, 0] = cdf(mus[0], sigmas[0], 0.0)
    ccps[0, 1] = cdf(mus[1], sigmas[1], 0.0)
    ccps[1, 0] = 1 - ccps[0, 0]
    ccps[1, 1] = 1 - ccps[0, 1]

    cs = torch.zeros(2, 2)
    cs[:, 0] = ccps[:, 0] * (1 - ps)
    cs[:, 1] = ccps[:, 1] * ps


    ccpt = torch.zeros(2, 2)
    ccpt[0, 0] = .75 * cdf(mut[0], sigmat[0], 0.0) + .25 * cdf(mut[1], sigmat[1], 0.0)
    ccpt[0, 1] = .25 * cdf(mut[0], sigmat[0], 0.0) + .75 * cdf(mut[1], sigmat[1], 0.0)
    ccpt[1, 0] = 1 - ccpt[0, 0]
    ccpt[1, 1] = 1 - ccpt[0, 1]

    ct = torch.zeros(2, 2)
    ct[:, 0] = ccpt[:, 0] * (1 - pt)
    ct[:, 1] = ccpt[:, 1] * pt

    eps = ccpt - ccps

    ct = torch.zeros(2, 2)
    ct[:, 0] = ccpt[:, 0] * (1 - pt)
    ct[:, 1] = ccpt[:, 1] * pt
    
    muPred = torch.zeros(2)
    muPred[0] = (1-pt) * (.75 * cdf(mut[0], sigmat[0], 0.0) + .25 * cdf(mut[1], sigmat[1], 0.0)) + pt * (.25 * cdf(mut[0], sigmat[0], 0.0) + .75 * cdf(mut[1], sigmat[1], 0.0))
    muPred[1] = 1 - muPred[0]

    muTrue = torch.zeros(2)
    muTrue[0] = 1 - pt
    muTrue[1] = pt

    lambdaMinC = torch.linalg.eig(cs)[0].abs().min()
    lambdaMaxC = torch.linalg.eig(cs)[0].abs().max()
    lambdaMinE = torch.linalg.eig(eps)[0].abs().min()
    lambdaMaxE = torch.linalg.eig(eps)[0].abs().max()

    cInv = torch.inverse(cs)
    wPred = cInv @ muPred

    wTrue = torch.zeros(2)
    wTrue[0] = (1-pt) / (1-ps)
    wTrue[1] = pt / ps
    
    eps = ccpt - ccps
    # print(cs, eps)
    
    ourBound = 2 * np.sqrt(2.)/lambdaMinC * source_target_dist(eps) * torch.norm(muTrue)
    
    trueDist = torch.norm(wTrue - wPred)

    return trueDist, ourBound

In [10]:
ps = .5
ms = 1.
ss = 1.0

pt = .5
mt = 1.
st = 1.0

arr = []
for pt in [.5000001, .50001, .501, .51, .6, .7, .8, .9, .99, .9999, .9999999]:
    trueDist, ourBound = main2(ps, ms, ss, pt, mt, st)
    print(pt, mt, st, trueDist, ourBound, ourBound/trueDist)

0.5000001 1.0 1.0 tensor(1.3328e-07) tensor(1.0000) tensor(7503001.5000)
0.50001 1.0 1.0 tensor(1.4161e-05) tensor(1.0000) tensor(70614.8047)
0.501 1.0 1.0 tensor(0.0014) tensor(1.0000) tensor(707.2441)
0.51 1.0 1.0 tensor(0.0141) tensor(1.0002) tensor(70.7258)
0.6 1.0 1.0 tensor(0.1414) tensor(1.0198) tensor(7.2111)
0.7 1.0 1.0 tensor(0.2828) tensor(1.0770) tensor(3.8079)
0.8 1.0 1.0 tensor(0.4243) tensor(1.1662) tensor(2.7487)
0.9 1.0 1.0 tensor(0.5657) tensor(1.2806) tensor(2.2638)
0.99 1.0 1.0 tensor(0.6930) tensor(1.4001) tensor(2.0205)
0.9999 1.0 1.0 tensor(0.7070) tensor(1.4141) tensor(2.0002)
0.9999999 1.0 1.0 tensor(0.7071) tensor(1.4142) tensor(2.0000)


In [11]:
def main2(ps, ms1, ms2, ss1, ss2, pt, mt1, mt2, st1, st2):
    '''
    x|y=0 is N(-m1, s1)
    x|y=1 is N(m2, s2)
    target y=1 with probability pt
    
    requires m1 < m2
    assumes ps = .5
    '''
    mus = torch.tensor([ms1, ms2])
    sigmas = torch.tensor([ss1, ss2])
    mut = torch.tensor([mt1, mt2])
    sigmat = torch.tensor([st1, st2])

    decisionBoundary = (ms1 + ms2) / 2

    ccps = torch.zeros(2, 2)
    ccps[0, 0] = cdf(mus[0], sigmas[0], decisionBoundary)
    ccps[0, 1] = cdf(mus[1], sigmas[1], decisionBoundary)
    ccps[1, 0] = 1 - ccps[0, 0]
    ccps[1, 1] = 1 - ccps[0, 1]

    cs = torch.zeros(2, 2)
    cs[:, 0] = ccps[:, 0] * (1 - ps)
    cs[:, 1] = ccps[:, 1] * ps


    ccpt = torch.zeros(2, 2)
    ccpt[0, 0] = cdf(mut[0], sigmat[0], decisionBoundary)
    ccpt[0, 1] = cdf(mut[1], sigmat[1], decisionBoundary)
    ccpt[1, 0] = 1 - ccpt[0, 0]
    ccpt[1, 1] = 1 - ccpt[0, 1]

    ct = torch.zeros(2, 2)
    ct[:, 0] = ccpt[:, 0] * (1 - pt)
    ct[:, 1] = ccpt[:, 1] * pt

    eps = ccpt - ccps

    ct = torch.zeros(2, 2)
    ct[:, 0] = ccpt[:, 0] * (1 - pt)
    ct[:, 1] = ccpt[:, 1] * pt
    
    muPred = torch.zeros(2)
    muPred[0] = cdf(mut[0], sigmat[0], decisionBoundary) * (1-pt) + cdf(mut[1], sigmat[1], decisionBoundary) * pt
    muPred[1] = 1 - muPred[0]

    muTrue = torch.zeros(2)
    muTrue[0] = 1 - pt
    muTrue[1] = pt

    lambdaMin = torch.linalg.eig(cs)[0].abs().min()
    cInv = torch.inverse(cs)
    wPred = cInv @ muPred

    wTrue = torch.zeros(2)
    wTrue[0] = (1-pt) / (1-ps)
    wTrue[1] = pt / ps
    
    eps = ccpt - ccps
    
    ourBound = 2 * np.sqrt(2.)/lambdaMin * source_target_dist(eps) * torch.norm(muTrue)

    trueDist = torch.norm(wTrue - wPred)

    return trueDist, ourBound

In [12]:
ps = .5
ms1 = -5.
ms2 = 1.
ss1 = 1.0
ss2 = 1.0

pt = .9
mt1 = -10.
mt2 = 10.
st1 = 10.0
st2 = 10.0

trueDist, ourBound21 = main2(ps, ms1, ms2, ss1, ss2, pt, mt1, mt2, st1, st2)
for pt in [.5000001, .50001, .501, .51, .6, .7, .8, .9, .99, .9999, .9999999]:
    trueDist, ourBound21 = main2(ps, ms1, ms2, ss1, ss2, pt, mt1, mt2, st1, st2)
    print(pt, ourBound21/trueDist)

0.5000001 tensor(6.1517)
0.50001 tensor(6.1521)
0.501 tensor(6.1932)
0.51 tensor(6.5948)
0.6 tensor(19.0099)
0.7 tensor(19.4888)
0.8 tensor(7.1034)
0.9 tensor(4.6895)
0.99 tensor(3.7729)
0.9999 tensor(3.7029)
0.9999999 tensor(3.7022)
