In [None]:
!pip install pandas torch transformers scikit-learn tqdm faiss-cpu sentence-transformers

In [None]:
import sys
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import T5Tokenizer, T5ForConditionalGeneration
import argparse
from sklearn.model_selection import train_test_split
import torch.optim as optim
from tqdm.auto import tqdm


class RAGDataset(Dataset):
    """
    Custom Dataset class for loading query-context-answer pairs for training the RAG model.
    This class handles tokenizing the data and preparing it for PyTorch's DataLoader.
    """
    def __init__(self, dataframe, tokenizer, source_len, target_len):
        """
        Initialize the dataset.
        
        Args:
            dataframe (pd.DataFrame): The dataset containing query, context, and answer columns.
            tokenizer (transformers.PreTrainedTokenizer): Tokenizer for encoding text.
            source_len (int): Maximum length for the input sequence.
            target_len (int): Maximum length for the target sequence.
        """
        self.data = dataframe.reset_index(drop=True)
        self.tokenizer = tokenizer
        self.source_len = source_len
        self.target_len = target_len
        self.query = self.data['query']
        self.context = self.data['context']
        self.answer = self.data['answer']

    def __len__(self):
        """
        Returns:
            int: The number of samples in the dataset.
        """
        return len(self.data)

    def __getitem__(self, idx):
        """
        Retrieve a single data point from the dataset.
        
        Args:
            idx (int): Index of the data point.
        
        Returns:
            dict: Dictionary containing tokenized input and target sequences.
        """
        query = str(self.query[idx])
        context = str(self.context[idx])
        answer = str(self.answer[idx])

        # combine query and context into a single input string
        source_text = f"query: {query} context: {context}"
        
        # tokenize the input string
        source = self.tokenizer.encode_plus(
            source_text, max_length=self.source_len, padding="max_length", truncation=True, return_tensors="pt"
        )
        # tokenize the answer string
        target = self.tokenizer.encode_plus(
            answer, max_length=self.target_len, padding="max_length", truncation=True, return_tensors="pt"
        )

        return {
            "input_ids": source["input_ids"].squeeze(),
            "attention_mask": source["attention_mask"].squeeze(),
            "labels": target["input_ids"].squeeze(),
        }


def preprocess_data(file_path):
    """
    Preprocess the dataset to include required columns and handle missing values.
    
    Args:
        file_path (str): Path to the dataset CSV file.
    
    Returns:
        pd.DataFrame: Preprocessed dataframe with 'query', 'context', and 'answer' columns.
    """
    # load the CSV file
    df = pd.read_csv(file_path)
    
    # retain only the 'question' and 'answer' columns
    df = df[['question', 'answer']]
    
    # drop rows with missing values
    df = df.dropna(subset=['question', 'answer'])
    
    # rename columns for consistency
    df = df.rename(columns={'question': 'query', 'answer': 'answer'})
    
    # add a 'context' column (using the answer as context for now)
    df['context'] = df['answer']
    return df


def train_epoch(model, loader, optimizer, device, epoch, logging_steps):
    """
    Train the model for one epoch.
    
    Args:
        model (torch.nn.Module): The model being trained.
        loader (DataLoader): DataLoader for the training data.
        optimizer (torch.optim.Optimizer): Optimizer for updating model weights.
        device (torch.device): Device to run the model on (CPU, GPU, etc.).
        epoch (int): Current epoch number.
        logging_steps (int): Frequency of logging progress during training.
    
    Returns:
        float: The average training loss for the epoch.
    """
    model.train()  # set the model to training mode
    total_loss = 0  # initialize total loss
    progress_bar = tqdm(loader, desc=f"Epoch {epoch}", disable=False)  # progress bar for tracking

    for step, batch in enumerate(progress_bar):
        # move inputs and labels to the specified device
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)

        # forward pass through the model
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        total_loss += loss.item()

        # backward pass and optimization step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # log the loss every `logging_steps`
        if (step + 1) % logging_steps == 0:
            progress_bar.set_postfix({"loss": loss.item()})

    # return the average loss for the epoch
    return total_loss / len(loader)


def main():
    """
    Main function to fine-tune the T5 model for Retrieval-Augmented Generation (RAG).
    This version handles Jupyter Notebook's extra arguments gracefully.
    """
    # simulating command-line arguments for Jupyter Notebook
    class Args:
        model_name = "t5-base"
        train_file = "/kaggle/input/layoutlm/medquad.csv"
        output_dir = "rag_model"
        batch_size = 8
        epochs = 3
        lr = 5e-5
        max_input_length = 512
        max_output_length = 150
        device = "cuda"
        logging_steps = 10

    args = Args()  # use the custom Args class to store arguments

    # enhanced device selection logic
    if args.device == "mps" and torch.backends.mps.is_available():
        device = torch.device("mps")
    elif args.device == "cuda" and torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")
    print(f"Using device: {device}")

    # preprocess the data
    df = preprocess_data(args.train_file)

    # split data into training and validation sets
    train_df, val_df = train_test_split(df, test_size=0.1, random_state=42)

    # load the tokenizer and model
    tokenizer = T5Tokenizer.from_pretrained(args.model_name, legacy=False)
    model = T5ForConditionalGeneration.from_pretrained(args.model_name).to(device)

    # create DataLoaders for training and validation datasets
    train_dataset = RAGDataset(train_df, tokenizer, args.max_input_length, args.max_output_length)
    val_dataset = RAGDataset(val_df, tokenizer, args.max_input_length, args.max_output_length)
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False)

    # optimizer
    optimizer = optim.AdamW(model.parameters(), lr=args.lr)

    # training loop
    for epoch in range(1, args.epochs + 1):
        train_loss = train_epoch(model, train_loader, optimizer, device, epoch, args.logging_steps)
        print(f"Epoch {epoch} Training Loss: {train_loss:.4f}")

    # save the model and tokenizer
    model.save_pretrained(args.output_dir)
    tokenizer.save_pretrained(args.output_dir)
    print(f"Model saved to {args.output_dir}")

In [None]:
import pandas as pd
import faiss
from sentence_transformers import SentenceTransformer

# define parameters
csv_file = "/kaggle/input/layoutlm/medquad.csv"  # path to your dataset
updated_csv_file = "/kaggle/working/medquad_with_context.csv"  # output dataset path
index_file = "/kaggle/working/context.index"  # path to save the FAISS index

# load the dataset
df = pd.read_csv(csv_file)

# add a 'context' column if it doesn't exist
if 'context' not in df.columns:
    print("No 'context' column found. Creating it from the 'answer' column.")
    if 'answer' not in df.columns:
        raise ValueError("The dataset must have an 'answer' column to create the 'context'.")
    df['context'] = df['answer']  # use 'answer' as the context

# save the updated dataset with the 'context' column
df.to_csv(updated_csv_file, index=False)
print(f"Updated dataset with 'context' column saved to {updated_csv_file}")

# use SentenceTransformer to generate embeddings for the context column
embedder = SentenceTransformer("all-MiniLM-L6-v2")
contexts = df["context"].tolist()
context_embeddings = embedder.encode(contexts, convert_to_tensor=False).astype("float32")

# create a FAISS index for the embeddings
dimension = context_embeddings.shape[1]
index = faiss.IndexFlatL2(dimension)  # L2 (Euclidean) distance
index.add(context_embeddings)

# save the FAISS index
faiss.write_index(index, index_file)
print(f"FAISS index saved to {index_file}")

In [None]:
import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration
import faiss
import pandas as pd
from sentence_transformers import SentenceTransformer


def load_index(index_file, csv_file):
    """
    Load the FAISS index and the associated dataset.

    Args:
        index_file (str): Path to the FAISS index file.
        csv_file (str): Path to the CSV file containing the dataset.

    Returns:
        faiss.IndexFlatL2: The loaded FAISS index.
        pd.DataFrame: The dataset containing queries, contexts, and answers.
    """
    df = pd.read_csv(csv_file)  # load the dataset
    index = faiss.read_index(index_file)  # load the FAISS index
    return index, df


def retrieve_context(query, index, df, embedder, top_k=1):
    """
    Retrieve the most relevant context(s) from the FAISS index based on the query.

    Args:
        query (str): The user's input question.
        index (faiss.IndexFlatL2): The FAISS index for retrieval.
        df (pd.DataFrame): The dataset to retrieve contexts from.
        embedder (SentenceTransformer): The embedding model to encode the query.
        top_k (int): Number of top contexts to retrieve.

    Returns:
        list: A list of retrieved contexts.
    """
    query_vector = embedder.encode([query]).astype("float32")  # embed the query
    distances, indices = index.search(query_vector, top_k)  # search the FAISS index
    return [df.iloc[i]["context"] for i in indices[0]]  # retrieve contexts by index


# simulated arguments for the Jupyter Notebook
args = {
    "query": "What are the symptoms of diabetes?",
    "model_dir": "rag_model",  # fine-tuned model directory
    "csv_file": "/kaggle/working/medquad_with_context.csv",  # CSV file containing the dataset
    "index_file": "/kaggle/working/context.index",  # FAISS index file
    "device": "cuda",  # device to run the inference on
    "top_k": 1  # number of top contexts to retrieve
}

# enhanced device selection logic
if args["device"] == "mps" and torch.backends.mps.is_available():
    device = torch.device("mps")
elif args["device"] == "cuda" and torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
print(f"Using device: {device}")

# load the fine-tuned model and tokenizer
tokenizer = T5Tokenizer.from_pretrained(args["model_dir"], legacy=False)
model = T5ForConditionalGeneration.from_pretrained(args["model_dir"]).to(device)

# load the FAISS index and dataset
embedder = SentenceTransformer("all-MiniLM-L6-v2")  # use a sentence embedding model
index, df = load_index(args["index_file"], args["csv_file"])

# retrieve the most relevant context for the input query
contexts = retrieve_context(args["query"], index, df, embedder, args["top_k"])
input_text = f"query: {args['query']} context: {' '.join(contexts)}"

# generate the answer using the fine-tuned model
input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)
input_length = len(input_ids[0])  # length of the input query + context
outputs = model.generate(
    input_ids,
    max_length=input_length + 150,  # allow the model to generate a longer response
    num_beams=5,
    no_repeat_ngram_size=2
)
answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
if "." in answer:
    answer = answer[:answer.rfind(".") + 1]  # trim to the last full sentence

# display the results
print(f"Query: {args['query']}")
print(f"Retrieved Context: {' '.join(contexts)}")
print(f"Generated Answer: {answer}")