<a href="https://colab.research.google.com/github/shannn1/goodRAG/blob/main/train_2_0.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from transformers import (
    AutoModel,
    AutoModelForSeq2SeqLM,
    RagTokenizer,
    RagRetriever,
    RagSequenceForGeneration,
    RagConfig,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    DataCollatorForSeq2Seq,
)
import torch
from datasets import Dataset

In [None]:
##这里需要一些data的处理，包括生成知识库路径和索引文件路径



In [None]:


question_encoder = AutoModel.from_pretrained("facebook/rag-sequence-nq")
generator = AutoModelForSeq2SeqLM.from_pretrained("generator_model_path") #e.g.,"google-t5/t5-small"

retriever = RagRetriever.from_pretrained(
    "facebook/rag-sequence-nq",  # 使用默认配置
    index_name="custom",         # 自定义索引
    passages_path="path/to/your/passages",  # 知识库路径
    index_path="path/to/your/index.faiss"  # 索引文件路径
)

rag_config = RagConfig.from_question_encoder_generator_configs(
    question_encoder.config,
    generator.config,
    n_docs=5,  # 检索的文档数量，可以调整
)

model = RagSequenceForGeneration(
    config=rag_config,
    question_encoder=question_encoder,
    generator=generator,
    retriever=retriever
)

tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq")

In [None]:
data = {
    "questions": ["How many people live in Paris?", "What is the capital of France?"],
    "answers": ["10 million", "Paris"],
}
dataset = Dataset.from_dict(data)

def preprocess_function(examples):
    questions = examples['question']
    answers = examples['answer']
    encoded_questions = tokenizer.question_encoder(
        questions,
        max_length=512,
        truncation=True,
        padding='max_length',
        return_tensors="pt"
    )
    encoded_answers = tokenizer.generator(
        answers,
        max_length=512,
        truncation=True,
        padding='max_length',
        return_tensors="pt"
    )
    encoded_questions["labels"] = encoded_answers["input_ids"]
    return encoded_questions

tokenized_train = qa_train_dataset.map(preprocess_function, batched=True)
tokenized_train = tokenized_train.remove_columns(['question', 'answer'])
tokenized_train.set_format('torch')

collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)

In [None]:
training_args = Seq2SeqTrainingArguments(
    output_dir="./rag_output",
    evaluation_strategy="steps",
    learning_rate=3e-5,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    num_train_epochs=3,
    save_steps=10,
    logging_dir="./logs",
    predict_with_generate=True,
)
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    tokenizer=tokenizer,
    data_collator=collator,
)
trainer.train()
trainer.save_model("./rag_trained_model")

In [None]:
##以下是推理


inputs = tokenizer("How many people live in Paris?", return_tensors="pt")
targets = tokenizer(text_target="In Paris, there are 10 million people.", return_tensors="pt")
input_ids = inputs["input_ids"]
labels = targets["input_ids"]
outputs = model(input_ids=input_ids, labels=labels)