In [1]:
# Utils
from classification.model import *

## Зафиксируем **seed**

In [2]:
set_all_seeds()

## Data

### **Classification dataset**

In [None]:
from PIL import Image
import torchvision.transforms as T

class ImageClassificationDataset(torch.utils.data.Dataset):
    transform = T.Compose([
        T.Resize((224, 224)),
        T.ToTensor(),
        T.Normalize(mean=0.5, std=0.5),
    ])

    def __init__(self, image_paths, labels=None, augmentation=None):
        self.image_paths = image_paths
        self.labels = labels

        self.augmentation = augmentation

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        # Считываем изображение
        image_path = self.image_paths[idx]
        image_pil = Image.open(image_path).convert("RGB")

        # Приминяем аугментации, если есть
        if self.augmentation is not None:
            image_pil = self.augmentation(image_pil)

        # Трансформируем изображение в tensor
        image_tensor = self.transform(image_pil)
        result = {'args': [image_tensor]}

        # Добавляем label, если есть
        if self.labels is not None:
            label = self.labels[idx]
            label_tensor = torch.tensor(label, dtype=torch.long)
            result['labels'] = label_tensor

        return result
    
    def get_item(self, idx):
        image_path = self.image_paths[idx]
        image_pil = Image.open(image_path)

        result = {'image': image_pil}
        if self.labels is not None:
            result['label'] = self.labels[idx]

        return result # image, (label)

In [None]:
augmentation = T.Compose([
    # Добавьте своих аугментаций (по желанию)
])

### Preparation

In [4]:
classes = list() # название классов

In [None]:
image_paths = list() # пути к изображениям
labels = list() # метки, соответствующие изображениям

### **Split**

In [None]:
from sklearn.model_selection import train_test_split

train_image_paths, valid_image_paths, train_labels, valid_labels = train_test_split(image_paths, labels, test_size=0.2, random_state=42) # stratify=labels

### Create **Datasets**

In [None]:
dataset = ImageClassificationDataset(image_paths, labels)

train_set = ImageClassificationDataset(train_image_paths, train_labels, augmentation)
valid_set = ImageClassificationDataset(valid_image_paths, valid_labels)

### Create **DataLoader**

In [None]:
batch_size = 32
num_workers = 4

train_loader = DataLoader(train_set, batch_size=batch_size, num_workers=num_workers, shuffle=True)
valid_loader = DataLoader(valid_set, batch_size=batch_size, num_workers=num_workers, shuffle=False)

### ***Visualization***

In [None]:
show_classification(dataset, classes=classes)

## **Models**

In [11]:
from torchvision import models
from transformers import AutoModelForImageClassification

from torch_lr_finder import LRFinder

In [None]:
def find_lr(model_wrapped):
    loss_fn = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model_wrapped.parameters(), lr=1e-7)

    # Создаем адаптированный DataLoader
    class AdaptedDataLoader(torch.utils.data.DataLoader):
        def __init__(self, dataloader):
            self.dataloader = dataloader
            self.iterator = iter(dataloader)
        
        def __iter__(self):
            self.iterator = iter(self.dataloader)
            return self
        
        def __next__(self):
            batch = next(self.iterator)
            return batch['args'][0], batch['labels']

    adapted_loader = AdaptedDataLoader(train_loader)
    lr_finder = LRFinder(model_wrapped, optimizer, loss_fn, device=model_wrapped.device)
    lr_finder.range_test(adapted_loader, end_lr=1, num_iter=100)
    lr_finder.plot()
    lr_finder.reset()

In [None]:
class CustomOutput(nn.Module):
    def __init__(self, model, output_transform=lambda out: out.logits):
        super().__init__()
        self.model = model
        self.output_transform = output_transform

    def forward(self, *args, **kwargs):
        return self.output_transform(self.model(*args, **kwargs))

    def __getattr__(self, name):
        if name in ('model', 'output_transform'):
            return super().__getattr__(name)
        return getattr(self.model, name)
    
    def __setattr__(self, name, value):
        if name in ('model', 'output_transform'):
            super().__setattr__(name, value)
        else:
            setattr(self.model, name, value)

### *Score*

In [None]:
scores = dict()

### **Model**: `google/vit-base-patch16-224`

In [None]:
model = CustomOutput(AutoModelForImageClassification.from_pretrained("google/vit-base-patch16-224", num_labels=len(classes), ignore_mismatched_sizes=True))
optimizer = optim.Adam(model.parameters(), lr=5e-5)

model_wrapped = Classifier(model, "Google-VitBase", optimizer)

In [None]:
model_wrapped.fit(train_loader, valid_loader, 15)

In [None]:
scores[model_wrapped.best_score] = model_wrapped

### **Model**: `EfficientNet_B0`

In [None]:
model = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.IMAGENET1K_V1)
model.classifier[-1] = nn.Linear(model.classifier[-1].in_features, len(classes))

model_wrapped = Classifier(model, "EfficientNet_B0")

In [None]:
model_wrapped.fit(train_loader, valid_loader, 10)

In [None]:
scores[model_wrapped.best_score] = model_wrapped

## Result

In [None]:
best_model_wrapped = scores[max(scores)]
best_model_wrapped.name

In [None]:
n = (2, 4)  # Кортеж (rows, cols)
fig_image_size = 5

fig, axes = plt.subplots(n[0], n[1], figsize=(fig_image_size * n[1], fig_image_size * n[0]))

for i in range(n[0]):
    for j in range(n[1]):
        idx = random.randrange(len(valid_set))
        batch = valid_set[idx]
        batch['args'][0] = batch['args'][0].unsqueeze(0)
        prediction = best_model_wrapped.predict(batch)

        ax = axes[i][j]
        ax.imshow(denormalize(batch['args'][0].squeeze(), mean, std).cpu().numpy().transpose(1, 2, 0))
        ax.axis('off')
        ax.set_title(f"Class: {classes[batch['labels']]}\nPredict: {classes[prediction]}", fontsize=10)

plt.tight_layout()
plt.show()

## Submission

In [None]:
test_dir = ""

test_image_names = os.listdir(test_dir)
test_image_paths = list(map(lambda image_name: f"{test_dir}/{image_name}", test_image_names))

test_set = ImageClassificationDataset(test_image_paths)

In [None]:
predict_class_id = best_model_wrapped.predict(test_set)
predict_class_names = [classes[class_id] for class_id in predict_class_id]