In [2]:
import torch
from torch import nn 
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms
import matplotlib.pyplot as plt
from collections import namedtuple
from sklearn.metrics import classification_report

In [12]:
from torch._C import device
from torch.nn.modules.batchnorm import BatchNorm2d
from collections import namedtuple
def get_classes():
  classes = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
  return classes

# khai bao tap train va test theo cau truc de de su dung
TrainTest = namedtuple('TrainTest', ['train', 'test'])

def prepare_data():
  transform = transforms.Compose([
    transforms.ToTensor()
  ])
  trainset = torchvision.datasets.CIFAR10(root='.\data', download=True, train=True, transform=transform)
  testset = torchvision.datasets.CIFAR10(root='.\data', download=True, train=False, transform=transform)
  return TrainTest(train=trainset, test=testset)

# cung cap du lieu theo tung batch, moi lan dua 1 batch gom 128 anh vao de training
def prepare_loader(datasets):
  trainloader = DataLoader(dataset=datasets.train, batch_size=128, shuffle=True, num_workers=4)
  testloader = DataLoader(dataset=datasets.test, batch_size=128, shuffle=False, num_workers=4)
  return TrainTest(train=trainloader, test=testloader)

class VGG16(nn.Module):
  def __init__(self):
    super().__init__() 
    self.features = self._make_features()
    self.classification_head = nn.Linear(in_features=512, out_features=10)
  
  def forward(self, x):  # anh dau vao la x
    out = self.features(x) # dau ra la anh khi ta cho qua cac layers
    out = out.view(out.size(0), -1) # dau vao cua lop full 4-conected la anh 128x512x1x1
    # ta can chuyen anh ve dang 128x512, cac chieu phia sau bi gop lai
    
    out = self.classification_head(out) #128x10
    return out

  def _make_features(self):
    config = [64,64,'MP',128,128,'MP',256,256,256,'MP',512,512,512,'MP',512,512,512,'MP']
    layers = [] # tao mang chua cac layers
    c_in=3 # anh mau RGB
    for c in config:
      if c == 'MP':
        layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
      else:
        layers += [nn.Conv2d(in_channels=c_in, 
                             out_channels=c,
                             kernel_size=3,
                             stride=1,
                             padding=1), #padding=1 lam cho kich thuoc ma tran dau ra khong doi
                   nn.BatchNorm2d(num_features=c),
                   nn.ReLU6(inplace=True)] # lam` viec truc tiep tren dau ra
        c_in = c # moi lop se lam thay doi so kenh cua anh
    return nn.Sequential(*layers) # tao ra network gom cac layers

def imshow(images, labels, predicted, target_names):
  img = torchvision.utils.make_grid(images)
  plt.imshow(img.permute(1,2,0).cpu().numpy())
  [print(target_names[c], end=' ') for c in list(labels.cpu().numpy())]
  print() # nhan that 
  [print(target_names[c], end=' ') for c in list(predicted.cpu().numpy())]
  print() # nhan du doan

def train_epoch(epoch, model, loader, loss_func, optimizer, device):
  model.train() 
  running_loss = 0.0
  reporting_steps = 60 # moi 60 vong lap report 1 lan
  for i, (images, labels) in enumerate(loader):
    #model.train()
    images, labels = images.to(device), labels.to(device)
    outputs = model(images)
    loss = loss_func(outputs, labels)

    optimizer.zero_grad()
    loss.backward() # tinh dao ham tu ham loi
    optimizer.step()

    # quan sat qua trinh training
    running_loss += loss.item()
    if i % reporting_steps == reporting_steps-1:
      print(f"Epoch {epoch} step {i} ave_loss {running_loss/reporting_steps:.4f}")
      running_loss = 0.0

def test_epoch(epoch, model, loader, device):
  ytrue = []
  ypred = []
  with torch.no_grad():
    model.eval()
  
    for i, (images, labels) in enumerate(loader):
      images, labels = images.to(device), labels.to(device)
      outputs = model(images)
      _, predicted = torch.max(outputs, dim=1)

      ytrue += list(labels.cpu().numpy())
      ypred += list(predicted.cpu().numpy())

  return ypred, ytrue

def main():
  classes = get_classes()
  datasets = prepare_data()

  loaders = prepare_loader(datasets)

  # train bang gpu
  device = torch.device("cuda:0")
  model = VGG16().to(device)
  
  loss_func = nn.CrossEntropyLoss()
  optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
  for epoch in range(10):
    train_epoch(epoch, model, loaders.train, loss_func, optimizer, device)
    ypred, ytrue = test_epoch(epoch, model, loaders.test, device)
    print(classification_report(ytrue, ypred))

main()

Files already downloaded and verified
Files already downloaded and verified


  cpuset_checked))


Epoch 0 step 59 ave_loss 1.9708
Epoch 0 step 119 ave_loss 1.5822
Epoch 0 step 179 ave_loss 1.3671
Epoch 0 step 239 ave_loss 1.2290
Epoch 0 step 299 ave_loss 1.1308
Epoch 0 step 359 ave_loss 1.0406
              precision    recall  f1-score   support

           0       0.90      0.35      0.51      1000
           1       0.79      0.83      0.81      1000
           2       0.46      0.54      0.50      1000
           3       0.35      0.37      0.36      1000
           4       0.59      0.66      0.62      1000
           5       0.46      0.73      0.56      1000
           6       0.78      0.71      0.75      1000
           7       0.82      0.60      0.69      1000
           8       0.85      0.69      0.76      1000
           9       0.72      0.84      0.78      1000

    accuracy                           0.63     10000
   macro avg       0.67      0.63      0.63     10000
weighted avg       0.67      0.63      0.63     10000



  cpuset_checked))


Epoch 1 step 59 ave_loss 0.8696
Epoch 1 step 119 ave_loss 0.8416
Epoch 1 step 179 ave_loss 0.7882
Epoch 1 step 239 ave_loss 0.7671
Epoch 1 step 299 ave_loss 0.7569
Epoch 1 step 359 ave_loss 0.7241
              precision    recall  f1-score   support

           0       0.78      0.73      0.75      1000
           1       0.76      0.94      0.84      1000
           2       0.43      0.78      0.55      1000
           3       0.52      0.59      0.55      1000
           4       0.80      0.53      0.64      1000
           5       0.81      0.43      0.56      1000
           6       0.89      0.66      0.76      1000
           7       0.87      0.70      0.78      1000
           8       0.80      0.90      0.85      1000
           9       0.82      0.84      0.83      1000

    accuracy                           0.71     10000
   macro avg       0.75      0.71      0.71     10000
weighted avg       0.75      0.71      0.71     10000



  cpuset_checked))


Epoch 2 step 59 ave_loss 0.5995
Epoch 2 step 119 ave_loss 0.5940
Epoch 2 step 179 ave_loss 0.5712
Epoch 2 step 239 ave_loss 0.5750
Epoch 2 step 299 ave_loss 0.5477
Epoch 2 step 359 ave_loss 0.5655
              precision    recall  f1-score   support

           0       0.72      0.84      0.78      1000
           1       0.95      0.75      0.84      1000
           2       0.79      0.58      0.67      1000
           3       0.76      0.39      0.51      1000
           4       0.75      0.80      0.78      1000
           5       0.65      0.79      0.71      1000
           6       0.72      0.91      0.81      1000
           7       0.81      0.87      0.83      1000
           8       0.97      0.62      0.76      1000
           9       0.62      0.95      0.75      1000

    accuracy                           0.75     10000
   macro avg       0.77      0.75      0.74     10000
weighted avg       0.77      0.75      0.74     10000



  cpuset_checked))


Epoch 3 step 59 ave_loss 0.4594
Epoch 3 step 119 ave_loss 0.4367
Epoch 3 step 179 ave_loss 0.4682
Epoch 3 step 239 ave_loss 0.4645
Epoch 3 step 299 ave_loss 0.4439
Epoch 3 step 359 ave_loss 0.4394
              precision    recall  f1-score   support

           0       0.83      0.80      0.82      1000
           1       0.92      0.91      0.91      1000
           2       0.73      0.67      0.70      1000
           3       0.75      0.44      0.55      1000
           4       0.55      0.93      0.69      1000
           5       0.87      0.45      0.59      1000
           6       0.73      0.89      0.81      1000
           7       0.74      0.86      0.80      1000
           8       0.89      0.89      0.89      1000
           9       0.88      0.86      0.87      1000

    accuracy                           0.77     10000
   macro avg       0.79      0.77      0.76     10000
weighted avg       0.79      0.77      0.76     10000



  cpuset_checked))


Epoch 4 step 59 ave_loss 0.3449
Epoch 4 step 119 ave_loss 0.3535
Epoch 4 step 179 ave_loss 0.3592
Epoch 4 step 239 ave_loss 0.3567
Epoch 4 step 299 ave_loss 0.3729
Epoch 4 step 359 ave_loss 0.3527
              precision    recall  f1-score   support

           0       0.81      0.80      0.80      1000
           1       0.96      0.84      0.90      1000
           2       0.68      0.69      0.68      1000
           3       0.67      0.54      0.59      1000
           4       0.60      0.87      0.71      1000
           5       0.94      0.36      0.52      1000
           6       0.61      0.95      0.75      1000
           7       0.90      0.73      0.80      1000
           8       0.82      0.94      0.88      1000
           9       0.89      0.87      0.88      1000

    accuracy                           0.76     10000
   macro avg       0.79      0.76      0.75     10000
weighted avg       0.79      0.76      0.75     10000



  cpuset_checked))


Epoch 5 step 59 ave_loss 0.2474
Epoch 5 step 119 ave_loss 0.2627
Epoch 5 step 179 ave_loss 0.2818
Epoch 5 step 239 ave_loss 0.2862
Epoch 5 step 299 ave_loss 0.2807
Epoch 5 step 359 ave_loss 0.3001
              precision    recall  f1-score   support

           0       0.81      0.82      0.81      1000
           1       0.92      0.91      0.91      1000
           2       0.71      0.79      0.75      1000
           3       0.57      0.79      0.66      1000
           4       0.79      0.82      0.80      1000
           5       0.80      0.67      0.73      1000
           6       0.94      0.75      0.83      1000
           7       0.96      0.76      0.85      1000
           8       0.82      0.95      0.88      1000
           9       0.95      0.82      0.89      1000

    accuracy                           0.81     10000
   macro avg       0.83      0.81      0.81     10000
weighted avg       0.83      0.81      0.81     10000



  cpuset_checked))


Epoch 6 step 59 ave_loss 0.1823
Epoch 6 step 119 ave_loss 0.2183
Epoch 6 step 179 ave_loss 0.2389
Epoch 6 step 239 ave_loss 0.2122
Epoch 6 step 299 ave_loss 0.2414
Epoch 6 step 359 ave_loss 0.2557
              precision    recall  f1-score   support

           0       0.80      0.89      0.84      1000
           1       0.94      0.91      0.93      1000
           2       0.62      0.84      0.72      1000
           3       0.64      0.73      0.68      1000
           4       0.84      0.75      0.79      1000
           5       0.92      0.54      0.68      1000
           6       0.93      0.82      0.87      1000
           7       0.87      0.87      0.87      1000
           8       0.89      0.93      0.91      1000
           9       0.88      0.91      0.89      1000

    accuracy                           0.82     10000
   macro avg       0.83      0.82      0.82     10000
weighted avg       0.83      0.82      0.82     10000



  cpuset_checked))


Epoch 7 step 59 ave_loss 0.1357
Epoch 7 step 119 ave_loss 0.1677
Epoch 7 step 179 ave_loss 0.1778
Epoch 7 step 239 ave_loss 0.1922
Epoch 7 step 299 ave_loss 0.1869
Epoch 7 step 359 ave_loss 0.2004
              precision    recall  f1-score   support

           0       0.81      0.88      0.85      1000
           1       0.95      0.89      0.92      1000
           2       0.80      0.69      0.74      1000
           3       0.62      0.73      0.67      1000
           4       0.78      0.85      0.81      1000
           5       0.67      0.79      0.73      1000
           6       0.91      0.83      0.87      1000
           7       0.96      0.75      0.84      1000
           8       0.92      0.91      0.91      1000
           9       0.92      0.90      0.91      1000

    accuracy                           0.82     10000
   macro avg       0.83      0.82      0.82     10000
weighted avg       0.83      0.82      0.82     10000



  cpuset_checked))


Epoch 8 step 59 ave_loss 0.1308
Epoch 8 step 119 ave_loss 0.1149
Epoch 8 step 179 ave_loss 0.1343
Epoch 8 step 239 ave_loss 0.1428
Epoch 8 step 299 ave_loss 0.1621
Epoch 8 step 359 ave_loss 0.1720
              precision    recall  f1-score   support

           0       0.77      0.91      0.83      1000
           1       0.87      0.95      0.91      1000
           2       0.91      0.63      0.74      1000
           3       0.69      0.68      0.68      1000
           4       0.80      0.83      0.81      1000
           5       0.68      0.82      0.74      1000
           6       0.86      0.89      0.87      1000
           7       0.87      0.84      0.86      1000
           8       0.96      0.79      0.87      1000
           9       0.90      0.89      0.90      1000

    accuracy                           0.82     10000
   macro avg       0.83      0.82      0.82     10000
weighted avg       0.83      0.82      0.82     10000



  cpuset_checked))


Epoch 9 step 59 ave_loss 0.1035
Epoch 9 step 119 ave_loss 0.1092
Epoch 9 step 179 ave_loss 0.1086
Epoch 9 step 239 ave_loss 0.1226
Epoch 9 step 299 ave_loss 0.1365
Epoch 9 step 359 ave_loss 0.1385
              precision    recall  f1-score   support

           0       0.86      0.85      0.85      1000
           1       0.88      0.94      0.91      1000
           2       0.75      0.77      0.76      1000
           3       0.54      0.83      0.65      1000
           4       0.89      0.72      0.80      1000
           5       0.90      0.56      0.69      1000
           6       0.85      0.88      0.87      1000
           7       0.89      0.85      0.87      1000
           8       0.91      0.90      0.91      1000
           9       0.91      0.90      0.90      1000

    accuracy                           0.82     10000
   macro avg       0.84      0.82      0.82     10000
weighted avg       0.84      0.82      0.82     10000

