In [1]:
import os
import shutil

import pandas as pd
from sklearn.model_selection import train_test_split

In [2]:
!ls tmp/document_classification

[34mdata[m[m      train.csv train.pq


In [3]:
df = pd.read_csv('tmp/document_classification/train.csv')
df

Unnamed: 0,image,class
0,email/doc_000042.png,email
1,email/doc_000046.png,email
2,email/doc_000076.png,email
3,email/doc_000079.png,email
4,email/doc_000111.png,email
...,...,...
160,scientific_publication/doc_000845.png,scientific_publication
161,scientific_publication/doc_000864.png,scientific_publication
162,scientific_publication/doc_000891.png,scientific_publication
163,scientific_publication/doc_000942.png,scientific_publication


In [4]:
df['class'].value_counts()

email                     55
resume                    55
scientific_publication    55
Name: class, dtype: int64

In [5]:
df_train, df_test = train_test_split(df, test_size=30, stratify=df['class'], random_state=123)
df_train.shape, df_test.shape

((135, 2), (30, 2))

In [6]:
df_train['class'].value_counts()

resume                    45
scientific_publication    45
email                     45
Name: class, dtype: int64

In [7]:
df_test['class'].value_counts()

scientific_publication    10
email                     10
resume                    10
Name: class, dtype: int64

In [8]:
!ls tmp/document_classification/data

[34memail[m[m                  [34mresume[m[m                 [34mscientific_publication[m[m


In [9]:
!ls tmp/document_classification/data/email

doc_000042.png doc_000196.png doc_000363.png doc_000550.png doc_000694.png
doc_000046.png doc_000238.png doc_000448.png doc_000558.png doc_000745.png
doc_000076.png doc_000255.png doc_000464.png doc_000577.png doc_000750.png
doc_000079.png doc_000260.png doc_000465.png doc_000586.png doc_000784.png
doc_000111.png doc_000275.png doc_000471.png doc_000591.png doc_000787.png
doc_000115.png doc_000278.png doc_000483.png doc_000596.png doc_000796.png
doc_000133.png doc_000279.png doc_000485.png doc_000612.png doc_000825.png
doc_000142.png doc_000282.png doc_000507.png doc_000637.png doc_000840.png
doc_000148.png doc_000297.png doc_000511.png doc_000650.png doc_000862.png
doc_000165.png doc_000333.png doc_000528.png doc_000655.png doc_000872.png
doc_000195.png doc_000347.png doc_000532.png doc_000675.png doc_000873.png


In [10]:
!ls tmp/document_classification/data/resume

doc_000051.png doc_000175.png doc_000361.png doc_000468.png doc_000639.png
doc_000070.png doc_000191.png doc_000369.png doc_000473.png doc_000674.png
doc_000072.png doc_000223.png doc_000375.png doc_000476.png doc_000727.png
doc_000080.png doc_000248.png doc_000377.png doc_000499.png doc_000734.png
doc_000088.png doc_000264.png doc_000402.png doc_000501.png doc_000752.png
doc_000091.png doc_000281.png doc_000411.png doc_000543.png doc_000760.png
doc_000097.png doc_000286.png doc_000441.png doc_000551.png doc_000763.png
doc_000101.png doc_000294.png doc_000443.png doc_000575.png doc_000802.png
doc_000109.png doc_000301.png doc_000447.png doc_000609.png doc_000809.png
doc_000169.png doc_000344.png doc_000450.png doc_000629.png doc_000824.png
doc_000173.png doc_000353.png doc_000460.png doc_000636.png doc_000847.png


In [11]:
def find_duplicate_files(dir1, dir2):
    # ディレクトリ1のファイル名リストを取得
    files_dir1 = set(os.listdir(dir1))
    # ディレクトリ2のファイル名リストを取得
    files_dir2 = set(os.listdir(dir2))
    # 共通するファイル名を取得
    duplicates = files_dir1.intersection(files_dir2)
    return duplicates

In [12]:
# ディレクトリパスの指定
directory1 = 'tmp/document_classification/data/email'
directory2 = 'tmp/document_classification/data/resume'
# 重複ファイルを検索
duplicates1 = find_duplicate_files(directory1, directory2)

# ディレクトリパスの指定
directory1 = 'tmp/document_classification/data/email'
directory2 = 'tmp/document_classification/data/scientific_publication'
# 重複ファイルを検索
duplicates2 = find_duplicate_files(directory1, directory2)

# ディレクトリパスの指定
directory1 = 'tmp/document_classification/data/scientific_publication'
directory2 = 'tmp/document_classification/data/resume'
# 重複ファイルを検索
duplicates3 = find_duplicate_files(directory1, directory2)

duplicates1, duplicates2, duplicates3

(set(), set(), set())

In [13]:
df_train2 = df_train.copy()
df_test2 = df_test.copy()

df_train2['image'] = ['images/' + s.split('/', 1)[-1] for s in df_train['image']]
df_test2['image'] = ['images/' + s.split('/', 1)[-1] for s in df_test['image']]
df_train2.shape, df_test2.shape

((135, 2), (30, 2))

In [14]:
df_train2.head()

Unnamed: 0,image,class
60,images/doc_000091.png,resume
57,images/doc_000072.png,resume
116,images/doc_000128.png,scientific_publication
92,images/doc_000501.png,resume
144,images/doc_000534.png,scientific_publication


In [15]:
df_test2.head()

Unnamed: 0,image,class
159,images/doc_000832.png,scientific_publication
2,images/doc_000076.png,email
145,images/doc_000584.png,scientific_publication
62,images/doc_000101.png,resume
17,images/doc_000279.png,email


## Hydrogen Torch dataset

In [16]:
src_dir = 'tmp/document_classification/data/'
dst_dir = 'tmp/HT_document_classification/images/'

In [17]:
for d in df['image']:
    #print(src_dir + d)
    # ファイルをコピー
    shutil.copy(src_dir + d, dst_dir)

In [18]:
!ls -l tmp/HT_document_classification/images/ | wc

     166    1487   11892


In [19]:
df_train2.to_csv('tmp/HT_document_classification/train.csv', index=False)
df_test2.to_csv('tmp/HT_document_classification/test.csv', index=False)

## Driverless AI

In [20]:
src_dir = 'tmp/document_classification/data/'
dst_dir_train = 'tmp/DAI_document_classification_TRAIN/images/'
dst_dir_test = 'tmp/DAI_document_classification_TEST/images/'

In [21]:
for d in df_train['image']:
    #print(src_dir + d)
    # ファイルをコピー
    shutil.copy(src_dir + d, dst_dir_train)

for d in df_test['image']:
    #print(src_dir + d)
    # ファイルをコピー
    shutil.copy(src_dir + d, dst_dir_test)

In [22]:
!ls -l tmp/DAI_document_classification_TRAIN/images/ | wc

     136    1217    9732


In [23]:
!ls -l tmp/DAI_document_classification_TEST/images/ | wc

      31     272    2171


In [24]:
df_train2.to_csv('tmp/DAI_document_classification_TRAIN/data.csv', index=False)
df_test2.to_csv('tmp/DAI_document_classification_TEST/data.csv', index=False)