In [1]:
from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision
from torchvision import datasets, transforms
from torchvision.models import inception_v3, Inception_V3_Weights
from torchvision.datasets.folder import default_loader, IMG_EXTENSIONS
import matplotlib.pyplot as plt
import time
import os
import copy
import random
from PIL import Image
from torch.utils.data import (TensorDataset, 
                              Dataset, 
                              Subset,
                              random_split,
                              DataLoader,
                              RandomSampler, 
                              SequentialSampler, 
                              )
from transformers import BertTokenizer, BertForSequenceClassification
from models import initialize_vision_model, initialize_language_model
from GarbageUtils import GarbageDataset, split_dataset, GarbageImageFolder, append_value



device = torch.device("cuda" if torch.cuda.is_available else "cpu")

In [2]:
print(f'Device: {device}')
print("PyTorch Version: ",torch.__version__)
print("Torchvision Version: ",torchvision.__version__)

torch.cuda.empty_cache()

Device: cuda
PyTorch Version:  1.13.1+cu117
Torchvision Version:  0.14.1+cu117


In [3]:
data_dir = "./data"
vision_model_name = "inception"
language_model_name = "bert-base-uncased"
num_classes = 4
batch_size = 16
epochs = 5
feature_extract = True

#Note: this won't work with efficientnet, can't multiply the matrices.
out_features = 2052

In [4]:
if language_model_name == "bert-base-uncased":
    out_features = 2052

In [5]:
vision_model, input_size, vision_out_features = initialize_vision_model(vision_model_name, num_classes, feature_extract, 
                                                                        use_pretrained=True)
language_model, tokenizer = initialize_language_model(language_model_name, num_classes, multimodal=True)


Initializing InceptionV3 with weights=Inception_V3_Weights.DEFAULT...
Input size = 299


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at


Initializing Bert-Base-Uncased...


In [6]:
vision_model

Inception3(
  (Conv2d_1a_3x3): BasicConv2d(
    (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
    (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (Conv2d_2a_3x3): BasicConv2d(
    (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (Conv2d_2b_3x3): BasicConv2d(
    (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (maxpool1): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  (Conv2d_3b_1x1): BasicConv2d(
    (conv): Conv2d(64, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(80, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (Conv2d_4a_3x3): BasicConv2d(
    (conv): Conv2d(80, 192, kernel_size=(3, 3), stri

In [7]:
def validate_image(path):
    try:
        im = Image.open(path)
        return True
    except:
        return False

In [8]:
image_dataset = GarbageImageFolder(data_dir, is_valid_file=validate_image)

In [9]:
image_dataset

Dataset GarbageImageFolder
    Number of datapoints: 5312
    Root location: ./data

In [10]:
classes =  image_dataset.classes
num_classes = len(classes)
print(classes)
print(f'Num of Classes: {num_classes}')

['black', 'blue', 'green', 'other']
Num of Classes: 4


In [11]:
image_dataset.class_to_idx

{'black': 0, 'blue': 1, 'green': 2, 'other': 3}

In [12]:
a = slice(-3, -1)
image_dataset[0:6]

[(<PIL.Image.Image image mode=RGB size=800x800 at 0x22DDCDE8100>, 0),
 (<PIL.Image.Image image mode=RGB size=800x800 at 0x22DEA0008B0>, 0),
 (<PIL.Image.Image image mode=RGB size=800x800 at 0x22DEA000C10>, 0),
 (<PIL.Image.Image image mode=RGB size=800x800 at 0x22DEA000D90>, 0),
 (<PIL.Image.Image image mode=RGB size=800x800 at 0x22DEA000AC0>, 0),
 (<PIL.Image.Image image mode=RGB size=1734x1301 at 0x22DEA000760>, 0)]

In [13]:
image_dataset[a]

[(<PIL.Image.Image image mode=RGB size=1155x1600 at 0x22D8011ACA0>, 3),
 (<PIL.Image.Image image mode=RGB size=2615x3044 at 0x22D8011ACD0>, 3)]

In [14]:
train_set, val_set, test_set = split_dataset(image_dataset.imgs, test_size=0.2)

In [15]:
def get_dataloaders(input_size, train_set, val_set, test_set):
    from torchvision import datasets, transforms
    data_transforms = {
        'train': transforms.Compose([
            transforms.RandomResizedCrop(input_size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'val': transforms.Compose([
            transforms.Resize(input_size),
            transforms.CenterCrop(input_size),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
    }
    
    train_set = GarbageDataset(train_set, is_subset=False, transform=data_transforms['train'])
    val_set = GarbageDataset(val_set,  is_subset=False, transform=data_transforms['val'])
    test_set = GarbageDataset(test_set,  is_subset=False, transform=data_transforms['val'])
    
    print("Loading data...")
    print(f'Train set size: {len(train_set)}')
    print(f'Val set size: {len(val_set)}')
    print(f'Test set size: {len(test_set)}')
    
    dataloaders_dict = {
        'train': DataLoader(train_set, batch_size = batch_size, shuffle=True, num_workers=4, drop_last=True),
        'val': DataLoader(val_set, batch_size = batch_size, shuffle=False, num_workers=4, drop_last=True)
    }
    
    test_dataloader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=4, drop_last=True)
    
#     print("Loading Datasets and Initializing DataLoaders...")
    return test_dataloader, dataloaders_dict

test_dataloader, dataloaders_dict = get_dataloaders(input_size, train_set, val_set, test_set)

Loading data...
Train set size: 3187
Val set size: 1062
Test set size: 1062


In [16]:
class MultiModalGarbageModel(torch.nn.Module):
    def __init__(self, num_classes, text_module, vision_module, text_module_name, vision_module_name,
                 out_features_combined, dropout_p=None):
        super(MultiModalGarbageModel, self).__init__()
        self.text_module = text_module
        self.vision_module = vision_module
        self.text_module_name = text_module_name
        self.vision_module_name = vision_module_name
#         self.fc = torch.nn.Linear(out_features_combined, int(out_features_combined/2))
#         self.fc2 = torch.nn.Linear(int(out_features_combined/2), int(out_features_combined/4))
#         self.fc3 = torch.nn.Linear(int(out_features_combined/4), num_classes)
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(out_features_combined, int(out_features_combined/2)),
            nn.ReLU(),
            nn.Linear(int(out_features_combined/2), num_classes),
        )

    def forward(self, vision_data, text_data, attention_mask):
        text_out, vision_out = 0, 0
        
        # get output from vision model
        if self.vision_module_name == "inception":
            self.vision_module.aux_logits = False
            vision_out = self.vision_module(vision_data)
        else:
            vision_out = self.vision_module(vision_data)

        # get output from text model    
        if self.text_module_name == "bert-base-uncased":
            out = self.text_module(text_data, attention_mask)
            text_out = out[0]
        else:
            text_out = self.text_module(text_data, attention_mask)
            
        combined = torch.cat((vision_out, text_out), dim=1)
        combined = combined.view(combined.size(0), -1)
        logits = self.linear_relu_stack(combined)
#         combined = self.fc(combined)
        
        return logits

In [17]:
model = MultiModalGarbageModel(num_classes, language_model, vision_model, language_model_name,
                                          vision_model_name, out_features)

In [18]:
def train_model(model, dataloaders, criterion, optimizer, epochs=25, is_inception=False, result_dict=None):
    since = time.time()
    
    val_acc_history = list()
    
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    
    for epoch in range(1, epochs+1):
#         print(f'Epoch {epoch+1}/{epochs}')
#         print("-" * 10)
        
        for phase in ["train", "val"]:
            if phase == "train":
                model.train()
            else:
                model.eval()
            running_loss = 0.0
            running_corrects = 0
            
            #Iterate over the data # image_file, label, input_ids, attention_mask, file_name
            for inputs, labels, in_ids, att_mask in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)
                in_ids = in_ids.to(device)
                att_mask = att_mask.to(device)
                
                optimizer.zero_grad() # to zero the parameter gradients
                
                # forward, track history if only in train mode
                with torch.set_grad_enabled(phase == "train"):
                    outputs = model(inputs, in_ids, att_mask)
                    loss = criterion(outputs, labels)
#                     if is_inception and phase == "train":
#                         outputs, aux_outputs = model(inputs)
#                         loss1 = criterion(outputs, labels)
#                         loss2 = criterion(aux_outputs, labels)
#                         loss = loss1 + 0.4*loss2
#                     else:
#                         outputs = model(inputs)
#                         loss = criterion(outputs, labels)
                    
                    _, preds = torch.max(outputs, 1)
                    
                    # backward, optimize only if in train mode
                    if phase == "train":
                        loss.backward()
                        optimizer.step()
                        
                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
                
            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
            
#             print(f'{phase} Loss: {epoch_loss} Accuracy: {epoch_acc}')
            print('-' * 59)
            print('| Epoch {:3d}/{:3d} | {} Loss: {:8.3f} | {} Accuracy {:8.3f} |'.format(
                epoch, epochs, phase, epoch_loss, phase, epoch_acc))
            print('-' * 59)
            
            if result_dict is not None:
                append_value(result_dict, "Epoch", epoch)
                append_value(result_dict, phase+" Accuracy", epoch_acc)
                append_value(result_dict, phase+" Loss", epoch_loss)
            
            # deepcopy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
            if phase == 'val':
                val_acc_history.append(epoch_acc)
        print()
        
    time_elapsed = time.time() - since
    print('Training Complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best Validation Accuracy: {:.04f}'.format(best_acc))
    
    model.load_state_dict(best_model_wts)
    
    return model, val_acc_history

In [19]:
# def train_model(model, dataloaders, criterion, optimizer, epochs=25, is_inception=False):
#     since = time.time()
    
#     val_acc_history = list()
    
#     best_model_wts = copy.deepcopy(model.state_dict())
#     best_acc = 0.0
    
#     for epoch in range(epochs):
#         print(f'Epoch {epoch}/{epochs - 1}')
#         print("-" * 10)
        
#         for phase in ["train", "val"]:
#             if phase == "train":
#                 model.train()
#             else:
#                 model.eval()
#             running_loss = 0.0
#             running_corrects = 0
            
#             #Iterate over the data # image_file, label, input_ids, attention_mask, file_name
#             for inputs, labels, _, _ in dataloaders[phase]:
#                 inputs = inputs.to(device)
#                 labels = labels.to(device)
                
#                 optimizer.zero_grad() # to zero the parameter gradients
                
#                 # forward, track history if only in train mode
#                 with torch.set_grad_enabled(phase == "train"):
#                     if is_inception and phase == "train":
#                         outputs, aux_outputs = model(inputs)
#                         loss1 = criterion(outputs, labels)
#                         loss2 = criterion(aux_outputs, labels)
#                         loss = loss1 + 0.4*loss2
#                     else:
#                         outputs = model(inputs)
#                         loss = criterion(outputs, labels)
                    
#                     _, preds = torch.max(outputs, 1)
                    
#                     # backward, optimize only if in train mode
#                     if phase == "train":
#                         loss.backward()
#                         optimizer.step()
                        
#                 # statistics
#                 running_loss += loss.item() * inputs.size(0)
#                 running_corrects += torch.sum(preds == labels.data)
                
#             epoch_loss = running_loss / len(dataloaders[phase].dataset)
#             epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
            
#             print(f'{phase} Loss: {epoch_loss} Accuracy: {epoch_acc}')
            
#             # deepcopy the model
#             if phase == 'val' and epoch_acc > best_acc:
#                 best_acc = epoch_acc
#                 best_model_wts = copy.deepcopy(model.state_dict())
#             if phase == 'val':
#                 val_acc_history.append(epoch_acc)
#         print()
        
#     time_elapsed = time.time() - since
#     print('Training Complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
#     print('Best Validation Acuuracy: {:.04f}'.format(best_acc))
    
#     model.load_state_dict(best_model_wts)
    
#     return model, val_acc_history

In [20]:
# send model to device: gpu or cpu
model = model.to(device)

params_to_update = model.parameters()
print("Parameters to learn:")
if feature_extract:
    params_to_update = list()
    for name, param in model.named_parameters():
        if param.requires_grad == True:
            params_to_update.append(param)
            print("\t", name)
else:
    for name, param in model.named_parameters():
        if param.requires_grad == True:
            print("\t", name)
            
optimizer_ft = optim.SGD(params_to_update, lr=0.001, momentum=0.9)

Parameters to learn:
	 text_module.bert.embeddings.word_embeddings.weight
	 text_module.bert.embeddings.position_embeddings.weight
	 text_module.bert.embeddings.token_type_embeddings.weight
	 text_module.bert.embeddings.LayerNorm.weight
	 text_module.bert.embeddings.LayerNorm.bias
	 text_module.bert.encoder.layer.0.attention.self.query.weight
	 text_module.bert.encoder.layer.0.attention.self.query.bias
	 text_module.bert.encoder.layer.0.attention.self.key.weight
	 text_module.bert.encoder.layer.0.attention.self.key.bias
	 text_module.bert.encoder.layer.0.attention.self.value.weight
	 text_module.bert.encoder.layer.0.attention.self.value.bias
	 text_module.bert.encoder.layer.0.attention.output.dense.weight
	 text_module.bert.encoder.layer.0.attention.output.dense.bias
	 text_module.bert.encoder.layer.0.attention.output.LayerNorm.weight
	 text_module.bert.encoder.layer.0.attention.output.LayerNorm.bias
	 text_module.bert.encoder.layer.0.intermediate.dense.weight
	 text_module.bert.encode

In [None]:
result_dict = {}
criterion = nn.CrossEntropyLoss()
model, hist = train_model(model, dataloaders_dict, criterion, optimizer_ft, epochs = 100, result_dict=result_dict,)

-----------------------------------------------------------
| Epoch   1/100 | train Loss:    1.235 | train Accuracy    0.459 |
-----------------------------------------------------------
-----------------------------------------------------------
| Epoch   1/100 | val Loss:    1.125 | val Accuracy    0.521 |
-----------------------------------------------------------

-----------------------------------------------------------
| Epoch   2/100 | train Loss:    1.084 | train Accuracy    0.551 |
-----------------------------------------------------------
-----------------------------------------------------------
| Epoch   2/100 | val Loss:    0.971 | val Accuracy    0.628 |
-----------------------------------------------------------

-----------------------------------------------------------
| Epoch   3/100 | train Loss:    0.980 | train Accuracy    0.597 |
-----------------------------------------------------------
-----------------------------------------------------------
| Epoch   3

In [None]:
# for x in range(1, 1062):
#     if 1062%x == 0:
#         print(f'rem: {1062%x} --> {x}')

In [None]:
model_copy = model

In [None]:
PATH = "model.pth"

In [None]:
torch.save(model.state_dict(), PATH)

In [None]:
model = MultiModalGarbageModel(num_classes, language_model, vision_model, language_model_name,
                                          vision_model_name, out_features)
model.load_state_dict(torch.load(PATH))
model.to(device)
model.eval()