In [1]:
import pandas as pd
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
import json
from tqdm.notebook import tqdm

def calculate_stats(dataset):
    loader = DataLoader(dataset, batch_size=100, num_workers=0, shuffle=False)
    mean = 0.
    std = 0.
    total_samples = len(dataset)
    
    # Create a tqdm progress bar
    pbar = tqdm(total=total_samples, desc="Calculating Stats", unit="sample")
    
    for batch in loader:
        images = batch[0]  # Assuming images are always the first element
        batch_samples = images.size(0)
        images = images.view(batch_samples, images.size(1), -1)
        mean += images.mean(2).sum(0)
        std += images.std(2).sum(0)
        
        # Update the progress bar
        pbar.update(batch_samples)
   
    mean /= total_samples
    std /= total_samples
    
    # Close the progress bar
    pbar.close()
    
    return mean, std

def save_stats(mean, std, filename):
    data = {
        'train_mean': mean.tolist(),
        'train_std': std.tolist()
    }
    with open(filename, 'w') as f:
        json.dump(data, f)

def split_csv(file_path, train_ratio=0.9, random_state=42):
    # Read the CSV file
    df = pd.read_csv(file_path)
    
    # Split the data into train and validation sets
    train_data, val_data = train_test_split(df, test_size=(1-train_ratio), random_state=random_state)
    
    # Generate new file names
    base_name = file_path.rsplit('.', 1)[0]
    train_file = f"split_train.csv"
    val_file = f"split_val.csv"
    
    # Save the split datasets
    train_data.to_csv(train_file, index=False)
    val_data.to_csv(val_file, index=False)
    
    print(f"Training set saved to: {train_file}")
    print(f"Validation set saved to: {val_file}")
    
    return train_file, val_file

# Usage example
file_path = "./COMP90086_2024_Project_train/train.csv"
train_file, val_file = split_csv(file_path)

Training set saved to: split_train.csv
Validation set saved to: split_val.csv


In [5]:
import os
import cv2
import numpy as np
import pandas as pd
from tqdm import tqdm
import json

def calculate_image_stats(csv_file, csv_name, img_dir, use_quantized):
    # Read the CSV file
    df = pd.read_csv(csv_file)
    
    # Initialize variables for mean and std calculation
    sum_means = np.zeros(3)
    sum_stds = np.zeros(3)
    count = 0

    # Iterate through all images
    for _, row in tqdm(df.iterrows(), total=len(df), desc="Processing images"):
        img_name = str(row[0])
        if use_quantized:
            img_path = os.path.join(img_dir, f"quantized/{img_name}_quantized.jpg")
        else:
            img_path = os.path.join(img_dir, f"{img_name}_original.jpg")
        
        # Skip augmented images
        if "_flipped" in img_path or "_zoomed" in img_path:
            continue

        # Read and process the image
        image = cv2.imread(img_path)
        if image is None:
            print(f"Warning: Could not read image {img_path}")
            continue
        
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        # Calculate mean and std for each channel
        means = np.mean(image, axis=(0, 1))
        stds = np.std(image, axis=(0, 1))
        
        sum_means += means
        sum_stds += stds
        count += 1

    # Calculate final mean and std
    final_mean = sum_means / count
    final_std = sum_stds / count

    stats = {
        "mean": final_mean.tolist(),
        "std": final_std.tolist()
    }

    quant = "_quantized" if use_quantized else ""
    output_file = f"dataset_stats/{csv_name}{quant}.json"
    with open(output_file, "w") as f:
        json.dump(stats, f, indent=4)

    print(f"Image statistics saved to {output_file}")
    print(f"Mean: {final_mean}")
    print(f"Std: {final_std}")

img_dir = './preprocessed_images/train'
main_train = './COMP90086_2024_Project_train/train.csv'
split_train = 'split_train.csv'

for quant in [True, False]:
    calculate_image_stats(main_train, "full", img_dir, quant)
    calculate_image_stats(split_train, "split", img_dir, quant)


Processing images:   0%|          | 0/7680 [00:00<?, ?it/s]

  img_name = str(row[0])
Processing images: 100%|██████████| 7680/7680 [01:45<00:00, 72.56it/s]


Image statistics saved to dataset_stats/full_quantized.json
Mean: [118.6946178  111.99958266 103.1688242 ]
Std: [68.48892217 57.7008792  47.95330727]


  img_name = str(row[0])
Processing images: 100%|██████████| 6912/6912 [01:35<00:00, 72.32it/s]


Image statistics saved to dataset_stats/split_quantized.json
Mean: [118.68824203 111.95225786 103.06642011]
Std: [68.46407555 57.64863105 47.88215884]


  img_name = str(row[0])
Processing images: 100%|██████████| 7680/7680 [01:41<00:00, 75.45it/s]


Image statistics saved to dataset_stats/full.json
Mean: [119.21526968 112.51073973 103.68285506]
Std: [69.10808986 58.16075209 48.56616013]


  img_name = str(row[0])
Processing images: 100%|██████████| 6912/6912 [01:35<00:00, 72.46it/s]

Image statistics saved to dataset_stats/split.json
Mean: [119.20886585 112.46371131 103.58079217]
Std: [69.08458836 58.10946274 48.49390902]



