In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
import numpy as np
import random


In [None]:
import os
import sys
from google.colab import drive

# Mount Drive
drive.mount('/content/drive', force_remount=True)



Mounted at /content/drive


In [None]:
path = '/content/drive/My Drive/ECE661/Project/rotnet'

print(os.listdir(path))
sys.path.append(path)

[]


In [None]:
class RotNetDataset(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset
        self.rotation_angles = [0, 90, 180, 270]

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        img, _ = self.dataset[idx]
        angle = random.choice(self.rotation_angles)
        rotated_img = transforms.functional.rotate(img, angle)
        label = self.rotation_angles.index(angle)
        return rotated_img, label


In [None]:
class RotNet(nn.Module):
    def __init__(self):
        super(RotNet, self).__init__()
        # Start with a simple CNN or use torchvision.models.resnet18(pretrained=False)
        self.encoder = torchvision.models.resnet18(pretrained=False)
        self.encoder.fc = nn.Linear(512, 4)  # 4 rotation classes

    def forward(self, x):
        return self.encoder(x)


In [None]:
# Transforms for CIFAR-10 (resize to match ResNet input if needed)
transform = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
])

# Load CIFAR-10
train_data = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(RotNetDataset(train_data), batch_size=128, shuffle=True)

# Model, Loss, Optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = RotNet().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)




In [None]:
# num_epochs = 10

# for epoch in range(num_epochs):
#     model.train()
#     running_loss = 0.0
#     correct, total = 0, 0

#     for imgs, labels in train_loader:
#         imgs, labels = imgs.to(device), labels.to(device)

#         outputs = model(imgs)
#         loss = criterion(outputs, labels)

#         optimizer.zero_grad()
#         loss.backward()
#         optimizer.step()

#         running_loss += loss.item()
#         _, predicted = torch.max(outputs, 1)
#         total += labels.size(0)
#         correct += (predicted == labels).sum().item()

#     print(f"Epoch [{epoch+1}/{num_epochs}] Loss: {running_loss/len(train_loader):.4f} Accuracy: {100*correct/total:.2f}%")




def train(model, dataloader, optimizer, criterion, num_epochs, save_path="checkpoints/"):
    for epoch in range(1, num_epochs + 1):
        model.train()
        epoch_loss = 0
        correct, total = 0, 0

        for batch in dataloader:
            inputs, labels = batch  # Adjust this line for your dataset
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()

            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()


        print(f"Epoch {epoch}/{num_epochs}, Loss: {epoch_loss/len(dataloader):.4f} Accuracy: {100*correct/total:.2f}%")

        # 🔒 Save model every 20 epochs
        if epoch % 20 == 0:
            torch.save(model.state_dict(), f"{save_path}model_epoch_{epoch}.pt")
            print(f"✅ Saved checkpoint at epoch {epoch}")


In [None]:
train(model, train_loader, optimizer, criterion, 200, path)

Epoch 1/200, Loss: 0.9098 Accuracy: 62.29%
Epoch 2/200, Loss: 0.8253 Accuracy: 65.74%
Epoch 3/200, Loss: 0.7485 Accuracy: 69.33%
Epoch 4/200, Loss: 0.6783 Accuracy: 72.93%
Epoch 5/200, Loss: 0.6123 Accuracy: 75.89%
Epoch 6/200, Loss: 0.5676 Accuracy: 77.94%
Epoch 7/200, Loss: 0.5199 Accuracy: 79.97%
Epoch 8/200, Loss: 0.4808 Accuracy: 81.33%
Epoch 9/200, Loss: 0.4457 Accuracy: 82.97%
Epoch 10/200, Loss: 0.4140 Accuracy: 84.27%
Epoch 11/200, Loss: 0.3897 Accuracy: 85.39%
Epoch 12/200, Loss: 0.3664 Accuracy: 86.26%
Epoch 13/200, Loss: 0.3388 Accuracy: 87.28%
Epoch 14/200, Loss: 0.3189 Accuracy: 88.17%
Epoch 15/200, Loss: 0.2981 Accuracy: 88.96%
Epoch 16/200, Loss: 0.2771 Accuracy: 89.82%
Epoch 17/200, Loss: 0.2640 Accuracy: 90.30%
Epoch 18/200, Loss: 0.2421 Accuracy: 91.20%
Epoch 19/200, Loss: 0.2248 Accuracy: 91.83%
Epoch 20/200, Loss: 0.2094 Accuracy: 92.41%
✅ Saved checkpoint at epoch 20
Epoch 21/200, Loss: 0.1934 Accuracy: 93.08%
Epoch 22/200, Loss: 0.1780 Accuracy: 93.70%
Epoch 23/2

In [None]:
# Save encoder (feature extractor)
torch.save(model.encoder.state_dict(), 'rotnet_encoder.pth')


In [None]:
# Load CIFAR-10 Test Data
test_data = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
test_loader = DataLoader(RotNetDataset(test_data), batch_size=128, shuffle=False)


In [None]:
def evaluate_rotnet(model, dataloader, device):
    model.eval()
    total, correct = 0, 0
    class_correct = [0] * 4
    class_total = [0] * 4

    with torch.no_grad():
        for imgs, labels in dataloader:
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            for i in range(labels.size(0)):
                label = labels[i]
 72.5               class_total[label] += 1
                if predicted[i] == label:
                    class_correct[label] += 1

    overall_acc = 100 * correct / total
    print(f"\nTest Accuracy: {overall_acc:.2f}%")

    for i, angle in enumerate([0, 90, 180, 270]):
        if class_total[i] > 0:
            acc = 100 * class_correct[i] / class_total[i]
            print(f"Class {angle}° Accuracy: {acc:.2f}%")


In [None]:
print("Evaluating RotNet on test set...")
evaluate_rotnet(model, test_loader, device)


Evaluating RotNet on test set...

Test Accuracy: 80.82%
Class 0° Accuracy: 80.54%
Class 90° Accuracy: 78.62%
Class 180° Accuracy: 83.09%
Class 270° Accuracy: 81.04%
