## Imports

In [None]:
from scipy import stats
from statsmodels.distributions.empirical_distribution import ECDF

import mxnet as mx
import numpy as np
from mxnet import nd, autograd
from mxnet import gluon

# two customized modules
from labelshift import *
from utils4gluon import *
from data_shift import *

## Configurations

In [None]:
domain = 'cifar10' # choices: 'mnist', 'cifar10', 'cifar100'

# p_P(y)
p_P = [.05, .55, .05, .05, .05, .05 , .05, .05, .05, .05]

# P_Q(y)
p_Q = None

## Prepare the data

In [None]:

ctx = mx.cpu()

##################################
#  Add some code to make this play well 
#  with other datasets
##################################


if domain == "mnist":
    mnist = mx.test_utils.get_mnist()
    num_inputs = 784
    num_outputs = 10
    dfeat = 784
    nclass = 10
    dataset = mnist
    X = dataset["train_data"]
    y = dataset["train_label"]
    # lastly get the test set and its corresponding iterator
    Xtest = dataset["test_data"]
    ytest = dataset["test_label"]
    
elif domain == "cifar10":
    num_inputs = 3072
    num_outputs = 10
    dfeat = 3072
    nclass = 10
    def transform(data, label):
        return nd.transpose(data.astype(np.float32), (2,0,1))/255, label.astype(np.float32)
    train_DS = gluon.data.vision.CIFAR10(train=True, transform=transform)
    test_DS = gluon.data.vision.CIFAR10(train=False, transform=transform)
    
    def transform_data(data):
        return nd.transpose(data.astype(np.float32), axes=(0,3,1,2))/255

    def transform_label(label):
        return nd.transpose(label.astype(np.float32))

    X = transform_data(train_DS._data).asnumpy()
    y = transform_label(nd.array(train_DS._label)).asnumpy()

    Xtest = transform_data(test_DS._data).asnumpy()
    ytest = transform_label(nd.array(test_DS._label)).asnumpy()
    

batch_size = 64
n = X.shape[0]

################################################
#  Set the label distribution at train time
################################################
X, y = tweak_dist(X, y, 10, 100000, Py)


In [None]:
X.shape

## Random splits

In [None]:
n = X.shape[0]

# Random permutation of the data
rand_idx = np.random.permutation(n)
X = X[rand_idx,...]
y = y[rand_idx]

num = 2

Xtrain = X[:(n//num),:,:,:]
ytrain = y[:(n//num)]
Xval = X[(n//num):(2*n//num),:,:,:]
yval = y[(n//num):(2*n//num):]

################################################
#  Set the label distribution at test time
################################################
# Xtest, ytest = tweak_dist(X, y, 10, 10000, Py)

In [None]:
print(nd.one_hot(nd.array(ytrain), 10).sum(axis=0) / len(ytrain))
print(nd.one_hot(nd.array(yval), 10).sum(axis=0) / len(yval))
print(nd.one_hot(nd.array(ytest), 10).sum(axis=0) / len(ytest))

In [None]:
Xval.shape

In [None]:
X.shape

##  Train a Classifier

In [None]:
num_hidden = 256
net = gluon.nn.Sequential()
with net.name_scope():
    net.add(gluon.nn.Dense(num_hidden, activation="relu"))
    net.add(gluon.nn.Dense(num_hidden, activation="relu"))
    net.add(gluon.nn.Dense(num_outputs))

net.collect_params().initialize(mx.init.Xavier(magnitude=2.24), ctx=ctx)
softmax_cross_entropy = gluon.loss.SoftmaxCrossEntropyLoss()
trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': .1})

epochs = 5

# Training
weighted_train(net, softmax_cross_entropy, trainer, Xtrain, ytrain, Xval, yval, ctx, dfeat, epoch=epochs, weightfunc=None)


# Prediction
ypred_s, ypred_s_soft = predict_all(Xval, net, ctx, dfeat)
ypred_t, ypred_t_soft = predict_all(Xtest, net, ctx, dfeat)


# Converting to numpy array for later convenience
ypred_s= ypred_s.asnumpy()
ypred_s_soft = ypred_s_soft.asnumpy()
ypred_t=ypred_t.asnumpy()

## Estimate Wt and Py

In [None]:
wt = estimate_labelshift_ratio(yval, ypred_s, ypred_t,num_outputs)

Py_est = estimate_target_dist(wt, yval,num_outputs)

Py_true =calculate_marginal(ytest,num_outputs)
Py_base =calculate_marginal(yval,num_outputs)

wt_true = Py_true/Py_base

print(np.concatenate((wt,wt_true),axis=1))
print(np.concatenate((Py_est,Py_true),axis=1))

print("||wt - wt_true||^2  = " + repr(np.sum((wt-wt_true)**2)/np.linalg.norm(wt_true)**2))
print("KL(Py_est|| Py_true) = " + repr(stats.entropy(Py_est,Py_base)))

## Solve weighted ERM and compare to previously trained models

In [None]:
data_test = mx.io.NDArrayIter(Xtest, ytest, batch_size, shuffle=False)

acc_unweighted =  evaluate_accuracy(data_test, net, ctx, dfeat) # in fact, drawing confusion matrix maybe more informative

print("Accuracy unweighted", acc_unweighted)

wt_ndarray = nd.array(wt,ctx=ctx)

weightfunc = lambda x,y: wt_ndarray[y.asnumpy().astype(int)]

# Train a model using the following!
net2 = gluon.nn.Sequential()
with net2.name_scope():
    net2.add(gluon.nn.Dense(num_hidden, activation="relu"))
    net2.add(gluon.nn.Dense(num_hidden, activation="relu"))
    net2.add(gluon.nn.Dense(num_outputs))

net2.collect_params().initialize(mx.init.Xavier(magnitude=2.24), ctx=ctx)
trainer2 = gluon.Trainer(net2.collect_params(), 'sgd', {'learning_rate': .1})
epochs2 = 5

# Training
weighted_train(net2, softmax_cross_entropy, trainer2, Xtrain, ytrain, Xval, yval, ctx, dfeat, epoch=epochs2, weightfunc=weightfunc)

data_test.reset()
acc_weighted = evaluate_accuracy(data_test, net2, ctx, dfeat)

print("Accuracy weighted", acc_weighted)