In [1]:
import torch
import torch.optim as optim
from tqdm import tqdm
from src import build_dataloader, build_model

In [2]:
def train(model, optimizer, data_loader, epoch):
    model.train()

    train_loss = 0
    i = 0
    for data in tqdm(data_loader, desc = f"EPOCH: {epoch}"):
        images, labels = data
        images = images.cuda()
        labels = labels.cuda()

        optimizer.zero_grad()
        loss = model.loss(images, labels)

        i += 1
        train_loss += loss.item()    

        loss.backward()        
        optimizer.step()

    print(f"Train Loss: {train_loss / i:.2f}")

@torch.no_grad()
def test(model, data_loader):
    model.eval()

    test_loss= 0
    i = 0
    for data in data_loader:
        images, labels = data
        images = images.cuda()
        labels = labels.cuda()

        loss = model.loss(images, labels)

        i += 1
        test_loss += loss.item()    
        
    print(f"Test Loss: {test_loss / i:.2f}")

In [3]:
train_loader, test_loader = build_dataloader(64)
model = build_model().cuda()
optimizer = optim.Adam(model.parameters())

for epoch in range(30):
    train(model, optimizer, train_loader, epoch + 1)
    test(model, test_loader)

torch.save(model.state_dict(), "weight.pt")

EPOCH: 1: 100%|██████████| 938/938 [00:28<00:00, 33.21it/s]


Train Loss: 2.75
Test Loss: 2.54


EPOCH: 2: 100%|██████████| 938/938 [00:26<00:00, 35.18it/s]


Train Loss: 2.54
Test Loss: 2.50


EPOCH: 3: 100%|██████████| 938/938 [00:27<00:00, 33.78it/s]


Train Loss: 2.50
Test Loss: 2.45


EPOCH: 4: 100%|██████████| 938/938 [00:28<00:00, 32.97it/s]


Train Loss: 2.48
Test Loss: 2.46


EPOCH: 5: 100%|██████████| 938/938 [00:27<00:00, 33.93it/s]


Train Loss: 2.46
Test Loss: 2.43


EPOCH: 6: 100%|██████████| 938/938 [00:27<00:00, 33.75it/s]


Train Loss: 2.45
Test Loss: 2.45


EPOCH: 7: 100%|██████████| 938/938 [00:27<00:00, 33.97it/s]


Train Loss: 2.44
Test Loss: 2.44


EPOCH: 8: 100%|██████████| 938/938 [00:26<00:00, 34.75it/s]


Train Loss: 2.44
Test Loss: 2.43


EPOCH: 9: 100%|██████████| 938/938 [00:26<00:00, 35.03it/s]


Train Loss: 2.43
Test Loss: 2.43


EPOCH: 10: 100%|██████████| 938/938 [00:26<00:00, 35.04it/s]


Train Loss: 2.42
Test Loss: 2.42


EPOCH: 11: 100%|██████████| 938/938 [00:26<00:00, 34.75it/s]


Train Loss: 2.42
Test Loss: 2.42


EPOCH: 12: 100%|██████████| 938/938 [00:26<00:00, 35.27it/s]


Train Loss: 2.41
Test Loss: 2.42


EPOCH: 13: 100%|██████████| 938/938 [00:25<00:00, 36.21it/s]


Train Loss: 2.41
Test Loss: 2.42


EPOCH: 14: 100%|██████████| 938/938 [00:26<00:00, 35.89it/s]


Train Loss: 2.41
Test Loss: 2.41


EPOCH: 15: 100%|██████████| 938/938 [00:26<00:00, 36.05it/s]


Train Loss: 2.40
Test Loss: 2.42


EPOCH: 16: 100%|██████████| 938/938 [00:26<00:00, 35.85it/s]


Train Loss: 2.40
Test Loss: 2.42


EPOCH: 17: 100%|██████████| 938/938 [00:25<00:00, 36.25it/s]


Train Loss: 2.40
Test Loss: 2.42


EPOCH: 18: 100%|██████████| 938/938 [00:25<00:00, 36.29it/s]


Train Loss: 2.40
Test Loss: 2.41


EPOCH: 19: 100%|██████████| 938/938 [00:26<00:00, 36.06it/s]


Train Loss: 2.40
Test Loss: 2.41


EPOCH: 20: 100%|██████████| 938/938 [00:25<00:00, 36.10it/s]


Train Loss: 2.39
Test Loss: 2.41


EPOCH: 21: 100%|██████████| 938/938 [00:25<00:00, 36.27it/s]


Train Loss: 2.39
Test Loss: 2.42


EPOCH: 22: 100%|██████████| 938/938 [00:26<00:00, 35.45it/s]


Train Loss: 2.39
Test Loss: 2.42


EPOCH: 23: 100%|██████████| 938/938 [00:25<00:00, 36.31it/s]


Train Loss: 2.39
Test Loss: 2.42


EPOCH: 24: 100%|██████████| 938/938 [00:25<00:00, 36.17it/s]


Train Loss: 2.38
Test Loss: 2.40


EPOCH: 25: 100%|██████████| 938/938 [00:25<00:00, 36.23it/s]


Train Loss: 2.39
Test Loss: 2.41


EPOCH: 26: 100%|██████████| 938/938 [00:25<00:00, 36.23it/s]


Train Loss: 2.38
Test Loss: 2.41


EPOCH: 27: 100%|██████████| 938/938 [00:26<00:00, 36.07it/s]


Train Loss: 2.38
Test Loss: 2.41


EPOCH: 28: 100%|██████████| 938/938 [00:25<00:00, 36.12it/s]


Train Loss: 2.38
Test Loss: 2.42


EPOCH: 29: 100%|██████████| 938/938 [00:26<00:00, 36.00it/s]


Train Loss: 2.38
Test Loss: 2.43


EPOCH: 30: 100%|██████████| 938/938 [00:25<00:00, 36.10it/s]


Train Loss: 2.38
Test Loss: 2.42
