In [1]:
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision
from torchvision import transforms

from PIL import Image, UnidentifiedImageError

## Remove Corrupted Downloads

In [2]:
def remove_corrupt_images(directory):
    bad_files = []
    for subdir, _, files in os.walk(directory):
        for file in files:
            file_path = os.path.join(subdir, file)
            try:
                img = Image.open(file_path)
                img.verify()
            except Exception:
                bad_files.append(file_path)

    for file in bad_files:
        os.remove(file)
        print(f"Removed corrupted image: {file}")

remove_corrupt_images('./train')
remove_corrupt_images('./test')
remove_corrupt_images('./val')

## Prepare Dataset

In [4]:
train_data_path = './train'
val_data_path = './val'
test_data_path = './test'

transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

train_data = torchvision.datasets.ImageFolder(
    root=train_data_path,
    transform=transform
)

val_data = torchvision.datasets.ImageFolder(
    root=val_data_path,
    transform=transform
)

test_data = torchvision.datasets.ImageFolder(
    root=test_data_path,
    transform=transform
)

In [5]:
batch_size = 64
num_workers = 4
train_loader = torch.utils.data.DataLoader(
    dataset=train_data,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers
)
val_loader = torch.utils.data.DataLoader(
    dataset=val_data,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers
)
test_loader = torch.utils.data.DataLoader(
    dataset=test_data,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers
)

## FC Model

In [6]:
class SimpleNet(nn.Module):
    def __init__(self, num_classes=2):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(3*64*64, 84)
        self.fc2 = nn.Linear(84, 50)
        self.fc3 = nn.Linear(50, num_classes)

    def forward(self, x):
        x = x.view(-1, 3*64*64)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [7]:
device = torch.device('mps') if torch.backends.mps.is_available() else torch.device('cpu')
device

device(type='mps')

In [8]:
simplenet = SimpleNet().to(device)
optimizer = optim.Adam(simplenet.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

## Train Eval Loop

In [None]:
epochs = 20
def train(
    model=simplenet, optimzer=optimizer, loss_fn=criterion, 
    train_loader=train_loader, val_loader=val_loader, 
    epochs=epochs, device=device):

    for epoch in range(epochs):
        train_loss = 0.0
        val_loss = 0.0
        model.train()
        for batch in train_loader:
            optimizer.zero_grad()
            inputs, target = batch
            inputs = inputs.to(device)
            target = target.to(device)
            output = model(inputs)
            loss = loss_fn(output, target)
            loss.backward()
            optimizer.step()

            train_loss += loss.data.item()
        train_loss /= len(train_loader)


        model.eval()
        num_correct = 0.0
        num_examples = 0.0
        for batch in val_loader:
            inputs, target = batch
            inputs = inputs.to(device)
            target = target.to(device)
            output = model(inputs)
            loss = loss_fn(output, target)

            val_loss += loss.data.item()
            correct = torch.eq(torch.max(F.softmax(output), dim=1)[1], target).view(-1)
            num_correct += torch.sum(correct).item()
            num_examples += correct.shape[0]
        val_loss /= len(val_loader)

        print(f"Epoch: {epoch+1}, Train Loss: {train_loss:.2f}, Val Loss: {val_loss:.2f}, Accuracy: {(num_correct/num_examples):.2f}")


train()

## Test Model

In [19]:
labels = ['cat', 'fish']

img = Image.open('./test/fish/36488930_5d269647b5.jpg')
img = transform(img)
img = img.unsqueeze(0)

pred = simplenet(img.to(device))
print(pred)
pred = pred.argmax()
print(f"Predicted: {labels[pred.item()]}")

tensor([[-6.1190,  6.2526]], device='mps:0', grad_fn=<LinearBackward0>)
Predicted: fish


In [None]:
for img, target in test_loader:

    pred = simplenet(img.to(device))
    pred = pred.argmax(dim=1)
    for i in range(len(pred)):
        print(f"Predicted: {labels[pred[i]]}, Actual: {labels[target[i]]}")

Predicted: cat, Actual: cat
Predicted: cat, Actual: cat
Predicted: cat, Actual: cat
Predicted: cat, Actual: cat
Predicted: cat, Actual: cat
Predicted: fish, Actual: cat
Predicted: cat, Actual: cat
Predicted: cat, Actual: cat
Predicted: cat, Actual: cat
Predicted: cat, Actual: cat
Predicted: cat, Actual: cat
Predicted: fish, Actual: cat
Predicted: cat, Actual: cat
Predicted: cat, Actual: cat
Predicted: cat, Actual: cat
Predicted: cat, Actual: cat
Predicted: fish, Actual: cat
Predicted: cat, Actual: cat
Predicted: cat, Actual: cat
Predicted: cat, Actual: cat
Predicted: cat, Actual: cat
Predicted: cat, Actual: cat
Predicted: cat, Actual: cat
Predicted: fish, Actual: cat
Predicted: cat, Actual: cat
Predicted: cat, Actual: cat
Predicted: cat, Actual: cat
Predicted: cat, Actual: cat
Predicted: fish, Actual: cat
Predicted: cat, Actual: cat
Predicted: cat, Actual: cat
Predicted: fish, Actual: cat
Predicted: cat, Actual: cat
Predicted: cat, Actual: cat
Predicted: cat, Actual: cat
Predicted: cat