In [1]:
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torch.nn as nn
from torchvision import models
from tqdm import tqdm, trange
import time

In [None]:
# Define transforms
transform = transforms.Compose([
    transforms.Resize((300, 250)), # Resize images
    transforms.ToTensor(),               
])

# Load datasets
train_dataset = datasets.ImageFolder(
    root="Data/data1a/training",
    transform=transform
)

val_dataset = datasets.ImageFolder(
    root="Data/data1a/validation",
    transform=transform
)

# Create DataLoaders
train_loader = DataLoader(
    train_dataset,
    batch_size=32,
    shuffle=True,
    num_workers=2
)

val_loader = DataLoader(
    val_dataset,
    batch_size=32,
    shuffle=False,
    num_workers=2
)

# debugging 
# print("Classes:", train_dataset.classes)
# print("Train batches:", len(train_loader))
# print("Val batches:", len(val_loader))

In [None]:
class ResNetBinary(nn.Module):
    def __init__(self):
        super().__init__()

        # Load ResNet18 model
        self.model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
        num_features = self.model.fc.in_features 
        self.model.fc = nn.Linear(num_features, 2)

    def forward(self, x):
        return self.model(x)


In [2]:
# Checking for cuda
print("torch.version:", torch.version)
print("torch.version.cuda:", torch.version.cuda)
print("cuda available:", torch.cuda.is_available())
print("cuda device count:", torch.cuda.device_count())
if torch.cuda.is_available():
    print("current device index:", torch.cuda.current_device())
    try:
        print("device name:", torch.cuda.get_device_name(0))
    except Exception as e:
        print("get_device_name failed:", e)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

NameError: name 'torch' is not defined

In [None]:
def train(model, dataset, epochs):
    optimizer = torch.optim.Adam(model.parameters())
    loss = nn.CrossEntropyLoss()
    dataloader = train_loader
    model = model.to(device)
    for epoch in trange(epochs):
        start = time.time()
        for (xs, targets) in tqdm(dataloader):
            xs, targets = xs.to(device), targets.to(device)
            ys = model(xs)
            optimizer.zero_grad()
            l = loss(ys, targets)
            l.backward()
            optimizer.step()
            with torch.no_grad():
                acc = (ys.argmax(axis=1) == targets).sum() / xs.shape[0]
        duration = time.time() - start
        print("[%d] acc = %.2f loss = %.4f in %.2f seconds." % (epoch, acc.item(), l.item(), duration))

Using device: cuda


In [None]:
model = ResNetBinary()

train(model, train_dataset, epochs=10)

100%|██████████| 15/15 [02:34<00:00, 10.32s/it]
 10%|█         | 1/10 [02:38<23:47, 158.62s/it]

[0] acc = 0.98 loss = 0.0785 in 154.77 seconds.


100%|██████████| 15/15 [02:35<00:00, 10.36s/it]
 20%|██        | 2/10 [05:19<21:17, 159.71s/it]

[1] acc = 0.92 loss = 0.1388 in 155.34 seconds.


 33%|███▎      | 5/15 [00:56<01:52, 11.30s/it]
 20%|██        | 2/10 [06:15<25:02, 187.80s/it]


KeyboardInterrupt: 

In [None]:
def validate(model, dataloader):
    """Run model on validation DataLoader and return loss/accuracy."""
    model = model.to(device)
    model.eval()
    loss_fn = nn.CrossEntropyLoss()
    total = 0
    correct = 0
    total_loss = 0.0
    with torch.no_grad():
        for xs, targets in tqdm(dataloader):
            xs = xs.to(device, non_blocking=True)
            targets = targets.to(device, non_blocking=True)
            ys = model(xs)
            loss = loss_fn(ys, targets)
            total_loss += loss.item() * xs.size(0)
            preds = ys.argmax(dim=1)
            correct += (preds == targets).sum().item()
            total += xs.size(0)
    avg_loss = total_loss / total if total > 0 else 0.0
    acc = correct / total if total > 0 else 0.0
    print(f"Validation - loss: {avg_loss:.4f}, accuracy: {acc:.4f} ({correct}/{total})")
    return {'loss': avg_loss, 'accuracy': acc, 'correct': correct, 'total': total}

In [None]:
try:
    _ = validate(model, val_loader)
except NameError:
    print('Validation DataLoader not found.')

In [None]:
# Saving model weights
torch.save(model.state_dict(), "trained_weights/model.pth")