# Import Libraries

In [2]:
import torch, torchvision
from torchvision.transforms import ToTensor, Resize
from torchvision import transforms
import os
from PIL import Image
from torch.utils.data import DataLoader
import torch.nn as nn
from torch import optim
from tqdm import tqdm
import copy
import numpy as np
from torch.utils.data import Dataset
import requests
from io import BytesIO

# Constants

In [3]:
TRAIN_PATH = "../data/train"
VAL_PATH = "../data/val"
NUM_BATCH = 32
EPOCHS = 5
LEARNING_RATE = 1e-3
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Transformers

In [7]:
transform = transforms.Compose([
    ToTensor(),
    Resize((500,500))
])

# Datset & DataLoader

#### Dataset Class

In [8]:
class CatDogDataset(Dataset):

    def __init__(self, train_dir, transform = None):
        
        self.train_dir = train_dir
        self.transform = transform
        self.images = os.listdir(train_dir)
        

    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, index):
        image_path = os.path.join(self.train_dir, self.images[index])
        label = self.images[index].split(".")[0]

        label = 0 if label == 'cat' else 1
        
        image = np.array(Image.open(image_path))
        
        if self.transform is not None:
            image = self.transform(image)

        return image, label

#### Prepare data and dataloader

In [9]:
train_data = CatDogDataset(TRAIN_PATH, transform)
val_data = CatDogDataset(VAL_PATH, transform)


train_dl = DataLoader(train_data, batch_size=NUM_BATCH)
val_dl = DataLoader(val_data, batch_size=NUM_BATCH)

# Initialize ResNet18 

#### Import Pretrained model

In [12]:
model = torchvision.models.resnet18(pretrained=True)
print(model)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

#### freeze weights

In [11]:
for param in model.parameters():
    param.requires_grad = False

#### finetune the last fully connected layer to prefered output

In [13]:
model.fc = nn.Sequential(*[
    nn.Linear(in_features=512, out_features=2),
    nn.Softmax(dim=1)
])
print(model)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

# Functions

#### Validation Function

In [14]:
def validate(model, data):

    total = 0
    correct = 0

    for (images, labels) in data:
        images = images.to(DEVICE)
        x = model(images)
        _, pred = torch.max(x, 1)
        total += x.size(0)
        correct += torch.sum(pred == labels)

    return correct*100/total

#### Train Function

In [18]:
def train(num_epoch = EPOCHS, lr = LEARNING_RATE, device = DEVICE):
    accuracies = []
    cnn = model.to(device)
    cec = nn.CrossEntropyLoss()
    optimizer = optim.Adam(cnn.parameters(), lr=lr)

    max_accuracy = 0

    for epoch in range(num_epoch):
        for i, (images, labels) in tqdm(enumerate(train_dl)):
            images = images.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            pred = cnn(images)
            loss = cec(pred, labels)
            loss.backward()
            optimizer.step()
        accuracy = float(validate(cnn,val_dl))
        accuracies.append(accuracy)
        if accuracy > max_accuracy:
            best_model = copy.deepcopy(cnn)
            max_accuracy = accuracy
            print("saving best model with accuracy: ", accuracy)
        print("Epoch: ", epoch+1, "Accuracy: ", accuracy, "%")

    # plt.plot(accuracies)
    return best_model


# Train the model

In [19]:
# resnet = train()

# Save Best Model

In [20]:
# torch.save(resnet.state_dict(), "ResNet_CatDog_v2.pth")
# OR
# model_scripted = torch.jit.script(resnet) 
# model_scripted.save('CatDogModel.pt')

# Inference

#### Main Function

In [1]:
def inference(path, model, device="cpu"):
    try:
        resp = requests.get(path, timeout=10)
        print("request sent")
    except:
        return False
    
    with torch.no_grad():
        image = np.array(Image.open(BytesIO(resp.content)))
        
        image = transforms(image)
        image = image.unsqueeze(0)
        pred = model(image.to(device))
        return pred

#### Run Inference

In [None]:
path = str(input("insert the image url: "))
pred = inference(path, model)
if torch.is_tensor(pred):
    pred_idx = np.argmax(pred)

    pred_label = "cat" if pred_idx == 0 else "dog"
    
    print(f"Predicted: {pred_label}, Prob: {pred[0][pred_idx]*100}%")
else:
    print("can not get the url!!!")