In [21]:
import os
import json

import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt

import os
import shutil
from pathlib import Path

In [10]:
def read_csv_classes(csv_dir: str, csv_name: str):
  data = pd.read_csv(os.path.join(csv_dir, csv_name))
  # print(data.head(1))  # filename, label

  label_set = set(data["label"].drop_duplicates().values)

  print("{} have {} images and {} classes.".format(csv_name, data.shape[0],
                                                   len(label_set)))
  return data, label_set


data_dir = "/hdd/zengs_data/vision_data/imagenetmini-1000"
train_data, train_label = read_csv_classes(data_dir, "train.csv")
val_data, val_label = read_csv_classes(data_dir, "val.csv")
test_data, test_label = read_csv_classes(data_dir, "test.csv")

train_data['dataset'] = 'train'
val_data['dataset'] = 'val'
test_data['dataset'] = 'test'

data = pd.concat([train_data, val_data, test_data], axis=0)

labels = (train_label | val_label | test_label)
labels = list(labels)
labels.sort()
print("all classes: {}".format(len(labels)))

train.csv have 38400 images and 64 classes.
val.csv have 9600 images and 16 classes.
test.csv have 12000 images and 20 classes.
all classes: 100


In [11]:

split_train_data = []
split_val_data = []
rate = 0.2
for label in labels:
  class_data = data[data["label"] == label]
  shuffle_data = class_data.sample(frac=1, random_state=1)
  num_train_sample = int(class_data.shape[0] * (1 - rate))
  split_train_data.append(shuffle_data[:num_train_sample])
  split_val_data.append(shuffle_data[num_train_sample:])
  
# concatenate data
new_train_data = pd.concat(split_train_data, axis=0)
new_val_data = pd.concat(split_val_data, axis=0)

In [26]:
# For training
for index, row in new_train_data.iterrows():
  filename = row['filename']
  label = row['label']
  dataset = row['dataset']
  
  original_filename = os.path.join(data_dir, dataset, label, filename)
  new_filename = os.path.join(data_dir, 'new_dataset', 'train', label, filename)
  
  path = Path(new_filename).parent.absolute()
  Path(path).mkdir(parents=True, exist_ok=True) 
  shutil.copyfile(original_filename, new_filename)
  
  # print(new_filename)
  # break


# For validation dataset
for index, row in new_val_data.iterrows():
  filename = row['filename']
  label = row['label']
  dataset = row['dataset']
  
  original_filename = os.path.join(data_dir, dataset, label, filename)
  new_filename = os.path.join(data_dir, 'new_dataset', 'val', label, filename)
  
  path = Path(new_filename).parent.absolute()
  Path(path).mkdir(parents=True, exist_ok=True) 
  shutil.copyfile(original_filename, new_filename)
  
  # print(new_filename)
  # break