In [1]:
import os
import clip
import torch
import numpy as np
import pandas as pd
from PIL import Image
from tqdm import tqdm
from skimage import io, transform
from torchvision import transforms, utils
from torch.utils.data import Dataset, DataLoader

In [2]:

# Load the model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load('ViT-B/32', device)

print(device)

cuda


In [3]:
import random
images_root_pth = './birds/CUB_200_2011/images/'
text_root_pth = './birds/text/'
imgID_pth_df = pd.read_csv('./birds/CUB_200_2011/images.txt', sep=' ', header=None, names=['img_id', 'sub_pth'])
train_split_df = pd.read_csv('./birds//CUB_200_2011/train_test_split.txt', sep=' ', names=['img_id', 'is_training'])
class_names = pd.read_csv('./birds/CUB_200_2011/classes.txt', sep=' ', header=None, names=['class_id', 'class_name'])
class_names.class_name = class_names.class_name.map(lambda x: x.split('.')[1].lower())

cub_dataset_df = imgID_pth_df.merge(train_split_df, on='img_id', how='inner')


In [4]:


class CUBDataset(Dataset):
    """CUB dataset."""

    def __init__(self, dataframe, img_root_dir, text_root_dir, transform=None, t=None):
        """
        Args:
            dataframe (pd.DataFrame): Dataframe with paths and train/test split information.
            root_dir (string): Root directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.cub_img_df = dataframe
        self.img_root_dir = img_root_dir
        self.text_root_dir = text_root_dir
        self.transform = transform
        self.t = t

    def __len__(self):
        return len(self.cub_img_df)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        sub_pth = self.cub_img_df.iloc[idx, 1]
        
        #IMAGE PROCESSING
        img_name = os.path.join(self.img_root_dir,
                                sub_pth)
        image = io.imread(img_name)
        target = int(sub_pth.split('.')[0])
        if self.transform:
            image = self.transform(image)
        
        #TEXT PROCESSING
        
        pth=sub_pth[:-4]+'.txt'
        '''
        text_file_name = os.path.join(self.text_root_dir, pth)
        lines = open(text_file_name).read().splitlines()
        content = random.choice(lines)
        text = clip.tokenize(content)
        
        '''
        text_file_name = os.path.join(self.text_root_dir, pth)
        myfile=open(text_file_name,"r")
        content=myfile.readline()
        text = clip.tokenize(content)
        

        return image, text, target
                                      
train_set = CUBDataset(cub_dataset_df[cub_dataset_df.is_training==1],images_root_pth, text_root_pth,transform=transforms.Compose([transforms.ToPILImage(), preprocess]), t=True)
test_set = CUBDataset(cub_dataset_df[cub_dataset_df.is_training==0], images_root_pth, text_root_pth, transform=transforms.Compose([transforms.ToPILImage(), preprocess]), t=False)

In [5]:
len(train_set)

5994

In [6]:
def get_img_txt_features(dataset,s1 = None):
    all_img_features = []
    all_txt_features = []
    all_labels = []
    
    with torch.no_grad():
        for images, details, labels in tqdm(DataLoader(dataset, batch_size=1,shuffle=True, num_workers=0)):
            #print(images.shape)
            img_features = model.encode_image(images.to(device))
            txt_features = model.encode_text(torch.cat(tuple(details)).to(device))

            all_img_features.append(img_features)
            all_txt_features.append(txt_features)
            all_labels.append(labels)
    #print(all_img_features.shape)
    return all_img_features, all_txt_features, all_labels


In [7]:
train_image_features, train_txt_features, train_labels = get_img_txt_features(train_set)
test_image_features, test_txt_features, test_labels = get_img_txt_features(test_set)

100%|██████████| 5994/5994 [02:46<00:00, 35.95it/s]
100%|██████████| 5794/5794 [02:28<00:00, 38.92it/s]


In [8]:
print(len(train_image_features), len(train_txt_features), len(train_labels))
print(train_image_features[0].shape, train_txt_features[0].shape)

5994 5994 5994
torch.Size([1, 512]) torch.Size([1, 512])


In [None]:
train_image_features = train_image_features.to(torch.float32)
train_txt_features = train_txt_features.to(torch.float32) 
test_image_features = test_image_features.to(torch.float32) 
test_txt_features = test_txt_features.to(torch.float32)
train_labels = train_labels.to(torch.float32) 
train_labels = train_labels.to(torch.float32)

In [9]:
img_temp = torch.cat(train_image_features, axis=0)
txt_temp = torch.cat(train_txt_features, axis=0)
labels_temp = torch.tensor(train_labels)

In [10]:
print(txt_temp.shape)

torch.Size([5994, 512])


In [11]:
class ProcessedDataset(Dataset):
    def __init__(self, image_features, text_features, classes, transform=None):
        """
        Args:
            dataframe (pd.DataFrame): Dataframe with paths and train/test split information.
            root_dir (string): Root directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        
        self.image_features = image_features
        self.text_features = text_features
        self.transform = transform
        self.labels = classes

    def __len__(self):
        return len(self.image_features)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        image_feature = self.image_features[idx]
        text_feature = self.text_features[idx]
        label = self.labels[idx]


        return image_feature, text_feature, label

In [12]:
train_pset = ProcessedDataset(train_image_features, train_txt_features, train_labels)
test_pset = ProcessedDataset(test_image_features, test_txt_features, test_labels)

In [13]:
train_loader = torch.utils.data.DataLoader(train_pset, batch_size=16, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_pset, batch_size=16, shuffle=False)

In [14]:
#image model

import torch.nn as nn
import torch.nn.functional as F

class Image_Model(nn.Module):
    def __init__(self):
        super(Image_Model, self).__init__()
        self.img_hid1 = nn.Linear(512, 384)
        #self.img_hid2 = nn.Linear(256, 256)
        self.img_hid3 = nn.Linear(384, 256)
        self.img_hid4 = nn.Linear(256, 128)
        self.img_dropout1 = nn.Dropout(0.5)

    def forward(self, img):
        z1 = F.relu(self.img_hid1(img))
        z1 = self.img_dropout1(z1)
        #z1 = F.relu(self.img_hid2(z1))
        z1 = F.relu(self.img_hid3(z1))
        z1 = F.relu(self.img_hid4(z1))
        # print(z1.shape)
        return z1

img_model = Image_Model()
print(img_model)

Image_Model(
  (img_hid1): Linear(in_features=512, out_features=384, bias=True)
  (img_hid3): Linear(in_features=384, out_features=256, bias=True)
  (img_hid4): Linear(in_features=256, out_features=128, bias=True)
  (img_dropout1): Dropout(p=0.5, inplace=False)
)


In [15]:
#text model
import torch.nn as nn
import torch.nn.functional as F

class Text_Model(nn.Module):
    def __init__(self):
        super(Text_Model, self).__init__()
        self.txt_hid1 = nn.Linear(512, 384)
        #self.txt_hid2 = nn.Linear(256, 256)
        self.txt_hid3 = nn.Linear(384, 256)
        self.txt_hid4 = nn.Linear(256, 128)
        self.txt_dropout1 = nn.Dropout(0.5)

    def forward(self, text):
        z2 = F.relu(self.txt_hid1(text))
        z2 = self.txt_dropout1(z2)
        #z2 = F.relu(self.txt_hid2(z2))
        z2 = F.relu(self.txt_hid3(z2))
        z2 = F.relu(self.txt_hid4(z2))
        # print(z2.shape)
        return z2

txt_model = Text_Model()
print(txt_model)

Text_Model(
  (txt_hid1): Linear(in_features=512, out_features=384, bias=True)
  (txt_hid3): Linear(in_features=384, out_features=256, bias=True)
  (txt_hid4): Linear(in_features=256, out_features=128, bias=True)
  (txt_dropout1): Dropout(p=0.5, inplace=False)
)


In [16]:
class ParentModel(nn.Module):
    def __init__(self, modelA, modelB):
        super(ParentModel, self).__init__()
        # self.modelA = modelA
        # self.modelB = modelB
        self.fc1 = nn.Linear(1024,512)
        self.bn1 = nn.BatchNorm1d(512)
        self.fc2 = nn.Linear(512,200)
        # self.fc3 = nn.Linear(64,200)
         
        self.dropout1=nn.Dropout(0.5)


    def forward(self, x1, x2):
        # x1 = self.modelA(x1)
        # x2 = self.modelB(x2)
        x = torch.cat((x1, x2), dim=1)
        x=x.float()
        # print(x.shape)
        x=self.bn1(F.relu(self.fc1(x)))
        x=self.fc2(x)
        # x = self.dropout1(x)
        # x=self.fc3(x)
        return x
        
img_model = Image_Model()
txt_model = Text_Model()
end_model = ParentModel(img_model, txt_model)
print(end_model)


ParentModel(
  (fc1): Linear(in_features=1024, out_features=512, bias=True)
  (bn1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc2): Linear(in_features=512, out_features=200, bias=True)
  (dropout1): Dropout(p=0.5, inplace=False)
)


In [17]:
pytorch_total_params = sum(p.numel() for p in end_model.parameters())
print('Number of parameters: {0}'.format(pytorch_total_params))

Number of parameters: 628424


In [18]:
# specify loss function
end_model.to(device)
criterion = nn.CrossEntropyLoss()
# specify optimizer
optimizer = torch.optim.SGD(end_model.parameters(), lr=0.005)

In [19]:
epochs=30

end_model.to(device)
for epoch in range(0, epochs):
    train_loss=0
    test_loss=0
    train_total=0
    train_correct=0
    
    print("epoch number: {0}".format(epoch))
    end_model.train()
    
    with tqdm(train_loader, unit = 'batch') as tepoch:

        for batch_idx, (train_images, train_details, train_labels) in enumerate(tepoch):
            
            train_img_features = torch.cat(tuple(train_images)).to(device)
            train_txt_features = torch.cat(tuple(train_details)).to(device)
            #print(img_features.shape, txt_features.shape)
            
            #print(train_labels)
            train_labels = torch.cat(tuple(train_labels)).to(device)
            #print(train_labels.shape)
            optimizer.zero_grad()
            output = end_model(train_img_features, train_txt_features)
            #print(output)
            loss = criterion(output, train_labels.long()-1)
            loss.backward()
            optimizer.step()


            train_loss += loss.item()
            _, predicted = output.max(1)
            train_total += train_labels.size(0)
            train_correct += predicted.eq(train_labels.long()-1).sum().item()
            # print(train_correct) 
        print(' train loss: {:.4f} accuracy: {:.4f}'.format(train_loss/(batch_idx+1), 100.*train_correct/train_total))

    with torch.no_grad():
        test_total=0
        test_correct= 0
        
        end_model.eval()
        
        with tqdm(test_loader, unit ="batch") as tepoch:
            for batch_idx ,(text_images, test_details, test_labels) in enumerate(tepoch):


                test_img_features = torch.cat(tuple(text_images)).to(device)
                test_txt_features = torch.cat(tuple(test_details)).to(device)
                test_labels = torch.cat(tuple(test_labels)).to(device)

                y_pred_test = end_model(test_img_features, test_txt_features)
                loss_test = criterion(y_pred_test, test_labels.long()-1)
                test_loss+=loss_test.item()
            
                _, predicted = y_pred_test.max(1)
                test_total += test_labels.size(0)
                test_correct += predicted.eq(test_labels.long()-1).sum().item()
            print('test loss: {:.4f} accuracy: {:.4f}'.format(test_loss/(batch_idx+1), 100.*test_correct/test_total))


epoch number: 0


100%|██████████| 375/375 [00:00<00:00, 426.53batch/s]


 train loss: 4.4453 accuracy: 14.8815


100%|██████████| 363/363 [00:00<00:00, 1517.87batch/s]


test loss: 3.5573 accuracy: 35.7957
epoch number: 1


100%|██████████| 375/375 [00:00<00:00, 524.43batch/s]


 train loss: 3.0383 accuracy: 49.4661


100%|██████████| 363/363 [00:00<00:00, 1564.26batch/s]


test loss: 2.8232 accuracy: 41.9572
epoch number: 2


100%|██████████| 375/375 [00:00<00:00, 423.85batch/s]


 train loss: 2.2798 accuracy: 65.7157


100%|██████████| 363/363 [00:00<00:00, 1574.84batch/s]


test loss: 2.1375 accuracy: 62.0642
epoch number: 3


100%|██████████| 375/375 [00:00<00:00, 411.70batch/s]


 train loss: 1.8066 accuracy: 74.2075


100%|██████████| 363/363 [00:00<00:00, 1522.46batch/s]


test loss: 1.8186 accuracy: 66.5516
epoch number: 4


100%|██████████| 375/375 [00:00<00:00, 426.29batch/s]


 train loss: 1.4756 accuracy: 79.7965


100%|██████████| 363/363 [00:00<00:00, 1551.96batch/s]


test loss: 1.6057 accuracy: 69.3649
epoch number: 5


100%|██████████| 375/375 [00:00<00:00, 445.23batch/s]


 train loss: 1.2536 accuracy: 82.9329


100%|██████████| 363/363 [00:00<00:00, 1569.31batch/s]


test loss: 1.4670 accuracy: 71.0045
epoch number: 6


100%|██████████| 375/375 [00:00<00:00, 435.54batch/s]


 train loss: 1.0805 accuracy: 86.0194


100%|██████████| 363/363 [00:00<00:00, 1567.47batch/s]


test loss: 1.3519 accuracy: 72.2644
epoch number: 7


100%|██████████| 375/375 [00:00<00:00, 455.02batch/s]


 train loss: 0.9438 accuracy: 87.3874


100%|██████████| 363/363 [00:00<00:00, 1633.72batch/s]


test loss: 1.2688 accuracy: 73.0929
epoch number: 8


100%|██████████| 375/375 [00:00<00:00, 445.73batch/s]


 train loss: 0.8407 accuracy: 89.3560


100%|██████████| 363/363 [00:00<00:00, 1547.89batch/s]


test loss: 1.2083 accuracy: 73.5243
epoch number: 9


100%|██████████| 375/375 [00:00<00:00, 411.73batch/s]


 train loss: 0.7441 accuracy: 91.2913


100%|██████████| 363/363 [00:00<00:00, 1532.61batch/s]


test loss: 1.1665 accuracy: 73.6624
epoch number: 10


100%|██████████| 375/375 [00:00<00:00, 437.29batch/s]


 train loss: 0.6669 accuracy: 92.1421


100%|██████████| 363/363 [00:00<00:00, 1558.06batch/s]


test loss: 1.1387 accuracy: 73.7142
epoch number: 11


100%|██████████| 375/375 [00:00<00:00, 453.72batch/s]


 train loss: 0.6015 accuracy: 93.4101


100%|██████████| 363/363 [00:00<00:00, 1317.24batch/s]


test loss: 1.0905 accuracy: 74.4736
epoch number: 12


100%|██████████| 375/375 [00:00<00:00, 481.93batch/s]


 train loss: 0.5442 accuracy: 94.4611


100%|██████████| 363/363 [00:00<00:00, 1519.19batch/s]


test loss: 1.0742 accuracy: 74.3528
epoch number: 13


100%|██████████| 375/375 [00:00<00:00, 435.82batch/s]


 train loss: 0.4887 accuracy: 95.4454


100%|██████████| 363/363 [00:00<00:00, 1492.84batch/s]


test loss: 1.0577 accuracy: 74.4736
epoch number: 14


100%|██████████| 375/375 [00:00<00:00, 443.89batch/s]


 train loss: 0.4518 accuracy: 95.8125


100%|██████████| 363/363 [00:00<00:00, 1523.96batch/s]


test loss: 1.0570 accuracy: 73.9903
epoch number: 15


100%|██████████| 375/375 [00:00<00:00, 405.90batch/s]


 train loss: 0.4089 accuracy: 96.8635


100%|██████████| 363/363 [00:00<00:00, 1514.59batch/s]


test loss: 1.0379 accuracy: 74.1111
epoch number: 16


100%|██████████| 375/375 [00:00<00:00, 425.58batch/s]


 train loss: 0.3804 accuracy: 97.4474


100%|██████████| 363/363 [00:00<00:00, 1497.40batch/s]


test loss: 1.0330 accuracy: 74.2320
epoch number: 17


100%|██████████| 375/375 [00:00<00:00, 404.43batch/s]


 train loss: 0.3489 accuracy: 97.4975


100%|██████████| 363/363 [00:00<00:00, 1525.58batch/s]


test loss: 1.0124 accuracy: 74.0421
epoch number: 18


100%|██████████| 375/375 [00:00<00:00, 466.98batch/s]


 train loss: 0.3166 accuracy: 98.2649


100%|██████████| 363/363 [00:00<00:00, 1592.44batch/s]


test loss: 1.0109 accuracy: 74.5944
epoch number: 19


100%|██████████| 375/375 [00:00<00:00, 403.49batch/s]


 train loss: 0.2880 accuracy: 98.4151


100%|██████████| 363/363 [00:00<00:00, 1481.89batch/s]


test loss: 1.0088 accuracy: 74.1457
epoch number: 20


100%|██████████| 375/375 [00:00<00:00, 442.28batch/s]


 train loss: 0.2735 accuracy: 98.7821


100%|██████████| 363/363 [00:00<00:00, 1537.25batch/s]


test loss: 1.0000 accuracy: 73.8523
epoch number: 21


100%|██████████| 375/375 [00:00<00:00, 465.00batch/s]


 train loss: 0.2482 accuracy: 99.1658


100%|██████████| 363/363 [00:00<00:00, 1549.32batch/s]


test loss: 1.0016 accuracy: 74.1974
epoch number: 22


100%|██████████| 375/375 [00:00<00:00, 390.35batch/s]


 train loss: 0.2287 accuracy: 99.2492


100%|██████████| 363/363 [00:00<00:00, 1510.32batch/s]


test loss: 1.0036 accuracy: 73.5934
epoch number: 23


100%|██████████| 375/375 [00:00<00:00, 437.47batch/s]


 train loss: 0.2108 accuracy: 99.3994


100%|██████████| 363/363 [00:00<00:00, 1595.19batch/s]


test loss: 0.9879 accuracy: 73.9213
epoch number: 24


100%|██████████| 375/375 [00:00<00:00, 457.06batch/s]


 train loss: 0.1969 accuracy: 99.5662


100%|██████████| 363/363 [00:00<00:00, 1545.73batch/s]


test loss: 0.9996 accuracy: 73.9386
epoch number: 25


100%|██████████| 375/375 [00:00<00:00, 465.90batch/s]


 train loss: 0.1840 accuracy: 99.5829


100%|██████████| 363/363 [00:00<00:00, 1422.51batch/s]


test loss: 0.9860 accuracy: 74.0594
epoch number: 26


100%|██████████| 375/375 [00:00<00:00, 430.65batch/s]


 train loss: 0.1723 accuracy: 99.6830


100%|██████████| 363/363 [00:00<00:00, 1503.04batch/s]


test loss: 0.9845 accuracy: 73.9558
epoch number: 27


100%|██████████| 375/375 [00:00<00:00, 430.78batch/s]


 train loss: 0.1603 accuracy: 99.8665


100%|██████████| 363/363 [00:00<00:00, 1478.78batch/s]


test loss: 0.9778 accuracy: 73.3690
epoch number: 28


100%|██████████| 375/375 [00:00<00:00, 492.91batch/s]


 train loss: 0.1506 accuracy: 99.8498


100%|██████████| 363/363 [00:00<00:00, 1498.10batch/s]


test loss: 0.9800 accuracy: 73.7832
epoch number: 29


100%|██████████| 375/375 [00:00<00:00, 458.20batch/s]


 train loss: 0.1402 accuracy: 99.9333


100%|██████████| 363/363 [00:00<00:00, 1561.27batch/s]

test loss: 0.9806 accuracy: 73.4898





In [None]:
# epoch number: 11
# 100%|██████████| 375/375 [00:00<00:00, 453.72batch/s]
#  train loss: 0.6015 accuracy: 93.4101
# 100%|██████████| 363/363 [00:00<00:00, 1317.24batch/s]
# test loss: 1.0905 accuracy: 74.4736
# epoch number: 12