### Load data

In [3]:
import json
processed_file = 'C://Users//Vinny//work//amr_files//dsk_processed_filtered//dsk_processed_filtered.json'
with open(processed_file) as json_file:
    data = json.load(json_file)

In [4]:
from scipy.sparse import coo_matrix
i,j,k = zip(*data['data_list'])
coo = coo_matrix((k,(i,j)),shape = (len(data['kmer_idx']),len(data['strain_idx'])))

In [5]:
coo

<263453x144 sparse matrix of type '<class 'numpy.int32'>'
	with 21496883 stored elements in COOrdinate format>

In [6]:
coo.shape

(263453, 144)

In [7]:
sorted_strains = sorted(data['strain_idx'].items(), key=lambda tup: tup[1])
sorted_strains[0:3]

[('/srv/data/amr/sample_data/strain_kmers/GCF_000660705.1_Myco_tube_TKK-01-0030_V1_genomic.txt',
  0),
 ('/srv/data/amr/sample_data/strain_kmers/GCF_000659825.1_Myco_tube_TKK_02_0059_V1_genomic.txt',
  1),
 ('/srv/data/amr/sample_data/strain_kmers/GCF_000659225.1_Myco_tube_TKK_02_0005_V1_genomic.txt',
  2)]

### Load prediction labels

In [8]:
import pandas as pd
outcome_file = 'C://Users//Vinny//work//amr//TBmetadata//pyrazinamide.csv'
df = pd.read_csv(outcome_file, header=None)

In [9]:
import os

pred_labels = []
for strain_name, idx in sorted_strains:
    basename = os.path.basename(strain_name)

    found = False
    for idx2, row in df.iterrows():
        strain = row[0].astype(str)
        label = row[2].astype(int)
        if strain in basename:
            found = True
            break

    if found:
        pred_labels.append((strain, label))    

In [10]:
pred_df = pd.DataFrame(pred_labels, columns=['strain', 'resistance'])
pred_df

Unnamed: 0,strain,resistance
0,660705.1,0
1,659825.1,1
2,659225.1,1
3,657135.1,1
4,680395.1,1
...,...,...
139,660785.1,1
140,659905.1,1
141,680635.1,1
142,659485.1,1


### Train Models

In [23]:
# can change model here depending on what we want to use
from sklearn.linear_model import LogisticRegressionCV
clf = LogisticRegressionCV(cv=5, penalty='l1', solver='liblinear')
clf.fit(coo.transpose(), pred_df['resistance'])
scores = abs(clf.coef_[0])

In [32]:
kmer_idx = data['kmer_idx']
kmers = list(kmer_idx.keys())

In [53]:
results = []
for i in range(len(kmers)):
    key = kmers[i]
    idx = kmer_idx[key]
    score = scores[i]
    results.append((key, idx, score))

results_df = pd.DataFrame(results, columns=['kmer', 'idx', 'score'])
results_df

Unnamed: 0,kmer,idx,score
0,AAAATACGAGCTCGCTCTTTACGCTGAGCTT,0,0.0
1,AAAATCATCGCCCACACAGCTCTCGCGGAGG,1,0.0
2,AAAAGGCGGCACCAACGGCAACGGCGGCAGC,2,0.0
3,AAACCGGCGAAGCCGCCGGTGCCGCCGTTGC,3,0.0
4,AAACGCCGTTCCTGGACCTCACCCTCACCGG,4,0.0
...,...,...,...
263448,CCGCTAGCCCCGCAGTTGACCGCCACCGCCA,263448,0.0
263449,CTAGCCCCGCAGTTGACCGCCACCGCCACCG,263449,0.0
263450,CGCTAGCCCCGCAGTTGACCGCCACCGCCAC,263450,0.0
263451,TAGCCCCGCAGTTGACCGCCACCGCCACCGC,263451,0.0


In [54]:
results_df.to_csv('prediction.csv')

### In Function Form

In [15]:
import random
from sklearn.linear_model import LogisticRegressionCV
def classifier(X, y, randomise=False, save_filename=None, cv=5, penalty='l1'):
    # randomisation, if required
    if randomise:
        y = random.shuffle(y)
    # fit model
    clf = LogisticRegressionCV(cv=cv, penalty=penalty, solver='liblinear')
    clf.fit(X, y)
    scores = abs(clf.coef_[0])
    # create dataset
    kmer_idx = data['kmer_idx']
    kmers = list(kmer_idx.keys())
    results = []
    for i in range(len(kmers)):
        key = kmers[i]
        idx = kmer_idx[key]
        score = scores[i]
        results.append((key, idx, score))

    results_df = pd.DataFrame(results, columns=['kmer', 'idx', 'score'])
    if save_filename is not None:
        results_df.to_csv(save_filename)
    return results_df

In [16]:
X = coo.transpose()
y =  pred_df['resistance']
classifier(X, y)

Unnamed: 0,kmer,idx,score
0,AAAATACGAGCTCGCTCTTTACGCTGAGCTT,0,0.0
1,AAAATCATCGCCCACACAGCTCTCGCGGAGG,1,0.0
2,AAAAGGCGGCACCAACGGCAACGGCGGCAGC,2,0.0
3,AAACCGGCGAAGCCGCCGGTGCCGCCGTTGC,3,0.0
4,AAACGCCGTTCCTGGACCTCACCCTCACCGG,4,0.0
...,...,...,...
263448,CCGCTAGCCCCGCAGTTGACCGCCACCGCCA,263448,0.0
263449,CTAGCCCCGCAGTTGACCGCCACCGCCACCG,263449,0.0
263450,CGCTAGCCCCGCAGTTGACCGCCACCGCCAC,263450,0.0
263451,TAGCCCCGCAGTTGACCGCCACCGCCACCGC,263451,0.0
