In [None]:
import pandas as pd
import PIL
import os
import numpy as np
from glob import glob
import torch
import random
from torchvision import transforms as T
import threading
from tqdm import tqdm
import cv2
import itertools

In [None]:
class Transform_rmbg:
    def __init__(self, thresh=240, medianblur=7):
        self.thresh = thresh
        self.medianblur = medianblur

    def __call__(self, img):
        img = np.array(img)
        img_bin = np.array(img[:, :, 0] < self.thresh, dtype=np.uint8) * 255
        img_blur = cv2.medianBlur(img_bin, self.medianblur)
        img_filter = img.copy()
        img_filter[img_blur == 0] = 0

        retval, labels, stats, centroids = cv2.connectedComponentsWithStats(img_blur)
        stats = np.array(stats)
        stats[0, -1] = 0
        label_idx = np.argmax(stats[:, -1])
        img_mask = labels == label_idx
        img_final = img_filter.copy()
        img_final[~img_mask] = (0, 0, 0)

        return PIL.Image.fromarray(img_filter)

def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

def RMBG(img, transform, count, angle, distance):
    new_img = np.zeros(np.array(img).shape)
    bg_img = PIL.Image.open(img_bg_dirs[count % len(img_bg_dirs)]).convert('RGB')
    bg_img = bg_img.resize((400,400))

    img = np.array(Transform_rmbg()(img))
        
    new_img[img != 0] = img[img != 0]

    new_img = np.array(new_img, dtype=np.uint8)
    bg_img = np.array(bg_img)

    new_img[np.sum(new_img, -1)==0] = bg_img[np.sum(new_img, -1)==0]
    new_img = PIL.Image.fromarray(new_img)
    
    save_dir_ = os.path.join(save_dir, f'{count}.jpg')
    new_img.save(save_dir_)

def shift_image(X, dx, dy):
    X = np.roll(X, dy, axis=0)
    X = np.roll(X, dx, axis=1)
    if dy>0:
        X[:dy, :] = 0
    elif dy<0:
        X[dy:, :] = 0
    if dx>0:
        X[:, :dx] = 0
    elif dx<0:
        X[:, dx:] = 0
    return X

def MixUp(img_indxs, count, angle, distance):
    imgs = []
    for i in img_indxs: imgs.append(PIL.Image.open(img_path[i]))
    for i in range(len(imgs)): imgs[i] = transforms(imgs[i])

    new_img = np.zeros(np.array(imgs[0]).shape)
    bg_img = PIL.Image.open(img_bg_dirs[count % len(img_bg_dirs)]).convert('RGB')
    bg_img = bg_img.resize((400,400))

    for i in range(len(imgs)):
        angle += 2*np.pi/len(imgs)
        x, y = distance*np.cos(angle), distance*np.sin(angle)
        x,y = map(int, [x,y])
        imgs[i] = shift_image(imgs[i], x//1, y//1)
        new_img[imgs[i] != 0] = imgs[i][imgs[i] != 0]

    new_img = np.array(new_img, dtype=np.uint8)
    new_img = PIL.Image.fromarray(new_img)
    new_img = Transform_rmbg()(new_img)
    
    new_img, bg_img = np.array(new_img), np.array(bg_img)

    new_img[np.sum(new_img, -1)==0] = bg_img[np.sum(new_img, -1)==0]
    new_img = PIL.Image.fromarray(new_img)
    
    new_img.save(os.path.join(save_dir, f'{count}.jpg'))


In [None]:
seed_everything(107)

save_dir = "./train_rmbg/"
if not os.path.isdir(save_dir): 
    os.mkdir(save_dir)

train = pd.read_csv("train.csv", index_col=0)
img_path = list(train.loc[:, 'img_path'])
img_label = np.array(train.loc[:, 'A':])
img_bg_dirs = glob('./indoorCVPR_09/images/*/*.jpg')
np.random.shuffle(img_bg_dirs)

transforms = T.Compose([
                Transform_rmbg()
                ])

df = pd.read_csv("./train.csv", index_col=0)
df = df.iloc[:0, :]

count = 0

for i in tqdm(range(len(img_label))):
    img = PIL.Image.open(img_path[i])
    threading.Thread(target=RMBG, args=(np.array(img), transforms, count, np.random.rand() * 2*np.pi, 20 + 50*np.random.rand())).start()
    save_dir_ = os.path.join(save_dir, f'{count}.jpg')
    df.loc[f'{count}.jpg', 'img_path':] = [save_dir_] + list(img_label[i])
    count += 1
    
df.to_csv("train_rmbg.csv")

save_dir = "./train_mixup/"
if not os.path.isdir(save_dir): 
    os.mkdir(save_dir)

train = pd.read_csv("train.csv", index_col=0)
img_path = list(train.loc[:, 'img_path'])
img_label = np.array(train.loc[:, 'A':])

img_bg_dirs = glob('./indoorCVPR_09/images/*/*.jpg')
np.random.shuffle(img_bg_dirs)

shuffle_inx = list(range(len(img_path)))
np.random.shuffle(shuffle_inx)
img_path = np.array(img_path)[shuffle_inx]
img_label = img_label[shuffle_inx]

transforms = T.Compose([
                Transform_rmbg()
                ])
df = pd.read_csv("./train.csv", index_col=0)
df = df.iloc[:0, :]

count = 0
for i, indxs in enumerate(itertools.combinations(range(len(img_path)), r=2)):
    if i % 100 != 0: continue
    labels = img_label[indxs[0]] + img_label[indxs[1]]
    if np.max(labels) >= 2 or np.sum(labels) > 9: continue
    indxs = list(indxs)
    np.random.shuffle(indxs)
    threading.Thread(target=MixUp, args=(indxs, count, np.random.rand() * 2*np.pi, 20+50*np.random.rand())).start()
    df.loc[f'{count}.jpg', 'img_path':] = [os.path.join(save_dir, f'{count}.jpg')] + list(labels)
    count += 1
    if count % 100 == 0:
        print(f"\r{count}", end="")
df.to_csv(save_dir.split('/')[1] + '.csv')