In [12]:
import sys
sys.path.append("../models")

import torch
import torchvision
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
from model_v1 import CnnModel
import torch.nn as nn
import torch.optim as optim
import json


In [13]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device


'cpu'

In [14]:
trainTransform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor()
])


In [15]:
trainData = datasets.ImageFolder("../data/train", transform=trainTransform)
trainLoader = DataLoader(trainData, batch_size=16, shuffle=True)

testData = datasets.ImageFolder("../data/test", transform=trainTransform)
testLoader = DataLoader(testData, batch_size=16, shuffle=False)


In [16]:
model = CnnModel()
model.to(device)


CnnModel(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1))
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1))
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (adapt): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc1): Linear(in_features=32, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=2, bias=True)
)

In [17]:
lossFunc = nn.CrossEntropyLoss()
opt = optim.Adam(model.parameters(), lr=0.0003)


In [18]:
from tqdm import tqdm

epochs = 20
history = {}

for e in range(epochs):
    print(f"Epoch {e+1}/{epochs}")
    loop = tqdm(trainLoader)
    
    total = 0
    correct = 0
    runningLoss = 0

    for imgs, labels in loop:
        imgs = imgs.to(device)
        labels = labels.to(device)

        out = model(imgs)
        loss = lossFunc(out, labels)

        opt.zero_grad()
        loss.backward()
        opt.step()

        runningLoss += loss.item()
        _, preds = torch.max(out, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    accuracy = correct / total
    history[e] = {"loss": runningLoss, "accuracy": accuracy}
    
    print(f"Loss: {runningLoss:.4f}  |  Accuracy: {accuracy:.4f}")

print("\nTraining Finished.\n")


Epoch 1/20


100%|██████████| 501/501 [05:47<00:00,  1.44it/s]


Loss: 346.6987  |  Accuracy: 0.5146
Epoch 2/20


100%|██████████| 501/501 [04:10<00:00,  2.00it/s]


Loss: 342.3407  |  Accuracy: 0.5585
Epoch 3/20


100%|██████████| 501/501 [03:48<00:00,  2.19it/s]


Loss: 334.8769  |  Accuracy: 0.5785
Epoch 4/20


100%|██████████| 501/501 [03:35<00:00,  2.32it/s]


Loss: 331.6152  |  Accuracy: 0.5889
Epoch 5/20


100%|██████████| 501/501 [03:28<00:00,  2.40it/s]


Loss: 330.0074  |  Accuracy: 0.5960
Epoch 6/20


100%|██████████| 501/501 [03:29<00:00,  2.39it/s]


Loss: 327.7197  |  Accuracy: 0.6037
Epoch 7/20


100%|██████████| 501/501 [03:23<00:00,  2.46it/s]


Loss: 327.8332  |  Accuracy: 0.6035
Epoch 8/20


100%|██████████| 501/501 [03:29<00:00,  2.40it/s]


Loss: 326.5068  |  Accuracy: 0.6022
Epoch 9/20


100%|██████████| 501/501 [04:43<00:00,  1.76it/s]


Loss: 325.3484  |  Accuracy: 0.6070
Epoch 10/20


100%|██████████| 501/501 [04:14<00:00,  1.97it/s]


Loss: 325.2306  |  Accuracy: 0.6069
Epoch 11/20


100%|██████████| 501/501 [03:50<00:00,  2.18it/s]


Loss: 324.3153  |  Accuracy: 0.6092
Epoch 12/20


100%|██████████| 501/501 [03:54<00:00,  2.13it/s]


Loss: 322.9247  |  Accuracy: 0.6106
Epoch 13/20


100%|██████████| 501/501 [03:46<00:00,  2.21it/s]


Loss: 321.0922  |  Accuracy: 0.6152
Epoch 14/20


100%|██████████| 501/501 [03:46<00:00,  2.21it/s]


Loss: 319.9861  |  Accuracy: 0.6264
Epoch 15/20


100%|██████████| 501/501 [03:32<00:00,  2.36it/s]


Loss: 320.2644  |  Accuracy: 0.6180
Epoch 16/20


100%|██████████| 501/501 [03:36<00:00,  2.32it/s]


Loss: 318.0156  |  Accuracy: 0.6280
Epoch 17/20


100%|██████████| 501/501 [03:33<00:00,  2.34it/s]


Loss: 318.8208  |  Accuracy: 0.6234
Epoch 18/20


100%|██████████| 501/501 [03:47<00:00,  2.21it/s]


Loss: 315.7007  |  Accuracy: 0.6322
Epoch 19/20


100%|██████████| 501/501 [03:44<00:00,  2.23it/s]


Loss: 314.7510  |  Accuracy: 0.6352
Epoch 20/20


100%|██████████| 501/501 [03:58<00:00,  2.10it/s]


Loss: 315.0750  |  Accuracy: 0.6315

Training Finished.



In [19]:
best_epoch = max(history.keys(), key=lambda e: history[e]["accuracy"])
best_acc = history[best_epoch]["accuracy"]

print(f"Best Accuracy: {best_acc:.4f} at epoch {best_epoch+1}")


Best Accuracy: 0.6352 at epoch 19


In [20]:
save_path = "../models/model_v1.pth"
torch.save(model.state_dict(), save_path)
print("Model saved to:", save_path)


Model saved to: ../models/model_v1.pth


In [21]:
import json

metrics_path = "../results/metrics_v1.json"

with open(metrics_path, "w") as f:
    json.dump(history, f, indent=4)

print("Training metrics saved to:", metrics_path)


Training metrics saved to: ../results/metrics_v1.json


In [None]:
#