# Bird Classification using a Weakly-Supervised Data Augmentation Network.

### Let's import some stuff to make this work

In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
import matplotlib.pyplot as plt
import numpy as np
import random
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
import torchvision.datasets as datasets
from glob import glob
import cv2
import datetime as dt
import h5py
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
import pandas as pd
import os
torch.manual_seed(0)
names = pd.read_csv('../input/birds21wi/birds/names.txt', sep="\n", header=None)

cuda:0


## Preprocessing:

For preprocessing we get the images in both the train and test folder and compress them into hdf5 files. To do this we iterate through each image and resize them using OpenCV (to 128x128). Although in our network we use 299x299 images, using 128x128 in this stage proved to be efficient for loading the images into memory. Moreover, the use of HDF5 files allows us to increase the performance of our data laoders since they just have to look at memory instead of actual files in the system.

In [None]:
def proc_images(isTest):
    """
    Saves compressed, resized images as HDF5 datsets
    Returns
        data.h5, where each dataset is an image or class label
        e.g. X23,y23 = image and corresponding class label
    """
    start = dt.datetime.now()
    # ../input/
    #PATH = os.path.abspath(os.path.join('..', 'input'))
    # ../input/sample/images/
    #SOURCE_IMAGES = os.path.join(PATH, "sample", "images")
    # ../input/sample/images/*.png
    if (isTest):
        images = glob('../input/birds21wi/birds/test/**/*.jpg')
        images.extend(glob('../input/birds21wi/birds/test/**/*.JPG'))
    else:
        images = glob('../input/birds21wi/birds/train/**/*.jpg')
        images.extend(glob('../input/birds21wi/birds/train/**/*.JPG'))
    # Load labels
    labels = pd.read_csv('../input/birds21wi/birds/labels.csv')

    # Set the disease type you want to look for

    # Size of data
    NUM_IMAGES = len(images)
    HEIGHT = 128
    WIDTH = 128
    CHANNELS = 3
    SHAPE = (HEIGHT, WIDTH, CHANNELS)

    x_data = []
    y_data = []
    if isTest:
        paths = []
        fname = 'birds21hdf5_test.h5'
    else:
        fname = 'birds21hdf5.h5'
    with h5py.File(fname, 'w') as hf: 
        for i,img in enumerate(images):            
            # Images
            image = cv2.imread(img)
            image = cv2.resize(image, (WIDTH,HEIGHT), interpolation=cv2.INTER_CUBIC)
            x_data.append(image)
            # Labels
            base = os.path.basename(img)
            if isTest:
                y_data.append('')
                paths.append(base)
            else:
                finding = labels["class"][labels["path"] == base].values[0]
                y_data.append(finding)
            

            end=dt.datetime.now()
            print("\r", i, ": ", (end-start).seconds, "seconds", end="")
        xset = hf.create_dataset(  
            name='images',
            data=x_data,
            shape=(NUM_IMAGES,HEIGHT, WIDTH, CHANNELS),
            maxshape=(NUM_IMAGES, HEIGHT, WIDTH, CHANNELS),
            compression="gzip",
            compression_opts=9)
        if isTest:
            paths = np.array(paths, dtype='S')
            pathset = hf.create_dataset(
                name='path',
                data=paths,
                shape=(1, NUM_IMAGES),
                maxshape=(1, NUM_IMAGES)
            )
        else:
            yset = hf.create_dataset(
            name='labels',
            data=y_data,
            shape=(1, NUM_IMAGES),
            maxshape=(1, NUM_IMAGES),
            compression="gzip",
            compression_opts=9)
#proc_images(True)

In [2]:
# Data loader
class H5Dataset(torch.utils.data.Dataset):
    def __init__(self, h5_file, transform=None):
        self.transform = transform
        self.h5_file = h5py.File(h5_file, 'r')
        self.images = self.h5_file['images'][:]
        self.labels = torch.LongTensor(self.h5_file['labels'][:]).transpose(0, 1)
        self.labels = torch.flatten(self.labels)

        
    def __len__(self):
        return self.labels.shape[0]
      
    def __getitem__(self, idx):
        data = self.images[idx]
        label = self.labels[idx]
        
        if self.transform:
            data = self.transform(data)
        return (data, label)
#dataset = H5Dataset('data_2.h5')

In [3]:
# H5Dataset for test dataset since we need information about the paths of the images when we make our final predictions
class H5DatasetTest(torch.utils.data.Dataset):
    def __init__(self, h5_file, transform=None):
        self.transform = transform
        self.h5_file = h5py.File(h5_file, 'r')
        self.images = self.h5_file['images'][:]
        self.paths = np.array(self.h5_file['path'][:]).transpose()

        
    def __len__(self):
        return self.paths.shape[0]
      
    def __getitem__(self, idx):
        data = self.images[idx]
        
        if self.transform:
            data = self.transform(data)
        return (data, torch.tensor(idx))
    
    def target_to_path(self, target):
        return str(self.paths[target.item()].item().decode('utf-8'))


To train our network we need some feedback on the accuracy it gets over data it has never seen. Therefore, we partition the images in the training folder so that we have roughly an 80, 20% split between the training data and test data (for our experiments, the actual test data is for the final predictions)

In [4]:
def get_birds21wi_data():
    resize=(299, 299)
    transform_train = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((299, 299)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
    ])

    transform_test = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((299, 299)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
    ])

    dataset = H5Dataset('../input/birds21hdf5/birds21hdf5.h5')
    #dataset = datasets.ImageFolder('../input/birds21wi/birds/train')
    #print(len(dataset))
    train_set, val_set = torch.utils.data.random_split(dataset, [30849, 7713])
    train_set.dataset.transform = transform_train
    val_set.dataset.transform = transform_test
    trainloader = torch.utils.data.DataLoader(train_set, batch_size=20, shuffle=True, num_workers=8)
    testloader = torch.utils.data.DataLoader(val_set, batch_size=20, shuffle=True, num_workers=8)
    return {'train': trainloader, 'test':testloader}#, 'labels': dataset.class_to_idx}
data = get_birds21wi_data()

In [None]:
def get_actual_label(output_label, class_to_idx):
    key_list = list(class_to_idx.keys())
    val_list = list(class_to_idx.values())
    pos = val_list.index(output_label)
    return int(key_list[pos])
data['train']

In [None]:
dataiter = iter(data['train'])
images, labels = dataiter.next()
images = images[:8]
print(labels[0])
print(images.size())

def imshow(img):
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

# show images
#print(data['train'].dataset.samples[i][0])
imshow(images[0])
# print labels
print("Labels:" + ' '.join('%9s' % names.iloc[labels[j].item()] for j in range(1)))

flat = torch.flatten(images, 1)
print(images.size())
print(flat.size())

## Network

Here we try to use a Weakly-Supervised Data Augmentation Network, where data augmentation is driven by attention maps rather than by randomness. In our WSDAN network we have a convolutional-neural network that outputs feature maps, another one outputs the attention maps from these features. Then the feature matrix we use for our predictions is outputed by a Bilinear Attention Pooling net. Our features are given by an inceptionv3 network that has been previously trained on the iNaturalist dataset.

In [6]:
EPSILON = 1e-12


class BasicConv2d(nn.Module):

    def __init__(self, in_channels, out_channels, **kwargs):
        super(BasicConv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
        self.bn = nn.BatchNorm2d(out_channels, eps=0.001)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return F.relu(x, inplace=True)

# Bilinear Attention Pooling
class BAP(nn.Module):
    def __init__(self, pool='GAP'):
        super(BAP, self).__init__()
        assert pool in ['GAP', 'GMP']
        if pool == 'GAP':
            self.pool = None
        else:
            self.pool = nn.AdaptiveMaxPool2d(1)

    def forward(self, features, attentions):
        B, C, H, W = features.size()
        _, M, AH, AW = attentions.size()

        # match size
        if AH != H or AW != W:
            attentions = F.interpolate(attentions, size=(H, W), mode='bilinear', align_corners=True)

        # feature_matrix: (B, M, C) -> (B, M * C)
        if self.pool is None:
            feature_matrix = (torch.einsum('imjk,injk->imn', (attentions, features)) / float(H * W)).view(B, -1)
        else:
            feature_matrix = []
            for i in range(M):
                AiF = self.pool(features * attentions[:, i:i + 1, ...]).view(B, -1)
                feature_matrix.append(AiF)
            feature_matrix = torch.cat(feature_matrix, dim=1)

        # sign-sqrt
        feature_matrix = torch.sign(feature_matrix) * torch.sqrt(torch.abs(feature_matrix) + EPSILON)

        # l2 normalization along dimension M and C
        feature_matrix = F.normalize(feature_matrix, dim=-1)
        return feature_matrix

# WS-DAN Implementation: Weakly Supervised Data Augmentation Network for Fine-Grained
# Visual Classification.
class WSDAN(nn.Module):
    def __init__(self, M=32):
        super(WSDAN, self).__init__()
        self.num_classes = 555
        self.M = M
        
        # Let's load the inceptio_v3 model untrained on ImageNet, instead we load
        # the weights
        model  = models.inception_v3(pretrained=False)
        model.fc = nn.Linear(2048, 8142)
        model.aux_logits = False
        model.load_state_dict(torch.load('../input/v3-inat2018/iNat_2018_InceptionV3.pth.tar', map_location='cpu')['state_dict'])
        model.eval()
        

        self.inception = nn.Sequential(
            model.Conv2d_1a_3x3,
            model.Conv2d_2a_3x3,
            model.Conv2d_2b_3x3,
            nn.MaxPool2d(kernel_size=3, stride=2),
            model.Conv2d_3b_1x1,
            model.Conv2d_4a_3x3,
            nn.MaxPool2d(kernel_size=3, stride=2),
            model.Mixed_5b,
            model.Mixed_5c,
            model.Mixed_5d,
            model.Mixed_6a,
            model.Mixed_6b,
            model.Mixed_6c,
            model.Mixed_6d,
            model.Mixed_6e,
            model.Mixed_7a,
            model.Mixed_7b,
            model.Mixed_7c,
        )

        self.num_features = 2048
        
        # attention maps
        self.attentions = BasicConv2d(self.num_features, self.M, kernel_size=1)
        
        #Bilinear Attention Pooling
        self.bap = BAP(pool='GAP')
        
        # Classification Layer
        self.fc = nn.Linear(self.M * self.num_features, self.num_classes, bias=False)
        
    def forward(self, x, training):
        batch_size = x.size(0)
        feature_maps = self.inception(x)
        
        #attention_maps = self.attentions(feature_maps)

        attention_maps = feature_maps[:, :self.M, ...]
        feature_matrix = self.bap(feature_maps, attention_maps)
        
        # Classification
        p = self.fc(feature_matrix * 100.)

        # Generate Attention Map
        if training:
            # Randomly choose one of attention maps Ak
            attention_map = []
            for i in range(batch_size):
                attention_weights = torch.sqrt(attention_maps[i].sum(dim=(1, 2)).detach() + EPSILON)
                attention_weights = F.normalize(attention_weights, p=1, dim=0)
                k_index = np.random.choice(self.M, 2, p=attention_weights.cpu().numpy())
                attention_map.append(attention_maps[i, k_index, ...])
            attention_map = torch.stack(attention_map)  # (B, 2, H, W) - one for cropping, the other for dropping
        else:
            # Object Localization Am = mean(Ak)
            attention_map = torch.mean(attention_maps, dim=1, keepdim=True)  # (B, 1, H, W)

        # p: (B, self.num_classes)
        # feature_matrix: (B, M * C)
        # attention_map: (B, 2, H, W) in training, (B, 1, H, W) in val/testing
        return p, feature_matrix, attention_map
#net=WSDAN()        

In [7]:
##################################
# augment function
##################################
def batch_augment(images, attention_map, mode='crop', theta=0.5, padding_ratio=0.1):
    batches, _, imgH, imgW = images.size()

    if mode == 'crop':
        crop_images = []
        for batch_index in range(batches):
    
            atten_map = attention_map[batch_index:batch_index + 1]
            if isinstance(theta, tuple):
                theta_c = random.uniform(*theta) * atten_map.max()
            else:
                theta_c = theta * atten_map.max()

            crop_mask = F.interpolate(atten_map, size=(imgH, imgW), mode='bilinear', align_corners=True) >= theta_c
            nonzero_indices = torch.nonzero(crop_mask[0, 0, ...])
            height_min = max(int(nonzero_indices[:, 0].min().item() - padding_ratio * imgH), 0)
            height_max = min(int(nonzero_indices[:, 0].max().item() + padding_ratio * imgH), imgH)
            width_min = max(int(nonzero_indices[:, 1].min().item() - padding_ratio * imgW), 0)
            width_max = min(int(nonzero_indices[:, 1].max().item() + padding_ratio * imgW), imgW)

            crop_images.append(
                F.interpolate(images[batch_index:batch_index + 1, :, height_min:height_max, width_min:width_max],
                                    size=(imgH, imgW), mode='bilinear', align_corners=True))
        crop_images = torch.cat(crop_images, dim=0)
        return crop_images

    elif mode == 'drop':
        drop_masks = []
        for batch_index in range(batches):
            atten_map = attention_map[batch_index:batch_index + 1]
            if isinstance(theta, tuple):
                theta_d = random.uniform(*theta) * atten_map.max()
            else:
                theta_d = theta * atten_map.max()

            drop_masks.append(F.interpolate(atten_map, size=(imgH, imgW), mode='bilinear', align_corners=True) < theta_d)
        drop_masks = torch.cat(drop_masks, dim=0)
        drop_images = images * drop_masks.float()
        return drop_images
    

##############################################
# Center Loss for Attention Regularization
##############################################
class CenterLoss(nn.Module):
    def __init__(self):
        super(CenterLoss, self).__init__()
        self.l2_loss = nn.MSELoss(reduction='sum')

    def forward(self, outputs, targets):
        return self.l2_loss(outputs, targets) / outputs.size(0)

def generate_heatmap(attention_maps):
    heat_attention_maps = []
    heat_attention_maps.append(attention_maps[:, 0, ...])  # R
    heat_attention_maps.append(attention_maps[:, 0, ...] * (attention_maps[:, 0, ...] < 0.5).float() + \
                               (1. - attention_maps[:, 0, ...]) * (attention_maps[:, 0, ...] >= 0.5).float())  # G
    heat_attention_maps.append(1. - attention_maps[:, 0, ...])  # B
    return torch.stack(heat_attention_maps, dim=1)

ToPILImage = transforms.ToPILImage()
MEAN = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
STD = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)

### Training

The training function is similar to past versions with some updates. In particular, now there is a `schedule` parameter to handle learning rate scheduling and also a `checkpoint_path` parameter which will be where training checkpoints are saved (if provided).

The core of training is the same though, get a batch, run the model forward, calculate loss, run it backward, update.

In [14]:
def train(net, dataloader, optimizer=None, epochs=1, start_epoch=0, lr=0.01, momentum=0.9, decay=0.0005, 
          verbose=1, print_every=10, state=None, scheduler=None, schedule={}, checkpoint_path=None, beta=5e-2):
    net = net.to(device)
    net.train()
    losses = []
    criterion = nn.CrossEntropyLoss()
    center_loss = CenterLoss()
    train_accuracies = []
    test_accuracies = []
    num_classes = 555
    optimizer = optim.SGD(net.parameters(), lr=lr, momentum=momentum, weight_decay=decay)
    
    # feature_center: size of (#classes, #attention_maps * #channel_features)
    feature_center = torch.zeros(num_classes, 32 * net.num_features).to(device)
    
    # Load previous training state
    if state:
        net.load_state_dict(state['net'])
        optimizer.load_state_dict(state['optimizer'])
        start_epoch = state['epoch']
        losses = state['losses']
        feature_center = state['feature_center']

  # Fast forward lr schedule through already trained epochs
    for epoch in range(start_epoch):
        if epoch in schedule:
            print ("Learning rate: %f"% schedule[epoch])
            for g in optimizer.param_groups:
                g['lr'] = schedule[epoch]

    for epoch in range(start_epoch, epochs):
        sum_loss = 0.0

    # Update learning rate when scheduled
        if epoch in schedule:
            print("entering schedule change\n")
            print ("Learning rate: %f"% schedule[epoch])
            for g in optimizer.param_groups:
                g['lr'] = schedule[epoch]

        for i, batch in enumerate(dataloader, 0):
            inputs, labels = batch[0].to(device), batch[1].to(device)

            optimizer.zero_grad()
            
            # Raw Image. i.e., no cropping or dropping, the original input image

            output_raw, feature_matrix, attention_map = net(inputs, True)
            
            # Update Feature Center
            feature_center_batch = F.normalize(feature_center[labels], dim=-1)
            feature_center[labels] += beta * (feature_matrix.detach() - feature_center_batch)
            
            ##################################
            # Attention Cropping
            ##################################
            with torch.no_grad():
                crop_images = batch_augment(inputs, attention_map[:, :1, :, :], mode='crop', theta=(0.4, 0.6), padding_ratio=0.1)
                
            output_crop, _, _ = net(crop_images, True)
            
            ##################################
            # Attention Dropping
            ##################################
            with torch.no_grad():
                drop_images = batch_augment(inputs, attention_map[:, 1:, :, :], mode='drop', theta=(0.2, 0.5))
            
            output_drop, _, _ = net(drop_images, True)

            loss = criterion(output_raw, labels) / 3. + \
             criterion(output_crop, labels) / 3. + \
             criterion(output_drop, labels) / 3. + \
             center_loss(feature_matrix, feature_center_batch)
            
            loss.backward()  # autograd magic, computes all the partial derivatives
            optimizer.step() # takes a step in gradient direction

            losses.append(loss.item())
            sum_loss += loss.item()
            if scheduler:
                scheduler.step()
            if i % print_every == print_every-1:    # print every 10 mini-batches
                if verbose:
                    print('[%d, %5d] loss: %.3f' % (epoch, i + 1, sum_loss / print_every))
                sum_loss = 0.0
        test_accuracies.append(accuracy(net, data['test']))
        print(test_accuracies[-1])
        net.train()
        if checkpoint_path and (((epoch + 1) % 5 == 0) or (epoch == 0)):
            print("reached checkpoint")
            state = {'epoch': epoch+1, 'net': net.state_dict(), 'optimizer': optimizer.state_dict(), 'losses': losses, 'feature_center':feature_center}
            torch.save(state, 'checkpoint-%d.pkl'%(epoch+1))
        print("Finished epoch")
    return losses, test_accuracies

def accuracy(net, dataloader):
    net = net.to(device)
    net.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch in dataloader:
            images, labels = batch[0].to(device), batch[1].to(device)
            output_raw, _, attention_maps = net(images, False)
            
            crop_images = batch_augment(images, attention_maps, mode='crop', theta=0.1, padding_ratio=0.05)
            output_crop, _, _ = net(crop_images, False)
            
            outputs = (output_raw + output_crop) / 2.
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return correct/total

def smooth(x, size):
    return np.convolve(x, np.ones(size)/size, mode='valid')

Let's try finetuning the entire model using the pretrained weights. Then we will try it as a feature extractor.

In [None]:

model = WSDAN()
pretrain_losses, accuracies = train(model, data['train'],start_epoch=10, epochs=40, 
                                    state=torch.load('checkpoint-10.pkl'),
                                    schedule={0:0.001, 5:0.0001, 10:0.00001, 20:0.000001},
                                    print_every=100, checkpoint_path='lol')
print("Testing accuracy: %f" % accuracy(model, data['test']))
plt.plot(smooth(pretrain_losses,50))

In [None]:
transform_act_test = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((299, 299)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
])
test_dataset = H5DatasetTest('../input/birds21-train/birds21hdf5_test.h5', transform=transform_act_test)
testloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=8)
def predict(net, dataloader, ofname):
    out = open(ofname, 'w')
    out.write("path,class\n")
    net.to(device)
    net.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for i, (images, labels) in enumerate(dataloader, 0):
            if i%100 == 0:
                print(i)
            images, labels = images.to(device), labels.to(device)
            output_raw, _, attention_maps = net(images, False)
            
            crop_images = batch_augment(images, attention_maps, mode='crop', theta=0.1, padding_ratio=0.05)
            output_crop, _, _ = net(crop_images, False)
            
            outputs = (output_raw + output_crop) / 2.
            _, predicted = torch.max(outputs.data, 1)
            #fname, _ = dataloader.dataset.samples[i]
            #print(labels)
            path = test_dataset.target_to_path(labels)
            #print(labels)
            out.write("test/{},{}\n".format(path.split('/')[-1], predicted.item()))
            #break
    out.close()
predict(model, testloader, "preds.csv")