Here we present an example to **recover smoothed labels from untrained LeNet on cifar10**.

Due to time limit, instead of 1000 examples, here we experiment 100 times with 10 samples for every class, and got **100%** accuracy.

In [2]:
import torch
import random
from typing import OrderedDict
from exp import cross_entropy_for_onehot
from recovering import label_recovery
from tqdm import tqdm
import numpy as np
from datetime import datetime
import time
seed=2023
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.enabled = False
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
CONFIG=OrderedDict(device=torch.device('cpu'),
    dataset="cifar10",
    network="lenet",
    opt="lbfgs",
    type='label_smooth',
    pretrained=False,
    criterion=cross_entropy_for_onehot,
    lr=0.5,
    bound=100,
    iteration=200,
    initia=1.,
    coefficient=4)
test=label_recovery(CONFIG)
test.datadir='/home/yanbo.wang/'+test.datadir
datalist=np.load('additional_files/mixup_list_cifar10.npy')
# datalist=np.load('additional_files/mixup_list_imagenet.npy',allow_pickle=True)
# datalist=np.load('additional_files/dataset_cifar100.csv',allow_pickle=True)
exp=np.zeros((10,100,8))#index,prob,featureloss,real_scalar,reco_scalar,scalar_loss,success,time 
for i in tqdm(range(10)):
    choice_index=np.random.choice(datalist[i],10)
    for i_exp, ind in enumerate(choice_index):
        prob=random.uniform(0,0.5)
        #prob=0
        test.setup(ind,prob)
        exp[i,i_exp,0],exp[i,i_exp,1]=ind,prob
        start_time=time.time()
        exp[i,i_exp,6]=test.label_reco()
        if exp[i,i_exp,6]==-1:
            exp[i,i_exp,6]=test.pso()
        exp[i,i_exp,7]=time.time()-start_time
        exp[i,i_exp,3],exp[i,i_exp,4],exp[i,i_exp,5]=test.ground_truth,test.scalar,test.ground_truth-test.scalar
        if exp[i,i_exp,6] ==1 or exp[i,i_exp,6] == 0:
            exp[i,i_exp,2]=((test.recover_tensor-test.net.temp)**2).sum()

  0%|          | 0/10 [00:00<?, ?it/s]

flip!
epoch is 6
epoch is 3
flip!
epoch is 6
flip!
epoch is 6
flip!
epoch is 6
flip!
epoch is 6
flip!
epoch is 6
flip!
epoch is 6
flip!
epoch is 6


 10%|█         | 1/10 [00:04<00:37,  4.12s/it]

epoch is 3
flip!
epoch is 6
epoch is 3
flip!
epoch is 6
flip!
epoch is 6
flip!
epoch is 6
flip!
epoch is 6
epoch is 3
epoch is 3
flip!
epoch is 6


 20%|██        | 2/10 [00:07<00:31,  3.88s/it]

epoch is 3
flip!
epoch is 6
epoch is 3
epoch is 3
epoch is 3
flip!
epoch is 6
flip!
epoch is 7
skip!
epoch is 7
epoch is 4
flip!
epoch is 6


 30%|███       | 3/10 [00:11<00:26,  3.77s/it]

epoch is 3
epoch is 3
epoch is 3
epoch is 3
epoch is 3
flip!
epoch is 6
epoch is 3
flip!
epoch is 6
flip!
epoch is 6
epoch is 3


 40%|████      | 4/10 [00:14<00:21,  3.57s/it]

epoch is 3
epoch is 3
flip!
epoch is 6
epoch is 3
flip!
epoch is 6
epoch is 3
flip!
epoch is 6
skip!
skip!
skip!
skip!
skip!
skip!
skip!
skip!
flip!
epoch is 29
flip!
epoch is 6
flip!
epoch is 6


 50%|█████     | 5/10 [00:19<00:19,  3.86s/it]

flip!
epoch is 6
flip!
epoch is 6
flip!
epoch is 6
flip!
epoch is 6
flip!
epoch is 6
epoch is 3
flip!
epoch is 6
epoch is 3
flip!
epoch is 6
flip!
epoch is 6


 60%|██████    | 6/10 [00:23<00:15,  3.90s/it]

epoch is 3
skip!
skip!
skip!
skip!
skip!
skip!
skip!
skip!
flip!
epoch is 30
skip!
skip!
skip!
skip!
skip!
skip!
skip!
skip!
flip!
epoch is 29
epoch is 3
skip!
skip!
skip!
skip!
skip!
skip!
skip!
skip!
flip!
epoch is 30
epoch is 3
epoch is 3
skip!
skip!
skip!
skip!
skip!
skip!
skip!
skip!
flip!
epoch is 31
flip!
epoch is 7
skip!
skip!
skip!
flip!
epoch is 18
skip!
skip!
skip!
skip!
skip!
skip!
skip!
skip!
flip!


 70%|███████   | 7/10 [00:29<00:14,  4.87s/it]

epoch is 29
flip!
epoch is 6
epoch is 4
epoch is 3
epoch is 3
epoch is 3
epoch is 3
flip!
epoch is 6
epoch is 3
epoch is 3


 80%|████████  | 8/10 [00:33<00:08,  4.34s/it]

epoch is 3
flip!
epoch is 6
flip!
epoch is 6
flip!
epoch is 6
flip!
epoch is 6
flip!
epoch is 6
epoch is 3
epoch is 3
epoch is 3
flip!
epoch is 6


 90%|█████████ | 9/10 [00:36<00:04,  4.09s/it]

epoch is 3
epoch is 3
flip!
epoch is 6
flip!
epoch is 6
flip!
epoch is 6
epoch is 3
flip!
epoch is 6
flip!
epoch is 6
flip!
epoch is 6
flip!
epoch is 6


100%|██████████| 10/10 [00:40<00:00,  4.07s/it]

flip!
epoch is 6





Under such settings, gradient-based algorithm (L-BFGS) is able to find the global optimum.

Below is the tiny experiment for mixup. It takes about 18min to finish recover 135 examples.

When L-BFGS did not get the optimal scalar, PSO is called to search iteratively. Adjusting searching bound, seed population as well as every searching range will alter the running time and searching accuracy (and error). It is a time-accuracy trade-off. 

In [3]:
from itertools import combinations
seed=2023
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
CONFIG=OrderedDict(device=torch.device('cpu'),
    dataset="cifar10",
    network="lenet",
    opt="lbfgs",
    type='mixup',
    pretrained=False,
    criterion=cross_entropy_for_onehot,
    lr=0.5,
    bound=100,
    iteration=200,
    initia=1.,
    coefficient=2)
test=label_recovery(CONFIG)
test.datadir='/home/yanbo.wang/'+test.datadir
mixup_list=np.load('additional_files/mixup_list_cifar10.npy')
exp=np.zeros((45,20,9))#index,prob,featureloss,real_scalar,reco_scalar,scalar_loss,success,time 
combination_list=list(combinations(range(10),2))
for i in tqdm(range(45)):
    item=combination_list[i]
    for ii in range(3):# it is ok to change the number of experiments for every combination. Currently we have 3*45=135 experiments in total.
        ind=[random.randint(0,999),random.randint(0,999)]
        prob=random.uniform(0,1)
        test.setup([mixup_list[item[0],ind[0]],mixup_list[item[1],ind[1]]],[1-prob,prob])
        exp[i,ii,0],exp[i,ii,1],exp[i,ii,2]=ind[0],ind[1],prob
        start_time=time.time()
        exp[i,ii,7]=test.label_reco()
        if exp[i,ii,7]==-1:
            exp[i,ii,7]=test.pso(50)
        exp[i,ii,8]=time.time()-start_time
        exp[i,ii,4],exp[i,ii,5],exp[i,ii,6]=test.ground_truth,test.scalar,test.ground_truth-test.scalar
        if exp[i,ii,7] ==1 or exp[i,ii,7] == 0:
            exp[i,ii,3]=((test.recover_tensor-test.net.temp)**2).sum()

  0%|          | 0/45 [00:00<?, ?it/s]

flip!
epoch is 6
flip!
epoch is 6


  2%|▏         | 1/45 [00:01<00:46,  1.06s/it]

epoch is 3
flip!
epoch is 6
flip!
epoch is 6


  4%|▍         | 2/45 [00:02<00:54,  1.26s/it]

flip!
epoch is 6
flip!
epoch is 6
flip!
epoch is 6


  7%|▋         | 3/45 [00:03<00:48,  1.16s/it]

epoch is 3
flip!
epoch is 6
flip!
epoch is 6


  9%|▉         | 4/45 [00:04<00:46,  1.13s/it]

epoch is 3
flip!
epoch is 6
epoch is 3


 11%|█         | 5/45 [00:05<00:46,  1.17s/it]

flip!
epoch is 6
epoch is 3
epoch is 3
skip!


 13%|█▎        | 6/45 [00:06<00:43,  1.12s/it]

flip!
epoch is 10
flip!
skip!
skip!
skip!
skip!
skip!
skip!
skip!
skip!
scalar is tensor(-129.6024, requires_grad=True) while gt is -2.5163229
out of bound!
ground_truth: -2.5163228511810303
searching from 0.7 to 6.0!
[6.] [12.90142536]
searching from -6.3 to -1.0!
[-2.51632335] [8.97517683e-13]
successfully find the ground_truth [-2.51632335]
flip!
epoch is 6


 16%|█▌        | 7/45 [02:15<27:09, 42.88s/it]

epoch is 3
flip!
epoch is 6
flip!
epoch is 6


 18%|█▊        | 8/45 [02:16<18:15, 29.62s/it]

flip!
epoch is 6
epoch is 3
epoch is 3


 20%|██        | 9/45 [02:17<12:23, 20.66s/it]

flip!
epoch is 6
flip!
epoch is 6
flip!
epoch is 6


 22%|██▏       | 10/45 [02:18<08:31, 14.60s/it]

epoch is 3
flip!
epoch is 6
flip!
epoch is 6


 24%|██▍       | 11/45 [02:19<05:55, 10.45s/it]

epoch is 3
epoch is 3
epoch is 3


 27%|██▋       | 12/45 [02:20<04:08,  7.52s/it]

flip!
epoch is 6
flip!
epoch is 6
flip!
epoch is 6


 29%|██▉       | 13/45 [02:22<02:59,  5.62s/it]

flip!
epoch is 6
epoch is 3
flip!
epoch is 6
skip!
skip!
skip!
skip!
skip!
skip!
skip!
skip!
flip!


 31%|███       | 14/45 [02:23<02:13,  4.32s/it]

epoch is 29
flip!
epoch is 6
flip!
epoch is 6


 33%|███▎      | 15/45 [02:24<01:41,  3.37s/it]

epoch is 3
flip!
epoch is 6
flip!
epoch is 6


 36%|███▌      | 16/45 [02:25<01:19,  2.75s/it]

flip!
epoch is 6
flip!
epoch is 6
epoch is 3


 38%|███▊      | 17/45 [02:26<01:02,  2.23s/it]

flip!
epoch is 6
epoch is 3
flip!
skip!
skip!
skip!
skip!
skip!
skip!
skip!
skip!
scalar is tensor(-129.5270, requires_grad=True) while gt is -2.0672958
out of bound!
ground_truth: -2.067295789718628
searching from 0.7 to 6.0!
[6.] [12.40940857]
searching from -6.3 to -1.0!
[-2.06728735] [1.35953809e-11]
successfully find the ground_truth [-2.06728735]


 40%|████      | 18/45 [04:35<18:06, 40.24s/it]

flip!
epoch is 6
flip!
epoch is 6
flip!
epoch is 6
flip!


 42%|████▏     | 19/45 [04:36<12:22, 28.55s/it]

epoch is 6
flip!
epoch is 6
flip!
epoch is 7


 44%|████▍     | 20/45 [04:38<08:29, 20.39s/it]

flip!
epoch is 6
epoch is 3
skip!
flip!
epoch is 10
skip!


 47%|████▋     | 21/45 [04:39<05:52, 14.68s/it]

flip!
epoch is 10
flip!
epoch is 6
epoch is 3
flip!
skip!
scalar is tensor(-118.8092, requires_grad=True) while gt is -4.2833943
out of bound!
ground_truth: -4.2833943367004395
searching from 0.7 to 6.0!
[6.] [13.31690311]
searching from -6.3 to -1.0!


 49%|████▉     | 22/45 [06:48<18:48, 49.06s/it]

[-4.28339318] [9.24150321e-16]
successfully find the ground_truth [-4.28339318]
epoch is 3
epoch is 3
flip!


 51%|█████     | 23/45 [06:49<12:41, 34.63s/it]

epoch is 6
flip!
epoch is 6
flip!
epoch is 6


 53%|█████▎    | 24/45 [06:50<08:35, 24.55s/it]

epoch is 3
epoch is 3
flip!
epoch is 6


 56%|█████▌    | 25/45 [06:51<05:49, 17.49s/it]

flip!
epoch is 6
epoch is 3
epoch is 3


 58%|█████▊    | 26/45 [06:52<03:57, 12.49s/it]

flip!
epoch is 6
flip!
epoch is 6
flip!
epoch is 6


 60%|██████    | 27/45 [06:53<02:42,  9.05s/it]

epoch is 3
epoch is 3
epoch is 4


 62%|██████▏   | 28/45 [06:54<01:51,  6.56s/it]

epoch is 3
epoch is 3
flip!
epoch is 6


 64%|██████▍   | 29/45 [06:55<01:18,  4.91s/it]

flip!
epoch is 6
flip!
epoch is 6
epoch is 3


 67%|██████▋   | 30/45 [06:56<00:56,  3.74s/it]

epoch is 3
flip!
epoch is 6
flip!
epoch is 6


 69%|██████▉   | 31/45 [06:57<00:42,  3.02s/it]

flip!
epoch is 6
skip!
skip!
skip!
skip!
skip!
skip!
skip!
skip!
flip!
epoch is 29
epoch is 3


 71%|███████   | 32/45 [06:58<00:31,  2.41s/it]

epoch is 3
epoch is 3
epoch is 3


 73%|███████▎  | 33/45 [06:59<00:22,  1.92s/it]

flip!
epoch is 6
flip!
epoch is 6
flip!
epoch is 6


 76%|███████▌  | 34/45 [07:00<00:18,  1.67s/it]

flip!
epoch is 6
flip!
epoch is 6
flip!
epoch is 6


 78%|███████▊  | 35/45 [07:01<00:14,  1.44s/it]

epoch is 3
skip!
flip!
epoch is 10
skip!
flip!
epoch is 10


 80%|████████  | 36/45 [07:03<00:13,  1.48s/it]

flip!
epoch is 6
flip!
epoch is 6
flip!
epoch is 6
flip!
skip!
scalar is tensor(-192.0335, requires_grad=True) while gt is -7.2275977
out of bound!
ground_truth: -7.227597713470459
searching from 0.7 to 6.0!
[6.] [15.23548412]
searching from -6.3 to -1.0!
[-6.3] [0.02103045]
searching from 5.7 to 16.0!
[16.] [1.88450933]
searching from -16.3 to -6.0!


 82%|████████▏ | 37/45 [11:28<10:45, 80.70s/it]

[-7.22757834] [2.70919528e-12]
successfully find the ground_truth [-7.22757834]
flip!
epoch is 6
epoch is 3
flip!


 84%|████████▍ | 38/45 [11:29<06:37, 56.85s/it]

epoch is 6
epoch is 3
flip!
epoch is 6


 87%|████████▋ | 39/45 [11:30<04:00, 40.09s/it]

flip!
epoch is 6
epoch is 3
epoch is 3
skip!
skip!
skip!
skip!
skip!
skip!
skip!
skip!
flip!
skip!
skip!
skip!
skip!
skip!
skip!
skip!
skip!
scalar is tensor(-130.0161, requires_grad=True) while gt is -6.4449916
out of bound!
ground_truth: -6.444991588592529
searching from 0.7 to 6.0!
[3.87603519] [1.33605087]
searching from -6.3 to -1.0!
[-6.3] [0.00056831]
searching from 5.7 to 16.0!
[16.] [0.59575492]
searching from -16.3 to -6.0!


 89%|████████▉ | 40/45 [15:57<09:00, 108.09s/it]

[-6.44498116] [9.10413888e-13]
successfully find the ground_truth [-6.44498116]
skip!
flip!
epoch is 10
skip!
flip!
epoch is 10
skip!
flip!


 91%|█████████ | 41/45 [15:59<05:04, 76.23s/it] 

epoch is 10
flip!
epoch is 6
skip!
flip!
epoch is 10
skip!


 93%|█████████▎| 42/45 [16:01<02:41, 53.80s/it]

flip!
epoch is 10
epoch is 3
epoch is 3
flip!
skip!
skip!
skip!
skip!
skip!
skip!
skip!
skip!
scalar is tensor(-129.5073, requires_grad=True) while gt is -2.0390737
out of bound!
ground_truth: -2.0390737056732178
searching from 0.7 to 6.0!
[6.] [12.89249039]
searching from -6.3 to -1.0!


 96%|█████████▌| 43/45 [18:08<02:31, 75.82s/it]

[-2.03905833] [3.33778873e-11]
successfully find the ground_truth [-2.03905833]
flip!
epoch is 6
flip!
epoch is 6


 98%|█████████▊| 44/45 [18:09<00:53, 53.44s/it]

flip!
epoch is 6
epoch is 3
epoch is 3


100%|██████████| 45/45 [18:10<00:00, 24.23s/it]

flip!
epoch is 6





In [6]:
count=0
for i in exp[:,:,7].reshape((-1)):
    if i<0:
        count+=1

In [7]:
count

0