In [1]:
import os
import clip
import torch
import numpy as np
import pandas as pd
from PIL import Image
from tqdm import tqdm
from skimage import io, transform
from torchvision import transforms, utils
from torch.utils.data import Dataset, DataLoader

In [2]:
#Load the model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32",device=device,jit=False) #Must set jit=False for training
checkpoint = torch.load("./model_30_5e7_001_fixed.pt")
model.load_state_dict(checkpoint['model_state_dict'])

<All keys matched successfully>

In [3]:
import random
images_root_pth = './birds/CUB_200_2011/images/'
text_root_pth = './birds/text/'
imgID_pth_df = pd.read_csv('./birds/CUB_200_2011/images.txt', sep=' ', header=None, names=['img_id', 'sub_pth'])
train_split_df = pd.read_csv('./birds//CUB_200_2011/train_test_split.txt', sep=' ', names=['img_id', 'is_training'])
class_names = pd.read_csv('./birds/CUB_200_2011/classes.txt', sep=' ', header=None, names=['class_id', 'class_name'])
class_names.class_name = class_names.class_name.map(lambda x: x.split('.')[1].lower())

cub_dataset_df = imgID_pth_df.merge(train_split_df, on='img_id', how='inner')


In [4]:


class CUBDataset(Dataset):
    """CUB dataset."""

    def __init__(self, dataframe, img_root_dir, text_root_dir, transform=None, t=None):
        """
        Args:
            dataframe (pd.DataFrame): Dataframe with paths and train/test split information.
            root_dir (string): Root directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.cub_img_df = dataframe
        self.img_root_dir = img_root_dir
        self.text_root_dir = text_root_dir
        self.transform = transform
        self.t = t

    def __len__(self):
        return len(self.cub_img_df)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        sub_pth = self.cub_img_df.iloc[idx, 1]
        
        #IMAGE PROCESSING
        img_name = os.path.join(self.img_root_dir,
                                sub_pth)
        image = io.imread(img_name)
        target = int(sub_pth.split('.')[0])
        if self.transform:
            image = self.transform(image)
        
        #TEXT PROCESSING
        
        pth=sub_pth[:-4]+'.txt'
        '''
        text_file_name = os.path.join(self.text_root_dir, pth)
        lines = open(text_file_name).read().splitlines()
        content = random.choice(lines)
        text = clip.tokenize(content)
        
        '''
        text_file_name = os.path.join(self.text_root_dir, pth)
        myfile=open(text_file_name,"r")
        content=myfile.readline()
        text = clip.tokenize(content)
        

        return image, text, target
                                      
train_set = CUBDataset(cub_dataset_df[cub_dataset_df.is_training==1],images_root_pth, text_root_pth,transform=transforms.Compose([transforms.ToPILImage(), preprocess]), t=True)
test_set = CUBDataset(cub_dataset_df[cub_dataset_df.is_training==0], images_root_pth, text_root_pth, transform=transforms.Compose([transforms.ToPILImage(), preprocess]), t=False)

In [5]:
len(train_set)

5994

In [6]:
def get_img_txt_features(dataset,s1 = None):
    all_img_features = []
    all_txt_features = []
    all_labels = []
    
    with torch.no_grad():
        for images, details, labels in tqdm(DataLoader(dataset, batch_size=1,shuffle=True, num_workers=0)):
            #print(images.shape)
            img_features = model.encode_image(images.to(device))
            txt_features = model.encode_text(torch.cat(tuple(details)).to(device))

            all_img_features.append(img_features)
            all_txt_features.append(txt_features)
            all_labels.append(labels)
    #print(all_img_features.shape)
    return all_img_features, all_txt_features, all_labels


In [7]:
train_image_features, train_txt_features, train_labels = get_img_txt_features(train_set)
test_image_features, test_txt_features, test_labels = get_img_txt_features(test_set)

100%|██████████| 5994/5994 [02:31<00:00, 39.66it/s]
100%|██████████| 5794/5794 [02:37<00:00, 36.77it/s]


In [8]:
print(len(train_image_features), len(train_txt_features), len(train_labels))
print(train_image_features[0].shape, train_txt_features[0].shape)

5994 5994 5994
torch.Size([1, 512]) torch.Size([1, 512])


In [9]:
train_image_features = train_image_features.to(torch.float32)
train_txt_features = train_txt_features.to(torch.float32) 
test_image_features = test_image_features.to(torch.float32) 
test_txt_features = test_txt_features.to(torch.float32)
train_labels = train_labels.to(torch.float32) 
train_labels = train_labels.to(torch.float32)

AttributeError: 'list' object has no attribute 'to'

In [10]:
img_temp = torch.cat(train_image_features, axis=0)
txt_temp = torch.cat(train_txt_features, axis=0)
labels_temp = torch.tensor(train_labels)

In [11]:
print(txt_temp.shape)

torch.Size([5994, 512])


In [12]:
class ProcessedDataset(Dataset):
    def __init__(self, image_features, text_features, classes, transform=None):
        """
        Args:
            dataframe (pd.DataFrame): Dataframe with paths and train/test split information.
            root_dir (string): Root directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        
        self.image_features = image_features
        self.text_features = text_features
        self.transform = transform
        self.labels = classes

    def __len__(self):
        return len(self.image_features)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        image_feature = self.image_features[idx]
        text_feature = self.text_features[idx]
        label = self.labels[idx]


        return image_feature, text_feature, label

In [13]:
train_pset = ProcessedDataset(train_image_features, train_txt_features, train_labels)
test_pset = ProcessedDataset(test_image_features, test_txt_features, test_labels)

In [14]:
train_loader = torch.utils.data.DataLoader(train_pset, batch_size=16, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_pset, batch_size=16, shuffle=False)

In [15]:
#image model

import torch.nn as nn
import torch.nn.functional as F

class Image_Model(nn.Module):
    def __init__(self):
        super(Image_Model, self).__init__()
        self.img_hid1 = nn.Linear(512, 384)
        #self.img_hid2 = nn.Linear(256, 256)
        self.img_hid3 = nn.Linear(384, 256)
        self.img_hid4 = nn.Linear(256, 128)
        self.img_dropout1 = nn.Dropout(0.5)

    def forward(self, img):
        z1 = F.relu(self.img_hid1(img))
        z1 = self.img_dropout1(z1)
        #z1 = F.relu(self.img_hid2(z1))
        z1 = F.relu(self.img_hid3(z1))
        z1 = F.relu(self.img_hid4(z1))
        # print(z1.shape)
        return z1

img_model = Image_Model()
print(img_model)

Image_Model(
  (img_hid1): Linear(in_features=512, out_features=384, bias=True)
  (img_hid3): Linear(in_features=384, out_features=256, bias=True)
  (img_hid4): Linear(in_features=256, out_features=128, bias=True)
  (img_dropout1): Dropout(p=0.5, inplace=False)
)


In [16]:
#text model
import torch.nn as nn
import torch.nn.functional as F

class Text_Model(nn.Module):
    def __init__(self):
        super(Text_Model, self).__init__()
        self.txt_hid1 = nn.Linear(512, 384)
        #self.txt_hid2 = nn.Linear(256, 256)
        self.txt_hid3 = nn.Linear(384, 256)
        self.txt_hid4 = nn.Linear(256, 128)
        self.txt_dropout1 = nn.Dropout(0.5)

    def forward(self, text):
        z2 = F.relu(self.txt_hid1(text))
        z2 = self.txt_dropout1(z2)
        #z2 = F.relu(self.txt_hid2(z2))
        z2 = F.relu(self.txt_hid3(z2))
        z2 = F.relu(self.txt_hid4(z2))
        # print(z2.shape)
        return z2

txt_model = Text_Model()
print(txt_model)

Text_Model(
  (txt_hid1): Linear(in_features=512, out_features=384, bias=True)
  (txt_hid3): Linear(in_features=384, out_features=256, bias=True)
  (txt_hid4): Linear(in_features=256, out_features=128, bias=True)
  (txt_dropout1): Dropout(p=0.5, inplace=False)
)


In [17]:
class ParentModel(nn.Module):
    def __init__(self, modelA, modelB):
        super(ParentModel, self).__init__()
        # self.modelA = modelA
        # self.modelB = modelB
        self.fc1 = nn.Linear(1024,512)
        self.bn1 = nn.BatchNorm1d(512)
        self.fc2 = nn.Linear(512,200)
        # self.fc3 = nn.Linear(64,200)
         
        self.dropout1=nn.Dropout(0.5)


    def forward(self, x1, x2):
        # x1 = self.modelA(x1)
        # x2 = self.modelB(x2)
        x = torch.cat((x1, x2), dim=1)
        # print(x.shape)
        x=x.float()
        x=self.bn1(F.relu(self.fc1(x)))
        x=self.fc2(x)
        # x = self.dropout1(x)
        # x=self.fc3(x)
        return x
        
img_model = Image_Model()
txt_model = Text_Model()
end_model = ParentModel(img_model, txt_model)
print(end_model)


ParentModel(
  (fc1): Linear(in_features=1024, out_features=512, bias=True)
  (bn1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc2): Linear(in_features=512, out_features=200, bias=True)
  (dropout1): Dropout(p=0.5, inplace=False)
)


In [18]:
pytorch_total_params = sum(p.numel() for p in end_model.parameters())
print('Number of parameters: {0}'.format(pytorch_total_params))

Number of parameters: 628424


In [19]:
# specify loss function
end_model.to(device)
criterion = nn.CrossEntropyLoss()
# specify optimizer
optimizer = torch.optim.SGD(end_model.parameters(), lr=0.005)

In [20]:
epochs=30

end_model.to(device)
for epoch in range(0, epochs):
    train_loss=0
    test_loss=0
    train_total=0
    train_correct=0
    
    print("epoch number: {0}".format(epoch))
    end_model.train()
    
    with tqdm(train_loader, unit = 'batch') as tepoch:

        for batch_idx, (train_images, train_details, train_labels) in enumerate(tepoch):
            
            train_img_features = torch.cat(tuple(train_images)).to(device)
            train_txt_features = torch.cat(tuple(train_details)).to(device)
            #print(img_features.shape, txt_features.shape)
            
            #print(train_labels)
            train_labels = torch.cat(tuple(train_labels)).to(device)
            #print(train_labels.shape)
            optimizer.zero_grad()
            output = end_model(train_img_features, train_txt_features)
            #print(output)
            loss = criterion(output, train_labels.long()-1)
            loss.backward()
            optimizer.step()


            train_loss += loss.item()
            _, predicted = output.max(1)
            train_total += train_labels.size(0)
            train_correct += predicted.eq(train_labels.long()-1).sum().item()
            # print(train_correct) 
        print(' train loss: {:.4f} accuracy: {:.4f}'.format(train_loss/(batch_idx+1), 100.*train_correct/train_total))

    with torch.no_grad():
        test_total=0
        test_correct= 0
        
        end_model.eval()
        
        with tqdm(test_loader, unit ="batch") as tepoch:
            for batch_idx ,(text_images, test_details, test_labels) in enumerate(tepoch):


                test_img_features = torch.cat(tuple(text_images)).to(device)
                test_txt_features = torch.cat(tuple(test_details)).to(device)
                test_labels = torch.cat(tuple(test_labels)).to(device)

                y_pred_test = end_model(test_img_features, test_txt_features)
                loss_test = criterion(y_pred_test, test_labels.long()-1)
                test_loss+=loss_test.item()
            
                _, predicted = y_pred_test.max(1)
                test_total += test_labels.size(0)
                test_correct += predicted.eq(test_labels.long()-1).sum().item()
            print('test loss: {:.4f} accuracy: {:.4f}'.format(test_loss/(batch_idx+1), 100.*test_correct/test_total))


epoch number: 0


100%|██████████| 375/375 [00:00<00:00, 411.37batch/s]


 train loss: 4.4226 accuracy: 15.7157


100%|██████████| 363/363 [00:00<00:00, 1563.09batch/s]


test loss: 3.5561 accuracy: 34.9154
epoch number: 1


100%|██████████| 375/375 [00:00<00:00, 392.84batch/s]


 train loss: 3.1261 accuracy: 46.5465


100%|██████████| 363/363 [00:00<00:00, 1566.09batch/s]


test loss: 2.7639 accuracy: 51.2254
epoch number: 2


100%|██████████| 375/375 [00:01<00:00, 344.70batch/s]


 train loss: 2.4216 accuracy: 62.1455


100%|██████████| 363/363 [00:00<00:00, 1515.09batch/s]


test loss: 2.2511 accuracy: 59.3200
epoch number: 3


100%|██████████| 375/375 [00:00<00:00, 473.14batch/s]


 train loss: 1.9570 accuracy: 71.5048


100%|██████████| 363/363 [00:00<00:00, 1592.57batch/s]


test loss: 1.9502 accuracy: 65.1709
epoch number: 4


100%|██████████| 375/375 [00:00<00:00, 378.95batch/s]


 train loss: 1.6281 accuracy: 77.1438


100%|██████████| 363/363 [00:00<00:00, 1620.94batch/s]


test loss: 1.7080 accuracy: 69.4166
epoch number: 5


100%|██████████| 375/375 [00:01<00:00, 340.31batch/s]


 train loss: 1.3851 accuracy: 80.8976


100%|██████████| 363/363 [00:00<00:00, 1514.06batch/s]


test loss: 1.5496 accuracy: 70.7283
epoch number: 6


100%|██████████| 375/375 [00:01<00:00, 337.75batch/s]


 train loss: 1.2051 accuracy: 83.4001


100%|██████████| 363/363 [00:00<00:00, 1605.94batch/s]


test loss: 1.4197 accuracy: 71.9020
epoch number: 7


100%|██████████| 375/375 [00:01<00:00, 327.24batch/s]


 train loss: 1.0594 accuracy: 84.9183


100%|██████████| 363/363 [00:00<00:00, 1559.65batch/s]


test loss: 1.3318 accuracy: 73.1446
epoch number: 8


100%|██████████| 375/375 [00:01<00:00, 328.23batch/s]


 train loss: 0.9390 accuracy: 87.4875


100%|██████████| 363/363 [00:00<00:00, 1576.18batch/s]


test loss: 1.2716 accuracy: 73.8523
epoch number: 9


100%|██████████| 375/375 [00:01<00:00, 337.02batch/s]


 train loss: 0.8517 accuracy: 89.1058


100%|██████████| 363/363 [00:00<00:00, 1556.60batch/s]


test loss: 1.2245 accuracy: 73.8177
epoch number: 10


100%|██████████| 375/375 [00:01<00:00, 334.40batch/s]


 train loss: 0.7775 accuracy: 89.6897


100%|██████████| 363/363 [00:00<00:00, 1583.16batch/s]


test loss: 1.1898 accuracy: 73.6279
epoch number: 11


100%|██████████| 375/375 [00:01<00:00, 340.00batch/s]


 train loss: 0.7052 accuracy: 91.2079


100%|██████████| 363/363 [00:00<00:00, 1579.10batch/s]


test loss: 1.1506 accuracy: 74.4909
epoch number: 12


100%|██████████| 375/375 [00:01<00:00, 350.83batch/s]


 train loss: 0.6432 accuracy: 92.4758


100%|██████████| 363/363 [00:00<00:00, 1582.63batch/s]


test loss: 1.1206 accuracy: 74.3183
epoch number: 13


100%|██████████| 375/375 [00:01<00:00, 342.81batch/s]


 train loss: 0.5898 accuracy: 93.1431


100%|██████████| 363/363 [00:00<00:00, 1578.08batch/s]


test loss: 1.1000 accuracy: 74.4391
epoch number: 14


100%|██████████| 375/375 [00:01<00:00, 332.91batch/s]


 train loss: 0.5466 accuracy: 94.1441


100%|██████████| 363/363 [00:00<00:00, 1578.79batch/s]


test loss: 1.0769 accuracy: 75.1812
epoch number: 15


100%|██████████| 375/375 [00:00<00:00, 383.27batch/s]


 train loss: 0.5030 accuracy: 95.2286


100%|██████████| 363/363 [00:00<00:00, 1586.94batch/s]


test loss: 1.0708 accuracy: 74.4391
epoch number: 16


100%|██████████| 375/375 [00:00<00:00, 468.69batch/s]


 train loss: 0.4761 accuracy: 95.7291


100%|██████████| 363/363 [00:00<00:00, 1550.74batch/s]


test loss: 1.0652 accuracy: 74.5254
epoch number: 17


100%|██████████| 375/375 [00:00<00:00, 457.97batch/s]


 train loss: 0.4444 accuracy: 95.9626


100%|██████████| 363/363 [00:00<00:00, 1461.57batch/s]


test loss: 1.0656 accuracy: 74.7670
epoch number: 18


100%|██████████| 375/375 [00:00<00:00, 465.22batch/s]


 train loss: 0.4133 accuracy: 96.6133


100%|██████████| 363/363 [00:00<00:00, 1270.16batch/s]


test loss: 1.0576 accuracy: 74.7152
epoch number: 19


100%|██████████| 375/375 [00:00<00:00, 430.90batch/s]


 train loss: 0.3818 accuracy: 97.3807


100%|██████████| 363/363 [00:00<00:00, 1572.95batch/s]


test loss: 1.0414 accuracy: 74.0594
epoch number: 20


100%|██████████| 375/375 [00:00<00:00, 376.96batch/s]


 train loss: 0.3595 accuracy: 97.6310


100%|██████████| 363/363 [00:00<00:00, 1537.94batch/s]


test loss: 1.0325 accuracy: 74.1111
epoch number: 21


100%|██████████| 375/375 [00:00<00:00, 459.43batch/s]


 train loss: 0.3285 accuracy: 98.0647


100%|██████████| 363/363 [00:00<00:00, 1554.16batch/s]


test loss: 1.0356 accuracy: 73.7832
epoch number: 22


100%|██████████| 375/375 [00:00<00:00, 455.53batch/s]


 train loss: 0.3085 accuracy: 98.1815


100%|██████████| 363/363 [00:00<00:00, 1562.26batch/s]


test loss: 1.0213 accuracy: 74.1974
epoch number: 23


100%|██████████| 375/375 [00:00<00:00, 444.06batch/s]


 train loss: 0.2921 accuracy: 98.6320


100%|██████████| 363/363 [00:00<00:00, 1303.59batch/s]


test loss: 1.0183 accuracy: 74.4736
epoch number: 24


100%|██████████| 375/375 [00:00<00:00, 419.51batch/s]


 train loss: 0.2739 accuracy: 98.6653


100%|██████████| 363/363 [00:00<00:00, 1512.79batch/s]


test loss: 1.0273 accuracy: 73.8523
epoch number: 25


100%|██████████| 375/375 [00:00<00:00, 493.40batch/s]


 train loss: 0.2594 accuracy: 99.0324


100%|██████████| 363/363 [00:00<00:00, 1528.03batch/s]


test loss: 1.0242 accuracy: 73.9731
epoch number: 26


100%|██████████| 375/375 [00:00<00:00, 387.31batch/s]


 train loss: 0.2468 accuracy: 99.1658


100%|██████████| 363/363 [00:00<00:00, 1517.68batch/s]


test loss: 1.0319 accuracy: 73.4898
epoch number: 27


100%|██████████| 375/375 [00:00<00:00, 420.89batch/s]


 train loss: 0.2322 accuracy: 99.2993


100%|██████████| 363/363 [00:00<00:00, 1583.24batch/s]


test loss: 1.0173 accuracy: 73.9731
epoch number: 28


100%|██████████| 375/375 [00:00<00:00, 432.45batch/s]


 train loss: 0.2204 accuracy: 99.3994


100%|██████████| 363/363 [00:00<00:00, 1553.33batch/s]


test loss: 1.0244 accuracy: 73.8350
epoch number: 29


100%|██████████| 375/375 [00:00<00:00, 420.72batch/s]


 train loss: 0.2044 accuracy: 99.5662


100%|██████████| 363/363 [00:00<00:00, 1566.59batch/s]

test loss: 1.0164 accuracy: 73.8177





In [None]:
# epoch number: 14
# 100%|██████████| 375/375 [00:01<00:00, 332.91batch/s]
#  train loss: 0.5466 accuracy: 94.1441
# 100%|██████████| 363/363 [00:00<00:00, 1578.79batch/s]
# test loss: 1.0769 accuracy: 75.1812
# epoch number: 15