In [1]:
from torch import nn, optim
import torch.nn.functional as F
from torchvision import models, transforms
from torch.autograd import Variable
import torch
import requests
import json
import numpy as np
from os import path
import os
import re
from PIL import Image
from torch.utils.data.dataset import Dataset
from torchvision import transforms, utils
from tqdm import tqdm_notebook as tqdm

PATH = '/shared/HARRISON'
MODEL_PATH = 'vgg_hashnet_r3_v6.pth'
pretrained = True

In [2]:
class VGG_HASHNET(nn.Module):

    def __init__(self):
        super(VGG_HASHNET, self).__init__()
        self.object_extractor = models.vgg16(pretrained=True)
        self.object_extractor.classifier = nn.Sequential(*list(self.object_extractor.classifier.children())[:-2])
        self.object_extractor.eval()

        self.background_extractor = models.__dict__['alexnet'](num_classes=365)
        checkpoint = torch.load('/shared/alexnet_places365.pth.tar', map_location=lambda storage, loc: storage)
        state_dict = {str.replace(k,'module.',''): v for k,v in checkpoint['state_dict'].items()}
        self.background_extractor.load_state_dict(state_dict)
        self.background_extractor.classifier = nn.Sequential(*list(self.background_extractor.classifier.children())[:-1])
        self.background_extractor.eval()

        self.fc1 = nn.Linear(4096 * 2, 4096)
#         self.fc1 = nn.DataParallel(self.fc1)
        self.fc2 = nn.Linear(4096,4096)
#         self.fc2 = nn.DataParallel(self.fc2)

        self.output = nn.Linear(4096,997)
        if pretrained:
            loaded = torch.load(MODEL_PATH)
            self.load_state_dict({ str.replace(k,'module.', ''): v for k,v in loaded.items() })
#         print(self.object_extractor)
#         print("============================")
#         print(self.background_extractor)
    def forward(self, x):
        obj_feat = self.object_extractor(x)
#         print(obj_feat.size())
        scene_feat = self.background_extractor(x)
#         print(scene_feat.size())
        feats = torch.cat((obj_feat, scene_feat), dim=1)
#         print(feats.size())
        output = self.fc1(feats)
        output = self.fc2(output)
        output = self.output(output)
#         print('Layer:', output)
        return torch.sigmoid(output)
model = VGG_HASHNET()
model = model.cuda()
model = nn.DataParallel(model)
# model.load_state_dict(weight_dict)
# torch.save(model.state_dict(), "./TEHBESTMODEL.pth")


In [3]:
class HARRISON_DATASET(Dataset):
    def __init__(self):
        # reading either data or image labels, filepaths
        #read data_list
        with open(path.join(PATH, 'data_list.txt')) as f:
            self.images = [l.strip() for l in f.readlines()]
#         print('SELF IMAGES: {}'.format(self.images[0]))

        # read hashtags
        with open(path.join(PATH, 'tag_list.txt')) as f:
            self.tags = [l.strip() for l in f.readlines()]

        self.tags = [x.split(' ') for x in self.tags]
#         print(self.tags)
        self.word_map = {}
        curr_size = 0

        for labels in self.tags:
            for tag in labels:
                if tag not in self.word_map:
                    self.word_map[tag] = curr_size
                    curr_size += 1

        self.num_labels = len(self.word_map.keys())
        self.tags = [self._convert_onehot(x) for x in self.tags]

        self.normalize = transforms.Normalize(
           mean=[0.485, 0.456, 0.406],
           std=[0.229, 0.224, 0.225]
        )
        self.preprocess = transforms.Compose([
           transforms.Resize(224),
           transforms.CenterCrop(224),
           transforms.ToTensor(),
           self.normalize
        ])

    def __getitem__(self, index):
#         print(self.word_map)
        #read_file from disk
        img_path = os.path.join(PATH, self.images[index])
        img = Image.open(img_path)
        img_tensor = self.preprocess(img).cuda()
        return(img_tensor, self.tags[index])

    def __len__(self):
        return len(self.images)

    def _convert_onehot(self, tags):
        one_hot = np.zeros(self.num_labels, dtype=np.float32)
        for tag in tags:
            one_hot[self.word_map[tag]] = 1
        return torch.from_numpy(one_hot).cuda()


In [4]:
from torch.utils.data.sampler import SubsetRandomSampler
dset = HARRISON_DATASET()
batch_size = 72
validation_split = .2
shuffle_dataset = True

dataset_size = len(dset)
indices = list(range(dataset_size))
split = int(np.floor(validation_split * dataset_size))
random_seed = 42
if shuffle_dataset:
    np.random.seed(random_seed)
    np.random.shuffle(indices)
train_indices, val_indices = indices[split:], indices[:split]
train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(val_indices)

train_loader = torch.utils.data.DataLoader(dset, batch_size=batch_size,
                                           sampler=train_sampler)
validation_loader = torch.utils.data.DataLoader(dset,
                                                batch_size=batch_size,
                                                sampler=valid_sampler)

In [8]:
criterion = nn.BCELoss()
# optimizer = optim.SGD(model.parameters(), lr=0.008, momentum=1.1)
# optimizer = optim.SGD(model.parameters(), lr=0.004, momentum=1)
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.7)
# optimizer = optim.Adam(model.parameters(), lr=0.001)
EPOCHS = 50

def get_intersection(a, b):
    a = a.cpu()
    b = b.cpu()
#     print(a,b)

    values_output, idx_output = torch.topk(a, 5)
    values_label, idx_labels = torch.topk(b, 5)

    precision = []
    recall = []
    accuracy = []
    
    for i in range(len(idx_output)):
    #     print(idx_output[i], idx_labels[i])
        intersection = np.intersect1d(idx_output[i].numpy(), idx_labels[i].numpy())
        accuracy.append(min(len(intersection), 1))

#         print(idx_output[i], idx_labels[i])
        intersection = [x for x in intersection if b[i].numpy()[x] > 0]
        unique, counts = np.unique(intersection, return_counts=True)
        int_count = len(counts[counts > 0])
        precision.append(int_count / 5)
        
        
        np_values_label = values_label[i].numpy()
        recall.append(int_count / len(np_values_label[np_values_label > 0]))
#         input()
        if(recall[-1] > 1):
            print(idx_output[i], idx_labels[i])
            print(int_count)
            print(intersection)
            print(len(values_label[i].numpy()[np.where(values_label[i].numpy() > 0)]))
            print(values_output[i], values_label[i])
            print(precision[-1])
            input()
    return (precision, recall, accuracy)


In [None]:
# Train model
BATCHES = 100
previous_lr = 0.0
break_out = False

for ep in tqdm(range(EPOCHS)):
    if break_out:
        print('Breaking out!')
        break

    running_loss = 0.0
    running_precision = 0.0
    running_recall = 0.0
    running_accuracy = 0.0
    total_predictions = 0

    for i, data in enumerate(train_loader, 0):
        inputs, labels = data
#         print('Labels: {}'.format(labels))
        
        optimizer.zero_grad()
        
        #forward + back + opt3
        outputs = model.forward(inputs)
        
        precision, recall, accuracy = get_intersection(outputs, labels)
        total_predictions += len(precision)
        
        running_precision += sum(precision)
        running_recall += sum(recall)
        running_accuracy += sum(accuracy)
        
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        if i % BATCHES == (BATCHES-1):    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.4f | precision: %.4f | recall: %.4f | accuracy: %.4f' %
                  (ep + 1, i + 1, running_loss / BATCHES, running_precision / total_predictions, 
                  running_recall / total_predictions, running_accuracy / total_predictions))
            if abs(running_loss - previous_lr)/running_loss < 0.001:
                break_out = True
                break
            previous_lr = running_loss
            running_loss = 0.0
            running_recall = 0.0
            running_accuracy = 0.0
            running_precision = 0.0
            total_predictions = 0


In [None]:
# Save model (optional)
# torch.save(model.state_dict(), 'vgg_hashnet_r3_v6.pth')

In [9]:
bench_loss = 0.0
bench_precision = 0.0
bench_recall = 0.0
bench_accuracy = 0.0
bench_total_predictions = 0

for i, data in tqdm(enumerate(validation_loader, 0)):
    inputs, labels = data
    outputs = model.forward(inputs)
    
    precision, recall, accuracy = get_intersection(outputs, labels)
    bench_total_predictions += len(precision)
        
    bench_precision += sum(precision)
    bench_recall += sum(recall)
    bench_accuracy += sum(accuracy)


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))




In [10]:
print('Overall results: ')
print('# of samples: {}'.format(bench_total_predictions))
print('Precision @ 5: {:4f}'.format(bench_precision / bench_total_predictions))
print('Recall @ 5: {:4f}'.format(bench_recall / bench_total_predictions))
print('Accuracy: {:4f}'.format(bench_accuracy / bench_total_predictions))

Overall results: 
# of samples: 11476
Precision @ 5: 0.141774
Recall @ 5: 0.233670
Accuracy: 0.537034


In [None]:
%matplotlib inline

import torchvision.utils
import matplotlib.pyplot as plt

def imshow(img):
    img = img / 2 + 0.5 # unnormalize
    img = img.cpu()
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))

dataiter = iter(validation_loader)
images, labels = dataiter.next()
classes = dset.word_map
classes = {val:key for (key, val) in classes.items()}
# classes[123]
print(images.size())

    # print images
    # print(output)
    # print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(4)))
    # output = model(images)
    # _, predicted = torch.max(outputs, 5)
    # print('Predicted: ', ' '.join('%5s' % classes[predicted[j]]
    # for j in range(4))) 

In [None]:
output = model(images[0:2])
print(torch.topk(output,5))
imshow(images[-2])

In [None]:
results = torch.topk(output, 5)[-1].cpu().numpy()[-2]
print(results)
for i in results:
    print(classes[i])