In [1]:
import torch
from transformers import BartTokenizer, BartForConditionalGeneration

In [2]:
# Load the BART model and tokenizer
tokenizer = BartTokenizer.from_pretrained("facebook/bart-large")
add_special_tokens = {'additional_special_tokens': ['<query>', '<response>', '<latent>', '<persona>']}
tokenizer.add_special_tokens(add_special_tokens)

4

In [3]:
args = {
   "num_latent": 10,
   "num_latent2": 10,
}

In [4]:
encoded_tokens = tokenizer.encode(add_special_tokens['additional_special_tokens'])
args['bos'], args["query"], args["response"], args["latent"], args["persona"], args["eos"] = encoded_tokens
args["pad"] = 1

In [5]:
# Load the model with the updated tokenizer
model = BartForConditionalGeneration.from_pretrained("facebook/bart-base")
model.resize_token_embeddings(len(tokenizer))

The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


BartScaledWordEmbedding(50269, 768, padding_idx=1)

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)


BartForConditionalGeneration(
  (model): BartModel(
    (shared): BartScaledWordEmbedding(50269, 768, padding_idx=1)
    (encoder): BartEncoder(
      (embed_tokens): BartScaledWordEmbedding(50269, 768, padding_idx=1)
      (embed_positions): BartLearnedPositionalEmbedding(1026, 768)
      (layers): ModuleList(
        (0-5): 6 x BartEncoderLayer(
          (self_attn): BartSdpaAttention(
            (k_proj): Linear(in_features=768, out_features=768, bias=True)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (activation_fn): GELUActivation()
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (fc2): Linear(in_features=3072, out_features=768, bias=True)
          (final_lay

In [None]:
import pickle
# Load the pickle file
with open('./Synthetic-Persona-Chat/prepared_data/test_data.pkl', 'rb') as f:
    data_set = pickle.load(f)

# Now `data_set` contains your loaded data
print(data_set[0])


{'input_ids': [50267, 50268, 100, 95, 2162, 10, 1518, 92, 790, 4, 2, 100, 101, 7, 3836, 23, 5, 950, 4, 2, 100, 422, 10, 2335, 41227, 334, 4, 2, 100, 33, 10, 380, 4045, 13495, 4, 2, 100, 101, 602, 8, 6016, 842, 462, 16731, 4, 2, 50265, 30086, 6, 38, 437, 646, 44518, 112, 18, 766, 8174, 653, 18, 110, 766, 116, 2], 'labels': [50266, 30086, 6, 38, 437, 646, 44518, 132, 18, 766, 8174, 85, 18, 2579, 7, 972, 47, 4, 2]}


In [26]:
torch.tensor(data_set[0]['labels'])

tensor([50266, 30086,     6,    38,   437,   646, 44518,   132,    18,   766,
         8174,    85,    18,  2579,     7,   972,    47,     4,     2])

In [39]:

# Convert to torch dataset
class CNNDailyMailDataset(torch.utils.data.Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        input_ids = torch.tensor(item['input_ids'])
        # attention_mask = torch.tensor(item['attention_mask'])
        labels = torch.tensor(item['labels'])
         

        decoder_input_ids = labels
        # labels = labels[1:]

        # decoder_input_ids = labels[:-1]
        # labels = labels[1:]

        return {
            'input_ids': input_ids,
            # 'attention_mask': attention_mask,
            'decoder_input_ids': decoder_input_ids,
            'labels': labels
        }


In [40]:
from torch.nn.utils.rnn import pad_sequence

class CollateFn:
    def __init__(self, pad_token_id):
        self.pad_token_id = pad_token_id

    def __call__(self, batch):
        input_ids = [item['input_ids'] for item in batch]
        decoder_input_ids = [item['decoder_input_ids'] for item in batch]
        labels = [item['labels'] for item in batch]

        input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.pad_token_id)
        decoder_input_ids = pad_sequence(decoder_input_ids, batch_first=True, padding_value=self.pad_token_id)
        labels = pad_sequence(labels, batch_first=True, padding_value=-100)  # ignore index

        return {
            'input_ids': input_ids,
            'decoder_input_ids': decoder_input_ids,
            'labels': labels
        }


In [41]:
train_dataset = CNNDailyMailDataset(data_set)

In [61]:
pad_token_id = args['pad']

pad_fn = CollateFn(pad_token_id)
train_loader = torch.utils.data.DataLoader(train_dataset, shuffle=True, batch_size=2,collate_fn=pad_fn)

In [62]:
for batch in train_loader:
    # Move the batch to the device
    input_ids = batch['input_ids'].to(device)
    decoder_input_ids = batch['decoder_input_ids'].to(device)
    labels = batch['labels'].to(device)

    # Print the shapes of the tensors
    print(f"Input ID: {input_ids}")
    print(f"Decoder Input IDs : {decoder_input_ids}")
    print(f"Labels: {labels}")
    break

Input ID: tensor([[50267, 50268,   100,   173,    25,    10,  3254,     4,     2,   100,
           524,    10,  2602, 37958,     4,     2,   100,   524,  2997,    19,
            10,  1159,     4,     2,   100,   657,     7,  7142,     4,     2,
           100,   101,  2600,     4,     2, 50265,  1711,    18,   205,     7,
          1798,     4,     2, 50266,  2847,     6,    99,    18,   110,  2674,
           689,     7,  7142,   116,     2, 50265,   100,   657,  6836, 18236,
             4,    38,   146,    10,  1266, 33362,   741,  8982,  4977,   242,
             4,     2, 50266,  7516,     6,    14,  4428, 10964,     4,    38,
           657, 18236,   350,     4,     2, 50265,  5096,   350,     4,    38,
           115,  3529,    24,   358,   183,     4,     2,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1],
        [50267, 50268,   100,   657,     5,  1971,  6187,     4,     2, 10285,
           688,   

In [None]:
from tqdm import tqdm
import torch

optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

num_epochs = 1  # for example

for epoch in range(num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    model.train()
    total_loss = 0.0

    progress_bar = tqdm(enumerate(train_loader), total=len(train_loader), desc="Training")

    for batch_idx, batch in progress_bar:
        input_ids = batch['input_ids'].to(device)
        labels = batch['labels'].to(device)

        outputs = model(input_ids=input_ids, labels=labels)

        loss = outputs.loss
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        total_loss += loss.item()

        # Calculate perplexity for the current batch
        perplexity = torch.exp(loss).item()

        # Update progress bar description
        progress_bar.set_postfix({
            'Batch': batch_idx+1,
            'Loss': f"{loss.item():.4f}",
            'Perplexity': f"{perplexity:.4f}"
        })

    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1} completed | Average Loss: {avg_loss:.4f}")



Epoch 1/3
Batch 10 | Loss: 4.2847 | Perplexity: 72.5843


KeyboardInterrupt: 

In [70]:
query = "<latent><persona> I love hiking and books. <query> What's your hobby?"
inputs = tokenizer(query, return_tensors='pt').to(device)

output_ids = model.generate(
    input_ids=inputs['input_ids'],
    max_length=50,
    num_beams=4,
    early_stopping=True
)

response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
print("Response:", response)



Response: I like to read books.
