In [1]:
import os
import re
import json
import torch
from torch import nn
from PIL import Image
from ultralytics import YOLO
from torchvision import models, transforms
from torch.utils.data import Dataset, DataLoader
from safetensors.torch import save_file, load_file
from transformers.tokenization_utils_base import BatchEncoding
from transformers import (PreTrainedModel, PretrainedConfig, AutoModelForSequenceClassification, 
                          AutoTokenizer, ViTForImageClassification, ViTImageProcessor, TrainingArguments,
                          Trainer)

from sklearn.metrics import f1_score
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


## Config
Updated config for new model structure

In [2]:
class MultimodalConfig(PretrainedConfig):
    model_type = "multimodal"

    def __init__(self,
                 models=None,
                 **kwargs):
        super().__init__(**kwargs)
        self.models = models if models is not None else {}
        self.image_models = [i for i in self.models.keys() if 'prompt' not in i]

        self.resnet_model_paths = {i:v for i,v in self.models.items() if 'resnet' in i.lower()}
        self.vit_model_paths = {i:v for i,v in self.models.items() if 'vit' in i}
        self.yolo_model_paths = {i:v for i,v in self.models.items() if 'yolo' in i}

        self.text_models = [i for i in self.models.keys() if 'prompt' in i]
        self.nlp_transformers_model_paths = {i:v for i,v in self.models.items() if 'prompt' in i}
        
        self.class_names = ["PG", "PG13", "R", "X", "XXX"]
        self.label2id = {label: i for i, label in enumerate(self.class_names)}
        self.id2label = {i: label for label, i in self.label2id.items()}


In [3]:
model_paths = {i: os.path.join('../../multimodalComposite/models/', i) for i in os.listdir('../../multimodalComposite/models/') if 'multimodal' not in i and 'yoloRater' not in i}
print(model_paths)

{'.DS_Store': '../../multimodalComposite/models/.DS_Store', 'resNet50CV': '../../multimodalComposite/models/resNet50CV', 'june_yolo_cls': '../../multimodalComposite/models/june_yolo_cls', 'baseresNet50': '../../multimodalComposite/models/baseresNet50', 'promptRoberta': '../../multimodalComposite/models/promptRoberta', 'resNet18CV': '../../multimodalComposite/models/resNet18CV', 'promptTagBert': '../../multimodalComposite/models/promptTagBert', 'vitRater': '../../multimodalComposite/models/vitRater', 'promptBert': '../../multimodalComposite/models/promptBert', 'baseresNet18': '../../multimodalComposite/models/baseresNet18'}


In [None]:
## From our analysis, these are the 5 models we want to train
## megamix - vitRater	YoloCLS	resnet50CV	promptbertTag	promptRoberta
# model_paths = {key:model_paths[key] for key in ['vitRater', 'june_yolo_cls', 'resNet50CV', 'promptTagBert', 'promptRoberta']}

## top3f1 - vitRater	promptTagBert	promptRoberta
model_paths = {key:model_paths[key] for key in ['vitRater', 'promptTagBert', 'promptRoberta']}

## top3f1 (better pg, pg13 performance) - YOLOCLS	promptTagBert	promptRoberta
# model_paths = {key:model_paths[key] for key in ['june_yolo_cls', 'promptTagBert', 'promptRoberta']}

## top3Acc - vitRater	promptRoberta	YoloCLS
# model_paths = {key:model_paths[key] for key in ['vitRater', 'june_yolo_cls', 'promptRoberta']}

## top3Acc (best R performance) -vitRater	promptRoberta	baseResnet50
# model_paths = {key:model_paths[key] for key in ['vitRater', 'baseresNet50', 'promptRoberta']}


In [5]:
config = MultimodalConfig(models=model_paths)

## Dataset

fixed model paths and dataset to work with dataloader and hf trainer api

In [13]:
class CustomImageTextDataset(Dataset):
    def __init__(self, data_dir, target_transform=None):
        self.data_dir = data_dir        
        self.files = self.get_files()
        self.labels = self.get_files(True)
        self.target_transform = target_transform
        
    def get_files(self, labels = False):
        files = []
        for i in os.listdir(self.data_dir):
            if labels:
                imgs = [i for j in os.listdir(os.path.join(self.data_dir,i)) if j.endswith(".jpg")]
            else:
                imgs = [os.path.join(self.data_dir, i, j.split('.')[0]) for j in os.listdir(os.path.join(self.data_dir,i)) if j.endswith(".jpg")]
            files.extend(imgs)
        return files
    
    def __len__(self):
        return len(self.files)
    
    def __getitem__(self, idx):
        img_path = self.files[idx] + '.jpg'
        txt_path = self.files[idx] + '.txt'
        tag_path = self.files[idx] + '_tags.txt'
        # image = Image.open(img_path)
        try:
            text = open(txt_path, 'r').read()
        except:
            text = ''
        try:
            tags = open(tag_path, 'r').read()
        except:
            tags = None
        label = self.labels[idx]
        if self.target_transform:
            label = self.target_transform[label]
            if isinstance(label, int):
                label = torch.tensor(label)
        return img_path, text, tags, label
    
    def __repr__(self):
        string = f"CustomImageTextDataset(num_samples={len(self.files)}, data_dir='{self.data_dir}')\n"
        return string + f"Unique Labels: {list(set(self.labels))} for phase: {os.path.basename(self.data_dir)}"
    

In [14]:
datadir = "./data/datasets/mayWithTags/train"
validir = "./data/datasets/mayWithTags/val"

In [15]:
training_data = CustomImageTextDataset(datadir, target_transform=config.label2id)
validation_data = CustomImageTextDataset(validir, target_transform=config.label2id)

FileNotFoundError: [Errno 2] No such file or directory: './data/datasets/mayWithTags/train'

In [8]:
# training_data[1000]

## Processor

created processor to simplify dataset AND model code. 

Processor now works like hf processors in that it takes an image path or text from a dataset and, depending on the models in the config/processor, we can process the modalities and serve them to the proper models. 

In [6]:
class MultimodalProcessor:
    def __init__(self, models = None):
        self.models = models if models is not None else {}
        self.image_models = [i for i in self.models.keys() if 'prompt' not in i]
        self.text_models = [i for i in self.models.keys() if 'prompt' in i]
        self.image_processors = self.load_image_processors()
        self.text_processors = self.load_text_processors()
    
    @staticmethod
    def clean_text(text: str) -> str:
        text = str(text)
        cleaned_text = re.sub(r"[():<>[\]]", " ", text)
        cleaned_text = cleaned_text.replace("\n", " ")
        cleaned_text = re.sub(r"\s+", " ", cleaned_text)
        cleaned_text = re.sub(r"\s*,\s*", ", ", cleaned_text)
        return cleaned_text.strip()         
    
    def load_image_processors(self):
        """Use either transforms native to the model architecture or the vit image processor
        NOTE: if you change the transforms an image undergoes during training, you should change them here as well
        """
        image_processors = {}
        for model_name in self.image_models:
            if 'vit' in model_name.lower():
                try:
                    processor = ViTImageProcessor.from_pretrained(self.models[model_name])
                    image_processors[model_name] = processor
                except Exception as e:
                    print(f"Error loading {model_name}: {e}")
            elif 'resnet' in model_name.lower() and 'resnet' not in image_processors.keys():
                try:
                    processor = transforms.Compose([
                        transforms.Resize(256),
                        transforms.CenterCrop(224),
                        transforms.ToTensor(),
                        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
                    ])
                    image_processors['resnet'] = processor
                except Exception as e:
                    print(f"Error loading {model_name}: {e}")
            elif 'resnet' in model_name.lower() and 'resnet' in image_processors.keys():
                continue
            elif 'yolo' and 'cls' in model_name.lower():
                image_processors[model_name] = YOLO(os.path.join(self.models[model_name], 'best_model_params.pt')).transforms
            else:
                #We expect DET Yolos to be None, as they just
                #take in the PIL.Image and return the output
                image_processors[model_name] = None
        return image_processors    

    def load_text_processors(self):
        text_processors = {}
        for model_name in self.text_models:
            try:
                tokenizer = AutoTokenizer.from_pretrained(self.models[model_name])
                text_processors[model_name] = tokenizer
            except Exception as e:
                print(f"Error loading {model_name}: {e}")
        return text_processors
    
    def __call__(self, img_path, text, tags, label):
        inputs = {}
        img = Image.open(img_path)
        
        for model in self.image_models:
            if 'vit' in model.lower():
                inputs[model] = self.image_processors[model](img,
                        return_tensors="pt")['pixel_values']
        
            elif 'resnet' in model.lower():
                inputs[model] = self.image_processors['resnet'](img).unsqueeze(0)
        
            elif 'yolo' and 'cls' in model.lower():
                inputs[model] = self.image_processors[model](img)
            else:
                inputs[model] = img_path
        
        for model in self.text_models:
            if 'tag' in model.lower() and tags:
                text_in = ' '.join([text, tags])
            else:
                text_in = text
            
            text_in = self.clean_text(text_in)
            inputs[model] = self.text_processors[model](text_in, return_tensors="pt", truncation=True, padding='max_length')
        
        inputs['labels'] = label
        return inputs

In [7]:
processor = MultimodalProcessor(models = model_paths)

In [18]:
processor(img_path= training_data[0][0],
          text= training_data[0][1],
          tags= training_data[0][2],
          label= training_data[0][3])

NameError: name 'training_data' is not defined

## Dataloader with collator

In [19]:
def custom_collator(batch):
    img_paths = [item[0] for item in batch]
    texts = [item[1] for item in batch]
    tags = [item[2] for item in batch]
    labels = [item[3] for item in batch]
    
    # Assuming processor is defined elsewhere and processes the data correctly
    processed_batch = [processor(img_paths[i], texts[i], tags[i], labels[i]) for i in range(len(img_paths))]
    
    out = {}
    
    for key in processed_batch[0].keys():
        if isinstance(processed_batch[0][key], torch.Tensor):
            out[key] = torch.stack([i[key] for i in processed_batch])
        elif isinstance(processed_batch[0][key], BatchEncoding):

            tmp = [i[key] for i in processed_batch]
           
            var = {k: [] for k in tmp[0].keys()}
            for k in var.keys():
                var[k] = torch.stack([i[k] for i in tmp])
                
            out[key] = var
        else:
            out[key] = [i[key] for i in processed_batch]
   
    return out

In [20]:
train_dataloader = DataLoader(training_data, batch_size=5, shuffle=True, collate_fn=custom_collator)
val_dataloader = DataLoader(validation_data, batch_size=5, shuffle=True, collate_fn=custom_collator)

NameError: name 'training_data' is not defined

In [14]:
for batch in train_dataloader:
    print(batch)
    print(batch.keys())
    break

{'baseresNet18': tensor([[[[[ 0.2453,  0.2453,  0.2453,  ...,  0.2453,  0.2453,  0.2453],
           [ 0.2453,  0.2453,  0.2453,  ...,  0.2453,  0.2453,  0.2453],
           [ 0.2453,  0.2453,  0.2453,  ...,  0.2453,  0.2453,  0.2453],
           ...,
           [ 0.2453,  0.2453,  0.2453,  ...,  0.2453,  0.2453,  0.2453],
           [ 0.2453,  0.2453,  0.2453,  ...,  0.2453,  0.2453,  0.2453],
           [ 0.2453,  0.2453,  0.2453,  ...,  0.2453,  0.2453,  0.2453]],

          [[ 0.3452,  0.3452,  0.3452,  ...,  0.3452,  0.3452,  0.3452],
           [ 0.3452,  0.3452,  0.3452,  ...,  0.3452,  0.3452,  0.3452],
           [ 0.3452,  0.3452,  0.3452,  ...,  0.3452,  0.3452,  0.3452],
           ...,
           [ 0.3452,  0.3452,  0.3452,  ...,  0.3452,  0.3452,  0.3452],
           [ 0.3452,  0.3452,  0.3452,  ...,  0.3452,  0.3452,  0.3452],
           [ 0.3452,  0.3452,  0.3452,  ...,  0.3452,  0.3452,  0.3452]],

          [[ 0.7925,  0.7925,  0.7925,  ...,  0.7925,  0.7925,  0.7925]

  out['labels'] = torch.tensor(out['label'])


In [15]:
for batch in val_dataloader:
    print(batch)
    print(batch.keys())
    break

{'baseresNet18': tensor([[[[[-1.1247, -1.0390, -1.0390,  ..., -1.0733, -1.1760, -1.1589],
           [-1.1075, -1.0733, -1.0048,  ..., -1.1247, -1.1589, -1.1418],
           [-1.1075, -1.0390, -1.0219,  ..., -1.1075, -1.1247, -1.1247],
           ...,
           [-0.3198, -0.2684, -0.2171,  ..., -1.6555, -0.8335,  1.4783],
           [-0.2856, -0.1486, -0.1828,  ..., -0.9020, -1.0390,  1.3242],
           [-0.3369, -0.2684, -0.2342,  ..., -0.5253, -0.9705,  1.1700]],

          [[-0.7052, -0.7227, -0.6877,  ..., -0.7402, -0.8452, -0.8452],
           [-0.6877, -0.7227, -0.6702,  ..., -0.8102, -0.8452, -0.8277],
           [-0.6527, -0.7052, -0.6877,  ..., -0.8277, -0.8452, -0.8452],
           ...,
           [-0.2150, -0.1800, -0.1099,  ..., -1.0553, -0.6001,  1.8508],
           [-0.1450, -0.0399, -0.0749,  ...,  0.0476, -0.6001,  1.6583],
           [-0.1975, -0.1450, -0.1275,  ...,  0.8354, -0.2675,  1.5182]],

          [[-0.3055, -0.2707, -0.3055,  ..., -0.3578, -0.4624, -0.4624]

  out['labels'] = torch.tensor(out['label'])


## Model

In [8]:
class MultimodalModel(PreTrainedModel):
    config_class = MultimodalConfig

    def __init__(self, config: MultimodalConfig, device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')):
        super().__init__(config)
        self.num_classes = 5  # Example with 5 classes
        self.features = {}

        # Initialize the models and set requires_grad = False to freeze them
        self.resnet_models = nn.ModuleDict()
        if config.resnet_model_paths:
            for model_id, model_path in config.resnet_model_paths.items():
                ##Load model architecture
                model = models.resnet50() if "resnet50" in model_id.lower() else models.resnet18()
                prev_fc = model.fc
                model.fc = nn.Linear(model.fc.in_features, self.num_classes)
                ##Load weights
                model.load_state_dict(torch.load(f"{model_path}/best_model_params.pt", map_location='cpu'))
                model.fc = prev_fc
                ##set output to identity to return layer before prediction
                model.fc = nn.Identity()
                ##freeze model
                for param in model.parameters(): 
                    param.requires_grad = False
                model.to(device)
                ##add model to model dict
                self.resnet_models[model_id] = model
                ##store number of features
                self.features[model_id] = prev_fc.in_features

        self.vit_models = nn.ModuleDict()
        if config.vit_model_paths:
            for model_id, model_path in config.vit_model_paths.items():
                model = ViTForImageClassification.from_pretrained(model_path)
                for param in model.parameters():
                    param.requires_grad = False
                model.to(device)
                self.vit_models[model_id] = model
                self.features[model_id] = model.classifier.in_features

        self.yolo_models = nn.ModuleDict()
        if config.yolo_model_paths:
            for model_id, model_path in config.yolo_model_paths.items():
                if 'cls' in model_id:
                    model = YOLO(os.path.join(model_path, 'best_model_params.pt'))
                    sequential_model = model.model.model
                    self.features[model_id] = sequential_model[-1].linear.in_features
                    
                    self.replace_yolo_last_linear_with_identity(sequential_model)
                    for param in sequential_model.parameters():
                        param.requires_grad = False
                    model.to(device)                        
                    self.yolo_models[model_id] = sequential_model

                elif 'det' in model_id:
                    model = YOLO(os.path.join(model_path, 'best_model_params.pt'))
                    model.to(device)
                    self.yolo_models[model_id] = model
                    self.features[model_id] = None
                    
        self.nlp_transformers_models = nn.ModuleDict()
        if config.nlp_transformers_model_paths:
            for model_id, model_path in config.nlp_transformers_model_paths.items():
                model = AutoModelForSequenceClassification.from_pretrained(model_path)
                for param in model.parameters():
                    param.requires_grad = False
                model.to(device)
                self.nlp_transformers_models[model_id] = model
                self.features[model_id] = self.get_classifier_in_features(model)

        # MLP layers (only this part will be fine-tuned)
        self.mlp = nn.Sequential(
            nn.Linear(sum(self.features.values()), 1024),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, config.num_labels)
        )
        self.loss_fn = nn.CrossEntropyLoss()

    def get_classifier_in_features(self, model):
        """
        Function to get the number of input features to the classifier layer.
        Handles different model structures such as BERT and RoBERTa.
        """
        try:
            # RoBERTa's classifier
            return model.classifier.out_proj.in_features
        except AttributeError:
            pass

        try:
            # BERT's classifier
            return model.classifier.linear.in_features
        except AttributeError:
            pass

        try:
            # DistilBERT's classifier
            return model.classifier.in_features
        except AttributeError:
            pass

        try: 
            #ViT's classifier
            return model.classifier.in_features
        except AttributeError:
            pass

        try: 
            return model.fc.in_features
        except AttributeError:
            pass

        # Add other model types as needed
        raise ValueError("Unknown model structure or unsupported model type.")
    
    def forward(self, **inputs):
        
        self.mlp = self.mlp.to(self.device)
        features = []        
        labels = inputs.pop('labels', None)
        
        if labels is not None:
            labels = labels.to(self.device)
            
        for key in inputs.keys():
            input = inputs[key]
            
            if key in self.resnet_models.keys():
                with torch.no_grad():
                    resnet_features = self.resnet_models[key](input.squeeze(1).to(self.device))
                features.append(resnet_features)
            
            if key in self.vit_models.keys():
                with torch.no_grad():
                    vit_features = self.vit_models[key](pixel_values=input.squeeze(1).to(self.device),
                                            output_hidden_states=True).hidden_states[-1][:, 0, :]
                features.append(vit_features)

            if key in self.yolo_models.keys():
                with torch.no_grad():
                    yolo_features = self.yolo_models[key](input.to(self.device))
                features.append(yolo_features)
            
            if key in self.nlp_transformers_models.keys():
                with torch.no_grad():
                    nlp_features = self.nlp_transformers_models[key](input['input_ids'].squeeze(1).to(self.device),
                    input['attention_mask'].squeeze(1).to(self.device),
                    output_hidden_states=True).hidden_states[-1][:, 0, :]
                features.append(nlp_features)
            
        ##Concat features
        features = torch.cat(features, dim=1)
        logits = self.mlp(features)

        loss = None
        if labels is not None:
            loss = self.loss_fn(logits.view(-1, self.config.num_labels), labels.view(-1))
    
        return {"loss": loss, "logits": logits}

    @staticmethod
    def replace_yolo_last_linear_with_identity(model):
        for name, module in model.named_modules():
            if isinstance(module, nn.Linear):
                parent_name, child_name = name.rsplit('.',1)
                parent = model
                for part in parent_name.split('.'):
                    parent = getattr(parent, part)
                setattr(parent, child_name, nn.Identity())
                return True
        return False
    
    def save_pretrained(self, save_directory, state_dict=None, safe_serialization=False):
        os.makedirs(save_directory, exist_ok=True)
        if state_dict is None:
            state_dict = self.state_dict()
        if safe_serialization:
            save_file(state_dict, os.path.join(save_directory, "model.safetensors"))
        else:
            print("save torch")
            torch.save(state_dict, os.path.join(save_directory, "pytorch_model.bin"))
        self.config.save_pretrained(save_directory)
    
    @classmethod
    def from_pretrained(cls, save_directory, **kwargs):
        config = cls.config_class.from_pretrained(save_directory)
        model = cls(config)
        if kwargs.get("safe_serialization", False):
            state_dict = load_file(os.path.join(save_directory, "model.safetensors"))
        else:
            state_dict = torch.load(os.path.join(save_directory, "pytorch_model.bin"), map_location='cpu')
        model.load_state_dict(state_dict)
        return model

In [9]:
model = MultimodalModel(config)

## Check Training

In [18]:
def compute_metrics(eval_pred):
    print("Computing metrics...")
    try:
        logits, labels = eval_pred
        if isinstance(labels, dict):
            labels = labels['label']  # Adjust this key if necessary
        print(f"Logits shape: {logits.shape}, Labels shape: {labels.shape}")
        predictions = np.argmax(logits, axis=-1)
        eval_loss = np.mean(logits - labels)
        accuracy = np.mean(predictions == labels)
        f1 = f1_score(labels, predictions, average='weighted')
        print(f"Computed metrics - Accuracy: {accuracy}, F1: {f1}")
        return {
            "eval_acc": accuracy,
            "eval_f1": f1,
            "eval_loss": eval_loss
        }
    except Exception as e:
        print(f"Error in compute_metrics: {e}")
        import traceback
        traceback.print_exc()
        return {
            "eval_acc": 0,
            "eval_f1": 0,
            "eval_loss": 0
        }

In [19]:
training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=1,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    warmup_steps=10,
    weight_decay=0.01,
    logging_dir='./logs',
    logging_steps=2,
    evaluation_strategy="steps",
    eval_steps=10,  # Set eval_steps to 10
    save_steps=10,
    save_safetensors=True,  # Ensure this is supported by your Transformers version
    disable_tqdm=False,  # Ensure progress bars are shown
    logging_first_step=True,  # Log the first step
)




In [20]:

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=training_data,
    eval_dataset=validation_data,
    data_collator=custom_collator,
)

In [21]:
trainer.train()

  out['labels'] = torch.tensor(out['label'])
  context_layer = torch.nn.functional.scaled_dot_product_attention(
Could not estimate the number of tokens of the input, floating-point operations will not be computed
  0%|          | 1/668 [00:03<41:11,  3.71s/it]

{'loss': 1.5779, 'grad_norm': 5.221979141235352, 'learning_rate': 5e-06, 'epoch': 0.0}


  0%|          | 2/668 [00:06<36:30,  3.29s/it]

{'loss': 1.6059, 'grad_norm': 5.409246444702148, 'learning_rate': 1e-05, 'epoch': 0.0}


  1%|          | 4/668 [00:12<33:38,  3.04s/it]

{'loss': 1.5862, 'grad_norm': 5.218859672546387, 'learning_rate': 2e-05, 'epoch': 0.01}


  1%|          | 6/668 [00:18<32:35,  2.95s/it]

{'loss': 1.5893, 'grad_norm': 5.059216499328613, 'learning_rate': 3e-05, 'epoch': 0.01}


  1%|          | 8/668 [00:24<32:29,  2.95s/it]

{'loss': 1.5346, 'grad_norm': 5.1202826499938965, 'learning_rate': 4e-05, 'epoch': 0.01}


  1%|▏         | 10/668 [00:30<32:14,  2.94s/it]

{'loss': 1.526, 'grad_norm': 4.642451286315918, 'learning_rate': 5e-05, 'epoch': 0.01}


                                                
  1%|▏         | 10/668 [05:50<32:14,  2.94s/it] 

{'eval_runtime': 320.8396, 'eval_samples_per_second': 11.754, 'eval_steps_per_second': 0.368, 'epoch': 0.01}


  out['labels'] = torch.tensor(out['label'])
  2%|▏         | 12/668 [06:00<13:07:27, 72.02s/it] 

{'loss': 1.4351, 'grad_norm': 4.700843811035156, 'learning_rate': 4.984802431610942e-05, 'epoch': 0.02}


  2%|▏         | 13/668 [06:03<9:17:56, 51.11s/it] 

KeyboardInterrupt: 