In [17]:
import torch
from transformers import BartTokenizer, BartForConditionalGeneration

In [18]:
# Load the BART model and tokenizer
tokenizer = BartTokenizer.from_pretrained("facebook/bart-large")
add_special_tokens = {'additional_special_tokens': ['<query>', '<response>', '<latent>', '<persona>']}
tokenizer.add_special_tokens(add_special_tokens)

4

In [19]:
args = {
   "num_latent": 10,
   "num_latent2": 10,
}

encoded_tokens = tokenizer.encode(add_special_tokens['additional_special_tokens'])
args['bos'], args["query"], args["response"], args["latent"], args["persona"], args["eos"] = encoded_tokens
args["pad"] = 1

In [20]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_gpus = torch.cuda.device_count()

print(f"Using {num_gpus} GPUs")

Using 0 GPUs


In [21]:
import pickle
# Load the pickle file
with open('./Synthetic-Persona-Chat/prepared_data/test_data.pkl', 'rb') as f:
    data_set = pickle.load(f)

# Now `data_set` contains your loaded data
print(data_set[0])


{'input_ids': [50267, 50268, 100, 95, 2162, 10, 1518, 92, 790, 4, 2, 100, 101, 7, 3836, 23, 5, 950, 4, 2, 100, 422, 10, 2335, 41227, 334, 4, 2, 100, 33, 10, 380, 4045, 13495, 4, 2, 100, 101, 602, 8, 6016, 842, 462, 16731, 4, 2, 50265, 30086, 6, 38, 437, 646, 44518, 112, 18, 766, 8174, 653, 18, 110, 766, 116, 2], 'labels': [50266, 30086, 6, 38, 437, 646, 44518, 132, 18, 766, 8174, 85, 18, 2579, 7, 972, 47, 4, 2]}


In [22]:
from torch.utils.data import DataLoader, Dataset, RandomSampler
# Convert to torch dataset
class CNNDailyMailDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        input_ids = torch.tensor(item['input_ids'])
        # attention_mask = torch.tensor(item['attention_mask'])
        labels = torch.tensor(item['labels'])
         

        decoder_input_ids = labels
        # labels = labels[1:]

        # decoder_input_ids = labels[:-1]
        # labels = labels[1:]

        return {
            'input_ids': input_ids,
            # 'attention_mask': attention_mask,
            'decoder_input_ids': decoder_input_ids,
            'labels': labels
        }


In [26]:
from torch.nn.utils.rnn import pad_sequence

class CollateFn:
    def __init__(self, pad_token_id):
        self.pad_token_id = pad_token_id

    def __call__(self, batch):
        input_ids = [item['input_ids'] for item in batch]
        decoder_input_ids = [item['decoder_input_ids'] for item in batch]
        labels = [item['labels'] for item in batch]

        input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.pad_token_id)
        decoder_input_ids = pad_sequence(decoder_input_ids, batch_first=True, padding_value=self.pad_token_id)
        labels = pad_sequence(labels, batch_first=True, padding_value=-100)  # ignore index

        return {
            'input_ids': input_ids,
            'decoder_input_ids': decoder_input_ids,
            'labels': labels
        }


In [30]:
pad_token_id = args['pad']
# Create the dataset and dataloader
train_dataset = CNNDailyMailDataset(data_set)

# Create a collate function to pad the sequences
pad_fn = CollateFn(pad_token_id)
sampler = RandomSampler(train_dataset)
train_loader = DataLoader(train_dataset,batch_size=2,collate_fn=pad_fn,sampler=sampler)

In [31]:
for batch in train_loader:
    # Move the batch to the device
    input_ids = batch['input_ids'].to(device)
    decoder_input_ids = batch['decoder_input_ids'].to(device)
    labels = batch['labels'].to(device)

    # Print the shapes of the tensors
    print(f"Input ID: {input_ids}")
    print(f"Decoder Input IDs : {decoder_input_ids}")
    print(f"Labels: {labels}")
    break

Input ID: tensor([[50267, 50268,   100,   524,    10,  9613,     4,     2,  2387,  2674,
          1971,    16,   295,   853, 40949,     4,     2,   100,   101,   878,
             4,     2,   100,   173,    23,  4716,  2793,     4,     2, 50265,
           100,   101,     7,   213,     7,     5,  4105,     6,   310,   569,
           426,     6,     8,  1166,     4,     2, 50266,   100,   101,     7,
           213,     7,     5,  4105,   350,   328,    38,    67,   101,     7,
           213,  5651,     8,  7358,     4,     2, 50265,  7516,     6,  3035,
           328,    38,   657,  5651,   350,     4,     2, 50266,  5096,   350,
           328,     2, 50265,  2847,     6,    99,   109,    47,   206,     9,
          1261,    98,   444,   116,     2,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1],
        [50267, 50268,   100,   101,     7, 29716, 15146,    13,  7272,    11,
             5,  1098,     4,     2,   100,   10

In [32]:
# Load the model with the updated tokenizer
model = BartForConditionalGeneration.from_pretrained("facebook/bart-base")
model.resize_token_embeddings(len(tokenizer))

BartScaledWordEmbedding(50269, 768, padding_idx=1)

In [None]:
from tqdm import tqdm
import torch

optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

num_epochs = 1  # for example

for epoch in range(num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    model.train()
    total_loss = 0.0

    progress_bar = tqdm(enumerate(train_loader), total=len(train_loader), desc="Training")

    for batch_idx, batch in progress_bar:
        input_ids = batch['input_ids'].to(device)
        labels = batch['labels'].to(device)

        outputs = model(input_ids=input_ids, labels=labels)

        loss = outputs.loss.mean()  # Average loss over the batch
        # loss = outputs.loss
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        total_loss += loss.item()

        # Calculate perplexity for the current batch
        perplexity = torch.exp(loss).item()

        # Update progress bar description
        progress_bar.set_postfix({
            'Batch': batch_idx+1,
            'Loss': f"{loss.item():.4f}",
            'Perplexity': f"{perplexity:.4f}"
        })

    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1} completed | Average Loss: {avg_loss:.4f}")



Epoch 1/3
Batch 10 | Loss: 4.2847 | Perplexity: 72.5843


KeyboardInterrupt: 

In [None]:
# Save the fine-tuned model
save_path = "gpt2-multi-gpu" if num_gpus > 1 else "gpt2-single-gpu"
model.module.save_pretrained(save_path) if num_gpus > 1 else model.save_pretrained(save_path)
tokenizer.save_pretrained(save_path)

print(f"Model saved at {save_path}")


In [70]:
query = "<latent><persona> I love hiking and books. <query> What's your hobby?"
inputs = tokenizer(query, return_tensors='pt').to(device)

output_ids = model.generate(
    input_ids=inputs['input_ids'],
    max_length=50,
    num_beams=4,
    early_stopping=True
)

response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
print("Response:", response)



Response: I like to read books.
