In [1]:
# !pip install scikit-learn 
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import ImageFolder
import numpy as np 
from sklearn.metrics import accuracy_score
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])
dataset = ImageFolder(root='./classification/', transform=transform)

split_size = [0.8,0.1,0.1]
train_size = int(split_size[0] * len(dataset))
test_size = int(split_size[1] * len(dataset))
val_size = len(dataset) - train_size - test_size

train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, val_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=True)

ValueError: Could not find the operator torchvision::nms. Please make sure you have already registered the operator and (if registered from C++) loaded it via torch.ops.load_library.

In [None]:
class CustomNet(nn.Module):
    def __init__(self, num_classes):
        super(CustomNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(128 * 56 * 56, 512)
        self.fc2 = nn.Linear(512, num_classes)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

In [None]:
def train_step( model, dataloader, device, optimizer,criterion):
    model.train()
    total_loss = 0
    for batch in dataloader:
        inputs, labels = batch
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        total_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    return total_loss / len(dataloader)

In [None]:
def eval_step(model, dataloader, device, criterion):
    model.eval()
    total_loss = 0
    targets = []
    predictions = []
    for batch in dataloader:
        inputs, labels = batch
        inputs, labels = inputs.to(device), labels.to(device)
        model.zero_grad()
        with torch.no_grad():
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            outputs = nn.functional.softmax(outputs, dim=1)
            outputs = torch.argmax(outputs, dim=1)
            total_loss += loss.item()
            targets.extend(labels.cpu().numpy())
            predictions.extend(outputs.cpu().numpy())
    
    epoch_loss = total_loss/len(dataloader)
    return epoch_loss, targets, predictions

In [None]:
model = CustomNet(num_classes=3).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

model.to(device)

best_loss = np.inf
best_epoch = 0

# epochs=50
# model_path = 'classification.pth'
# for epoch in range(epochs):
#     train_loss = train_step(model, train_loader, device, optimizer,criterion)

#     val_loss,_,_ = eval_step(model, val_loader, device,criterion)

#     print(f"\nEpoch: {epoch+1} | Training loss: {train_loss} | Validation Loss: {val_loss}")
#     if (val_loss < best_loss) :
#         torch.save(model.state_dict(), model_path)
#         best_loss = val_loss
#         best_epoch = epoch+1

In [None]:
loaded_state_dict = torch.load( "classification1.pth",  map_location=device)
model.load_state_dict(loaded_state_dict)

test_loss,targets,predictions = eval_step(model, test_loader, device,criterion)
accuracy = accuracy_score(targets, predictions)
print(f"Accuracy: {accuracy}")

Accuracy: 1.0
