In [3]:
import torch
import torch.nn as nn
from torchvision import transforms
from transformers import BertTokenizer
from PIL import Image
import json





In [1]:
!pip install torch torchvision
!pip install accelerate -U
!pip install transformers[torch]
! pip install datasets
! pip install --upgrade tqdm
!pip install pytorch-lightning

Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable


In [1]:
import torch
from torch.utils.data import DataLoader
from torchvision import models
import torch.nn as nn
from transformers import BertTokenizer, BertModel
from PIL import Image
import json
import os
from sklearn.metrics import accuracy_score, f1_score

# Define the dataset class
class MemeDataset(torch.utils.data.Dataset):
    def __init__(self, image_dir, annotation_file, tokenizer, label_mapping, transform=None):
        self.image_dir = image_dir
        with open(annotation_file, 'r', encoding='utf-8') as f:
            self.data = json.load(f)
        self.tokenizer = tokenizer
        self.label_mapping = label_mapping
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        image_path = os.path.join(self.image_dir, item['image'])
        image = Image.open(image_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        text = item['text']
        inputs = self.tokenizer(text, return_tensors="pt", max_length=512, truncation=True, padding="max_length")
        input_ids = inputs['input_ids'].squeeze(0)
        attention_mask = inputs['attention_mask'].squeeze(0)
        label_indices = [self.label_mapping[label] for label in item['labels']]
        labels = torch.zeros(len(self.label_mapping))
        labels[label_indices] = 1
        return image, input_ids, attention_mask, labels
    
def create_label_mapping(annotation_file):
    all_labels = set()
    with open(annotation_file, 'r', encoding='utf-8') as f:
        data = json.load(f)
    for item in data:
        all_labels.update(item['labels'])
    return {label: i for i, label in enumerate(sorted(all_labels))}

# Define the model architecture
class MemeClassifier(nn.Module):
    def __init__(self, text_model, num_labels):
        super(MemeClassifier, self).__init__()
        self.text_model = text_model
        self.image_model = models.resnet152(pretrained=True)
        self.image_model.fc = nn.Identity()
        self.classifier = nn.Linear(self.text_model.config.hidden_size + 2048, num_labels)

    def forward(self, images, input_ids, attention_mask):
        with torch.no_grad():
            image_features = self.image_model(images)
            text_features = self.text_model(input_ids=input_ids, attention_mask=attention_mask).pooler_output
        combined_features = torch.cat((text_features, image_features), dim=1)
        logits = self.classifier(combined_features)
        return logits

# Load the pre-trained model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
text_model = BertModel.from_pretrained('bert-base-uncased')

# Paths to your data
image_dir = 'C:\\Users\\harih\\Downloads\\27\\train_images\\train_images'
annotation_file = 'C:\\Users\\harih\\Downloads\\27\\semeval2024_dev_release\\semeval2024_dev_release\\subtask2a\\preprocessed_train_v2.json'

# Create the label mapping
label_mapping = create_label_mapping(annotation_file)
print("Label mapping:", label_mapping)

# Create the dataset and dataloader
train_dataset = MemeDataset(image_dir, annotation_file, tokenizer, label_mapping)
#train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
#print("Train dataloader created.")

# Load the pre-trained model
model = MemeClassifier(BertModel.from_pretrained('bert-base-uncased'), 22)
model_state_dict = torch.load('C:\\Users\\harih\\Downloads\\27\\meme_model_10e.pth', map_location=device)
model.load_state_dict(model_state_dict)
model.to(device)
model.eval()
print("Pre-trained model loaded and set to evaluation mode.")

# Evaluate the model on the training set
#print("Evaluating the model on the training set:")
#evaluate(model, train_loader, device, label_mapping)

# Evaluate the model on the validation set
#val_image_dir = 'C:\\Users\\harih\\Downloads\\27\\dev_images\\dev_images'
#val_annotation_file = 'C:\\Users\\harih\\Downloads\\27\\semeval2024_dev_release\\semeval2024_dev_release\\subtask2a\\dev_subtask2a_en.json'
#val_dataset = MemeDataset(val_image_dir, val_annotation_file, tokenizer, label_mapping)
#val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False)
#print("Validation dataloader created.")

#print("Evaluating the model on the validation set:")
#evaluate(model, val_loader, device, label_mapping)

Using device: cuda
Label mapping: {'Appeal to (Strong) Emotions': 0, 'Appeal to authority': 1, 'Appeal to fear/prejudice': 2, 'Bandwagon': 3, 'Black-and-white Fallacy/Dictatorship': 4, 'Causal Oversimplification': 5, 'Doubt': 6, 'Exaggeration/Minimisation': 7, 'Flag-waving': 8, 'Glittering generalities (Virtue)': 9, 'Loaded Language': 10, "Misrepresentation of Someone's Position (Straw Man)": 11, 'Name calling/Labeling': 12, 'Obfuscation, Intentional vagueness, Confusion': 13, 'Presenting Irrelevant Data (Red Herring)': 14, 'Reductio ad hitlerum': 15, 'Repetition': 16, 'Slogans': 17, 'Smears': 18, 'Thought-terminating cliché': 19, 'Transfer': 20, 'Whataboutism': 21}




Pre-trained model loaded and set to evaluation mode.


In [3]:
import torch
from PIL import Image
from torchvision import transforms
from transformers import BertTokenizer, BertModel

In [5]:
def detect_meme_emotion(image_path, text):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(device)

    # Preprocess the image
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    image = Image.open(image_path).convert('RGB')
    image = transform(image).unsqueeze(0).to(device)

    # Preprocess the text
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    inputs = tokenizer(text, return_tensors="pt", max_length=512, truncation=True, padding="max_length")
    input_ids = inputs['input_ids'].to(device)
    attention_mask = inputs['attention_mask'].to(device)

    # Make the prediction
    with torch.no_grad():
        logits = model(image, input_ids, attention_mask)
        predictions = torch.sigmoid(logits).round().squeeze().cpu().numpy()

    # Get the predicted labels
    reversed_label_mapping = {v: k for k, v in label_mapping.items()}
    predicted_labels = [reversed_label_mapping[i] for i, p in enumerate(predictions) if p == 1]

    return predicted_labels

In [7]:
image_path = 'C:\\Users\\harih\\Downloads\\27\\dev_images\\dev_images\\prop_meme_2704.png'
text = "STOP CALLING IT A CONCENTRATION CAMP\\n\\nITS JUST AN AUSTRALIAN GOVERNMENT-RUN FACILITY WHERE YOU ARE DETAINED WITHOUT A TRIAL, RELOCATED BY THE MILITARY, CANNOT LEAVE WHENEVER YOU WANT, AND WILL BE PURSUED AND ARRESTED FOR TRYING TO ESCAPE, IN THE NAME OF SAFETY\\n\\nTHE ONLY REASON PEOPLE AREN'T SHARING UNDERCOVER VIDEOS OF AMERICANS TRAPPED IN MANDATORY \\QUARANTINE CAMPS\\ IS THE 2ND AMENDMENT"    
predicted_labels = detect_meme_emotion(image_path, text)
print(predicted_labels)

cuda
['Black-and-white Fallacy/Dictatorship', 'Loaded Language']


In [9]:
image_path = 'C:\\Users\\harih\\Downloads\\27\\dev_images\\dev_images\\prop_meme_2704.png'
text = "KNOCK KNOCK\\n\\nWHO'S THERE?\\n\\nWHO'S WHERE?"
predicted_labels = detect_meme_emotion(image_path, text)
print(predicted_labels)

cuda
['Smears']


In [None]:
from flask import Flask, request, jsonify
from flask_cors import CORS
app = Flask(__name__)
CORS(app)

@app.route('/detect_emotion', methods=['POST'])
def detect_emotion():
    # Check if the request contains both image file and text
    if 'image' not in request.files or 'text' not in request.form:
        return jsonify({'error': 'Please provide both image and text'}), 400

    image_file = request.files['image']
    print("Image file", image_file)
    print("Image file name", image_file.filename)
    text = request.form['text']

    # Check if image file and text are not empty
    if image_file.filename == '' or text.strip() == '':
        return jsonify({'error': 'Image file or text is empty'}), 400

    # Save the image file temporarily
    image_path = image_file.filename
    image_file.save(image_path)

    try:
        # Detect emotions in the meme
        predicted_labels = detect_meme_emotion(image_path, text)
        return jsonify({'predicted_labels': predicted_labels}), 200
    except Exception as e:
        return jsonify({'error': str(e)}), 500

if __name__ == '__main__':
    app.run()


 * Serving Flask app '__main__'
 * Debug mode: off


 * Running on http://127.0.0.1:5000
Press CTRL+C to quit


Image file <FileStorage: 'prop_meme_9787.png' ('image/png')>
Image file name prop_meme_9787.png
cuda


127.0.0.1 - - [07/May/2024 13:17:29] "POST /detect_emotion HTTP/1.1" 200 -


Image file <FileStorage: 'prop_meme_9787.png' ('image/png')>
Image file name prop_meme_9787.png
cuda


127.0.0.1 - - [07/May/2024 13:17:54] "POST /detect_emotion HTTP/1.1" 200 -


Image file <FileStorage: 'prop_meme_9730.png' ('image/png')>
Image file name prop_meme_9730.png
cuda


127.0.0.1 - - [07/May/2024 13:21:10] "POST /detect_emotion HTTP/1.1" 200 -


Image file <FileStorage: 'prop_meme_9730.png' ('image/png')>
Image file name prop_meme_9730.png
cuda


127.0.0.1 - - [07/May/2024 13:25:21] "POST /detect_emotion HTTP/1.1" 200 -


Image file <FileStorage: 'prop_meme_9730.png' ('image/png')>
Image file name prop_meme_9730.png
cuda


127.0.0.1 - - [07/May/2024 13:47:58] "POST /detect_emotion HTTP/1.1" 200 -


Image file <FileStorage: 'prop_meme_10413.png' ('image/png')>
Image file name prop_meme_10413.png
cuda


127.0.0.1 - - [07/May/2024 13:48:40] "POST /detect_emotion HTTP/1.1" 200 -


Image file <FileStorage: 'prop_meme_10413.png' ('image/png')>
Image file name prop_meme_10413.png
cuda


127.0.0.1 - - [07/May/2024 13:49:32] "POST /detect_emotion HTTP/1.1" 200 -


In [1]:
!pip install Flask-CORS


Defaulting to user installation because normal site-packages is not writeable
