In [None]:
import time
import torch
import torchvision.models.resnet50 as resnet
import torchvision.models.convnext_base as convnext
import torchvision.transforms.v2 as transforms
from torch.utils.data import Dataset, DataLoader

In [None]:
BATCH_SIZE = 8
EPOCHS = 32

In [None]:
import os
from torchvision.io import read_image
from torch.utils.data import Dataset, DataLoader

class LeafClassificationDataset(Dataset):
    def __init__(self, dir, transforms=None, default_label=1):
        self.files = [os.path.join(dir, file) for file in os.listdir() if os.path.isfile(file)]
        self.transforms = transforms
        self.default_label = default_label

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        image = read_image(self.files[idx])
        label = self.default_label
        if self.transforms:
            image = self.transforms(image)
        return image, label

In [None]:
tf = transforms.Compose([
    transforms.ToTensor(),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(.5),
    transforms.RandomPerspective(distortion_scale=0.6)
])
dataset = LeafClassificationDataset("_data/urban_street/images", transforms=tf)
dataloader = DataLoader(dataset, shuffle=True, batch_size=BATCH_SIZE, num_workers=BATCH_SIZE//2)

In [None]:
model = convnext()

In [None]:
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

In [None]:
import pandas as pd
loss_df = pd.DataFrame({}, columns=["step", "epoch", "loss"])

In [None]:
step = 0
for epoch in range(EPOCHS):

    running_loss = 0.0
    for batch in dataloader:
        step += len(batch)
        inputs, labels = batch

        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print("Epoch Loss: ", running_loss)
print("Finished Training, Saving Model")
torch.save(model, f"out/resnet_leaf_classifier/{time.time()}.pt")