In [2]:
from sklearn.model_selection import train_test_split

In [3]:
import module.input_mutation_path as imp
from module.make_dataset import base

data_config = {
    'train_end': 37,
    'test_start': 34,
    'ylen': 1,
    'val_ratio': 0.2,
    'mix_ratio': 0.2, 
    'frag_len': 10,
    'max_co_occur': 20,
    'nmax': 1000,
    'nmax_per_strain': 1000000
}

dataset_config = {
    'strains': ['B.1.1.7','P.1','BA.2','BA.1.1','BA.1','B.1.617.2','B.1.351','B.1.1.529'],
    'usher_dir': '../usher_output/',
    'bunpu_csv': "table_heatmap/250621/table_set/table_set.csv",
    'codon_csv': 'meta_data/codon_mutation4.csv',
    'cache_dir': '../cache',  # 特徴データキャッシュ用ディレクトリ
}

names, lengths, base_HGVS_paths = imp.input(
dataset_config['strains'], 
dataset_config['usher_dir'], 
nmax=data_config['nmax'], 
nmax_per_strain=data_config['nmax_per_strain']
)

print("元データ",len(base_HGVS_paths))

def filter_co_occur(data,sample_name,data_len,max_co_occur,out_num=None):
    filted_data = []
    filted_sample_name =[]
    filted_data_len = []
    for i in range(len(data)):
        compare = 0
        for j in range(len(data[i])):
            mutation = data[i][j].split(',')
            if(compare < len(mutation)):
                compare = len(mutation)
        if(compare <= max_co_occur):
            filted_data.append(data[i])
            filted_sample_name.append(sample_name[i])
            filted_data_len.append(data_len[i])
        if(out_num is not None and len(filted_data)>=out_num):
            break
    return filted_data,filted_sample_name,filted_data_len

def unique_path(data):
    return [list(item) for item in dict.fromkeys(tuple(path) for path in data)]

filted_data, temp, temp  = filter_co_occur(base_HGVS_paths,names,lengths,data_config['max_co_occur'])
print("共起フィルタ",len(filted_data))

data = unique_path(filted_data)
print("ユニークパス",len(data))

def data_by_ts(data):
    data_ts = {}
    for i in range(len(data)):
        length = len(data[i])
        if(data_ts.get(length) is None):
            data_ts[length] = []
        data_ts[length].append(data[i])
    return data_ts

def fragmentation(data,frag_len,end_opt=False):
    frag_data = []
    for i in range(len(data)):
        if end_opt:
            start = len(data[i])-frag_len
        else:
            start = 0
        for j in range(start,len(data[i])-frag_len+1):
            frag_data.append(data[i][j:j+frag_len])
    return frag_data

def dataset_by_ts(data, train_end, test_start, mix_ratio=0.2, val_ratio=0.2, frag_len=None, unique=False):
    import random
    train = []
    test = {}
    data_ts = data_by_ts(data)
    keys = sorted(list(data_ts.keys()))
    print("keys:",keys)

    if train_end >= test_start:
        train_test_mix = [test_start, train_end]
    else:
        train_test_mix = None
    
    if train_test_mix is None:
        for k in keys:
            items = data_ts[k]
            if k <= train_end:
                train.extend(items)
            if test_start <= k:
                test[k] = items.copy()
    else:
        for k in keys:
            items = data_ts[k]
            if k < test_start:
                train.extend(items)
            elif train_end < k:
                test[k] = items.copy()
            else:
                train_temp, test[k] = train_test_split(items, test_size=mix_ratio)
                train.extend(train_temp)
    if frag_len is not None:
        train, valid = train_test_split(train, test_size=val_ratio)
        train = fragmentation(train,frag_len=frag_len)
        valid = fragmentation(valid,frag_len=frag_len,end_opt=True)
        for k in sorted(list(test.keys())):
            test[k] = fragmentation(test[k],frag_len=frag_len,end_opt=True)
    
    if unique:
        train = unique_path(train)
        valid = unique_path(valid)
        for k in sorted(list(test.keys())):
            test[k] = unique_path(test[k])
        
    return train, valid, test

train,valid,test = dataset_by_ts(data,data_config['train_end'],data_config['test_start'],
                                    data_config['mix_ratio'],data_config['val_ratio'],data_config['frag_len'],unique=True)

print("train",len(train))
print("valid",len(valid))
for k in sorted(list(test.keys())):
    print(f"test:{k} {len(test[k])}")

[INFO] import: ../usher_output/B.1.1.7/0/mutation_paths_B.1.1.7.tsv
[INFO] 指定されたnmax=1000に達しました。
[INFO] B.1.1.7のデータを読み込みました: 1000 サンプル
[INFO] 読み込み完了: 1000 サンプル
元データ 1000
共起フィルタ 1000
ユニークパス 979
keys: [21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 38]
train 2683
valid 194
test:34 4
test:35 2
test:36 1
test:38 2


In [5]:
print(base_HGVS_paths[0])

['C14408T', 'A23403G', 'C241T', 'C3037T', 'G28881A', 'G28882A,G28883C', 'C884T', 'T884C,G28280C,A28281T,T28282A', 'A23063T', 'C15279T', 'A28111G,C28977T', 'T23063A', 'C913T', 'C5986T', 'T6954C,A23063T,C23271A,C23604A,C23709T', 'T24506G,G24914C', 'T16176C', 'C14676T,C27972T', 'C3267T', 'C5388A', 'G28048T', 'A12097G', 'G3259T', 'C7029T', 'T1054C', 'C7749T', 'C203A', 'C21588T', 'C28312T']
