In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchtext.datasets import SST
from torchtext.data import Field, BucketIterator
from torchtext.vocab import GloVe
TEXT = Field(lower=True,fix_length=200,batch_first=True)
LABEL = Field(sequential=False)
train,valid,test = SST.splits(TEXT,LABEL)
TEXT.build_vocab(train,vectors=GloVe(name='6B',dim=100),max_size=20000,min_freq=10)
LABEL.build_vocab(train)
train_iter, valid_iter, test_iter = BucketIterator.splits((train,valid,test),batch_size=16)


In [6]:
# Define the model architecture
class TextClassifier(nn.Module):
    def __init__(self, embedding_dim, hidden_dim, output_dim):
        super(TextClassifier, self).__init__()
        self.embedding = nn.Embedding.from_pretrained(TEXT.vocab.vectors)
        self.rnn = nn.RNN(embedding_dim, hidden_dim, num_layers=2, batch_first=True)
        self.fc = nn.Linear(hidden_dim, output_dim)
    
    def forward(self, x):
        embedded = self.embedding(x)
        _, hidden = self.rnn(embedded)
        hidden = hidden[-1]  # Take the last hidden state from all layers
        output = self.fc(hidden)
        return output

In [7]:
# Initialize the model
embedding_dim = 100
hidden_dim = 128
output_dim = len(LABEL.vocab)
model = TextClassifier(embedding_dim, hidden_dim, output_dim)

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())


In [8]:
# Training loop
def train_model(model, iterator, optimizer, criterion):
    model.train()
    total_loss = 0.0
    total_correct = 0
    
    for batch in iterator:
        optimizer.zero_grad()
        text, label = batch.text, batch.label
        predictions = model(text)
        loss = criterion(predictions, label)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        total_correct += (predictions.argmax(1) == label).sum().item()
    
    return total_loss / len(iterator), total_correct / len(iterator.dataset)


In [9]:
# Evaluation loop
def evaluate_model(model, iterator, criterion):
    model.eval()
    total_loss = 0.0
    total_correct = 0
    
    with torch.no_grad():
        for batch in iterator:
            text, label = batch.text, batch.label
            predictions = model(text)
            loss = criterion(predictions, label)
            
            total_loss += loss.item()
            total_correct += (predictions.argmax(1) == label).sum().item()
    
    return total_loss / len(iterator), total_correct / len(iterator.dataset)

In [10]:
# Train the model
num_epochs = 10

for epoch in range(num_epochs):
    train_loss, train_acc = train_model(model, train_iter, optimizer, criterion)
    valid_loss, valid_acc = evaluate_model(model, valid_iter, criterion)
    
    print(f'Epoch: {epoch+1}/{num_epochs}')
    print(f'Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}')
    print(f'Valid Loss: {valid_loss:.4f} | Valid Acc: {valid_acc:.4f}')
    print()

# Test the model
test_loss, test_acc = evaluate_model(model, test_iter, criterion)
print(f'Test Loss: {test_loss:.4f} | Test Acc: {test_acc:.4f}')

Epoch: 1/10
Train Loss: 1.0621 | Train Acc: 0.4082
Valid Loss: 1.0677 | Valid Acc: 0.4033

Epoch: 2/10
Train Loss: 1.0519 | Train Acc: 0.4147
Valid Loss: 1.0625 | Valid Acc: 0.3887

Epoch: 3/10
Train Loss: 1.0507 | Train Acc: 0.4120
Valid Loss: 1.0605 | Valid Acc: 0.4033

Epoch: 4/10
Train Loss: 1.0497 | Train Acc: 0.4163
Valid Loss: 1.0609 | Valid Acc: 0.4033

Epoch: 5/10
Train Loss: 1.0494 | Train Acc: 0.4109
Valid Loss: 1.0634 | Valid Acc: 0.4033

Epoch: 6/10
Train Loss: 1.0596 | Train Acc: 0.4141
Valid Loss: 1.0702 | Valid Acc: 0.3887

Epoch: 7/10
Train Loss: 1.0604 | Train Acc: 0.4123
Valid Loss: 1.1208 | Valid Acc: 0.2707

Epoch: 8/10
Train Loss: 1.0581 | Train Acc: 0.4101
Valid Loss: 1.0764 | Valid Acc: 0.4005

Epoch: 9/10
Train Loss: 1.0565 | Train Acc: 0.4059
Valid Loss: 1.0766 | Valid Acc: 0.3933

Epoch: 10/10
Train Loss: 1.0579 | Train Acc: 0.4108
Valid Loss: 1.0655 | Valid Acc: 0.4042

Test Loss: 1.0493 | Test Acc: 0.4208
