In [None]:
import os
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from sklearn.model_selection import train_test_split
from tqdm import tqdm
# Xception is part of the timm library since pretrainedmodels is deprecated
from timm import create_model  

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
train_data = pd.read_csv("./data/train_images.csv")
train_data['image_path'] = train_data['image_path'].apply(lambda x: x.strip())
train_data['image_path'] = 'data/' + train_data['image_path']

test_data = pd.read_csv("./data/test_images_path.csv")

train_df, val_df = train_test_split(train_data, test_size=0.2, random_state=42)

print(f"Train size: {len(train_df)}, Validation size: {len(val_df)}")

In [None]:
class BirdieDataset(Dataset):
    def __init__(self, dataframe, transform=None):
        self.data = dataframe
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        img_path = self.data.iloc[index, 0]
        image = Image.open(img_path).convert("RGB")
         # 0-based indexing
        label = int(self.data.iloc[index, 1]) - 1 

        if self.transform:
            image = self.transform(image)
        return image, label

class TestDataset(Dataset):
    def __init__(self, dataframe, transform=None):
        self.data = dataframe
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        img_path = self.data['image_path']
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image

In [None]:
# transformations
train_transform = transforms.Compose([
    # Xception requires input size 299x299
    transforms.Resize((299, 299)),  
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

test_transform = transforms.Compose([
    transforms.Resize((299, 299)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

In [None]:
# datasets
train_dataset = BirdieDataset(train_df, transform=train_transform)
val_dataset = BirdieDataset(val_df, transform=test_transform)

test_dataset = TestDataset(test_data, transform=test_transform)

# dataloaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [None]:
# model setup
model = create_model('xception', pretrained=True, num_classes=train_df['label'].nunique())
model = model.to(device)

# loss, optimizer, and scheduler
criterion = nn.CrossEntropyLoss()
# L2 regularization
optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=2)

In [None]:
def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, epochs=5):
    best_acc = 0.0
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        for images, labels in tqdm(train_loader):
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

        train_acc = 100. * correct / total
        val_acc = evaluate_model(model, val_loader)
        scheduler.step(val_acc)

        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss:.4f}, Train Acc: {train_acc:.2f}%, Val Acc: {val_acc:.2f}%")
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(model.state_dict(), "best_xception_model.pth")
    print("Training complete. Best Validation Accuracy:", best_acc)

def evaluate_model(model, loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    return 100. * correct / total

In [None]:
# Train the model
train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, epochs=10)