In [1]:
from datasets import load_dataset
from torch.utils.data import random_split, DataLoader
from transformers import DistilBertTokenizer
import torch

# Load the dataset
ds = load_dataset("stanfordnlp/imdb")
ds.pop('unsupervised')

train_ds = ds["train"]
test_ds = ds["test"]

# Split the train dataset into train and validation sets
train_size = int(0.8 * len(train_ds))
val_size = len(train_ds) - train_size
train_ds, val_ds = random_split(train_ds, [train_size, val_size])

# Initialize the tokenizer
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")

# Tokenize function
def tokenize_function(example):
    return tokenizer(
        example["text"],
        padding="max_length",
        truncation=True,
        max_length=256  # Adjust max_length for LSTM input size (e.g., 128, 256, etc., based on GPU memory)
    )

# Tokenize datasets
train_ds = train_ds.dataset.map(tokenize_function, batched=True)
val_ds = val_ds.dataset.map(tokenize_function, batched=True)
test_ds = test_ds.map(tokenize_function, batched=True)

# Keep only input_ids and attention_mask in the dataset for LSTM
def format_for_lstm(batch):
    return {
        "input_ids": torch.tensor(batch["input_ids"]),
        "attention_mask": torch.tensor(batch["attention_mask"]),
        "label": torch.tensor(batch["label"])
    }

train_ds = train_ds.map(format_for_lstm, batched=True)
val_ds = val_ds.map(format_for_lstm, batched=True)
test_ds = test_ds.map(format_for_lstm, batched=True)

# Set the format for PyTorch compatibility
train_ds.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
val_ds.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
test_ds.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])

# Create DataLoaders
batch_size = 1024
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=6)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=6)
test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=6)



In [2]:
import torch
import torch.nn as nn
import torch.optim as optim

device = "cuda" if torch.cuda.is_available() else "cpu"

class LSTMModel(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, output_dim, num_layers=2, dropout=0.2):
        super(LSTMModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTM(input_size=embed_dim, hidden_size=hidden_dim, num_layers=num_layers, batch_first=True, dropout=dropout)
        self.fc = nn.Linear(hidden_dim, output_dim)
            
    def forward(self, x, attention_mask=None):
        embedded = self.embedding(x)
        if attention_mask is not None:
            packed_embedded = nn.utils.rnn.pack_padded_sequence(
                embedded, attention_mask.sum(1).cpu(), batch_first=True, enforce_sorted=False
            )
            packed_output, (hidden, _) = self.lstm(packed_embedded)
            output, _ = nn.utils.rnn.pad_packed_sequence(packed_output, batch_first=True)
        else:
            output, (hidden, _) = self.lstm(embedded)

        hidden = hidden[-1] 
        out = self.fc(hidden)
        return out

vocab_size = tokenizer.vocab_size  # from DistilBERT tokenizer
embed_dim = 256  # size of the word embeddings
hidden_dim = 512  # hidden dimension for the LSTM
output_dim = 1  # binary classification (positive/negative sentiment)

model = LSTMModel(vocab_size, embed_dim, hidden_dim, output_dim).to(device)
print(model)        

LSTMModel(
  (embedding): Embedding(30522, 256)
  (lstm): LSTM(256, 512, num_layers=2, batch_first=True, dropout=0.2)
  (fc): Linear(in_features=512, out_features=1, bias=True)
)


In [3]:
import torch.optim as optim
from tqdm import tqdm

device = "cuda" if torch.cuda.is_available() else "cpu"
model = LSTMModel(vocab_size, embed_dim, hidden_dim, output_dim).to(device)
model.train()

# Define the loss function and optimizer
loss_fn = torch.nn.BCEWithLogitsLoss()  # Use for binary classification
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4, weight_decay=0.1)
scaler = torch.cuda.amp.GradScaler()

num_epochs = 50
max_grad_norm = 0.1  # Set your max_grad_norm value

for epoch in range(num_epochs):
    model.train()  # Set the model to training mode
    running_loss = 0.0
    
    # Wrap train_loader in tqdm to monitor training progress
    train_loader_tqdm = tqdm(train_loader, desc=f"Epoch [{epoch + 1}/{num_epochs}] - Training", leave=False)
    
    for batch in train_loader_tqdm:
        inputs, attention_mask, labels = batch['input_ids'].to(device), batch['attention_mask'].to(device), batch['label'].float().to(device)
        
        optimizer.zero_grad()

        # Forward pass with mixed precision
        with torch.cuda.amp.autocast():
            outputs = model(inputs, attention_mask)
            loss = loss_fn(outputs.squeeze(), labels)

        # Backward pass
        scaler.scale(loss).backward()

        # Clip gradients
        scaler.unscale_(optimizer)  # Unscale before gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)  # Clip gradients

        scaler.step(optimizer)
        scaler.update()

        running_loss += loss.item()
        train_loader_tqdm.set_postfix(loss=loss.item())  # Display the current batch loss
        
    # Calculate average training loss for the epoch
    avg_loss = running_loss / len(train_loader)

    # Validation phase
    model.eval()
    running_val_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        val_loader_tqdm = tqdm(val_loader, desc=f"Epoch [{epoch + 1}/{num_epochs}] - Validation", leave=False)
        
        for batch in val_loader_tqdm:
            inputs, attention_mask, labels = batch['input_ids'].to(device), batch['attention_mask'].to(device), batch['label'].float().to(device)

            with torch.cuda.amp.autocast():
                outputs = model(inputs, attention_mask)
                loss = loss_fn(outputs.squeeze(), labels)
                
            running_val_loss += loss.item()

            # Calculate accuracy
            predicted = (outputs.squeeze() > 0).long()
            total += labels.size(0)
            correct += (predicted == labels.long()).sum().item()
        
            val_loader_tqdm.set_postfix(loss=loss.item())  # Display the current validation batch loss
    
    # Calculate average validation loss and accuracy
    avg_val_loss = running_val_loss / len(val_loader)
    accuracy = correct / total * 100

    print(f"Epoch [{epoch + 1}/{num_epochs}], "
          f"Training Loss: {avg_loss:.4f}, "
          f"Validation Loss: {avg_val_loss:.4f}, "
          f"Validation Accuracy: {accuracy:.2f}%")

print("Training complete.")

  scaler = torch.cuda.amp.GradScaler()
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
                                                                                

Epoch [1/50], Training Loss: 0.6865, Validation Loss: 0.6579, Validation Accuracy: 61.03%


                                                                                

Epoch [2/50], Training Loss: 0.6385, Validation Loss: 0.6189, Validation Accuracy: 67.11%


                                                                                

Epoch [3/50], Training Loss: 0.5548, Validation Loss: 0.5251, Validation Accuracy: 74.61%


                                                                                

Epoch [4/50], Training Loss: 0.5213, Validation Loss: 0.5151, Validation Accuracy: 74.94%


                                                                                

Epoch [5/50], Training Loss: 0.4842, Validation Loss: 0.4767, Validation Accuracy: 77.54%


                                                                                

Epoch [6/50], Training Loss: 0.4498, Validation Loss: 0.4391, Validation Accuracy: 80.31%


                                                                                

Epoch [7/50], Training Loss: 0.4197, Validation Loss: 0.4021, Validation Accuracy: 82.66%


                                                                                

Epoch [8/50], Training Loss: 0.4066, Validation Loss: 0.3924, Validation Accuracy: 83.35%


                                                                                

Epoch [9/50], Training Loss: 0.3889, Validation Loss: 0.3720, Validation Accuracy: 84.71%


                                                                                

Epoch [10/50], Training Loss: 0.3774, Validation Loss: 0.3272, Validation Accuracy: 86.65%


                                                                                

Epoch [11/50], Training Loss: 0.4068, Validation Loss: 0.3193, Validation Accuracy: 86.90%


                                                                                

Epoch [12/50], Training Loss: 0.3634, Validation Loss: 0.3130, Validation Accuracy: 87.09%


                                                                                

Epoch [13/50], Training Loss: 0.3548, Validation Loss: 0.2993, Validation Accuracy: 88.13%


                                                                                

Epoch [14/50], Training Loss: 0.3080, Validation Loss: 0.2964, Validation Accuracy: 88.36%


                                                                                

Epoch [15/50], Training Loss: 0.3268, Validation Loss: 0.3424, Validation Accuracy: 84.81%


                                                                                

Epoch [16/50], Training Loss: 0.3134, Validation Loss: 0.2607, Validation Accuracy: 89.81%


                                                                                

Epoch [17/50], Training Loss: 0.2953, Validation Loss: 0.2766, Validation Accuracy: 89.07%


                                                                                

Epoch [18/50], Training Loss: 0.2888, Validation Loss: 0.2979, Validation Accuracy: 87.59%


                                                                                

Epoch [19/50], Training Loss: 0.2833, Validation Loss: 0.2558, Validation Accuracy: 90.13%


                                                                                

Epoch [20/50], Training Loss: 0.2733, Validation Loss: 0.2606, Validation Accuracy: 89.37%


                                                                                

Epoch [21/50], Training Loss: 0.2726, Validation Loss: 0.2701, Validation Accuracy: 88.56%


                                                                                

Epoch [22/50], Training Loss: 0.2739, Validation Loss: 0.2783, Validation Accuracy: 88.35%


                                                                                

Epoch [23/50], Training Loss: 0.2820, Validation Loss: 0.2727, Validation Accuracy: 89.12%


                                                                                

Epoch [24/50], Training Loss: 0.2554, Validation Loss: 0.2750, Validation Accuracy: 89.00%


                                                                                

Epoch [25/50], Training Loss: 0.2478, Validation Loss: 0.2406, Validation Accuracy: 90.90%


                                                                                

Epoch [26/50], Training Loss: 0.2486, Validation Loss: 0.2103, Validation Accuracy: 91.68%


                                                                                

Epoch [27/50], Training Loss: 0.2403, Validation Loss: 0.2852, Validation Accuracy: 88.04%


                                                                                

Epoch [28/50], Training Loss: 0.2310, Validation Loss: 0.1889, Validation Accuracy: 92.94%


                                                                                

Epoch [29/50], Training Loss: 0.2251, Validation Loss: 0.1949, Validation Accuracy: 93.01%


                                                                                

Epoch [30/50], Training Loss: 0.2148, Validation Loss: 0.2540, Validation Accuracy: 89.83%


                                                                                

Epoch [31/50], Training Loss: 0.2165, Validation Loss: 0.1830, Validation Accuracy: 93.49%


                                                                                

Epoch [32/50], Training Loss: 0.2072, Validation Loss: 0.2010, Validation Accuracy: 92.07%


                                                                                

Epoch [33/50], Training Loss: 0.2024, Validation Loss: 0.2283, Validation Accuracy: 90.39%


                                                                                

Epoch [34/50], Training Loss: 0.1981, Validation Loss: 0.1657, Validation Accuracy: 93.64%


                                                                                

Epoch [35/50], Training Loss: 0.1943, Validation Loss: 0.1772, Validation Accuracy: 93.82%


                                                                                

Epoch [36/50], Training Loss: 0.1946, Validation Loss: 0.2538, Validation Accuracy: 89.66%


                                                                                

Epoch [37/50], Training Loss: 0.1873, Validation Loss: 0.1550, Validation Accuracy: 94.88%


                                                                                

Epoch [38/50], Training Loss: 0.1786, Validation Loss: 0.1621, Validation Accuracy: 93.83%


                                                                                

Epoch [39/50], Training Loss: 0.1791, Validation Loss: 0.1498, Validation Accuracy: 94.42%


                                                                                

Epoch [40/50], Training Loss: 0.1695, Validation Loss: 0.2334, Validation Accuracy: 90.21%


                                                                                

Epoch [41/50], Training Loss: 0.1725, Validation Loss: 0.1321, Validation Accuracy: 95.26%


                                                                                

Epoch [42/50], Training Loss: 0.1616, Validation Loss: 0.1626, Validation Accuracy: 94.28%


                                                                                

Epoch [43/50], Training Loss: 0.1623, Validation Loss: 0.1737, Validation Accuracy: 93.24%


                                                                                

Epoch [44/50], Training Loss: 0.1610, Validation Loss: 0.1203, Validation Accuracy: 96.12%


                                                                                

Epoch [45/50], Training Loss: 0.1602, Validation Loss: 0.1073, Validation Accuracy: 96.56%


                                                                                

Epoch [46/50], Training Loss: 0.1498, Validation Loss: 0.2247, Validation Accuracy: 90.27%


                                                                                

Epoch [47/50], Training Loss: 0.1485, Validation Loss: 0.0900, Validation Accuracy: 97.33%


                                                                                

Epoch [48/50], Training Loss: 0.1329, Validation Loss: 0.1151, Validation Accuracy: 96.18%


                                                                                

Epoch [49/50], Training Loss: 0.1310, Validation Loss: 0.1698, Validation Accuracy: 93.64%


                                                                                

Epoch [50/50], Training Loss: 0.1287, Validation Loss: 0.0901, Validation Accuracy: 97.30%
Training complete.




In [4]:
model.eval()  # Set the model to evaluation mode
running_test_loss = 0.0
correct = 0
total = 0
    
# Define the loss function for evaluation
loss_fn = torch.nn.BCEWithLogitsLoss()

with torch.no_grad():
    test_loader_tqdm = tqdm(test_loader, desc="Testing", leave=False)
        
    for batch in test_loader_tqdm:
        inputs, attention_mask, labels = batch['input_ids'].to(device), batch['attention_mask'].to(device), batch['label'].float().to(device)

        # Forward pass with mixed precision
        with torch.cuda.amp.autocast():
            outputs = model(inputs, attention_mask)
            loss = loss_fn(outputs.squeeze(), labels)
                
        running_test_loss += loss.item()

        # Calculate accuracy
        predicted = (outputs.squeeze() > 0).long()
        total += labels.size(0)
        correct += (predicted == labels.long()).sum().item()

        test_loader_tqdm.set_postfix(loss=loss.item())  # Display the current test batch loss
    
# Calculate average test loss and accuracy
avg_test_loss = running_test_loss / len(test_loader)
accuracy = correct / total * 100

print(f"Test Loss: {avg_test_loss:.4f}, Test Accuracy: {accuracy:.2f}%")

  with torch.cuda.amp.autocast():
                                                                                

Test Loss: 0.5061, Test Accuracy: 84.04%


