In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from torchvision.datasets import CIFAR10
from torchvision import transforms
from torchvision.models import resnet50
from torchsummary import summary
from sklearn.model_selection import train_test_split

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load CIFAR-10 dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_dataset = CIFAR10(root='./data', train=True, transform=transform, download=True)
test_dataset = CIFAR10(root='./data', train=False, transform=transform, download=True)

# Split the dataset into train and validation sets
train_indices, val_indices = train_test_split(list(range(len(train_dataset))), test_size=0.2, random_state=42)

train_sampler = torch.utils.data.SubsetRandomSampler(train_indices)
val_sampler = torch.utils.data.SubsetRandomSampler(val_indices)

train_loader = DataLoader(dataset=train_dataset, batch_size=128, sampler=train_sampler, num_workers=2)
val_loader = DataLoader(dataset=train_dataset, batch_size=128, sampler=val_sampler, num_workers=2)
test_loader = DataLoader(dataset=test_dataset, batch_size=128, shuffle=False, num_workers=2)

# Load pre-trained ResNet-50 without top (classification) layer
resnet = resnet50(pretrained=True, progress=True)
resnet = nn.Sequential(*list(resnet.children())[:-1])  # Remove the last fully connected layer

# Freeze the parameters in the ResNet-50 model
for param in resnet.parameters():
    param.requires_grad = False

# Extract features from the images using ResNet-50
def extract_features(loader):
    features = []
    labels = []

    resnet.eval()
    with torch.no_grad():
        for images, batch_labels in loader:
            images = images.to(device)
            batch_labels = batch_labels.to(device)

            # Obtain feature representations
            features.append(resnet(images).squeeze())
            labels.append(batch_labels)

    features = torch.cat(features, dim=0)
    labels = torch.cat(labels, dim=0)
    return features, labels

# Extract features for training, validation, and testing
train_features, train_labels = extract_features(train_loader)
val_features, val_labels = extract_features(val_loader)
test_features, test_labels = extract_features(test_loader)

# Reshape the features for LSTM input
train_features = train_features.view(train_features.size(0), 1, -1)
val_features = val_features.view(val_features.size(0), 1, -1)
test_features = test_features.view(test_features.size(0), 1, -1)

# Define RNN model architecture
class RNNModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(RNNModel, self).__init__()
        self.rnn = nn.LSTM(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        _, (h_n, _) = self.rnn(x)
        out = self.fc(h_n[-1])
        return out

input_size = train_features.size(2)
hidden_size = 64
output_size = 10

model = RNNModel(input_size, hidden_size, output_size).to(device)
summary(model, input_size=(1, 1, input_size))

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
num_epochs = 10
best_val_accuracy = 0.0

for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0
    train_correct = 0

    for features, labels in train_loader:
        features = features.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()

        # Forward pass
        outputs = model(features)
        loss = criterion(outputs.squeeze(), labels)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        train_correct += (predicted == labels).sum().item()

    train_loss /= len(train_loader.dataset)

    # Validation
    model.eval()
    val_loss = 0.0
    val_correct = 0

    with torch.no_grad():
        for features, labels in val_loader:
            features = features.to(device)
            labels = labels.to(device)

            outputs = model(features)
            loss = criterion(outputs.squeeze(), labels)

            val_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            val_correct += (predicted == labels).sum().item()

    val_loss /= len(val_loader.dataset)
    val_accuracy = val_correct / len(val_loader.dataset)

    print(f'Epoch [{epoch+1}/{num_epochs}], '
          f'Train Loss: {train_loss:.4f}, Train Accuracy: {100.0 * train_correct / len(train_loader.dataset):.2f}%, '
          f'Val Loss: {val_loss:.4f}, Val Accuracy: {100.0 * val_accuracy:.2f}%')

    # Save the model if it has the best validation accuracy
    if val_accuracy > best_val_accuracy:
        best_val_accuracy = val_accuracy
        torch.save(model.state_dict(), 'best_model.pt')

# Test the model
model.load_state_dict(torch.load('best_model.pt'))
model.eval()
test_correct = 0

with torch.no_grad():
    for features, labels in test_loader:
        features = features.to(device)
        labels = labels.to(device)

        outputs = model(features)
        _, predicted = torch.max(outputs, 1)
        test_correct += (predicted == labels).sum().item()

test_accuracy = test_correct / len(test_loader.dataset)
print(f'Test Accuracy: {100.0 * test_accuracy:.2f}%')


Files already downloaded and verified
Files already downloaded and verified




AssertionError: LSTM: Expected input to be 2-D or 3-D but received 4-D tensor