In [113]:
import os
import json
from PIL import Image
from tqdm import tqdm

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from torchvision import transforms
from transformers import AutoImageProcessor, AutoModel

from sklearn.model_selection import train_test_split
import warnings
warnings.filterwarnings('ignore')

In [114]:
CONFIG = {
    'train_dir': '/kaggle/input/one-piece-classification-2025/splitted/',
    'test_dir': '/kaggle/input/one-piece-classification-2025/splitted/test',
    'labels_json': '/kaggle/input/one-piece-classification-2025/labels.json',
    'train_annotations': '/kaggle/input/one-piece-classification-2025/train_annotations.csv',
    'submission_csv': '/kaggle/input/one-piece-classification-2025/submission.csv',
    'output_dir': '/kaggle/working/',
    'backbone_name': 'facebook/vit-mae-base',
    'img_size': 224,
    'batch_size': 32,
    'epochs': 10,
    'learning_rate': 1e-3,
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    'num_workers': 4,
    'validation_split': 0.2,
    'random_seed': 42,
    'classifier_hidden': 512,
}

In [115]:
torch.manual_seed(CONFIG['random_seed'])
np.random.seed(CONFIG['random_seed'])

In [116]:
with open(CONFIG['labels_json'], 'r') as f:
    label_map = json.load(f)
train_df = pd.read_csv(CONFIG['train_annotations'])
submission_df = pd.read_csv(CONFIG['submission_csv'])

id_to_name = {int(k): v for k, v in label_map.items()}
name_to_id = {v: int(k) for k, v in label_map.items()}
num_classes = len(id_to_name)

In [117]:
def fix_path(x, root):
    x = x.replace("\\", "/")
    return f"{root}{x}"


train_df["file_path"] = train_df["image_path"].apply(
    lambda x: fix_path(x, CONFIG["train_dir"])
)

In [192]:
class CharacterDataset(Dataset):
    def __init__(self, df, img_dir, augment_transforms=None, is_test=False):
        self.df = df.reset_index(drop=True)
        self.img_dir = img_dir
        self.augment_transforms = augment_transforms
        self.is_test = is_test
        self.extensions = ["png", "jpg", "jpeg"]

    def __len__(self):
        return len(self.df)

    def _resolve_path(self, base_path_without_ext):
        for ext in self.extensions:
            candidate = f"{base_path_without_ext}.{ext}"
            if os.path.exists(candidate):
                return candidate
        raise FileNotFoundError(f"No image found for {base_path_without_ext}")

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        if self.is_test:
            base = os.path.join(self.img_dir, str(row['id']))
        else:
            base = os.path.splitext(row['file_path'])[0]
        img_path = self._resolve_path(base)
        image = Image.open(img_path).convert("RGB")
            
        image = self.augment_transforms(image)

        if self.is_test:
            return image, row['id']
        else:
            label = int(row['label'])
            return image, label

In [119]:
train_aug = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.05),
    transforms.RandomResizedCrop(CONFIG['img_size'], scale=(0.8, 1.0), ratio=(0.9, 1.1)),
])
val_aug = transforms.Compose([
    transforms.Resize((CONFIG['img_size'], CONFIG['img_size'])),
])

In [174]:
class ProcessorCollator:
    def __init__(self, processor, is_test=False):
        self.processor = processor
        self.is_test = is_test

    def __call__(self, batch):
        images = [item[0] for item in batch]
        targets = [item[1] for item in batch]
        processed = self.processor(images=images, return_tensors='pt')
        pixel_values = processed['pixel_values']
        if self.is_test:
            return pixel_values, targets
        else:
            return pixel_values, torch.tensor(targets, dtype=torch.long)

In [121]:
class MAEClassifier(nn.Module):
    def __init__(self, backbone_model, feat_dim, num_classes, classifier_hidden=None):
        super().__init__()
        self.backbone = backbone_model
        self.classifier = nn.Sequential(
            nn.Linear(feat_dim, classifier_hidden),
            nn.ReLU(inplace=True),
            nn.Dropout(0.2),
            nn.Linear(classifier_hidden, num_classes)
        )

    def forward(self, pixel_values, return_feats=False):
        outputs = self.backbone(pixel_values=pixel_values, return_dict=True)
        patches = outputs.last_hidden_state[:, 1:, :]
        emb = patches.mean(dim=1)
        logits = self.classifier(emb)
        return logits

In [132]:
def train_one_epoch(model, backbone, train_loader, optimizer, criterion, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    pbar = tqdm(train_loader, desc='Train', leave=False)
    for pixel_values, labels in pbar:
        pixel_values = pixel_values.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        
        logits = model(pixel_values)

        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, preds = torch.max(logits, dim=1)
        total += labels.size(0)
        correct += preds.eq(labels).sum().item()
        pbar.set_postfix({'loss': running_loss / (len(pbar)), 'acc': 100. * correct / total})

    return running_loss / len(train_loader), 100. * correct / total

def validate(model, val_loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        pbar = tqdm(val_loader, desc='Val ', leave=False)
        for pixel_values, labels in pbar:
            pixel_values = pixel_values.to(device)
            labels = labels.to(device)

            logits = model(pixel_values)
            loss = criterion(logits, labels)

            running_loss += loss.item()
            _, preds = torch.max(logits, dim=1)
            total += labels.size(0)
            correct += preds.eq(labels).sum().item()
            pbar.set_postfix({'loss': running_loss / (len(pbar)), 'acc': 100. * correct / total})

    return running_loss / len(val_loader), 100. * correct / total

In [133]:
train_data_df, val_data_df = train_test_split(
    train_df,
    test_size=CONFIG['validation_split'],
    random_state=CONFIG['random_seed'],
    stratify=train_df['label']
)
print(f"Train samples: {len(train_data_df)}, Val samples: {len(val_data_df)}")

Train samples: 2332, Val samples: 584


In [124]:
processor = AutoImageProcessor.from_pretrained(CONFIG['backbone_name'])
backbone = AutoModel.from_pretrained(CONFIG['backbone_name'])
backbone.to(CONFIG['device'])
backbone.eval()

Fast image processor class <class 'transformers.models.vit.image_processing_vit_fast.ViTImageProcessorFast'> is available for this model. Using slow image processor class. To use the fast image processor class set `use_fast=True`.


ViTMAEModel(
  (embeddings): ViTMAEEmbeddings(
    (patch_embeddings): ViTMAEPatchEmbeddings(
      (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    )
  )
  (encoder): ViTMAEEncoder(
    (layer): ModuleList(
      (0-11): 12 x ViTMAELayer(
        (attention): ViTMAEAttention(
          (attention): ViTMAESelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
          )
          (output): ViTMAESelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
        )
        (intermediate): ViTMAEIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
          (intermediate_act_fn): GELUActivation()
        )
        (output): ViTMAEOutput(
          (dense): Linear(i

In [193]:
train_dataset = CharacterDataset(train_data_df, CONFIG['train_dir'], augment_transforms=train_aug, is_test=False)
val_dataset = CharacterDataset(val_data_df, CONFIG['train_dir'], augment_transforms=val_aug, is_test=False)
test_dataset = CharacterDataset(submission_df, CONFIG['test_dir'], augment_transforms=val_aug, is_test=True)

In [194]:
train_collator = ProcessorCollator(processor, is_test=False)
val_collator = ProcessorCollator(processor, is_test=False)
test_collator = ProcessorCollator(processor, is_test=True)

In [195]:
train_loader = DataLoader(
    train_dataset, batch_size=CONFIG['batch_size'], shuffle=True, num_workers=CONFIG['num_workers'],
    collate_fn=train_collator, pin_memory=True,
)
val_loader = DataLoader(
    val_dataset, batch_size=CONFIG['batch_size'], shuffle=False, num_workers=CONFIG['num_workers'],
    collate_fn=val_collator, pin_memory=True,
)
test_loader = DataLoader(
    test_dataset, batch_size=CONFIG['batch_size'], shuffle=False, num_workers=CONFIG['num_workers'],
    collate_fn=test_collator, pin_memory=True,
)

In [137]:
model = MAEClassifier(
    backbone_model=backbone, feat_dim=768, num_classes=num_classes,
    classifier_hidden=CONFIG['classifier_hidden']).to(CONFIG['device'])

In [138]:
for param in model.backbone.parameters():
    param.requires_grad = False

for param in model.classifier.parameters():
    param.requires_grad = True

In [144]:
params = model.classifier.parameters()
optimizer = torch.optim.AdamW(params, lr=CONFIG['learning_rate'])

criterion = nn.CrossEntropyLoss()
best_val_acc = 0.0
history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}

for epoch in range(CONFIG['epochs']):
    print(f"\nEpoch {epoch+1}/{CONFIG['epochs']}")

    train_loss, train_acc = train_one_epoch(model, backbone, train_loader, optimizer, criterion, CONFIG['device'])
    val_loss, val_acc = validate(model, val_loader, criterion, CONFIG['device'])

    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)

    print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
    print(f"Val Loss: {val_loss:.4f}, Val   Acc: {val_acc:.2f}%")


Epoch 1/10


                                                                             

Train Loss: 0.3668, Train Acc: 87.86%
Val Loss: 0.9385, Val   Acc: 71.23%

Epoch 2/10


                                                                             

Train Loss: 0.3091, Train Acc: 90.44%
Val Loss: 0.9727, Val   Acc: 74.14%

Epoch 3/10


                                                                             

Train Loss: 0.3339, Train Acc: 89.28%
Val Loss: 0.9936, Val   Acc: 71.06%

Epoch 4/10


                                                                             

Train Loss: 0.3076, Train Acc: 90.48%
Val Loss: 0.9577, Val   Acc: 72.43%

Epoch 5/10


                                                                             

Train Loss: 0.3251, Train Acc: 89.97%
Val Loss: 0.9973, Val   Acc: 72.43%

Epoch 6/10


                                                                             

Train Loss: 0.3343, Train Acc: 89.79%
Val Loss: 0.9551, Val   Acc: 72.60%

Epoch 7/10


                                                                             

Train Loss: 0.3131, Train Acc: 90.14%
Val Loss: 1.0382, Val   Acc: 72.77%

Epoch 8/10


                                                                             

Train Loss: 0.3187, Train Acc: 89.62%
Val Loss: 0.9273, Val   Acc: 73.80%

Epoch 9/10


                                                                             

Train Loss: 0.2931, Train Acc: 91.42%
Val Loss: 1.0634, Val   Acc: 71.40%

Epoch 10/10


                                                                             

Train Loss: 0.3097, Train Acc: 90.01%
Val Loss: 0.9955, Val   Acc: 73.63%




In [201]:
model.eval()
all_ids = []
all_preds = []
with torch.no_grad():
    for pixel_values, ids in tqdm(test_loader, desc='Predict'):
        pixel_values = pixel_values.to(CONFIG['device'])
        logits = model(pixel_values)
        _, preds = torch.max(logits, dim=1)
        all_ids.extend(ids)
        all_preds.extend(preds.cpu().numpy().tolist())

submission = pd.DataFrame({'id': all_ids, 'label': all_preds})
submission_path = os.path.join(CONFIG['output_dir'], 'submission.csv')
submission.to_csv(submission_path, index=False)

Predict: 100%|██████████| 27/27 [00:05<00:00,  4.55it/s]
