In [None]:
import os, time

In [None]:
TRIAL_NAME='ResNet_baseline'

In [None]:
CONF={
    'niter':200,
    'GPU':0,
    'BS':128,
    'test_BS':256,
    'N_neg':3,
    'name':TRIAL_NAME,
    'tb_dir':os.path.join('./runs', TRIAL_NAME),
    'nz':64,
    'seed':10708,
    'data_dir':'/DataSet/COCO',
    'dataType':'train2017',
    'valType':'val2017',
    'LAMBDA':0.5,
    'use_super':False,
    'test_classes':80
}

In [None]:
os.environ["CUDA_VISIBLE_DEVICES"]=str(CONF['GPU'])

In [None]:
# Clear any logs from previous runs
time.sleep(2)
import shutil
shutil.rmtree(CONF['tb_dir'], ignore_errors=True)
time.sleep(5)

In [None]:
import texar.torch as tx
import random
import torch
import torch.nn as nn
import torch.nn.parallel
from torch.nn import functional as F
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision
from torch.utils.tensorboard import SummaryWriter
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
from torch import autograd
import multiprocessing
from PIL import Image
from sklearn import metrics

In [None]:
device = torch.device("cuda:0" if True else "cpu")

In [None]:
writer = SummaryWriter(log_dir=CONF['tb_dir'])

In [None]:
random.seed(CONF['seed'])
torch.manual_seed(CONF['seed'])
np.random.seed(CONF['seed'])
cudnn.benchmark = True

In [None]:
batch_size = CONF['BS']

In [None]:
T = transforms.Compose([
    transforms.RandomResizedCrop((256,256), scale=(0.3, 1.0), ratio=(0.75, 1.3333333333333333)),
    transforms.ColorJitter(brightness=.1, contrast=.05, saturation=.05, hue=.05),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor()
])
T_test = transforms.Compose([
    transforms.Resize((256,256)),
    transforms.ToTensor()
])

In [None]:
from torchvision.datasets.vision import VisionDataset
class CocoClassification(VisionDataset):
    """`MS Coco Detection <http://mscoco.org/dataset/#detections-challenge2016>`_ Dataset.

    Args:
        root (string): Root directory where images are downloaded to.
        annFile (string): Path to json annotation file.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.ToTensor``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        transforms (callable, optional): A function/transform that takes input sample and its target as entry
            and returns a transformed version.
    """

    def sample_class(self, k):
        if CONF['use_super']:
            self.classes = np.array(["vehicle", "outdoor", "indoor", "person", "appliance", "furniture", "sports", "food", "kitchen", "accessory", "electronic", "animal"])#np.arange(12)+1
            self.class_description = ["vehicle", "outdoor", "indoor", "person", "appliance", "furniture", "sports", "food", "kitchen", "accessory", "electronic", "animal"]
            return
        class_list = self.coco.getCatIds()
        self.classes = np.sort(np.random.choice(class_list, size=k, replace=False))
        self.class_description = self.coco.loadCats(self.classes)
        arr = []
        for catId in self.classes:
            arr+=self.coco.getImgIds(catIds=[catId])
        self.ids = sorted(list(set(arr)))
        
    def __init__(self, root, annFile, transform=None, target_transform=None, transforms=None):
        super(CocoClassification, self).__init__(root, transforms, transform, target_transform)
        from pycocotools.coco import COCO
        self.coco = COCO(annFile)
        self.sample_class(len(self.coco.getCatIds()))

    def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: Tuple (image, target). target is the object returned by ``coco.loadAnns``.
        """
        coco = self.coco
        img_id = self.ids[index]
        ann_ids = coco.getAnnIds(imgIds=img_id)
        cat_ids = [ann['category_id'] for ann in coco.loadAnns(ann_ids)]
        target = coco.loadCats(cat_ids)
        if CONF['use_super']:
            target = np.array([x['supercategory'] for x in target])
        else:
            target = np.array([x['id'] for x in target if x['id'] in self.classes])
        targets = torch.FloatTensor([1 if (c in target) else 0 for c in self.classes])
        path = coco.loadImgs(img_id)[0]['file_name']
        img = Image.open(os.path.join(self.root, path)).convert('RGB')
        if self.transforms is not None:
            img, targets = self.transforms(img, targets)

        return img, targets


    def __len__(self):
        return len(self.ids)


In [None]:

clas_set = CocoClassification(root = '{}/{}'.format(CONF['data_dir'],CONF['dataType']),
                        annFile = '{}/annotations/instances_{}.json'.format(CONF['data_dir'],CONF['dataType']),
                        transform=T)
val_set = CocoClassification(root = '{}/{}'.format(CONF['data_dir'],CONF['valType']),
                        annFile = '{}/annotations/instances_{}.json'.format(CONF['data_dir'],CONF['valType']),
                        transform=T_test)

In [None]:
clas_set.sample_class(CONF['test_classes'])
train_loader = torch.utils.data.DataLoader(clas_set, batch_size=CONF['BS'],shuffle=True, num_workers=8, pin_memory=True)

In [None]:
val_loader = torch.utils.data.DataLoader(val_set, batch_size=CONF['test_BS'],shuffle=False, num_workers=8, pin_memory=False)

In [None]:
class Flatten(nn.Module):
    def forward(self, x):
        x = x.view(x.size()[0], -1)
        return x

In [None]:
resnet18 = torchvision.models.resnet18(pretrained=True)
modules=list(resnet18.children())[:-1]
modules.append(Flatten())
modules.append(nn.Linear(512, CONF['test_classes']))
resnet = nn.Sequential(*modules)
resnet.cuda()

In [None]:
opt = optim.SGD(resnet.parameters(), lr=0.02, momentum=0.9)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(opt, factor=0.2, patience=10)

In [None]:
c=0

In [None]:
criterion = nn.BCEWithLogitsLoss()

In [None]:
from tqdm.notebook import tqdm, trange

In [None]:
def hamming_score(y_true, y_pred, normalize=True, sample_weight=None):
    '''
    Compute the Hamming score (a.k.a. label-based accuracy) for the multi-label case
    https://stackoverflow.com/q/32239577/395857
    '''
    acc_list = []
    for i in range(y_true.shape[0]):
        set_true = set( np.where(y_true[i])[0] )
        set_pred = set( np.where(y_pred[i])[0] )
        tmp_a = None
        if len(set_true) == 0 and len(set_pred) == 0:
            tmp_a = 1
        else:
            tmp_a = len(set_true.intersection(set_pred))/\
                    float( len(set_true.union(set_pred)) )
        acc_list.append(tmp_a)
    return np.mean(acc_list)

In [None]:
for it in trange(CONF['niter']):

    l=[]
    LBL=[]
    activation = []
    Y = []
    resnet.eval()
    with torch.no_grad():
        count = 0
        corrects = torch.zeros(CONF['test_classes'])
        for img,lbl in tqdm(val_loader, leave=False, desc="testing"):
            LBL.append(lbl.numpy())
            lbl=lbl.cuda()
            count += img.shape[0]
            pred = resnet(img.cuda())
            l.append(criterion(pred, lbl).item())
            activation.append(pred.data.cpu().numpy())
            pred = torch.sigmoid(pred)>.5
            Y.append(pred.data.cpu().numpy())
            corrects += torch.sum(torch.eq(pred, lbl), dim=0).cpu()
        acc = (corrects/float(count))
        writer.add_scalar("sup_acc/val_avg", torch.mean(acc).item(), global_step=it)
        writer.add_histogram('baseline/acc_val', acc.data.cpu().numpy(), global_step=it)
        writer.add_scalar("sup_loss/val_loss", np.average(l), global_step=it)
        writer.flush()

    Y, LBL, activation = np.concatenate(Y,axis=0), np.concatenate(LBL,axis=0), np.concatenate(activation,axis=0)
    writer.add_scalar("sup_metrics/subset_acc", metrics.accuracy_score(LBL, Y, normalize=True), global_step=it)
    writer.add_scalar("sup_metrics/hamming_loss", metrics.hamming_loss(LBL, Y), global_step=it)
    writer.add_scalar("sup_metrics/hamming_score", hamming_score(LBL, Y), global_step=it)
    writer.add_scalar("sup_metrics/micro_f1", metrics.f1_score(LBL, Y, average='micro'), global_step=it)
    writer.add_scalar("sup_metrics/macro_f1", metrics.f1_score(LBL, Y, average='macro'), global_step=it)
    writer.add_scalar("sup_metrics/micro_roc_auc", metrics.roc_auc_score(LBL, Y, average='micro'), global_step=it)
    writer.add_scalar("sup_metrics/macro_roc_auc", metrics.roc_auc_score(LBL, Y, average='macro'), global_step=it)
    writer.add_scalar("sup_metrics/micro_precision", metrics.precision_score(LBL,Y,average='micro'), global_step=it)
    writer.add_scalar("sup_metrics/macro_precision", metrics.precision_score(LBL,Y,average='macro'), global_step=it)
    writer.add_scalar("sup_metrics/micro_recall", metrics.recall_score(LBL,Y,average='micro'), global_step=it)
    writer.add_scalar("sup_metrics/macro_recall", metrics.recall_score(LBL,Y,average='macro'), global_step=it)
    writer.add_scalar("sup_metrics/avg_acc", np.average(np.sum((Y==LBL), axis=0)/Y.shape[0]), global_step=it)
    scheduler.step(np.average(l))
    
    count = 0
    corrects = torch.zeros(CONF['test_classes'])
    resnet.train()
    for img,lbl in tqdm(train_loader, leave=False, desc="training"):
        count += img.shape[0]
        img = img.cuda()
        pred = resnet(img)
        loss = criterion(pred,lbl.cuda())
        with torch.no_grad():
            corrects += torch.sum(torch.eq((torch.sigmoid(pred)>.5).cpu(), lbl), dim=0)
        opt.zero_grad()
        loss.backward()
        opt.step()
        writer.add_scalar("sup_loss/loss", loss.item(), global_step=c)
        c+=1
    acc = (corrects/float(count))
    writer.add_scalar("sup_acc/train_avg", torch.mean(acc).item(), global_step=it)
    writer.add_histogram('baseline/acc_train', acc.data.cpu().numpy(), global_step=it)

writer.close()