<a href="https://colab.research.google.com/github/sameekshya1999/GRU/blob/main/transformer_one_nextword.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:


from transformers import GPT2TokenizerFast, GPT2LMHeadModel
from datasets import load_dataset
import torch
import textwrap
import warnings
warnings.filterwarnings('ignore')

def pretty_print(s):
    print("Output:\n" + 80 * '-')
    print(textwrap.fill(tokenizer.decode(s, skip_special_tokens=True), 80))

model_to_use = "gpt2-large"

print("Using model: ", model_to_use)

tokenizer = GPT2TokenizerFast.from_pretrained(model_to_use)

model = GPT2LMHeadModel.from_pretrained(model_to_use,
                                        output_scores=True,
                                        pad_token_id=tokenizer.eos_token_id)

print(model.config)

# Import necessary libraries
import torch
from torch.nn.utils.rnn import pad_sequence
from sklearn.model_selection import train_test_split
from datasets import load_dataset  # Ensure `datasets` library is installed

# Load the dataset (replace with your dataset)
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")

# Set a padding token for the tokenizer
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Tokenize the dataset
texts = dataset["text"]
tokenized_texts = tokenizer("\n\n".join(texts), return_tensors="pt", truncation=True, padding=True)

# Prepare input-output pairs for next-word prediction
seq_length = 50  # Define sequence length
input_ids = tokenized_texts["input_ids"][0]

X, y = [], []
for i in range(seq_length, len(input_ids)):
    X.append(input_ids[i-seq_length:i])  # Previous tokens
    y.append(input_ids[i])  # Target token (next word)

# Split data into training (70%), validation (15%), and test (15%) sets
X_train, X_temp, y_train, y_temp = train_test_split(X, y, test_size=0.3, random_state=42)  # 70% training, 30% temporary
X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, random_state=42)  # Split 30% into 15% val and 15% test

# Convert sequences to tensors and pad them
X_train = pad_sequence([torch.tensor(seq) for seq in X_train], batch_first=True, padding_value=tokenizer.pad_token_id)
X_val = pad_sequence([torch.tensor(seq) for seq in X_val], batch_first=True, padding_value=tokenizer.pad_token_id)
X_test = pad_sequence([torch.tensor(seq) for seq in X_test], batch_first=True, padding_value=tokenizer.pad_token_id)
y_train = torch.tensor(y_train)
y_val = torch.tensor(y_val)
y_test = torch.tensor(y_test)

# Print dataset sizes for confirmation
print(f"Training set: {len(X_train)} samples")
print(f"Validation set: {len(X_val)} samples")
print(f"Test set: {len(X_test)} samples")

# Define optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

# Define loss function
loss_fn = torch.nn.CrossEntropyLoss()

# Training loop
epochs = 3
batch_size = 32
train_losses, val_losses, train_accs, val_accs = [], [], [], []

for epoch in range(epochs):
    model.train()
    epoch_loss, correct, total = 0, 0, 0

    # Training step
    for i in range(0, len(X_train), batch_size):
        batch_X = X_train[i:i+batch_size]
        batch_y = y_train[i:i+batch_size]

        # Forward pass
        outputs = model(batch_X, labels=batch_X)
        loss = outputs.loss
        logits = outputs.logits[:, -1, :]  # Logits of the last token
        predictions = logits.argmax(dim=-1)

        # Accuracy calculation
        correct += (predictions == batch_y).sum().item()
        total += batch_y.size(0)

        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        epoch_loss += loss.item()

    train_losses.append(epoch_loss / total)
    train_accs.append(correct / total)

    # Validation step
    model.eval()
    val_loss, val_correct, val_total = 0, 0, 0
    with torch.no_grad():
        for i in range(0, len(X_val), batch_size):
            batch_X = X_val[i:i+batch_size]
            batch_y = y_val[i:i+batch_size]

            # Forward pass
            outputs = model(batch_X, labels=batch_X)
            val_loss += outputs.loss.item()
            logits = outputs.logits[:, -1, :]
            predictions = logits.argmax(dim=-1)

            val_correct += (predictions == batch_y).sum().item()
            val_total += batch_y.size(0)

    val_losses.append(val_loss / val_total)
    val_accs.append(val_correct / val_total)

    print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_losses[-1]:.4f}, "
          f"Val Loss: {val_losses[-1]:.4f}, Train Acc: {train_accs[-1]:.4f}, "
          f"Val Acc: {val_accs[-1]:.4f}")

# Test evaluation after training
model.eval()
test_loss, test_correct, test_total = 0, 0, 0
with torch.no_grad():
    for i in range(0, len(X_test), batch_size):
        batch_X = X_test[i:i+batch_size]
        batch_y = y_test[i:i+batch_size]

        # Forward pass
        outputs = model(batch_X, labels=batch_X)
        test_loss += outputs.loss.item()
        logits = outputs.logits[:, -1, :]
        predictions = logits.argmax(dim=-1)

        test_correct += (predictions == batch_y).sum().item()
        test_total += batch_y.size(0)

# Calculate final test metrics
test_loss /= test_total
test_accuracy = test_correct / test_total

print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}")

import matplotlib.pyplot as plt

# Assuming test_loss and test_accuracy are calculated after the training loop
test_loss_value = test_loss  # Test loss from evaluation
test_accuracy_value = test_accuracy  # Test accuracy from evaluation

# Plot Loss
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label="Train Loss", marker='o')
plt.plot(val_losses, label="Validation Loss", marker='o')
plt.axhline(y=test_loss_value, color='r', linestyle='--', label="Test Loss")  # Add test loss as a horizontal line
plt.title("Loss vs. Epochs")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend()
plt.grid(True)
plt.show()

# Plot Accuracy
plt.figure(figsize=(10, 5))
plt.plot(train_accs, label="Train Accuracy", marker='o')
plt.plot(val_accs, label="Validation Accuracy", marker='o')
plt.axhline(y=test_accuracy_value, color='r', linestyle='--', label="Test Accuracy")  # Add test accuracy as a horizontal line
plt.title("Accuracy vs. Epochs")
plt.xlabel("Epochs")
plt.ylabel("Accuracy")
plt.legend()
plt.grid(True)
plt.show()
