In [1]:
import torch
import numpy as np
import pickle
import json
import csv
import imageio
import os
import pandas as pd
from torchvision import datasets, models, transforms
from PIL import Image
from torch import nn
import scipy
import torch.nn as nn
import torch.optim as optim
import torch.functional as F
from torch.utils.data import Dataset, DataLoader
import time
import random
from torch.optim.lr_scheduler import StepLR
import shutil
from datetime import date
import scipy.io as sio

In [2]:
#Create alexnet features
alexnet = models.alexnet(pretrained=True)
for param in alexnet.parameters():
    param.requires_grad = False

In [3]:
activation = {}
def get_activation(name):
    def hook(model, input, output):
        activation[name] = output.detach()
    return hook

alexnet.classifier[5].register_forward_hook(get_activation('img_feats'))

<torch.utils.hooks.RemovableHandle at 0x7fe067f47f40>

In [4]:
transform = transforms.Compose([
    transforms.Resize(224),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], 
                         [0.229, 0.224, 0.225])
])

In [5]:
subjects = {1:'art', 2:'biology', 3:'geography', 4:'history', 5:'literature', 6:'media',\
            7:'music', 8:'royalty', 9:'sport', 10:'warfare'}

In [6]:
raw_features = sio.loadmat('./wikipedia_dataset/raw_features.mat')

img_features, text_features = np.float32(raw_features['I_tr']), np.float32(raw_features['T_tr'])

concepts = []
#stores tuple (text name, img name, concept)
text_img_names = []

infile = open('./wikipedia_info/trainset_txt_img_cat.list', 'r')
data = infile.readlines()
for line in data:
    line_arr = line.strip().split('\t')
    text_img_names.append((line_arr[0], line_arr[1], subjects[int(line_arr[2])]))
    concepts.append(int(line_arr[2])-1)

In [7]:
#Make all text and img pairs
#Store img feats,text name,text features,img label,text label,label_similarity

pairs = []

for i in range(len(img_features)):
    arr = np.arange(len(img_features))
    np.random.shuffle(arr)
    
    img = Image.open('./wikipedia_dataset/images/' + text_img_names[i][2] + '/' + text_img_names[i][1]+'.jpg')
    img = img.convert('RGB')
    img = transform(img)
    img = img.reshape(1,3,224,224)

    output = alexnet(img)
    img_feats = activation['img_feats'][0].numpy()
    
    for j in arr[:100]:
        pairs.append((text_img_names[i][2] + '/' + text_img_names[i][1]+'.jpg',\
                      text_img_names[j][2] + '/' + text_img_names[j][0], img_feats,\
                      text_features[j], concepts[i],concepts[j], np.float32(concepts[i]==concepts[j])))
    
    
    pairs.append((text_img_names[i][2] + '/' + text_img_names[i][1]+'.jpg',\
                  text_img_names[i][2] + '/' + text_img_names[i][0], img_feats,\
                          text_features[i], concepts[i],concepts[i], 1.))
    
    if(not i%100):
        print(i)

0
100
200
300
400
500
600
700
800
900
1000
1100
1200
1300
1400
1500
1600




1700
1800
1900
2000
2100


In [8]:
print(pairs[0])

('media/ceb47321a83dd824cec2d5d3f2034765.jpg', 'literature/7dde4239baf7365dbedb1a0503f9d322-2', array([0.       , 0.       , 0.718269 , ..., 2.4364696, 0.       ,
       0.       ], dtype=float32), array([0.07198273, 0.3217547 , 0.03789706, 0.03646101, 0.03835768,
       0.08873656, 0.03675003, 0.17545001, 0.07450258, 0.11810768],
      dtype=float32), 5, 4, 0.0)


In [20]:
class DataLoader():
    """
    DataLoader class for image features of the WIKI dataset.
    """
    
    def __init__(self, img_text_pairs, train=True):
        #If train is false, we are testing
        self.all_pairs = img_text_pairs
        
        self.num_img_features = len(img_text_pairs[0][0])
        self.num_text_features = len(img_text_pairs[0][2])
         
    def __len__(self):
        return len(self.all_pairs)
        
    
    def __getitem__(self, idx):
        pair_info = self.all_pairs[idx]
        return [{'img':pair_info[0], 'text': pair_info[1], 'img_feats': pair_info[2],
                 'text_feats': pair_info[3], 'img_label': pair_info[4], 'text_label':pair_info[5],
                 'similarity': pair_info[6]}]

In [21]:
class Model(nn.Module):
    
    def __init__(self, hashing_length=32, num_of_concepts=10):
        super(Model, self).__init__()
        
        self.img_hashing_layer = nn.Linear(4096, hashing_length)
        nn.init.xavier_normal_(self.img_hashing_layer.weight)
        
        self.text_hashing_layer = nn.Linear(10, hashing_length)
        nn.init.xavier_normal_(self.text_hashing_layer.weight)
        
        self.img_output = nn.Linear(hashing_length, num_of_concepts)
        nn.init.xavier_normal_(self.img_output.weight)
        
        self.text_output = nn.Linear(hashing_length, num_of_concepts)
        nn.init.xavier_normal_(self.text_output.weight)
        
        self.num_of_concepts = num_of_concepts
        self.tanh = nn.Tanh()
        self.softmax = nn.Softmax(dim=1)
        
    def hash_vector(self, features, img=True):
        if img:
            hash_code = self.img_hashing_layer(features)
        else:
            hash_code = self.text_hashing_layer(features)
            
        hash_code = self.tanh(hash_code)
        return hash_code
        
    def forward(self, img_features, text_features):
        img_hash_code = self.img_hashing_layer(img_features)
        img_hash_code = self.tanh(img_hash_code)
        concept_vector = self.img_output(img_hash_code)
        img_output = self.softmax(concept_vector)

        text_hash_code = self.text_hashing_layer(text_features)
        text_hash_code = self.tanh(text_hash_code)
        concept_vector = self.text_output(text_hash_code)
        text_output = self.softmax(concept_vector)
        
        return img_output, text_output, img_hash_code, text_hash_code

In [22]:
class Loss_Correction(nn.Module):
    def __init__(self,):
        super(Loss_Correction, self).__init__()

    def forward(self, img_hash_batch, text_hash_batch):
        #Quantization
        loss = torch.norm(torch.abs(img_hash_batch) - torch.ones(img_hash_batch.shape), p='fro')**2
        loss += torch.norm(torch.abs(text_hash_batch) - torch.ones(text_hash_batch.shape), p='fro')**2

        #Bit Balancing
        loss += torch.norm(torch.matmul(torch.ones((1,img_hash_batch.shape[0])), img_hash_batch), p=2)**2
        loss += torch.norm(torch.matmul(torch.ones((1,img_hash_batch.shape[0])), text_hash_batch), p=2)**2

        return loss/2*(img_hash_batch.shape[0])

In [23]:
def transform_to_binary(hash_code):
    binary_code = np.ones(len(hash_code), dtype=np.int8)
    for i, el in enumerate(hash_code):
        if (el<0):
            binary_code[i] = -1
    
    return binary_code

In [24]:
def loss_fn_1(img_hash_code, text_hash_code, similarity):    
    #Inter Modality Similarity preserving loss
    loss1 = nn.MSELoss()
    
    return loss1(torch.dot(img_hash_code, text_hash_code)/len(img_hash_code), similarity)

In [25]:
def loss_fn_2(img_output, text_output, img_label, text_label): 
    #Label Preserving loss
    loss2 = nn.CrossEntropyLoss()
        
    return (loss2(img_output, img_label) + loss2(text_output, text_label))

In [26]:
def hamming_distance(img_hash_vec, text_hash_vec):
    hm_dist = 0

    for i in range(len(img_hash_vec)):
        if (img_hash_vec[i]!=text_hash_vec[i]):
            hm_dist+=1
    
    return hm_dist

In [27]:
BATCH_SIZE = 64
NUM_OF_EPOCHS = 300

In [28]:
train_dataset = DataLoader(pairs)

train_size = int(0.8 * len(train_dataset))
valid_size = len(train_dataset) - train_size

train_ds, valid_ds = torch.utils.data.random_split(train_dataset, [train_size, valid_size])

train_loader = torch.utils.data.DataLoader(train_ds,batch_size=BATCH_SIZE, shuffle=True)
valid_loader = torch.utils.data.DataLoader(valid_ds,batch_size=BATCH_SIZE, shuffle=True)

In [31]:
#TRAINING LOOP

device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")

model = Model(hashing_length=32, num_of_concepts=10)
model = model.to(device)

model.train()

optimizer = optim.Adagrad(
    [
        {"params": model.img_hashing_layer.parameters(), "lr": 0.01},
        {"params": model.text_hashing_layer.parameters(), "lr": 0.01},
        {"params": model.img_output.parameters(), "lr": 0.001},
        {"params": model.text_output.parameters(), "lr": 0.001},
    ],
    lr=0.00001,
    )

# scheduler = StepLR(optimizer, step_size=1, gamma=0.5)

alpha, beta = 1, 0.5

for epoch in range(NUM_OF_EPOCHS):
    total_epoch_loss = 0
    for i, data_row in enumerate(train_loader):
        optimizer.zero_grad()

        text_name_batch, batch_img_feats, batch_text_feats, batch_img_label, batch_text_label, batch_similarity =\
        data_row[0]['text'], data_row[0]['img_feats'], data_row[0]['text_feats'],\
        data_row[0]['img_label'], data_row[0]['text_label'], data_row[0]['similarity']
        
        img_output, text_output, img_hash_batch, text_hash_batch = model(batch_img_feats, batch_text_feats) 
        
        for j in range(img_output.shape[0]):
            if (j==0):
                loss1 = loss_fn_1(img_hash_batch[j], text_hash_batch[j], batch_similarity[j])
            else:
                loss1 += loss_fn_1(img_hash_batch[j], text_hash_batch[j], batch_similarity[j]) 

        loss2 = loss_fn_2(img_output, text_output, batch_img_label, batch_text_label)

        correction = Loss_Correction()
        loss3 = correction(img_hash_batch, text_hash_batch)

        loss = alpha*(loss1+loss2)
#         loss = alpha*(loss1+loss2) + beta*loss3

        try:
            loss.backward()
        except:
            ''

        optimizer.step()
        
        total_epoch_loss += loss.item()

    print("Epoch %d, Average loss per batch : %0.3f"%(epoch, total_epoch_loss/len(train_loader)))
    
    if(epoch%10==0):
        torch.save(model.state_dict(), "./saved_models/epoch" + str(epoch) + ".pth")

Epoch 0, Average loss per batch : 10.542
Epoch 1, Average loss per batch : 9.872
Epoch 2, Average loss per batch : 9.461
Epoch 3, Average loss per batch : 9.173


KeyboardInterrupt: 

In [32]:
#LOAD TRAINED MODEL

model = Model(hashing_length=32, num_of_concepts=10)
model = model.to('cpu')

model.eval()
model.load_state_dict(torch.load('./saved_models/epoch230.pth', map_location='cpu'))

<All keys matched successfully>

In [35]:
#TESTING LOOP

img_hash = {}
text_hash = {}

alpha, beta = 1, 0.5

total_loss = 0

for i, data_row in enumerate(valid_loader):
    img_name_batch, text_name_batch, batch_img_feats, batch_text_feats, batch_img_label, batch_text_label, batch_similarity =\
    data_row[0]['img'], data_row[0]['text'], data_row[0]['img_feats'], data_row[0]['text_feats'],\
    data_row[0]['img_label'], data_row[0]['text_label'], data_row[0]['similarity']

    img_output, text_output, img_hash_batch, text_hash_batch = model(batch_img_feats, batch_text_feats) 

    for i in range(len(img_name_batch)):
        img_hash[img_name_batch[i]] = transform_to_binary(img_hash_batch[i].detach().numpy())
        text_hash[text_name_batch[i]] = transform_to_binary(text_hash_batch[i].detach().numpy())
    
    for j in range(img_output.shape[0]):
        if (j==0):
            loss1 = loss_fn_1(img_hash_batch[j], text_hash_batch[j], batch_similarity[j])
        else:
            loss1 += loss_fn_1(img_hash_batch[j], text_hash_batch[j], batch_similarity[j]) 

    loss2 = loss_fn_2(img_output, text_output, batch_img_label, batch_text_label)

    correction = Loss_Correction()
    loss3 = correction(img_hash_batch, text_hash_batch)

    loss = alpha*(loss1+loss2)
#         print(loss.item())
#         loss = alpha*(loss1+loss2) + beta*loss3

    total_loss+=loss.item()

print("Average validation loss per batch : %0.3f"% (total_loss/len(valid_loader)))

Average validation loss per batch : 7.977


In [49]:
for i, data_row in enumerate(valid_loader):
    img_name_batch, text_name_batch, batch_img_feats, batch_text_feats, batch_img_label, batch_text_label, batch_similarity =\
    data_row[0]['img'], data_row[0]['text'], data_row[0]['img_feats'], data_row[0]['text_feats'],\
    data_row[0]['img_label'], data_row[0]['text_label'], data_row[0]['similarity']

    img_output, text_output, img_hash_batch, text_hash_batch = model(batch_img_feats, batch_text_feats) 

    print(len(np.where(np.abs(img_hash_batch.detach().numpy())<0.99)[0]))
    break

print("Average validation loss per batch : %0.3f"% (total_loss/len(valid_loader)))

500
Average validation loss per batch : 7.977


In [36]:
#PRECISION@K
K = 100

In [37]:
#Image-Query-Text
img_cnt, text_cnt = 0., 0.

for query_img in img_hash.keys():
    
    img_hm_dist = {}
    text_hm_dist = {}

    for text in text_hash.keys():
        text_hm_dist[text] = hamming_distance(img_hash[query_img], text_hash[text])

    #Closest-5 texts
    text_response = sorted(text_hm_dist.items(), key=lambda item: item[1])

    for img in img_hash.keys():
        img_hm_dist[img] = hamming_distance(img_hash[query_img], img_hash[img])

    #Closest-5 images
    img_response = sorted(img_hm_dist.items(), key=lambda item: item[1])
    
    for j in range(K):
        if text_response[j][0].split('/')[0]==query_img.split('/')[0]:
            text_cnt+=0.01

        if img_response[j][0].split('/')[0]==query_img.split('/')[0]:
            img_cnt+=0.01 
            
print("Average Precision@%d for Image-Query-Text is %0.2f"%(K, text_cnt/len(img_hash.keys())))
print("Average Precision@%d for Image-Query-Image is %0.2f"%(K, img_cnt/len(img_hash.keys())))

Average Precision@100 for Image-Query-Text is 0.55
Average Precision@100 for Image-Query-Image is 0.58


In [38]:
#Image-Query-Text
img_cnt, text_cnt = 0., 0.

for query_text in text_hash.keys():
    
    img_hm_dist = {}
    text_hm_dist = {}

    for text in text_hash.keys():
        text_hm_dist[text] = hamming_distance(text_hash[query_text], text_hash[text])

    #Closest-5 texts
    text_response = sorted(text_hm_dist.items(), key=lambda item: item[1])

    for img in img_hash.keys():
        img_hm_dist[img] = hamming_distance(text_hash[query_text], img_hash[img])

    #Closest-5 images
    img_response = sorted(img_hm_dist.items(), key=lambda item: item[1])
    
    for j in range(K):
        if text_response[j][0].split('/')[0]==query_text.split('/')[0]:
            text_cnt+=0.01

        if img_response[j][0].split('/')[0]==query_text.split('/')[0]:
            img_cnt+=0.01 
            
print("Average Precision@%d for Text-Query-Image is %0.2f"%(K, img_cnt/len(text_hash.keys())))
print("Average Precision@%d for Text-Query-Text is %0.2f"%(K, text_cnt/len(text_hash.keys())))

Average Precision@100 for Text-Query-Image is 0.62
Average Precision@100 for Text-Query-Text is 0.60
