### A

In [1]:
# Set hyperparameters
EPOCH = 30
BATCH_SIZE = 16
LR = 0.01

In [2]:
import torch
import torchvision.models as models
import torch.nn as nn
import torch.optim as optim
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset ,random_split
import os
import cv2
import numpy as np
import jajucha2

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

folder_path = "checkpoints/"

# Check if the folder exists
if os.path.exists(folder_path):
    # Iterate over all the files and directories inside the folder
    for filename in os.listdir(folder_path):
        file_path = os.path.join(folder_path, filename)
        try:
            # Check if it is a file or directory and remove accordingly
            if os.path.isfile(file_path) or os.path.islink(file_path):
                os.unlink(file_path)  # Remove the file or symbolic link
            elif os.path.isdir(file_path):
                shutil.rmtree(file_path)  # Remove the directory
        except Exception as e:
            print(f"Failed to delete {file_path}. Reason: {e}")
    print(f"All items in the folder '{folder_path}' have been deleted.")
else:
    print(f"Folder '{folder_path}' does not exist.")

All items in the folder 'checkpoints/' have been deleted.


In [3]:
class CustomDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.classes = sorted(os.listdir(root_dir))
        self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)}
        self.data = self._load_data()

    def _load_data(self):
        data = []
        for cls in self.classes:
            class_path = os.path.join(self.root_dir, cls)
            for img_name in os.listdir(class_path):
                img_path = os.path.join(class_path, img_name)
                data.append((img_path, self.class_to_idx[cls]))
        return data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path, label = self.data[idx]
        image = Image.open(img_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        return image, label

# Import the ResNet-18 model
resnet18 = models.resnet18(pretrained=True)  # Set pretrained=True to load weights pre-trained on ImageNet

# Prepare dataset and preprocessing
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Change input size
    transforms.ToTensor(),
])

# Set the directory
current_directory = os.getcwd()
file_path = "data"

custom_dataset = CustomDataset(root_dir=file_path, transform=transform)

train_size = int(0.8 * len(custom_dataset))  # 80% for training
test_size = len(custom_dataset) - train_size  # 20% for testing

train_dataset, test_dataset = random_split(custom_dataset, [train_size, test_size])

trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(test_dataset, batch_size=8, shuffle=True, num_workers=2)

In [4]:
# Define ResNet18
net = models.resnet18(pretrained=False)  # Import ResNet18 from torchvision.models
num_ftrs = net.fc.in_features
net.fc = nn.Linear(num_ftrs, int(len(jajucha2.ai.get_classes(str(os.path.basename(os.path.dirname(os.getcwd())))))))
net = net.to(device)

In [5]:
# Define loss function & optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9, weight_decay=5e-4)

In [None]:
from tqdm import tqdm
from IPython.display import display, HTML
display(HTML("<style>.output_scroll { height: auto; }</style>"))

for epoch in range(0, EPOCH):
    print('\nEpoch: %d' % (epoch + 1))
    net.train()
    sum_loss = 0.0
    correct = 0.0
    total = 0.0
    length = len(trainloader)

    # Use tqdm to create a progress bar for each epoch
    with tqdm(total=length, desc=f"Epoch {epoch + 1}/{EPOCH}", unit="batch") as pbar:
        for i, data in enumerate(trainloader, 0):
            # Prepare dataset
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()

            # Forward & backward
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # Print accuracy & loss in each batch
            sum_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += predicted.eq(labels.data).cpu().sum()

            # Update the tqdm progress bar
            pbar.set_postfix({
                'Loss': f'{sum_loss / (i + 1):.3f}', 
                'Acc': f'{100. * correct / total:.3f}%'
            })
            pbar.update(1)

    model_scripted = torch.jit.script(net)
    model_scripted.save(os.path.join('checkpoints', f'resnet_epoch_{epoch + 1}.pt'))
    
    # Get accuracy with the test dataset in each epoch
    print('Waiting Test...')
    with torch.no_grad():
        correct = 0
        total = 0
        for data in testloader:
            net.eval()
            images, labels = data
            images, labels = images.to(device), labels.to(device)

            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum()

        print('Test\'s accuracy is: %.3f%%' % (100. * correct / total))

print('Training has finished, total epoch is %d' % EPOCH)
