In [65]:
import torch
from torch.utils.data import DataLoader, Dataset
from transformers import BartTokenizer, BartForConditionalGeneration

# Assuming you have a dataset with source documents, positive summaries, and negative summaries
# Define your dataset class
class CustomDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

# Define your custom DataCollator
class CustomDataCollator:
    def __init__(self, tokenizer, model):
        self.tokenizer = tokenizer
        self.model = model

    def __call__(self, examples):
        positive_documents = [example['positive_document'] for example in examples]
        negative_documents = [example['negative_document'] for example in examples]
        source_summaries = [example['source_summary'] for example in examples]

        # Tokenize and create input tensors
        inputs = self.tokenizer(
            positive_documents,
            # negative_summaries,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=512  # Adjust as needed
        )
        
        # Tokenize and create input tensors
        negative_inputs = self.tokenizer(
            negative_documents,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=512  # Adjust as needed
        )
        
        with tokenizer.as_target_tokenizer():
            labels = self.tokenizer(
                source_summaries,
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=512  # Adjust as needed
            )
        
        inputs["negative_input_ids"] = negative_inputs["input_ids"]
        inputs["negative_attention_mask"] = negative_inputs["attention_mask"]
        inputs["labels"] = labels["input_ids"]

        return inputs

# Load BART model and tokenizer
model_name = "facebook/bart-large-cnn"
tokenizer = BartTokenizer.from_pretrained(model_name)
model = BartForConditionalGeneration .from_pretrained(model_name)

In [93]:
# Load your dataset and create DataLoader
data = [{'positive_document': 'positive_document_1', 'negative_document': 'negative_document_1_source_summary', 'source_summary': 'source_summary_1 with a Hugging Face datasets'},
       {'positive_document': 'positive_document_2', 'negative_document': 'negative_document_2_source_summary', 'source_summary': 'source_summary_2 with a Hugging '},
       {'positive_document': 'positive_document_3', 'negative_document': 'negative_document_3_source_summary', 'source_summary': 'source_summary_3 with a'},
       {'positive_document': 'positive_document_4', 'negative_document': 'negative_document_4_source_summary', 'source_summary': 'source_summary_4 with a Hugging Face'}]
custom_dataset = CustomDataset(data)
custom_data_collator = CustomDataCollator(tokenizer, model)
dataloader = DataLoader(custom_dataset, batch_size=2, collate_fn=custom_data_collator, shuffle=True)

# Define your contrastive loss function
contrastive_loss = torch.nn.TripletMarginLoss(margin=1.0)

# Training loop
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

epochs = 3

for epoch in range(epochs):
    for batch in dataloader:
        inputs = batch
        print(inputs)
        # break

        # Forward pass
        outputs = model(**inputs, output_hidden_states=True)
        
        print(outputs.encoder_last_hidden_state.shape)
        break
        # Extract embeddings
        source_embeddings = outputs.encoder_last_hidden_state[0]
        positive_embeddings = outputs.encoder_last_hidden_state[1]
        negative_embeddings = outputs.encoder_last_hidden_state[2]

        # Compute contrastive loss
        loss = contrastive_loss(source_embeddings, positive_embeddings, negative_embeddings)
        print(loss)

        # Backpropagation and optimization step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

{'input_ids': tensor([[    0, 22173,  1215, 43017,  1215,   246,     2],
        [    0, 22173,  1215, 43017,  1215,   306,     2]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1]]), 'negative_input_ids': tensor([[    0, 33407,  1215, 43017,  1215,   246,  1215, 17747,  1215, 48600,
             2],
        [    0, 33407,  1215, 43017,  1215,   306,  1215, 17747,  1215, 48600,
             2]]), 'negative_attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]), 'labels': tensor([[    0, 17747,  1215, 48600,  1215,   246,    19,    10,     2,     1,
             1,     1],
        [    0, 17747,  1215, 48600,  1215,   306,    19,    10, 30581,  3923,
         12346,     2]])}


TypeError: BartForConditionalGeneration.forward() got an unexpected keyword argument 'negative_input_ids'

In [87]:
inputs['input_ids'].shape

torch.Size([2, 7])

In [88]:
inputs['negative_input_ids'][0].shape

torch.Size([11])

In [89]:
inputs['negative_input_ids'][1].shape

torch.Size([11])

In [90]:
inputs['input_ids'][0]

tensor([    0, 22173,  1215, 43017,  1215,   134,     2])

In [91]:
tokenizer.decode(inputs['input_ids'][0])

'<s>positive_document_1</s>'

In [92]:
tokenizer.decode(inputs['input_ids'][1])

'<s>positive_document_4</s>'