In [None]:
!pip install datasets

In [2]:
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import (
    T5ForConditionalGeneration,
    T5Tokenizer,
    AdamW,
    get_linear_schedule_with_warmup
)
from datasets import load_dataset
import numpy as np
from tqdm.auto import tqdm
import wandb
import random

In [3]:
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed()

In [4]:
class BoolQDataset(Dataset):
    def __init__(self, data, tokenizer, max_length=512):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        question = item['question']
        passage = item['passage']
        answer = 'yes' if item['answer'] else 'no'

        input_text = f"question: {question} context: {passage}"

        # Tokenize inputs and targets
        inputs = self.tokenizer(
            input_text,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors="pt"
        )

        targets = self.tokenizer(
            answer,
            max_length=8,  # Short max_length for yes/no answers
            padding='max_length',
            truncation=True,
            return_tensors="pt"
        )

        return {
            'input_ids': inputs['input_ids'].squeeze(),
            'attention_mask': inputs['attention_mask'].squeeze(),
            'labels': targets['input_ids'].squeeze(),
        }

In [None]:
model_name = 't5-base'
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)

In [6]:
dataset = load_dataset("boolq")
print(f"Training set size: {len(dataset['train'])}")
print(f"Validation set size: {len(dataset['validation'])}")

README.md:   0%|          | 0.00/6.57k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/3.69M [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/1.26M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/9427 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/3270 [00:00<?, ? examples/s]

Training set size: 9427
Validation set size: 3270


In [7]:
train_dataset = BoolQDataset(dataset['train'], tokenizer)
val_dataset = BoolQDataset(dataset['validation'], tokenizer)

In [8]:
batch_size = 8
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
print(f"Using device: {device}")

learning_rate = 3e-4
epochs = 3
warmup_steps = 600

optimizer = AdamW(model.parameters(), lr=learning_rate)
total_steps = len(train_dataloader) * epochs
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=warmup_steps,
    num_training_steps=total_steps
)

In [10]:
def train_epoch(model, dataloader, optimizer, scheduler, device):
    model.train()
    total_loss = 0
    progress_bar = tqdm(dataloader, desc='Training')

    for batch in progress_bar:
        optimizer.zero_grad()

        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels
        )

        loss = outputs.loss
        total_loss += loss.item()

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()

        progress_bar.set_postfix({'loss': loss.item()})

    return total_loss / len(dataloader)

In [11]:
def evaluate(model, dataloader, device):
    model.eval()
    total_correct = 0
    total_samples = 0

    with torch.no_grad():
        for batch in tqdm(dataloader, desc='Evaluating'):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            outputs = model.generate(
                input_ids=input_ids,
                attention_mask=attention_mask,
                max_length=8,
                num_beams=4
            )

            decoded_preds = tokenizer.batch_decode(outputs, skip_special_tokens=True)
            decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

            total_correct += sum(1 for pred, label in zip(decoded_preds, decoded_labels)
                               if pred.lower() == label.lower())
            total_samples += len(decoded_preds)

    accuracy = total_correct / total_samples
    return accuracy

In [12]:
best_accuracy = 0
for epoch in range(epochs):
    print(f"\nEpoch {epoch + 1}/{epochs}")

    # Train
    train_loss = train_epoch(model, train_dataloader, optimizer, scheduler, device)

    # Evaluate
    val_accuracy = evaluate(model, val_dataloader, device)

    print(f"Train Loss: {train_loss:.4f}")
    print(f"Validation Accuracy: {val_accuracy:.4f}")

    # Save best model
    if val_accuracy > best_accuracy:
        best_accuracy = val_accuracy
        torch.save(model.state_dict(), "best_model.pt")

print(f"\nBest Validation Accuracy: {best_accuracy:.4f}")


Epoch 1/3


Training:   0%|          | 0/1179 [00:00<?, ?it/s]

Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


Evaluating:   0%|          | 0/409 [00:00<?, ?it/s]

Train Loss: 0.6603
Validation Accuracy: 0.7829

Epoch 2/3


Training:   0%|          | 0/1179 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/409 [00:00<?, ?it/s]

Train Loss: 0.0503
Validation Accuracy: 0.8073

Epoch 3/3


Training:   0%|          | 0/1179 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/409 [00:00<?, ?it/s]

Train Loss: 0.0205
Validation Accuracy: 0.8180

Best Validation Accuracy: 0.8180


In [13]:
def load_trained_model(model_path):
    model = T5ForConditionalGeneration.from_pretrained('t5-base')

    model.load_state_dict(torch.load(model_path))

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)

    model.eval()

    return model

trained_model = load_trained_model('/content/best_model.pt')

  model.load_state_dict(torch.load(model_path))


In [14]:
def evaluate_example(passage, question, model, tokenizer):
    device = next(model.parameters()).device

    # Format input
    input_text = f"question: {question} context: {passage}"

    inputs = tokenizer(
        input_text,
        max_length=512,
        padding='max_length',
        truncation=True,
        return_tensors="pt"
    ).to(device)

    with torch.no_grad():
        outputs = model.generate(
            input_ids=inputs['input_ids'],
            attention_mask=inputs['attention_mask'],
            max_length=8,
            num_beams=4
        )

    prediction = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return prediction.lower()

In [15]:
def evaluate_batch(examples, model, tokenizer, batch_size=8):
    results = []
    for i in range(0, len(examples), batch_size):
        batch = examples[i:i+batch_size]

        for example in batch:
            prediction = evaluate_example(
                example['passage'],
                example['question'],
                model,
                tokenizer
            )
            results.append({
                'question': example['question'],
                'passage': example['passage'],
                'predicted_answer': prediction
            })

    return results

In [16]:
test_cases = {
    'short_context': [
        {
            'passage': "The Earth is the third planet from the Sun and the only place known to harbor life.",
            'question': "is earth the closest planet to sun",
            'expected': "no"
        },
        {
            'passage': "Coffee is a brewed drink prepared from roasted coffee beans.",
            'question': "is coffee made from tea leaves",
            'expected': "no"
        },
        {
            'passage': "Mount Everest is the highest mountain peak in the world.",
            'question': "is mount everest the tallest mountain",
            'expected': "yes"
        },
        {
            'passage': "Dolphins are marine mammals known for their intelligence and playful behavior.",
            'question': "are dolphins fish",
            'expected': "no"
        },
        {
            'passage': "The Sahara is the world's largest hot desert, covering most of North Africa.",
            'question': "is sahara in africa",
            'expected': "yes"
        },
        {
            'passage': "The piano is a musical instrument played by pressing keys that hammer strings.",
            'question': "is piano a string instrument",
            'expected': "no"
        },
        {
            'passage': "Tigers are the largest living cat species and a member of the genus Panthera.",
            'question': "is tiger the biggest cat",
            'expected': "yes"
        },
        {
            'passage': "Gold is a chemical element with the symbol Au and atomic number 79.",
            'question': "is gold a metal",
            'expected': "yes"
        },
        {
            'passage': "The moon is Earth's only natural satellite and orbits the planet.",
            'question': "does earth have multiple moons",
            'expected': "no"
        },
        {
            'passage': "Penguins are aquatic, flightless birds that live almost exclusively in the Southern Hemisphere.",
            'question': "can penguins fly",
            'expected': "no"
        }
    ],

    'long_context': [
        {
            'passage': """The Industrial Revolution was a period of major industrialization and innovation during the late 18th and early 19th century. The Industrial Revolution began in Great Britain, and quickly spread throughout Western Europe and North America. This time period saw the mechanization of agriculture and textile manufacturing and a revolution in power, including steam ships and railroads, that affected social, cultural and economic conditions.""",
            'question': "did industrial revolution start in america",
            'expected': "no"
        },
        {
            'passage': """The human brain is the central organ of the human nervous system, and with the spinal cord makes up the central nervous system. The brain consists of the cerebrum, the brainstem and the cerebellum. It controls most of the activities of the body, processing, integrating, and coordinating the information it receives from the sense organs, and making decisions as to the instructions sent to the rest of the body.""",
            'question': "is the brain part of the nervous system",
            'expected': "yes"
        },
        {
            'passage': """Climate change is a long-term change in the average weather patterns that have come to define Earth's local, regional and global climates. These changes have a broad range of observed effects that are synonymous with the term. Changes include sea levels rising, glaciers melting, arctic ice declining, and increased intensity of extreme weather events.""",
            'question': "is climate change only about temperature",
            'expected': "no"
        },
        {
            'passage': """The Internet is a global network of billions of computers and other electronic devices. With the Internet, it's possible to access almost any information, communicate with anyone else in the world, and do much more. You can do all of this by connecting a computer to the Internet, which is also called going online. The Internet has revolutionized the computer and communications world like nothing before.""",
            'question': "is internet limited to computers only",
            'expected': "no"
        },
        {
            'passage': """Photosynthesis is the process by which plants and other organisms convert light energy into chemical energy that can be released to fuel the organisms' activities. This chemical energy is stored in carbohydrate molecules, such as sugars, which are synthesized from carbon dioxide and water. Oxygen is also released as a byproduct of photosynthesis.""",
            'question': "does photosynthesis produce oxygen",
            'expected': "yes"
        },
        {
            'passage': """The Renaissance was a period in European history marking the transition from the Middle Ages to modernity and covering the 15th and 16th centuries. It occurred after the Crisis of the Late Middle Ages and was associated with great social change. In addition to the standard periodization, proponents of a 'long Renaissance' may put its beginning in the 14th century.""",
            'question': "was renaissance in the 20th century",
            'expected': "no"
        },
        {
            'passage': """Democracy is a form of government in which the people have the authority to deliberate and decide legislation, or to choose governing officials to do so. Who is considered part of 'the people' and how authority is shared among or delegated by the people has changed throughout the centuries.""",
            'question': "is democracy rule by people",
            'expected': "yes"
        },
        {
            'passage': """The Great Wall of China is a series of fortifications that were built across the historical northern borders of ancient Chinese states and Imperial China as protection against various nomadic groups from the Eurasian Steppe. Several walls were built from as early as the 7th century BCE, with selective stretches later joined together by Qin Shi Huang.""",
            'question': "was great wall built in single dynasty",
            'expected': "no"
        },
        {
            'passage': """Artificial intelligence (AI) is intelligence demonstrated by machines, unlike the natural intelligence displayed by humans and animals, which involves consciousness and emotionality. The distinction between the former and the latter categories is often revealed by the acronym chosen. 'Strong' AI is usually labelled as AGI while attempts to emulate 'natural' intelligence are usually characterized as ABI.""",
            'question': "is ai same as human intelligence",
            'expected': "no"
        },
        {
            'passage': """The theory of evolution by natural selection was proposed by Charles Darwin and Alfred Russel Wallace in the mid-19th century and was set out in detail in Darwin's book On the Origin of Species. Evolution by natural selection was first demonstrated by the observation that more offspring are often produced than can possibly survive.""",
            'question': "did darwin propose evolution theory",
            'expected': "yes"
        }
    ],

    'simple_yes': [
        {
            'passage': "Lions are the second-largest living big cat after the tiger.",
            'question': "is lion a big cat",
            'expected': "yes"
        },
        {
            'passage': "Water boils at 100 degrees Celsius at sea level.",
            'question': "does water boil at 100 celsius",
            'expected': "yes"
        },
        {
            'passage': "The Sun is a star at the center of our Solar System.",
            'question': "is the sun a star",
            'expected': "yes"
        },
        {
            'passage': "Birds lay eggs to reproduce and most species can fly.",
            'question': "do birds lay eggs",
            'expected': "yes"
        },
        {
            'passage': "Humans need oxygen to breathe and survive.",
            'question': "do humans need oxygen",
            'expected': "yes"
        },
        {
            'passage': "Plants need sunlight for photosynthesis.",
            'question': "do plants need sunlight",
            'expected': "yes"
        },
        {
            'passage': "Diamond is the hardest natural material known.",
            'question': "is diamond very hard",
            'expected': "yes"
        },
        {
            'passage': "The Earth rotates on its axis causing day and night.",
            'question': "does earth rotate",
            'expected': "yes"
        },
        {
            'passage': "Mathematics is a subject that deals with numbers and calculations.",
            'question': "is math about numbers",
            'expected': "yes"
        },
        {
            'passage': "Dogs are domesticated animals often kept as pets.",
            'question': "are dogs pets",
            'expected': "yes"
        }
    ],

    'simple_no': [
        {
            'passage': "Mercury is the closest planet to the Sun in our solar system.",
            'question': "is venus closest to sun",
            'expected': "no"
        },
        {
            'passage': "Whales are mammals, not fish, despite living in water.",
            'question': "are whales fish",
            'expected': "no"
        },
        {
            'passage': "The Antarctic is located at the South Pole.",
            'question': "is antarctic in north pole",
            'expected': "no"
        },
        {
            'passage': "Vegetables are plant-based foods, not derived from animals.",
            'question': "are vegetables from animals",
            'expected': "no"
        },
        {
            'passage': "The Mona Lisa was painted by Leonardo da Vinci.",
            'question': "did picasso paint mona lisa",
            'expected': "no"
        },
        {
            'passage': "English is the primary language of the United Kingdom.",
            'question': "is french the main language of uk",
            'expected': "no"
        },
        {
            'passage': "Cars run on engines powered by fuel or electricity.",
            'question': "do cars run on water",
            'expected': "no"
        },
        {
            'passage': "Insects have six legs and three body segments.",
            'question': "do insects have four legs",
            'expected': "no"
        },
        {
            'passage': "The Great Wall of China is located in China.",
            'question': "is great wall in india",
            'expected': "no"
        },
        {
            'passage': "Humans have 206 bones in their adult body.",
            'question': "do humans have 106 bones",
            'expected': "no"
        }
    ],

    'complex': [
        {
            'passage': """Quantum entanglement is a physical phenomenon that occurs when pairs or groups of particles are generated, interact, or share spatial proximity in ways such that the quantum state of each particle cannot be described independently of the state of the others, even when the particles are separated by a large distance.""",
            'question': "can quantum states be described independently in entanglement",
            'expected': "no"
        },
        {
            'passage': """DNA replication is the biological process of producing two identical replicas of DNA from one original DNA molecule. This process occurs in all living organisms and is the basis for biological inheritance. The process starts when DNA double helix is unwound by helicase.""",
            'question': "is dna replication a single step process",
            'expected': "no"
        },
        {
            'passage': """Black holes are regions of spacetime where gravity is so strong that nothing – no particles or even electromagnetic radiation such as light – can escape from it. The theory of general relativity predicts that a sufficiently compact mass can deform spacetime to form a black hole.""",
            'question': "can light escape black holes",
            'expected': "no"
        },
        {
            'passage': """Neural networks are computing systems inspired by the biological neural networks that constitute animal brains. Such systems learn to perform tasks by considering examples, generally without being programmed with task-specific rules. They automatically generate identifying characteristics from the examples they process.""",
            'question': "do neural networks need explicit programming",
            'expected': "no"
        },
        {
            'passage': """The endocrine system is a chemical messenger system comprising feedback loops of hormones released by internal glands of an organism directly into the circulatory system, regulating distant target organs. In humans, the major endocrine glands are the thyroid gland and the adrenal glands.""",
            'question': "does endocrine system use nerves for signaling",
            'expected': "no"
        },
        {
            'passage': """Superconductivity is a set of physical properties observed in certain materials where electrical resistance vanishes and magnetic flux fields are expelled from the material. Any material exhibiting these properties is a superconductor. Unlike an ordinary metallic conductor, whose resistance decreases gradually as its temperature is lowered even down to near absolute zero, a superconductor has a characteristic critical temperature below which the resistance drops abruptly to zero.""",
            'question': "do superconductors have resistance at critical temperature",
            'expected': "no"
        },
        {
            'passage': """The citric acid cycle, also known as the Krebs cycle or the tricarboxylic acid cycle, is a series of chemical reactions in aerobic cellular respiration that generate energy through the oxidation of acetyl-CoA derived from carbohydrates, fats, and proteins.""",
            'question': "is krebs cycle related to energy production",
            'expected': "yes"
        },
        {
            'passage': """CRISPR gene editing is a genetic engineering technique in molecular biology by which the genomes of living organisms may be modified. It is based on a simplified version of the bacterial CRISPR-Cas9 antiviral defense system. By delivering the Cas9 nuclease complexed with a synthetic guide RNA into a cell, the cell's genome can be cut at a desired location.""",
            'question': "is crispr a natural bacterial process",
            'expected': "yes"
        },
        {
            'passage': """The Standard Model of particle physics is a theory describing three of the four known fundamental forces in the universe (electromagnetic, weak, and strong interactions), as well as classifying all known elementary particles. It was developed in stages throughout the latter half of the 20th century.""",
            'question': "does standard model explain all fundamental forces",
            'expected': "no"
        },
        {
            'passage': """Blockchain technology is a structure that stores transactional records, also known as the block, of the public in several databases, known as the 'chain,' in a network connected through peer-to-peer nodes. This storage is referred to as a 'digital ledger.' Every transaction in this ledger is authorized by the digital signature of the owner.""",
            'question': "is blockchain centrally controlled",
            'expected': "no"
        }
    ]
}


def evaluate_categories(test_cases, model, tokenizer):
    results = {}

    for category, cases in test_cases.items():
        correct = 0
        total = len(cases)

        print(f"\nEvaluating {category}...")
        for case in tqdm(cases):
            prediction = evaluate_example(
                case['passage'],
                case['question'],
                model,
                tokenizer
            )
            if prediction.lower() == case['expected'].lower():
                correct += 1

        accuracy = correct / total
        results[category] = {
            'correct': correct,
            'total': total,
            'accuracy': accuracy
        }

    return results

print("Starting categorical evaluation...")
category_results = evaluate_categories(test_cases, trained_model, tokenizer)

print("\nResults by Category:")
print("=" * 50)
for category, metrics in category_results.items():
    print(f"\n{category.replace('_', ' ').title()}:")
    print(f"Accuracy: {metrics['accuracy']:.2%}")
    print(f"Correct: {metrics['correct']}/{metrics['total']}")

total_correct = sum(r['correct'] for r in category_results.values())
total_cases = sum(r['total'] for r in category_results.values())
overall_accuracy = total_correct / total_cases

print("\nOverall Statistics:")
print("=" * 50)
print(f"Total Accuracy: {overall_accuracy:.2%}")
print(f"Total Correct: {total_correct}/{total_cases}")

# Additional analyses
print("\nDetailed Analysis:")
print("=" * 50)
print(f"Yes/No Distribution Accuracy:")
yes_accuracy = category_results['simple_yes']['accuracy']
no_accuracy = category_results['simple_no']['accuracy']
print(f"Yes Questions: {yes_accuracy:.2%}")
print(f"No Questions: {no_accuracy:.2%}")

print(f"\nContext Length Impact:")
short_accuracy = category_results['short_context']['accuracy']
long_accuracy = category_results['long_context']['accuracy']
print(f"Short Context: {short_accuracy:.2%}")
print(f"Long Context: {long_accuracy:.2%}")

print(f"\nComplexity Impact:")
complex_accuracy = category_results['complex']['accuracy']
simple_avg = (yes_accuracy + no_accuracy) / 2
print(f"Simple Questions: {simple_avg:.2%}")
print(f"Complex Questions: {complex_accuracy:.2%}")

Starting categorical evaluation...

Evaluating short_context...


  0%|          | 0/10 [00:00<?, ?it/s]


Evaluating long_context...


  0%|          | 0/10 [00:00<?, ?it/s]


Evaluating simple_yes...


  0%|          | 0/10 [00:00<?, ?it/s]


Evaluating simple_no...


  0%|          | 0/10 [00:00<?, ?it/s]


Evaluating complex...


  0%|          | 0/10 [00:00<?, ?it/s]


Results by Category:

Short Context:
Accuracy: 80.00%
Correct: 8/10

Long Context:
Accuracy: 80.00%
Correct: 8/10

Simple Yes:
Accuracy: 90.00%
Correct: 9/10

Simple No:
Accuracy: 90.00%
Correct: 9/10

Complex:
Accuracy: 60.00%
Correct: 6/10

Overall Statistics:
Total Accuracy: 80.00%
Total Correct: 40/50

Detailed Analysis:
Yes/No Distribution Accuracy:
Yes Questions: 90.00%
No Questions: 90.00%

Context Length Impact:
Short Context: 80.00%
Long Context: 80.00%

Complexity Impact:
Simple Questions: 90.00%
Complex Questions: 60.00%


In [None]:
for i, result in enumerate(results, 1):
    print(f"\nTest Case {i}:")
    print("-" * 50)
    print(f"Question: {result['question']}")
    print(f"Context: {result['passage'][:100]}...")  # Show first 100 chars of passage
    print(f"Predicted Answer: {result['predicted_answer']}")
    print("-" * 50)