In [1]:
import pandas as pd
from sklearn.metrics import classification_report, confusion_matrix, f1_score
from tqdm.notebook import tqdm
import numpy as np
import json
import os
import os.path as osp

In [2]:
from dataset import MultiModalDataset
from model import FoodItemTagModel

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torchvision import models, transforms

In [4]:
# base_path = "/Users/santhosh.mohan/Downloads/DSCVAssessment/assignments/food_item_tag"
base_path = "C:\\Users\\Mercedez\\Downloads\\santhosh\\food_item_tag"

In [5]:
image_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}


In [6]:
def train_epoch(model, dataloader, loss_func, optim, device = 'cpu'):
    model.train()
    epoch_loss = 0
    total_size = 0
    for batch in tqdm(dataloader):
        image = batch['image'].to(device)
        price = batch['price'].to(device)
        label = batch['label'].to(device)
        text = {}
        text['input_ids'] = batch['input_ids'].to(device)
        text['attention_mask'] = batch['attention_mask'].to(device)
        text['token_type_ids'] = batch['token_type_ids'].to(device)
        optim.zero_grad()
        with torch.set_grad_enabled(True):
            output = model(image,text,price)
            loss = loss_func(output, label)
            loss.backward()
            optim.step()
        epoch_loss += loss.item() * label.size(0)
        total_size += label.size(0)
    epoch_loss = epoch_loss / total_size
    print(f"Training Loss - {epoch_loss}")

def validate_epoch(model, dataloader, loss_func, device = 'cpu'):
    model.eval()
    epoch_loss = 0
    total_size = 0
    for batch in tqdm(dataloader):
        image = batch['image'].to(device)
        price = batch['price'].to(device)
        label = batch['label'].to(device)
        text = {}
        text['input_ids'] = batch['input_ids'].to(device)
        text['attention_mask'] = batch['attention_mask'].to(device)
        text['token_type_ids'] = batch['token_type_ids'].to(device)
        with torch.set_grad_enabled(False):
            output = model(image,text,price)
            loss = loss_func(output, label)
        epoch_loss += loss.item() * label.size(0)
        total_size += label.size(0)
    epoch_loss = epoch_loss / total_size
    print(f"Validation Loss - {epoch_loss}")
    return epoch_loss

In [7]:
def train_image_model(device='cpu', epochs = 10):
    model = FoodItemTagModel(512,512,48)
    model = model.to(device)
    ### loading computed class weights
    with open(f"{base_path}\\data\\class_weights.json") as fh:
        line = fh.readline()
        weights = json.loads(line)
    weights = torch.tensor(weights).to(device)
    loss_func = nn.BCEWithLogitsLoss(pos_weight=weights)
    optimizer = optim.SGD(model.parameters(), lr=0.003, momentum=0.9)
    scheduler = lr_scheduler.MultiplicativeLR(optimizer,lr_lambda=lambda epoch: 0.95)
    training_dataset = MultiModalDataset(f"{base_path}\\data\\training_data.csv", f"{base_path}\\imgs", image_transforms["train"])
    validation_dataset = MultiModalDataset(f"{base_path}\\data\\validation_data.csv", f"{base_path}\\imgs", image_transforms["val"])


    training_dataloader = torch.utils.data.DataLoader(training_dataset, batch_size=24,
                                                 shuffle=True, num_workers=0)
    validation_dataloader = torch.utils.data.DataLoader(validation_dataset, batch_size=24,
                                                 shuffle=True, num_workers=0)#,collate_fn=lambda x: x)
    
    min_score = 1000.0
    for epoch in range(epochs):
        print(f"Epoch - {epoch+1}")
        train_epoch(model, training_dataloader, loss_func, optimizer, device)
        score = validate_epoch(model, validation_dataloader, loss_func, device)
        scheduler.step()
        if score < min_score:
            torch.save(model.state_dict(), f"model_{epoch}.pth")
            if osp.exists("best_model.pth"):
                os.remove("best_model.pth")
            os.rename( f"model_{epoch}.pth","best_model.pth")
        
    return model

In [None]:
model = train_image_model(epochs = 20,device='cuda:0')
torch.save(model.state_dict(), "model.pth")

Some weights of the model checkpoint at cahya/bert-base-indonesian-522M were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
Asking to truncate to max_le

Epoch - 1


  0%|          | 0/331 [00:00<?, ?it/s]

Training Loss - 0.6044637018346214


  0%|          | 0/83 [00:00<?, ?it/s]

Validation Loss - 0.5336963149028912
Epoch - 2


  0%|          | 0/331 [00:00<?, ?it/s]

Training Loss - 0.17234062908912937


  0%|          | 0/83 [00:00<?, ?it/s]

Validation Loss - 0.09926054097488245
Epoch - 3


  0%|          | 0/331 [00:00<?, ?it/s]

Training Loss - 0.10258827897883978


  0%|          | 0/83 [00:00<?, ?it/s]

Validation Loss - 0.0986476982709833
Epoch - 4


  0%|          | 0/331 [00:00<?, ?it/s]

Training Loss - 0.1020623714521176


  0%|          | 0/83 [00:00<?, ?it/s]

Validation Loss - 0.09845744428826654
Epoch - 5


  0%|          | 0/331 [00:00<?, ?it/s]

Training Loss - 0.10161212145231811


  0%|          | 0/83 [00:00<?, ?it/s]

Validation Loss - 0.09847134906414705
Epoch - 6


  0%|          | 0/331 [00:00<?, ?it/s]

Training Loss - 0.10121503627473451


  0%|          | 0/83 [00:00<?, ?it/s]

Validation Loss - 0.09835176985950446
Epoch - 7


  0%|          | 0/331 [00:00<?, ?it/s]

Training Loss - 0.10124152375012552


  0%|          | 0/83 [00:00<?, ?it/s]

Validation Loss - 0.09817874254671632
Epoch - 8


  0%|          | 0/331 [00:00<?, ?it/s]

Training Loss - 0.10100065956178092


  0%|          | 0/83 [00:00<?, ?it/s]

Validation Loss - 0.0981305985927677
Epoch - 9


  0%|          | 0/331 [00:00<?, ?it/s]

Training Loss - 0.1008837329324409


  0%|          | 0/83 [00:00<?, ?it/s]

Validation Loss - 0.09808149623747775
Epoch - 10


  0%|          | 0/331 [00:00<?, ?it/s]

Training Loss - 0.10057230927746061


  0%|          | 0/83 [00:00<?, ?it/s]

Validation Loss - 0.09806346770798494
Epoch - 11


  0%|          | 0/331 [00:00<?, ?it/s]

Training Loss - 0.1007575128528159


  0%|          | 0/83 [00:00<?, ?it/s]

Validation Loss - 0.09797884565578696
Epoch - 12


  0%|          | 0/331 [00:00<?, ?it/s]

Training Loss - 0.10078865978633164


  0%|          | 0/83 [00:00<?, ?it/s]

Validation Loss - 0.09799759070211589
Epoch - 13


  0%|          | 0/331 [00:00<?, ?it/s]

Training Loss - 0.10059132563645555


  0%|          | 0/83 [00:00<?, ?it/s]

Validation Loss - 0.09789924321880908
Epoch - 14


  0%|          | 0/331 [00:00<?, ?it/s]

Training Loss - 0.10048047487254651


  0%|          | 0/83 [00:00<?, ?it/s]

Validation Loss - 0.09786901574796035
Epoch - 15


  0%|          | 0/331 [00:00<?, ?it/s]

Training Loss - 0.10056939990455188


  0%|          | 0/83 [00:00<?, ?it/s]

Validation Loss - 0.09784142547866764
Epoch - 16


  0%|          | 0/331 [00:00<?, ?it/s]