In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
from collections import Counter

In [None]:
# --- Device ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [None]:
from google.colab import drive
import zipfile
import os


drive.mount('/content/drive')

zip_path = '/content/drive/MyDrive/anime_dataset.zip'


extract_path = '/content/anime_dataset'


with zipfile.ZipFile(zip_path, 'r') as zip_ref:
    zip_ref.extractall(extract_path)


print(os.listdir(extract_path))


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
['train', 'test']


In [None]:
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

In [None]:
from PIL import Image

In [None]:
def pil_loader_rgb(path: str):
    with open(path, 'rb') as f:
        img = Image.open(f)
        return img.convert("RGB")  # convert palette/transparency immediately

train_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

test_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

In [None]:
# --- Datasets ---
train_dataset = datasets.ImageFolder("/content/anime_dataset/train", transform=train_transforms, loader = pil_loader_rgb)
test_dataset = datasets.ImageFolder("/content/anime_dataset/test", transform=test_transforms, loader = pil_loader_rgb)

In [None]:
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=2)

In [None]:

print("Number of classes:", len(train_dataset.classes))

# Count number of images per class
labels = [label for _, label in train_dataset.samples]
class_counts = Counter(labels)
Total = 0
for _, count in class_counts.items():
    Total+= count
print("Total images:", Total)

Number of classes: 493
Total images: 8140


In [None]:
temp = train_dataset.classes

In [None]:
import json

print(json.dumps(temp, ensure_ascii=False))


["20th Century Boys", "3gatsu no Lion", "5toubun no Hanayome", "86", "91 Days", "Accel World", "Acchi Kocchi", "Ajin Part 2 OVA", "AkaKill Gekijou", "Akagami no Shirayukihime", "Akame ga Kill", "Akatsuki no Yona", "Akira", "Aku no Hana", "AldnoahZero", "Amagi Brilliant Park", "Angel Beats", "Angel Beats Specials", "Ano Hi Mita Hana no Namae wo Bokutachi wa Mada Shiranai", "Another", "Ansatsu Kyoushitsu", "Ao Haru Ride", "Ao no Exorcist", "Ao no Hako", "Aria the Animation", "Arifureta Shokugyou de Sekai Saikyou", "Ashita no Joe", "Asobi Asobase", "Ayakashi", "Azumanga Daiou The Animation", "Baccano", "Bakemonogatari", "Bakuman", "Banana Fish", "Barakamon", "Beastars", "Beelzebub", "Berserk", "Bishoujo Senshi Sailor Moon", "Black Cat TV", "Black Clover", "Black Lagoon", "Black Lagoon Omake", "BlackRock Shooter OVA", "Bleach", "Bleach Movie 2", "Blood", "Blood Lad", "Blue Lock", "Blue Period", "Bocchi the Rock", "Boku dake ga Inai Machi", "Boku no Hero Academia", "Boku no Hero Academia 2n

In [None]:
from torchvision.models import ResNet18_Weights

In [None]:
model = models.resnet18(weights=ResNet18_Weights.DEFAULT)  # Use ResNet18 pretrained on ImageNet
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, len(train_dataset.classes))  # Change output layer
model = model.to(device)

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

In [None]:
from timeit import default_timer as timer
from tqdm import tqdm

In [None]:
import os

# Create folder for checkpoints if it doesn't exist
os.makedirs("/content/checkpoints2", exist_ok=True)


In [None]:
start_time = timer()
# --- Training loop skeleton ---
num_epochs = 20
for epoch in tqdm(range(num_epochs)):
    model.train()
    running_loss = 0.0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)

    epoch_loss = running_loss / len(train_loader.dataset)
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}")

    end_time = timer()
    elapsed_time = end_time - start_time
    print(f"Elapsed time: {elapsed_time:.2f} seconds")

    # --- Save checkpoint after each epoch ---
    torch.save(model.state_dict(), f"/content/checkpoints2/resnet18_epoch{epoch+1}.pth")

  5%|▌         | 1/20 [03:52<1:13:44, 232.87s/it]

Epoch [1/20], Loss: 5.7532
Elapsed time: 232.78 seconds


 10%|█         | 2/20 [07:44<1:09:37, 232.10s/it]

Epoch [2/20], Loss: 4.6326
Elapsed time: 464.29 seconds


 15%|█▌        | 3/20 [11:34<1:05:32, 231.35s/it]

Epoch [3/20], Loss: 3.8203
Elapsed time: 694.79 seconds


 20%|██        | 4/20 [15:24<1:01:30, 230.68s/it]

Epoch [4/20], Loss: 3.1026
Elapsed time: 924.46 seconds


In [None]:
# --- Save final model ---
torch.save(model.state_dict(), "/content/resnet18_final2.pth")
print("Training complete.")

In [None]:
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for batch_idx, (images, labels) in enumerate(test_loader):
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, preds = torch.max(outputs, 1)

        # Get file paths for the current batch
        start_idx = batch_idx * test_loader.batch_size
        end_idx = start_idx + images.size(0)
        paths = [test_loader.dataset.imgs[i][0] for i in range(start_idx, end_idx)]

        for path, pred, label in zip(paths, preds, labels):
            pred_name = test_loader.dataset.classes[pred.item()]
            true_name = test_loader.dataset.classes[label.item()]

        # Update counters for accuracy
        correct += (preds == labels).sum().item()
        total += labels.size(0)

accuracy = correct / total
print(f"Test Accuracy: {accuracy*100:.2f}%")


In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
from google.colab import files

torch.save(model.state_dict(), 'resnet18_final1.pth')


In [None]:
## I forgot to make built in list to track the loss, instead I just printed it at every epoch, I am going to use those losses to visualize the general loss during the training.
