In [1]:
import torch 
import torchvision
from torchvision import transforms
from torchvision.datasets import CIFAR10
from torch import nn

  warn(


In [3]:
DATA_DIR="./data"
def get_cifar10_data_loader(batch_size, num_workers=4):
    """
    Get the CIFAR10 data loader
    """
    # define the transform
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

    # get the training and testing datasets
    train_dataset = CIFAR10(root=DATA_DIR, train=True, transform=transform, download=True)
    test_dataset = CIFAR10(root=DATA_DIR, train=False, transform=transform, download=True)
    train_set, val_set = torch.utils.data.random_split(train_dataset, [45000, 5000])
    # get the data loaders
    train_loader = torch.utils.data.DataLoader(dataset=train_set, batch_size=batch_size, shuffle=True,
                                               num_workers=num_workers, drop_last=True)
    val_loader = torch.utils.data.DataLoader(dataset=val_set, batch_size=batch_size, shuffle=True,
                                             num_workers=num_workers, drop_last=False)
    test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False,
                                              num_workers=num_workers, drop_last=False)

    return train_loader, val_loader, test_loader

In [8]:
# Define the CNN model
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        
        self.conv_layer = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        
        self.fc_layer = nn.Sequential(
            nn.Linear(128*8*8, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Linear(512, 10)
        )
    
    def forward(self, x):
        x = self.conv_layer(x)
        x = x.view(x.size(0), -1)
        x = self.fc_layer(x)
        return x


In [9]:
train_loader,val_loader, test_loader = get_cifar10_data_loader(64)

Files already downloaded and verified
Files already downloaded and verified


In [12]:
# Model, Loss and Optimizer
device = "cuda:0" if torch.cuda.is_available() else 0
model = SimpleCNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [15]:
# Training loop
num_epochs = 50
for epoch in range(num_epochs):
    for i, (inputs, labels) in enumerate(train_loader):
        inputs = inputs.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}")

print("Finished Training")

Epoch 1/50, Loss: 0.4647936224937439
Epoch 2/50, Loss: 0.6260641813278198
Epoch 3/50, Loss: 0.48246100544929504
Epoch 4/50, Loss: 0.6623132228851318
Epoch 5/50, Loss: 0.15517809987068176
Epoch 6/50, Loss: 0.6476132273674011
Epoch 7/50, Loss: 0.40398281812667847
Epoch 8/50, Loss: 0.23638379573822021
Epoch 9/50, Loss: 0.6133595108985901
Epoch 10/50, Loss: 0.4988524913787842
Epoch 11/50, Loss: 0.3220791220664978
Epoch 12/50, Loss: 0.30658406019210815
Epoch 13/50, Loss: 0.2607712745666504
Epoch 14/50, Loss: 0.3048642873764038
Epoch 15/50, Loss: 0.26616334915161133
Epoch 16/50, Loss: 0.22834931313991547
Epoch 17/50, Loss: 0.2074802666902542
Epoch 18/50, Loss: 0.2858315706253052
Epoch 19/50, Loss: 0.2955467402935028
Epoch 20/50, Loss: 0.20158526301383972
Epoch 21/50, Loss: 0.18471454083919525
Epoch 22/50, Loss: 0.2831464409828186
Epoch 23/50, Loss: 0.08453170210123062
Epoch 24/50, Loss: 0.23504959046840668
Epoch 25/50, Loss: 0.15674324333667755
Epoch 26/50, Loss: 0.31135642528533936
Epoch 27

In [21]:
torch.save(model.state_dict(),'simple_cnn.pth')

In [2]:
for name, param in torch.load('simple_cnn.pth').items():
    print(f"Name: {name}")
    print(f"Shape: {param.shape}\n")

Name: conv_layer.0.weight
Shape: torch.Size([32, 3, 3, 3])

Name: conv_layer.0.bias
Shape: torch.Size([32])

Name: conv_layer.1.weight
Shape: torch.Size([32])

Name: conv_layer.1.bias
Shape: torch.Size([32])

Name: conv_layer.1.running_mean
Shape: torch.Size([32])

Name: conv_layer.1.running_var
Shape: torch.Size([32])

Name: conv_layer.1.num_batches_tracked
Shape: torch.Size([])

Name: conv_layer.3.weight
Shape: torch.Size([64, 32, 3, 3])

Name: conv_layer.3.bias
Shape: torch.Size([64])

Name: conv_layer.6.weight
Shape: torch.Size([128, 64, 3, 3])

Name: conv_layer.6.bias
Shape: torch.Size([128])

Name: conv_layer.7.weight
Shape: torch.Size([128])

Name: conv_layer.7.bias
Shape: torch.Size([128])

Name: conv_layer.7.running_mean
Shape: torch.Size([128])

Name: conv_layer.7.running_var
Shape: torch.Size([128])

Name: conv_layer.7.num_batches_tracked
Shape: torch.Size([])

Name: fc_layer.0.weight
Shape: torch.Size([512, 8192])

Name: fc_layer.0.bias
Shape: torch.Size([512])

Name: fc_la

In [23]:
print(torch.load('simple_cnn.pth'))

OrderedDict([('conv_layer.0.weight', tensor([[[[ 2.1116e-01,  4.8455e-02, -2.9921e-01],
          [ 1.9832e-01,  2.3508e-01,  9.9809e-03],
          [ 1.2435e-01, -4.5312e-01, -1.2458e-01]],

         [[-2.4210e-01,  1.0128e-01, -3.0559e-01],
          [-8.1651e-02,  2.7562e-01,  5.9083e-02],
          [ 1.9792e-01, -3.0688e-01,  3.1400e-01]],

         [[-2.2417e-01,  2.8597e-01, -1.2115e-02],
          [-2.5839e-01,  4.2692e-02,  1.8737e-01],
          [-2.0617e-01, -2.2708e-01,  3.1691e-01]]],


        [[[-1.8919e-01, -1.9894e-02, -3.8444e-02],
          [ 5.1361e-02, -3.9159e-01,  2.5619e-01],
          [ 3.7105e-02, -3.2774e-01, -2.8004e-02]],

         [[-1.1920e-01,  3.0491e-01,  1.2978e-01],
          [ 4.7603e-01, -9.0025e-02,  5.4110e-01],
          [ 2.0026e-01, -4.5074e-02, -1.4235e-02]],

         [[-1.4389e-03, -2.6135e-01, -1.4051e-01],
          [ 1.3640e-01, -2.9573e-01, -9.7937e-02],
          [ 2.6921e-02, -1.8648e-01,  7.9275e-02]]],


        [[[ 8.2735e-02,  1.19