In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
import torch
import torch.nn.functional as F
from tqdm import tqdm
import copy
import collections

In [2]:
def gen_data(n, d, p, scale=1.0):
    ys = torch.distributions.Bernoulli(torch.tensor(p)).sample((n, 1))
    zs = torch.randn(n, d)
    xs = zs + scale * (2*ys-1)
    return xs.float(), ys

In [2]:
def chi_sq_dist(x, y):
    return ((x-y)/x).abs().sum()

In [4]:
def train(model, x, y, classWeights, epochs = 1000):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    for i in tqdm(range(epochs)):
        optimizer.zero_grad()
        probPreds = model(x)
        loss = F.binary_cross_entropy(probPreds, y, weight=classWeights[y.long()])
        loss.backward()
        optimizer.step()
    return model

In [12]:
def class_conditional_prob(confusion):
    return confusion / confusion.sum(0, keepdim=True)

def source_target_dist(eps):
    # ccp1 = class_conditional_prob(c1)
    # ccp2 = class_conditional_prob(c2)
    # return max(chi_sq_dist(ccp1[:, 0], ccp2[:, 0]), chi_sq_dist(ccp1[:, 1], ccp2[:, 1]))
    return eps.abs().sum(dim = 0).max()

In [32]:
def cdf(mu, sigma, x):
    return 0.5 * (1 + torch.erf((x - mu) / (sigma * np.sqrt(2))))

In [50]:
ps = .5
ms = 1.
ss = 1.0

mus = torch.tensor([-ms, ms])
sigmas = torch.tensor([ss, ss])

pt = .8
mt = .7
st = 1.0

mut = torch.tensor([-mt, mt])
sigmat = torch.tensor([st, st])

In [57]:
# confusion matrix using the cdf

ccps = torch.zeros(2, 2)
ccps[0, 0] = cdf(mus[0], sigmas[0], 0.0)
ccps[0, 1] = cdf(mus[1], sigmas[1], 0.0)
ccps[1, 0] = 1 - ccps[0, 0]
ccps[1, 1] = 1 - ccps[0, 1]

cs = torch.zeros(2, 2)
cs[:, 0] = ccps[:, 0] * ps
cs[:, 1] = ccps[:, 1] * (1 - ps)


ccpt = torch.zeros(2, 2)
ccpt[0, 0] = cdf(mut[0], sigmat[0], 0.0)
ccpt[0, 1] = cdf(mut[1], sigmat[1], 0.0)
ccpt[1, 0] = 1 - ccpt[0, 0]
ccpt[1, 1] = 1 - ccpt[0, 1]

eps = ccpt - ccps

ct = torch.zeros(2, 2)
ct[:, 0] = ccpt[:, 0] * pt
ct[:, 1] = ccpt[:, 1] * (1 - pt)

In [58]:
muPred = torch.zeros(2)
muPred[0] = cdf(mut[0], sigmat[0], 0.0) * (1-pt) + cdf(mut[1], sigmat[1], 0.0) * pt
muPred[1] = 1 - muPred[0]

In [59]:
muTrue = torch.zeros(2)
muTrue[0] = 1 - pt
muTrue[1] = pt

In [60]:
lambdaMin = torch.linalg.eig(cs)[0].abs().min()
cInv = torch.inverse(cs)
wPred = cInv @ muPred

In [61]:
wTrue = torch.zeros(2)
wTrue[0] = (1-pt) / (1-ps)
wTrue[1] = pt / ps

In [70]:
trueDist = torch.norm(wTrue - wPred)
trueDist

tensor(0.2071)

In [65]:
ourBound21 = np.sqrt(2.)/lambdaMin * source_target_dist(eps) * torch.norm(muTrue)
ourBound21

tensor(0.5692)

In [66]:
ourBound21/trueDist

tensor(2.7487)

In [67]:
eps = ccpt - ccps
lambdaMax = torch.linalg.eig(eps)[0].abs().max()
ourBound19 = lambdaMax / lambdaMin * torch.norm(muTrue)
ourBound19

tensor(0.4025)

In [68]:
ourBound19/trueDist

tensor(1.9436)

In [69]:
ourBound17 = torch.norm(cInv @ eps @ muTrue)
ourBound17

tensor(0.2071)

In [74]:
def main(ps, ms, ss, pt, mt, st):
    mus = torch.tensor([-ms, ms])
    sigmas = torch.tensor([ss, ss])
    mut = torch.tensor([-mt, mt])
    sigmat = torch.tensor([st, st])

    ccps = torch.zeros(2, 2)
    ccps[0, 0] = cdf(mus[0], sigmas[0], 0.0)
    ccps[0, 1] = cdf(mus[1], sigmas[1], 0.0)
    ccps[1, 0] = 1 - ccps[0, 0]
    ccps[1, 1] = 1 - ccps[0, 1]

    cs = torch.zeros(2, 2)
    cs[:, 0] = ccps[:, 0] * ps
    cs[:, 1] = ccps[:, 1] * (1 - ps)


    ccpt = torch.zeros(2, 2)
    ccpt[0, 0] = cdf(mut[0], sigmat[0], 0.0)
    ccpt[0, 1] = cdf(mut[1], sigmat[1], 0.0)
    ccpt[1, 0] = 1 - ccpt[0, 0]
    ccpt[1, 1] = 1 - ccpt[0, 1]

    eps = ccpt - ccps

    ct = torch.zeros(2, 2)
    ct[:, 0] = ccpt[:, 0] * pt
    ct[:, 1] = ccpt[:, 1] * (1 - pt)
    
    muPred = torch.zeros(2)
    muPred[0] = cdf(mut[0], sigmat[0], 0.0) * (1-pt) + cdf(mut[1], sigmat[1], 0.0) * pt
    muPred[1] = 1 - muPred[0]

    muTrue = torch.zeros(2)
    muTrue[0] = 1 - pt
    muTrue[1] = pt

    lambdaMin = torch.linalg.eig(cs)[0].abs().min()
    cInv = torch.inverse(cs)
    wPred = cInv @ muPred

    wTrue = torch.zeros(2)
    wTrue[0] = (1-pt) / (1-ps)
    wTrue[1] = pt / ps
    
    eps = ccpt - ccps
    
    ourBound21 = np.sqrt(2.)/lambdaMin * source_target_dist(eps) * torch.norm(muTrue)
    lambdaMax = torch.linalg.eig(eps)[0].abs().max()
    ourBound19 = lambdaMax / lambdaMin * torch.norm(muTrue)

    ourBound17 = torch.norm(cInv @ eps @ muTrue)

    trueDist = torch.norm(wTrue - wPred)

    return trueDist, ourBound21, ourBound19, ourBound17

In [75]:
ps = .5
ms = 1.
ss = 1.0

pt = .8
mt = .7
st = 1.0

trueDist, ourBound21, ourBound19, ourBound17 = main(ps, ms, ss, pt, mt, st)
print(trueDist, ourBound21, ourBound19, ourBound17)

tensor(0.2071) tensor(0.5692) tensor(0.4025) tensor(0.2071)
