In [18]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
import time
import os
import numpy as np

transform = transforms.Compose(
    [
     transforms.RandomHorizontalFlip(),
     transforms.RandomGrayscale(),
     transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

transform1 = transforms.Compose(
    [
     transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=False, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=100,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=False, transform=transform1)
testloader = torch.utils.data.DataLoader(testset, batch_size=50,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')



class Net(nn.Module):


    def __init__(self):
        super(Net,self).__init__()
        self.conv1 = nn.Conv2d(3,64,3,padding=1)
        self.conv2 = nn.Conv2d(64,64,3,padding=1)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu1 = nn.ReLU()

        self.conv3 = nn.Conv2d(64,128,3,padding=1)
        self.conv4 = nn.Conv2d(128, 128, 3,padding=1)
        self.pool2 = nn.MaxPool2d(2, 2, padding=1)
        self.bn2 = nn.BatchNorm2d(128)
        self.relu2 = nn.ReLU()

        self.conv5 = nn.Conv2d(128,128, 3,padding=1)
        self.conv6 = nn.Conv2d(128, 128, 3,padding=1)
        self.conv7 = nn.Conv2d(128, 128, 1,padding=1)
        self.pool3 = nn.MaxPool2d(2, 2, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.relu3 = nn.ReLU()

        self.conv8 = nn.Conv2d(128, 256, 3,padding=1)
        self.conv9 = nn.Conv2d(256, 256, 3, padding=1)
        self.conv10 = nn.Conv2d(256, 256, 1, padding=1)
        self.pool4 = nn.MaxPool2d(2, 2, padding=1)
        self.bn4 = nn.BatchNorm2d(256)
        self.relu4 = nn.ReLU()

        self.conv11 = nn.Conv2d(256, 512, 3, padding=1)
        self.conv12 = nn.Conv2d(512, 512, 3, padding=1)
        self.conv13 = nn.Conv2d(512, 512, 1, padding=1)
        self.pool5 = nn.MaxPool2d(2, 2, padding=1)
        self.bn5 = nn.BatchNorm2d(512)
        self.relu5 = nn.ReLU()

        self.fc14 = nn.Linear(512*4*4,1024)
        self.drop1 = nn.Dropout2d()
        self.fc15 = nn.Linear(1024,1024)
        self.drop2 = nn.Dropout2d()
        self.fc16 = nn.Linear(1024,10)


    def forward(self,x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.pool1(x)
        x = self.bn1(x)
        x = self.relu1(x)


        x = self.conv3(x)
        x = self.conv4(x)
        x = self.pool2(x)
        x = self.bn2(x)
        x = self.relu2(x)

        x = self.conv5(x)
        x = self.conv6(x)
        x = self.conv7(x)
        x = self.pool3(x)
        x = self.bn3(x)
        x = self.relu3(x)

        x = self.conv8(x)
        x = self.conv9(x)
        x = self.conv10(x)
        x = self.pool4(x)
        x = self.bn4(x)
        x = self.relu4(x)

        x = self.conv11(x)
        x = self.conv12(x)
        x = self.conv13(x)
        x = self.pool5(x)
        x = self.bn5(x)
        x = self.relu5(x)
        # print(" x shape ",x.size())
        x = x.view(-1,512*4*4)
        x = F.relu(self.fc14(x))
        x = self.drop1(x)
        x = F.relu(self.fc15(x))
        x = self.drop2(x)
        x = self.fc16(x)

        return x

    def train_sgd(self,device):
        optimizer = optim.Adam(self.parameters(), lr=0.0001)

        path = 'weights.tar'
        initepoch = 0

        if os.path.exists(path) is not True:
            loss = nn.CrossEntropyLoss()
#             loss = loss_new()
            # optimizer = optim.SGD(self.parameters(),lr=0.01)

        else:
            checkpoint = torch.load(path)
            self.load_state_dict(checkpoint['model_state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            initepoch = checkpoint['epoch']
            loss = checkpoint['loss']




        for epoch in range(initepoch,20):  # loop over the dataset multiple times
            timestart = time.time()

            running_loss = 0.0
            total = 0
            correct = 0
            for i, data in enumerate(trainloader, 0):
                # get the inputs
                inputs, labels = data
                inputs, labels = inputs.to(device),labels.to(device,dtype= torch.int64)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward + backward + optimize
                outputs = self(inputs)
                l = loss_new(outputs, labels)
                l.backward()
                optimizer.step()

                # print statistics
                running_loss += l.item()
                # print("i ",i)
                if i % 500 == 499:  # print every 500 mini-batches
                    print('[%d, %5d] loss: %.4f' %
                          (epoch, i, running_loss / 500))
                    running_loss = 0.0
                    _, predicted = torch.max(outputs.data, 1)
                    total += labels.size(0)
                    correct += (predicted == labels).sum().item()
                    print('Accuracy of the network on the %d tran images: %.3f %%' % (total,
                            100.0 * correct / total))
                    total = 0
                    correct = 0
                    torch.save({'epoch':epoch,
                                'model_state_dict':net.state_dict(),
                                'optimizer_state_dict':optimizer.state_dict(),
                                'loss':loss
                                },path)

            print('epoch %d cost %3f sec' %(epoch,time.time()-timestart))

        print('Finished Training')

    def test(self,device):
        correct = 0
        total = 0
        with torch.no_grad():
            for data in testloader:
                images, labels = data
                images, labels = images.to(device), labels.to(device)
                outputs = self(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        print('Accuracy of the network on the 10000 test images: %.3f %%' % (
                100.0 * correct / total))



In [None]:
def loss_new(outputs, labels):
    # sfm=nn.Softmax(1)
    # pred=sfm(outputs)
    # pred = pred.detach()#detach使得pred requires_grad=False,并且不影响outputs

    pred = F.softmax(outputs, dim=1)####是否应该detach????????????/

    # criterion = nn.CrossEntropyLoss()
    # Lo=criterion(labels_update,labels)

    Le = -torch.mean(torch.sum(F.log_softmax(outputs, dim=1) * pred, dim=1))
    # Lc=criterion(labels_update,pred)-criterion(outputs,pred)
    print((torch.log(labels) * pred).shape)
    Lc = -torch.mean(torch.sum(torch.log(labels) * pred, dim=1)) - Le
    loss_total = Lc /class_num
    return loss_total

In [17]:
import torch
from torch import nn
from torch import optim
import numpy as np
import copy as cp

class PENCIL():
    def __init__(self, all_labels_tensor, n_samples, n_classes, n_epochs, lrs, alpha, beta, gamma, K=10, save_losses=False, use_KL=True):
        '''
        all_labels_tensor: torch tensor, 1-D tensor of labels indexed as in the training dataset object
        n_samples: int, length of training dataset
        n_epochs: list of positive ints, number of epochs of phases in form [n_epochs_i for i in range(3)]
        lrs: list of floats, learnings rates for phases in form [lr_i for i in range(3)]
        alpha: coefficient for lo loss
        beta: coefficient for le loss
        gamma: coefficient for label estimate update
        K: int, learning rate multiplier for label estimate updates
        save_losses: bool, whether to save losses into list of lists of form [[lc,lo,le] for e in *phase 2 epochs*]
        use_KL: bool, whether to use KL loss or crossentropy for phase 3
        '''
        
        self.save_losses = save_losses
        self.use_KL = use_KL
        self.n_epochs = n_epochs
        self.lrs = lrs
        self.n_classes = n_classes
        
        self.alpha = alpha
        self.beta = beta
        self.gamma = gamma
        self.K = K
        
        self.CELoss = nn.CrossEntropyLoss()
        self.KLLoss = nn.KLDivLoss(reduction='mean') #PENCIL official implementation uses mean, not batchmean
        self.softmax = nn.Softmax(dim=1)
        self.logsoftmax = nn.LogSoftmax(dim=1)
        
        self._init_y_tilde(all_labels_tensor)
        self.y_prev = None
        self.losses = []
        
    def _init_y_tilde(self, all_labels_tensor):
        '''
        all_labels_tensor: torch tensor, 1-D tensor of labels indexed as in the training dataset object
        '''
        labels_temp = torch.zeros(all_labels_tensor.size(0), self.n_classes).scatter_(1, all_labels_tensor.view(-1, 1).long(), self.K)
        self.y_tilde = labels_temp.numpy()
        
    def set_lr(self, optimizer, epoch):
        '''
        Call before inner training loop to update lr based on PENCIL phase
        '''
        lr = -1
        if epoch == 0: lr = self.lrs[0] # Phase 1
        elif epoch == self.n_epochs[0]: lr = self.lrs[1] # Phase 2
        elif epoch == self.n_epochs[0]+self.n_epochs[1]: lr = self.lrs[2] # Phase 3 
        elif epoch == self.n_epochs[0]+self.n_epochs[1]+self.n_epochs[2]//3: # Phase 3 first decay
            lr = self.lrs[2]/10
        elif epoch == self.n_epochs[0]+self.n_epochs[1]+2*self.n_epochs[2]//3: # Phase 3 second decay
            lr = self.lrs[2]/100
            
        if lr!=-1:
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
    
    def get_loss(self, epoch, outputs, labels, indices):
        '''
        outputs: un-normalized logits 
        labels: cuda tensor of noisy labels
        indices: cpu tensor of indices for current batch
        '''
        # Calculate loss based on current phase
        if epoch < self.n_epochs[0]: #Phase 1
            lc = self.CELoss(outputs, labels)
        else:
            self.y_prev = cp.deepcopy(self.y_tilde[indices,:]) #Get unnormalized label estimates
            self.y_prev = torch.tensor(self.y_prev).float()
            self.y_prev = self.y_prev.cuda()
            self.y_prev.requires_grad = True
            # obtain label distributions (y_hat)
            y_h = self.softmax(self.y_prev)
            if epoch<self.n_epochs[0]+self.n_epochs[1] or self.use_KL: # During phase 1. 
                lc = self.KLLoss(self.logsoftmax(self.y_prev),self.softmax(outputs))
            else: # During phase 2 use CE if self.use_KL=False
                lc = self.CELoss(self.softmax(outputs),self.softmax(y_h))
            lo = self.CELoss(y_h, labels) # lo is compatibility loss
            le = - torch.mean(torch.mul(self.softmax(outputs), self.logsoftmax(outputs))) # le is entropy loss
        # Compute total loss
        if epoch < self.n_epochs[0]:
            loss = lc
        elif epoch < self.n_epochs[0]+self.n_epochs[1]:
            loss = lc + self.alpha * lo + self.beta * le
            if self.save_losses: self.losses.append([lc.item(),lo.item(),le.item()])
        else:
            loss = lc
        return loss
    
    def update_y_tilde(self, epoch, indices):
        '''
        Call this after the backward pass over the loss
        ''' 
        # If in phase 2, update y estimate
        if epoch >= self.n_epochs[0] and epoch < self.n_epochs[0]+self.n_epochs[1]:
            # update y_tilde by back-propagation
            self.y_tilde[indices]+=-self.gamma*self.y_prev.grad.data.cpu().numpy()

In [19]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net = Net()
net = net.to(device)
net.train_sgd(device)
net.test(device)

RuntimeError: The size of tensor a (100) must match the size of tensor b (10) at non-singleton dimension 1

In [None]:
def loss_new(outputs, labels):
    # sfm=nn.Softmax(1)
    # pred=sfm(outputs)
    # pred = pred.detach()#detach使得pred requires_grad=False,并且不影响outputs

    pred = F.softmax(outputs, dim=1)####是否应该detach????????????/

    # criterion = nn.CrossEntropyLoss()
    # Lo=criterion(labels_update,labels)

    Le = -torch.mean(torch.sum(F.log_softmax(outputs, dim=1) * pred, dim=1))
    # Lc=criterion(labels_update,pred)-criterion(outputs,pred)
    print((torch.log(labels) * pred).shape)
    Lc = -torch.mean(torch.sum(torch.log(labels) * pred, dim=1)) - Le
    loss_total = Lc /class_num
    return loss_total