# MCQ Generation with T5 Fine-Tuning and Classical Distractor Pipeline




In [None]:
!pip install spacy sense2vec requests nltk sentence-transformers transformers datasets --quiet
!python -m spacy download en_core_web_sm

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m54.5/54.5 kB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m40.6/40.6 kB[0m [31m838.8 kB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m485.4/485.4 kB[0m [31m12.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m5.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m143.5/143.5 kB[0m [31m5.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m4.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m37.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m39.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
!wget https://github.com/explosion/sense2vec/releases/download/v1.0.0/s2v_reddit_2015_md.tar.gz
!tar -xvf s2v_reddit_2015_md.tar.gz
!mv s2v_reddit_2015_md s2v_old


--2025-02-25 16:39:12--  https://github.com/explosion/sense2vec/releases/download/v1.0.0/s2v_reddit_2015_md.tar.gz
Resolving github.com (github.com)... 140.82.114.3
Connecting to github.com (github.com)|140.82.114.3|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://objects.githubusercontent.com/github-production-release-asset-2e65be/50261113/52126080-0993-11ea-8190-8f0e295df22a?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=releaseassetproduction%2F20250225%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20250225T163912Z&X-Amz-Expires=300&X-Amz-Signature=6403d3d3fc6d6e10b372b367cb08168dea61ba624d978c3b1fa10e6915f38876&X-Amz-SignedHeaders=host&response-content-disposition=attachment%3B%20filename%3Ds2v_reddit_2015_md.tar.gz&response-content-type=application%2Foctet-stream [following]
--2025-02-25 16:39:12--  https://objects.githubusercontent.com/github-production-release-asset-2e65be/50261113/52126080-0993-11ea-8190-8f0e295df22a?X-Amz-Algorithm=AWS4-HM

## 1. Imports & Basic Setup

In [None]:
import nltk
import re
import random
import requests
import spacy
import json
import torch
from torch.utils.data import Dataset

from datasets import load_dataset
from sentence_transformers import SentenceTransformer, util
from sense2vec import Sense2Vec

from transformers import (
    T5Tokenizer,
    T5ForConditionalGeneration,
    Trainer,
    TrainingArguments
)

# Download NLTK data
nltk.download('punkt')
nltk.download('wordnet')
nltk.download('averaged_perceptron_tagger')

# Load spaCy model
nlp = spacy.load("en_core_web_sm")

# Load SBERT for distractor re-ranking
sbert_model = SentenceTransformer('all-MiniLM-L6-v2')

# Attempt to load local Sense2Vec model
try:
    s2v = Sense2Vec().from_disk("s2v_old")  # folder with sense2vec data
except:
    print("Sense2Vec model folder 's2v_old' not found. Distractor generation with sense2vec may be partial.")
    s2v = None

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

## 2. Prepare RACE & SQuAD for T5

We'll load **RACE** (all config) and **SQuAD** from Hugging Face, converting each sample into the format:
```
context => "Generate MCQ: <passage>"
target  => "Question: <question> Answer: <answer>"
```
We'll combine them in a single JSONL file for training.

In [None]:
def process_race(split="train"):
    """
    Loads the RACE dataset (e.g., 'race', 'all'), uses the specified split.
    Returns a list of dicts: {context, target, domain}.
    """
    race_ds = load_dataset("race", "all", split=split)
    samples = []
    for item in race_ds:
        article = item.get("article", "").strip()
        question = item.get("question", "").strip()
        options = item.get("options", [])
        answer_field = item.get("answer", None)

        # If the answer is an integer index, map it to the actual option text
        if isinstance(answer_field, int) and options:
            correct_ans = options[answer_field].strip()
        else:
            if answer_field is not None:
                correct_ans = str(answer_field).strip()
            else:
                correct_ans = ""

        context_str = "Generate MCQ: " + article
        target_str = f"Question: {question} Answer: {correct_ans}"

        samples.append({
            "context": context_str,
            "target": target_str,
            "domain": "english"
        })
    return samples

def process_squad(split="train"):
    """
    Loads SQuAD from Hugging Face, uses the specified split.
    Returns a list of dicts: {context, target, domain}.
    """
    squad_ds = load_dataset("squad", split=split)
    samples = []
    for item in squad_ds:
        context_text = item.get("context", "").strip()
        question = item.get("question", "").strip()
        answers = item.get("answers", {}).get("text", [])
        if answers:
            answer_str = answers[0].strip()
        else:
            answer_str = ""

        context_str = "Generate MCQ: " + context_text
        target_str = f"Question: {question} Answer: {answer_str}"

        samples.append({
            "context": context_str,
            "target": target_str,
            "domain": "english"
        })
    return samples

def combine_datasets(race_split="train[:2000]", squad_split="train[:2000]"):
    """
    Combine samples from RACE and SQuAD.
    Adjust the splits as desired.
    """
    race_samples = process_race(split=race_split)
    squad_samples = process_squad(split=squad_split)
    combined = race_samples + squad_samples
    return combined

### Create & Save the Unified JSONL

In [None]:
# Combine RACE + SQuAD (using slices for demo)
all_samples = combine_datasets(
    race_split="train[:20000]",  # or 'train' to use the entire dataset
    squad_split="train[:20000]"   # or 'train' to use the entire dataset
)

output_file = "/content/drive/MyDrive/MinorProject/unified_mcq_dataset.jsonl"
with open(output_file, "w", encoding="utf-8") as f:
    for sample in all_samples:
        f.write(json.dumps(sample) + "\n")

print(f"Unified dataset saved to {output_file} with {len(all_samples)} samples.")

## 3. `MCQDataset` for Fine-Tuning
We'll define a custom `Dataset` that:
- Reads each JSON line,
- Tokenizes `context` as the T5 **input**,
- Tokenizes `target` as the T5 **labels**.

In [None]:
class MCQDataset(Dataset):
    def __init__(
        self,
        jsonl_file,
        tokenizer,
        max_source_length=512,
        max_target_length=128
    ):
        self.samples = []
        with open(jsonl_file, "r", encoding="utf-8") as f:
            for line in f:
                self.samples.append(json.loads(line))

        self.tokenizer = tokenizer
        self.max_source_length = max_source_length
        self.max_target_length = max_target_length

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]
        source_text = sample["context"]
        target_text = sample["target"]

        # Tokenize the source
        source_enc = self.tokenizer(
            source_text,
            max_length=self.max_source_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        # Tokenize the target
        target_enc = self.tokenizer(
            target_text,
            max_length=self.max_target_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        return {
            'input_ids': source_enc['input_ids'].squeeze(),
            'attention_mask': source_enc['attention_mask'].squeeze(),
            'labels': target_enc['input_ids'].squeeze()
        }

## 4. Fine-Tune T5 on Our Combined Dataset

In [None]:
# 1) Load T5 tokenizer + model
model_name = "t5-base"  # could use 't5-small' for faster training or 't5-large' for better capacity
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name).to(device)

# 2) Create dataset
jsonl_path = "/content/drive/MyDrive/MinorProject/unified_mcq_dataset.jsonl"
dataset = MCQDataset(jsonl_path, tokenizer)
print("Total samples:", len(dataset))

# 3) Split into train/val
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
print("Train size:", len(train_dataset), "Validation size:", len(val_dataset))

# 4) Define training arguments
training_args = TrainingArguments(
    output_dir="/content/drive/MyDrive/MinorProject/t5_mcq_finetuned",
    num_train_epochs=1,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    warmup_steps=200,
    weight_decay=0.01,
    logging_dir="/content/drive/MyDrive/MinorProject/logs",
    logging_steps=50,
    evaluation_strategy="steps",
    eval_steps=200,
    save_steps=200,
    save_total_limit=2,
    fp16=torch.cuda.is_available(),
    push_to_hub=False
)

# 5) Create Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset
)

# 6) Train!
trainer.train()

## 5. Save the Fine-Tuned Model

In [None]:
# After training finishes, save the model & tokenizer
trainer.save_model("/content/drive/MyDrive/MinorProject/t5_mcq_finetuned")
tokenizer.save_pretrained("/content/drive/MyDrive/MinorProject/t5_mcq_finetuned")
print("Model saved in '/content/drive/MyDrive/MinorProject/t5_mcq_finetuned' folder.")

## 6. Classical Distractor Pipeline
We now define all the functions for **question generation** (rule-based) and **distractor** generation, **filtering**, and **re-ranking**.
However, remember: Our T5 model will produce text in the format:
```
"Question: <question> Answer: <answer>"
```
In the final step, we can combine these approaches so we can do either:
1. Let T5 produce a question + short answer, then generate distractors.
2. Or directly do the classical approach on each sentence (the code below shows how).

In [None]:
# Utility for text cleanup & sentence splitting
def clean_text(text: str) -> str:
    text = text.replace("\n", " ")
    text = re.sub(r"\s+", " ", text).strip()
    text = text.replace("–", "-")
    return text

def split_into_sentences(text: str):
    text = clean_text(text)
    sentences = nltk.sent_tokenize(text)
    return [s.strip() for s in sentences if s.strip()]

def pick_top_sentences(sentences, num=5):
    """ Simple heuristic: pick the longest sentences first. """
    sorted_sents = sorted(sentences, key=lambda x: len(x), reverse=True)
    return sorted_sents[:num]

# ---- Basic Rule-based Q Generation (Optional) ----
def generate_question_from_sentence(sentence):
    doc = nlp(sentence)

    target_span = None
    answer_text = None
    question_word = "What"  # default

    ent_label_to_qw = {
        "PERSON": "Who",
        "GPE": "Where",
        "LOC": "Where",
        "ORG": "What",
        "DATE": "When",
        "TIME": "When",
    }

    # 1) Named entities
    for ent in doc.ents:
        answer_text = ent.text
        if ent.label_ in ent_label_to_qw:
            question_word = ent_label_to_qw[ent.label_]
        target_span = ent
        break

    # 2) Fallback to noun chunks
    if not target_span:
        for chunk in doc.noun_chunks:
            if len(chunk.text) > 2:
                target_span = chunk
                answer_text = chunk.text
                question_word = "What"
                break

    if not target_span:
        return None

    # Replace chunk with question word
    start = target_span.start_char
    end = target_span.end_char
    question_sentence = sentence[:start] + question_word + sentence[end:]
    question_sentence = question_sentence.strip()
    if not question_sentence.endswith("?"):
        question_sentence = question_sentence.rstrip(".") + "?"

    return (question_sentence, answer_text.strip())

In [None]:
# Distractor Generation
# 1) Time phrases
def is_time_phrase(text):
    pattern = r"(early|late|mid)\\s+(1\\d\\d0s|2\\d\\d0s)"
    return bool(re.search(pattern, text.lower()))

def generate_time_distractors(time_text, num_distractors=3):
    pattern = r"(early|late|mid)\\s+(\\d{4})s"
    match = re.search(pattern, time_text.lower())
    if not match:
        return ["in the early 2000s", "in the mid 1990s", "in the late 1980s"][:num_distractors]

    descriptor = match.group(1)
    decade_str = match.group(2)
    try:
        decade_int = int(decade_str)
    except:
        decade_int = 1990

    descriptors = ["early", "mid", "late"]
    shifts = [-1, 0, 1, 2, -2]
    candidates = []
    for desc in descriptors:
        for shift in shifts:
            new_decade = decade_int + (shift * 10)
            if desc == descriptor and new_decade == decade_int:
                continue
            cand = f"in the {desc} {new_decade}s"
            candidates.append(cand)

    random.shuffle(candidates)
    return candidates[:num_distractors]

# 2) ConceptNet
def get_conceptnet_candidates(word, language="en", limit=50):
    url = f"http://api.conceptnet.io/c/{language}/{word}?limit={limit}"
    try:
        resp = requests.get(url)
        if resp.status_code != 200:
            return []
        data = resp.json()
        candidates = []
        for edge in data.get("edges", []):
            for node in [edge.get("start", {}), edge.get("end", {})]:
                term = node.get("term", "")
                parts = term.split("/")
                if len(parts) >= 4:
                    candidate = parts[3].replace("_", " ").strip()
                    if candidate.lower() != word.lower():
                        candidates.append(candidate)
        return list(set(candidates))
    except:
        return []

# 3) WordNet
from nltk.corpus import wordnet as wn

def wordnet_candidates(word, pos=wn.NOUN):
    distractors = set()
    synsets = wn.synsets(word, pos=pos)
    if not synsets:
        return []
    for syn in synsets:
        for lemma in syn.lemmas():
            lw = lemma.name().replace("_", " ")
            if lw.lower() != word.lower():
                distractors.add(lw)

    for syn in synsets:
        for hyper in syn.hypernyms():
            for hypo in hyper.hyponyms():
                for lemma in hypo.lemmas():
                    lw = lemma.name().replace("_", " ")
                    if lw.lower() != word.lower():
                        distractors.add(lw)
    return list(distractors)

# 4) Sense2Vec
def sense2vec_candidates(tagged_word, topn=15):
    if not s2v:
        return []
    candidates = []
    try:
        sim_list = s2v.most_similar(tagged_word, n=topn)
        for cand, score in sim_list:
            cand_word = cand.split("|")[0]
            if cand_word.lower() != tagged_word.split("|")[0].lower():
                candidates.append(cand_word)
    except KeyError:
        pass
    return list(set(candidates))

# 5) Named Entity Distractors (Placeholders)
def generate_location_distractors(location_text, num=3):
    locs = ["London", "Berlin", "Tokyo", "Sydney", "New York"]
    random.shuffle(locs)
    return locs[:num]

def generate_person_distractors(person_text, num=3):
    ppl = ["Rihanna", "Lady Gaga", "Adele", "Britney Spears", "Taylor Swift"]
    random.shuffle(ppl)
    return ppl[:num]

# ---- Filtering & Re-ranking ----
def is_synonym_or_lemma(candidate, correct_answer):
    if candidate.lower() == correct_answer.lower():
        return True
    cand_syn = wn.synsets(candidate)
    ans_syn = wn.synsets(correct_answer)
    if not cand_syn or not ans_syn:
        return False
    for synset in cand_syn:
        if synset in ans_syn:
            return True
    return False

def is_partial_match(candidate, correct_answer):
    c_clean = re.sub(r"\\W+", "", candidate.lower())
    a_clean = re.sub(r"\\W+", "", correct_answer.lower())
    return (c_clean in a_clean) or (a_clean in c_clean)

def filter_candidates(candidates, correct_answer, context):
    context_lower = context.lower()
    filtered = []
    for c in candidates:
        c_lower = c.lower()
        if is_synonym_or_lemma(c, correct_answer):
            continue
        if is_partial_match(c, correct_answer):
            continue
        if c_lower in context_lower:
            continue
        if correct_answer.lower() in c_lower:
            continue
        filtered.append(c)
    return list(set(filtered))

def threshold_rerank(
    candidates,
    correct_answer,
    context,
    answer_sim_threshold=0.8,
    context_sim_threshold=0.3,
    top_k=3
):
    if not candidates:
        return []
    correct_emb = sbert_model.encode(correct_answer, convert_to_tensor=True)
    context_emb = sbert_model.encode(context, convert_to_tensor=True)
    candidate_embs = sbert_model.encode(candidates, convert_to_tensor=True)

    sim_ans = util.cos_sim(candidate_embs, correct_emb).squeeze(dim=1)
    sim_ctx = util.cos_sim(candidate_embs, context_emb).squeeze(dim=1)

    results = []
    for i, cand in enumerate(candidates):
        ans_score = float(sim_ans[i])
        ctx_score = float(sim_ctx[i])
        if ans_score < answer_sim_threshold and ctx_score > context_sim_threshold:
            final_score = ctx_score - ans_score
            results.append((cand, final_score))

    results.sort(key=lambda x: x[1], reverse=True)
    return [r[0] for r in results[:top_k]]

# 6) Final distractor function
def spacy_pos_to_wordnet_pos(spacy_pos):
    if spacy_pos.startswith('N'):
        return wn.NOUN
    elif spacy_pos.startswith('V'):
        return wn.VERB
    elif spacy_pos.startswith('J'):
        return wn.ADJ
    elif spacy_pos.startswith('R'):
        return wn.ADV
    return wn.NOUN

def extract_main_token(phrase):
    doc = nlp(phrase)
    if len(doc) == 1:
        return doc[0].text
    best_token = doc.root
    for token in doc:
        if token.pos_ in ["NOUN", "PROPN"]:
            best_token = token
            break
    return best_token.text

def generate_best_distractors(correct_answer, context, num_distractors=3):
    # 1) Time phrase
    if is_time_phrase(correct_answer):
        return generate_time_distractors(correct_answer, num_distractors)

    # 2) Named entity check
    doc_ent = nlp(correct_answer)
    if doc_ent.ents:
        ent = doc_ent.ents[0]
        label = ent.label_
        if label in ["GPE", "LOC"]:
            return generate_location_distractors(correct_answer, num_distractors)
        elif label in ["PERSON", "ORG"]:
            return generate_person_distractors(correct_answer, num_distractors)

    # 3) Single word or multiword?
    tokens = correct_answer.split()
    multiword = (len(tokens) > 1)

    if not multiword:
        # single
        doc2 = nlp(correct_answer)
        if doc2 and len(doc2) == 1:
            spacy_pos = doc2[0].tag_
            wn_pos = spacy_pos_to_wordnet_pos(spacy_pos)
        else:
            wn_pos = wn.NOUN

        cnet_cands = get_conceptnet_candidates(correct_answer)
        wn_cands = wordnet_candidates(correct_answer, pos=wn_pos)

        # sense2vec
        if wn_pos == wn.NOUN:
            s2v_tag = f"{correct_answer}|NOUN"
        elif wn_pos == wn.VERB:
            s2v_tag = f"{correct_answer}|VERB"
        elif wn_pos == wn.ADJ:
            s2v_tag = f"{correct_answer}|ADJ"
        else:
            s2v_tag = f"{correct_answer}|NOUN"

        s2v_cands = sense2vec_candidates(s2v_tag, topn=15)
        all_cands = list(set(cnet_cands + wn_cands + s2v_cands))
        filtered = filter_candidates(all_cands, correct_answer, context)
        final_distractors = threshold_rerank(filtered, correct_answer, context, top_k=num_distractors)
        return final_distractors
    else:
        # multi
        main_token = extract_main_token(correct_answer)
        doc3 = nlp(main_token)
        if doc3:
            wn_pos = spacy_pos_to_wordnet_pos(doc3[0].tag_)
        else:
            wn_pos = wn.NOUN

        cnet_cands = get_conceptnet_candidates(correct_answer)
        wn_cands = wordnet_candidates(correct_answer, wn_pos)
        if wn_pos == wn.NOUN:
            s2v_tag = f"{main_token}|NOUN"
        elif wn_pos == wn.VERB:
            s2v_tag = f"{main_token}|VERB"
        elif wn_pos == wn.ADJ:
            s2v_tag = f"{main_token}|ADJ"
        else:
            s2v_tag = f"{main_token}|NOUN"

        s2v_cands = sense2vec_candidates(s2v_tag, topn=15)
        all_cands = list(set(cnet_cands + wn_cands + s2v_cands))
        filtered = filter_candidates(all_cands, correct_answer, context)
        final_distractors = threshold_rerank(filtered, correct_answer, context, top_k=num_distractors)
        return final_distractors

In [None]:
import re
import random
import requests
import nltk
import spacy

from nltk.corpus import wordnet as wn
from sentence_transformers import SentenceTransformer, util
from sense2vec import Sense2Vec

nltk.download('punkt')
nltk.download('wordnet')
nltk.download('averaged_perceptron_tagger')

# Load spaCy (for NER and POS)
nlp = spacy.load("en_core_web_sm")

# Load SBERT for re-ranking
sbert_model = SentenceTransformer('all-MiniLM-L6-v2')

# Load Sense2Vec
s2v = Sense2Vec().from_disk("s2v_old")

# =====================================================
# 1. Utilities: POS Mapping, Time Detection, etc.
# =====================================================
def spacy_pos_to_wordnet_pos(spacy_pos):
    """
    Convert spaCy POS tag to WordNet POS constant.
    Defaults to wn.NOUN if no match.
    """
    if spacy_pos.startswith('N'):
        return wn.NOUN
    elif spacy_pos.startswith('V'):
        return wn.VERB
    elif spacy_pos.startswith('J'):
        return wn.ADJ
    elif spacy_pos.startswith('R'):
        return wn.ADV
    return wn.NOUN

def is_time_phrase(text):
    """
    Check if text matches a time expression like 'late 1990s', 'early 2000s', etc.
    For more robust detection, you could also check spaCy's NER for `DATE`.
    """
    # Example pattern: (early|late|mid) 19XXs or 20XXs, etc.
    pattern = r"(early|late|mid)\s+(1\d\d0s|2\d\d0s)"
    return bool(re.search(pattern, text.lower()))

def generate_time_distractors(time_text, num_distractors=3):
    """
    Simple approach: parse 'late 1990s' -> descriptor='late', decade=1990,
    then generate variations like 'early 1990s', 'mid 2000s', etc.
    """
    pattern = r"(early|late|mid)\s+(\d{4})s"
    match = re.search(pattern, time_text.lower())
    if not match:
        # If we can't parse it, just return some placeholders
        return ["in the early 2000s", "in the late 1980s", "in the mid 1970s"][:num_distractors]

    descriptor = match.group(1)  # early/late/mid
    decade_str = match.group(2)  # e.g., '1990'
    try:
        decade_int = int(decade_str)
    except ValueError:
        decade_int = 1990

    descriptors = ["early", "mid", "late"]
    possible_shifts = [-1, 0, 1, 2, -2]  # shift decades

    candidates = []
    for desc in descriptors:
        for shift in possible_shifts:
            new_decade = decade_int + (shift * 10)
            # Skip identical phrase (same descriptor, same decade)
            if desc == descriptor and new_decade == decade_int:
                continue
            cand = f"in the {desc} {new_decade}s"
            candidates.append(cand)

    random.shuffle(candidates)
    return candidates[:num_distractors]

# =====================================================
# 2. Candidate Retrieval from External Sources
# =====================================================
def get_conceptnet_candidates(word, language="en", limit=50):
    """
    Grab related terms from ConceptNet.
    """
    url = f"http://api.conceptnet.io/c/{language}/{word}?limit={limit}"
    response = requests.get(url)
    if response.status_code != 200:
        return []
    data = response.json()
    candidates = []
    for edge in data.get("edges", []):
        for node in [edge.get("start", {}), edge.get("end", {})]:
            term = node.get("term", "")
            parts = term.split("/")
            if len(parts) >= 4:
                candidate = parts[3].replace("_", " ").strip()
                if candidate.lower() != word.lower():
                    candidates.append(candidate)
    return list(set(candidates))

def wordnet_candidates(word, pos=wn.NOUN):
    """
    Combine synonyms + hypernyms->hyponyms from WordNet.
    """
    distractors = set()
    synsets = wn.synsets(word, pos=pos)
    if not synsets:
        return list(distractors)

    # Synonyms
    for syn in synsets:
        for lemma in syn.lemmas():
            lemma_word = lemma.name().replace("_", " ")
            if lemma_word.lower() != word.lower():
                distractors.add(lemma_word)

    # Hypernyms -> Hyponyms
    for syn in synsets:
        for hyper in syn.hypernyms():
            for hypo in hyper.hyponyms():
                for lemma in hypo.lemmas():
                    lemma_word = lemma.name().replace("_", " ")
                    if lemma_word.lower() != word.lower():
                        distractors.add(lemma_word)

    return list(distractors)

def sense2vec_candidates(tagged_word, topn=15):
    """
    If sense2vec has this sense, return top similar terms.
    Otherwise, return empty list.
    """
    candidates = []
    try:
        sim_list = s2v.most_similar(tagged_word, n=topn)
        for candidate, score in sim_list:
            candidate_word = candidate.split("|")[0]
            if candidate_word.lower() != tagged_word.split("|")[0].lower():
                candidates.append(candidate_word)
    except KeyError:
        # Sense2Vec doesn't have that sense
        pass
    return list(set(candidates))

# =====================================================
# 3. Named-Entity Distractors (e.g., for places)
# =====================================================
def generate_location_distractors(location_text, num_distractors=3):
    """
    Example custom approach for location-based answers.
    You might expand this with a bigger dictionary of cities/countries.
    """
    # Basic placeholders:
    locations = ["Los Angeles", "London", "Sydney", "Tokyo", "Berlin"]
    random.shuffle(locations)
    # Return first num_distractors
    return locations[:num_distractors]

def generate_person_distractors(person_text, num_distractors=3):
    """
    If the correct answer is a person name, you might supply other relevant names.
    This is extremely domain-specific; we do a placeholder here.
    """
    # Example placeholders:
    people = ["Rihanna", "Lady Gaga", "Adele", "Britney Spears"]
    random.shuffle(people)
    return people[:num_distractors]

# =====================================================
# 4. Advanced Filtering
# =====================================================
def is_synonym_or_lemma(candidate, correct_answer):
    """
    Check if candidate is same or a direct WordNet synonym of correct_answer.
    """
    if candidate.lower() == correct_answer.lower():
        return True
    cand_synsets = wn.synsets(candidate)
    ans_synsets = wn.synsets(correct_answer)
    if not cand_synsets or not ans_synsets:
        return False
    for synset in cand_synsets:
        if synset in ans_synsets:
            return True
    return False

def is_partial_match(candidate, correct_answer):
    """
    e.g. "photosynthetic" vs "photosynthesis"
    """
    c_clean = re.sub(r'\W+', '', candidate.lower())
    a_clean = re.sub(r'\W+', '', correct_answer.lower())
    return (c_clean in a_clean) or (a_clean in c_clean)

def filter_candidates(candidates, correct_answer, context):
    """
    - Remove synonyms or direct matches
    - Remove partial matches
    - Remove if candidate appears in the context
    - Remove multi-word containing correct answer
    """
    context_lower = context.lower()
    filtered = []
    for c in candidates:
        c_lower = c.lower()
        if is_synonym_or_lemma(c, correct_answer):
            continue
        if is_partial_match(c, correct_answer):
            continue
        if c_lower in context_lower:
            continue
        if correct_answer.lower() in c_lower:
            continue
        filtered.append(c)
    return list(set(filtered))

# =====================================================
# 5. SBERT Threshold-based Re-ranking
# =====================================================
def threshold_rerank(
    candidates,
    correct_answer,
    context,
    answer_sim_threshold=0.8,
    context_sim_threshold=0.3,
    top_k=3
):
    """
    1) Encode with SBERT
    2) Calculate sim with correct_answer and context
    3) Filter by thresholds, then rank by (context_sim - answer_sim)
    4) Return top_k
    """
    if not candidates:
        return []

    correct_emb = sbert_model.encode(correct_answer, convert_to_tensor=True)
    context_emb = sbert_model.encode(context, convert_to_tensor=True)
    candidate_embs = sbert_model.encode(candidates, convert_to_tensor=True)

    sim_ans = util.cos_sim(candidate_embs, correct_emb).squeeze(dim=1)
    sim_ctx = util.cos_sim(candidate_embs, context_emb).squeeze(dim=1)

    results = []
    for i, cand in enumerate(candidates):
        ans_score = float(sim_ans[i])
        ctx_score = float(sim_ctx[i])

        if ans_score < answer_sim_threshold and ctx_score > context_sim_threshold:
            final_score = ctx_score - ans_score
            results.append((cand, final_score))

    results.sort(key=lambda x: x[1], reverse=True)
    return [r[0] for r in results[:top_k]]

# =====================================================
# 6. Main "Best" Distractor Generation
# =====================================================
def extract_main_token(phrase):
    """
    For multiword phrase, pick the syntactic head or first NOUN
    as the best single token for Sense2Vec fallback.
    """
    doc = nlp(phrase)
    if len(doc) == 1:
        return doc[0].text  # single word anyway

    best_token = doc.root
    # or pick first noun if available
    for token in doc:
        if token.pos_ in ["NOUN", "PROPN"]:
            best_token = token
            break
    return best_token.text

def generate_best_distractors(correct_answer, context, num_distractors=3):
    """
    Unified approach:
    1) Check if single-word. If yes, do normal WordNet+ConceptNet+Sense2Vec.
    2) If multiword:
       a) If time expression -> do time distractors
       b) If named entity -> e.g. location or person distractors
       c) Else fallback single-token approach for Sense2Vec
    3) Combine with WordNet+ConceptNet using original phrase
    4) Filter & Re-rank
    """

    tokens = correct_answer.strip().split()
    multiword = (len(tokens) > 1)

    # -------------- SINGLE-WORD CASE --------------
    if not multiword:
        # Get POS
        doc = nlp(correct_answer)
        wordnet_pos = wn.NOUN
        if doc and len(doc) == 1:
            spacy_pos = doc[0].tag_
            wordnet_pos = spacy_pos_to_wordnet_pos(spacy_pos)

        # Gather from WordNet, ConceptNet
        cnet_cands = get_conceptnet_candidates(correct_answer)
        wn_cands = wordnet_candidates(correct_answer, pos=wordnet_pos)

        # Sense2Vec (build tag)
        s2v_tag = None
        if wordnet_pos == wn.NOUN:
            s2v_tag = f"{correct_answer}|NOUN"
        elif wordnet_pos == wn.VERB:
            s2v_tag = f"{correct_answer}|VERB"
        elif wordnet_pos == wn.ADJ:
            s2v_tag = f"{correct_answer}|ADJ"
        else:
            s2v_tag = f"{correct_answer}|NOUN"

        s2v_cands = sense2vec_candidates(s2v_tag, topn=15)

        all_cands = list(set(cnet_cands + wn_cands + s2v_cands))
        filtered = filter_candidates(all_cands, correct_answer, context)
        final_distractors = threshold_rerank(filtered, correct_answer, context, top_k=num_distractors)
        return final_distractors

    # -------------- MULTIWORD CASE --------------
    # a) Check if time expression
    if is_time_phrase(correct_answer):
        return generate_time_distractors(correct_answer, num_distractors)

    # b) Named Entity check with spaCy
    doc = nlp(correct_answer)
    if doc.ents and len(doc.ents) > 0:
        ent = doc.ents[0]
        label = ent.label_
        if label in ["GPE", "LOC"]:
            # location fallback
            distractors = generate_location_distractors(correct_answer, num_distractors)
            return distractors
        elif label in ["PERSON", "ORG"]:
            # person fallback
            distractors = generate_person_distractors(correct_answer, num_distractors)
            return distractors
        # else, fallback to single token approach

    # c) Fallback single-token approach for Sense2Vec
    main_token = extract_main_token(correct_answer)
    # We'll do WordNet + ConceptNet with the original phrase,
    # but sense2vec with main_token.

    # WordNet POS for the main_token
    tok_doc = nlp(main_token)
    if tok_doc:
        wordnet_pos = spacy_pos_to_wordnet_pos(tok_doc[0].tag_)
    else:
        wordnet_pos = wn.NOUN

    cnet_cands = get_conceptnet_candidates(correct_answer)  # entire phrase
    wn_cands = wordnet_candidates(correct_answer, pos=wordnet_pos)  # might be less relevant for multiword, but we try

    # Sense2Vec on the main_token
    # We'll guess a tag
    if wordnet_pos == wn.NOUN:
        s2v_tag = f"{main_token}|NOUN"
    elif wordnet_pos == wn.VERB:
        s2v_tag = f"{main_token}|VERB"
    elif wordnet_pos == wn.ADJ:
        s2v_tag = f"{main_token}|ADJ"
    else:
        s2v_tag = f"{main_token}|NOUN"

    s2v_cands = sense2vec_candidates(s2v_tag, topn=15)

    all_cands = list(set(cnet_cands + wn_cands + s2v_cands))
    filtered = filter_candidates(all_cands, correct_answer, context)
    final_distractors = threshold_rerank(filtered, correct_answer, context, top_k=num_distractors)
    return final_distractors

# =====================================================
# 7. Example MCQ Pipeline
# =====================================================
def generate_qa(context):
    """
    Dummy QA for illustration. Replace with your model if needed.
    """
    # We'll simulate a multiword time phrase answer
    question = " Which country is known as the Land of the Rising Sun?"
    correct_answer = "Japan"
    return question, correct_answer

def generate_mcq(context, num_distractors=3):
    question, correct_answer = generate_qa(context)
    distractors = generate_best_distractors(correct_answer, context, num_distractors)

    # If not enough distractors found
    if len(distractors) < num_distractors:
        distractors += ["(No more distractors found)"] * (num_distractors - len(distractors))

    options = distractors + [correct_answer]
    random.shuffle(options)
    return {
        "question": question,
        "options": options,
        "correct_answer": correct_answer
    }

# =====================================================
# 8. Test / Demo
# =====================================================
if __name__ == "__main__":
    context_example = (
"Japan, an island nation in East Asia, is often referred to as the 'Land of the Rising Sun' because its name in Japanese, Nihon (日本), means 'origin of the sun.' This name reflects Japan’s position east of China, where the sun rises earlier."

    )

    mcq = generate_mcq(context_example, num_distractors=3)
    print("Question:", mcq["question"])
    print("Options:", mcq["options"])
    print("Correct Answer:", mcq["correct_answer"])

## 7. End-to-End Generation Demo
We show **two** approaches:

1. **T5 Approach**: Use the fine-tuned T5 to generate `"Question: ... Answer: ..."` from a passage. Then parse out the question and answer, generate **distractors** with the classical pipeline, and form an MCQ.

2. **Classical Only**: The naive method that picks a chunk from a sentence and forms a question, then uses the same distractor pipeline.

### 7.1. T5 Inference + Distractors

In [None]:
# Load our fine-tuned model for inference
inference_model = T5ForConditionalGeneration.from_pretrained("/content/drive/MyDrive/MinorProject/t5_mcq_finetuned").to(device)
inference_tokenizer = T5Tokenizer.from_pretrained("/content/drive/MyDrive/MinorProject/t5_mcq_finetuned")

def t5_generate_question_answer(passage, max_length=128):
    """
    Given a passage, feed "Generate MCQ: <passage>" to T5.
    Return the string: "Question: ... Answer: ...".
    """
    prompt = "Generate MCQ: " + passage
    inputs = inference_tokenizer(prompt, return_tensors="pt").to(device)
    with torch.no_grad():
        output_ids = inference_model.generate(
            inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            max_length=max_length,
            num_beams=4,
            early_stopping=True
        )
    decoded = inference_tokenizer.decode(output_ids[0], skip_special_tokens=True)
    return decoded

def parse_t5_output(t5_output):
    """
    T5 typically outputs: "Question: <q> Answer: <a>".
    This function extracts <q> and <a> using a simple regex or string split.
    """
    # naive parse
    question_part = ""
    answer_part = ""

    # Attempt a split by 'Question:' and 'Answer:'
    # 1) remove leading/trailing spaces
    text = t5_output.strip()
    # 2) find positions of 'Question:' and 'Answer:'
    q_idx = text.lower().find("question:")
    a_idx = text.lower().find("answer:")

    if q_idx != -1 and a_idx != -1:
        question_part = text[q_idx + len("question:"):a_idx].strip()
        answer_part = text[a_idx + len("answer:"):].strip()
    return question_part, answer_part

def generate_mcq_with_t5_and_distractors(passage, num_distractors=3):
    """
    1) Use T5 to get (question, answer) from the passage.
    2) Generate distractors for that answer using the classical pipeline.
    3) Return the final MCQ.
    """
    t5_output = t5_generate_question_answer(passage)
    question, correct_answer = parse_t5_output(t5_output)

    # If we fail to parse a question or answer, bail out
    if not question or not correct_answer:
        return {
            "passage": passage,
            "question": "(Could not parse question)",
            "options": [],
            "correct_answer": ""
        }

    # Generate distractors from the original passage as context
    distractors = generate_best_distractors(correct_answer, passage, num_distractors=num_distractors)
    if len(distractors) < num_distractors:
        while len(distractors) < num_distractors:
            distractors.append("(No more distractors)")

    options = distractors + [correct_answer]
    random.shuffle(options)

    mcq = {
        "passage": passage,
        "question": question,
        "options": options,
        "correct_answer": correct_answer
    }
    return mcq

### 7.2. Classical-Only Approach
If you choose **not** to use T5, you can run the naive approach of:
- Splitting the text into sentences,
- For each sentence, pick a chunk as the answer,
- Generate a question by replacing that chunk with a WH-word,
- Use the same distractor pipeline.

Below is a function to do exactly that, returning multiple MCQs from a text.

In [None]:
def generate_mcqs_from_text_classical(text, num_questions=5):
    sentences = split_into_sentences(text)
    if not sentences:
        return []

    chosen_sents = pick_top_sentences(sentences, num=num_questions * 2)
    mcqs = []
    for sent in chosen_sents:
        qa = generate_question_from_sentence(sent)
        if not qa:
            continue
        question, correct_answer = qa
        distractors = generate_best_distractors(correct_answer, sent, num_distractors=3)
        if len(distractors) < 3:
            while len(distractors) < 3:
                distractors.append("(No more distractors)")

        options = distractors + [correct_answer]
        random.shuffle(options)

        mcq = {
            "context_sentence": sent,
            "question": question,
            "options": options,
            "correct_answer": correct_answer
        }
        mcqs.append(mcq)
        if len(mcqs) >= num_questions:
            break
    return mcqs

def simple_evaluation_demo(mcqs):
    """
    Just print out MCQs for inspection.
    """
    for i, mcq in enumerate(mcqs, start=1):
        print(f"MCQ {i}:")
        if "context_sentence" in mcq:
            print("Context:", mcq["context_sentence"])
        elif "passage" in mcq:
            print("Passage:", mcq["passage"])
        print("Q:", mcq.get("question", "No question"))
        for idx, opt in enumerate(mcq.get("options", []), start=1):
            print(f"  {idx}) {opt}")
        print("Correct:", mcq.get("correct_answer", "N/A"))
        print("-"*60)

## 8. Final Demonstration
We'll show:
- Using **T5** to generate Q&A + classical distractors.
- Using the **classical** approach alone.

In [None]:
if __name__ == "__main__":
    sample_text = (
        "Beyoncé Giselle Knowles-Carter is an American singer, songwriter, and actress. "
        "Born and raised in Houston, Texas, she rose to fame in the late 1990s as the lead singer of the R&B group Destiny's Child. "
        "Often referred to as 'Queen Bey', Beyoncé is one of the world's best-selling recording artists, having sold over 120 million records worldwide. "
        "She has won 28 Grammy Awards and is the most-nominated woman in the award's history."
        "She loved Taylor Swift."
    )

    print("\n=== DEMO 1: T5 + Distractors ===\n")
    # T5 inference + distractors
    mcq_result = generate_mcq_with_t5_and_distractors(sample_text, num_distractors=3)
    # simple_evaluation_demo([mcq_result])
    print(mcq_result)


    # print("\n=== DEMO 2: Classical-Only ===\n")
    # # Classical approach
    # classical_mcqs = generate_mcqs_from_text_classical(sample_text, num_questions=3)
    # simple_evaluation_demo(classical_mcqs)

In [None]:

!pip install spacy rake-nltk nltk sentence-transformers openai
!python -m spacy download en_core_web_sm
!python -m nltk.downloader wordnet

In [None]:
!python -m nltk.downloader stopwords
!python -m nltk.downloader punkt_tab

In [None]:
import spacy
from rake_nltk import Rake
from nltk.corpus import wordnet as wn
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
import openai  # For GPT-3.5/4 integration (optional)

# Load models
nlp = spacy.load("en_core_web_sm")
sentence_model = SentenceTransformer('all-MiniLM-L6-v2')

def generate_distractors(context, question, correct_answer, num_distractors=3):
    # Step 1: Extract entities and keywords from context
    doc = nlp(context)
    entities = [ent.text for ent in doc.ents if ent.label_ in ['PERSON', 'ORG', 'PRODUCT', 'NORP', 'PROFESSION']]

    # Step 2: Keyword extraction with RAKE
    r = Rake()
    r.extract_keywords_from_text(context)
    keywords = [kw[0] for kw in r.get_ranked_phrases_with_scores()[:5] if kw[0] != correct_answer]

    # Step 3: Get semantic neighbors using WordNet
    synsets = wn.synsets(correct_answer)
    wordnet_distractors = []
    for syn in synsets:
        for lemma in syn.lemmas():
            # Get hyponyms and related terms
            hypernyms = [hyp.name().split('.')[0] for hyp in syn.hypernyms()]
            similar = [sim.name().split('.')[0] for sim in syn.similar_tos()]
            wordnet_distractors.extend(hypernyms + similar)

    # Combine all candidates
    candidates = list(set(entities + keywords + wordnet_distractors))

    # Step 4: Semantic filtering with Sentence-BERT
    correct_embedding = sentence_model.encode([correct_answer])
    candidate_embeddings = sentence_model.encode(candidates)

    similarities = cosine_similarity(correct_embedding, candidate_embeddings)[0]
    filtered = [
        (cand, sim) for cand, sim in zip(candidates, similarities)
        if 0.4 < sim < 0.8  # Tune these thresholds
    ]

    # Sort by relevance
    filtered.sort(key=lambda x: x[1], reverse=True)
    top_candidates = [cand for cand, sim in filtered][:num_distractors*2]

    # Step 5: LLM Refinement (using GPT-3.5/4 as example)
    prompt = f"""
    Generate {num_distractors} plausible MCQ distractors for this question.
    Context: {context}
    Question: {question}
    Correct Answer: {correct_answer}
    Candidate Distractors: {top_candidates}

    Rules:
    1. Choose/rewrite candidates to be plausible but incorrect
    2. Make them grammatically consistent with the question
    3. Ensure they relate to the context

    Output ONLY comma-separated distractors:
    """

    # Uncomment for OpenAI API usage
    """
    response = openai.ChatCompletion.create(
        model="gpt-3.5-turbo",
        messages=[{"role": "user", "content": prompt}],
        temperature=0.7
    )
    final_distractors = response.choices[0].message.content.split(', ')
    """

    # Fallback if no API access: use top candidates
    final_distractors = top_candidates[:num_distractors]

    return list(set(final_distractors))[:num_distractors]  # Ensure uniqueness

# Example usage
context =  "Beyoncé Giselle Knowles-Carter is an American singer, songwriter, and actress. Born and raised in Houston, Texas, she rose to fame in the late 1990s as the lead singer of the R&B group Destiny's Child. Often referred to as 'Queen Bey', Beyoncé is one of the world's best-selling recording artists, having sold over 120 million records worldwide.She has won 28 Grammy Awards and is the most-nominated woman in the award's history."

question = "When did Beyonce start becoming popular?"
correct_answer = "in the late 1990s"

distractors = generate_distractors(context, question, correct_answer)
print(f"Correct Answer: {correct_answer}")
print(f"Distractors: {distractors}")

In [None]:
from transformers import T5ForConditionalGeneration, T5Tokenizer

model = T5ForConditionalGeneration.from_pretrained("t5-small")
tokenizer = T5Tokenizer.from_pretrained("t5-small")

def t5_generate_distractors(context, question, answer):
    input_text = f"generate distractors: {context} {question} {answer}"
    inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
    outputs = model.generate(**inputs)
    return tokenizer.decode(outputs[0], skip_special_tokens=True).split(", ")

In [None]:
context =  "Beyoncé Giselle Knowles-Carter is an American singer, songwriter, and actress. Born and raised in Houston, Texas, she rose to fame in the late 1990s as the lead singer of the R&B group Destiny's Child. Often referred to as 'Queen Bey', Beyoncé is one of the world's best-selling recording artists, having sold over 120 million records worldwide.She has won 28 Grammy Awards and is the most-nominated woman in the award's history."

question = "When did Beyonce start becoming popular?"
correct_answer = "in the late 1990s"

distractors = t5_generate_distractors(context, question, correct_answer)

In [None]:
distractors

In [None]:
import torch
from transformers import T5ForConditionalGeneration, T5Tokenizer, Trainer, TrainingArguments
from datasets import Dataset, load_dataset
import numpy as np
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity

# 1. Fine-tuned T5 Model
class DistractorGenerator:
    def __init__(self, model_name="t5-base"):
        self.tokenizer = T5Tokenizer.from_pretrained(model_name)
        self.model = T5ForConditionalGeneration.from_pretrained(model_name)
        self.semantic_filter = SentenceTransformer('all-MiniLM-L6-v2')

    def prepare_data(self, dataset_name="sciq"):
        """Load and preprocess training data"""
        dataset = load_dataset(dataset_name)

        def format_example(example):
            input_text = f"generate distractors: {example['support']} {example['question']} {example['correct_answer']}"
            target_text = " , ".join(example['distractors'])
            return {"input_text": input_text, "target_text": target_text}

        return dataset.map(format_example, batched=False)

    def train(self, dataset, output_dir="./distractor_t5"):
        # Tokenization
        def tokenize_fn(examples):
            model_inputs = self.tokenizer(
                examples["input_text"],
                max_length=512,
                truncation=True,
                padding="max_length"
            )

            with self.tokenizer.as_target_tokenizer():
                labels = self.tokenizer(
                    examples["target_text"],
                    max_length=128,
                    truncation=True,
                    padding="max_length"
                )

            model_inputs["labels"] = labels["input_ids"]
            return model_inputs

        tokenized_dataset = dataset.map(tokenize_fn, batched=True)

        # Training arguments
        training_args = TrainingArguments(
            output_dir=output_dir,
            num_train_epochs=3,
            per_device_train_batch_size=8,
            per_device_eval_batch_size=8,
            warmup_steps=500,
            weight_decay=0.01,
            logging_dir='./logs',
            logging_steps=100,
            evaluation_strategy="steps",
            eval_steps=500
        )

        # Trainer
        trainer = Trainer(
            model=self.model,
            args=training_args,
            train_dataset=tokenized_dataset["train"],
            eval_dataset=tokenized_dataset["validation"]
        )

        trainer.train()
        trainer.save_model(f"{output_dir}/final_model")

    def generate(self, context, question, correct_answer, num_distractors=3, temperature=0.7):
        """Generate and filter distractors"""
        # Generate raw candidates
        input_text = f"generate distractors: {context} {question} {correct_answer}"
        input_ids = self.tokenizer.encode(input_text, return_tensors="pt")

        outputs = self.model.generate(
            input_ids,
            max_length=128,
            num_return_sequences=5,
            num_beams=5,
            temperature=temperature,
            early_stopping=True
        )

        raw_distractors = [self.tokenizer.decode(out, skip_special_tokens=True)
                          for out in outputs]

        # Split and clean
        candidates = list(set([d.strip() for dist in raw_distractors
                             for d in dist.split(",")]))

        # Semantic filtering
        return self.filter_distractors(candidates, correct_answer, num_distractors)

    def filter_distractors(self, candidates, correct_answer, num_distractors):
        """Filter using semantic similarity"""
        # Encode all candidates
        all_texts = [correct_answer] + candidates
        embeddings = self.semantic_filter.encode(all_texts)

        # Calculate similarities
        correct_emb = embeddings[0:1]
        candidate_embs = embeddings[1:]
        similarities = cosine_similarity(correct_emb, candidate_embs)[0]

        # Filter criteria
        filtered = [
            (cand, sim) for cand, sim in zip(candidates, similarities)
            if 0.3 < sim < 0.8  # Adjust these thresholds
        ]

        # Sort by optimal similarity
        filtered.sort(key=lambda x: -abs(x[1] - 0.5))  # Prefer 0.4-0.6 range

        return [cand for cand, _ in filtered[:num_distractors]]

# 2. Usage Pipeline
if __name__ == "__main__":
    # Initialize generator
    dg = DistractorGenerator()

    # Train on SCIQ dataset (example)
    # dataset = dg.prepare_data()
    # dg.train(dataset)

    # Load pre-trained model (after training)
    dg.model = T5ForConditionalGeneration.from_pretrained("/distractor_t5/final_model")

    # Generate distractors
    context = "Last week I talked with some of my students about what they wanted to do after they graduated..."
    question = "We can know from the passage that the author works as a_."
    correct_answer = "teacher"

    distractors = dg.generate(context, question, correct_answer)

    print(f"Correct Answer: {correct_answer}")
    print(f"Plausible Distractors: {distractors}")