# Setup the training data for the model

In [1]:
import sys

sys.path.append('..')

In [2]:
import sqlite3
from collections import defaultdict
from dataclasses import dataclass, field

import pandas as pd
from Bio.SeqIO.FastaIO import SimpleFastaParser
from sklearn.model_selection import train_test_split

from adna.pylib import consts, utils

In [3]:
GOOD_DIR = consts.DATA_DIR / 'raw' / 'mostly_on_target'
RAW_DIR = consts.DATA_DIR / 'raw' / 'raw_data'
MT_DIR = consts.DATA_DIR / 'raw' / 'reference_mitogenome'

## Get the positives

In [4]:
GOOD = set()
for path in GOOD_DIR.glob('*.gz'):
    with utils.open_file(path) as fasta_file:
        for rec in SimpleFastaParser(fasta_file):
            GOOD.add(rec[0])

len(GOOD)

2553721

## Read raw data

In [5]:
@dataclass
class SeqCounts:
    ids: list[str] = field(default_factory=list)
    label_0: int = 0
    label_1: int = 0
    rev_0: int = 0
    rev_1: int = 0

In [6]:
SEQS = defaultdict(SeqCounts)

In [7]:
for path in RAW_DIR.glob('*.gz'):
    with utils.open_file(path) as fasta_file:
        for rec in SimpleFastaParser(fasta_file):
            id_ = rec[0].replace(' ', '_')
            rev = id_ + '_(reversed)'

            count = SEQS[rec[1]]
            count.ids.append(id_)

            if id_ in GOOD:
                count.label_1 += 1
            else:
                count.label_0 += 1

            if rev in GOOD:
                count.rev_1 += 1
            else:
                count.rev_0 += 1

len(SEQS)

478722

## Remove duplicate sequences

In [8]:
for seq, count in SEQS.items():
    if count.label_1 > 0 and count.label_0 > 0:
        del SEQS[seq]
        continue
    if count.rev_1 > 0 and count.rev_0 > 0:
        del SEQS[seq]

len(SEQS)

478722

## Create data frame

In [9]:
RECS = []
for seq, count in SEQS.items():
    RECS.append({
        'id': count.ids[0],
        'seq': seq,
        'label': 1 if count.label_1 > 0 else 0,
        'rev': 1 if count.rev_1 > 0 else 0,
        'dups': len(count.ids),
        'split': '',
    })

In [10]:
df = pd.DataFrame(RECS)
df.head()

Unnamed: 0,id,seq,label,rev,dups,split
0,A00916:157:HLNFGDSX2:2:1101:8377:1000_1:N:0:CG...,GGGTGCACTAATAACTAGCTCAGTGTGTCTACGCCAAATTGACCTA...,1,0,1,
1,A00916:157:HLNFGDSX2:2:1101:12825:1000_1:N:0:C...,GCATTTCATCAAACTGCGACAAAATCCCATTCCACCCCTACTTCTC...,1,0,1,
2,A00916:157:HLNFGDSX2:2:1101:13675:1000_1:N:0:C...,TTTTTTGGCCTTCAAGGATGAATTAATGATACGGTTTCGGGTGTAA...,0,0,1,
3,A00916:157:HLNFGDSX2:2:1101:18539:1000_1:N:0:C...,CTATTCTTCTACCTACGCCTGGCGTACTGCTCCACTATCACACTTT...,0,0,7,
4,A00916:157:HLNFGDSX2:2:1101:20943:1000_1:N:0:C...,TTTACTGCCTATTTTATCAATTGTCACGAAACAACGTTCCACTTAA...,0,0,3,


## Split the data

In [11]:
train, other = train_test_split(df, train_size=0.6, random_state=23)
val, test = train_test_split(other, train_size=0.5, random_state=45)

In [12]:
train.split = 'train'
val.split = 'val'
test.split = 'test'

## Write data to database

In [13]:
with sqlite3.connect(consts.SQL) as cxn:
    train.to_sql('seqs', cxn, if_exists='replace', index=False)
    val.to_sql('seqs', cxn, if_exists='append', index=False)
    test.to_sql('seqs', cxn, if_exists='append', index=False)