In [None]:
#imports
import torch
import torchvision
from torchvision import datasets, models, transforms
from path import *
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
dataTransforms = {
    #use different transforms on the images to make it harder
    'train':transforms.Compose([
        transforms.RandomRotation(90),
        transforms.RandomAffine(0, shear=10, scale=(0.8,1.2)),
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.ColorJitter(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])
    ]),
    
    #transforms for validation
    'val':transforms.Compose([
        transforms.Resize(255),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])
    ]),
}

#setup for the training data
trainset = datasets.ImageFolder(root=path+"train\\",transform=dataTransforms["train"])
trainloader = torch.utils.data.DataLoader(trainset,batch_size=8,shuffle=True) 

#setup for the test data
testset = datasets.ImageFolder(root=path+"val\\",transform=dataTransforms["val"])
testloader = torch.utils.data.DataLoader(testset,batch_size=4,shuffle=True) 
 





In [None]:
import matplotlib.pyplot as plt
import numpy as np

images, iclasses = next(iter(trainloader))

#view some sample images
def imshow(im):
    img = im.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    img = std * img + mean
    img = np.clip(img, 0, 1)
    plt.imshow(img) 
    
imshow(torchvision.utils.make_grid(images))

In [None]:
import torch.nn as nn
import torch.nn.functional as F

#define the network
class Network(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(2048,512)
        self.fc2 = nn.Linear(512,128)
        self.fc3 = nn.Linear(128,32)
        self.fc4 = nn.Linear(32,3)
        self.dropout = nn.Dropout(p=0.15)
        
    def forward(self,x):
        x = x.view(x.shape[0],-1)
        x = self.dropout(F.relu(self.fc1(x)))
        x = self.dropout(F.relu(self.fc2(x)))
        x = self.dropout(F.relu(self.fc3(x)))
        x = F.log_softmax(self.fc4(x),dim=1)
        return x

net = Network()

In [None]:
model = models.resnet50(pretrained=True)#pretrained model
for parameter in model.parameters():
    parameter.requires_grad = False
model.fc = net #connect network to model
model.to(device)#move to device
model

In [None]:
import torch.optim as optim
criterion = nn.NLLLoss() #negative log likelihood loss
optimizer = optim.Adam(model.fc.parameters(),lr=0.001)#similar to SGD but with momentum

In [None]:
epochs = 5
for epoch in range(epochs):
    #train
    model.train()
    tLoss = 0
    for images, labels in trainloader:
        images,labels = images.to(device),labels.to(device)
        optimizer.zero_grad()
        with torch.set_grad_enabled(True):
            output = model(images)
            loss = criterion(output,labels)
            loss.backward()
            optimizer.step()
            tLoss += loss.item()*images.size(0)
    #test
    model.eval() 
    acc = 0
    eLoss = 0
    for images, labels in testloader:
        #setup for getting the classificatoin
        images,labels = images.to(device),labels.to(device)
        optimizer.zero_grad()
        with torch.no_grad():
            output = model(images)
            loss = criterion(output,labels)
            eLoss += loss.item()*images.size(0)
            #check the accuracy by seeing if its the right classification
            val = torch.exp(output) 
            pVal, pClass = val.topk(1,dim=1)
            isEqual = pClass ==labels.view(*pClass.shape)
            acc += torch.mean(isEqual.type(torch.FloatTensor)).item()
     
    acc/=len(testloader)
    eLoss/= len(testloader)
    tLoss/= len(trainloader)
    print('Epoch: {} Train Loss: {:.2f} Accuracy: {:.2f} Test Loss: {:.2f}'.format(epoch, tLoss, acc, eLoss))
    
    