In [1]:
import data_loader
import torch
from loss import pairwise_similarity, NT_xent
from models.resnet import ResNet18,ResNet50
from models.projector import Projector
import torch.optim as optim
from torchlars import LARS
from scheduler import GradualWarmupScheduler
from attack_lib import FastGradientSignUntargeted,RepresentationAdv
import os
import time
import argparse
#from utils import progress_bar, checkpoint

In [2]:
import torch
torch.set_rng_state(torch.get_rng_state())

In [None]:
parser = argparse.ArgumentParser(description='PyTorch RoCL training')
parser.add_argument('--local_rank', type=int, default=0)
parser.add_argument('--epoch', type=int, default=1000)
args = parser.parse_args()
multi_gpu = True
ngpu = 2

In [None]:
world_size = ngpu
torch.distributed.init_process_group(
    'nccl',
    init_method='env://',
    world_size=world_size,
    rank=args.local_rank,
)

In [2]:
train_sampler , train_loader, test_loader = data_loader.get_loader(local_rank=args.local_rank)

Files already downloaded and verified
Files already downloaded and verified


In [3]:
model = ResNet18(num_classes=10, contrastive_learning=True)
projector = Projector(expansion=1)
model.cuda()
projector.cuda()
model       = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
model       = torch.nn.parallel.DistributedDataParallel(
                model,
                device_ids=[args.local_rank],
                output_device=args.local_rank,
                find_unused_parameters=True,
)
projector   = torch.nn.parallel.DistributedDataParallel(
                projector,
                device_ids=[args.local_rank],
                output_device=args.local_rank,
                find_unused_parameters=True,
)

Projector(
  (linear_1): Linear(in_features=512, out_features=2048, bias=True)
  (linear_2): Linear(in_features=2048, out_features=128, bias=True)
)

In [4]:
epsilon = float(8/255)
alpha = float(2/255)
max_iters = 7
loss_type="sim"
regularize_type = 'other'
lr = 0.1
weight_decay = 1e-6
epochs = args.epoch
lr_multiplier = 15.0
lamda = float(512)
random_start = True
advtrain_type = "Rep" #Rep/None
temperature = 0.5


In [5]:
 RepAttack = RepresentationAdv(model, projector, epsilon=epsilon, alpha=alpha, min_val=0.0, max_val=1.0, max_iters=max_iters, _type="linf", loss_type=loss_type, regularize = regularize_type)

In [6]:
model_params = []
model_params += model.parameters()
model_params += projector.parameters()

In [7]:
base_optimizer  = optim.SGD(model_params, lr=lr, momentum=0.9, weight_decay=weight_decay)
optimizer   = LARS(optimizer=base_optimizer, eps=1e-8, trust_coef=0.001)

In [8]:
scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=lr_multiplier, total_epoch=10, after_scheduler=scheduler_cosine)

In [9]:
def checkpoint(model, acc, epoch, optimizer, save_name_add=''):
    # Save checkpoint.
    print('Saving..')
    state = {
        'epoch': epoch,
        'acc': acc,
        'model': model.state_dict(),
        'optimizer_state' : optimizer.state_dict(),
        'rng_state': torch.get_rng_state()
    }

    save_name = './checkpoint/ckpt_'
    save_name += save_name_add

    if not os.path.isdir('./checkpoint'):
        os.mkdir('./checkpoint')
    torch.save(state, save_name)

In [10]:
def train(epoch):
    
    print('\nEpoch: %d' % epoch)

    model.train()
    projector.train()

    train_sampler.set_epoch(epoch)
    scheduler_warmup.step()

    total_loss = 0
    reg_simloss = 0
    reg_loss = 0

    for batch_idx, (ori, inputs_1, inputs_2, label) in enumerate(train_loader):
        ori, inputs_1, inputs_2 = ori.cuda(), inputs_1.cuda() ,inputs_2.cuda()

        
        attack_target = inputs_2

        
        advinputs, adv_loss = RepAttack.get_loss(original_images=inputs_1, target = attack_target, optimizer=optimizer, weight= lamda, random_start=random_start)
        reg_loss    += adv_loss.data

        if not (advtrain_type == 'None'):
            inputs = torch.cat((inputs_1, inputs_2, advinputs))
        else:
            inputs = torch.cat((inputs_1, inputs_2))
        
        outputs = projector(model(inputs))
        similarity, gathered_outputs = pairwise_similarity(outputs, temperature=temperature, multi_gpu=multi_gpu, adv_type = advtrain_type) 
        
        simloss  = NT_xent(similarity, advtrain_type)
        
        if not (advtrain_type=='None'):
            loss = simloss + adv_loss
        else:
            loss = simloss
        
        optimizer.zero_grad()
        loss.backward()
        total_loss += loss.data
        reg_simloss += simloss.data
        
        optimizer.step()

    
        if (args.local_rank % ngpus_per_node == 0):
            if 'Rep' in args.advtrain_type:
                progress_bar(batch_idx, len(train_loader),
                             'Loss: %.3f | SimLoss: %.3f | Adv: %.2f'
                             % (total_loss / (batch_idx + 1), reg_simloss / (batch_idx + 1), reg_loss / (batch_idx + 1)))
            else:
                progress_bar(batch_idx, len(train_loader),
                         'Loss: %.3f | Adv: %.3f'
                         % (total_loss/(batch_idx+1), reg_simloss/(batch_idx+1)))
        
    return (total_loss/batch_idx, reg_simloss/batch_idx)


In [11]:
def test(epoch, train_loss):
    model.eval()
    projector.eval()
    if args.local_rank % ngpus_per_node == 0:
        checkpoint(model, train_loss, epoch, optimizer, save_name_add='_epoch_'+str(epoch))
        checkpoint(projector, train_loss, epoch, optimizer, save_name_add=('_projector_epoch_' + str(epoch)))

In [12]:
start_time = time.time()
for epoch in range(0, epochs):
    train_loss, reg_loss = train(epoch)
    test(epoch, train_loss)
end_time = time.time()
print("Time taken for {} epoch {}".format(epochs, (end_time - start_time) ))


Epoch: 0


	add_(Number alpha, Tensor other)
Consider using one of the following signatures instead:
	add_(Tensor other, *, Number alpha) (Triggered internally at  ..\torch\csrc\utils\python_arg_parser.cpp:882.)
  p.grad.add_(weight_decay, p.data)


Saving..
Saving..
Time taken for 1 epoch 570.3546478748322
