In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data
from torch.utils.data import DataLoader, Dataset


import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models
from PIL import Image

import numpy as np
import copy
import os
import random
import time
import cv2
from ipywidgets import interact
from collections import namedtuple
import matplotlib.pyplot as plt


# 데이터 확인 및 전처리

In [2]:
cat_dir = './data/Cat/'
dog_dir = './data/Dog/'

cat_image_path = sorted([os.path.join(cat_dir, f) for f in os.listdir(cat_dir)])
dog_image_path = sorted([os.path.join(dog_dir, f) for f in os.listdir(dog_dir)])

image_file_path = [*cat_image_path, *dog_image_path]
correct_image_path = [i for i in image_file_path if cv2.imread(i) is not None]

random.seed(29)
random.shuffle(correct_image_path)

train_image_files = correct_image_path[:400]
val_image_files = correct_image_path[400:-10]
test_image_files = correct_image_path[-10:]
print(len(train_image_files), len(val_image_files), len(test_image_files))
print(train_image_files[:5])

400 92 10
['./data/Cat/cat.4.jpg', './data/Dog/dog.20.jpg', './data/Cat/cat.158.jpg', './data/Cat/cat.75.jpg', './data/Cat/cat.152.jpg']


In [3]:
@interact(index=(0, len(train_image_files)-1))
def image_show(index=0):
    image = train_image_files[index]
    image = cv2.imread(image)
    print('image shape: ', image.shape)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    plt.title(train_image_files[index].split('/')[-1].split('.')[0])
    plt.imshow(image)
    plt.tight_layout()
    plt.show()

interactive(children=(IntSlider(value=0, description='index', max=399), Output()), _dom_classes=('widget-inter…

In [4]:
from utils import build_trasnforms, MyDataset

In [5]:
image_size = 224
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
batch_size = 32

In [6]:
train_dataset = MyDataset(train_image_files, transforms=build_trasnforms(image_size=image_size, mean=mean, std=std), phase='train')
val_dataset = MyDataset(val_image_files, transforms=build_trasnforms(image_size=image_size, mean=mean, std=std), phase='val')

In [7]:
print(train_dataset.__getitem__(0)[0].size(), train_dataset.__getitem__(0)[1])

torch.Size([3, 224, 224]) 0


In [11]:
dataloaders = {}

dataloaders['train'] = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
dataloaders['val'] = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

image, label = next(iter(dataloaders['train']))
print(image.shape, label)

torch.Size([32, 3, 224, 224]) tensor([0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1,
        0, 0, 1, 1, 1, 1, 1, 1])


# build model