# Learning under Label Shift - CIFAR10, Softmax responses

Choose one label to knock out during training. Then at test time all labels are uniform.

In [12]:
import numpy as np
import mxnet as mx
from mxnet import nd, autograd, gluon
import matplotlib.pyplot as plt

First let's get some data...

In [13]:
num_classes = 10

# num_train = 90000
# num_valid = 1000
# num_test = 1000

# py_train = [.5, .5]
# py_test = [.2, .8]

knockout_label = 0
knockout_train_number = 2000
knockout_valid_number = 500

In [14]:
def transform(data, label):
    return nd.transpose(data.astype(np.float32), (2,0,1))/255, label.astype(np.float32)
train_DS = gluon.data.vision.CIFAR10(train=True, transform=transform)
test_DS = gluon.data.vision.CIFAR10(train=False, transform=transform)


In [15]:
train_DS._data.shape

(50000, 32, 32, 3)

In [16]:
train_DS._label

array([6, 9, 9, ..., 9, 1, 1], dtype=int32)

In [17]:
def transform_data(data):
    return nd.transpose(data.astype(np.float32), axes=(0,3,1,2))/255

def transform_label(label):
    return nd.transpose(label.astype(np.float32))

In [18]:
# get the train data
X_train_raw = transform_data(train_DS._data[:-10000])
Y_train_raw = transform_label(nd.array(train_DS._label[:-10000]))

# get the validation data
X_valid_raw = transform_data(train_DS._data[-10000:])
Y_valid_raw = transform_label(nd.array(train_DS._label[-10000:]))


train_counts = nd.sum(nd.one_hot(Y_train_raw, 10), axis=0)
valid_counts = nd.sum(nd.one_hot(Y_valid_raw, 10), axis=0)

num_train = 40000-(int(train_counts[knockout_label].asscalar()) - knockout_train_number)
X_train = nd.zeros((num_train, 3, 32, 32))
Y_train = nd.zeros(num_train)

num_valid = 10000-(int(valid_counts[knockout_label].asscalar()) - knockout_valid_number)
X_valid = nd.zeros((num_valid, 3, 32, 32))
Y_valid = nd.zeros(num_valid)


##########################################################
#   Construct a training set that has only ``knockout_train_number`` 
#   examples corresponding to knockout_label
##########################################################
knockout_count = 0
j = 0
for i in range(len(X_train_raw)):
    if int(Y_train_raw[i].asscalar()) == knockout_label:
#         print("knockout_label_called")
        knockout_count += 1
        if knockout_count <= knockout_train_number:
            X_train[j] = X_train_raw[i]
            Y_train[j] = Y_train_raw[i]
            j += 1
            print("added knockout label, kockout count: %s, j: %s, i: %s" % (knockout_count, j, i))
    else:
#         print("non_knockout_called")
        X_train[j] = X_train_raw[i]
        Y_train[j] = Y_train_raw[i]
        j += 1
            
    if i % 1000==0:
        print(i)
            

##########################################################
#   Construct a training set that has only ``knockout_valid_number`` 
#   examples corresponding to knockout_label
##########################################################

knockout_count = 0
j = 0
for i in range(len(X_valid_raw)):
    if int(Y_valid_raw[i].asscalar()) == knockout_label:
#         print("knockout_label_called")
        knockout_count += 1
        if knockout_count <= knockout_valid_number:
            X_valid[j] = X_valid_raw[i]
            Y_valid[j] = Y_valid_raw[i]
            j += 1
            print("added knockout label, kockout count: %s, j: %s, i: %s" % (knockout_count, j, i))
    else:
#         print("non_knockout_called")
        X_valid[j] = X_valid_raw[i]
        Y_valid[j] = Y_valid_raw[i]
        j += 1
            
    if i % 1000==0:
        print(i)
            
            
# get the test data 
X_test = transform_data(test_DS._data)
Y_test = transform_label(nd.array(test_DS._label))

num_train = len(X_train)
num_valid = len(X_valid)
num_test = len(Y_test)



0
added knockout label, kockout count: 1, j: 30, i: 29
added knockout label, kockout count: 2, j: 31, i: 30
added knockout label, kockout count: 3, j: 36, i: 35
added knockout label, kockout count: 4, j: 50, i: 49
added knockout label, kockout count: 5, j: 78, i: 77
added knockout label, kockout count: 6, j: 94, i: 93
added knockout label, kockout count: 7, j: 116, i: 115
added knockout label, kockout count: 8, j: 117, i: 116
added knockout label, kockout count: 9, j: 130, i: 129
added knockout label, kockout count: 10, j: 166, i: 165
added knockout label, kockout count: 11, j: 180, i: 179
added knockout label, kockout count: 12, j: 186, i: 185
added knockout label, kockout count: 13, j: 190, i: 189
added knockout label, kockout count: 14, j: 200, i: 199
added knockout label, kockout count: 15, j: 214, i: 213
added knockout label, kockout count: 16, j: 221, i: 220
added knockout label, kockout count: 17, j: 224, i: 223
added knockout label, kockout count: 18, j: 234, i: 233
added knock

added knockout label, kockout count: 194, j: 1927, i: 1926
added knockout label, kockout count: 195, j: 1936, i: 1935
added knockout label, kockout count: 196, j: 1951, i: 1950
added knockout label, kockout count: 197, j: 1955, i: 1954
added knockout label, kockout count: 198, j: 1961, i: 1960
added knockout label, kockout count: 199, j: 1978, i: 1977
added knockout label, kockout count: 200, j: 1981, i: 1980
added knockout label, kockout count: 201, j: 1989, i: 1988
added knockout label, kockout count: 202, j: 2000, i: 1999
2000
added knockout label, kockout count: 203, j: 2007, i: 2006
added knockout label, kockout count: 204, j: 2011, i: 2010
added knockout label, kockout count: 205, j: 2015, i: 2014
added knockout label, kockout count: 206, j: 2022, i: 2021
added knockout label, kockout count: 207, j: 2028, i: 2027
added knockout label, kockout count: 208, j: 2035, i: 2034
added knockout label, kockout count: 209, j: 2036, i: 2035
added knockout label, kockout count: 210, j: 2054, 

added knockout label, kockout count: 374, j: 3790, i: 3789
added knockout label, kockout count: 375, j: 3816, i: 3815
added knockout label, kockout count: 376, j: 3832, i: 3831
added knockout label, kockout count: 377, j: 3843, i: 3842
added knockout label, kockout count: 378, j: 3853, i: 3852
added knockout label, kockout count: 379, j: 3860, i: 3859
added knockout label, kockout count: 380, j: 3881, i: 3880
added knockout label, kockout count: 381, j: 3885, i: 3884
added knockout label, kockout count: 382, j: 3892, i: 3891
added knockout label, kockout count: 383, j: 3894, i: 3893
added knockout label, kockout count: 384, j: 3901, i: 3900
added knockout label, kockout count: 385, j: 3904, i: 3903
added knockout label, kockout count: 386, j: 3907, i: 3906
added knockout label, kockout count: 387, j: 3910, i: 3909
added knockout label, kockout count: 388, j: 3919, i: 3918
added knockout label, kockout count: 389, j: 3939, i: 3938
added knockout label, kockout count: 390, j: 3941, i: 39

added knockout label, kockout count: 569, j: 5641, i: 5640
added knockout label, kockout count: 570, j: 5643, i: 5642
added knockout label, kockout count: 571, j: 5650, i: 5649
added knockout label, kockout count: 572, j: 5658, i: 5657
added knockout label, kockout count: 573, j: 5660, i: 5659
added knockout label, kockout count: 574, j: 5666, i: 5665
added knockout label, kockout count: 575, j: 5674, i: 5673
added knockout label, kockout count: 576, j: 5685, i: 5684
added knockout label, kockout count: 577, j: 5697, i: 5696
added knockout label, kockout count: 578, j: 5700, i: 5699
added knockout label, kockout count: 579, j: 5703, i: 5702
added knockout label, kockout count: 580, j: 5704, i: 5703
added knockout label, kockout count: 581, j: 5720, i: 5719
added knockout label, kockout count: 582, j: 5739, i: 5738
added knockout label, kockout count: 583, j: 5744, i: 5743
added knockout label, kockout count: 584, j: 5758, i: 5757
added knockout label, kockout count: 585, j: 5766, i: 57

added knockout label, kockout count: 737, j: 7506, i: 7505
added knockout label, kockout count: 738, j: 7510, i: 7509
added knockout label, kockout count: 739, j: 7512, i: 7511
added knockout label, kockout count: 740, j: 7521, i: 7520
added knockout label, kockout count: 741, j: 7528, i: 7527
added knockout label, kockout count: 742, j: 7529, i: 7528
added knockout label, kockout count: 743, j: 7531, i: 7530
added knockout label, kockout count: 744, j: 7538, i: 7537
added knockout label, kockout count: 745, j: 7544, i: 7543
added knockout label, kockout count: 746, j: 7548, i: 7547
added knockout label, kockout count: 747, j: 7551, i: 7550
added knockout label, kockout count: 748, j: 7574, i: 7573
added knockout label, kockout count: 749, j: 7582, i: 7581
added knockout label, kockout count: 750, j: 7589, i: 7588
added knockout label, kockout count: 751, j: 7593, i: 7592
added knockout label, kockout count: 752, j: 7600, i: 7599
added knockout label, kockout count: 753, j: 7613, i: 76

added knockout label, kockout count: 938, j: 9363, i: 9362
added knockout label, kockout count: 939, j: 9374, i: 9373
added knockout label, kockout count: 940, j: 9376, i: 9375
added knockout label, kockout count: 941, j: 9380, i: 9379
added knockout label, kockout count: 942, j: 9385, i: 9384
added knockout label, kockout count: 943, j: 9386, i: 9385
added knockout label, kockout count: 944, j: 9391, i: 9390
added knockout label, kockout count: 945, j: 9402, i: 9401
added knockout label, kockout count: 946, j: 9414, i: 9413
added knockout label, kockout count: 947, j: 9418, i: 9417
added knockout label, kockout count: 948, j: 9431, i: 9430
added knockout label, kockout count: 949, j: 9436, i: 9435
added knockout label, kockout count: 950, j: 9439, i: 9438
added knockout label, kockout count: 951, j: 9450, i: 9449
added knockout label, kockout count: 952, j: 9465, i: 9464
added knockout label, kockout count: 953, j: 9472, i: 9471
added knockout label, kockout count: 954, j: 9494, i: 94

added knockout label, kockout count: 1132, j: 11217, i: 11216
added knockout label, kockout count: 1133, j: 11219, i: 11218
added knockout label, kockout count: 1134, j: 11225, i: 11224
added knockout label, kockout count: 1135, j: 11230, i: 11229
added knockout label, kockout count: 1136, j: 11238, i: 11237
added knockout label, kockout count: 1137, j: 11239, i: 11238
added knockout label, kockout count: 1138, j: 11241, i: 11240
added knockout label, kockout count: 1139, j: 11252, i: 11251
added knockout label, kockout count: 1140, j: 11256, i: 11255
added knockout label, kockout count: 1141, j: 11264, i: 11263
added knockout label, kockout count: 1142, j: 11277, i: 11276
added knockout label, kockout count: 1143, j: 11281, i: 11280
added knockout label, kockout count: 1144, j: 11284, i: 11283
added knockout label, kockout count: 1145, j: 11298, i: 11297
added knockout label, kockout count: 1146, j: 11303, i: 11302
added knockout label, kockout count: 1147, j: 11305, i: 11304
added kn

added knockout label, kockout count: 1303, j: 12855, i: 12854
added knockout label, kockout count: 1304, j: 12884, i: 12883
added knockout label, kockout count: 1305, j: 12911, i: 12910
added knockout label, kockout count: 1306, j: 12912, i: 12911
added knockout label, kockout count: 1307, j: 12916, i: 12915
added knockout label, kockout count: 1308, j: 12917, i: 12916
added knockout label, kockout count: 1309, j: 12923, i: 12922
added knockout label, kockout count: 1310, j: 12937, i: 12936
added knockout label, kockout count: 1311, j: 12945, i: 12944
added knockout label, kockout count: 1312, j: 12961, i: 12960
added knockout label, kockout count: 1313, j: 12968, i: 12967
added knockout label, kockout count: 1314, j: 12992, i: 12991
added knockout label, kockout count: 1315, j: 12994, i: 12993
13000
added knockout label, kockout count: 1316, j: 13009, i: 13008
added knockout label, kockout count: 1317, j: 13012, i: 13011
added knockout label, kockout count: 1318, j: 13017, i: 13016
ad

added knockout label, kockout count: 1485, j: 14728, i: 14727
added knockout label, kockout count: 1486, j: 14733, i: 14732
added knockout label, kockout count: 1487, j: 14755, i: 14754
added knockout label, kockout count: 1488, j: 14779, i: 14778
added knockout label, kockout count: 1489, j: 14783, i: 14782
added knockout label, kockout count: 1490, j: 14797, i: 14796
added knockout label, kockout count: 1491, j: 14807, i: 14806
added knockout label, kockout count: 1492, j: 14808, i: 14807
added knockout label, kockout count: 1493, j: 14849, i: 14848
added knockout label, kockout count: 1494, j: 14857, i: 14856
added knockout label, kockout count: 1495, j: 14868, i: 14867
added knockout label, kockout count: 1496, j: 14878, i: 14877
added knockout label, kockout count: 1497, j: 14899, i: 14898
added knockout label, kockout count: 1498, j: 14910, i: 14909
added knockout label, kockout count: 1499, j: 14946, i: 14945
added knockout label, kockout count: 1500, j: 14949, i: 14948
added kn

added knockout label, kockout count: 1661, j: 16597, i: 16596
added knockout label, kockout count: 1662, j: 16625, i: 16624
added knockout label, kockout count: 1663, j: 16626, i: 16625
added knockout label, kockout count: 1664, j: 16633, i: 16632
added knockout label, kockout count: 1665, j: 16635, i: 16634
added knockout label, kockout count: 1666, j: 16638, i: 16637
added knockout label, kockout count: 1667, j: 16648, i: 16647
added knockout label, kockout count: 1668, j: 16678, i: 16677
added knockout label, kockout count: 1669, j: 16687, i: 16686
added knockout label, kockout count: 1670, j: 16705, i: 16704
added knockout label, kockout count: 1671, j: 16722, i: 16721
added knockout label, kockout count: 1672, j: 16734, i: 16733
added knockout label, kockout count: 1673, j: 16735, i: 16734
added knockout label, kockout count: 1674, j: 16753, i: 16752
added knockout label, kockout count: 1675, j: 16764, i: 16763
added knockout label, kockout count: 1676, j: 16777, i: 16776
added kn

added knockout label, kockout count: 1837, j: 18452, i: 18451
added knockout label, kockout count: 1838, j: 18506, i: 18505
added knockout label, kockout count: 1839, j: 18509, i: 18508
added knockout label, kockout count: 1840, j: 18526, i: 18525
added knockout label, kockout count: 1841, j: 18539, i: 18538
added knockout label, kockout count: 1842, j: 18546, i: 18545
added knockout label, kockout count: 1843, j: 18573, i: 18572
added knockout label, kockout count: 1844, j: 18579, i: 18578
added knockout label, kockout count: 1845, j: 18592, i: 18591
added knockout label, kockout count: 1846, j: 18600, i: 18599
added knockout label, kockout count: 1847, j: 18604, i: 18603
added knockout label, kockout count: 1848, j: 18606, i: 18605
added knockout label, kockout count: 1849, j: 18615, i: 18614
added knockout label, kockout count: 1850, j: 18625, i: 18624
added knockout label, kockout count: 1851, j: 18631, i: 18630
added knockout label, kockout count: 1852, j: 18658, i: 18657
added kn

21000
22000
23000
24000
25000
26000
27000
28000
29000
30000
31000
32000
33000
34000
35000
36000
37000
38000
39000
0
added knockout label, kockout count: 1, j: 13, i: 12
added knockout label, kockout count: 2, j: 19, i: 18
added knockout label, kockout count: 3, j: 23, i: 22
added knockout label, kockout count: 4, j: 105, i: 104
added knockout label, kockout count: 5, j: 114, i: 113
added knockout label, kockout count: 6, j: 126, i: 125
added knockout label, kockout count: 7, j: 132, i: 131
added knockout label, kockout count: 8, j: 144, i: 143
added knockout label, kockout count: 9, j: 151, i: 150
added knockout label, kockout count: 10, j: 160, i: 159
added knockout label, kockout count: 11, j: 191, i: 190
added knockout label, kockout count: 12, j: 194, i: 193
added knockout label, kockout count: 13, j: 195, i: 194
added knockout label, kockout count: 14, j: 198, i: 197
added knockout label, kockout count: 15, j: 212, i: 211
added knockout label, kockout count: 16, j: 242, i: 241
add

added knockout label, kockout count: 199, j: 1933, i: 1932
added knockout label, kockout count: 200, j: 1961, i: 1960
added knockout label, kockout count: 201, j: 1972, i: 1971
added knockout label, kockout count: 202, j: 1976, i: 1975
2000
added knockout label, kockout count: 203, j: 2011, i: 2010
added knockout label, kockout count: 204, j: 2020, i: 2019
added knockout label, kockout count: 205, j: 2027, i: 2026
added knockout label, kockout count: 206, j: 2040, i: 2039
added knockout label, kockout count: 207, j: 2045, i: 2044
added knockout label, kockout count: 208, j: 2058, i: 2057
added knockout label, kockout count: 209, j: 2070, i: 2069
added knockout label, kockout count: 210, j: 2074, i: 2073
added knockout label, kockout count: 211, j: 2084, i: 2083
added knockout label, kockout count: 212, j: 2104, i: 2103
added knockout label, kockout count: 213, j: 2106, i: 2105
added knockout label, kockout count: 214, j: 2111, i: 2110
added knockout label, kockout count: 215, j: 2149, 

added knockout label, kockout count: 375, j: 3758, i: 3757
added knockout label, kockout count: 376, j: 3772, i: 3771
added knockout label, kockout count: 377, j: 3776, i: 3775
added knockout label, kockout count: 378, j: 3778, i: 3777
added knockout label, kockout count: 379, j: 3784, i: 3783
added knockout label, kockout count: 380, j: 3812, i: 3811
added knockout label, kockout count: 381, j: 3815, i: 3814
added knockout label, kockout count: 382, j: 3829, i: 3828
added knockout label, kockout count: 383, j: 3830, i: 3829
added knockout label, kockout count: 384, j: 3832, i: 3831
added knockout label, kockout count: 385, j: 3839, i: 3838
added knockout label, kockout count: 386, j: 3850, i: 3849
added knockout label, kockout count: 387, j: 3868, i: 3867
added knockout label, kockout count: 388, j: 3876, i: 3875
added knockout label, kockout count: 389, j: 3908, i: 3907
added knockout label, kockout count: 390, j: 3910, i: 3909
added knockout label, kockout count: 391, j: 3925, i: 39

In [19]:
nd.sum(nd.one_hot(Y_train_raw, 10), axis=0)


[ 3986.  3986.  4048.  3984.  4003.  3975.  4020.  4023.  3997.  3978.]
<NDArray 10 @cpu(0)>

In [20]:
nd.sum(nd.one_hot(Y_train, 10), axis=0)


[ 2000.  3986.  4048.  3984.  4003.  3975.  4020.  4023.  3997.  3978.]
<NDArray 10 @cpu(0)>

In [21]:
nd.sum(nd.one_hot(Y_valid_raw, 10), axis=0)


[ 1014.  1014.   952.  1016.   997.  1025.   980.   977.  1003.  1022.]
<NDArray 10 @cpu(0)>

In [22]:
nd.sum(nd.one_hot(Y_valid, 10), axis=0)


[  500.  1014.   952.  1016.   997.  1025.   980.   977.  1003.  1022.]
<NDArray 10 @cpu(0)>

## Now we'll train a convolutional neural network to predict Y. 

In [23]:
data_ctx = mx.gpu(0)
model_ctx = mx.gpu(0)
num_fc = 512


# build a convnet
net = gluon.nn.Sequential()
with net.name_scope():
    net.add(gluon.nn.Conv2D(channels=20, kernel_size=5, activation='relu'))
    net.add(gluon.nn.MaxPool2D(pool_size=2, strides=2))            
    net.add(gluon.nn.Conv2D(channels=50, kernel_size=5, activation='relu'))
    net.add(gluon.nn.MaxPool2D(pool_size=2, strides=2))
    # The Flatten layer collapses all axis, except the first one, into one axis.
    net.add(gluon.nn.Flatten())
    net.add(gluon.nn.Dense(num_fc, activation="relu"))
    net.add(gluon.nn.Dense(num_classes))
    
net.collect_params().initialize(mx.init.Xavier(magnitude=2.24), ctx=model_ctx)
trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.1})
ce_loss = gluon.loss.SoftmaxCrossEntropyLoss()

# Now build up dataloaders

In [24]:
batch_size = 128

train_data = gluon.data.DataLoader(gluon.data.ArrayDataset(nd.array(X_train), nd.array(Y_train)),
                                      batch_size=batch_size, shuffle=True)

valid_data = gluon.data.DataLoader(gluon.data.ArrayDataset(nd.array(X_valid), nd.array(Y_valid)),
                                      batch_size=batch_size, shuffle=True)

test_data = gluon.data.DataLoader(gluon.data.ArrayDataset(nd.array(X_test), nd.array(Y_test)),
                                      batch_size=batch_size, shuffle=True)


In [25]:
# net(X_train[:10].as_in_context(model_ctx))

### Code to calculate accuracy

In [26]:
def evaluate_accuracy(data_iterator, net):
    acc = mx.metric.Accuracy()
    for i, (data, label) in enumerate(data_iterator):
        data = data.as_in_context(model_ctx)
        label = label.as_in_context(model_ctx)
        output = net(data)
        predictions = nd.argmax(output, axis=1)
        acc.update(preds=predictions, labels=label)
    return acc.get()[1]

In [27]:
epochs = 1
loss_sequence = []

for e in range(epochs):
    cumulative_loss = 0
    for i, (data, label) in enumerate(train_data):
        data = data.as_in_context(model_ctx)
        label = label.as_in_context(model_ctx)
        with autograd.record():
            output = net(data)
            loss = ce_loss(output, label)
        loss.backward()
        trainer.step(batch_size)
        cumulative_loss += nd.sum(loss).asscalar()
    train_accuracy = evaluate_accuracy(train_data, net)
    valid_accuracy = evaluate_accuracy(valid_data, net)
    print("Epoch %s, loss: %s, train_acc: %s, valid_acc: %s" % (e, cumulative_loss, train_accuracy, valid_accuracy))
    loss_sequence.append(cumulative_loss)

Epoch 0, loss: 77855.4344635, train_acc: 0.418582627453, valid_acc: 0.417878979549


# Now let's go through the valdiation set and put together the response matrix

In [28]:
response_matrix = np.zeros((10,10))

In [29]:
valid_preds = nd.softmax(net(X_valid.as_in_context(model_ctx)), axis=1)

In [30]:
nd.sum(valid_preds, axis=0)


[  538.06610107  1251.02954102   816.19494629   874.35144043  1212.98925781
   815.58850098  1030.9543457   1119.10693359  1012.80169678   814.91723633]
<NDArray 10 @gpu(0)>

In [31]:
test_output = net(X_test.as_in_context(model_ctx))
test_output


[[-1.11545908 -0.40086365  0.54802281 ..., -1.27542853 -0.15521327
  -1.78242874]
 [ 0.57163537  3.16792321 -0.62265599 ..., -1.95094502  2.9572196
   2.29805064]
 [ 0.05822547  2.18377542 -0.34445405 ..., -1.08171535  1.93548775
   1.62874877]
 ..., 
 [-1.64128232 -1.41522229  0.25768721 ...,  1.88782036 -1.90832806
  -1.00214326]
 [-1.51265574  0.47482574 -0.26147765 ...,  0.11907473 -0.87166446
  -0.64124072]
 [-0.63573653 -1.36788404  0.5114311  ...,  2.7820015  -1.23431695
  -0.31894213]]
<NDArray 10000x10 @gpu(0)>

In [33]:
test_preds = nd.softmax(test_output, axis=1)
test_preds


[[ 0.01969173  0.04023729  0.10392616 ...,  0.0167807   0.05144146
   0.01010701]
 [ 0.03150987  0.42266867  0.00954492 ...,  0.00252874  0.34236756
   0.17710027]
 [ 0.04337798  0.36339921  0.02899933 ...,  0.01387393  0.28350061
   0.20861186]
 ..., 
 [ 0.00705084  0.0088393   0.04709265 ...,  0.24038698  0.00539839
   0.01336028]
 [ 0.01615646  0.1178958   0.05645804 ...,  0.08260334  0.03067079
   0.03861863]
 [ 0.02029894  0.00976126  0.06392665 ...,  0.61912465  0.01115612
   0.02786477]]
<NDArray 10000x10 @gpu(0)>

In [34]:
test_response = nd.sum(test_preds, axis=0) / num_test

In [35]:
test_response


[ 0.06297506  0.12905765  0.08826275  0.09007443  0.12581889  0.08445676
  0.10623354  0.11893882  0.110544    0.08363811]
<NDArray 10 @gpu(0)>

In [36]:
valid_preds


[[ 0.02264453  0.70477009  0.01202571 ...,  0.00287019  0.09316307
   0.13859841]
 [ 0.08244438  0.21720351  0.02259534 ...,  0.01590944  0.30350691
   0.26457042]
 [ 0.00326736  0.00427032  0.05778188 ...,  0.11016581  0.0014704
   0.00437582]
 ..., 
 [ 0.06685457  0.18505482  0.01442392 ...,  0.0187577   0.14846069
   0.53434575]
 [ 0.10901469  0.09017766  0.05149648 ...,  0.01673063  0.62943751
   0.04501303]
 [ 0.06093625  0.0673751   0.06718838 ...,  0.27986681  0.0430745
   0.09846281]]
<NDArray 9486x10 @gpu(0)>

In [37]:
nd.sum(nd.one_hot(Y_train_raw, 10), axis=0)


[ 3986.  3986.  4048.  3984.  4003.  3975.  4020.  4023.  3997.  3978.]
<NDArray 10 @cpu(0)>

### Building the reponse matrix

Now we will build the response matric based on the classifier's output on the validation set. 

In [38]:
YV = Y_valid
valid_counts = np.zeros(num_classes)

valid_preds_np = valid_preds.asnumpy()


for i in range(num_valid):
    response_matrix[:,int(Y_valid[i].asscalar())] += valid_preds_np[i] 
    valid_counts[int(Y_valid[i].asscalar())] += 1

In [39]:
for j in range(num_classes):
    if valid_counts[j] > 0:
        response_matrix[:,j] /= valid_counts[j]

In [40]:
response_matrix

array([[ 0.20052022,  0.04551991,  0.0606538 ,  0.03605141,  0.04762776,
         0.03204594,  0.02138741,  0.03700266,  0.10543017,  0.0529241 ],
       [ 0.12787846,  0.39904336,  0.0623686 ,  0.06901757,  0.05511797,
         0.05559825,  0.05350413,  0.0618634 ,  0.17833173,  0.24391538],
       [ 0.09375897,  0.03907244,  0.14874066,  0.09612488,  0.12017256,
         0.10097513,  0.09341054,  0.0833439 ,  0.05251801,  0.04058273],
       [ 0.04748553,  0.04427382,  0.09907979,  0.16230617,  0.09065969,
         0.15294355,  0.12272688,  0.08663925,  0.03937209,  0.05374173],
       [ 0.08131776,  0.0532345 ,  0.17941268,  0.13288551,  0.24071207,
         0.13907348,  0.19335631,  0.14888579,  0.04787005,  0.04602099],
       [ 0.04657561,  0.03714042,  0.09879018,  0.14530414,  0.08098333,
         0.17973241,  0.08605921,  0.09016802,  0.03936549,  0.03530448],
       [ 0.03382321,  0.04952437,  0.12150172,  0.14717638,  0.14748955,
         0.12939198,  0.26857379,  0.07575817

## Invert the response matrix and estimate the test set marginal P(Y)

In [41]:
R_inv = np.linalg.inv(response_matrix)

In [42]:
R_inv.dot(test_response.asnumpy())

array([ 0.09070917,  0.09629855,  0.14153385,  0.08344577,  0.06510781,
        0.10336234,  0.11087846,  0.10648564,  0.10255169,  0.09962674])

In [8]:
X_valid


[[[[ 1.          0.98823529  0.99215686 ...,  0.64705884  0.95294118
     0.99607843]
   [ 1.          0.98823529  0.99607843 ...,  0.50980395  0.88235295
     0.99215686]
   [ 1.          0.99607843  0.97254902 ...,  0.5529412   0.86274511
     0.99215686]
   ..., 
   [ 0.9137255   0.84705883  0.94509804 ...,  0.03529412  0.07058824
     0.66274512]
   [ 1.          1.          0.99215686 ...,  0.08235294  0.44313726
     0.92156863]
   [ 1.          0.98431373  0.99215686 ...,  0.67450982  0.90196079
     0.96862745]]

  [[ 1.          0.98823529  0.98823529 ...,  0.69411767  0.96470588
     0.99215686]
   [ 1.          0.98823529  0.99607843 ...,  0.56470591  0.90980393  1.        ]
   [ 1.          0.99607843  0.96862745 ...,  0.60784316  0.89019608  1.        ]
   ..., 
   [ 0.91764706  0.84705883  0.94509804 ...,  0.04313726  0.07450981
     0.67058825]
   [ 1.          1.          0.99215686 ...,  0.09019608  0.4509804
     0.92941177]
   [ 1.          0.98431373  0.99215686 ..

In [9]:
X_test


[[[[ 0.61960787  0.62352943  0.64705884 ...,  0.53725493  0.49411765
     0.45490196]
   [ 0.59607846  0.59215689  0.62352943 ...,  0.53333336  0.49019608
     0.46666667]
   [ 0.59215689  0.59215689  0.61960787 ...,  0.54509807  0.50980395
     0.47058824]
   ..., 
   [ 0.26666668  0.16470589  0.12156863 ...,  0.14901961  0.05098039
     0.15686275]
   [ 0.23921569  0.19215687  0.13725491 ...,  0.10196079  0.11372549
     0.07843138]
   [ 0.21176471  0.21960784  0.17647059 ...,  0.09411765  0.13333334
     0.08235294]]

  [[ 0.43921569  0.43529412  0.45490196 ...,  0.37254903  0.35686275
     0.33333334]
   [ 0.43921569  0.43137255  0.44705883 ...,  0.37254903  0.35686275
     0.34509805]
   [ 0.43137255  0.42745098  0.43529412 ...,  0.38431373  0.37254903
     0.34901962]
   ..., 
   [ 0.48627451  0.39215687  0.34509805 ...,  0.38039216  0.25098041
     0.33333334]
   [ 0.45490196  0.40000001  0.33333334 ...,  0.32156864  0.32156864
     0.25098041]
   [ 0.41960785  0.41176471  0.34

In [11]:
net(X_test.as_in_context(model_ctx))


[[ 0.17089634  0.01088437 -0.04569125 ..., -0.04264998 -0.02200803
  -0.01031495]
 [ 0.26343119  0.00868653 -0.04735368 ...,  0.10443804 -0.06067207
  -0.02821601]
 [ 0.20072933  0.00649943 -0.03450613 ...,  0.0231309  -0.01013924
  -0.02070815]
 ..., 
 [ 0.12646179  0.00889461 -0.04903109 ..., -0.0093085  -0.05755803
  -0.0453046 ]
 [ 0.23281378 -0.00636021 -0.10480515 ..., -0.00398986 -0.04914022
  -0.01008645]
 [ 0.16446826  0.00741851 -0.03781474 ..., -0.00831909 -0.04310116
  -0.02940574]]
<NDArray 10000x10 @gpu(0)>