In [3]:
import os
import cv2
import numpy as np
import pandas as pd
from glob import glob
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_absolute_error
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from tqdm import tqdm
from transformers import AutoFeatureExtractor, SwinForImageClassification
import random

# Configuration setup
class CFG:
    seed = 42
    batch_size = 8
    num_epochs = 20  # Increased epochs for better training
    model_names = ["microsoft/swin-tiny-patch4-window7-224", "microsoft/swin-base-patch4-window12-384"]
    input_sizes = [(224, 224), (384, 384)]
    learning_rate = 1e-4
    lr_scheduler_step = 5  # Step size for LR scheduler
    lr_scheduler_gamma = 0.1  # Multiplicative factor for LR decay

# Set seed for reproducibility
def set_seed(seed=42):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)
    print('> SEEDING DONE')

set_seed(CFG.seed)

# Define paths (update these paths according to your local setup)
train_image_path = './rgb'
leaf_count_path = 'leaf_count.csv'
test_image_path = '../Test_data'

# Load the leaf count CSV without headers
leaf_count_df = pd.read_csv(leaf_count_path, header=None)
print("Leaf count data without headers:")
print(leaf_count_df.head())

# Assuming the first column is the image filename and the second column is the leaf count
leaf_count_df.columns = ['image', 'leaf_count']
print("Leaf count data with assigned headers:")
print(leaf_count_df.head())

# Custom dataset class for loading images and leaf counts
class LeafCountDataset(Dataset):
    def __init__(self, image_files, leaf_counts, transform=None):
        self.image_files = image_files
        self.leaf_counts = leaf_counts
        self.transform = transform

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img = cv2.imread(self.image_files[idx])
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        if self.transform:
            img = self.transform(img)
        
        leaf_count = torch.tensor(self.leaf_counts[idx], dtype=torch.float32)
        return img, leaf_count

# Get image files and corresponding leaf counts
image_files = sorted(glob(os.path.join(train_image_path, '*.png')))
leaf_counts = leaf_count_df['leaf_count'].values

# Split into training and validation sets
train_img_files, val_img_files, train_leaf_counts, val_leaf_counts = train_test_split(image_files, leaf_counts, test_size=0.2, random_state=CFG.seed)

# Define transformations and create datasets and dataloaders
train_datasets = []
val_datasets = []
train_loaders = []
val_loaders = []

for input_size, model_name in zip(CFG.input_sizes, CFG.model_names):
    feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
    transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize(input_size),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.RandomRotation(10),
        transforms.ToTensor(),
        transforms.Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std),
    ])
    train_dataset = LeafCountDataset(train_img_files, train_leaf_counts, transform)
    val_dataset = LeafCountDataset(val_img_files, val_leaf_counts, transform)

    train_loader = DataLoader(train_dataset, batch_size=CFG.batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=CFG.batch_size, shuffle=False)

    train_datasets.append(train_dataset)
    val_datasets.append(val_dataset)
    train_loaders.append(train_loader)
    val_loaders.append(val_loader)

# Load models
models = []
optimizers = []
schedulers = []

for model_name in CFG.model_names:
    model = SwinForImageClassification.from_pretrained(model_name)
    model.classifier = nn.Linear(model.classifier.in_features, 1)  # Change the output layer to single regression value
    model = model.to('cuda')
    models.append(model)
    optimizer = optim.Adam(model.parameters(), lr=CFG.learning_rate)
    optimizers.append(optimizer)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=CFG.lr_scheduler_step, gamma=CFG.lr_scheduler_gamma)
    schedulers.append(scheduler)

# Define loss function and optimizer
criterion = nn.MSELoss()

# Training function
def train_model(models, train_loaders, val_loaders, criterion, optimizers, schedulers, num_epochs=10):
    for epoch in range(num_epochs):
        for model, train_loader, val_loader, optimizer, scheduler in zip(models, train_loaders, val_loaders, optimizers, schedulers):
            model.train()
            train_loss = 0.0

            for images, leaf_counts in tqdm(train_loader):
                images = images.to('cuda')
                leaf_counts = leaf_counts.to('cuda')

                optimizer.zero_grad()
                outputs = model(images).logits.squeeze(1)  # Get the logits and squeeze the output to match leaf_counts shape
                loss = criterion(outputs, leaf_counts)
                loss.backward()
                optimizer.step()

                train_loss += loss.item() * images.size(0)

            train_loss = train_loss / len(train_loader.dataset)
            scheduler.step()

            model.eval()
            val_loss = 0.0
            preds = []

            with torch.no_grad():
                for images, leaf_counts in val_loader:
                    images = images.to('cuda')
                    leaf_counts = leaf_counts.to('cuda')

                    outputs = model(images).logits.squeeze(1)
                    loss = criterion(outputs, leaf_counts)
                    val_loss += loss.item() * images.size(0)

                    preds.extend(outputs.cpu().numpy())

            val_loss = val_loss / len(val_loader.dataset)
            val_mae = mean_absolute_error(val_leaf_counts, preds)

            print(f'Epoch {epoch+1}/{num_epochs}, Model: {model_name}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val MAE: {val_mae:.4f}')

train_model(models, train_loaders, val_loaders, criterion, optimizers, schedulers, num_epochs=CFG.num_epochs)

# Save the trained models
for model, model_name in zip(models, CFG.model_names):
    sanitized_model_name = model_name.replace("/", "_")
    torch.save(model.state_dict(), f'{sanitized_model_name}_leaf_count.pth')

# Function to predict and save results for the test set
def predict_and_save_results(models, test_image_path, output_csv_path):
    test_image_files = sorted(glob(os.path.join(test_image_path, '*.png')))
    results = []

    # Initialize feature extractors and transforms for each model
    feature_extractors = [AutoFeatureExtractor.from_pretrained(model_name) for model_name in CFG.model_names]
    transforms_list = [transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize(input_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std),
    ]) for input_size, feature_extractor in zip(CFG.input_sizes, feature_extractors)]

    for img_file in tqdm(test_image_files):
        img = cv2.imread(img_file)
        img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        model_preds = []

        for model, transform in zip(models, transforms_list):
            img_transformed = transform(img_rgb)
            img_transformed = img_transformed.unsqueeze(0).to('cuda')

            with torch.no_grad():
                output = model(img_transformed).logits.cpu().numpy()[0, 0]
                model_preds.append(output)

        # Average predictions from all models and convert to integer
        avg_pred = np.mean(model_preds)
        results.append([os.path.basename(img_file), int(avg_pred)])

    # Save results to CSV
    results_df = pd.DataFrame(results, columns=['image', 'leaf_count'])
    results_df.to_csv(output_csv_path, index=False)

# Predict and save results for the test set
output_csv_path = 'leaf_count_predictions.csv'
predict_and_save_results(models, test_image_path, output_csv_path)

print(f"Predicted leaf counts saved to {output_csv_path}")



> SEEDING DONE
Leaf count data without headers:
           0  1
0  00001.png  6
1  00002.png  6
2  00003.png  6
3  00004.png  9
4  00005.png  7
Leaf count data with assigned headers:
       image  leaf_count
0  00001.png           6
1  00002.png           6
2  00003.png           6
3  00004.png           9
4  00005.png           7


100%|███████████████████████████████████████████| 20/20 [00:03<00:00,  6.37it/s]


Epoch 1/20, Model: microsoft/swin-base-patch4-window12-384, Train Loss: 314.7393, Val Loss: 210.8609, Val MAE: 9.8476


100%|███████████████████████████████████████████| 20/20 [00:05<00:00,  3.69it/s]


Epoch 1/20, Model: microsoft/swin-base-patch4-window12-384, Train Loss: 227.7129, Val Loss: 98.7308, Val MAE: 5.6275


100%|███████████████████████████████████████████| 20/20 [00:03<00:00,  6.45it/s]


Epoch 2/20, Model: microsoft/swin-base-patch4-window12-384, Train Loss: 169.4867, Val Loss: 103.3170, Val MAE: 5.7385


100%|███████████████████████████████████████████| 20/20 [00:05<00:00,  3.69it/s]


Epoch 2/20, Model: microsoft/swin-base-patch4-window12-384, Train Loss: 83.9297, Val Loss: 77.7683, Val MAE: 5.5132


100%|███████████████████████████████████████████| 20/20 [00:03<00:00,  6.42it/s]


Epoch 3/20, Model: microsoft/swin-base-patch4-window12-384, Train Loss: 87.0735, Val Loss: 79.2054, Val MAE: 5.1886


100%|███████████████████████████████████████████| 20/20 [00:05<00:00,  3.67it/s]


Epoch 3/20, Model: microsoft/swin-base-patch4-window12-384, Train Loss: 80.4751, Val Loss: 83.7389, Val MAE: 6.4636


100%|███████████████████████████████████████████| 20/20 [00:03<00:00,  6.46it/s]


Epoch 4/20, Model: microsoft/swin-base-patch4-window12-384, Train Loss: 74.1781, Val Loss: 75.5837, Val MAE: 5.1310


100%|███████████████████████████████████████████| 20/20 [00:05<00:00,  3.69it/s]


Epoch 4/20, Model: microsoft/swin-base-patch4-window12-384, Train Loss: 67.9567, Val Loss: 73.0018, Val MAE: 5.0291


100%|███████████████████████████████████████████| 20/20 [00:03<00:00,  6.36it/s]


Epoch 5/20, Model: microsoft/swin-base-patch4-window12-384, Train Loss: 71.6880, Val Loss: 74.4512, Val MAE: 4.9258


100%|███████████████████████████████████████████| 20/20 [00:05<00:00,  3.63it/s]


Epoch 5/20, Model: microsoft/swin-base-patch4-window12-384, Train Loss: 66.0474, Val Loss: 55.4830, Val MAE: 4.2748


100%|███████████████████████████████████████████| 20/20 [00:03<00:00,  6.36it/s]


Epoch 6/20, Model: microsoft/swin-base-patch4-window12-384, Train Loss: 65.4921, Val Loss: 74.4795, Val MAE: 4.9761


100%|███████████████████████████████████████████| 20/20 [00:05<00:00,  3.65it/s]


Epoch 6/20, Model: microsoft/swin-base-patch4-window12-384, Train Loss: 52.3544, Val Loss: 36.6657, Val MAE: 3.5295


100%|███████████████████████████████████████████| 20/20 [00:03<00:00,  6.47it/s]


Epoch 7/20, Model: microsoft/swin-base-patch4-window12-384, Train Loss: 65.3976, Val Loss: 73.6092, Val MAE: 4.9690


100%|███████████████████████████████████████████| 20/20 [00:05<00:00,  3.63it/s]


Epoch 7/20, Model: microsoft/swin-base-patch4-window12-384, Train Loss: 44.6263, Val Loss: 34.3216, Val MAE: 3.4867


100%|███████████████████████████████████████████| 20/20 [00:03<00:00,  6.39it/s]


Epoch 8/20, Model: microsoft/swin-base-patch4-window12-384, Train Loss: 66.1318, Val Loss: 73.8043, Val MAE: 5.0684


100%|███████████████████████████████████████████| 20/20 [00:05<00:00,  3.63it/s]


Epoch 8/20, Model: microsoft/swin-base-patch4-window12-384, Train Loss: 45.5771, Val Loss: 30.6722, Val MAE: 3.2346


100%|███████████████████████████████████████████| 20/20 [00:03<00:00,  6.40it/s]


Epoch 9/20, Model: microsoft/swin-base-patch4-window12-384, Train Loss: 66.8283, Val Loss: 73.0219, Val MAE: 4.8958


100%|███████████████████████████████████████████| 20/20 [00:05<00:00,  3.63it/s]


Epoch 9/20, Model: microsoft/swin-base-patch4-window12-384, Train Loss: 40.0696, Val Loss: 30.3510, Val MAE: 3.0435


100%|███████████████████████████████████████████| 20/20 [00:03<00:00,  6.24it/s]


Epoch 10/20, Model: microsoft/swin-base-patch4-window12-384, Train Loss: 62.1929, Val Loss: 72.0899, Val MAE: 4.9763


100%|███████████████████████████████████████████| 20/20 [00:05<00:00,  3.65it/s]


Epoch 10/20, Model: microsoft/swin-base-patch4-window12-384, Train Loss: 37.4481, Val Loss: 30.1623, Val MAE: 3.0552


100%|███████████████████████████████████████████| 20/20 [00:03<00:00,  6.38it/s]


Epoch 11/20, Model: microsoft/swin-base-patch4-window12-384, Train Loss: 64.2261, Val Loss: 72.2020, Val MAE: 4.9659


100%|███████████████████████████████████████████| 20/20 [00:05<00:00,  3.65it/s]


Epoch 11/20, Model: microsoft/swin-base-patch4-window12-384, Train Loss: 37.3413, Val Loss: 31.6280, Val MAE: 3.2305


100%|███████████████████████████████████████████| 20/20 [00:03<00:00,  6.27it/s]


Epoch 12/20, Model: microsoft/swin-base-patch4-window12-384, Train Loss: 60.9335, Val Loss: 70.2587, Val MAE: 4.7713


100%|███████████████████████████████████████████| 20/20 [00:05<00:00,  3.64it/s]


Epoch 12/20, Model: microsoft/swin-base-patch4-window12-384, Train Loss: 36.1070, Val Loss: 29.7477, Val MAE: 3.2897


100%|███████████████████████████████████████████| 20/20 [00:03<00:00,  6.30it/s]


Epoch 13/20, Model: microsoft/swin-base-patch4-window12-384, Train Loss: 62.8724, Val Loss: 71.0761, Val MAE: 4.8601


100%|███████████████████████████████████████████| 20/20 [00:05<00:00,  3.61it/s]


Epoch 13/20, Model: microsoft/swin-base-patch4-window12-384, Train Loss: 36.5880, Val Loss: 27.9277, Val MAE: 2.8856


100%|███████████████████████████████████████████| 20/20 [00:03<00:00,  6.21it/s]


Epoch 14/20, Model: microsoft/swin-base-patch4-window12-384, Train Loss: 61.9965, Val Loss: 72.1079, Val MAE: 5.0062


100%|███████████████████████████████████████████| 20/20 [00:05<00:00,  3.66it/s]


Epoch 14/20, Model: microsoft/swin-base-patch4-window12-384, Train Loss: 37.2901, Val Loss: 29.9963, Val MAE: 3.2192


100%|███████████████████████████████████████████| 20/20 [00:03<00:00,  6.19it/s]


Epoch 15/20, Model: microsoft/swin-base-patch4-window12-384, Train Loss: 60.0574, Val Loss: 69.8050, Val MAE: 4.7691


100%|███████████████████████████████████████████| 20/20 [00:05<00:00,  3.65it/s]


Epoch 15/20, Model: microsoft/swin-base-patch4-window12-384, Train Loss: 36.6846, Val Loss: 29.8946, Val MAE: 3.0834


100%|███████████████████████████████████████████| 20/20 [00:03<00:00,  6.28it/s]


Epoch 16/20, Model: microsoft/swin-base-patch4-window12-384, Train Loss: 61.6622, Val Loss: 70.8664, Val MAE: 5.0113


100%|███████████████████████████████████████████| 20/20 [00:05<00:00,  3.67it/s]


Epoch 16/20, Model: microsoft/swin-base-patch4-window12-384, Train Loss: 38.4431, Val Loss: 28.9947, Val MAE: 2.9612


100%|███████████████████████████████████████████| 20/20 [00:03<00:00,  6.30it/s]


Epoch 17/20, Model: microsoft/swin-base-patch4-window12-384, Train Loss: 61.4796, Val Loss: 69.8635, Val MAE: 4.7724


100%|███████████████████████████████████████████| 20/20 [00:05<00:00,  3.62it/s]


Epoch 17/20, Model: microsoft/swin-base-patch4-window12-384, Train Loss: 35.9261, Val Loss: 28.5490, Val MAE: 3.0948


100%|███████████████████████████████████████████| 20/20 [00:03<00:00,  6.39it/s]


Epoch 18/20, Model: microsoft/swin-base-patch4-window12-384, Train Loss: 63.4836, Val Loss: 69.0650, Val MAE: 4.7453


100%|███████████████████████████████████████████| 20/20 [00:05<00:00,  3.68it/s]


Epoch 18/20, Model: microsoft/swin-base-patch4-window12-384, Train Loss: 38.0803, Val Loss: 28.5137, Val MAE: 2.9250


100%|███████████████████████████████████████████| 20/20 [00:03<00:00,  6.27it/s]


Epoch 19/20, Model: microsoft/swin-base-patch4-window12-384, Train Loss: 63.5447, Val Loss: 70.6239, Val MAE: 4.8539


100%|███████████████████████████████████████████| 20/20 [00:05<00:00,  3.62it/s]


Epoch 19/20, Model: microsoft/swin-base-patch4-window12-384, Train Loss: 37.7125, Val Loss: 26.7612, Val MAE: 2.8519


100%|███████████████████████████████████████████| 20/20 [00:03<00:00,  6.45it/s]


Epoch 20/20, Model: microsoft/swin-base-patch4-window12-384, Train Loss: 68.0519, Val Loss: 70.6490, Val MAE: 4.9234


100%|███████████████████████████████████████████| 20/20 [00:05<00:00,  3.67it/s]


Epoch 20/20, Model: microsoft/swin-base-patch4-window12-384, Train Loss: 36.0058, Val Loss: 28.6489, Val MAE: 2.9828


100%|███████████████████████████████████████████| 68/68 [00:03<00:00, 20.07it/s]

Predicted leaf counts saved to leaf_count_predictions.csv



