In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data

import torchvision.transforms as transforms
import torchvision.datasets as datasets

import numpy as np
from tqdm import tqdm

In [4]:
root = './data/'
num_classes = 10
train_data = datasets.CIFAR10(root=root, train=True, download=True)

Files already downloaded and verified


We compute the mean and std in each channel (in RGB there are 3).  
We divide by 255 because the values era in range 0-255, and we want the range to be 0-1.

In [5]:
means = train_data.data.mean(axis=(0, 1, 2)) / 255
stds = train_data.data.std(axis=(0, 1, 2)) / 255
print(f'Mean {means}, STD {stds}')

Mean [0.49139968 0.48215841 0.44653091], STD [0.24703223 0.24348513 0.26158784]


In [6]:
# Optional: rotation & flip for train
transform = transforms.Compose([
	transforms.ToTensor(),
	transforms.Normalize(mean=means, std=stds)])

train_data = datasets.CIFAR10(root=root, train=True, transform=transform)
test_data = datasets.CIFAR10(root=root, train=False, transform=transform)

In [7]:
# Extract validation data
valid_ratio = 0.9
n_train_examples = int(len(train_data) * valid_ratio)
n_valid_examples = len(train_data) - n_train_examples
train_data, valid_data = data.random_split(train_data, 
	[n_train_examples, n_valid_examples])
print(f'#train={len(train_data)}, #valid={len(valid_data)}, #test={len(test_data)}')

#train=45000, #valid=5000, #test=10000


In [8]:
batch_size = 256
train_loader = data.DataLoader(train_data, shuffle=True, batch_size=batch_size)
valid_loader = data.DataLoader(valid_data, batch_size=batch_size)
test_loader = data.DataLoader(test_data, batch_size=batch_size)

In [12]:
class LeNet5(nn.Module):
	def __init__(self, output_dim):
		super().__init__()

		self.conv1 = nn.Conv2d(in_channels=3,
							   out_channels=6,
							   kernel_size=5)
		self.conv2 = nn.Conv2d(in_channels=6,
							   out_channels=16,
							   kernel_size=5)

		self.fc1 = nn.Linear(16 * 5 * 5, 120)
		self.fc2 = nn.Linear(120, 84)
		self.fc3 = nn.Linear(84, output_dim)

		self.pool = nn.AvgPool2d(2, 2)
		self.relu = nn.ReLU(inplace=True)

	def forward(self, x):
		x = self.relu(self.pool(self.conv1(x)))
		x = self.relu(self.pool(self.conv2(x)))
		x = torch.flatten(x, 1)
		x = F.relu(self.fc1(x))		# orig LeNet5 used sigmoid
		x = F.relu(self.fc2(x))		# orig LeNet5 used sigmoid
		x = self.fc3(x)
		return x

In [15]:
class AlexNet(nn.Module):
	def __init__(self, output_dim):
		super().__init__()

		# channels 96, 256, 384, 384, 256
		self.conv1 = nn.Conv2d(3, 96, 3, 2, 1) 	# in_c, out_c, kernel_size, stride, padding
		self.conv2 = nn.Conv2d(96, 256, 3, padding=1)
		self.conv3 = nn.Conv2d(256, 384, 3, padding=1)
		self.conv4 = nn.Conv2d(384, 384, 3, padding=1)
		self.conv5 = nn.Conv2d(384, 256, 3, padding=1)

		self.fc1 = nn.Linear(256 * 2 * 2, 4096)
		self.fc2 = nn.Linear(4096, 4096)
		self.fc3 = nn.Linear(4096, output_dim)

		self.pool = nn.MaxPool2d(2, 2)
		self.relu = nn.ReLU(inplace=True)
		self.dropout = nn.Dropout(0.5)

	def forward(self, x):
		x = self.relu(self.pool(self.conv1(x)))
		x = self.relu(self.pool(self.conv2(x)))
		x = self.relu(self.conv3(x))
		x = self.relu(self.conv4(x))
		x = self.relu(self.pool(self.conv5(x)))
		x = torch.flatten(x, 1)
		x = self.relu(self.fc1(self.dropout(x)))
		x = self.relu(self.fc2(self.dropout(x)))
		x = self.fc3(x)

		return x

In [17]:
model_name = 'AlexNet'
if model_name == 'LeNet5':
	model = LeNet5(output_dim=num_classes)
elif model_name == 'AlexNet':
	model = AlexNet(output_dim=num_classes)

def count_parameters(model):
	return sum(p.numel() for p in model.parameters() if p.requires_grad)

def get_device():
	device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
	return device

print(f'The model has {count_parameters(model):,} trainable parameters')

The model has 24,342,410 trainable parameters


In [18]:
lr = 1e-4
optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()
device = get_device()


def calc_acc(y_pred, y):
	top_pred = y_pred.argmax(1, keepdim=True)
	correct = top_pred.eq(y.view_as(top_pred)).sum()
	acc = correct.float() / y.shape[0]
	return acc


def train(model, loader, optimizer, criterion):
	epoch_loss, epoch_acc = 0, 0

	model.train()

	for (x, y) in tqdm(loader):

		optimizer.zero_grad()

		y_pred = model(x)
		loss = criterion(y_pred, y)		
		loss.backward()
		optimizer.step()

		acc = calc_acc(y_pred, y)

		epoch_loss += loss.item()
		epoch_acc += acc.item()
	return epoch_loss / len(loader), epoch_acc / len(loader)


def evaluate(model, loader, criterion):

	epoch_loss, epoch_acc = 0, 0

	model.eval()

	with torch.no_grad():

		for (x, y) in tqdm(loader):
			y_pred = model(x)
			loss = criterion(y_pred, y)
			acc = calc_acc(y_pred, y)

			epoch_loss += loss.item()
			epoch_acc += acc.item()
		return epoch_loss / len(loader), epoch_acc / len(loader)


epochs = 25
best_valid_loss = float('inf')

for epoch in range(epochs):
	train_loss, train_acc = train(model, train_loader, optimizer, criterion)
	valid_loss, valid_acc = evaluate(model, valid_loader, criterion)

	if valid_loss < best_valid_loss:
		best_valid_loss = valid_loss
		torch.save(model.state_dict(), model_name + '.pt')

	print(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')
	print(f'\t Val. Loss: {valid_loss:.3f} |  Val. Acc: {valid_acc*100:.2f}%')

100%|██████████| 176/176 [02:39<00:00,  1.11it/s]
100%|██████████| 20/20 [00:06<00:00,  3.12it/s]


	Train Loss: 1.863 | Train Acc: 29.79%
	 Val. Loss: 1.554 |  Val. Acc: 41.16%


100%|██████████| 176/176 [03:02<00:00,  1.04s/it]
100%|██████████| 20/20 [00:06<00:00,  2.99it/s]


	Train Loss: 1.476 | Train Acc: 45.14%
	 Val. Loss: 1.353 |  Val. Acc: 49.56%


100%|██████████| 176/176 [02:57<00:00,  1.01s/it]
100%|██████████| 20/20 [00:07<00:00,  2.60it/s]


	Train Loss: 1.304 | Train Acc: 52.49%
	 Val. Loss: 1.258 |  Val. Acc: 54.25%


100%|██████████| 176/176 [02:55<00:00,  1.00it/s]
100%|██████████| 20/20 [00:06<00:00,  3.24it/s]


	Train Loss: 1.177 | Train Acc: 57.20%
	 Val. Loss: 1.180 |  Val. Acc: 56.48%


100%|██████████| 176/176 [03:01<00:00,  1.03s/it]
100%|██████████| 20/20 [00:07<00:00,  2.50it/s]


	Train Loss: 1.072 | Train Acc: 61.39%
	 Val. Loss: 1.071 |  Val. Acc: 60.91%


100%|██████████| 176/176 [02:53<00:00,  1.01it/s]
100%|██████████| 20/20 [00:07<00:00,  2.71it/s]


	Train Loss: 0.997 | Train Acc: 64.23%
	 Val. Loss: 1.035 |  Val. Acc: 62.30%


100%|██████████| 176/176 [02:54<00:00,  1.01it/s]
100%|██████████| 20/20 [00:06<00:00,  3.08it/s]


	Train Loss: 0.914 | Train Acc: 67.24%
	 Val. Loss: 0.988 |  Val. Acc: 64.42%


100%|██████████| 176/176 [02:46<00:00,  1.06it/s]
100%|██████████| 20/20 [00:07<00:00,  2.63it/s]


	Train Loss: 0.830 | Train Acc: 70.22%
	 Val. Loss: 0.947 |  Val. Acc: 67.33%


100%|██████████| 176/176 [02:53<00:00,  1.02it/s]
100%|██████████| 20/20 [00:06<00:00,  3.11it/s]


	Train Loss: 0.747 | Train Acc: 73.40%
	 Val. Loss: 0.930 |  Val. Acc: 68.02%


100%|██████████| 176/176 [02:50<00:00,  1.03it/s]
100%|██████████| 20/20 [00:06<00:00,  3.05it/s]


	Train Loss: 0.688 | Train Acc: 75.65%
	 Val. Loss: 0.916 |  Val. Acc: 69.46%


100%|██████████| 176/176 [02:53<00:00,  1.01it/s]
100%|██████████| 20/20 [00:07<00:00,  2.63it/s]


	Train Loss: 0.598 | Train Acc: 78.86%
	 Val. Loss: 0.914 |  Val. Acc: 69.95%


100%|██████████| 176/176 [02:53<00:00,  1.01it/s]
100%|██████████| 20/20 [00:06<00:00,  2.92it/s]


	Train Loss: 0.529 | Train Acc: 81.39%
	 Val. Loss: 0.910 |  Val. Acc: 70.38%


100%|██████████| 176/176 [02:52<00:00,  1.02it/s]
100%|██████████| 20/20 [00:06<00:00,  2.90it/s]


	Train Loss: 0.449 | Train Acc: 84.20%
	 Val. Loss: 0.943 |  Val. Acc: 71.25%


100%|██████████| 176/176 [02:54<00:00,  1.01it/s]
100%|██████████| 20/20 [00:08<00:00,  2.42it/s]


	Train Loss: 0.385 | Train Acc: 86.37%
	 Val. Loss: 0.993 |  Val. Acc: 70.31%


100%|██████████| 176/176 [03:01<00:00,  1.03s/it]
100%|██████████| 20/20 [00:06<00:00,  3.00it/s]


	Train Loss: 0.314 | Train Acc: 89.13%
	 Val. Loss: 1.022 |  Val. Acc: 70.63%


100%|██████████| 176/176 [02:51<00:00,  1.03it/s]
100%|██████████| 20/20 [00:06<00:00,  3.11it/s]


	Train Loss: 0.260 | Train Acc: 90.83%
	 Val. Loss: 1.106 |  Val. Acc: 70.71%


100%|██████████| 176/176 [02:49<00:00,  1.04it/s]
100%|██████████| 20/20 [00:06<00:00,  3.00it/s]


	Train Loss: 0.208 | Train Acc: 92.66%
	 Val. Loss: 1.163 |  Val. Acc: 70.51%


100%|██████████| 176/176 [03:01<00:00,  1.03s/it]
100%|██████████| 20/20 [00:07<00:00,  2.67it/s]


	Train Loss: 0.169 | Train Acc: 94.14%
	 Val. Loss: 1.287 |  Val. Acc: 70.71%


100%|██████████| 176/176 [03:01<00:00,  1.03s/it]
100%|██████████| 20/20 [00:07<00:00,  2.55it/s]


	Train Loss: 0.143 | Train Acc: 95.01%
	 Val. Loss: 1.357 |  Val. Acc: 70.62%


100%|██████████| 176/176 [02:45<00:00,  1.07it/s]
100%|██████████| 20/20 [00:06<00:00,  3.18it/s]


	Train Loss: 0.100 | Train Acc: 96.64%
	 Val. Loss: 1.428 |  Val. Acc: 70.58%


100%|██████████| 176/176 [02:48<00:00,  1.05it/s]
100%|██████████| 20/20 [00:06<00:00,  3.25it/s]


	Train Loss: 0.099 | Train Acc: 96.49%
	 Val. Loss: 1.331 |  Val. Acc: 71.00%


100%|██████████| 176/176 [02:51<00:00,  1.03it/s]
100%|██████████| 20/20 [00:06<00:00,  3.18it/s]


	Train Loss: 0.080 | Train Acc: 97.20%
	 Val. Loss: 1.500 |  Val. Acc: 71.12%


100%|██████████| 176/176 [02:56<00:00,  1.00s/it]
100%|██████████| 20/20 [00:08<00:00,  2.30it/s]


	Train Loss: 0.071 | Train Acc: 97.56%
	 Val. Loss: 1.624 |  Val. Acc: 70.99%


100%|██████████| 176/176 [02:46<00:00,  1.06it/s]
100%|██████████| 20/20 [00:06<00:00,  3.20it/s]


	Train Loss: 0.063 | Train Acc: 97.84%
	 Val. Loss: 1.658 |  Val. Acc: 71.00%


100%|██████████| 176/176 [02:42<00:00,  1.08it/s]
100%|██████████| 20/20 [00:06<00:00,  3.22it/s]

	Train Loss: 0.069 | Train Acc: 97.56%
	 Val. Loss: 1.541 |  Val. Acc: 71.03%



