In [2]:
from __future__ import print_function
import sys
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import random
import os
import argparse
import numpy as np
from Models import *
from torchvision.models import vgg19_bn

from MulticoreTSNE import MulticoreTSNE as TSNE

import data_loader_cifar as dataloader

import time
import datetime
import gc

#40 asym, 0.04 xi, eta 10, nc 20, nv 40

In [2]:
parser = argparse.ArgumentParser(description='PyTorch CIFAR Training')
parser.add_argument('--batch_size', default=64, type=int, help='train batchsize') 
parser.add_argument('--lr', '--learning_rate', default=0.02, type=float, help='initial learning rate')
parser.add_argument('--noise_mode',  default='sym_0.5', help = 'aggre,worst,rand1,rand2,rand3,noisy100, or sym_0.x, asym_0.x')
parser.add_argument('--num_epochs', default=300, type=int)
parser.add_argument('--t_w', default=10, type=int)
parser.add_argument('--nR', default=0.04, type=float)
parser.add_argument('--nc', default=0.2, type=float)
parser.add_argument('--nv', default=0.8, type=float)
parser.add_argument('--nvs', default=0.1, type=float)
parser.add_argument('--id', default='')
parser.add_argument('--seed', default=123)
parser.add_argument('--gpuid', default=0, type=int)
parser.add_argument('--data_path', default='./cifar-10', type=str, help='path to dataset')
parser.add_argument('--dataset', default='cifar10', type=str)
args = parser.parse_args(args = ['--data_path', 'data/CIFAR10',
                                 '--dataset', 'cifar10',
                                 '--noise_mode','asym_0.4',
                                 '--t_w', '10',
                                 '--batch_size','64',
                                 '--lr','0.02',
                                 '--num_epochs','150',
                                 '--nR', '0.1',
                                 '--nc','0.1',
                                 '--nv','0.2',
                                 '--nvs','0.02'])

In [3]:
torch.cuda.set_device(args.gpuid)
random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)

In [4]:
samples = 50000
test_samples = 10000
if args.dataset == 'cifar10':
    n_class = 10
else:
    n_class = 100
feature_num = 512
t_w = args.t_w

In [5]:
def test(epoch,net,):
    net.eval()
    correct = 0
    total = 0
    feature_temp = np.zeros((test_samples, feature_num))
    with torch.no_grad():
        for batch_idx, (inputs, targets, ind) in enumerate(test_loader):
            ind = ind.numpy()
            inputs, targets = inputs.cuda(), targets.cuda()
            feature, output = net(inputs)       
            _, predicted = torch.max(output, 1)     
            
            feature_temp[ind] = feature.cpu().detach().numpy()
                       
            total += targets.size(0)
            correct += predicted.eq(targets).cpu().sum().item()                 
    acc = 100.*correct/total
    
    test_log.write('Epoch:%d   Accuracy:%.2f\n'%(epoch,acc))
    test_log.flush()  
    
    lossb = relevant_hard_np(feature_temp)
    return acc, lossb, feature_temp


def linear_rampup(current, warm_up, rampup_length=16):
    current = np.clip((current-warm_up) / rampup_length, 0.0, 1.0)
    return args.lambda_u*float(current)


class NegEntropy(object):
    def __call__(self,outputs):
        probs = torch.softmax(outputs, dim=1)
        return torch.mean(torch.sum(probs.log()*probs, dim=1))
    

class Orthogonal_loss(nn.Module):
    def __init__(self,):
        super(Orthogonal_loss, self).__init__()
        
    def forward(self, x, ):
        n = x.size(0)
        m = x.size(1)

        I = torch.eye(m).cuda()
        e = x - x.mean(dim=0, keepdims = True)
        m_nonz = (e.sum(dim = 0) != 0).sum()
        
        cov = e.T @ e
        
        cov2 = cov ** 2
        
        select_i = torch.argmax(cov2 - cov2 * I, dim = 1)
        cov_m = (F.one_hot(select_i, num_classes = m) * cov2).sum()
        cov_i = (I * cov).sum()
        
        result = (cov_m-cov_i) / (m_nonz * n)
        return result
    
def relevant_hard_np(x,):
    n = x.shape[0]
    nz = x.shape[1]
    e = x - x.mean(axis = 0,keepdims = True)

    cov = e.T @ e

    sigma = (e ** 2).sum(axis = 0, keepdims = True)
    r = cov / (sigma.T @ sigma) ** 0.5

    r = r ** 2
    r[np.isnan(r)] = 0.0

    return np.mean(np.max(r - r * np.eye(nz), axis = -1))
    
class MSELoss(object):
    def __call__(self, logits, targets,):
        if len(targets.shape) == 1:
            targets = F.one_hot(targets, num_classes=n_class)
        
        probs = torch.softmax(logits, dim=1)

        Lu = torch.mean((probs - targets)**2, dim = -1)

        return Lu
    
def create_model():
    model = ResNet18(num_classes=n_class)
    model = model.cuda()
    return model

In [6]:
stats_log=open('./checkpoint/SNRLNL_%s_%s_%s'%(
    args.dataset,args.noise_mode,str(datetime.date.today())+'_'+str(time.localtime().tm_hour))+'_stats.txt','w') 
test_log=open('./checkpoint/SNRLNL_%s_%s_%s'%(
    args.dataset,args.noise_mode,str(datetime.date.today())+'_'+str(time.localtime().tm_hour))+'_acc.txt','w')     

In [7]:
real_world_noise_types = ['aggre','worst','rand1','rand2','rand3','noisy100']
synthetic_noise_types = ['sym','asym']

In [8]:
if args.noise_mode in real_world_noise_types:
    if args.dataset == 'cifar10':
        loader = dataloader.cifar_dataloader(args.dataset,noise_mode=args.noise_mode,batch_size=args.batch_size,num_workers=5,\
            root_dir=args.data_path,log=stats_log,noise_file='%s/CIFAR-10_human.pt'%(args.data_path,))
    else:
        loader = dataloader.cifar_dataloader(args.dataset,noise_mode=args.noise_mode,batch_size=args.batch_size,num_workers=5,\
            root_dir=args.data_path,log=stats_log,noise_file='%s/CIFAR-100_human.pt'%(args.data_path,))
else:
    loader = dataloader.cifar_dataloader(args.dataset,noise_mode=args.noise_mode,batch_size=args.batch_size,num_workers=5,\
        root_dir=args.data_path,log=stats_log,noise_file='%s/%s.json'%(args.data_path,args.noise_mode))

print('| Building net')
net = create_model()
cudnn.benchmark = True

opt = optim.SGD(net.parameters(),
                lr=args.lr,
                momentum=0.9,
                weight_decay=5e-4
               )

sch = optim.lr_scheduler.MultiStepLR(opt, [50, 100,], gamma = 0.1)

CE = nn.CrossEntropyLoss(reduction='none')
CEloss = nn.CrossEntropyLoss()
MSE = MSELoss()

loss_ortho = Orthogonal_loss()

all_loss = [[],[]] # save the history of losses from two networks

traindataset, trainloader = loader.run('warmup')
testdataset, test_loader = loader.run('test')

| Building net


In [None]:
train_Y = np.array(traindataset.train_label)
test_Y = np.array(testdataset.test_label)
noisy_Y = np.array(traindataset.noise_label)
revised_Y = np.array(traindataset.noise_label)
revised_Y_before = np.array(traindataset.noise_label)

Yt_list = [np.array(traindataset.noise_label)]
acc_list = []
loss_sep_list = [[]]
loss_train_list = []
Py_temp_list = []

score = np.random.rand(samples,)
score_uncertainty = np.random.rand(samples,)

OOD_mask_before = np.zeros((samples,),np.bool)

start_time = time.time()

for epoch in range(args.num_epochs):
    
    net.train()
    feature_num = 512
    if epoch < t_w:
        _, trainloader = loader.run('warmup')
    else:
        _, trainloader = loader.run('train')

    loss_train = 0
    acc_train = 0
    acc_train_ori = 0
    loss_train_ori = 0
    Py_temp = np.zeros((samples,),dtype=np.float32)
    Pred_temp = np.zeros((samples,),dtype=np.float32)
    Probs_temp = np.zeros((samples,n_class),dtype=np.float32)
    Logits_temp = np.zeros((samples,n_class),dtype=np.float32)
    Pred_other_temp = np.zeros((samples,),dtype=np.float32)
    
    feature_temp = np.zeros((samples, feature_num), dtype = np.float32)

    if len(Py_temp_list) > 1:
        score = Py_temp_list[-1]
    else:
        score = np.random.rand(samples,)
        
    OOD_mask = np.logical_and(score < np.sort(score[~OOD_mask_before])[int(len(score[~OOD_mask_before]) * args.nR)], ~OOD_mask_before)
    
    Y_onehot = np.eye(n_class)[revised_Y].astype(np.float32)
    Y_onehot_0 = np.eye(n_class)[noisy_Y].astype(np.float32)
    
    for batch_id, (X_data, targets, ind) in enumerate(trainloader):
        ind = ind.numpy()

        Y_data = np.array(revised_Y[ind]).astype(np.int64)
        Y_data_ori = np.array(train_Y[ind]).astype(np.int64)
        Y_data_before = np.array(revised_Y_before[ind]).astype(np.int64)
        temp_X = X_data.cuda()
        opt.zero_grad()
        Y_GPU = torch.from_numpy(Y_data).cuda()
        Y_GPU_ori = torch.from_numpy(Y_data_ori).cuda()
        Y_GPU_before = torch.from_numpy(Y_data_before).cuda()
        
        y_onehot = F.one_hot(Y_GPU.view(-1,),num_classes=n_class)
        y_onehot_before = F.one_hot(Y_GPU_before.view(-1,),num_classes=n_class)
        
        feature, logits = net(temp_X)

        probs = logits.softmax(1)
        Py = torch.sum(y_onehot * probs, dim = -1)
        Pred = probs.argmax(-1)
        logits_other = logits - logits * y_onehot_before
        Pred_other = torch.argmax(logits_other,dim=-1)
        
        if epoch < args.t_w:
            loss = CEloss(logits,Y_GPU.view(-1,))
        else:
            Y_GPU = torch.where(torch.from_numpy(OOD_mask[ind]).cuda(), Pred_other, Y_GPU)
            loss = CEloss(logits,Y_GPU.view(-1,))
            
        Py_temp[ind] = Py.cpu().detach().numpy()
        Pred_temp[ind] = Pred.cpu().detach().numpy()

        Probs_temp[ind] = probs.cpu().detach().numpy()
        Logits_temp[ind] = logits.cpu().detach().numpy()
        feature_temp[ind] = feature.cpu().detach().numpy()
        
        Pred_other_temp[ind] = Pred_other.cpu().detach().numpy()
                
        correct = (Pred == Y_GPU).sum().item()
        correct_ori = (Pred == Y_GPU_ori).sum().item()

        loss_train += loss.item()
        acc_train += correct
        acc_train_ori += correct_ori
        
        loss_sep_list[-1].append(loss.item())
        
        loss.backward()
        
        opt.step()
    sch.step()
    
    loss_train/=(batch_id+1)
    acc_train/=samples
    acc_train_ori/=samples

    print('epoch %d train complete'%epoch)
    acc_eval, lossb_eval, feature_val = test(epoch, net)
    
    loss_b = relevant_hard_np(feature_temp[:10000])

    acc_list.append(acc_eval)
    Py_temp_list.append(Py_temp)
    
    revised_Y_before = revised_Y.copy()
    OOD_mask_before = OOD_mask.copy()    
        
    if epoch < t_w:
        select = np.zeros((samples,),dtype = np.bool)
        score = np.random.rand(samples,)
    else:
        nC_points = []
        clean_mask = np.zeros((samples,),dtype=bool)
        Py_mean = np.zeros((samples,))
        for j in range(len(Py_temp_list)):
            Py_mean+=Py_temp_list[j]
        Py_mean/=len(Py_temp_list)   
        for j in range(n_class):
            class_mask = noisy_Y == j
            c_n = class_mask.sum()
            c_th = np.sort(Py_mean[class_mask])[-int(c_n * args.nc)]
            nC_points.append(np.where(np.logical_and(Py_mean>=c_th,class_mask))[0])
        nC_points = np.concatenate(nC_points)
        L_batch = 1000
        Y_onehot = np.eye(n_class)[revised_Y]
        zC = torch.from_numpy(feature_temp[nC_points].astype(np.float32)).cuda()
        fC = torch.from_numpy(Probs_temp[nC_points].astype(np.float32)).cuda()
        yC = torch.from_numpy(Y_onehot[nC_points].astype(np.float32)).cuda()
        fCcyC = fC - yC
        lr = 1e-6
        learning_risk = np.zeros((samples,)) 
        for j in range(int(np.ceil(samples/L_batch))):
            i_ind = np.arange(j*L_batch, min(samples,(j+1)*L_batch))
            zi = torch.from_numpy(feature_temp[i_ind].astype(np.float32)).cuda()
            fi = torch.from_numpy(Probs_temp[i_ind].astype(np.float32)).cuda()
            yi = torch.from_numpy(Y_onehot[i_ind].astype(np.float32)).cuda()
            zixzC = zi @ zC.transpose(1,0)
            part_1_1 = (zixzC + 1) @ fCcyC
            part1 = (part_1_1 * (yi-fi)).sum(dim=-1,keepdim=True)*4*lr/len(nC_points)
            part1_all = part_1_1 * (fi-torch.ones_like(yi))*4*lr/len(nC_points)

            learning_risk[i_ind] = part1.cpu().detach().numpy().ravel()
        # r_ = min(0.1 + 0.1 * (epoch - t_w), args.nv)
        r_ = min(args.nvs + args.nvs * (epoch - t_w), args.nv)
        # select = np.zeros((samples,),dtype=np.bool8)
        # for j in range(n_class):
        #     class_mask = revised_Y == j
        #     c_n = class_mask.sum()
        #     c_th = np.sort(learning_risk[class_mask])[-min(int(c_n * r_),c_n)]
        #     select[np.logical_and(learning_risk>=c_th, class_mask)] = True
        th = np.sort(learning_risk)[-min(int(samples * r_),samples)]
        select = learning_risk >= th
        
        revised_Y=np.where(select.ravel(), Pred_temp.ravel(), noisy_Y.ravel()).astype(int)

    is_noise = revised_Y != train_Y
    max_noised_class = -999
    for j_ in range(n_class):
        class_mask = train_Y == j_
        noise_n = np.logical_and(class_mask, is_noise).sum()
        if noise_n > max_noised_class:
            max_noised_class = noise_n
    

    Yt_remain_noise = np.sum(is_noise)
    end_time = time.time()
    print_str = 'loss_b:%.4f, loss_b_eval:%.4f, train loss:%.4f, train acc:%.4f, train acc ori:%.4f,\
          eval acc:%.4f, time elapsed:%.4f, epoch %d train cleaned, %d samples changed,\
          total remain noise:%.4d, max class noise:%d'%(loss_b,
                                                        lossb_eval,
                                                        loss_train,
                                                        acc_train,
                                                        acc_train_ori, 
                                                        acc_eval,
                                                        end_time - start_time,
                                                        epoch,
                                                        np.sum(revised_Y!=noisy_Y),
                                                        Yt_remain_noise,
                                                        max_noised_class)
    print(print_str)
    stats_log.write(print_str+'\n')
    stats_log.flush()  
    # loss_sep_list.append([])
    Yt_list.append(revised_Y)
    
    gc.collect()