In [1]:
%pip install transformers
from transformers import T5ForConditionalGeneration, T5Tokenizer, Trainer, TrainingArguments
import torch
 

 

 


Note: you may need to restart the kernel to use updated packages.


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
sample_data = [
    {"customer": "I'm having trouble logging into my account.", "agent_response": "Sure, could you please provide your email address?"},
    {"customer": "My order hasn't arrived yet.", "agent_response": "Can you please provide your order number so I can check the status?"},
    {"customer": "I need to update my billing information.", "agent_response": "Sure, could you please provide the new billing address?"},
    {"customer": "I accidentally deleted my account. Can you help me recover it?", "agent_response": "Of course, could you please provide your username or email associated with the account?"},
    {"customer": "I received a damaged product. What should I do?", "agent_response": "I apologize for the inconvenience. Could you please provide a photo of the damaged product and your order number?"},
    {"customer": "I'm unable to add items to my cart on the website.", "agent_response": "I'm sorry to hear that. Could you please try clearing your browser's cache and cookies, and then try again?"},
    {"customer": "Do you offer international shipping?", "agent_response": "Yes, we do offer international shipping. Could you please provide your shipping address so I can check the shipping options and costs?"},
    {"customer": "I need assistance with tracking my package.", "agent_response": "Sure, could you please provide your tracking number?"},
    {"customer": "I want to cancel my subscription.", "agent_response": "I'm sorry to hear that. Could you please provide your account details so I can assist you with canceling your subscription?"},
    {"customer": "How can I change my password?", "agent_response": "You can change your password by visiting the account settings page. Would you like me to guide you through it?"},
    {"customer": "I forgot my password. Can you help me reset it?", "agent_response": "Of course! Could you please provide the email address associated with your account so I can send you a password reset link?"},
    {"customer": "What payment methods do you accept?", "agent_response": "We accept various payment methods including credit/debit cards, PayPal, and bank transfers. Is there a specific method you're interested in?"},
    {"customer": "I'm having trouble placing an order on the website.", "agent_response": "I apologize for the inconvenience. Could you please describe the issue you're facing during the checkout process?"},
    {"customer": "Is there a way to expedite shipping for my order?", "agent_response": "Yes, we offer expedited shipping options. Could you please provide your order number so I can check availability and provide you with the options?"},
    {"customer": "Can I return an item if I'm not satisfied with it?", "agent_response": "Yes, you can return items within our return policy period. Could you please provide your order number so I can assist you with the return process?"},
    {"customer": "I need help with installing the software I purchased.", "agent_response": "Sure! Could you please specify the software you purchased and the device you're trying to install it on?"},
    {"customer": "What are your business hours?", "agent_response": "Our business hours are from [insert hours here]. Is there anything specific you need assistance with right now?"},
    {"customer": "I'm experiencing technical issues with your app.", "agent_response": "I'm sorry for the inconvenience. Could you please describe the issue you're encountering so I can assist you further?"},
    {"customer": "I'm interested in your product/service. Can you provide more information?", "agent_response": "Absolutely! Could you please specify which product or service you're interested in, so I can provide you with detailed information?"},
    {"customer": "How can I update my shipping address for an existing order?", "agent_response": "You can update your shipping address by contacting our customer support team with your order number and the new shipping address details."},
    {"customer": "I'm trying to reach your customer support but unable to connect.", "agent_response": "I apologize for the inconvenience. Our lines might be busy. Could you please try again later or provide me with your contact information, and I'll have someone reach out to you?"},
    {"customer": "I'm interested in your product, but I have some specific requirements. Can you accommodate them?", "agent_response": "We strive to meet our customers' needs. Could you please specify your requirements so I can check if we can accommodate them?"},
    {"customer": "I accidentally placed the wrong order. How can I change it?", "agent_response": "I'm sorry to hear that. Could you please provide your order number and the changes you'd like to make, and I'll see what I can do to assist you?"},
    {"customer": "Is there a warranty on the products you sell?", "agent_response": "Yes, most of our products come with a warranty. Could you please specify which product you're inquiring about so I can provide you with warranty details?"},
    {"customer": "I'm having difficulties accessing certain features on your website.", "agent_response": "I apologize for the inconvenience. Could you please specify which features you're having trouble accessing so I can assist you further?"},
    {"customer": "I'm interested in bulk ordering. Do you offer discounts for bulk purchases?", "agent_response": "Yes, we do offer discounts for bulk orders. Could you please provide more details about the quantity you're interested in, so I can provide you with a quote?"},
    {"customer": "How long does it usually take to receive a response from customer support?", "agent_response": "We strive to respond to all inquiries within [insert timeframe]. If you haven't received a response within this time, please let me know, and I'll expedite your request."},
    {"customer": "I received the wrong item in my order. What should I do?", "agent_response": "I apologize for the mistake. Could you please provide your order number and details of the item you received incorrectly, so I can arrange for a replacement or refund?"},
    {"customer": "I'm having trouble applying a discount code during checkout.", "agent_response": "I'm sorry for the inconvenience. Could you please provide the discount code you're trying to apply and describe the issue you're encountering so I can assist you further?"},
    {"customer": "I'm interested in becoming a reseller of your products. Do you have a reseller program?", "agent_response": "Yes, we do have a reseller program. Could you please provide your contact information, and I'll have someone from our sales team reach out to you with more details?"},
]


In [4]:
# Modify the data format
train_data = []
for i in range(len(sample_data) - 1):
    context = " ".join([conv["customer"] for conv in sample_data[:i+1]])
    target = sample_data[i+1]["agent_response"]
    train_data.append({"context": context, "target": target})

# Define a new CustomDataset class for next best suggestion prediction
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, data, tokenizer, max_length=512):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length
 
    def __len__(self):
        return len(self.data)
 
    def __getitem__(self, idx):
        item = self.data[idx]
        context = item['context']
        target = item['target']
 
        input_encoding = self.tokenizer(context, truncation=True, max_length=self.max_length, padding="max_length", return_tensors="pt")
        target_encoding = self.tokenizer(target, truncation=True, max_length=self.max_length, padding="max_length", return_tensors="pt")
 
        labels = target_encoding["input_ids"]
        labels[labels == 0] = -100  # Ignore padding tokens in loss calculation
 
        return {
            "input_ids": input_encoding["input_ids"].flatten(),
            "attention_mask": input_encoding["attention_mask"].flatten(),
            "labels": labels.flatten(),
        }


model_name = "t5-small"  # or "t5-base", "t5-large", etc.
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)
# Update training dataset
train_dataset = CustomDataset(train_data, tokenizer)



You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [5]:
# Define training arguments
training_args = TrainingArguments(
    per_device_train_batch_size=4,
    num_train_epochs=3,
    logging_dir='./logs',
    logging_steps=100,
    save_steps=1000,
    evaluation_strategy="steps",
    eval_steps=500,
    output_dir='./models'
)

# Define trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    tokenizer=tokenizer,
)

# Fine-tune the model
trainer.train()

# Save the fine-tuned model
trainer.save_model("./fine_tuned_model")


 88%|████████▊ | 21/24 [06:51<01:03, 21.29s/it]

In [None]:
# Load the fine-tuned model
fine_tuned_model = T5ForConditionalGeneration.from_pretrained("./fine_tuned_model")

# Function to generate response
def generate_response(user_dialogue, fine_tuned_model, tokenizer, max_length=512):
    # Prepare input dialogue
    input_text = "conversation history: " + user_dialogue
    input_encoding = tokenizer(input_text, truncation=True, max_length=max_length, padding="max_length", return_tensors="pt")
    
    # Generate response
    output_ids = fine_tuned_model.generate(input_ids=input_encoding["input_ids"], 
                                           attention_mask=input_encoding["attention_mask"], 
                                           max_length=max_length, 
                                           num_beams=5, 
                                           early_stopping=True)
    
    # Decode generated response
    generated_response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    
    return generated_response

# Example usage
user_dialogue = "I'm having trouble logging into my account."
response = generate_response(user_dialogue, fine_tuned_model, tokenizer)
print("Generated response:", response)
