In [1]:
# Utils
from classification.model import *

## Зафиксируем **seed**

In [2]:
set_all_seeds()

## Data

In [None]:
# Параметры для изображения
mean, std = 0.5, 0.5

### **Classification dataset**

In [None]:
from PIL import Image
from torchvision import transforms

class ClassificationDataset(torch.utils.data.Dataset):
    def __init__(self, image_paths, labels, transform, augmentation=None):
        self.image_paths = image_paths
        self.labels = labels

        self.transform = transform
        self.augmentation = augmentation

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        label = self.labels[idx]

        image_pil = Image.open(image_path).convert("RGB")

        if self.augmentation is not None:
            image_pil = self.augmentation(image_pil)

        image_tensor = self.transform(image_pil)
        label_tensor = torch.tensor(label, dtype=torch.long)

        return image_tensor, label_tensor

In [None]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std),
])

augmentation = transforms.Compose([
    
])

### Preparation

In [4]:
classes = list()

In [None]:
image_paths = list()
labels = list()

### **Split**

In [None]:
from sklearn.model_selection import train_test_split

train_data, valid_data, train_labels, valid_labels = train_test_split(image_paths, labels, test_size=0.2, random_state=42) # stratify=labels

### Create **Datasets**

In [None]:
dataset = ClassificationDataset(image_paths, labels, transform)

train_set = ClassificationDataset(train_data, train_labels, transform, augmentation)
valid_set = ClassificationDataset(valid_data, valid_labels, transform)

### Create **DataLoader**

In [None]:
batch_size = 24

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(valid_set, batch_size=batch_size, shuffle=True)

### ***Visualization***

In [None]:
show_classification(dataset, mean, std, classes=classes)

## **Models**

In [11]:
from torchvision import models
from transformers import AutoModelForImageClassification

from torch_lr_finder import LRFinder

In [None]:
def find_lr(model):
    loss_fn = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-7)

    lr_finder = LRFinder(model, optimizer, loss_fn, device=model.device)
    lr_finder.range_test(train_loader, end_lr=1, num_iter=100)
    lr_finder.plot()
    lr_finder.reset()

In [None]:
class CustomOutput(nn.Module):
    def __init__(self, model, output_transform=lambda out: out.logits):
        super().__init__()
        self.model = model
        self.output_transform = output_transform

    def forward(self, x):
        return self.output_transform(self.model(x))

    def __getattr__(self, name):
        if name in ('model', 'output_transform'):
            return super().__getattr__(name)
        return getattr(self.model, name)
    
    def __setattr__(self, name, value):
        if name in ('model', 'output_transform'):
            super().__setattr__(name, value)
        else:
            setattr(self.model, name, value)

### *Score*

In [None]:
scores = dict()

### **Model**: `google/vit-base-patch16-224`

In [None]:
image_size = (224, 224)
transform.transforms[0] = transforms.Resize(image_size)

In [None]:
model = CustomOutput(AutoModelForImageClassification.from_pretrained("google/vit-base-patch16-224", num_labels=len(classes), ignore_mismatched_sizes=True))
optimizer = optim.Adam(model.parameters(), lr=5e-5)

model_wrapped = Classifier(model, "Google-VitBase", optimizer)

In [None]:
model_wrapped.fit(train_loader, valid_loader, 15)

In [None]:
scores[model_wrapped.best_score] = model_wrapped

### **Model**: `EfficientNet_B0`

In [None]:
image_size = (224, 224)
transform.transforms[0] = transforms.Resize(image_size)

In [None]:
model = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.IMAGENET1K_V1)
model.classifier[-1] = nn.Linear(model.classifier[-1].in_features, len(classes))

model_wrapped = Classifier(model, "EfficientNet_B0")

In [None]:
model_wrapped.fit(train_loader, valid_loader, 10)

In [None]:
scores[model_wrapped.best_score] = model_wrapped

## Result

In [None]:
best_model_wrapped = scores[max(scores)]
best_model_wrapped.name

In [None]:
# Сделать вывод не только по строчкам, но и по столбцам

n = 3
fig_image_size = 5

fig, axes = plt.subplots(n, 1, figsize=(fig_image_size, fig_image_size * n))

for i, idx in enumerate(random.sample(range(len(valid_set)), n)):
    image, label = valid_set[idx]
    prediction = best_model_wrapped.predict(image)

    ax = axes[i]
    ax.imshow(denormalize(image, mean, std).cpu().numpy().transpose(1, 2, 0))
    ax.axis('off')
    ax.set_title(f"Class: {classes[label]}\nPredict: {classes[prediction]}", fontsize=10)

plt.tight_layout()
plt.show()

## Submission

In [None]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, image_paths, transform):
        self.image_paths = image_paths
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_pil = Image.open(self.image_paths[idx])
        return self.transform(image_pil)

In [None]:
test_dir = ""

test_image_names = os.listdir(test_dir)
test_image_paths = list(map(lambda image_name: f"{test_dir}/{image_name}", test_image_names))

test_set = Dataset(test_image_paths, transform)

In [None]:
predict_class_id = best_model_wrapped.predict(test_set)
predict_class_names = [classes[class_id] for class_id in predict_class_id]

In [None]:
# decode and save...