In [2]:
import os
import torch
import numpy as np
import pickle
import matplotlib.pyplot as plt
from torch.optim.lr_scheduler import LambdaLR, StepLR

In [3]:
#@title

import gzip
import html
import os
from functools import lru_cache

import ftfy
import regex as re


@lru_cache()
def bytes_to_unicode():
    """
    Returns list of utf-8 byte and a corresponding list of unicode strings.
    The reversible bpe codes work on unicode strings.
    This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
    When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
    This is a signficant percentage of your normal, say, 32K bpe vocab.
    To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
    And avoids mapping to whitespace/control characters the bpe code barfs on.
    """
    bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
    cs = bs[:]
    n = 0
    for b in range(2**8):
        if b not in bs:
            bs.append(b)
            cs.append(2**8+n)
            n += 1
    cs = [chr(n) for n in cs]
    return dict(zip(bs, cs))


def get_pairs(word):
    """Return set of symbol pairs in a word.
    Word is represented as tuple of symbols (symbols being variable-length strings).
    """
    pairs = set()
    prev_char = word[0]
    for char in word[1:]:
        pairs.add((prev_char, char))
        prev_char = char
    return pairs


def basic_clean(text):
    text = ftfy.fix_text(text)
    text = html.unescape(html.unescape(text))
    return text.strip()


def whitespace_clean(text):
    text = re.sub(r'\s+', ' ', text)
    text = text.strip()
    return text


class SimpleTokenizer(object):
    def __init__(self, bpe_path: str = "clip/bpe_simple_vocab_16e6.txt.gz"):
        self.byte_encoder = bytes_to_unicode()
        self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
        merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
        merges = merges[1:49152-256-2+1]
        merges = [tuple(merge.split()) for merge in merges]
        vocab = list(bytes_to_unicode().values())
        vocab = vocab + [v+'</w>' for v in vocab]
        for merge in merges:
            vocab.append(''.join(merge))
        vocab.extend(['<|startoftext|>', '<|endoftext|>'])
        self.encoder = dict(zip(vocab, range(len(vocab))))
        self.decoder = {v: k for k, v in self.encoder.items()}
        self.bpe_ranks = dict(zip(merges, range(len(merges))))
        self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
        self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)

    def bpe(self, token):
        if token in self.cache:
            return self.cache[token]
        word = tuple(token[:-1]) + ( token[-1] + '</w>',)
        pairs = get_pairs(word)

        if not pairs:
            return token+'</w>'

        while True:
            bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
            if bigram not in self.bpe_ranks:
                break
            first, second = bigram
            new_word = []
            i = 0
            while i < len(word):
                try:
                    j = word.index(first, i)
                    new_word.extend(word[i:j])
                    i = j
                except:
                    new_word.extend(word[i:])
                    break

                if word[i] == first and i < len(word)-1 and word[i+1] == second:
                    new_word.append(first+second)
                    i += 2
                else:
                    new_word.append(word[i])
                    i += 1
            new_word = tuple(new_word)
            word = new_word
            if len(word) == 1:
                break
            else:
                pairs = get_pairs(word)
        word = ' '.join(word)
        self.cache[token] = word
        return word

    def encode(self, text):
        bpe_tokens = []
        text = whitespace_clean(basic_clean(text)).lower()
        for token in re.findall(self.pat, text):
            token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
            bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
        return bpe_tokens

    def decode(self, tokens):
        text = ''.join([self.decoder[token] for token in tokens])
        text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
        return text


In [4]:
# from clip import clip
# from clip import model

# model, preprocess = clip.load("ViT-B/32", device='cuda', jit=False)

model = torch.jit.load("../checkpoints/model.pt").cuda().eval()
input_resolution = model.input_resolution.item()
context_length = model.context_length.item()
vocab_size = model.vocab_size.item()

print("Model parameters:", f"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}")
print("Input resolution:", input_resolution)
print("Context length:", context_length)
print("Vocab size:", vocab_size)

Model parameters: 151,277,313
Input resolution: 224
Context length: 77
Vocab size: 49408


# Step 1: Load LAD Dataset
## Option 1 Load Dataset from Scratch

In [5]:
file_root = '/media/hxd/82231ee6-d2b3-4b78-b3b4-69033720d8a8/MyDatasets/LAD'
data_root = file_root + '/LAD_annotations/'
img_root = file_root + '/LAD_images/'

In [6]:
# load attributes list
attributes_list_path = data_root + 'attribute_list.txt'
fsplit = open(attributes_list_path, 'r', encoding='UTF-8')
lines_attribute = fsplit.readlines()
fsplit.close()
list_attribute = list()
list_attribute_value = list()
for each in lines_attribute:
    tokens = each.split(', ')
    list_attribute.append(tokens[0])
    list_attribute_value.append(tokens[1])

In [7]:
# load label list
label_list_path = data_root + 'label_list.txt'
fsplit = open(label_list_path, 'r', encoding='UTF-8')
lines_label = fsplit.readlines()
fsplit.close()
list_label = dict()
list_label_value = list()
for each in lines_label:
    tokens = each.split(', ')
    list_label[tokens[0]]=tokens[1]
    list_label_value.append(tokens[1])

In [8]:
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from PIL import Image

preprocess = Compose([
    Resize((224, 224), interpolation=Image.BICUBIC),
    CenterCrop((224, 224)),
    ToTensor()
])

image_mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).cuda()
image_std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).cuda()



In [9]:
# load all the labels, attributes, images data from the LAD dataset
attributes_per_class_path = data_root + 'attributes.txt'
fattr = open(attributes_per_class_path, 'r', encoding='UTF-8')
lines_attr = fattr.readlines()
fattr.close()
images = list()
attr = list()
labels = list()
for each in lines_attr:
    tokens = each.split(', ')
    labels.append(list_label[tokens[0]])
    img_path = tokens[1]
    image = preprocess(Image.open(os.path.join(img_root, img_path)).convert("RGB"))
    images.append(image)
    attr_r = list(map(int, tokens[2].split()[1:-1]))
    attr.append([val for i,val in enumerate(list_attribute_value) if attr_r[i] == 1])

In [10]:
# Dump processed image and text to local
with open('../checkpoints/data_img_raw.pkl', 'wb') as file:
    pickle.dump(images, file)
with open('../checkpoints/data_txt_raw.pkl', 'wb') as file:
    pickle.dump({'label': labels, 'att': attr}, file)

## Option 2 Load LAD Dataset from Saved Files

In [11]:
with open('../checkpoints/data_img_raw.pkl', 'rb') as file:
    images = pickle.load(file)
with open('../checkpoints/data_txt_raw.pkl', 'rb') as file:
    b = pickle.load(file)
    
labels = b['label']
attr = b['att']

# Step 2: Obtain the Image and Text Features
## Option 1 Load CLIP to obtain features

In [12]:
# normalize images
image_input = torch.tensor(np.stack(images)).cuda()
image_input -= image_mean[:, None, None]
image_input /= image_std[:, None, None]

In [13]:
# Convert labels to tokens
tokenizer_label = SimpleTokenizer()
text_tokens = [tokenizer_label.encode(desc) for desc in labels]

sot_token = tokenizer_label.encoder['<|startoftext|>']
eot_token = tokenizer_label.encoder['<|endoftext|>']

text_inputs_label = torch.zeros(len(text_tokens), model.context_length, dtype=torch.long)
for i, tokens in enumerate(text_tokens):
    tokens = [sot_token] + tokens + [eot_token]
    text_inputs_label[i, :len(tokens)] = torch.tensor(tokens)
text_inputs_label = text_inputs_label.cuda()

In [14]:
# Convert attributes to tokens
tokenizer_att = SimpleTokenizer()
text_tokens = [[tokenizer_att.encode(desc) for desc in att] for att in attr]

sot_token = tokenizer_att.encoder['<|startoftext|>']
eot_token = tokenizer_att.encoder['<|endoftext|>']
text_inputs_att = list()

for j, tokens_img in enumerate(text_tokens):
    text_input = torch.zeros(len(tokens_img), model.context_length, dtype=torch.long)
    for i, tokens in enumerate(tokens_img):
        tokens = [sot_token] + tokens + [eot_token]
        text_input[i, :len(tokens)] = torch.tensor(tokens)
    text_inputs_att.append(text_input.cuda())


In [15]:
# Load CLIP model

In [16]:
with torch.no_grad():
    image_features = model.encode_image(image_input).float()

To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor'). (Triggered internally at  /opt/conda/conda-bld/pytorch_1623448278899/work/aten/src/ATen/native/BinaryOps.cpp:467.)
  image_features = model.encode_image(image_input).float()


In [17]:
with torch.no_grad():
    label_fea = model.encode_text(text_inputs_label.cuda()).float()

In [18]:
with torch.no_grad():
    text_feature = list()
    for txt in text_inputs_att:
        if len(txt) == 0:
            text_feature.append(torch.empty(0, 512).cuda())
        else:
            text_feature.append(model.encode_text(txt).float())

In [19]:
image_features /= image_features.norm(dim=-1, keepdim=True)

label_fea /= label_fea.norm(dim=-1, keepdim=True)

text_feature = torch.stack([torch.mean(item,0) for item in text_feature])
text_feature /= text_feature.norm(dim=-1, keepdim=True)

In [20]:
# Save image and text features
with open('../checkpoints/data_txt_feature.pkl', 'wb') as file:
    pickle.dump({'label': label_fea, 'att': text_feature}, file)
with open('../checkpoints/data_img_feature.pkl', 'wb') as file:
    pickle.dump(image_features, file)

# Option 2 Load saved image and text features

In [21]:
with open('../checkpoints/data_txt_feature.pkl', 'rb') as file:
    b = pickle.load(file)

label_fea = b['label']
text_feature = b['att']

In [22]:
with open('../checkpoints/data_img_feature.pkl', 'rb') as file:
    image_features = pickle.load(file)

# Construct the dataloader for classification

In [23]:
from torch.utils.data import Dataset
from sklearn import preprocessing

class Dataset(Dataset):

    def __init__(self, image_features, text_feature, labels, data_indx):
        self.image_features = image_features
        self.text_feature = text_feature
        self.labels = labels
        self.data_indx = data_indx
#         self.imgs = image_input
#         self.attr = attr

    def __len__(self):
        return len(self.image_features)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        sample = {'image': self.image_features[idx], 
                  'attribute': self.text_feature[idx], 
                  'label': self.labels[idx],
                  'data_indx': self.data_indx[idx]
#                   'imgs': self.imgs[idx],
#                   'attr': self.attr[idx]
                 }

        return sample
   

le = preprocessing.LabelEncoder()
le.fit(labels)
class_list = list(le.classes_)
labels_list = torch.tensor(le.transform(labels)).cuda()

attr_ = [';'.join(attr[0]) for item in attr]
data_indx = list(range(4600))
# dataset = Dataset(image_features, text_feature, labels_list, torch.tensor(np.stack(images)).cuda(), attr_)
dataset = Dataset(image_features, text_feature, labels_list, data_indx)
train_set, test_set = torch.utils.data.random_split(dataset,[4600-500,500])
trainloader = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)
testloader = torch.utils.data.DataLoader(test_set, batch_size=64, shuffle=True)

In [24]:
import torch.nn as nn
from torch.utils.data import DataLoader
# defining the model architecture
class Net(nn.Module):   
  def __init__(self):
      super(Net, self).__init__()

      self.linear_layers = nn.Sequential(
          nn.Linear(1024, 512),
          nn.Linear(512, 230)
      )

  # Defining the forward pass    
  def forward(self, x, t):
      con = torch.cat((x, t), 1)
      out = self.linear_layers(con)
      return out

In [25]:
model = Net().cuda()
error = nn.CrossEntropyLoss().cuda()
learning_rate = 0.01
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
scheduler = StepLR(optimizer, step_size=10, gamma=0.1)

In [26]:
num_epochs = 30
# Lists for visualization of loss and accuracy 
epoch_list = []
train_accuracy_list = []
train_loss_list = []
valid_accuracy_list = []
valid_loss_list = []
PATH = "../checkpoints/cnn.pth"

for epoch in range(num_epochs):
    correct = 0
    running_loss = 0
    model.train()
    for data in trainloader:
        # Transfering images and labels to GPU if available
#         image_batch, text_batch, label_batch, im_batch, att_batch = data['image'], data['attribute'], data['label'], data['imgs'], data['attr']
        image_batch, text_batch, label_batch, idx_batch = data['image'], data['attribute'], data['label'], data['data_indx']
        # Forward pass 
        outputs = model(image_batch, text_batch)
        #CrossEntropyLoss expects floating point inputs and long labels.
        loss = error(outputs, label_batch)
        # Initializing a gradient as 0 so there is no mixing of gradient among the batches
        optimizer.zero_grad()
        #Propagating the error backward
        loss.backward()
        # Optimizing the parameters
        optimizer.step()
    
        predictions = torch.max(outputs, 1)[1].cuda()
        correct += (predictions == label_batch).sum()
        running_loss += loss.item()

    train_loss_list.append(float(running_loss) / float(len(trainloader.dataset)))
    train_accuracy_list.append(float(correct) / float(len(trainloader.dataset)))
    
    # test on validation set
    correct = 0
    running_loss = 0
    with torch.no_grad():
        for data in testloader:
            image_batch, text_batch, label_batch, idx_batch = data['image'], data['attribute'], data['label'], data['data_indx']

            outputs = model(image_batch, text_batch)

            predictions = torch.max(outputs, 1)[1].cuda()
            correct += (predictions == label_batch).sum()
            running_loss += loss.item()


                               
    valid_loss_list.append(float(running_loss) / float(len(testloader.dataset)))
    valid_accuracy_list.append(float(correct) / float(len(testloader.dataset)))
                               
    print("Epoch: {}, train_loss: {}, train_accuracy: {}%, test_loss: {}, test_accuracy: {}%".format(epoch, 
                                                      train_loss_list[-1], 
                                                      train_accuracy_list[-1], 
                                                      valid_loss_list[-1], 
                                                      valid_accuracy_list[-1]))
            
    
                          
    epoch_list.append(epoch)     
    scheduler.step()
    
    if (epoch % 10) == 0:
        torch.save({
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict()
        }, PATH)

Epoch: 0, train_loss: 0.047759719650919845, train_accuracy: 0.2975609756097561%, test_loss: 0.0199383544921875, test_accuracy: 0.626%
Epoch: 1, train_loss: 0.010961451014367546, train_accuracy: 0.7914634146341464%, test_loss: 0.011954001426696777, test_accuracy: 0.792%
Epoch: 2, train_loss: 0.005725246373473145, train_accuracy: 0.8939024390243903%, test_loss: 0.0018081858158111573, test_accuracy: 0.792%
Epoch: 3, train_loss: 0.002696266144616302, train_accuracy: 0.9490243902439024%, test_loss: 6.169255450367928e-05, test_accuracy: 0.858%
Epoch: 4, train_loss: 0.0013520190019796535, train_accuracy: 0.9782926829268292%, test_loss: 0.0012208878993988036, test_accuracy: 0.864%
Epoch: 5, train_loss: 0.0012840903841140794, train_accuracy: 0.9819512195121951%, test_loss: 0.0001842034310102463, test_accuracy: 0.864%
Epoch: 6, train_loss: 0.0005731607787311077, train_accuracy: 0.9917073170731707%, test_loss: 0.0013498475551605225, test_accuracy: 0.884%
Epoch: 7, train_loss: 0.000886186328375848

In [27]:
m = nn.Softmax()

In [29]:
data = next(iter(testloader))
image_batch, text_batch, label_batch, idx_batch = data['image'], data['attribute'], data['label'], data['data_indx']
outputs = model(image_batch, text_batch)

for id in range(64):
    plt.imshow(images[idx_batch[id]].cpu().detach().permute(1, 2, 0))
    plt.show()
    print(m(outputs[id]).cpu().topk(3, dim=-1))
    top3 = m(outputs[id]).cpu().topk(3, dim=-1).indices
    print([class_list[i] for i in top3])
    print(attr[idx_batch[id]])