In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.utils.data as data

from model import ResidualBlock, ResNet
from train_eval_util import train, evaluate, calculate_accuracy, epoch_time
from getCIFAR10 import train_data, valid_data, test_data

import time

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [2]:
if torch.cuda.is_available():
    device = torch.device('cuda')
elif torch.backends.mps.is_available():
    device = torch.device('mps')  
else:
    device = torch.device('cpu')

print(f"Selected device: {device}")

Selected device: cuda


In [3]:
BATCH_SIZE = 500

train_iterator = DataLoader(train_data, batch_size= BATCH_SIZE, shuffle=True)

valid_iterator =  DataLoader(valid_data, batch_size=BATCH_SIZE, shuffle=False)

test_iterator =  DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=True)

# ResNet 32

In [4]:
model = ResNet(ResidualBlock, [5, 5, 5]).to(device)


total_layers = sum([1 for _ in model.modules() 
    if isinstance(_, nn.Conv2d) or isinstance(_, nn.Linear)]) - 2 # subtract input and output layers
    
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Total number of layers: {total_layers}")
print(f"Total number of parameters: {total_params}")

Total number of layers: 32
Total number of parameters: 466906


In [5]:

criterion = nn.CrossEntropyLoss().to(device)
# optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)

from torch.optim.lr_scheduler import StepLR
scheduler = StepLR(optimizer, step_size=5, gamma=0.1)


In [6]:
num_epochs = 100
train_acc_history = []
train_loss_history = []
valid_acc_history = []
valid_loss_history = []

for epoch in range(num_epochs):

    start_time = time.time()
    
    train_loss, train_acc = train(model, train_iterator, optimizer, criterion, device)
    scheduler.step()
    valid_loss, valid_acc = evaluate(model, valid_iterator, criterion, device)
        
    end_time = time.time()

        
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)

    
    print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. Acc: {valid_acc*100:.2f}%')

    train_acc_history.append(train_acc)
    train_loss_history.append( train_loss)
    valid_acc_history.append(valid_acc)
    valid_loss_history.append(valid_loss)

Epoch: 01 | Epoch Time: 0m 17s
	Train Loss: 1.788 | Train Acc: 32.28%
	 Val. Loss: 1.983 |  Val. Acc: 32.46%
Epoch: 02 | Epoch Time: 0m 16s
	Train Loss: 1.372 | Train Acc: 49.30%
	 Val. Loss: 1.521 |  Val. Acc: 48.74%
Epoch: 03 | Epoch Time: 0m 16s
	Train Loss: 1.121 | Train Acc: 59.64%
	 Val. Loss: 1.044 |  Val. Acc: 62.36%
Epoch: 04 | Epoch Time: 0m 16s
	Train Loss: 0.930 | Train Acc: 66.69%
	 Val. Loss: 1.044 |  Val. Acc: 62.64%
Epoch: 05 | Epoch Time: 0m 16s
	Train Loss: 0.807 | Train Acc: 71.69%
	 Val. Loss: 0.798 |  Val. Acc: 72.98%
Epoch: 06 | Epoch Time: 0m 16s
	Train Loss: 0.652 | Train Acc: 77.09%
	 Val. Loss: 0.621 |  Val. Acc: 78.56%
Epoch: 07 | Epoch Time: 0m 16s
	Train Loss: 0.609 | Train Acc: 78.67%
	 Val. Loss: 0.605 |  Val. Acc: 79.06%
Epoch: 08 | Epoch Time: 0m 16s
	Train Loss: 0.589 | Train Acc: 79.41%
	 Val. Loss: 0.586 |  Val. Acc: 79.14%
Epoch: 09 | Epoch Time: 0m 16s
	Train Loss: 0.574 | Train Acc: 79.90%
	 Val. Loss: 0.583 |  Val. Acc: 80.10%
Epoch: 10 | Epoch T

KeyboardInterrupt: 

In [None]:
import matplotlib.pyplot as plt

fig,(ax1, ax2) = plt.subplots(1, 2, figsize=(10, 3))

ax1.plot(range(num_epochs), train_loss_history, '--r')
ax1.plot(range(num_epochs), valid_loss_history, '-g')
ax1.set_xlabel('Epochs')
ax1.set_ylabel('Loss')
ax1.legend(['train', 'valid'])

ax2.plot(range(num_epochs), train_acc_history, '--r')
ax2.plot(range(num_epochs), valid_acc_history, '-g')
ax2.set_xlabel('Epochs')
ax2.set_xlabel('Accuracy')
ax1.legend(['train', 'valid'])

In [None]:
torch.save(model, 'resnet32.pt')