In [2]:
import torch 
import torch.nn as nn  
import torch.optim as optim 
import torch.utils.data 
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import torchvision 
from torchvision import transforms
from PIL import Image, ImageFile 
import cv2

In [5]:
def check_image(path):
    try:
        im = Image.open(path)
        return True 
    except:
        return False

In [3]:
img_transforms = transforms.Compose([
    transforms.Resize((64,64)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
    std=[0.229, 0.224, 0.225])
])

In [8]:
train_data_path="./train/"
train_data = torchvision.datasets.ImageFolder(root=train_data_path, transform=img_transforms, is_valid_file=check_image)

In [9]:
val_data_path = "./val/"
val_data = torchvision.datasets.ImageFolder(root=val_data_path, transform=img_transforms, is_valid_file=check_image)

In [10]:
test_data_path="./test/"
test_data=torchvision.datasets.ImageFolder(root=test_data_path,transform=img_transforms, is_valid_file=check_image)

In [11]:
batch_size = 64
train_data_loader=torch.utils.data.DataLoader(train_data, batch_size=batch_size)
val_data_loader=torch.utils.data.DataLoader(val_data, batch_size=batch_size)
test_data_loader=torch.utils.data.DataLoader(test_data, batch_size=batch_size)


In [12]:
class SimpleNet(nn.Module):

    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(12288, 84)
        self.fc2 = nn.Linear(84, 50)
        self.fc3 = nn.Linear(50,2)
    
    def forward(self, x):
        x = x.view(-1, 12288)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


In [14]:
simplenet = SimpleNet()
optimizer = optim.Adam(simplenet.parameters(), lr=0.001)
if torch.cuda.is_available():
    device = torch.device("cuda") 
else:
    device = torch.device("cpu")

simplenet.to(device)


SimpleNet(
  (fc1): Linear(in_features=12288, out_features=84, bias=True)
  (fc2): Linear(in_features=84, out_features=50, bias=True)
  (fc3): Linear(in_features=50, out_features=2, bias=True)
)

In [15]:
def train(model, optimizer, loss_fn, train_loader, val_loader, epochs=20, device="cpu"):
    for epoch in range(epochs):
        training_loss = 0.0
        valid_loss = 0.0 
        model.train()
        for batch in train_loader:
            optimizer.zero_grad()
            inputs, targets = batch
            inputs = inputs.to(device)
            targets = targets.to(device)
            output = model(inputs)
            loss = loss_fn(output, targets)
            loss.backward()
            optimizer.step()
            training_loss += loss.data.item() * inputs.size(0)
        training_loss /= len(train_loader.dataset)

        model.eval()
        num_correct = 0
        num_examples = 0
        for batch in val_loader:
            inputs, targets = batch
            inputs  = inputs.to(device)
            targets = targets.to(device)
            output = model(inputs)
            loss = loss_fn(output, targets)
            valid_loss += loss.data.item()  * inputs.size(0) 
            print('inputs size 0: ' , inputs.size(0))
            correct = torch.eq(torch.max(F.softmax(output), dim=1)[1], targets).view(-1)
            num_correct += torch.sum(correct).item()
            num_examples += correct.shape[0]
        valid_loss /= len(val_loader.dataset)

        print('Epoch: {}, Training loss: {:.2f}, Val loss: {:2f}, acc: {:.2f}'.format(epoch, training_loss, valid_loss, num_correct/ num_examples))

In [17]:
train(simplenet, optimizer,torch.nn.CrossEntropyLoss(), train_data_loader,val_data_loader, epochs=5, device=device)


inputs size 0:64
inputs size 0:24
Epoch: 0, Training loss: 0.27, Val loss: 0.421630, acc: 0.78
inputs size 0:64
inputs size 0:24
Epoch: 1, Training loss: 0.22, Val loss: 0.404057, acc: 0.80
inputs size 0:64
inputs size 0:24
Epoch: 2, Training loss: 0.18, Val loss: 0.399196, acc: 0.80
inputs size 0:64
inputs size 0:24
Epoch: 3, Training loss: 0.15, Val loss: 0.410392, acc: 0.80
inputs size 0:64
inputs size 0:24
Epoch: 4, Training loss: 0.12, Val loss: 0.411403, acc: 0.82


In [21]:
labels = ['cat', 'fish']
img = Image.open("./val/cat/99029168_940da3a1e5.jpg")
img = img_transforms(img).to(device)

pred = F.softmax(simplenet(img))
pred = pred.argmax()
print(labels[pred])

fish
