In [1]:
import torch
from torch.utils.data import DataLoader, Dataset
from transformers import BartTokenizer, BartForConditionalGeneration

# Load BART model and tokenizer
model_name = "facebook/bart-large"
tokenizer = BartTokenizer.from_pretrained(model_name)
model = BartForConditionalGeneration .from_pretrained(model_name)

Downloading (…)okenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/1.63k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/1.02G [00:00<?, ?B/s]

In [7]:
# device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')

In [8]:
# Assuming you have a dataset with source documents, positive summaries, and negative summaries
# Define your dataset class
class CustomDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

# Define your custom DataCollator
class CustomDataCollator:
    def __init__(self, tokenizer, model):
        self.tokenizer = tokenizer
        self.model = model

    def __call__(self, examples):
        positive_documents = [example['positive_document'] for example in examples]
        negative_documents = [example['negative_document'] for example in examples]
        source_summaries = [example['source_summary'] for example in examples]

        # Tokenize and create input tensors
        inputs = self.tokenizer(
            positive_documents,
            # negative_summaries,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=512  # Adjust as needed
        )
        
        # Tokenize and create input tensors
        negative_inputs = self.tokenizer(
            negative_documents,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=512  # Adjust as needed
        )
        
        with tokenizer.as_target_tokenizer():
            labels = self.tokenizer(
                source_summaries,
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=512  # Adjust as needed
            )
            
        # inputs["input_ids"] = tokenizer.pad
        # negative_inputs["input_ids"] = tokenizer.pad
        
        batch = tokenizer.pad(encoded_inputs={"input_ids": inputs["input_ids"].squeeze().tolist() + negative_inputs["input_ids"].squeeze().tolist()}, padding=True, return_tensors='pt')
        # batch["decoder_input_ids"] = torch.stack((inputs["labels"], inputs["labels"]))
        # batch["decoder_attention_mask"] = torch.stack((inputs["labels"], inputs["decoder_attention_mask"]))
        batch["decoder_input_ids"] = tokenizer.pad(encoded_inputs={"input_ids": labels["input_ids"].squeeze().tolist()+labels["input_ids"].squeeze().tolist()}, padding=False, return_tensors='pt')["input_ids"]
        labels["input_ids"] = [[(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]]
        batch["labels"] =  tokenizer.pad(encoded_inputs={"input_ids": labels["input_ids"]+labels["input_ids"]}, padding=True, return_tensors='pt')["input_ids"]

        return batch

In [13]:
# Load your dataset and create DataLoader
data = [{'positive_document': 'positive_document_1 with a Hugging ', 'negative_document': 'negative_document_1_source_summary', 'source_summary': 'source_summary_1 with a Hugging Face datasets'},
       {'positive_document': 'positive_document_2 with a Hugging Face datasets', 'negative_document': 'negative_document_2_source_summary', 'source_summary': 'source_summary_2 with a Hugging '},
       {'positive_document': 'positive_document_3 with a Hugging Face', 'negative_document': 'negative_document_3_source_summary', 'source_summary': 'source_summary_3 with a'},
       {'positive_document': 'positive_document_4 with a', 'negative_document': 'negative_document_4_source_summary', 'source_summary': 'source_summary_4 with a Hugging Face'}]
custom_dataset = CustomDataset(data)
custom_data_collator = CustomDataCollator(tokenizer, model)
dataloader = DataLoader(custom_dataset, batch_size=2, collate_fn=custom_data_collator, shuffle=True)

In [14]:
epochs = 2
for epoch in range(epochs):
    for batch in dataloader:
        print(batch['input_ids'].shape)
        print('-'*100)
        print(batch['decoder_input_ids'].shape)
        print('='*100)
        break

torch.Size([4, 13])
----------------------------------------------------------------------------------------------------
torch.Size([4, 12])
torch.Size([4, 13])
----------------------------------------------------------------------------------------------------
torch.Size([4, 13])


In [11]:
# Define your contrastive loss function
# contrastive_loss = torch.nn.TripletMarginLoss(margin=1.0)
cosine_loss = torch.nn.CosineEmbeddingLoss()

# Training loop
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

epochs = 3

for epoch in range(epochs):
    for batch in dataloader:
        inputs = batch
        # print(inputs)
        # break

        # Forward pass
        outputs = model(**inputs, output_hidden_states=True)
        
        print(outputs.encoder_last_hidden_state.shape)
        # break
        # Extract embeddings
        positive_embeddings_1 = outputs.encoder_last_hidden_state[0]
        print(positive_embeddings_1.shape)
        positive_embeddings_2 = outputs.encoder_last_hidden_state[1]
        print(positive_embeddings_2.shape)
        negative_embeddings_1 = outputs.encoder_last_hidden_state[2]
        print(negative_embeddings_1.shape)
        negative_embeddings_2 = outputs.encoder_last_hidden_state[3]
        print(negative_embeddings_2.shape)
        # break
        # Compute contrastive loss
        loss_1 = cosine_loss(positive_embeddings_1, negative_embeddings_1, torch.ones(positive_embeddings_1.size(dim=0)))
        loss_2 = cosine_loss(positive_embeddings_1, negative_embeddings_1, torch.ones(positive_embeddings_2.size(dim=0)))
        loss = (loss_1 + loss_2) / 2
        print("loss 1: ", loss_1)
        print('-'*100)
        print("loss 2: ", loss_2)
        print('-'*100)
        print("loss: ", loss)
        print('='*100)
        # break

        # Backpropagation and optimization step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

torch.Size([4, 11, 1024])
torch.Size([11, 1024])
torch.Size([11, 1024])
torch.Size([11, 1024])
torch.Size([11, 1024])
loss 1:  tensor(2.0591e-07, grad_fn=<MeanBackward0>)
----------------------------------------------------------------------------------------------------
loss 2:  tensor(2.0591e-07, grad_fn=<MeanBackward0>)
----------------------------------------------------------------------------------------------------
loss:  tensor(2.0591e-07, grad_fn=<DivBackward0>)
torch.Size([4, 11, 1024])
torch.Size([11, 1024])
torch.Size([11, 1024])
torch.Size([11, 1024])
torch.Size([11, 1024])
loss 1:  tensor(2.8014e-06, grad_fn=<MeanBackward0>)
----------------------------------------------------------------------------------------------------
loss 2:  tensor(2.8014e-06, grad_fn=<MeanBackward0>)
----------------------------------------------------------------------------------------------------
loss:  tensor(2.8014e-06, grad_fn=<DivBackward0>)
torch.Size([4, 11, 1024])
torch.Size([11, 1024])

In [36]:
positive_embeddings_1.size(dim=1)

1024

In [124]:
# Tokenize and create input tensors
inputs = tokenizer(
    [data[0]['positive_document'], data[1]['positive_document']],
    # negative_summaries,
    return_tensors="pt",
    padding=True,
    truncation=True,
    max_length=512  # Adjust as needed
)

# Tokenize and create input tensors
negative_inputs = tokenizer(
    [data[0]['negative_document'], data[1]['negative_document']],
    return_tensors="pt",
    padding=True,
    truncation=True,
    max_length=512  # Adjust as needed
)

In [109]:
inputs["input_ids"].squeeze().tolist() + negative_inputs["input_ids"].squeeze().tolist()

[[0, 22173, 1215, 43017, 1215, 134, 2],
 [0, 22173, 1215, 43017, 1215, 176, 2],
 [0, 33407, 1215, 43017, 1215, 134, 1215, 17747, 1215, 48600, 2],
 [0, 33407, 1215, 43017, 1215, 176, 1215, 17747, 1215, 48600, 2]]

In [108]:
negative_inputs["input_ids"].squeeze().tolist()

[[0, 33407, 1215, 43017, 1215, 134, 1215, 17747, 1215, 48600, 2],
 [0, 33407, 1215, 43017, 1215, 176, 1215, 17747, 1215, 48600, 2]]

In [101]:
c = tokenizer.truncate_sequences(inputs["input_ids"], negative_inputs["input_ids"], truncation_strategy='longest_first')

In [102]:
c

(tensor([[    0, 22173,  1215, 43017,  1215,   134,     2],
         [    0, 22173,  1215, 43017,  1215,   176,     2]]),
 tensor([[    0, 33407,  1215, 43017,  1215,   134,  1215, 17747,  1215, 48600,
              2],
         [    0, 33407,  1215, 43017,  1215,   176,  1215, 17747,  1215, 48600,
              2]]),
 [])

In [103]:
d = inputs

In [104]:
d

{'input_ids': tensor([[    0, 22173,  1215, 43017,  1215,   134,     2],
        [    0, 22173,  1215, 43017,  1215,   176,     2]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1]])}

In [105]:
{"input_ids": [inputs["input_ids"].squeeze().tolist()[0], negative_inputs["input_ids"].squeeze().tolist()[0]]}

{'input_ids': [[0, 22173, 1215, 43017, 1215, 134, 2],
  [0, 33407, 1215, 43017, 1215, 134, 1215, 17747, 1215, 48600, 2]]}

In [111]:
batch = tokenizer.pad(encoded_inputs={"input_ids": inputs["input_ids"].squeeze().tolist() + negative_inputs["input_ids"].squeeze().tolist()}, padding=True, return_tensors='pt')

In [112]:
batch

{'input_ids': tensor([[    0, 22173,  1215, 43017,  1215,   134,     2,     1,     1,     1,
             1],
        [    0, 22173,  1215, 43017,  1215,   176,     2,     1,     1,     1,
             1],
        [    0, 33407,  1215, 43017,  1215,   134,  1215, 17747,  1215, 48600,
             2],
        [    0, 33407,  1215, 43017,  1215,   176,  1215, 17747,  1215, 48600,
             2]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}

In [80]:
batch = tokenizer.pad(encoded_inputs={"input_ids": [inputs["input_ids"].squeeze().tolist(), negative_inputs["input_ids"].squeeze().tolist()]}, padding=True, return_tensors='pt')

In [None]:
batch

In [87]:
inputs['input_ids'].shape

torch.Size([2, 7])

In [88]:
inputs['negative_input_ids'][0].shape

torch.Size([11])

In [89]:
inputs['negative_input_ids'][1].shape

torch.Size([11])

In [90]:
inputs['input_ids'][0]

tensor([    0, 22173,  1215, 43017,  1215,   134,     2])

In [91]:
tokenizer.decode(inputs['input_ids'][0])

'<s>positive_document_1</s>'

In [92]:
tokenizer.decode(inputs['input_ids'][1])

'<s>positive_document_4</s>'