In [3]:
import torch
import torchvision
import os,sys
import torchvision.transforms as transforms

from torch.utils.data import DataLoader

from argparse import Namespace
from src.models.resnet3 import resnet18
import numpy as np
from src.exp_datasets.dataloader import dataset_wrapper, randomly_produce_valid_set
import logging

In [4]:
args = {"data_dir": "data/", "valid_count": 100, "meta_lr": 30,  "save_path": "output/","cuda":True, "lr":0.1, "batch_size":128, "test_batch_size":128, "epochs":100, "do_train":True, "use_pretrained_model":True, "lr_decay":True, "metric": "accuracy"}
args = Namespace(**args)
logger = logging.getLogger()
logger.setLevel(logging.INFO)
args.logger = logging

args.logger.info("start")

INFO:root:start


Download real labeling errors

In [5]:
os.system("wget \"http://www.yliuu.com/web-cifarN/files/CIFAR-N.zip\" -P " + args.data_dir)
# os.rename("CIFAR-N.zip", os.path.join(args.data_dir, "CIFAR-N.zip"))
os.system("unzip " + args.data_dir + "/CIFAR-N.zip -o -d" + args.data_dir)


--2023-02-05 00:18:34--  http://www.yliuu.com/web-cifarN/files/CIFAR-N.zip
Resolving www.yliuu.com (www.yliuu.com)... 2607:f8b0:4004:c08::80, 172.253.63.128
Connecting to www.yliuu.com (www.yliuu.com)|2607:f8b0:4004:c08::80|:80... 

Archive:  data//CIFAR-N.zip


connected.
HTTP request sent, awaiting response... 200 OK
Length: 438874 (429K) [application/zip]
Saving to: ‘data/CIFAR-N.zip.15’

     0K .......... .......... .......... .......... .......... 11% 1.28M 0s
    50K .......... .......... .......... .......... .......... 23% 4.98M 0s
   100K .......... .......... .......... .......... .......... 34% 9.00M 0s
   150K .......... .......... .......... .......... .......... 46% 8.37M 0s
   200K .......... .......... .......... .......... .......... 58% 7.92M 0s
   250K .......... .......... .......... .......... .......... 69% 8.03M 0s
   300K .......... .......... .......... .......... .......... 81% 8.03M 0s
   350K .......... .......... .......... .......... .......... 93% 8.14M 0s
   400K .......... .......... ........                        100% 6.29M=0.09s

2023-02-05 00:18:45 (4.76 MB/s) - ‘data/CIFAR-N.zip.15’ saved [438874/438874]

caution: filename not matched:  -o


2816

In [6]:
!jupyter nbextension enable --py widgetsnbextension

Enabling notebook extension jupyter-js-widgets/extension...
      - Validating: [32mOK[0m


Prepare a dataset and model

In [7]:
valid_count = args.valid_count

transform_train_list = [
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
]
transform_train = transforms.Compose(transform_train_list)
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

trainset = torchvision.datasets.CIFAR10(
    root=os.path.join(args.data_dir, 'CIFAR-10'),
    train=True,
    download=True,
    transform=transform_train,
)
testset = torchvision.datasets.CIFAR10(
    root=os.path.join(args.data_dir, 'CIFAR-10'),
    train=False,
    download=True,
    transform=transform_test,
)

# create a wrapper of the dataset to return the indices of each sample during training

trainset = dataset_wrapper(np.copy(trainset.data), np.copy(trainset.targets), transform_train)
testset = dataset_wrapper(np.copy(testset.data), np.copy(testset.targets), transform_test)
validset, testset = randomly_produce_valid_set(testset, rate = 0.4)

origin_labels = np.copy(trainset.targets)

trainset.targets = torch.load(os.path.join(args.data_dir, "CIFAR-N/CIFAR-10_human.pt"))['worse_label']

net = resnet18(num_classes=10)


Files already downloaded and verified
Files already downloaded and verified


Pretrain a model without reweighting

In [8]:
from src.main.main_train import basic_train
trainloader = DataLoader(
        trainset,
        batch_size=args.batch_size,
        pin_memory=True,
        shuffle=True
    )
validloader = DataLoader(validset, batch_size=args.test_batch_size, shuffle=False, num_workers=2, pin_memory=False)
testloader = DataLoader(testset, batch_size=args.test_batch_size, shuffle=False, num_workers=2, pin_memory=False)

criterion = torch.nn.CrossEntropyLoss()


optimizer = torch.optim.SGD(net.parameters(), lr=args.lr,
                momentum=0.9, weight_decay=5e-4, nesterov=True)
optimizer.param_groups[0]['initial_lr'] = args.lr

mile_stones_epochs = [100, 110]
gamma = 0.2
scheduler = torch.optim.lr_scheduler.MultiStepLR(
    optimizer,
    milestones=mile_stones_epochs,
    last_epoch=-1,
    gamma=gamma,
)
if args.cuda:
    net = net.cuda()
args.logger.info("start training")
# basic_train(
#             trainloader,
#             validloader,
#             testloader,
#             criterion,
#             args,
#             net,
#             optimizer,
#             scheduler=scheduler,
#         )

INFO:root:start training


selecting meta samples with RBC

In [10]:
from src.main.main_train import load_checkpoint2
from src.main.find_valid_set import get_representative_valid_ids_rbc, get_representative_valid_ids_gbc

#load the pretrained models for selecting meta samples
args.prev_save_path = args.save_path

net = load_checkpoint2(args, net)

#label_aware is false for noisy labels while label_aware is true for imbalanced dataset
args.label_aware=False
args.all_layer=False
args.model_prov_period=20
args.do_train=False
args.bias_classes=False
args.weight_by_norm=True

valid_ids, new_valid_representations = get_representative_valid_ids_rbc(criterion, optimizer, trainloader, args, net, valid_count)


INFO:root:==> Loading cached model...
INFO:root:==> Loading cached model successfully
391it [00:17, 22.33it/s]
INFO:root:max norm of the representation:5.245381
INFO:root:min norm of the representation:0.402533
INFO:root:extra representation starting from epoch 89
INFO:root:==> Loading cached model at epoch 89
INFO:root:==> Loading cached model successfully
391it [00:20, 18.94it/s]
INFO:root:max norm of the representation:5.058860
INFO:root:min norm of the representation:0.475458
INFO:root:==> Loading cached model at epoch 109
INFO:root:==> Loading cached model at epoch 129
INFO:root:==> Loading cached model at epoch 149
INFO:root:==> Loading cached model at epoch 169


running k-means on cuda..


100%|██████████| 99/99 [01:34<00:00,  1.05it/s]
[running kmeans]: 152it [03:34,  1.41s/it, center_shift=0.000039, iteration=152, tol=0.000100]
INFO:root:cluster count before and after:(100,92)


max dist sample to assigned cluster mean:: 0.5743220448493958
min dist sample to other cluster mean:: 0.19798117876052856
running k-means on cuda..


100%|██████████| 91/91 [01:20<00:00,  1.13it/s]
[running kmeans]: 119it [02:37,  1.32s/it, center_shift=0.000054, iteration=119, tol=0.000100]
INFO:root:cluster count before and after:(92,92)
INFO:root:unique cluster count::92


max dist sample to assigned cluster mean:: 0.5697579979896545
min dist sample to other cluster mean:: 0.19198131561279297
