In [1]:
import torch
import torch.optim as optim
from tqdm import tqdm
from src import build_dataloader, build_model

In [2]:
def train(model, optimizer, data_loader, epoch):
    model.train()

    train_loss = 0
    i = 0
    for data in tqdm(data_loader, desc = f"EPOCH: {epoch}"):
        images, labels = data
        images = images.cuda()
        labels = labels.cuda()

        optimizer.zero_grad()
        loss = model.loss(images, labels)

        i += 1
        train_loss += loss.item()    

        loss.backward()        
        optimizer.step()

    print(f"Train Loss: {train_loss / i:.2f}")

@torch.no_grad()
def test(model, data_loader):
    model.eval()

    test_loss= 0
    i = 0
    for data in data_loader:
        images, labels = data
        images = images.cuda()
        labels = labels.cuda()

        loss = model.loss(images, labels)

        i += 1
        test_loss += loss.item()    
        
    print(f"Test Loss: {test_loss / i:.2f}")

In [3]:
train_loader, test_loader = build_dataloader(64)
model = build_model().cuda()
optimizer = optim.AdamW(model.parameters())

for epoch in range(50):
    train(model, optimizer, train_loader, epoch + 1)
    test(model, test_loader)

torch.save(model.state_dict(), "weight.pt")

EPOCH: 1: 100%|██████████| 938/938 [00:29<00:00, 31.30it/s]


Train Loss: 3.07
Test Loss: 2.75


EPOCH: 2: 100%|██████████| 938/938 [00:29<00:00, 31.96it/s]


Train Loss: 2.74
Test Loss: 2.68


EPOCH: 3: 100%|██████████| 938/938 [00:28<00:00, 32.68it/s]


Train Loss: 2.70
Test Loss: 2.66


EPOCH: 4: 100%|██████████| 938/938 [00:28<00:00, 32.81it/s]


Train Loss: 2.67
Test Loss: 2.65


EPOCH: 5: 100%|██████████| 938/938 [00:28<00:00, 32.95it/s]


Train Loss: 2.65
Test Loss: 2.64


EPOCH: 6: 100%|██████████| 938/938 [00:28<00:00, 33.14it/s]


Train Loss: 2.64
Test Loss: 2.64


EPOCH: 7: 100%|██████████| 938/938 [00:28<00:00, 32.58it/s]


Train Loss: 2.63
Test Loss: 2.61


EPOCH: 8: 100%|██████████| 938/938 [00:28<00:00, 32.82it/s]


Train Loss: 2.62
Test Loss: 2.60


EPOCH: 9: 100%|██████████| 938/938 [00:28<00:00, 32.97it/s]


Train Loss: 2.61
Test Loss: 2.60


EPOCH: 10: 100%|██████████| 938/938 [00:28<00:00, 33.36it/s]


Train Loss: 2.60
Test Loss: 2.60


EPOCH: 11: 100%|██████████| 938/938 [00:27<00:00, 33.72it/s]


Train Loss: 2.60
Test Loss: 2.62


EPOCH: 12: 100%|██████████| 938/938 [00:27<00:00, 33.55it/s]


Train Loss: 2.60
Test Loss: 2.59


EPOCH: 13: 100%|██████████| 938/938 [00:28<00:00, 33.47it/s]


Train Loss: 2.59
Test Loss: 2.61


EPOCH: 14: 100%|██████████| 938/938 [00:28<00:00, 33.33it/s]


Train Loss: 2.59
Test Loss: 2.59


EPOCH: 15: 100%|██████████| 938/938 [00:28<00:00, 33.35it/s]


Train Loss: 2.58
Test Loss: 2.61


EPOCH: 16: 100%|██████████| 938/938 [00:28<00:00, 33.46it/s]


Train Loss: 2.58
Test Loss: 2.59


EPOCH: 17: 100%|██████████| 938/938 [00:28<00:00, 33.27it/s]


Train Loss: 2.58
Test Loss: 2.60


EPOCH: 18: 100%|██████████| 938/938 [00:28<00:00, 33.41it/s]


Train Loss: 2.57
Test Loss: 2.59


EPOCH: 19: 100%|██████████| 938/938 [00:27<00:00, 33.83it/s]


Train Loss: 2.57
Test Loss: 2.59


EPOCH: 20: 100%|██████████| 938/938 [00:28<00:00, 33.14it/s]


Train Loss: 2.57
Test Loss: 2.59


EPOCH: 21: 100%|██████████| 938/938 [00:28<00:00, 32.91it/s]


Train Loss: 2.57
Test Loss: 2.58


EPOCH: 22: 100%|██████████| 938/938 [00:28<00:00, 32.74it/s]


Train Loss: 2.57
Test Loss: 2.60


EPOCH: 23: 100%|██████████| 938/938 [00:28<00:00, 32.94it/s]


Train Loss: 2.56
Test Loss: 2.59


EPOCH: 24: 100%|██████████| 938/938 [00:28<00:00, 33.29it/s]


Train Loss: 2.56
Test Loss: 2.60


EPOCH: 25: 100%|██████████| 938/938 [00:28<00:00, 33.31it/s]


Train Loss: 2.56
Test Loss: 2.58


EPOCH: 26: 100%|██████████| 938/938 [00:28<00:00, 32.71it/s]


Train Loss: 2.56
Test Loss: 2.59


EPOCH: 27: 100%|██████████| 938/938 [00:28<00:00, 32.81it/s]


Train Loss: 2.56
Test Loss: 2.59


EPOCH: 28: 100%|██████████| 938/938 [00:28<00:00, 33.06it/s]


Train Loss: 2.56
Test Loss: 2.58


EPOCH: 29: 100%|██████████| 938/938 [00:27<00:00, 33.54it/s]


Train Loss: 2.55
Test Loss: 2.59


EPOCH: 30: 100%|██████████| 938/938 [00:27<00:00, 33.64it/s]


Train Loss: 2.55
Test Loss: 2.59


EPOCH: 31: 100%|██████████| 938/938 [00:28<00:00, 33.31it/s]


Train Loss: 2.55
Test Loss: 2.58


EPOCH: 32: 100%|██████████| 938/938 [00:27<00:00, 33.58it/s]


Train Loss: 2.55
Test Loss: 2.58


EPOCH: 33: 100%|██████████| 938/938 [00:28<00:00, 33.28it/s]


Train Loss: 2.55
Test Loss: 2.59


EPOCH: 34: 100%|██████████| 938/938 [00:28<00:00, 32.38it/s]


Train Loss: 2.55
Test Loss: 2.58


EPOCH: 35: 100%|██████████| 938/938 [00:28<00:00, 32.45it/s]


Train Loss: 2.55
Test Loss: 2.59


EPOCH: 36: 100%|██████████| 938/938 [00:28<00:00, 32.87it/s]


Train Loss: 2.55
Test Loss: 2.60


EPOCH: 37: 100%|██████████| 938/938 [00:28<00:00, 32.72it/s]


Train Loss: 2.55
Test Loss: 2.57


EPOCH: 38: 100%|██████████| 938/938 [00:28<00:00, 33.29it/s]


Train Loss: 2.55
Test Loss: 2.58


EPOCH: 39: 100%|██████████| 938/938 [00:28<00:00, 32.91it/s]


Train Loss: 2.55
Test Loss: 2.58


EPOCH: 40: 100%|██████████| 938/938 [00:29<00:00, 32.24it/s]


Train Loss: 2.54
Test Loss: 2.59


EPOCH: 41: 100%|██████████| 938/938 [00:28<00:00, 33.18it/s]


Train Loss: 2.54
Test Loss: 2.59


EPOCH: 42: 100%|██████████| 938/938 [00:28<00:00, 32.61it/s]


Train Loss: 2.54
Test Loss: 2.59


EPOCH: 43: 100%|██████████| 938/938 [00:28<00:00, 32.72it/s]


Train Loss: 2.55
Test Loss: 2.60


EPOCH: 44: 100%|██████████| 938/938 [00:28<00:00, 33.11it/s]


Train Loss: 2.54
Test Loss: 2.58


EPOCH: 45: 100%|██████████| 938/938 [00:28<00:00, 32.82it/s]


Train Loss: 2.54
Test Loss: 2.58


EPOCH: 46: 100%|██████████| 938/938 [00:28<00:00, 33.26it/s]


Train Loss: 2.54
Test Loss: 2.58


EPOCH: 47: 100%|██████████| 938/938 [00:28<00:00, 32.70it/s]


Train Loss: 2.54
Test Loss: 2.58


EPOCH: 48: 100%|██████████| 938/938 [00:29<00:00, 32.06it/s]


Train Loss: 2.54
Test Loss: 2.59


EPOCH: 49: 100%|██████████| 938/938 [00:28<00:00, 32.66it/s]


Train Loss: 2.54
Test Loss: 2.60


EPOCH: 50: 100%|██████████| 938/938 [00:28<00:00, 33.09it/s]


Train Loss: 2.54
Test Loss: 2.60
