In [149]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.models as models
import torchvision.transforms as transforms
import numpy as np
import pandas as pd
from tqdm import tqdm
from skimage import io
import matplotlib.pyplot as plt
import cv2
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [150]:
num_class = 597
net = models.resnet18(pretrained=True)
num_ftrs = net.fc.in_features
net.fc = nn.Linear(num_ftrs, num_class)
net = net.to(DEVICE)
batch_size = 64
epoches = 5
learning_rate = 1e-4
val = 0.2
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate, weight_decay=5e-4)

In [155]:
class OpenImageDataset(Dataset):

	def __init__(self, csvfile, root_dir, transform=None, test=False):
		csv = pd.read_csv(csvfile)
		csv = csv.loc[csv.ImageID.str.startswith('0')]
		self.img_ids = csv.ImageID
		self.YMin = np.array(csv.YMin)
		self.YMax = np.array(csv.YMax)
		self.XMin = np.array(csv.XMin)
		self.XMax = np.array(csv.XMax)
		self.root_dir = root_dir
		if not test:
			self.n2l = dict(enumerate(list(csv.LabelName.unique())))
			self.l2n = dict((v,k) for k,v in self.n2l.items())
			self.labels = csv.LabelName
			self.num_class = csv.LabelName.nunique()
		self.transform = transform
		self.test = test

	def __len__(self):
		return len(self.img_ids)

	def __getitem__(self, idx):
		img_name = self.root_dir+self.img_ids[idx]+'.jpg'
		image = io.imread(img_name)
		print(image.shape)
		if len(image.shape) == 1:
			image = image[0]
		if len(image.shape) == 2:
			image = cv2.cvtColor(image,cv2.COLOR_GRAY2RGB)
		if image.shape[2] == 4:
			image = cv2.cvtColor(image,cv2.COLOR_RGBA2RGB)
		h0, w0 = image.shape[:2]
		scale = 1024/max(h0, w0)
		h, w = int(round(h0*scale)), int(round(w0*scale))
		image = image[
			max(0, int(self.YMin[idx]*h)-1):min(h-1, int(self.YMax[idx]*h)+1),
			max(0, int(self.XMin[idx]*w)-1):min(w-1, int(self.XMax[idx]*w)+1)]
		image = np.transpose(image, (2, 0, 1))
		if self.transform:
			image = self.transform(image)
		print(image.shape)
		if self.test:
			return image
		else:
			return (image, self.l2n[self.labels[idx]])

In [156]:
def get_dataloader():
	normalize = transforms.Normalize(
		mean=[0.485, 0.456, 0.406], 
		std=[0.229, 0.224, 0.225],
	)
	pre_transform = transforms.Compose([
		transforms.ToTensor(),
		transforms.ToPILImage(),
		transforms.RandomResizedCrop(224),
		transforms.RandomHorizontalFlip(),
		transforms.ToTensor(),
		normalize,
	])
	train_dataset = OpenImageDataset(
		csvfile='all/train_bounding_boxes.csv',
		root_dir='train_0/',
		transform=pre_transform,
		test=False,
	)
	'''
	test_dataset = OpenImageDataset(
		csvfile='hw7data/test.csv',
		root_dir='hw7data/images/',
		transform=pre_transform,
		test=True,
	)
	'''
	train_dataloader = DataLoader(
		train_dataset,
		batch_size=batch_size,
		shuffle=True,
	)
	'''
	test_dataloader = DataLoader(
		test_dataset,
		batch_size=batch_size,
		shuffle=False,
	)
	'''
	#return train_dataloader, test_dataloader
	return train_dataloader

In [157]:
train_dataloader = get_dataloader()

In [158]:
print(len(train_dataloader.dataset))
for X, y in train_dataloader:
    print(X.shape)

1093385
(683, 1024)
torch.Size([3, 224, 224])
(1024, 623, 3)
torch.Size([3, 224, 224])
(1024, 768, 3)
torch.Size([3, 224, 224])
(683, 1024, 3)
torch.Size([3, 224, 224])
(683, 1024, 3)
torch.Size([3, 224, 224])
(768, 1024, 3)
torch.Size([3, 224, 224])
(1024, 768, 3)
torch.Size([3, 224, 224])
(768, 1024, 3)
torch.Size([3, 224, 224])
(683, 1024, 3)
torch.Size([3, 224, 224])
(768, 1024, 3)
torch.Size([3, 224, 224])
(699, 1024, 3)
torch.Size([3, 224, 224])
(686, 1024, 3)
torch.Size([3, 224, 224])
(768, 1024, 3)
torch.Size([3, 224, 224])
(768, 1024, 3)
torch.Size([3, 224, 224])
(1024, 685, 3)
torch.Size([3, 224, 224])
(546, 1024, 3)
torch.Size([3, 224, 224])
(768, 768, 3)
torch.Size([3, 224, 224])
(683, 1024, 3)
torch.Size([3, 224, 224])
(683, 1024, 3)
torch.Size([3, 224, 224])
(768, 1024, 3)
torch.Size([3, 224, 224])
(678, 1024, 3)
torch.Size([3, 224, 224])
(683, 1024, 3)
torch.Size([3, 224, 224])
(768, 1024, 3)
torch.Size([3, 224, 224])
(683, 1024, 3)
torch.Size([3, 224, 224])
(768, 1024, 

(683, 1024, 3)
torch.Size([3, 224, 224])
(683, 1024, 3)
torch.Size([3, 224, 224])
(678, 1024, 3)
torch.Size([3, 224, 224])
(903, 1024, 3)
torch.Size([3, 224, 224])
(765, 1024, 3)
torch.Size([3, 224, 224])
(686, 1024, 3)
torch.Size([3, 224, 224])
(768, 1024, 3)
torch.Size([3, 224, 224])
(837, 1024, 3)
torch.Size([3, 224, 224])
(679, 1024, 3)
torch.Size([3, 224, 224])
(678, 1024, 3)
torch.Size([3, 224, 224])
(683, 1024, 3)
torch.Size([3, 224, 224])
(1024, 768, 3)
torch.Size([3, 224, 224])
(1024, 683, 3)
torch.Size([3, 224, 224])
(1024, 1024, 3)
torch.Size([3, 224, 224])
(768, 1024, 3)
torch.Size([3, 224, 224])
(683, 1024, 3)
torch.Size([3, 224, 224])
(683, 1024, 3)
torch.Size([3, 224, 224])
(374, 1024, 3)
torch.Size([3, 224, 224])
(683, 1024, 3)
torch.Size([3, 224, 224])
(1024, 855, 3)
torch.Size([3, 224, 224])
(768, 1024, 3)
torch.Size([3, 224, 224])
(681, 1024, 3)
torch.Size([3, 224, 224])
(768, 1024, 3)
torch.Size([3, 224, 224])
(1024, 682, 3)
torch.Size([3, 224, 224])
(768, 768, 3)
t

(768, 1024, 3)
torch.Size([3, 224, 224])
(768, 1024, 3)
torch.Size([3, 224, 224])
(768, 1024, 3)
torch.Size([3, 224, 224])
(683, 1024, 3)
torch.Size([3, 224, 224])
(768, 1024, 3)
torch.Size([3, 224, 224])
(683, 1024, 3)
torch.Size([3, 224, 224])
(683, 1024, 3)
torch.Size([3, 224, 224])
(1024, 768, 3)
torch.Size([3, 224, 224])
(812, 1024, 3)
torch.Size([3, 224, 224])
(768, 1024, 3)
torch.Size([3, 224, 224])
(1024, 985, 3)
torch.Size([3, 224, 224])
(672, 1024, 3)
torch.Size([3, 224, 224])
(454, 1024, 3)
torch.Size([3, 224, 224])
(576, 1024, 3)
torch.Size([3, 224, 224])
(681, 1024, 3)
torch.Size([3, 224, 224])
(768, 1024, 3)
torch.Size([3, 224, 224])
(768, 1024, 3)
torch.Size([3, 224, 224])
(1024, 768, 3)
torch.Size([3, 224, 224])
(683, 1024, 3)
torch.Size([3, 224, 224])
(682, 1024, 3)
torch.Size([3, 224, 224])
(686, 1024, 3)
torch.Size([3, 224, 224])
(683, 1024, 3)
torch.Size([3, 224, 224])
(768, 1024, 3)
torch.Size([3, 224, 224])
(768, 1024, 3)
torch.Size([3, 224, 224])
(683, 1024, 3)
t

torch.Size([3, 224, 224])
(680, 1024, 3)
torch.Size([3, 224, 224])
(683, 1024, 3)
torch.Size([3, 224, 224])
(1024, 1024, 3)
torch.Size([3, 224, 224])
(645, 1024, 3)
torch.Size([3, 224, 224])
(683, 1024, 3)
torch.Size([3, 224, 224])
(768, 1024, 3)
torch.Size([3, 224, 224])
(668, 1024, 3)
torch.Size([3, 224, 224])
(683, 1024, 3)
torch.Size([3, 224, 224])
(683, 1024, 3)
torch.Size([3, 224, 224])
(683, 1024, 3)
torch.Size([3, 224, 224])
(714, 1024, 3)
torch.Size([3, 224, 224])
(768, 1024, 3)
torch.Size([3, 224, 224])
(683, 1024, 3)
torch.Size([3, 224, 224])
(768, 1024, 3)
torch.Size([3, 224, 224])
(685, 1024, 3)
torch.Size([3, 224, 224])
(684, 1024, 3)
torch.Size([3, 224, 224])
(731, 1024, 3)
torch.Size([3, 224, 224])
(881, 1024, 3)
torch.Size([3, 224, 224])
(685, 1024, 3)
torch.Size([3, 224, 224])
(1024, 682, 3)
torch.Size([3, 224, 224])
(1024, 1024, 3)
torch.Size([3, 224, 224])
(768, 1024, 3)
torch.Size([3, 224, 224])
(768, 1024, 3)
torch.Size([3, 224, 224])
(1024, 768, 3)
torch.Size([3,

(277, 1024, 3)
torch.Size([3, 224, 224])
(678, 1024, 3)
torch.Size([3, 224, 224])
(681, 1024, 3)
torch.Size([3, 224, 224])
(738, 1024, 3)
torch.Size([3, 224, 224])
(768, 1024, 3)
torch.Size([3, 224, 224])
(768, 1024, 3)
torch.Size([3, 224, 224])
(683, 1024, 3)
torch.Size([3, 224, 224])
(681, 1024, 3)
torch.Size([3, 224, 224])
(769, 1024, 3)
torch.Size([3, 224, 224])
(656, 1024, 3)
torch.Size([3, 224, 224])
(678, 1024, 3)
torch.Size([3, 224, 224])
(685, 1024, 3)
torch.Size([3, 224, 224])
(686, 1024, 3)
torch.Size([3, 224, 224])
(1023, 1024, 3)
torch.Size([3, 224, 224])
(679, 1024, 3)
torch.Size([3, 224, 224])
(768, 1024, 3)
torch.Size([3, 224, 224])
(768, 1024, 3)
torch.Size([3, 224, 224])
(768, 1024, 3)
torch.Size([3, 224, 224])
(766, 1024, 3)
torch.Size([3, 224, 224])
(768, 1024, 3)
torch.Size([3, 224, 224])
(683, 1024, 3)
torch.Size([3, 224, 224])
(684, 1024)
torch.Size([3, 224, 224])
(683, 1024, 3)
torch.Size([3, 224, 224])
(681, 1024, 3)
torch.Size([3, 224, 224])
(712, 1024, 3)
tor

KeyboardInterrupt: 

In [125]:
image = io.imread("train_0/0b0941addc9c6d7d.jpg")