In [1]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from sklearn.preprocessing import normalize
from sklearn.model_selection import train_test_split
from sklearn.metrics.pairwise import cosine_similarity
import open_clip
from transformers import BertTokenizer, BertModel
from torch_geometric.nn import GATConv
from sklearn.neighbors import kneighbors_graph

device = "cuda" if torch.cuda.is_available() else "cpu"

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class MisogynyDataset(Dataset):
    def __init__(self, data, label_map, transform=None):
        self.data = data.reset_index(drop=True)
        self.label_map = label_map
        self.transform = transform or transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor()
        ])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        image = Image.open(row["image_path"]).convert("RGB")
        image = self.transform(image)
        label = self.label_map[row["image_label"]]
        caption = row["image_caption"]
        return image, caption, label

class MisogynyDataLoader:
    def __init__(self, csv_file="data_csv.csv", batch_size=16, test_size=0.2, random_state=42,
                 train_transform=None, test_transform=None, num_workers=0, pin_memory=False):
        data = pd.read_csv(csv_file)
        label_map = {"kitchen":0, "shopping":1, "working":2, "leadership":3}

        train_df, test_df = train_test_split(
            data,
            test_size=test_size,
            random_state=random_state,
            shuffle=True,
            stratify=data["image_label"]
        )

        self.train_dataset = MisogynyDataset(train_df, label_map, transform=train_transform)
        self.test_dataset = MisogynyDataset(test_df, label_map, transform=test_transform)

        self.train_loader = DataLoader(self.train_dataset, batch_size=batch_size, shuffle=False,
                                       num_workers=num_workers, pin_memory=pin_memory)
        self.test_loader = DataLoader(self.test_dataset, batch_size=batch_size, shuffle=False,
                                      num_workers=num_workers, pin_memory=pin_memory)

In [3]:
class BERTEmbedder(nn.Module):
    def __init__(self):
        super().__init__()
        self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
        self.model_bert = BertModel.from_pretrained("bert-base-uncased")
        self.model_bert.eval()

    def forward(self, input_text):
        inputs = self.tokenizer(input_text, return_tensors="pt", padding=True, truncation=True)
        with torch.no_grad():
            outputs = self.model_bert(**inputs)
        token_embeddings = outputs.last_hidden_state
        attention_mask = inputs["attention_mask"]
        mask = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        sentence_embeddings = (token_embeddings * mask).sum(dim=1) / mask.sum(dim=1)
        embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
        return embeddings

class OpenClipVitEmbedder(nn.Module):
    def __init__(self, device=None):
        super().__init__()
        self.model, _, self.preprocess = open_clip.create_model_and_transforms(
            model_name="ViT-B-32", pretrained="openai"
        )
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        self.model = self.model.to(self.device)
        self.model.eval()
        for p in self.model.parameters():
            p.requires_grad = False

    def forward(self, image_tensor):
        image_tensor = image_tensor.to(self.device)
        with torch.no_grad():
            image_features = self.model.encode_image(image_tensor)
        image_features = F.normalize(image_features, p=2, dim=-1)
        return image_features

In [4]:
class LDALayer(nn.Module):
    def __init__(self, mean, coef):
        super().__init__()
        self.register_buffer("mean", torch.tensor(mean, dtype=torch.float32))
        self.register_buffer("weight", torch.tensor(coef, dtype=torch.float32))

    def forward(self, x):
        x = x - self.mean
        x = torch.matmul(x, self.weight.T)
        return x

In [5]:
class GraphModule(nn.Module):
    """
    Self-contained GAT module with stored training graph.
    Forward function can take new embeddings and return contextualized embeddings.
    """
    def __init__(self, node_features, k=20, hidden_dim=32, out_dim=64, heads=4, dropout=0.2):
        super().__init__()
        self.register_buffer("node_features", torch.tensor(node_features, dtype=torch.float32))
        self.k = k
        self.edge_index = self._build_knn_graph(node_features, k)
        in_dim = node_features.shape[1]
        self.gat1 = GATConv(in_dim, hidden_dim, heads=heads, concat=True, dropout=dropout)
        self.gat2 = GATConv(hidden_dim*heads, out_dim, heads=1, concat=False, dropout=dropout)

    def _build_knn_graph(self, embeddings, k):
        embeddings_norm = normalize(embeddings, axis=1)
        knn = kneighbors_graph(embeddings_norm, n_neighbors=k, mode='connectivity', include_self=True)
        knn = 0.5 * (knn + knn.T)
        coo = knn.tocoo()
        return torch.tensor([coo.row, coo.col], dtype=torch.long)

    def forward(self, new_node_features, k=None):
        """
        new_node_features: tensor of shape (num_new_nodes, feature_dim)
        Returns: contextualized embeddings for new nodes
        """
        k = k or self.k
        device = self.node_features.device
        training_nodes = self.node_features
        new_nodes = new_node_features.to(device)

        # Combine training nodes + new nodes
        all_nodes = torch.cat([training_nodes, new_nodes], dim=0)

        # Build edges for new nodes
        num_training = training_nodes.shape[0]
        num_new = new_nodes.shape[0]
        edge_rows = []
        edge_cols = []

        # For each new node, find top-k neighbors in training nodes
        sim = cosine_similarity(new_nodes.cpu().numpy(), training_nodes.cpu().numpy())  # (num_new, num_training)
        for i in range(num_new):
            topk_idx = np.argsort(sim[i])[-k:]
            new_idx = num_training + i
            edge_rows.extend([new_idx]*k + topk_idx.tolist())
            edge_cols.extend(topk_idx.tolist() + [new_idx]*k)  # bidirectional

        # Combine with existing edges
        existing_edges = self.edge_index
        new_edges = torch.tensor([edge_rows, edge_cols], dtype=torch.long).to(device)
        combined_edge_index = torch.cat([existing_edges.to(device), new_edges], dim=1)

        # Apply GAT layers
        x = self.gat1(all_nodes, combined_edge_index)
        x = F.elu(x)
        x = self.gat2(x, combined_edge_index)

        # Return only new nodes' embeddings
        return x[-num_new:]

    def save(self, path):
        torch.save({
            "gat_state_dict": self.state_dict(),
            "node_features": self.node_features,
            "edge_index": self.edge_index
        }, path)
        print(f"GraphModule saved at {path}")

    @classmethod
    def load(cls, path, k=20, hidden_dim=32, out_dim=64, heads=4, dropout=0.2, device=None):
        checkpoint = torch.load(path, map_location=device or "cpu")
        node_features = checkpoint["node_features"].cpu().numpy()
        model = cls(node_features, k=k, hidden_dim=hidden_dim, out_dim=out_dim, heads=heads, dropout=dropout)
        model.load_state_dict(checkpoint["gat_state_dict"])
        if device:
            model.to(device)
        print(f"GraphModule loaded from {path}")
        return model

In [6]:
def collect_embeddings(dataloader, text_model, image_model, device):
    text_embeddings = []
    image_embeddings = []

    for images, captions, _ in dataloader:
        captions = list(captions)
        with torch.no_grad():
            text_emb = text_model(captions).to("cpu")
            image_emb = image_model(images).to("cpu")
        text_embeddings.append(text_emb.numpy())
        image_embeddings.append(image_emb.numpy())

    text_embeddings = np.vstack(text_embeddings)
    image_embeddings = np.vstack(image_embeddings)
    return text_embeddings, image_embeddings

In [None]:
dataloaders = MisogynyDataLoader()
train_loader = dataloaders.train_loader

text_model = BERTEmbedder().to(device)
image_model = OpenClipVitEmbedder(device=device)

text_train_emb, image_train_emb = collect_embeddings(train_loader, text_model, image_model, device)
lda_mean = np.load("weights/combined_lda_mean.npy")
lda_coef = np.load("weights/combined_lda_coef.npy")
combined_lda_layer = LDALayer(lda_mean, lda_coef)

combined_raw = np.concatenate([text_train_emb, image_train_emb], axis=1)
combined_tensor = torch.tensor(combined_raw, dtype=torch.float32)
combined_lda_emb = combined_lda_layer(combined_tensor).numpy()

gat_module = GraphModule(node_features=combined_lda_emb, k=20, hidden_dim=32, out_dim=64)

gat_module.save("graph_module.pth")

gat_module_loaded = GraphModule.load("graph_module.pth", device=device)

new_emb = np.random.rand(1, combined_lda_emb.shape[1]).astype(np.float32)
new_emb_tensor = torch.tensor(new_emb)

contextualized_emb = gat_module_loaded(new_emb_tensor, k=5)
print("Contextualized embedding shape:", contextualized_emb.shape)
print("Embedding vector:", contextualized_emb)

Loading weights: 100%|██████████| 199/199 [00:00<00:00, 1416.26it/s, Materializing param=pooler.dense.weight]                               
BertModel LOAD REPORT from: bert-base-uncased
Key                                        | Status     |  | 
-------------------------------------------+------------+--+-
cls.predictions.bias                       | UNEXPECTED |  | 
cls.seq_relationship.weight                | UNEXPECTED |  | 
cls.predictions.transform.LayerNorm.weight | UNEXPECTED |  | 
cls.predictions.transform.dense.bias       | UNEXPECTED |  | 
cls.predictions.transform.LayerNorm.bias   | UNEXPECTED |  | 
cls.predictions.transform.dense.weight     | UNEXPECTED |  | 
cls.seq_relationship.bias                  | UNEXPECTED |  | 

Notes:
- UNEXPECTED	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.


GraphModule saved at graph_module.pth
GraphModule loaded from graph_module.pth
Contextualized embedding shape: torch.Size([1, 64])
Embedding vector: tensor([[-0.7221, -0.6699,  0.3172, -0.3804,  1.4258,  0.2117,  0.0096, -0.3153,
         -1.2866, -2.6263, -1.5839, -1.7813, -0.5408,  1.2618,  0.6201, -0.5782,
          0.1267,  0.1445,  0.6805,  0.4058, -1.4602, -0.3465,  0.7490, -0.0436,
          1.4061,  0.1863,  0.9807, -0.7317, -2.6265,  1.5070, -0.9574, -1.9253,
         -0.7596,  1.0118,  1.2317,  0.4074,  0.7122,  1.4086, -0.5095, -1.5471,
          1.5757,  0.1525, -1.3396,  2.3373,  0.3033, -0.7901, -1.6310, -0.1629,
          0.8914,  0.2795, -0.9303, -0.6516,  0.3851, -1.8809,  0.3712,  0.4218,
         -1.7391,  1.4416, -1.0993,  1.7109,  0.6776,  2.2506, -2.2052,  0.8949]],
       grad_fn=<SliceBackward0>)


  return torch.tensor([coo.row, coo.col], dtype=torch.long)
