In [1]:
import torch
from torch import nn
from PIL import Image

from torch.nn import TransformerEncoder, TransformerEncoderLayer
from transformers import BertTokenizer, BertModel, ViTFeatureExtractor, ViTModel
from torch.utils.data import Dataset
from torchvision.io import read_image
from torchvision.transforms import Resize
from tqdm import tqdm

In [2]:
device = "cuda:1" if torch.cuda.is_available() else "cpu"

In [3]:
import numpy as np
import pandas as pd
import json
from tqdm import tqdm
from time import sleep


In [4]:
bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
# bert_model = BertModel.from_pretrained("bert-base-uncased").to(device)

image_feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
# image_model = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k").to(device)



In [5]:
def get_positional_embeddings(sequence_length, d):
    result = torch.ones(sequence_length, d)
    for i in range(sequence_length):
        for j in range(d):
            result[i][j] = np.sin(i / (10000 ** (j / d))) if j % 2 == 0 else np.cos(i / (10000 ** ((j - 1) / d)))
    return result

In [6]:
class MyModel(nn.Module):
    def __init__(self,device, d_model, concat_size, num_classes, nlayers, nhead, d_hid, dropout=0.5):
        
        super(MyModel, self).__init__()
        
        self.device = device
        
        
        self.bert_model = BertModel.from_pretrained("bert-base-uncased").to(self.device)
        
        self.vit_model = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k").to(self.device)
        
        self.d_model = d_model
        encoder_layers = TransformerEncoderLayer(d_model, nhead, d_hid, dropout)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, nlayers)
        self.class_token = nn.Parameter(torch.rand(1, d_model))
        
        self.linear_layer = nn.Sequential(nn.Linear(concat_size,1), nn.Sigmoid())
    
    def preprocess(self,image, encoded_text_input):
        
        #input_img = self.vit_feature_extractor(image, return_tensors='pt').to(device)

        with torch.no_grad():
            text_embeds = self.bert_model.embeddings(input_ids = encoded_text_input['input_ids'], 
                                            token_type_ids = encoded_text_input['token_type_ids']).to(device)
            image_embeds = self.vit_model.embeddings(image['pixel_values']).to(device)

        return image_embeds, text_embeds
        
    def forward(self, images, texts):
        image_patch_embeds, text_embeddings = self.preprocess(images, texts)
        
        concat_embeds = torch.cat([image_patch_embeds, text_embeddings], 1)
        concat_embeds = torch.stack([torch.vstack((self.class_token, concat_embeds[i])) for i in range(len(concat_embeds))])
        pos_embed = get_positional_embeddings(concat_embeds.shape[1], self.d_model).repeat(concat_embeds.shape[0], 1, 1)
        concat_embeds+= pos_embed.to(self.device)
        
        logits = self.transformer_encoder(concat_embeds)
        logits = logits[:,0,:]
        
        preds = self.linear_layer(logits)
        
        return preds

In [7]:

class MyDataset(Dataset):
    
    def __init__(self, annotations, img_dir):
        self.labels = pd.read_csv(annotations)
        self.img_dir = img_dir 
    
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        img_name = self.labels.iloc[idx,1]
        img_path = self.img_dir + img_name[0] + "/" + \
                img_name[1] + "/" + img_name[2] + "/" + img_name[3] + "/" + img_name
        
        im = read_image(img_path)
        
        recipe = self.labels.iloc[idx,3]
        
        label = self.labels.iloc[idx, 4]
        
        #t_embeds, img_embeds = self.preprocess(im, recipe)
        return im, recipe, label

In [8]:
def collate_function(batch):
    imgs = []
    texts = []
    labels = []
    for img, text, label in batch:
        
        imgs.append(img)
        texts.append(text)
        labels.append(label)
    
    encoded_text_input = bert_tokenizer(texts, return_tensors='pt',padding='max_length', truncation=True).to(device)
    imgs_feats = image_feature_extractor(imgs, return_tensors='pt')
    labels = torch.tensor(labels)
    return imgs_feats,encoded_text_input,labels
    
        

In [9]:
train_d = MyDataset('train_sampled.csv','/freespace/local/sk2381/im2recipe-Pytorch/data/train/')

In [10]:
test_d = MyDataset('test.csv','/freespace/local/sk2381/im2recipe-Pytorch/data/test/')

In [11]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(train_d, collate_fn=collate_function,batch_size=128, shuffle=False)
test_dataloader = DataLoader(test_d, collate_fn=collate_function,batch_size=128, shuffle=False)

In [12]:

transformer_model = MyModel(device, 768,768,2,1,2,512,0.5).to(device)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [13]:
device

'cuda:1'

In [14]:
loss_fn = nn.BCELoss()
optimizer = torch.optim.SGD(transformer_model.parameters(), lr=1e-3)

In [18]:
# def train_full(epoch, dataloader, model, loss_fxn, optimizer, PATH):
#     i=0
#     k = 0
#     with tqdm(dataloader, unit="batch") as tepoch:
#         for img,recipe, y in tepoch:
#             if i<=1000:
#                 i+=1
#                 k+=1
#             else:
#                 tepoch.set_description(f"Epoch {epoch}")

#                 img = img.to(device)
#                 recipe = recipe.to(device)

#                 y = torch.reshape(y, (y.shape[0],1))
#                 y = y.float()
#                 y = y.to(device)

#                 pred = model(img, recipe)
#                 loss = loss_fn(pred, y)
#                 optimizer.zero_grad()
#                 loss.backward()
#                 optimizer.step()
#                 LOSS = loss.item()

#                 tepoch.set_postfix(loss=loss.item())
#                 if i%500 == 0:
#                     torch.save({'epoch': epoch,'batch': k,'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'loss': LOSS,}, PATH)
#                 k+=1
#         torch.save({'epoch': epoch,'batch': k,'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'loss': LOSS,}, PATH + "_final")

                
def train_full(epoch, dataloader, model, loss_fxn, optimizer, PATH):
    i=0
    with tqdm(dataloader, unit="batch") as tepoch:
        for img,recipe, y in tepoch:
            tepoch.set_description(f"Epoch {epoch}")

            img = img.to(device)
            recipe = recipe.to(device)

            y = torch.reshape(y, (y.shape[0],1))
            y = y.float()
            y = y.to(device)

            pred = model(img, recipe)
            loss = loss_fn(pred, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            LOSS = loss.item()

            tepoch.set_postfix(loss=loss.item())
            if i%500 == 0:
                torch.save({'epoch': epoch,'batch': i,'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'loss': LOSS,}, PATH)
            i+=1
        torch.save({'epoch': epoch,'batch': i,'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'loss': LOSS,}, PATH)
        
    

In [19]:
# checkpoint = torch.load('models_ins/model.pt')

In [20]:
# transformer_model.load_state_dict(checkpoint['model_state_dict'])
# optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

In [21]:
epochs = 1
PATH = 'models_ins2/model.pt'
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_full(t,train_dataloader, transformer_model, loss_fn, optimizer, PATH)
    
    #test(test_dataloader, model, loss_fn)
print("Done!")

  0%|                                               | 0/2200 [00:00<?, ?batch/s]

Epoch 1
-------------------------------


Epoch 0:  17%|██▎           | 372/2200 [59:18<5:26:34, 10.72s/batch, loss=0.692]Corrupt JPEG data: 22 extraneous bytes before marker 0xd9
Epoch 0:  28%|███▍        | 620/2200 [1:42:16<3:32:49,  8.08s/batch, loss=0.744]Corrupt JPEG data: premature end of data segment
Epoch 0:  44%|█████▎      | 963/2200 [2:27:40<2:44:07,  7.96s/batch, loss=0.685]Corrupt JPEG data: premature end of data segment
Epoch 0:  75%|████████▎  | 1652/2200 [3:58:51<1:12:03,  7.89s/batch, loss=0.699]Corrupt JPEG data: bad Huffman code
Epoch 0:  96%|████████████▌| 2123/2200 [5:01:16<09:50,  7.66s/batch, loss=0.698]Corrupt JPEG data: bad Huffman code
Epoch 0: 100%|█████████████| 2200/2200 [5:11:27<00:00,  8.49s/batch, loss=0.726]


Done!
