In [None]:


import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset
from torch.optim import AdamW
from torchvision import transforms
import pandas as pd
from transformers import AutoTokenizer, AutoModel, ViTModel
from transformers import BertTokenizer, BertModel
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
from sklearn.model_selection import train_test_split

from sklearn.metrics import accuracy_score

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [142]:
# --- Cross Attention ---
class CrossAttentionBlock(nn.Module):
    def __init__(self, dim_q, dim_k, num_heads=8, dropout=0.1):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim=dim_q, num_heads=num_heads, batch_first=True)
        self.norm_q = nn.LayerNorm(dim_q)
        self.norm_k = nn.LayerNorm(dim_k)
        self.dropout = nn.Dropout(dropout)
        self.proj = nn.Linear(dim_q, dim_q)

    def forward(self, q, k, v, key_padding_mask=None):
        q_norm, k_norm = self.norm_q(q), self.norm_k(k)
        attn_out, attn_weights = self.attn(
            q_norm, k_norm, v, key_padding_mask=key_padding_mask, need_weights=True
        )
        q_out = q + self.dropout(self.proj(attn_out))
        return q_out, attn_weights


# --- Symmetric Multimodal Classifier ---
class SymmetricMultimodalClassifier(nn.Module):
    def __init__(self, 
                 text_model='bert-base-uncased', 
                 image_model='google/vit-base-patch16-224',
                 hidden_dim=768, 
                 num_classes=2, 
                 use_contrastive=True):
        super().__init__()

        self.text_encoder = BertModel.from_pretrained(text_model)
        self.image_encoder = ViTModel.from_pretrained(image_model)
        self.use_contrastive = use_contrastive

        self.text_to_img = nn.Sequential(
            CrossAttentionBlock(hidden_dim, hidden_dim),
            CrossAttentionBlock(hidden_dim, hidden_dim)
        )
        self.img_to_text = nn.Sequential(
            CrossAttentionBlock(hidden_dim, hidden_dim),
            CrossAttentionBlock(hidden_dim, hidden_dim)
        )

        self.text_proj = nn.Linear(hidden_dim, hidden_dim)
        self.img_proj = nn.Linear(hidden_dim, hidden_dim)

        self.classifier = nn.Linear(hidden_dim, num_classes)

        # Для Grad-CAM
        self.last_attn_t2i = None
        self.last_attn_i2t = None

    def forward(self, input_ids, attention_mask, images):
        # --- Encode text & image ---
        text_feat = self.text_encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=False
        ).last_hidden_state  # [B, L_t, D]

        img_feat = self.image_encoder(images).last_hidden_state  # [B, L_i, D]
        key_padding_mask = (attention_mask == 0)

        # --- Symmetric Cross-Attention ---
        text_cross, attn_t2i = self.text_to_img(text_feat, img_feat, img_feat)
        img_cross, attn_i2t = self.img_to_text(img_feat, text_feat, text_feat, key_padding_mask)

        # Save for Grad-CAM visualization
        self.last_attn_t2i = attn_t2i.detach()
        self.last_attn_i2t = attn_i2t.detach()

        # --- Use CLS tokens only ---
        # text_emb = text_cross[:, 0, :]  # CLS
        text_emb = text_cross[:,0]
        img_emb = img_cross[:, 0, :]    # CLS
        

        # --- Normalize for contrastive loss ---
        text_emb_n = F.normalize(self.text_proj(text_emb), p=2, dim=-1)
        img_emb_n = F.normalize(self.img_proj(img_emb), p=2, dim=-1)

        # --- Classification ---
        logits = self.classifier(text_emb)

        if self.use_contrastive:
            return logits, text_emb_n, img_emb_n, attn_t2i, attn_i2t
        else:
            return logits, attn_t2i, attn_i2t


# --- Contrastive Loss (Symmetric) ---
def contrastive_loss(text_emb, img_emb, temperature=0.05):
    sim_matrix = torch.matmul(text_emb, img_emb.T) / temperature
    labels = torch.arange(sim_matrix.size(0), device=sim_matrix.device)
    loss_i = F.cross_entropy(sim_matrix, labels)
    loss_t = F.cross_entropy(sim_matrix.T, labels)
    return (loss_i + loss_t) / 2


In [143]:
class MultimodalNewsDataset(Dataset):
    def __init__(self, txt_file, image_dir, tokenizer, max_len=128, transform=None):
        self.data = pd.read_csv(txt_file, sep='\t')
        self.image_dir = image_dir
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.transform = transform
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        text = str(row['clean_title'])
        img_id = str(row['id'])
        label = torch.tensor(int(row['2_way_label']), dtype=torch.long)
        
        img_path = os.path.join(self.image_dir, f"{img_id}.jpg")
        
        try:
            image = Image.open(img_path).convert("RGB")
        except Exception as e:
            # print(f"[WARN] Ошибка при открытии {img_path}: {e}")
            # Возвращаем пустое изображение, чтобы не ломать batch
            image = Image.new("RGB", (224, 224), (0, 0, 0))
        if self.transform:
            image = self.transform(image)
        else:
            image = transforms.ToTensor()(image)
        
        encoding = self.tokenizer(
            text,
            max_length=self.max_len,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        input_ids = encoding['input_ids'].squeeze(0)
        attention_mask = encoding['attention_mask'].squeeze(0)
        
        return input_ids, attention_mask, image, label


In [145]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

dataset_path = '/repo/project_deepfake/project/fakeddit_dataset'
transform_train = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

transform_val = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

dataset = MultimodalNewsDataset(
    txt_file=dataset_path+'/text/text.txt',
    image_dir=dataset_path+'/images',
    tokenizer=tokenizer,
    transform=transform_train 
)

# Split indices
train_idx, val_idx = train_test_split(
    range(len(dataset)),
    test_size=0.2,
    random_state=42,
    stratify=dataset.data['2_way_label']
)

train_set = Subset(dataset, train_idx)
val_set   = Subset(dataset, val_idx)

val_set.dataset.transform = transform_val  

train_loader = DataLoader(train_set, batch_size=8, shuffle=True)
val_loader = DataLoader(val_set, batch_size=8, shuffle=False)

In [147]:
checkpoint = torch.load("checkpoints/multimodal_model.pth", map_location=device)
model = SymmetricMultimodalClassifier().to(device)
model.load_state_dict(checkpoint['model_state_dict'])

Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


RuntimeError: Error(s) in loading state_dict for SymmetricMultimodalClassifier:
	Missing key(s) in state_dict: "text_to_img.0.attn.in_proj_weight", "text_to_img.0.attn.in_proj_bias", "text_to_img.0.attn.out_proj.weight", "text_to_img.0.attn.out_proj.bias", "text_to_img.0.norm_q.weight", "text_to_img.0.norm_q.bias", "text_to_img.0.norm_k.weight", "text_to_img.0.norm_k.bias", "text_to_img.0.proj.weight", "text_to_img.0.proj.bias", "text_to_img.1.attn.in_proj_weight", "text_to_img.1.attn.in_proj_bias", "text_to_img.1.attn.out_proj.weight", "text_to_img.1.attn.out_proj.bias", "text_to_img.1.norm_q.weight", "text_to_img.1.norm_q.bias", "text_to_img.1.norm_k.weight", "text_to_img.1.norm_k.bias", "text_to_img.1.proj.weight", "text_to_img.1.proj.bias", "img_to_text.0.attn.in_proj_weight", "img_to_text.0.attn.in_proj_bias", "img_to_text.0.attn.out_proj.weight", "img_to_text.0.attn.out_proj.bias", "img_to_text.0.norm_q.weight", "img_to_text.0.norm_q.bias", "img_to_text.0.norm_k.weight", "img_to_text.0.norm_k.bias", "img_to_text.0.proj.weight", "img_to_text.0.proj.bias", "img_to_text.1.attn.in_proj_weight", "img_to_text.1.attn.in_proj_bias", "img_to_text.1.attn.out_proj.weight", "img_to_text.1.attn.out_proj.bias", "img_to_text.1.norm_q.weight", "img_to_text.1.norm_q.bias", "img_to_text.1.norm_k.weight", "img_to_text.1.norm_k.bias", "img_to_text.1.proj.weight", "img_to_text.1.proj.bias". 
	Unexpected key(s) in state_dict: "text_pool.proj.weight", "text_pool.proj.bias", "img_pool.proj.weight", "img_pool.proj.bias", "text_to_img.attn.in_proj_weight", "text_to_img.attn.in_proj_bias", "text_to_img.attn.out_proj.weight", "text_to_img.attn.out_proj.bias", "text_to_img.norm_q.weight", "text_to_img.norm_q.bias", "text_to_img.norm_k.weight", "text_to_img.norm_k.bias", "img_to_text.attn.in_proj_weight", "img_to_text.attn.in_proj_bias", "img_to_text.attn.out_proj.weight", "img_to_text.attn.out_proj.bias", "img_to_text.norm_q.weight", "img_to_text.norm_q.bias", "img_to_text.norm_k.weight", "img_to_text.norm_k.bias". 