In [1]:

import torch
from torch.backends import cudnn
cudnn.enabled = True
from torch.utils.data import DataLoader
import torch.nn.functional as F
import sys
sys.path.append('../')

import importlib

import voc12.dataloader
from misc import pyutils, torchutils

device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")


def validate(model, data_loader):
    print('validating ... ', flush=True, end='')

    val_loss_meter = pyutils.AverageMeter('loss1', 'loss2')

    model.eval()

    with torch.no_grad():
        for pack in data_loader:
            img = pack['img'].to(device)
            label = pack['label'].to(device)

            x = model(img)
            loss1 = F.multilabel_soft_margin_loss(x, label)

            val_loss_meter.add({'loss1': loss1.item()})

    model.train()

    print('loss: %.4f' % (val_loss_meter.pop('loss1')))

    return




model = getattr(importlib.import_module('net.resnet50_cam'), 'Net')()


train_dataset = voc12.dataloader.VOC12ClassificationDataset('/home/maruf/ws2m2/voc12/train.txt', voc12_root='/home/amartyadutta/VOC12/AMN/Datasets/VOCdevkit/VOC2012',
                                                            resize_long=(320, 640), hor_flip=True,
                                                            crop_size=512, crop_method="random")
train_data_loader = DataLoader(train_dataset, batch_size=16,
                               shuffle=True, num_workers=8, pin_memory=True, drop_last=True)
max_step = (len(train_dataset) // 16) * 5

val_dataset = voc12.dataloader.VOC12ClassificationDataset('/home/maruf/ws2m2/voc12/val.txt', voc12_root='/home/amartyadutta/VOC12/AMN/Datasets/VOCdevkit/VOC2012',
                                                          crop_size=512)
val_data_loader = DataLoader(val_dataset, batch_size=16,
                             shuffle=False, num_workers=8, pin_memory=True, drop_last=True)

param_groups = model.trainable_parameters()
optimizer = torchutils.PolyOptimizer([
    {'params': param_groups[0], 'lr': 0.1, 'weight_decay': 0.0001},
    {'params': param_groups[1], 'lr': 10*0.1, 'weight_decay': 0.0001},
], lr=0.1, weight_decay=0.0001, max_step=max_step)

# model = torch.nn.DataParallel(model).to(device)
model = model.to(device)
model.train()

avg_meter = pyutils.AverageMeter()

timer = pyutils.Timer()

for ep in range(5):
    
    print('Epoch %d/%d' % (ep+1, 5))
    
    for step, pack in enumerate(train_data_loader):

        img = pack['img'].to(device)
        label = pack['label'].to(device)

        img.requires_grad = True
        
        # additional for robustness
        noise_probability = torch.ones_like(img)*0.75
        noise = torch.bernoulli(noise_probability).detach()
        noise.requires_grad = True
        
        out = model(img*noise)
        scores = (out*label).sum(-1)
        
        saliency = torch.autograd.grad(
                scores, noise,
                grad_outputs = torch.ones(len(scores), device=device),
                retain_graph=True,
                create_graph=True
            )[0]

        grad_x, grad_y = torch.gradient(saliency, dim=(-1, -2))
        loss_reg = torch.linalg.norm(grad_x ** 2 + grad_y ** 2)
        # loss_reg = torch.linalg.norm(saliency)
        
        loss_cls = F.multilabel_soft_margin_loss(out, label)
        
        loss = 1 * loss_cls + 1 * loss_reg

        avg_meter.add({'loss_cls': loss_cls.item()})
        avg_meter.add({'loss_sal': loss_reg.item()})

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (optimizer.global_step-1)%20 == 0:
            timer.update_progress(optimizer.global_step / max_step)

            print('step:%5d/%5d' % (optimizer.global_step - 1, max_step),
                  'loss_cls:%.4f' % (avg_meter.pop('loss_cls')),
                  'loss_sal:%.4f' % (avg_meter.pop('loss_sal')),
                  'imps:%.1f' % ((step + 1) * 16 / timer.get_stage_elapsed()),
                  'lr: %.4f' % (optimizer.param_groups[0]['lr']),
                  'etc:%s' % (timer.str_estimated_complete()), flush=True)

    else:
        validate(model, val_data_loader)
        timer.reset_stage()

        



  from .autonotebook import tqdm as notebook_tqdm


Epoch 1/5
step:    0/  455 loss_cls:0.7128 loss_sal:0.0001 imps:5.1 lr: 0.1000 etc:Fri Nov 11 21:13:58 2022
step:   20/  455 loss_cls:0.2665 loss_sal:0.0005 imps:11.3 lr: 0.0960 etc:Fri Nov 11 21:00:53 2022
step:   40/  455 loss_cls:0.2399 loss_sal:0.0007 imps:11.5 lr: 0.0921 etc:Fri Nov 11 21:00:36 2022
step:   60/  455 loss_cls:0.2326 loss_sal:0.0014 imps:11.4 lr: 0.0880 etc:Fri Nov 11 21:00:42 2022
step:   80/  455 loss_cls:0.2152 loss_sal:0.0031 imps:11.2 lr: 0.0840 etc:Fri Nov 11 21:00:53 2022
validating ... loss: 0.1939
Epoch 2/5
step:  100/  455 loss_cls:0.2065 loss_sal:0.0051 imps:10.0 lr: 0.0800 etc:Fri Nov 11 21:02:03 2022
step:  120/  455 loss_cls:0.1876 loss_sal:0.0087 imps:10.3 lr: 0.0759 etc:Fri Nov 11 21:01:59 2022
step:  140/  455 loss_cls:0.1601 loss_sal:0.0128 imps:10.4 lr: 0.0718 etc:Fri Nov 11 21:01:56 2022
step:  160/  455 loss_cls:0.1641 loss_sal:0.0183 imps:10.4 lr: 0.0677 etc:Fri Nov 11 21:01:54 2022
step:  180/  455 loss_cls:0.1521 loss_sal:0.0165 imps:10.4 lr:

In [2]:
torch.save(model.state_dict(), '/home/maruf/ws2m2/sess/resnet50_cam' + '_saliency_norm.pth')
torch.cuda.empty_cache()