# Transfer learning with PyTorch
We're going to train a neural network to classify dogs and cats.

## Init, helpers, utils, ...

In [1]:
from pprint import pprint
import random
import datetime
import time

from IPython.core.debugger import set_trace

import matplotlib.pyplot as plt
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision.models as models
import os

import torchvision
from torchvision.datasets.folder import ImageFolder, default_loader

%matplotlib inline

In [4]:
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

device(type='cuda', index=1)

In [6]:
# Training helpers
def get_trainable(model_params):
    return (p for p in model_params if p.requires_grad)


def get_frozen(model_params):
    return (p for p in model_params if not p.requires_grad)


def all_trainable(model_params):
    return all(p.requires_grad for p in model_params)


def all_frozen(model_params):
    return all(not p.requires_grad for p in model_params)


def freeze_all(model_params):
    for param in model_params:
        param.requires_grad = False

## Transforms

In [7]:
from torchvision import transforms

IMG_SIZE = 224  #224  #defined by NN model input
_mean = [0.485, 0.456, 0.406]
_std = [0.229, 0.224, 0.225]


train_trans = transforms.Compose([
    transforms.Resize((IMG_SIZE,IMG_SIZE)),  #256  #(IMG_SIZE, IMG_SIZE)  # some images are pretty small
    #transforms.RandomCrop(IMG_SIZE),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(.3, .3, .3),
    transforms.ToTensor(),
    transforms.Normalize(_mean, _std),
])
val_trans = transforms.Compose([
    transforms.Resize((IMG_SIZE,IMG_SIZE)),  #256  #(IMG_SIZE, IMG_SIZE)
    #transforms.CenterCrop(IMG_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(_mean, _std),
])

## Dataset

In [9]:
# change current working path to the root folder of this project

root_path = os.path.abspath(os.pardir)

os.chdir(root_path)

BATCH_SIZE = 128  #2  #256  #512  #32  #220 for resnet152 on Dell Presison 5520 laptop, 400 for resnet18

n_classes = 2

from torch.utils.data import DataLoader

train_dl = DataLoader(
    train_ds,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=4,
)
val_dl = DataLoader(
    val_ds,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=4,
)

In [13]:
# add our new model for three inputs

from nets.ResNet_ronorigin import *
model = resnet_18()

for param in model.parameters():
    param.requires_grad = True
    
model = model.to(device)
criterion = nn.CrossEntropyLoss()

# Predict with Trained Model

Pickle the trained model and predict image with it. 

In [None]:
# save the trained model weights
import os

root_path = os.path.abspath(os.pardir)

model_weights_path = root_path+'/data/saved_model_weights/resnet18_whole'

# !!!only use it when you WANT to save a trained model weights!!!
torch.save(model.state_dict(), model_weights_path)

In [None]:
# load the trained model
from nets.ResNet_ronorigin import *
model = resnet_18()

for param in model.parameters():
    param.requires_grad = True
    
model = model.to(device)

model.load_state_dict(torch.load(model_weights_path))

In [None]:
# use trained model to predict a pair of people in a new image

# test data set
test_ds = ImageFolder(root_path+"/data/raw/DUI/test", transform=val_trans, loader=default_loader)
print(f'len(test_ds) = {len(test_ds)}. ')

test_dl = DataLoader(
    test_ds,xx
    batch_size=1,
    shuffle=False,
    num_workers=4,
)

print(f'test_ds[99]: \n{test_ds[99]}')
print(f'test_ds[99][1]: \n{test_ds[99][1]}')

In [None]:
# predict, WIP

# Eval
model.eval()  # IMPORTANT

with torch.no_grad():  # IMPORTANT
    for x in imagefolder:
        canvas, all_predicts = model()
        pred = model(x)

In [5]:
import os

root_path = os.path.abspath(os.pardir)

os.chdir(root_path)

from multipersonpose.predict_boxes import predict_boxes as pd


weight_name = 'multipersonpose/network/weight/pose_model.pth'
test_image = 'multipersonpose/readme/ski.jpg'

canvas, all_predicts = predict_boxes(test_image, weight_name)

ModuleNotFoundError: No module named 'network'