In [16]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from torchvision import datasets,models
from PIL import Image

from tensorboardX import SummaryWriter as Writer

import os
import pickle
from glob import glob
import numpy as np
from tqdm import tqdm_notebook as tqdm

In [44]:
torch.set_grad_enabled(True)
PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0
device = 'mps' if torch.backends.mps.is_available() else 'cpu'
print(device)

mps


In [45]:
def savePickle(name, toSave):
    file = open(name, 'wb')
    pickle.dump(toSave, file)
    file.close()

def loadPickle(name):
    file = open(name, 'rb')
    data = pickle.load(file)
    file.close()
    return data

In [46]:
class AverageMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        val = np.array(val)
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [47]:
numClasses = 2
TrainDatasetPath = "./train"
TestDatasetPath = "./test1"
lr = 0.01
_architecture = "INCEPTIONV3"

if _architecture == "INCEPTIONV3":
    batchSize = 8
    resizeSize = 332
    cropSize = 299
    
modelPath = _architecture

In [48]:
data_transforms = {
    "TRAIN": transforms.Compose([
        transforms.RandomResizedCrop(cropSize),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]),
    ]),
    "VAL": transforms.Compose([
        transforms.Resize(resizeSize),
        transforms.CenterCrop(cropSize),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]),
    ]),
    "TEST": transforms.Compose([
        transforms.Resize(resizeSize),
        transforms.CenterCrop(cropSize),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]),
    ]),
    "IMAGE": transforms.Compose([transforms.Resize(resizeSize),transforms.CenterCrop(cropSize)])
}

In [49]:
class Data(Dataset):
    def __init__(self, path, transform = None):
        self.x = []
        self.y = []
        self.transform = transform
        self.x = glob(path + "/*")
        
    def __len__(self):
        return len(self.x)
    
    def __getitem__(self, item):
        img = Image.open(open(self.x[item],'rb'))
        x = self.transform(img)
        y = np.array([1] if 'dog' in self.x[item] else [0]).astype('float32')
        return x,y
    

trainDataset = Data(TrainDatasetPath ,data_transforms["TRAIN"])
trainLoader = DataLoader(trainDataset, batch_size=batchSize, shuffle=False, num_workers=0)
testDataset = Data(TestDatasetPath ,data_transforms["VAL"])
testLoader = DataLoader(testDataset, batch_size=batchSize, shuffle=False, num_workers=0)
print(len(trainLoader),len(testLoader))

3125 1563


In [50]:
def getInceptionV3():
    inception = models.inception_v3(pretrained=True)
    num_features = inception.fc.in_features
    inception.fc = nn.Linear(num_features, numClasses)
    return inception

In [59]:
if os.path.exists(modelPath+".pt"):
    model = torch.load(modelPath+".pt")
    startEpoch,optimizer = loadPickle(modelPath+".info")
    print("loaded existing model")
else:
    print("creating model")
    if _architecture == "INCEPTIONV3":
        model = getInceptionV3().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()
criterion.requires_grad = True
optimizer = torch.optim.SGD(model.parameters(), lr=lr)

criterion.to(device)
model.to(device)

creating model


Inception3(
  (Conv2d_1a_3x3): BasicConv2d(
    (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
    (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (Conv2d_2a_3x3): BasicConv2d(
    (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (Conv2d_2b_3x3): BasicConv2d(
    (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (maxpool1): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  (Conv2d_3b_1x1): BasicConv2d(
    (conv): Conv2d(64, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(80, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (Conv2d_4a_3x3): BasicConv2d(
    (conv): Conv2d(80, 192, kernel_size=(3, 3), stri

In [60]:
def train():
    model.train()
    train_losses = AverageMeter()
    for idx, (x, y) in enumerate(tqdm(trainLoader)):
        x = x.to(device)
        y = y.to(device)
        y_pred = model(x)[0]
        loss = criterion(y, torch.argmax(y_pred, dim=1))
        loss.requires_grad = True
        loss.backward()
        optimizer.zero_grad()
        optimizer.step()
        train_losses.update(loss.item(), x.size(0))
    return train_losses.avg
    
    
def validate():
    model.eval()
    validate_losses = AverageMeter()
    for idx, (x, y) in enumerate(tqdm(testLoader)):
        x = x.to(device)
        y = y.to(device)
        y_pred = model(x)
        loss = criterion(y, torch.argmax(y_pred))
        validate_losses.update(loss.item(), x.size(0))
    return validate_losses.avg

In [61]:
writer = Writer("./logs/")
startEpoch = 0
endEpoch = 1
for currentEpoch in range(startEpoch,endEpoch):
    trainLoss = train()
    testLoss   = validate()
    print(currentEpoch,trainLoss, testLoss)
    torch.save(model, modelPath+".pt")
    writer.add_scalar('trainLoss', trainLoss, currentEpoch)
    writer.add_scalar('testLoss', testLoss, currentEpoch)

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for idx, (x, y) in enumerate(tqdm(trainLoader)):


  0%|          | 0/3125 [00:00<?, ?it/s]

KeyboardInterrupt: 