In [None]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision.models import mobilenet_v2, MobileNet_V2_Weights
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms.autoaugment import AutoAugment, AutoAugmentPolicy
import pandas as pd
from PIL import Image
import os
import time


In [None]:
class BirdsDataset(Dataset):
    def __init__(self, root_dir, csv_file=None, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.csv_file = csv_file

        if csv_file is not None:
            self.annotations = pd.read_csv(csv_file)
        else:
            self.img_names = [f for f in os.listdir(root_dir) if os.path.isfile(os.path.join(root_dir, f))]

    def __len__(self):
        if self.csv_file is not None:
            return len(self.annotations)
        else:
            return len(self.img_names)

    def __getitem__(self, index):
        if self.csv_file is not None:
            img_name = self.annotations.iloc[index, 0]
            y_label = torch.tensor(int(self.annotations.iloc[index, 1]))
        else:
            img_name = self.img_names[index]
            y_label = torch.tensor([])

        img_path = os.path.join(self.root_dir, img_name)
        image = Image.open(img_path).convert("RGB")

        if self.transform:
            image = self.transform(image)

        return image, y_label, img_name


In [None]:
def train_classifier(train_gt, train_img_dir, fast_train=True):
    start_time = time.time()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    transform_train = transforms.Compose([
        transforms.Resize((224, 224)),
        AutoAugment(AutoAugmentPolicy.IMAGENET),
        transforms.ToTensor(),
    ])

    train_dataset = BirdsDataset(csv_file=train_gt, root_dir=train_img_dir, transform=transform_train)
    train_loader = DataLoader(dataset=train_dataset, batch_size=256, shuffle=True, pin_memory=True, prefetch_factor=2, num_workers=4)

    model = mobilenet_v2(weights=MobileNet_V2_Weights.IMAGENET1K_V1)
    
    for param in model.features[:-6].parameters():
        param.requires_grad = False

    num_classes = len(set(train_dataset.annotations['class_id']))
    model.classifier = nn.Sequential(
        nn.Dropout(0.2),
        nn.Linear(model.classifier[1].in_features, 512),
        nn.ReLU(),
        nn.BatchNorm1d(512),
        nn.Dropout(0.2),
        nn.Linear(512, 256),
        nn.ReLU(),
        nn.BatchNorm1d(256),
        nn.Dropout(0.2),
        nn.Linear(256, num_classes),
    )
    
    model.to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

    num_epochs = 1 if fast_train else 10
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for inputs, labels, _ in train_loader:
            inputs = inputs.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        epoch_loss = running_loss / len(train_loader)
        epoch_acc = correct / total
        print(f'Epoch {epoch+1}, Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.4f}')
        scheduler.step()

    end_time = time.time()
    print(f"Training completed in {end_time - start_time:.2f} seconds")
    if not fast_train:
        torch.save({
            'model_state_dict': model.state_dict(),
            'num_classes': num_classes
        }, 'birds_model.ckpt')
    return model


In [None]:
def classify(model_path, test_img_dir):
    start_time = time.time()
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
    ])

    test_dataset = BirdsDataset(root_dir=test_img_dir, transform=transform)
    test_loader = DataLoader(dataset=test_dataset, batch_size=256, shuffle=False, pin_memory=True, prefetch_factor=2, num_workers=4)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    checkpoint = torch.load(model_path, map_location=device)
    num_classes = checkpoint['num_classes']
    
    model = mobilenet_v2(weights=None)
    model.classifier = nn.Sequential(
        nn.Dropout(0.2),
        nn.Linear(model.last_channel, 512),
        nn.ReLU(),
        nn.BatchNorm1d(512),
        nn.Dropout(0.2),
        nn.Linear(512, 256),
        nn.ReLU(),
        nn.BatchNorm1d(256),
        nn.Dropout(0.2),
        nn.Linear(256, num_classes),
    )
    
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)
    model.eval()
    
    predictions = {}
    with torch.no_grad():
        for images, labels, img_names in test_loader:
            images = images.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            for img_name, prediction in zip(img_names, predicted):
                predictions[img_name] = prediction.item()

    print(f"Classification completed in {time.time() - start_time:.2f} seconds")
    return predictions


In [None]:
train_csv_path = '/kaggle/input/public-tests-2/00_test_img_input/train/gt.csv'
train_images_dir = '/kaggle/input/public-tests-2/00_test_img_input/train/images'
test_images_dir = '/kaggle/input/public-tests-2/00_test_img_input/test/images'
model_save_path = '/kaggle/working/birds_model.ckpt'


In [None]:
model = train_classifier(train_gt=train_csv_path, train_img_dir=train_images_dir, fast_train=False)


In [None]:
predictions = classify(model_path=model_save_path, test_img_dir=test_images_dir)


In [None]:
# for img_name, pred in predictions.items():
#     print(f"Image: {img_name}, Predicted class: {pred}")
