In [4]:

import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary

from tqdm import tqdm
from model import Net
from data_loader import get_train_loader
from data_loader import get_test_loader

use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
torch.manual_seed(10)
batch_size = 128
EPOCHS=30
LAMBDA=0.1
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}



model = Net().to(device)
summary(model, input_size=(3, 32, 32))



train_losses = []
test_losses = []
train_acc = []
test_acc = []

train_loader = get_train_loader()
test_loader = get_test_loader()

def train(model, device, train_loader, optimizer, epoch):
    model.train()
    pbar = tqdm(train_loader)
    correct = 0
    processed = 0
    for batch_idx, (data, target) in enumerate(pbar):
        data, target = data.to(device), target.to(device)

        optimizer.zero_grad()
        y_pred = model(data)

        # Calculate loss
        criteria = nn.CrossEntropyLoss()
        loss = criteria(y_pred, target)
        train_losses.append(loss)

        # Backpropagation
        loss.backward()
        optimizer.step()

        pred = y_pred.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
        correct += pred.eq(target.view_as(pred)).sum().item()
        processed += len(data)

        pbar.set_description(desc=f'Loss={loss.item()} Batch_id={batch_idx} Accuracy={100 * correct / processed:0.2f}')
        train_acc.append(100 * correct / processed)


def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    test_losses.append(test_loss)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

    test_acc.append(100. * correct / len(test_loader.dataset))


from torch.optim.lr_scheduler import StepLR
model = Net().to(device)
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9,nesterov=False)
#scheduler = StepLR(optimizer, step_size=8, gamma=0.1)
for epoch in range(EPOCHS):
    print('Epoch:', epoch+1)
    train(model, device, train_loader, optimizer, epoch)
    #scheduler.step()
    test(model, device, test_loader)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 32, 30, 30]             864
       BatchNorm2d-2           [-1, 32, 30, 30]              64
              ReLU-3           [-1, 32, 30, 30]               0
           Dropout-4           [-1, 32, 30, 30]               0
            Conv2d-5           [-1, 64, 28, 28]          18,432
       BatchNorm2d-6           [-1, 64, 28, 28]             128
              ReLU-7           [-1, 64, 28, 28]               0
           Dropout-8           [-1, 64, 28, 28]               0
            Conv2d-9          [-1, 128, 26, 26]          73,728
      BatchNorm2d-10          [-1, 128, 26, 26]             256
             ReLU-11          [-1, 128, 26, 26]               0
          Dropout-12          [-1, 128, 26, 26]               0
           Conv2d-13          [-1, 230, 24, 24]         264,960
      BatchNorm2d-14          [-1, 230,

  0%|          | 0/782 [00:00<?, ?it/s]

Epoch: 1


Loss=1.848878026008606 Batch_id=781 Accuracy=35.26: 100%|██████████| 782/782 [00:35<00:00, 22.11it/s]
  0%|          | 0/782 [00:00<?, ?it/s]


Test set: Average loss: 1.4330, Accuracy: 4714/10000 (47.14%)

Epoch: 2


Loss=1.0872299671173096 Batch_id=781 Accuracy=51.79: 100%|██████████| 782/782 [00:35<00:00, 22.09it/s]
  0%|          | 0/782 [00:00<?, ?it/s]


Test set: Average loss: 1.1855, Accuracy: 5801/10000 (58.01%)

Epoch: 3


Loss=1.0600894689559937 Batch_id=781 Accuracy=59.39: 100%|██████████| 782/782 [00:34<00:00, 22.40it/s]
  0%|          | 0/782 [00:00<?, ?it/s]


Test set: Average loss: 1.0358, Accuracy: 6382/10000 (63.82%)

Epoch: 4


Loss=0.9389859437942505 Batch_id=781 Accuracy=64.95: 100%|██████████| 782/782 [00:35<00:00, 22.29it/s]
  0%|          | 0/782 [00:00<?, ?it/s]


Test set: Average loss: 0.9185, Accuracy: 6821/10000 (68.21%)

Epoch: 5


Loss=0.7230312824249268 Batch_id=781 Accuracy=68.91: 100%|██████████| 782/782 [00:35<00:00, 22.00it/s]
  0%|          | 0/782 [00:00<?, ?it/s]


Test set: Average loss: 0.8741, Accuracy: 7019/10000 (70.19%)

Epoch: 6


Loss=0.6350700855255127 Batch_id=781 Accuracy=71.98: 100%|██████████| 782/782 [00:34<00:00, 22.50it/s]
  0%|          | 0/782 [00:00<?, ?it/s]


Test set: Average loss: 0.7758, Accuracy: 7359/10000 (73.59%)

Epoch: 7


Loss=1.2888493537902832 Batch_id=781 Accuracy=74.12: 100%|██████████| 782/782 [00:35<00:00, 22.06it/s]
  0%|          | 0/782 [00:00<?, ?it/s]


Test set: Average loss: 0.7384, Accuracy: 7473/10000 (74.73%)

Epoch: 8


Loss=0.9594882130622864 Batch_id=781 Accuracy=76.11: 100%|██████████| 782/782 [00:35<00:00, 22.10it/s]
  0%|          | 0/782 [00:00<?, ?it/s]


Test set: Average loss: 0.7729, Accuracy: 7356/10000 (73.56%)

Epoch: 9


Loss=0.42239847779273987 Batch_id=781 Accuracy=77.48: 100%|██████████| 782/782 [00:35<00:00, 22.16it/s]
  0%|          | 0/782 [00:00<?, ?it/s]


Test set: Average loss: 0.6728, Accuracy: 7707/10000 (77.07%)

Epoch: 10


Loss=0.6624289751052856 Batch_id=781 Accuracy=79.11: 100%|██████████| 782/782 [00:34<00:00, 22.41it/s]
  0%|          | 0/782 [00:00<?, ?it/s]


Test set: Average loss: 0.7020, Accuracy: 7603/10000 (76.03%)

Epoch: 11


Loss=0.4350656569004059 Batch_id=781 Accuracy=80.39: 100%|██████████| 782/782 [00:35<00:00, 22.11it/s]
  0%|          | 0/782 [00:00<?, ?it/s]


Test set: Average loss: 0.6906, Accuracy: 7662/10000 (76.62%)

Epoch: 12


Loss=1.133341908454895 Batch_id=781 Accuracy=81.23: 100%|██████████| 782/782 [00:35<00:00, 22.18it/s]
  0%|          | 0/782 [00:00<?, ?it/s]


Test set: Average loss: 0.6620, Accuracy: 7778/10000 (77.78%)

Epoch: 13


Loss=0.5381175875663757 Batch_id=781 Accuracy=82.36: 100%|██████████| 782/782 [00:35<00:00, 22.21it/s]
  0%|          | 0/782 [00:00<?, ?it/s]


Test set: Average loss: 0.6698, Accuracy: 7760/10000 (77.60%)

Epoch: 14


Loss=0.351249098777771 Batch_id=781 Accuracy=83.19: 100%|██████████| 782/782 [00:35<00:00, 22.07it/s]
  0%|          | 0/782 [00:00<?, ?it/s]


Test set: Average loss: 0.6938, Accuracy: 7738/10000 (77.38%)

Epoch: 15


Loss=0.5164002180099487 Batch_id=781 Accuracy=84.20: 100%|██████████| 782/782 [00:35<00:00, 22.14it/s]
  0%|          | 0/782 [00:00<?, ?it/s]


Test set: Average loss: 0.6310, Accuracy: 7897/10000 (78.97%)

Epoch: 16


Loss=0.9458195567131042 Batch_id=781 Accuracy=84.80: 100%|██████████| 782/782 [00:35<00:00, 22.25it/s]
  0%|          | 0/782 [00:00<?, ?it/s]


Test set: Average loss: 0.6846, Accuracy: 7747/10000 (77.47%)

Epoch: 17


Loss=1.722427248954773 Batch_id=781 Accuracy=85.77: 100%|██████████| 782/782 [00:35<00:00, 22.11it/s]
  0%|          | 0/782 [00:00<?, ?it/s]


Test set: Average loss: 0.6486, Accuracy: 7936/10000 (79.36%)

Epoch: 18


Loss=0.9374287724494934 Batch_id=781 Accuracy=86.32: 100%|██████████| 782/782 [00:34<00:00, 22.43it/s]
  0%|          | 0/782 [00:00<?, ?it/s]


Test set: Average loss: 0.6784, Accuracy: 7853/10000 (78.53%)

Epoch: 19


Loss=1.0039093494415283 Batch_id=781 Accuracy=86.80: 100%|██████████| 782/782 [00:35<00:00, 22.31it/s]
  0%|          | 0/782 [00:00<?, ?it/s]


Test set: Average loss: 0.6710, Accuracy: 7910/10000 (79.10%)

Epoch: 20


Loss=0.05501816049218178 Batch_id=781 Accuracy=87.42: 100%|██████████| 782/782 [00:35<00:00, 22.11it/s]
  0%|          | 0/782 [00:00<?, ?it/s]


Test set: Average loss: 0.6569, Accuracy: 7958/10000 (79.58%)

Epoch: 21


Loss=0.5705156922340393 Batch_id=781 Accuracy=88.27: 100%|██████████| 782/782 [00:34<00:00, 22.47it/s]
  0%|          | 0/782 [00:00<?, ?it/s]


Test set: Average loss: 0.6401, Accuracy: 8010/10000 (80.10%)

Epoch: 22


Loss=0.4717254936695099 Batch_id=781 Accuracy=88.77: 100%|██████████| 782/782 [00:34<00:00, 22.44it/s]
  0%|          | 0/782 [00:00<?, ?it/s]


Test set: Average loss: 0.6737, Accuracy: 7933/10000 (79.33%)

Epoch: 23


Loss=0.1527709811925888 Batch_id=781 Accuracy=89.25: 100%|██████████| 782/782 [00:34<00:00, 22.49it/s]
  0%|          | 0/782 [00:00<?, ?it/s]


Test set: Average loss: 0.6697, Accuracy: 7958/10000 (79.58%)

Epoch: 24


Loss=0.2850653827190399 Batch_id=781 Accuracy=89.67: 100%|██████████| 782/782 [00:34<00:00, 22.38it/s]
  0%|          | 0/782 [00:00<?, ?it/s]


Test set: Average loss: 0.6761, Accuracy: 8012/10000 (80.12%)

Epoch: 25


Loss=0.2851686477661133 Batch_id=781 Accuracy=90.33: 100%|██████████| 782/782 [00:35<00:00, 22.30it/s]
  0%|          | 0/782 [00:00<?, ?it/s]


Test set: Average loss: 0.6842, Accuracy: 7967/10000 (79.67%)

Epoch: 26


Loss=0.14348618686199188 Batch_id=781 Accuracy=90.53: 100%|██████████| 782/782 [00:34<00:00, 22.64it/s]
  0%|          | 0/782 [00:00<?, ?it/s]


Test set: Average loss: 0.6916, Accuracy: 7955/10000 (79.55%)

Epoch: 27


Loss=0.9021440744400024 Batch_id=781 Accuracy=90.94: 100%|██████████| 782/782 [00:35<00:00, 22.21it/s]
  0%|          | 0/782 [00:00<?, ?it/s]


Test set: Average loss: 0.6732, Accuracy: 7983/10000 (79.83%)

Epoch: 28


Loss=0.6575213074684143 Batch_id=781 Accuracy=91.02: 100%|██████████| 782/782 [00:34<00:00, 22.73it/s]
  0%|          | 0/782 [00:00<?, ?it/s]


Test set: Average loss: 0.6785, Accuracy: 8013/10000 (80.13%)

Epoch: 29


Loss=0.7011525630950928 Batch_id=781 Accuracy=91.51: 100%|██████████| 782/782 [00:35<00:00, 22.29it/s]
  0%|          | 0/782 [00:00<?, ?it/s]


Test set: Average loss: 0.6728, Accuracy: 8081/10000 (80.81%)

Epoch: 30


Loss=0.6145514845848083 Batch_id=781 Accuracy=92.04: 100%|██████████| 782/782 [00:34<00:00, 22.54it/s]



Test set: Average loss: 0.7502, Accuracy: 7941/10000 (79.41%)

