# Transformer Design & Architecture

In [None]:
import torch
import torch.nn as nn
from transformers import BertModel, BertTokenizer, ViTModel, Wav2Vec2Model, Wav2Vec2Tokenizer
import torchvision.transforms as T
from PIL import Image
import librosa

class MultimodalTransformer(nn.Module):
    def __init__(self, text_model_name, image_model_name, audio_model_name, hidden_dim=768, n_heads=8, n_layers=6):
        super(MultimodalTransformer, self).__init__()

        # Tokenizers and models for text, image, and audio
        self.text_tokenizer = BertTokenizer.from_pretrained(text_model_name)
        self.text_model = BertModel.from_pretrained(text_model_name)

        self.image_model = ViTModel.from_pretrained(image_model_name)
        self.image_transform = T.Compose([
            T.Resize((224, 224)),
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

        self.audio_tokenizer = Wav2Vec2Tokenizer.from_pretrained(audio_model_name)
        self.audio_model = Wav2Vec2Model.from_pretrained(audio_model_name)

        # Linear projection layers to align dimensions for fusion
        self.text_proj = nn.Linear(self.text_model.config.hidden_size, hidden_dim)
        self.image_proj = nn.Linear(self.image_model.config.hidden_size, hidden_dim)
        self.audio_proj = nn.Linear(self.audio_model.config.hidden_size, hidden_dim)

        # Transformer layers for cross-modal fusion
        self.transformer = nn.Transformer(
            d_model=hidden_dim, nhead=n_heads, num_encoder_layers=n_layers, num_decoder_layers=n_layers
        )

        # Final classification layer (this can be changed based on your task)
        self.classifier = nn.Linear(hidden_dim, 10)  # Example: 10-class classification task

    def preprocess_text(self, text):
        # Tokenize and get input_ids and attention_mask
        encoding = self.text_tokenizer(text, return_tensors="pt", padding=True, truncation=True)
        return encoding["input_ids"], encoding["attention_mask"]

    def preprocess_image(self, image_path):
        # Load and transform the image
        image = Image.open(image_path)
        return self.image_transform(image).unsqueeze(0)  # Add batch dimension

    def preprocess_audio(self, audio_path):
        # Load audio and tokenize it
        audio, sr = librosa.load(audio_path, sr=16000)
        encoding = self.audio_tokenizer(audio, return_tensors="pt", padding=True)
        return encoding["input_values"]

    def forward(self, text, image_path, audio_path):
        # Preprocess the inputs
        text_input_ids, text_attention_mask = self.preprocess_text(text)
        image_input = self.preprocess_image(image_path)
        audio_input = self.preprocess_audio(audio_path)

        # Extract features from each modality
        text_features = self.text_model(input_ids=text_input_ids, attention_mask=text_attention_mask).last_hidden_state
        image_features = self.image_model(pixel_values=image_input).last_hidden_state
        audio_features = self.audio_model(input_values=audio_input).last_hidden_state

        # Project the features into a common dimension
        text_proj = self.text_proj(text_features)
        image_proj = self.image_proj(image_features)
        audio_proj = self.audio_proj(audio_features)

        # Concatenate the projected features along the sequence dimension
        multimodal_features = torch.cat((text_proj, image_proj, audio_proj), dim=1)

        # Use Transformer layers to fuse the multimodal features
        fused_features = self.transformer(multimodal_features, multimodal_features)

        # Use the fused features for classification (use the [CLS] token)
        logits = self.classifier(fused_features[:, 0, :])

        return logits

# Hyperparameter Tuning using Optuna


In [None]:
import torch
import torch.nn as nn
import optuna
from transformers import BertModel, ViTModel, Wav2Vec2Model
import torch.optim as optim
from sklearn.metrics import accuracy_score

# Define the objective function for Optuna to minimize
def objective(trial):
    # Step 1: Define the hyperparameters to tune
    learning_rate = trial.suggest_loguniform("lr", 1e-5, 1e-3)
    weight_decay = trial.suggest_loguniform("weight_decay", 1e-6, 1e-2)

    # Step 2: Initialize the model
    model = MultimodalTransformerTuner(
        text_model_name="bert-base-uncased",
        image_model_name="google/vit-base-patch16-224",
        audio_model_name="facebook/wav2vec2-base-960h"
    )

    # Step 3: Define the optimizer
    optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

    # Step 4: Define loss function (CrossEntropyLoss for classification tasks)
    criterion = nn.CrossEntropyLoss()

    # Step 5: Training loop (replace with actual training data)
    # Dummy data for illustration
    text_input = "This is an example text input."
    image_path = "example_image.jpg"
    audio_path = "example_audio.wav"

    labels = torch.tensor([0])  # Dummy labels for one batch
    model.train()

    optimizer.zero_grad()
    logits = model(text_input, image_path, audio_path)

    # Calculate loss
    loss = criterion(logits, labels)
    loss.backward()
    optimizer.step()

    # Here, you can add more training steps, batch-wise iteration, etc.

    # Step 6: Return the performance metric to Optuna
    return loss.item()

# Run the Optuna study to find the best hyperparameters
study = optuna.create_study(direction="minimize")  # Minimize the loss
study.optimize(objective, n_trials=10)  # You can adjust the number of trials

# Get the best hyperparameters
best_params = study.best_params
print(f"Best hyperparameters: {best_params}")

In [None]:
# Initialize final model with the best hyperparameters
final_model = MultimodalTransformerTuner(
    text_model_name="bert-base-uncased",
    image_model_name="google/vit-base-patch16-224",
    audio_model_name="facebook/wav2vec2-base-960h"
)

# Use best hyperparameters for optimizer
final_optimizer = optim.AdamW(final_model.parameters(), lr=best_params["lr"], weight_decay=best_params["weight_decay"])

# Train the model with the optimal settings.