In [None]:
import os
import torch 
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torchvision.transforms as transforms
from data.cifar import CIFAR10NEW, CIFAR100
from data.mnist import MNIST
#from model import CNN
import argparse, sys
import numpy as np
import datetime
import shutil
#from net_10 import Net
#import PreResNet_3
import matplotlib.pyplot as plt
import scipy.stats as stats
from torchvision import datasets
import cv2
#model.eval()

In [None]:
class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()
        # Input size: [batch, 3, 32, 32]
        # Output size: [batch, 3, 32, 32]
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 12, 4, stride=2, padding=1),            # [batch, 12, 16, 16]
            nn.ReLU(),
            nn.Conv2d(12, 24, 4, stride=2, padding=1),           # [batch, 24, 8, 8]
            nn.ReLU(),
			nn.Conv2d(24, 48, 4, stride=2, padding=1),           # [batch, 48, 4, 4]
            nn.ReLU(),
# 			nn.Conv2d(48, 96, 4, stride=2, padding=1),           # [batch, 96, 2, 2]
#             nn.ReLU(),
        )
        self.decoder = nn.Sequential(
#             nn.ConvTranspose2d(96, 48, 4, stride=2, padding=1),  # [batch, 48, 4, 4]
#             nn.ReLU(),
			nn.ConvTranspose2d(48, 24, 4, stride=2, padding=1),  # [batch, 24, 8, 8]
            nn.ReLU(),
			nn.ConvTranspose2d(24, 12, 4, stride=2, padding=1),  # [batch, 12, 16, 16]
            nn.ReLU(),
            nn.ConvTranspose2d(12, 3, 4, stride=2, padding=1),   # [batch, 3, 32, 32]
            nn.Sigmoid(),
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return encoded, decoded

In [None]:
class UNet(nn.Module):
    def __init__(
        self,
        in_channels=3,
        n_classes=2,
        depth=5,
        wf=6,
        padding=True,
        batch_norm=True,
        up_mode='upconv',
    ):
        """
        Implementation of
        U-Net: Convolutional Networks for Biomedical Image Segmentation
        (Ronneberger et al., 2015)
        https://arxiv.org/abs/1505.04597
        Using the default arguments will yield the exact version used
        in the original paper
        Args:
            in_channels (int): number of input channels
            n_classes (int): number of output channels
            depth (int): depth of the network
            wf (int): number of filters in the first layer is 2**wf
            padding (bool): if True, apply padding such that the input shape
                            is the same as the output.
                            This may introduce artifacts
            batch_norm (bool): Use BatchNorm after layers with an
                               activation function
            up_mode (str): one of 'upconv' or 'upsample'.
                           'upconv' will use transposed convolutions for
                           learned upsampling.
                           'upsample' will use bilinear upsampling.
        """
        super(UNet, self).__init__()
        assert up_mode in ('upconv', 'upsample')
        self.padding = padding
        self.depth = depth
        prev_channels = in_channels
        self.down_path = nn.ModuleList()
        for i in range(depth):
            self.down_path.append(
                UNetConvBlock(prev_channels, 2 ** (wf + i), padding, batch_norm)
            )
            prev_channels = 2 ** (wf + i)

        self.up_path = nn.ModuleList()
        for i in reversed(range(depth - 1)):
            self.up_path.append(
                UNetUpBlock(prev_channels, 2 ** (wf + i), up_mode, padding, batch_norm)
            )
            prev_channels = 2 ** (wf + i)
        self.last = nn.Conv2d(prev_channels, 3, kernel_size=1)
        self.out = F.sigmoid

    def forward(self, x):
        blocks = []
        for i, down in enumerate(self.down_path):
            x = down(x)
            if i != len(self.down_path) - 1:
                blocks.append(x)
                x = F.max_pool2d(x, 2)
        feature = x
        for i, up in enumerate(self.up_path):
            x = up(x, blocks[-i - 1])

        outputlast = self.last(x)
        output = self.out(outputlast)
        return feature, output


class UNetConvBlock(nn.Module):
    def __init__(self, in_size, out_size, padding, batch_norm):
        super(UNetConvBlock, self).__init__()
        block = []

        block.append(nn.Conv2d(in_size, out_size, kernel_size=3, padding=int(padding)))
        block.append(nn.ReLU())
        if batch_norm:
            block.append(nn.BatchNorm2d(out_size))

        block.append(nn.Conv2d(out_size, out_size, kernel_size=3, padding=int(padding)))
        block.append(nn.ReLU())
        block.append(nn.Dropout2d(p=0.15)) # edited
        if batch_norm:
            block.append(nn.BatchNorm2d(out_size))

        self.block = nn.Sequential(*block)

    def forward(self, x):
        out = self.block(x)
        return out


class UNetUpBlock(nn.Module):
    def __init__(self, in_size, out_size, up_mode, padding, batch_norm):
        super(UNetUpBlock, self).__init__()
        if up_mode == 'upconv':
            self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2, stride=2)
        elif up_mode == 'upsample':
            self.up = nn.Sequential(
                nn.Upsample(mode='bilinear', scale_factor=2),
                nn.Conv2d(in_size, out_size, kernel_size=1),
            )

        self.conv_block = UNetConvBlock(in_size, out_size, padding, batch_norm)

    def center_crop(self, layer, target_size):
        _, _, layer_height, layer_width = layer.size()
        diff_y = (layer_height - target_size[0]) // 2
        diff_x = (layer_width - target_size[1]) // 2
        return layer[
            :, :, diff_y : (diff_y + target_size[0]), diff_x : (diff_x + target_size[1])
        ]

    def forward(self, x, bridge):
        up = self.up(x)
        crop1 = self.center_crop(bridge, up.shape[2:])
        out = torch.cat([up, crop1], 1)
        out = self.conv_block(out)

        return out

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.autograd import Variable


def conv3x3(in_planes, out_planes, stride=1):
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(in_planes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class PreActBlock(nn.Module):
    '''Pre-activation version of the BasicBlock.'''
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(PreActBlock, self).__init__()
        self.bn1 = nn.BatchNorm2d(in_planes)
        self.conv1 = conv3x3(in_planes, planes, stride)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv2 = conv3x3(planes, planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)
            )

    def forward(self, x):
        out = F.relu(self.bn1(x))
        shortcut = self.shortcut(out)
        out = self.conv1(out)
        out = self.conv2(F.relu(self.bn2(out)))
        out += shortcut
        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion*planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class PreActBottleneck(nn.Module):
    '''Pre-activation version of the original Bottleneck module.'''
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(PreActBottleneck, self).__init__()
        self.bn1 = nn.BatchNorm2d(in_planes)
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)
            )

    def forward(self, x):
        out = F.relu(self.bn1(x))
        shortcut = self.shortcut(out)
        out = self.conv1(out)
        out = self.conv2(F.relu(self.bn2(out)))
        out = self.conv3(F.relu(self.bn3(out)))
        out += shortcut
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = conv3x3(3,64)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512*block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x, lin=0, lout=5):
        out = x
        if lin < 1 and lout > -1:
            out = self.conv1(out)
            out = self.bn1(out)
            out = F.relu(out)
        if lin < 2 and lout > 0:
            out = self.layer1(out)
        if lin < 3 and lout > 1:
            out = self.layer2(out)
        if lin < 4 and lout > 2:
            out = self.layer3(out)
        if lin < 5 and lout > 3:
            out = self.layer4(out)
        if lout > 4:
            out = F.avg_pool2d(out, 4)
            out = out.view(out.size(0), -1)
            out = self.linear(out)
        return out


def ResNet18(num_classes=10):
    return ResNet(PreActBlock, [2,2,2,2], num_classes=num_classes)

def ResNet34(num_classes=10):
    return ResNet(BasicBlock, [3,4,6,3], num_classes=num_classes)

def ResNet50(num_classes=10):
    return ResNet(Bottleneck, [3,4,6,3], num_classes=num_classes)

def ResNet101(num_classes=10):
    return ResNet(Bottleneck, [3,4,23,3], num_classes=num_classes)

def ResNet152(num_classes=10):
    return ResNet(Bottleneck, [3,8,36,3], num_classes=num_classes)


def test():
    net = ResNet18()
    y = net(Variable(torch.randn(1,3,32,32)))
    print(y.size())

In [None]:
mean = [0.4914, 0.4822, 0.4465]
std = [0.2023, 0.1994, 0.2010]

transform_train = transforms.Compose([
    #transforms.RandomCrop(32, padding=4),
    #transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])
train_dataset = CIFAR10NEW(root='./data/',
                                    download=True,
                                    train=True,
                                    transform=transform_train,
                                    noise_type='symmetric',
                                    noise_rate=0.5
                               )


In [None]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=4,
                                               pin_memory=True)

In [None]:
x = train_dataset.train_data
x = x.transpose(0,3,1,2)
x = x/255
for i in range(len(x)):
    for j in range(x.shape[1]):
        x[i][j] = (x[i][j] - mean[j]) / std[j]
x[25972]

In [None]:
for batch_idx, (data, target, index) in enumerate(train_loader):
    print(data,index)

In [None]:
def getIdx(ref_meta, ref_batch, num_class):
    tmp = np.zeros((len(ref_batch), num_class))
    ref_meta = ref_meta.detach().cpu().numpy()
    ref_batch = ref_batch.detach().cpu().numpy()
    #print(ref_meta.shape,ref_batch.shape)
    #print(ref_meta[0][1], ref_meta[1][1], ref_meta[2][1])
    #print("hello")
    for i in range(len(ref_batch)):
        for j in range(num_class):
            #print(ref_batch[i], ref_meta[j])
            #print(np.sum(ref_batch[i]*ref_meta[j], dtype=np.float),(np.linalg.norm(ref_batch[i])),(np.linalg.norm(ref_meta[j])))
            tmp[i][j] = (np.sum(ref_batch[i]*ref_meta[j], dtype=np.float)/(np.linalg.norm(ref_batch[i]))/(np.linalg.norm(ref_meta[j])))
    #res = np.zeros((len(ref_batch)))
    #print(tmp)
    res = np.argmax(tmp, axis=1)
    return res

def getIdx1(ref_meta, ref_batch, num_class):
    tmp = np.zeros((len(ref_batch), num_class))
    ref_meta = ref_meta.detach().cpu().numpy()
    ref_batch = ref_batch.detach().cpu().numpy()
    for i in range(len(ref_batch)):
        for j in range(num_class):
            tmp[i][j] = (np.sum(ref_batch[i]*ref_meta[j], dtype=np.float)/(np.linalg.norm(ref_batch[i]))/(np.linalg.norm(ref_meta[j])))
    #res = np.argmax(tmp, axis=1)
    return tmp

In [None]:
torch.cuda.set_device(1) 
use_cuda = torch.cuda.is_available()
device = torch.device("cuda")# if use_cuda else "cpu")
encoder = torch.load('./model_UNet1').to(device)
encoder.eval()
metadata = train_dataset.metadata
metadata = metadata.transpose(0,3,1,2)  
metadata = metadata/255
for i in range(len(metadata)):
    for j in range(metadata.shape[1]):
        metadata[i][j] = (metadata[i][j] - mean[j]) / std[j] #image = (image - mean) / std
metadata = torch.from_numpy(metadata).float().to(device)
metadata = metadata.detach()
metatarget = torch.from_numpy(train_dataset.metatarget).float().to(device)
distance_true=[]
distance_false=[]
for batch_idx, (data, target, index, true_target) in enumerate(train_loader):
    data, target, true_target = data.to(device), target.to(device), true_target.to(device)
    print(batch_idx)
    print(metatarget)
    #optimizer.zero_grad()
    #output = model(data)
    #print(output.shape,data.shape)
    #print(metadata.shape, data.shape)
    ref_meta, _ = encoder(metadata)
    ref_batch, _ = encoder(data)
    #print(ref_meta.shape, len(ref_batch))
    Idx = getIdx1(ref_meta, ref_batch, 10)
    #print(Idx)
    #Idx = torch.from_numpy(Idx).float().to(device)
    #dijige = np.argmax(Idx)
    if target == true_target:
        distance_true.append(max(Idx[0]) - Idx[0][target])
    else:
        distance_false.append(max(Idx[0]) - Idx[0][target])
    #output = F.log_softmax(output,dim=1)
    #pseudotarget = output.argmax(dim=1)
    #print(Idx, true_target)
        
        
        

In [None]:
import matplotlib.pyplot as plt  
plt.hist(distance_true, bins=100, range=[0,1], density=None)
plt.hist(distance_false, bins=100, range=[0,1], density=None)

In [None]:
def trainUNet(args, model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        _,output = model(data)
        #print(output.shape,data.shape)
        loss = nn.MSELoss()(output, data)
        loss.backward()
        optimizer.step()
        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
    return loss.item()

In [None]:
train_dataset.metadata.shape

In [None]:
def train(args, model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        _,output = model(data)
        #print(output.shape,data.shape)
        loss = torch.sqrt(nn.MSELoss()(output, data))
        loss.backward()
        optimizer.step()
        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
    return loss.item()
    

In [None]:
def mainUNet():
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    
    parser.add_argument('--result_dir', type = str, help = 'dir to save result txt files', default = 'results/')
    parser.add_argument('--noise_rate', type = float, help = 'corruption rate, should be less than 1', default = 0.5)
    parser.add_argument('--forget_rate', type = float, help = 'forget rate', default = None)
    parser.add_argument('--noise_type', type = str, help='[pairflip, symmetric]', default='symmetric')
    parser.add_argument('--num_gradual', type = int, default = 10, help='how many epochs for linear drop rate, can be 5, 10, 15. This parameter is equal to Tk for R(T) in Co-teaching paper.')
    parser.add_argument('--exponent', type = float, default = 1, help='exponent of the forget rate, can be 0.5, 1, 2. This parameter is equal to c in Tc for R(T) in Co-teaching paper.')
    parser.add_argument('--top_bn', action='store_true')
    parser.add_argument('--dataset', type = str, help = 'mnist, cifar10, or cifar100', default = 'cifar10')
    parser.add_argument('--n_epoch', type=int, default=300)
    parser.add_argument('--seed', type=int, default=1)
    parser.add_argument('--print_freq', type=int, default=10)
    parser.add_argument('--num_workers', type=int, default=2, help='how many subprocesses to use for data loading')
    parser.add_argument('--num_iter_per_epoch', type=int, default=400)
    parser.add_argument('--epoch_decay_start', type=int, default=80)
    parser.add_argument('--eps', type=float, default=9.9)
    
    parser.add_argument('--batch-size', type=int, default=128, metavar='N',
                        help='input batch size for training (default: 256)')
    parser.add_argument('--test-batch-size', type=int, default=4000, metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--lr', type=float, default=1e-3, metavar='LR',
                        help='learning rate (default: 0.01)')
    parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
                        help='SGD momentum (default: 0.5)')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA training')
    parser.add_argument('--log-interval', type=int, default=100, metavar='N',
                        help='how many batches to wait before logging training status')
    
    parser.add_argument('--save-model', action='store_true', default=False,
                        help='For Saving the current Model')
    parser.add_argument('--noise-level', type=float, default=80.0,
                        help='percentage of noise added to the data (values from 0. to 100.), default: 80.')
    parser.add_argument('--root-dir', type=str, default='/home/iedl/w00536717/data', help='path to CIFAR dir where cifar-10-batches-py/ and cifar-100-python/ are located. If the datasets are not downloaded, they will automatically be and extracted to this path, default: .')
    args = parser.parse_args()
    
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    
    batch_size=args.batch_size
    

#     if args.dataset=='cifar10':
#         input_channel=3
#         num_classes=10
#         args.top_bn = False
#         args.epoch_decay_start = 80
#         args.n_epoch = 200
#         train_dataset = CIFAR10NEW(root='./data/',
#                                     download=True,
#                                     train=True,
#                                     transform=transforms.ToTensor(),
#                                     noise_type='clean',
#                                     noise_rate=args.noise_rate
#                                )
    
#         test_dataset = CIFAR10NEW(root='./data/',
#                                     download=True,
#                                     train=False,
#                                     transform=transforms.ToTensor(),
#                                     noise_type='clean',
#                                     noise_rate=args.noise_rate
#                               )
#     #
#     if args.dataset=='cifar100':
#         input_channel=3
#         num_classes=100
#         args.top_bn = False
#         args.epoch_decay_start = 100
#         args.n_epoch = 200
#         train_dataset = CIFAR100(root='./data/',
#                                     download=True,
#                                     train=True,
#                                     transform=transforms.ToTensor(),
#                                     noise_type=args.noise_type,
#                                     noise_rate=args.noise_rate
#                                 )
    
#         test_dataset = CIFAR100(root='./data/',
#                                     download=True,
#                                     train=False,
#                                     transform=transforms.ToTensor(),
#                                     noise_type=args.noise_type,
#                                     noise_rate=args.noise_rate
#                                 )
    # if args.forget_rate is None:
    #     forget_rate=args.noise_rate
    # else:
    #     forget_rate=args.forget_rate
    #
    # noise_or_not = train_dataset.noise_or_not
    # # Data Loader (Input Pipeline)
    # print('loading dataset...')
    # train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
    #                                            batch_size=batch_size,
    #                                            num_workers=args.num_workers,
    #                                            drop_last=True,
    #                                            shuffle=True)
    #
    # test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
    #                                           batch_size=batch_size,
    #                                           num_workers=args.num_workers,
    #                                           drop_last=True,
    #                                           shuffle=False)

    mean = [0.4914, 0.4822, 0.4465]
    std = [0.2023, 0.1994, 0.2010]

    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])
    if args.dataset == 'cifar10':
        trainset = CIFAR10NEW(root=args.root_dir, train=True, download=True, transform=transform_train, noise_type='clean', noise_rate=args.noise_rate)
        num_classes = 10
    elif args.dataset == 'cifar100':
        trainset = CIFAR100(root=args.root_dir, train=True, download=True, transform=transform_train, noise_type='clean', noise_rate=args.noise_rate)
        num_classes = 100

    train_loader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=16,
                                               pin_memory=True)
    #train_loader_track = torch.utils.data.DataLoader(trainset_track, batch_size=args.batch_size, shuffle=False,
     #                                                num_workers=4, pin_memory=True)
    #test_loader = torch.utils.data.DataLoader(testset, batch_size=args.test_batch_size, shuffle=False, num_workers=4,
     #                                         pin_memory=True)

    #labels = get_data_cifar_2(train_loader_track)  # it should be "clonning"
    #noisy_labels = add_noise_cifar_wo(train_loader, args.noise_level,
    #                                 args.noise_type)  # it changes the labels in the train loader directly
    #noisy_labels_track = add_noise_cifar_wo(train_loader_track, args.noise_level, args.noise_type)

    # Define models
    #os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda")# if use_cuda else "cpu")
    cnn = UNet().to(device)
    #cnn = nn.DataParallel(cnn,device_ids=[0,1,2,3])
    #cnn = PreResNet_two.ResNet18(num_classes=10).to(device)
    cnn.cuda()
    #print(model.parameters)
    #optimizer1 = torch.optim.SGD(cnn1.parameters(), lr=learning_rate)
    optimizer = torch.optim.SGD(cnn.parameters(), lr=args.lr,weight_decay=1e-5,momentum=args.momentum)
    #optimizer = torch.optim.Adam(cnn.parameters(), lr=args.lr)
    #optimizer1 = torch.optim.SGD(cnn.parameters(), lr=1e-2,weight_decay=1e-4,momentum=args.momentum)
    bmm_model = bmm_model_maxLoss = bmm_model_minLoss=0

    acc=[]
    loss=[]
    loss_pure=[]
    loss_corrupt=[]
    out=[]
    for epoch in range(1, args.n_epoch + 1):
        if epoch<200:
            l1=trainUNet(args, cnn, device, train_loader, optimizer, epoch)
            #acc.append(test(args, cnn, device, test_loader))
    torch.save(cnn, './model_UNet1')
    name=str(args.dataset)+" "+str(args.noise_type)+" "+str(args.noise_rate)

In [None]:
def main():
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    
    parser.add_argument('--result_dir', type = str, help = 'dir to save result txt files', default = 'results/')
    parser.add_argument('--noise_rate', type = float, help = 'corruption rate, should be less than 1', default = 0.5)
    parser.add_argument('--forget_rate', type = float, help = 'forget rate', default = None)
    parser.add_argument('--noise_type', type = str, help='[pairflip, symmetric]', default='symmetric')
    parser.add_argument('--num_gradual', type = int, default = 10, help='how many epochs for linear drop rate, can be 5, 10, 15. This parameter is equal to Tk for R(T) in Co-teaching paper.')
    parser.add_argument('--exponent', type = float, default = 1, help='exponent of the forget rate, can be 0.5, 1, 2. This parameter is equal to c in Tc for R(T) in Co-teaching paper.')
    parser.add_argument('--top_bn', action='store_true')
    parser.add_argument('--dataset', type = str, help = 'mnist, cifar10, or cifar100', default = 'cifar10')
    parser.add_argument('--n_epoch', type=int, default=300)
    parser.add_argument('--seed', type=int, default=1)
    parser.add_argument('--print_freq', type=int, default=50)
    parser.add_argument('--num_workers', type=int, default=2, help='how many subprocesses to use for data loading')
    parser.add_argument('--num_iter_per_epoch', type=int, default=400)
    parser.add_argument('--epoch_decay_start', type=int, default=80)
    parser.add_argument('--eps', type=float, default=9.9)
    
    parser.add_argument('--batch-size', type=int, default=128, metavar='N',
                        help='input batch size for training (default: 256)')
    parser.add_argument('--test-batch-size', type=int, default=4000, metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--lr', type=float, default=0.1, metavar='LR',
                        help='learning rate (default: 0.01)')
    parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
                        help='SGD momentum (default: 0.5)')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA training')
    parser.add_argument('--log-interval', type=int, default=100, metavar='N',
                        help='how many batches to wait before logging training status')
    
    parser.add_argument('--save-model', action='store_true', default=False,
                        help='For Saving the current Model')
    parser.add_argument('--noise-level', type=float, default=80.0,
                        help='percentage of noise added to the data (values from 0. to 100.), default: 80.')
    parser.add_argument('--root-dir', type=str, default='/home/iedl/w00536717/data', help='path to CIFAR dir where cifar-10-batches-py/ and cifar-100-python/ are located. If the datasets are not downloaded, they will automatically be and extracted to this path, default: .')
    args = parser.parse_args()
    
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    
    batch_size=args.batch_size
    

    if args.dataset=='cifar10':
        input_channel=3
        num_classes=10
        args.top_bn = False
        args.epoch_decay_start = 80
        args.n_epoch = 1
        train_dataset = CIFAR10NEW(root='./data/',
                                    download=True,
                                    train=True,
                                    transform=transforms.ToTensor(),
                                    noise_type=args.noise_type,
                                    noise_rate=args.noise_rate
                               )
    
        test_dataset = CIFAR10NEW(root='./data/',
                                    download=True,
                                    train=False,
                                    transform=transforms.ToTensor(),
                                    noise_type=args.noise_type,
                                    noise_rate=args.noise_rate
                              )
    #
    if args.dataset=='cifar100':
        input_channel=3
        num_classes=100
        args.top_bn = False
        args.epoch_decay_start = 100
        args.n_epoch = 200
        train_dataset = CIFAR100(root='./data/',
                                    download=True,
                                    train=True,
                                    transform=transforms.ToTensor(),
                                    noise_type=args.noise_type,
                                    noise_rate=args.noise_rate
                                )
    
        test_dataset = CIFAR100(root='./data/',
                                    download=True,
                                    train=False,
                                    transform=transforms.ToTensor(),
                                    noise_type=args.noise_type,
                                    noise_rate=args.noise_rate
                                )
    # if args.forget_rate is None:
    #     forget_rate=args.noise_rate
    # else:
    #     forget_rate=args.forget_rate
    #
    # noise_or_not = train_dataset.noise_or_not
    # # Data Loader (Input Pipeline)
    # print('loading dataset...')
    # train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
    #                                            batch_size=batch_size,
    #                                            num_workers=args.num_workers,
    #                                            drop_last=True,
    #                                            shuffle=True)
    #
    # test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
    #                                           batch_size=batch_size,
    #                                           num_workers=args.num_workers,
    #                                           drop_last=True,
    #                                           shuffle=False)

    mean = [0.4914, 0.4822, 0.4465]
    std = [0.2023, 0.1994, 0.2010]

    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])
    if args.dataset == 'cifar10':
        trainset = datasets.CIFAR10(root=args.root_dir, train=True, download=True, transform=transform_train)
        trainset_track = datasets.CIFAR10(root=args.root_dir, train=True, transform=transform_train)
        testset = datasets.CIFAR10(root=args.root_dir, train=False, transform=transform_test)
        num_classes = 10
    elif args.dataset == 'cifar100':
        trainset = datasets.CIFAR100(root=args.root_dir, train=True, download=True, transform=transform_train)
        trainset_track = datasets.CIFAR100(root=args.root_dir, train=True, transform=transform_train)
        testset = datasets.CIFAR100(root=args.root_dir, train=False, transform=transform_test)
        num_classes = 100
    
    metadata = train_dataset.metadata
    
    train_loader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=16,
                                               pin_memory=True)
    #train_loader_track = torch.utils.data.DataLoader(trainset_track, batch_size=args.batch_size, shuffle=False,
     #                                                num_workers=4, pin_memory=True)
    test_loader = torch.utils.data.DataLoader(testset, batch_size=args.test_batch_size, shuffle=False, num_workers=16,
                                             pin_memory=True)

    #labels = get_data_cifar_2(train_loader_track)  # it should be "clonning"
    #noisy_labels = add_noise_cifar_wo(train_loader, args.noise_level,
    #                                 args.noise_type)  # it changes the labels in the train loader directly
    #noisy_labels_track = add_noise_cifar_wo(train_loader_track, args.noise_level, args.noise_type)

    # Define models
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda")# if use_cuda else "cpu")
    encoder = torch.load('./model_UNet')
    cnn = ResNet18(num_classes).to(device)
    #cnn = nn.DataParallel(cnn,device_ids=[0,1,2,3])
    #cnn = cnn.to(device)
    #cnn = PreResNet_two.ResNet18(num_classes=10).to(device)
    cnn.cuda()
    #print(model.parameters)
    #optimizer1 = torch.optim.SGD(cnn1.parameters(), lr=learning_rate)
    optimizer = torch.optim.SGD(cnn.parameters(), lr=args.lr,weight_decay=1e-5,momentum=args.momentum)
    #optimizer = torch.optim.Adam(cnn.parameters(), lr=args.lr)
    #optimizer1 = torch.optim.SGD(cnn.parameters(), lr=1e-2,weight_decay=1e-4,momentum=args.momentum)
    #bmm_model = bmm_model_maxLoss = bmm_model_minLoss=0

    acc=[]
    loss=[]
    loss_pure=[]
    loss_corrupt=[]
    out=[]
    for epoch in range(1, args.n_epoch + 1):
        if epoch<200:
            l1=train(args, cnn, encoder, device, train_loader, optimizer, epoch)
            acc.append(test(args, cnn, device, test_loader))
    torch.save(cnn, './model_UNet')
    name=str(args.dataset)+" "+str(args.noise_type)+" "+str(args.noise_rate)

In [None]:
torch.cuda.set_device(1) 
sys.argv = ['-f']
mainUNet()