In [None]:
import torch
import torch.nn as nn

from encoders import CLIPResNetEncoder, CLIPViTEncoder
from transformers import AutoTokenizer, AutoModel

In [None]:
class CLIPGPT2Encoder(nn.Module):
    def __init__(self, projection_dim=512):
        super().__init__()
        
        self.transformer = AutoModel.from_pretrained('gpt2')
        self.tokenizer = AutoTokenizer.from_pretrained('gpt2')
        
        self.projection = nn.Sequential(
            nn.LayerNorm(768),
            nn.Linear(768, projection_dim)
        )
        
        self.max_length = 77

    def forward(self, input_ids, attention_mask=None):
        outputs = self.transformer(
            input_ids=input_ids,
            attention_mask=attention_mask,
            return_dict=True
        )
        
        # Mean pooling
        last_hidden_state = outputs.last_hidden_state
        pooled_output = (last_hidden_state * attention_mask.unsqueeze(-1)).sum(1) / attention_mask.sum(1).unsqueeze(-1)
        
        projected = self.projection(pooled_output)
        text_features = projected / projected.norm(dim=-1, keepdim=True)
        
        return text_features