Model Creation

In [32]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision import datasets, models

# Define transformations for the training and validation data
train_transforms = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

val_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Load the datasets
train_dataset = datasets.ImageFolder('./data/train', transform=train_transforms)
val_dataset = datasets.ImageFolder('./data/val', transform=val_transforms)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=32, shuffle=False)

# Load a pre-trained ResNet model and modify the final layer
model = models.resnet18(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, len(train_dataset.classes))

# Move the model to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)


Training the model

In [33]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)


In [37]:
from PIL import Image

num_epochs = 10

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)

    epoch_loss = running_loss / len(train_loader.dataset)
    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}')
    
    # Validation phase
    model.eval()
    val_running_loss = 0.0
    correct = 0

    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            val_running_loss += loss.item() * inputs.size(0)
            _, preds = torch.max(outputs, 1)
            correct += torch.sum(preds == labels.data)

    val_loss = val_running_loss / len(val_loader.dataset)
    val_acc = correct.double() / len(val_loader.dataset)
    print(f'Validation Loss: {val_loss:.4f}, Accuracy: {val_acc:.4f}')


Epoch 1/10, Loss: 0.5425
Validation Loss: 0.1345, Accuracy: 0.9585
Epoch 2/10, Loss: 0.5493
Validation Loss: 0.1299, Accuracy: 0.9598
Epoch 3/10, Loss: 0.4999
Validation Loss: 0.1072, Accuracy: 0.9669
Epoch 4/10, Loss: 0.4987
Validation Loss: 0.1393, Accuracy: 0.9549
Epoch 5/10, Loss: 0.5295
Validation Loss: 0.1361, Accuracy: 0.9604
Epoch 6/10, Loss: 0.4806
Validation Loss: 0.0873, Accuracy: 0.9751
Epoch 7/10, Loss: 0.4588
Validation Loss: 0.1206, Accuracy: 0.9617
Epoch 8/10, Loss: 0.4700
Validation Loss: 0.0760, Accuracy: 0.9770
Epoch 9/10, Loss: 0.4686
Validation Loss: 0.0913, Accuracy: 0.9728
Epoch 10/10, Loss: 0.4792
Validation Loss: 0.0936, Accuracy: 0.9731


Inference

In [None]:
model.load_state_dict(torch.load('pokemon_model.pth'))
model.eval()

In [52]:
from PIL import Image

def predict(image_path, model, transform, class_names):
    image = Image.open(image_path)
    image = transform(image).unsqueeze(0)
    image = image.to(device)
    
    with torch.no_grad():
        outputs = model(image)
        _, preds = torch.max(outputs, 1)
        predicted_class = class_names[preds[0]]
    
    return predicted_class

# Transform for inference
infer_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

class_names = train_dataset.classes

image_path = './mew.png'
predicted_class = predict(image_path, model, infer_transform, class_names)
print(f'Predicted Class: {predicted_class}')


FileNotFoundError: [Errno 2] No such file or directory: './mew'

Integrate with React Native

In [None]:
import torch.onnx

# Convert PyTorch model to ONNX format
dummy_input = torch.randn(1, 3, 224, 224, device=device)
torch.onnx.export(model, dummy_input, "pokemon_model.onnx", verbose=True)
