Skip to content
Permalink

Comparing changes

Choose two branches to see what’s changed or to start a new pull request. If you need to, you can also or learn more about diff comparisons.

Open a pull request

Create a new pull request by comparing changes across two branches. If you need to, you can also . Learn more about diff comparisons here.
base repository: bhanML/Co-teaching
Failed to load repositories. Confirm that selected base ref is valid, then try again.
Loading
base: master
Choose a base ref
...
head repository: gist-ailab/Co-teaching
Failed to load repositories. Confirm that selected head ref is valid, then try again.
Loading
compare: master
Choose a head ref
  • 2 commits
  • 18 files changed
  • 1 contributor

Commits on Jan 16, 2024

  1. Copy the full SHA
    45a6fc7 View commit details

Commits on Feb 13, 2024

  1. add codis

    birdomi committed Feb 13, 2024
    Copy the full SHA
    a3547ff View commit details
Binary file added __pycache__/loss.cpython-37.pyc
Binary file not shown.
Binary file added __pycache__/loss_codis.cpython-37.pyc
Binary file not shown.
Binary file added __pycache__/model.cpython-37.pyc
Binary file not shown.
Binary file added data/__pycache__/__init__.cpython-37.pyc
Binary file not shown.
Binary file added data/__pycache__/cifar.cpython-37.pyc
Binary file not shown.
Binary file added data/__pycache__/mnist.cpython-37.pyc
Binary file not shown.
Binary file added data/__pycache__/utils.cpython-37.pyc
Binary file not shown.
8 changes: 4 additions & 4 deletions data/utils.py
Original file line number Diff line number Diff line change
@@ -99,7 +99,7 @@ def multiclass_noisify(y, P, random_state=0):
""" Flip classes according to transition probability matrix T.
It expects a number between 0 and the number of classes - 1.
"""
print np.max(y), P.shape[0]
print(np.max(y), P.shape[0])
assert P.shape[0] == P.shape[1]
assert np.max(y) < P.shape[0]

@@ -108,7 +108,7 @@ def multiclass_noisify(y, P, random_state=0):
assert (P >= 0.0).all()

m = y.shape[0]
print m
print(m)
new_y = y.copy()
flipper = np.random.RandomState(random_state)

@@ -142,7 +142,7 @@ def noisify_pairflip(y_train, noise, random_state=None, nb_classes=10):
assert actual_noise > 0.0
print('Actual noise %.2f' % actual_noise)
y_train = y_train_noisy
print P
print(P)

return y_train, actual_noise

@@ -167,7 +167,7 @@ def noisify_multiclass_symmetric(y_train, noise, random_state=None, nb_classes=1
assert actual_noise > 0.0
print('Actual noise %.2f' % actual_noise)
y_train = y_train_noisy
print P
print(P)

return y_train, actual_noise

42 changes: 38 additions & 4 deletions loss.py
Original file line number Diff line number Diff line change
@@ -7,18 +7,19 @@
# Loss functions
def loss_coteaching(y_1, y_2, t, forget_rate, ind, noise_or_not):
loss_1 = F.cross_entropy(y_1, t, reduce = False)
ind_1_sorted = np.argsort(loss_1.data).cuda()
ind_1_sorted = torch.argsort(loss_1)
loss_1_sorted = loss_1[ind_1_sorted]

loss_2 = F.cross_entropy(y_2, t, reduce = False)
ind_2_sorted = np.argsort(loss_2.data).cuda()
ind_2_sorted = torch.argsort(loss_2)
loss_2_sorted = loss_2[ind_2_sorted]

remember_rate = 1 - forget_rate
num_remember = int(remember_rate * len(loss_1_sorted))

pure_ratio_1 = np.sum(noise_or_not[ind[ind_1_sorted[:num_remember]]])/float(num_remember)
pure_ratio_2 = np.sum(noise_or_not[ind[ind_2_sorted[:num_remember]]])/float(num_remember)
# print(noise_or_not, ind)
pure_ratio_1 = torch.sum(noise_or_not[ind[ind_1_sorted[:num_remember]]])/float(num_remember)
pure_ratio_2 = torch.sum(noise_or_not[ind[ind_2_sorted[:num_remember]]])/float(num_remember)

ind_1_update=ind_1_sorted[:num_remember]
ind_2_update=ind_2_sorted[:num_remember]
@@ -29,3 +30,36 @@ def loss_coteaching(y_1, y_2, t, forget_rate, ind, noise_or_not):
return torch.sum(loss_1_update)/num_remember, torch.sum(loss_2_update)/num_remember, pure_ratio_1, pure_ratio_2


# Loss functions
def loss_multiteaching(y_, t, forget_rate, ind, noise_or_not):
loss_ = dict()
loss_mean = dict()
pure_ratio = dict()

for n in range(len(y_)):
loss_[n] = F.cross_entropy(y_[n], t, reduce = False)

remember_rate = 1 - forget_rate
num_remember = int(remember_rate * len(loss_[n]))

# print(noise_or_not, ind)
for n in range(len(y_)):
# Calculate Loss for n-th model
loss_sum = torch.zeros_like(loss_[n])
other_idx = list(range(len(y_)))
other_idx.pop(n)
for j in other_idx:
loss_sum += loss_[j]
ind_sorted = torch.argsort(loss_sum)
loss_sorted = loss_[n][ind_sorted]

pure_ratio_ = torch.sum(noise_or_not[ind[ind_sorted[:num_remember]]])/float(num_remember)

ind_update=ind_sorted[:num_remember]

# exchange
loss_update = F.cross_entropy(y_[n][ind_update], t[ind_update])
loss_mean[n] = loss_update / num_remember
pure_ratio[n] = pure_ratio_

return loss_mean, pure_ratio
64 changes: 64 additions & 0 deletions loss_codis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
import math
from numpy.testing import assert_array_almost_equal
import warnings

warnings.filterwarnings('ignore')

def kl_loss_compute(pred, soft_targets, reduce=True):

kl = F.kl_div(F.log_softmax(pred, dim=1),F.softmax(soft_targets, dim=1), reduce=False)

if reduce:
return torch.mean(torch.sum(kl, dim=1))
else:
return torch.sum(kl, 1)


def js_loss_compute(pred, soft_targets, reduce=True):

pred_softmax = F.softmax(pred, dim=1)
targets_softmax = F.softmax(soft_targets, dim=1)
mean = (pred_softmax + targets_softmax) / 2
kl_1 = F.kl_div(F.log_softmax(pred, dim=1), mean, reduce=False)
kl_2 = F.kl_div(F.log_softmax(soft_targets, dim=1), mean, reduce=False)
js = (kl_1 + kl_2) / 2

if reduce:
return torch.mean(torch.sum(js, dim=1))
else:
return torch.sum(js, 1)

def loss_ours(y_1, y_2, t, forget_rate, ind, noise_or_not, co_lambda=0.1):

loss_1 = F.cross_entropy(y_1, t, reduction='none') - co_lambda * js_loss_compute(y_1, y_2,reduce=False)
ind_1_sorted = np.argsort(loss_1.cpu().data).cuda()
loss_1_sorted = loss_1[ind_1_sorted]

loss_2 = F.cross_entropy(y_2, t, reduction='none') - co_lambda * js_loss_compute(y_1, y_2,reduce=False)
ind_2_sorted = np.argsort(loss_2.cpu().data).cuda()
loss_2_sorted = loss_2[ind_2_sorted]

remember_rate = 1 - forget_rate
num_remember = int(remember_rate * len(loss_1_sorted))

ind_1_update=ind_1_sorted[:num_remember].cpu()
ind_2_update=ind_2_sorted[:num_remember].cpu()
if len(ind_1_update) == 0:
ind_1_update = ind_1_sorted.cpu().numpy()
ind_2_update = ind_2_sorted.cpu().numpy()
num_remember = ind_1_update.shape[0]

pure_ratio_1 = np.sum(noise_or_not[ind[ind_1_sorted.cpu()[:num_remember]]])/float(num_remember)
pure_ratio_2 = np.sum(noise_or_not[ind[ind_2_sorted.cpu()[:num_remember]]])/float(num_remember)

loss_1_update = loss_1[ind_2_update]
loss_2_update = loss_2[ind_1_update]


return torch.sum(loss_1_update)/num_remember, torch.sum(loss_2_update)/num_remember, pure_ratio_1, pure_ratio_2
42 changes: 28 additions & 14 deletions main.py → main_co.py
Original file line number Diff line number Diff line change
@@ -24,7 +24,7 @@
parser.add_argument('--num_gradual', type = int, default = 10, help='how many epochs for linear drop rate, can be 5, 10, 15. This parameter is equal to Tk for R(T) in Co-teaching paper.')
parser.add_argument('--exponent', type = float, default = 1, help='exponent of the forget rate, can be 0.5, 1, 2. This parameter is equal to c in Tc for R(T) in Co-teaching paper.')
parser.add_argument('--top_bn', action='store_true')
parser.add_argument('--dataset', type = str, help = 'mnist, cifar10, or cifar100', default = 'mnist')
parser.add_argument('--dataset', type = str, help = 'mnist, cifar10, or cifar100', default = 'cifar10')
parser.add_argument('--n_epoch', type=int, default=200)
parser.add_argument('--seed', type=int, default=1)
parser.add_argument('--print_freq', type=int, default=50)
@@ -71,15 +71,15 @@
args.top_bn = False
args.epoch_decay_start = 80
args.n_epoch = 200
train_dataset = CIFAR10(root='./data/',
train_dataset = CIFAR10(root='/SSDc/yyg/cifar10/',
download=True,
train=True,
transform=transforms.ToTensor(),
noise_type=args.noise_type,
noise_rate=args.noise_rate
)

test_dataset = CIFAR10(root='./data/',
test_dataset = CIFAR10(root='/SSDc/yyg/cifar10/',
download=True,
train=False,
transform=transforms.ToTensor(),
@@ -93,15 +93,15 @@
args.top_bn = False
args.epoch_decay_start = 100
args.n_epoch = 200
train_dataset = CIFAR100(root='./data/',
train_dataset = CIFAR100(root='/SSDc/yyg/cifar100/',
download=True,
train=True,
transform=transforms.ToTensor(),
noise_type=args.noise_type,
noise_rate=args.noise_rate
)

test_dataset = CIFAR100(root='./data/',
test_dataset = CIFAR100(root='/SSDc/yyg/cifar100/',
download=True,
train=False,
transform=transforms.ToTensor(),
@@ -115,6 +115,7 @@
forget_rate=args.forget_rate

noise_or_not = train_dataset.noise_or_not
noise_or_not = torch.tensor(noise_or_not).cuda()

# Adjust learning rate and betas for Adam Optimizer
mom1 = 0.9
@@ -159,13 +160,13 @@ def accuracy(logit, target, topk=(1,)):

res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res

# Train the Model
def train(train_loader,epoch, model1, optimizer1, model2, optimizer2):
print 'Training %s...' % model_str
print('Training %s...' % model_str)
pure_ratio_list=[]
pure_ratio_1_list=[]
pure_ratio_2_list=[]
@@ -176,7 +177,7 @@ def train(train_loader,epoch, model1, optimizer1, model2, optimizer2):
train_correct2=0

for i, (images, labels, indexes) in enumerate(train_loader):
ind=indexes.cpu().numpy().transpose()
ind=indexes.T
if i>args.num_iter_per_epoch:
break

@@ -203,17 +204,30 @@ def train(train_loader,epoch, model1, optimizer1, model2, optimizer2):
optimizer2.zero_grad()
loss_2.backward()
optimizer2.step()

# print(loss_1.d)
if (i+1) % args.print_freq == 0:
print ('Epoch [%d/%d], Iter [%d/%d] Training Accuracy1: %.4F, Training Accuracy2: %.4f, Loss1: %.4f, Loss2: %.4f, Pure Ratio1: %.4f, Pure Ratio2 %.4f'
%(epoch+1, args.n_epoch, i+1, len(train_dataset)//batch_size, prec1, prec2, loss_1.data[0], loss_2.data[0], np.sum(pure_ratio_1_list)/len(pure_ratio_1_list), np.sum(pure_ratio_2_list)/len(pure_ratio_2_list)))
%(
epoch+1,
args.n_epoch,
i+1,
len(train_dataset)//batch_size,
prec1,
prec2,
loss_1.item(),
loss_2.item(),
(np.sum(pure_ratio_1_list)/len(pure_ratio_1_list)),
(np.sum(pure_ratio_2_list)/len(pure_ratio_2_list)))
)

train_acc1=float(train_correct)/float(train_total)
train_acc2=float(train_correct2)/float(train_total2)
return train_acc1, train_acc2, pure_ratio_1_list, pure_ratio_2_list

# Evaluate the Model
def evaluate(test_loader, model1, model2):
print 'Evaluating %s...' % model_str
print('Evaluating %s...' % model_str)
model1.eval() # Change model to 'eval' mode.
correct1 = 0
total1 = 0
@@ -243,7 +257,7 @@ def evaluate(test_loader, model1, model2):

def main():
# Data Loader (Input Pipeline)
print 'loading dataset...'
print('loading dataset...')
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=batch_size,
num_workers=args.num_workers,
@@ -256,15 +270,15 @@ def main():
drop_last=True,
shuffle=False)
# Define models
print 'building model...'
print('building model...')
cnn1 = CNN(input_channel=input_channel, n_outputs=num_classes)
cnn1.cuda()
print cnn1.parameters
print(cnn1.parameters)
optimizer1 = torch.optim.Adam(cnn1.parameters(), lr=learning_rate)

cnn2 = CNN(input_channel=input_channel, n_outputs=num_classes)
cnn2.cuda()
print cnn2.parameters
print(cnn2.parameters)
optimizer2 = torch.optim.Adam(cnn2.parameters(), lr=learning_rate)

mean_pure_ratio1=0
Loading