In [67]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModel, AutoTokenizer, Trainer, TrainingArguments, AutoModelForSeq2SeqLM
from sklearn.metrics import pairwise
import evaluate
import transformers
from transformers import default_data_collator

In [40]:
class TextEmbeddingDataset(Dataset):
    def __init__(self, texts, emb_g, emb_s):
        """
        :param texts: List of original input texts.
        :param emb_g: Embeddings from encoder g.
        :param emb_s: Embeddings from encoder s.
        """
        self.texts = texts
        self.emb_g = emb_g
        self.emb_s = emb_s

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        return {
            'text': self.texts[idx],
            'emb_g': torch.tensor(self.emb_g[idx], dtype=torch.float32),
            'emb_s': torch.tensor(self.emb_s[idx], dtype=torch.float32),
        }

In [41]:
class AlignmentModel(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(AlignmentModel, self).__init__()
        self.linear = nn.Linear(input_dim, output_dim)

    def forward(self, emb_s):
        return self.linear(emb_s)


In [42]:
class CosineSimilarityLoss(nn.Module):
    def forward(self, aligned_emb_s, emb_g):
        cos_sim = nn.functional.cosine_similarity(aligned_emb_s, emb_g, dim=-1)
        loss = 1 - cos_sim.mean()
        return loss

In [43]:
def load_encoder_decoder(
    model_name: str, lora: bool = False
) -> transformers.AutoModelForSeq2SeqLM:
    model_kwargs: Dict[str, Any] = {
        "low_cpu_mem_usage": True,
    }
    if lora:
        model_kwargs.update(
            {
                "load_in_8bit": True,
                "device_map": "auto",
            }
        )
    return transformers.AutoModelForSeq2SeqLM.from_pretrained(
        model_name, **model_kwargs
    )

def load_tokenizer(name: str, max_length: int) -> transformers.PreTrainedTokenizer:
    tokenizer = transformers.AutoTokenizer.from_pretrained(
        name,
        padding="max_length",
        truncation="max_length",
        max_length=max_length,
    )

    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    # Disable super annoying warning:
    # https://github.com/huggingface/transformers/issues/22638
    tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True
    return tokenizer


def load_embedder_and_tokenizer(name:str, **kwargs):
    model_kwargs = {
        "low_cpu_mem_usage": True,  # Not compatible with DeepSpeed
        "output_hidden_states": False,  # True output hidden states, for embedding last and first .
    }
    # TODO: check the configurations for commercial models

    if name=="me5":
        model = transformers.AutoModel.from_pretrained("intfloat/multilingual-e5-base", **model_kwargs)
        tokenizer = transformers.AutoTokenizer.from_pretrained("intfloat/multilingual-e5-base")
        tokenizer.pad_token = tokenizer.eos_token
    else:
        print(f"WARNING: Trying to initialize from unknown embedder {name}")
        model = transformers.AutoModel.from_pretrained(name, **model_kwargs)
        tokenizer = transformers.AutoTokenizer.from_pretrained(name)

    return model, tokenizer

In [186]:
def get_encoder_embeddings(model, tokenizer, input_texts):
    model.eval()

    embeddings = []
    with torch.no_grad():
        for text in input_texts:
            inputs = tokenizer(text, return_tensors='pt', truncation=True, padding=True, max_length=128)
            outputs = model(**inputs)
            embedding = outputs.last_hidden_state.mean(dim=1).squeeze(0).numpy()  # Use mean pooling
            embeddings.append(embedding)
            print(embedding.shape)

    return torch.tensor(embeddings)


In [179]:
def evaluate_bleu(predictions, references):
    bleu_metric = evaluate.load('sacrebleu')
    return bleu_metric.compute(predictions=predictions, references=[[ref] for ref in references])

In [180]:
texts = [
"This is a test sentence.",
"I love natural language processing.",
"Transformers are amazing models!"
]

# Load embeddings from encoder g and encoder s

In [181]:
encoder_model_name = "intfloat/multilingual-e5-small"
encoder_decoder_model_name = "google-t5/t5-small"
encoder, encoder_tokenizer = load_embedder_and_tokenizer("me5")
encoder_decoder = load_encoder_decoder(encoder_decoder_model_name)
encoder_decoder_tokenizer = load_tokenizer(encoder_decoder_model_name, max_length=128)

In [187]:
emb_g = get_encoder_embeddings(encoder_decoder.encoder, encoder_decoder_tokenizer, texts)

(512,)
(512,)
(512,)


In [188]:
emb_g.shape

torch.Size([3, 512])

In [189]:
emb_s = get_encoder_embeddings(encoder, encoder_tokenizer, texts)

(768,)
(768,)
(768,)


In [190]:
emb_s.shape

torch.Size([3, 768])

In [191]:
dataset = TextEmbeddingDataset(texts, emb_g, emb_s)

In [192]:
dataset[0]

  'emb_g': torch.tensor(self.emb_g[idx], dtype=torch.float32),
  'emb_s': torch.tensor(self.emb_s[idx], dtype=torch.float32),


{'text': 'This is a test sentence.',
 'emb_g': tensor([-4.8239e-02,  1.6047e-02, -9.7561e-03, -1.6206e-02, -1.9525e-02,
          6.1420e-02, -2.6809e-02, -3.9371e-02, -1.1036e-01, -1.2500e-01,
         -1.3804e-02,  1.4954e-01,  4.1940e-02, -2.5307e-02, -1.1941e-02,
          7.8106e-02, -3.2313e-02,  2.0682e-02,  4.1168e-02, -4.3143e-02,
          2.5930e-02,  1.3033e-01,  1.4260e-01, -1.1514e-01,  8.2676e-02,
         -1.0572e-01, -1.2880e-01, -1.2978e-01,  1.3664e-01, -2.0178e-01,
          1.3683e-01,  1.4738e-02, -1.1337e-01, -8.0304e-03, -1.3763e-01,
         -1.5196e-02,  5.1717e-02, -7.9046e-02, -1.8986e-01,  7.7120e-02,
         -8.5474e-02,  5.5013e-03,  1.7629e-01, -1.2956e-01, -3.2731e-02,
          1.0586e-01,  9.0880e-02, -1.5749e-02, -3.5015e-02,  2.3650e-02,
          6.6901e-02, -6.6957e-02,  1.4793e-02,  1.0112e-01, -2.1416e-01,
         -9.0710e-03,  6.3410e-02,  2.6716e-02, -3.7014e-02, -1.3567e-01,
          1.0831e-02,  1.7049e-01,  3.2329e-01,  6.2452e-02, -1.76

In [193]:
input_dim = emb_s.shape[1]
output_dim = emb_g.shape[1]
model = AlignmentModel(input_dim, output_dim)

In [194]:
model

AlignmentModel(
  (linear): Linear(in_features=768, out_features=512, bias=True)
)

In [195]:
training_args = TrainingArguments(
        output_dir="./results",
        num_train_epochs=5,
        per_device_train_batch_size=1,
        logging_dir="./logs",
        logging_steps=10,
         remove_unused_columns=False, #very important.
    use_mps_device=True
    )

In [196]:
class CosineSimilarityLoss(nn.Module):
    def forward(self, aligned_emb_s, emb_g):
        cos_sim = nn.functional.cosine_similarity(aligned_emb_s, emb_g, dim=-1)
        loss = 1 - cos_sim.mean()
        return loss
        
class CustomTrainer(Trainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        
    def compute_loss(self, model, inputs, return_outputs=False):
        print(inputs.keys())
        emb_s = inputs['emb_s']
        emb_g = inputs['emb_g']
        
        aligned_emb_s = model(emb_s)
        loss_fn = CosineSimilarityLoss()
        loss = loss_fn(aligned_emb_s, emb_g)
        
        return (loss, aligned_emb_s) if return_outputs else loss


In [197]:
# Create custom trainer
trainer = CustomTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
)

In [205]:
trainer.train()

dict_keys(['emb_g', 'emb_s'])


  'emb_g': torch.tensor(self.emb_g[idx], dtype=torch.float32),
  'emb_s': torch.tensor(self.emb_s[idx], dtype=torch.float32),


Step,Training Loss
10,0.1499


dict_keys(['emb_g', 'emb_s'])
dict_keys(['emb_g', 'emb_s'])
dict_keys(['emb_g', 'emb_s'])
dict_keys(['emb_g', 'emb_s'])
dict_keys(['emb_g', 'emb_s'])
dict_keys(['emb_g', 'emb_s'])
dict_keys(['emb_g', 'emb_s'])
dict_keys(['emb_g', 'emb_s'])
dict_keys(['emb_g', 'emb_s'])
dict_keys(['emb_g', 'emb_s'])
dict_keys(['emb_g', 'emb_s'])
dict_keys(['emb_g', 'emb_s'])
dict_keys(['emb_g', 'emb_s'])
dict_keys(['emb_g', 'emb_s'])


TrainOutput(global_step=15, training_loss=0.13779906034469605, metrics={'train_runtime': 0.1395, 'train_samples_per_second': 107.559, 'train_steps_per_second': 107.559, 'total_flos': 0.0, 'train_loss': 0.13779906034469605, 'epoch': 5.0})

In [206]:
model = model.eval()

In [207]:
model

AlignmentModel(
  (linear): Linear(in_features=768, out_features=512, bias=True)
)

In [208]:
emb_s

tensor([[-0.2073,  0.2610, -0.3151,  ..., -0.7859, -0.4563,  0.6098],
        [-0.1964,  0.8787, -0.4009,  ..., -0.8158, -0.4921,  0.8882],
        [-0.1370,  0.2270, -0.0974,  ..., -0.5315, -0.5712,  0.6356]])

In [209]:
emb_s.shape

torch.Size([3, 768])

In [210]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
emb_s = emb_s.to(device)
model.to(device)

AlignmentModel(
  (linear): Linear(in_features=768, out_features=512, bias=True)
)

In [211]:
aligned_emb_s =model(emb_s)

In [212]:
decoder_g = encoder_decoder.decoder

In [213]:
aligned_emb_s.shape

torch.Size([3, 512])

In [214]:
aligned_emb = aligned_emb_s.unsqueeze(0)
print(aligned_emb.shape)
# Create a named tuple to pass encoder outputs manually
from transformers.modeling_outputs import BaseModelOutput
encoder_outputs = BaseModelOutput(last_hidden_state=aligned_emb)
# print(encoder_outputs.shape)

# Provide initial decoder input IDs (start token)
decoder_input_ids = torch.tensor([[encoder_decoder.config.decoder_start_token_id]])

# Generate text based on aligned embeddings
decoder_outputs = encoder_decoder.generate(
    encoder_outputs=encoder_outputs,
    decoder_input_ids=decoder_input_ids,
    max_length=50,  # Limit the generation length
    num_beams=5     # Beam search for better results (optional)
)

# Decode the generated tokens into text
generated_text = encoder_decoder_tokenizer.decode(decoder_outputs[0], skip_special_tokens=True)
print(generated_text)

torch.Size([1, 3, 512])
                                                
