In [8]:
from transformers import T5Tokenizer, T5Model
import torch

class T5LanguageModel:
    def __init__(self, model_name='t5-base', device='cuda' if torch.cuda.is_available() else 'cpu'):
        self.device = device
        self.tokenizer = T5Tokenizer.from_pretrained(model_name)
        self.model = T5Model.from_pretrained(model_name).to(self.device)
        self.model.eval()

    def _encode_first_encoder_layer(self, texts):
        encodings = []
        for text in texts:
            encoded_input = self.tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=512).to(self.device)
            with torch.no_grad():
                output = self.model.encoder(input_ids=encoded_input.input_ids, output_hidden_states=True)
                first_encoder_layer = output.hidden_states[1]  # The first hidden state after the embedding layer
            encodings.append(first_encoder_layer.mean(dim=1).cpu().numpy())
        return encodings

    def _encode_first_decoder_layer(self, texts):
        encodings = []
        for text in texts:
            encoded_input = self.tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=512).to(self.device)
            decoder_input_ids = self.tokenizer('<pad>', return_tensors='pt').input_ids.to(self.device)
            with torch.no_grad():
                output = self.model(input_ids=encoded_input.input_ids, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
                first_decoder_layer = output.decoder_hidden_states[1]  # The first hidden state after the embedding layer
            encodings.append(first_decoder_layer.mean(dim=1).cpu().numpy())
        return encodings

    def _final_representation(self, texts):
        encodings = []
        for text in texts:
            encoded_input = self.tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=512).to(self.device)
            decoder_input_ids = self.tokenizer('<pad>', return_tensors='pt').input_ids.to(self.device)
            with torch.no_grad():
                output = self.model(input_ids=encoded_input.input_ids, decoder_input_ids=decoder_input_ids)
            encodings.append(output.encoder_last_hidden_state.mean(dim=1).cpu().numpy())
        return encodings

t5_model = T5LanguageModel(model_name='t5-base')
texts = ["Hello, world!", "How are you?"]
print("Encodings from the first encoder layer:", t5_model._encode_first_encoder_layer(texts))
print("Encodings from the first decoder layer:", t5_model._encode_first_decoder_layer(texts))
print("Final representations:", t5_model._final_representation(texts))


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Encodings from the first encoder layer: [array([[ 4.94703293e-01, -1.68299675e+01, -1.26231556e+01,
         1.53170133e+00, -9.66626167e+00,  1.54503298e+01,
        -7.07317657e+01,  1.62899816e+00, -1.24201660e+01,
        -1.39704609e+01,  2.60452104e+00,  2.25709724e+00,
         2.75690613e+01,  9.08774376e+00,  4.52202654e+00,
        -1.45248294e+00,  1.74275265e+01,  4.28111553e+00,
         3.63735771e+00,  1.96072788e+01,  7.05145931e+00,
         1.14767551e+01,  6.00450630e+01, -2.52803349e+00,
         1.58835983e+01, -3.66621161e+00,  6.91194820e+00,
        -1.32199202e+01, -9.56407642e+00,  8.30852890e+00,
        -2.01139374e+01, -2.58071098e+01, -2.17643623e+01,
        -1.16452563e+00, -5.41466999e+00,  7.29531384e+00,
         4.96910477e+00,  1.47953596e+01, -9.61669636e+00,
        -9.13374615e+00, -8.85727596e+00,  3.55617595e+00,
        -2.82051921e+00, -3.27776289e+00, -3.73824577e+01,
         5.36954260e+00, -1.25858583e+01,  1.96886673e+01,
        -1.4300