In [None]:
from PIL import Image
import os
import torch
import csv
import numpy as np
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision import transforms, utils
from torchvision.transforms import v2
import torch.nn as nn
import torch.nn.functional as F
import argparse
from datetime import datetime

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

## 1. Create table for dataloaders (all data)


In [None]:
def create_or_load_csv(root_path, sub_path):
    """Create or load a csv file for the data loader. 
    Root path will point towards train, val, or test.
    Sub path will point towards the image generation technique - for example fs for face-swap
    
    returns a nested list representing all data in csv"""

    csv_path = f"{root_path}/{sub_path}_{root_path}_table_mod1.csv"
    
    if os.path.exists(csv_path):
        print(f"Found {csv_path}. Loading data.")
        with open(csv_path) as csvfile:
            csvreader = csv.reader(csvfile, delimiter=" ")
            rows = list(csvreader)
            return rows
    else:
        print(f"Did not find csv data. Writing {csv_path}.")
        out_index = 0
        n_corrupted = 0 # check to see if images load correctly

        with open(csv_path, 'w', newline='') as csvfile:
            writer = csv.writer(csvfile, delimiter=' ')
            # Part of csv that is fake (we will use 0 for real, and 1 for fake)
            fake = 1
            for path, dirs, files in os.walk(f"{root_path}/{sub_path}"):
                for name in files:

                    if ".zip" not in name: # I kept the downloaded zip files in their directory locations
                        
                        filepath = f"{path}/{name}"

                        # This is part of the check for corrupted images
                        try:
                            im = Image.open(filepath)
                        except:
                            print(f"{filepath} will not load - marked as corrupted")
                            n_corrupted += 1
                            continue
                        
                        out_index += 1
                        
                        writer.writerow([out_index, filepath, fake])

            fake = 0 # this is a flag
            for path, dirs, files in os.walk(f"{root_path}/real"):
                for name in files:
                    if ".zip" not in name:
                        filepath = f"{path}/{name}"

                        # check for corrupted
                        try:
                            im = Image.open(filepath)
                        except:
                            print(f"{filepath} will not load - marked as corrupted")
                            n_corrupted += 1
                            continue
                        
                        out_index += 1

                        writer.writerow([out_index, filepath, fake])
            
            print("CSV created. Reading and loading data.")
            with open(csv_path) as csvfile:
                csvreader = csv.reader(csvfile, delimiter=" ")
                rows = list(csvreader)
                return rows

In [None]:
# This should create all of the original tables (with all data) if they do not already exist
for root_path in ["train", "test"]:
    for sub_path in ["fe", "fs", "i2i", "t2i"]:
        temp_rows = create_or_load_csv(root_path, sub_path)

In [13]:
import os
import csv
from random import shuffle

## 2. Create a data subset and put it in a table

In [14]:
# Load csv file by name
def load_csv(root_path, sub_path):
    csv_path = f"{root_path}/{sub_path}_{root_path}_table_mod1.csv"
    with open(csv_path) as csvfile:
        csvreader = csv.reader(csvfile, delimiter=" ")
        rows = list(csvreader)
    return rows

def write_csv(rows, outpath):
    with open(outpath, 'w', newline='') as csvfile:
        writer = csv.writer(csvfile, delimiter=' ')
        for row in rows:
            writer.writerow(row)

In [15]:
# Let's do one by one
for mod_type in ["fe", "fs", "i2i", "t2i"]:
    dsplit_type = "train"
    rows = load_csv(dsplit_type, mod_type)
    
    real = [row for row in rows if row[2] == "0"]
    fake = [row for row in rows if row[2] == "1"]

    shuffle(real)
    shuffle(fake) # just to make sure it is not ordered in any way
    
    # Subset to the amount of 
    real = real[:5_000]
    fake = fake[:5_000]
    
    # concatenate the two
    rows = real + fake
    shuffle(rows)

    for i in range(len(rows)):
        rows[i][0] = i
    
    out_path = f"subset_data/{dsplit_type}/{mod_type}_{dsplit_type}_10k_subset.csv"
    write_csv(rows, out_path)

In [16]:
for mod_type in ["fe", "fs", "i2i", "t2i"]:
    dsplit_type = "test"
    rows = load_csv(dsplit_type, mod_type)
    
    real = [row for row in rows if row[2] == "0"]
    fake = [row for row in rows if row[2] == "1"]

    shuffle(real)
    shuffle(fake) # just to make sure it is not ordered in any way
    
    # # Subset to the amount of 
    real = real[:1_000]
    fake = fake[:1_000]
    
    # # concatenate the two
    rows = real + fake
    shuffle(rows)

    for i in range(len(rows)):
        rows[i][0] = i
    
    out_path = f"subset_data/{dsplit_type}/{mod_type}_{dsplit_type}_2k_subset.csv"
    write_csv(rows, out_path)

[2, 'train/t2i/HPS/id_0418/image_4/1.jpg/align_sd_gene_000_000.png', '1']

## 3. Transfer Images to New Location

In [108]:
import os
import csv
import random
import shutil

In [156]:
# Load csv file by name
def load_new_csv(root_path, sub_path):
    csv_path = f"subset_data/{root_path}/{sub_path}_{root_path}_2k_subset.csv"
    with open(csv_path) as csvfile:
        csvreader = csv.reader(csvfile, delimiter=" ")
        rows = list(csvreader)
    return rows

def write_new_csv(rows, outpath):
    with open(outpath, 'w', newline='') as csvfile:
        writer = csv.writer(csvfile, delimiter=' ')
        for row in rows:
            writer.writerow(row)

In [160]:
# Copy and rename the files from the old directory structure to the new. 
# 

dsplit_type = "test"
mod_type = "t2i"
rows = load_new_csv(dsplit_type, mod_type)
new_rows = []
for i, row in enumerate(rows):
    prev_path = row[1]
    # print("prev_path: ", prev_path)
    prev_path_list = prev_path.split("/")

    if prev_path_list[1] == "real":
        new_dir = f"subset_data/{dsplit_type}/real"
    else: 
        new_dir = f"subset_data/{dsplit_type}/{mod_type}"
        
    new_path = f"{new_dir}/{prev_path_list[-1]}"

    # new_path = "/".join(["subset_data", prev_path_list[0], prev_path_list[1], prev_path_list[-1]])
    # print("new path: ", new_path)
    
    # Copy the file from the old location to the new location
    # copy_dir = "/".join(["subset_data", prev_path_list[0], prev_path_list[1]])
    # print("copying to: ", new_dir)
    shutil.copy(prev_path, new_dir)

#     newer_numbered_path = "/".join(["subset_data", prev_path_list[0], prev_path_list[1], str(i) + ".png"])
    newer_numbered_path = f"{new_dir}/{mod_type}_{str(i)}.png"
    # print("newer: ", newer_numbered_path)
    
    os.rename(new_path, newer_numbered_path)
    
    row[1] = newer_numbered_path
    # print()
    # print()

outpath = f"subset_data/{dsplit_type}/{mod_type}_{dsplit_type}_2k_subset_update.csv"
print("outpath: ", outpath)

write_new_csv(rows, outpath)

outpath:  subset_data/test/t2i_test_2k_subset_update.csv


In [123]:
# from tqdm import tqdm
# for file in tqdm(os.listdir("subset_data/train/real")):
#     if os.path.isfile("subset_data/train/real/"+file):
#         os.remove("subset_data/train/real/"+file)

100%|██████████| 3191/3191 [00:00<00:00, 3236.09it/s]


In [165]:
len(os.listdir("subset_data/test/real"))

4001

In [153]:
random.choice(os.listdir("subset_data/train/real"))

'fe_7285.png'

In [95]:
row = random.choice(train_rows)

In [96]:
prev_path = row[1]
prev_path

'train/i2i/FreeDoM_I/id_1160/save_sketch_2/3.jpg/FreeDom_Gene.png'

In [97]:
prev_path = prev_path.split("/")
prev_path

['train',
 'i2i',
 'FreeDoM_I',
 'id_1160',
 'save_sketch_2',
 '3.jpg',
 'FreeDom_Gene.png']

In [100]:
i=1
new_path = "/".join(["subset_data", prev_path[0], prev_path[1], str(i) + ".png"])
new_path

'subset_data/train/i2i/1.png'

In [72]:
new_paths = []
for row in train_rows:
    prev_path = row[1]
    prev_path = prev_path.split("/")
    new_path = "/".join(["subset_data", prev_path[0], prev_path[1], prev_path[-1]])
    new_paths.append(new_path)

In [74]:
len(new_paths)

10000

In [75]:
len(set(new_paths))

5030

In [81]:
new_paths[8]

'subset_data/train/fe/0C_interText_optDM_3_alpha=0.5.png'

In [82]:
test = new_paths[8]
test

'subset_data/train/fe/0C_interText_optDM_3_alpha=0.5.png'

In [87]:
test = test.split("/")
test[-1] = "_" + test[-1]
test = "/".join(test)
test

'subset_data/train/fe/_0C_interText_optDM_3_alpha=0.5.png'

In [88]:
def create_no_dupe_string(s, arr):
    while s in arr:
        s_arr = s.split("/")
        s_arr[-1] = "_" + s_arr[-1]
        s = "/".join(s_arr)
    arr.append(s)