# **Predicting on test set**

The purpose of this notebook is to use `svm_train_2` to predict m6A RNA modifications in all direct RNA-Seq data sets from the SG-NEx data.

### **Load model and data**

In [1]:
import pickle
import sys
import pandas as pd

# Function to load the .pkl file (e.g., a model)
def load_pkl(file_path):
    try:
        with open(file_path, 'rb') as file:
            model = pickle.load(file)  # Assuming it's a model; change var name if not
        return model
    except Exception as e:
        print(f"Error loading .pkl file: {e}")
        sys.exit(1)

# Function to load input data (e.g., from a CSV file)
def load_input(input_file):
    try:
        data = pd.read_csv(input_file)  # Adjust if input is JSON, text, etc.
        return data
    except Exception as e:
        print(f"Error loading input file: {e}")
        sys.exit(1)

pkl_file = 'svm_train_2.pkl'  # Renamed from generic to svm_train_2.pkl
input_file = 'test_set.csv'    # Renamed from generic to test_set.csv
    
loaded_model = load_pkl(pkl_file)
input_data = load_input(input_file)

### **Data manipulation**

This section focuses on manipulating the data to ensure it can be fed into the model.

In [2]:
# View the data
input_data.head()

Unnamed: 0,gene_id,sequence,transcript_id,transcript_position,feature_1,feature_2,feature_3,feature_4,feature_5,feature_6,feature_7,feature_8,feature_9
0,ENSG00000004059,AAGACCA,ENST00000000233,244,0.00299,2.06,125.0,0.0177,10.4,122.0,0.0093,10.9,84.1
1,ENSG00000004059,AAGACCA,ENST00000000233,244,0.00631,2.53,125.0,0.00844,4.67,126.0,0.0103,6.3,80.9
2,ENSG00000004059,AAGACCA,ENST00000000233,244,0.00465,3.92,109.0,0.0136,12.0,124.0,0.00498,2.13,79.6
3,ENSG00000004059,AAGACCA,ENST00000000233,244,0.00398,2.06,125.0,0.0083,5.01,130.0,0.00498,3.78,80.4
4,ENSG00000004059,AAGACCA,ENST00000000233,244,0.00664,2.92,120.0,0.00266,3.94,129.0,0.013,7.15,82.2


**Treat each unique transcript_id and transcript_position individually**

Here, aggregate all records that have the same unique `transcript_id` and `transcript_position`, treating it as one single record of data.

Similar to the process of training and testing with the SVM model, take the average of each feature column grouped by unique `transcript_id` and `transcript_position`.

In [3]:
# List of feature data to be aggregated
features = [f"feature_{i}" for i in range(1, 10)]
print(features)

['feature_1', 'feature_2', 'feature_3', 'feature_4', 'feature_5', 'feature_6', 'feature_7', 'feature_8', 'feature_9']


In [4]:
# Group by transcript_id and transcript_position, average for all feature columns
agg_data = (
    input_data.groupby(["transcript_id", "transcript_position"], as_index=False)
      .agg({
        # Unlike the features, do not aggregate gene_id and sequence columns (store the first occurrence of each column)
          "gene_id": "first",
          "sequence": "first",
          **{col: "mean" for col in features}  # average of features
      })
)

# Sanity check
agg_data.head()

Unnamed: 0,transcript_id,transcript_position,gene_id,sequence,feature_1,feature_2,feature_3,feature_4,feature_5,feature_6,feature_7,feature_8,feature_9
0,ENST00000000233,244,ENSG00000004059,AAGACCA,0.008264,4.223784,123.702703,0.009373,7.382162,125.913514,0.007345,4.386989,80.57027
1,ENST00000000233,261,ENSG00000004059,CAAACTG,0.006609,3.216424,109.681395,0.006813,3.226535,107.889535,0.00771,3.016599,94.290698
2,ENST00000000233,316,ENSG00000004059,GAAACAG,0.00757,2.940541,105.475676,0.007416,3.642703,98.947027,0.007555,2.087146,89.364324
3,ENST00000000233,332,ENSG00000004059,AGAACAT,0.01062,6.47635,129.355,0.008632,2.8992,97.8365,0.006101,2.23652,89.154
4,ENST00000000233,368,ENSG00000004059,AGGACAA,0.010701,6.415051,117.924242,0.011479,5.870303,121.954545,0.010019,4.260253,85.178788


### **Predictions of m6A labels**

Should return a result that has the following format:

| transcript_id | transcript_position | score |

- `transcript_id`: the unique id of each transcript
- `transcript_position`: the position in the transcript
- `score`: probability that the given site has m6A modification

In [5]:
y_prob = loaded_model.predict_proba(agg_data.iloc[:, 4:13])
# Combine the test data set with the predictions and output the result
result = pd.DataFrame()
result["transcript_id"] = agg_data["transcript_id"]
result["transcript_position"] = agg_data["transcript_position"]
result["score"] = y_prob[:, 1]
result.head()

# Save the output as a csv file
# result.to_csv("test_result.csv", index = False)
result.head()

Unnamed: 0,transcript_id,transcript_position,score
0,ENST00000000233,244,0.038028
1,ENST00000000233,261,0.010155
2,ENST00000000233,316,0.03962
3,ENST00000000233,332,0.211524
4,ENST00000000233,368,0.564735


To include `gene_id` and `sequence`

In [6]:
result.insert(0, 'sequence', agg_data["sequence"])
result.insert(0, 'gene_id', agg_data["gene_id"])
result.head()

Unnamed: 0,gene_id,sequence,transcript_id,transcript_position,score
0,ENSG00000004059,AAGACCA,ENST00000000233,244,0.038028
1,ENSG00000004059,CAAACTG,ENST00000000233,261,0.010155
2,ENSG00000004059,GAAACAG,ENST00000000233,316,0.03962
3,ENSG00000004059,AGAACAT,ENST00000000233,332,0.211524
4,ENSG00000004059,AGGACAA,ENST00000000233,368,0.564735
