# **Libs**

In [1]:
# Torch
from torchvision import transforms, models

# Остальное
from sklearn.model_selection import train_test_split

# Utils
from image_classification.model import *

# **Code**

## Зафиксируем seed

In [None]:
set_all_seeds()

## Data

### **Transformation** and **augmentation**

In [None]:
image_transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std)
])

image_augmentation = transforms.Compose([
    image_transform
])

### Reading

In [None]:
classes = list()

In [None]:
image_paths = list()
labels = list()

### Split

In [None]:
train_image_paths, valid_image_paths, train_labels, valid_labels = train_test_split(image_paths, labels, test_size=0.2, random_state=42, stratify=labels)

### Create **Datasets**

In [None]:
dataset = ImageClassificationDataset(image_paths, labels, transform=image_transform)

train_set = ImageClassificationDataset(train_image_paths, train_labels, transform=image_augmentation)
valid_set = ImageClassificationDataset(valid_image_paths, valid_labels, transform=image_transform)

### Create **DataLoader**

In [None]:
batch_size = 24

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(valid_set, batch_size=batch_size, shuffle=True)

### Visualization

In [None]:
show_images(dataset, classes=classes)

## Models

### Score

In [None]:
scores = dict()

### EfficientNet_B0

In [None]:
model = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.IMAGENET1K_V1)
model.classifier[-1] = nn.Linear(model.classifier[-1].in_features, len(classes))

model_wrapped = ImageClassifier(model, "EfficientNet_B0")

In [None]:
model_wrapped.fit(train_loader, valid_loader, 10)

In [None]:
scores[model_wrapped.best_score] = model_wrapped

## Result

In [None]:
best_model_wrapped = scores[max(scores)]
best_model_wrapped.name

In [None]:
n = 3

fig, axes = plt.subplots(n, 1, figsize=(5, 5 * n))

for i, idx in enumerate(random.sample(range(len(valid_set)), n)):
    image, label = valid_set[idx]
    prediction = best_model_wrapped.predict(image)

    ax = axes[i]
    ax.imshow(denormalize(image).cpu().numpy().transpose(1, 2, 0))
    ax.axis('off')
    ax.set_title(f"Class: {classes[label]}\nPredict: {classes[prediction]}", fontsize=10)

plt.tight_layout()
plt.show()

## Submission

In [None]:
test_image_paths = list()
test_set = ImageDataset(test_image_paths, transform=image_transform)

In [None]:
predict_class_id = best_model_wrapped.predict(test_set)
predict_class_names = [classes[class_id] for class_id in predict_class_id]