In [1]:
import torch 
from torch import nn
from torch.nn import functional as F
import numpy as np
import pandas as pd
from vocab import Vocab
from dataprocessor import DataProcessor
from conv_model import ConvProtein
from utils import *
from tqdm import trange
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error

In [2]:
def get_data(filename):
    data = pd.read_csv(filename)
    return data['protein_sequence'].values , data['tm'].values 

x,y = get_data('./data/train_fixed.csv')

min_v,max_v = min(y),max(y)

trainx,testx,trainy ,testy = train_test_split(x,y,test_size=0.2)

vocab = Vocab()
vocab.from_file()

vectors = load_ptvectors('./vectors.npy')

train_datafeeder = DataProcessor(vocab , trainx,trainy, [min_v,max_v])
test_datafeeder  = DataProcessor(vocab , testx,testy, [min_v,max_v])

In [3]:
model = ConvProtein(vectors,lr=1e-4)


for i in trange(1000):
    x,y = train_datafeeder.sample(100)
    model.updates(x,y)
    if i% 100 == 0:
        predictions , trues = [] , []
        for testx,testy in test_datafeeder.export(10 , 100):
            preds = model.predicts(testx)
            predictions.extend(preds)
            trues.extend(testy)
        perf = np.corrcoef(predictions,trues)[0,1]
        print('correlation on test dataset is ',perf)


  0%|          | 1/1000 [00:06<1:41:45,  6.11s/it]

correlation on test dataset is  0.009639192364105927


 10%|█         | 100/1000 [02:21<20:52,  1.39s/it]

   100--0.056


 10%|█         | 101/1000 [02:27<40:46,  2.72s/it]

correlation on test dataset is  0.09695485016915464


 20%|██        | 200/1000 [04:41<18:08,  1.36s/it]

   200--0.032


 20%|██        | 201/1000 [04:47<36:03,  2.71s/it]

correlation on test dataset is  0.5058574103441116


 30%|███       | 300/1000 [07:03<17:07,  1.47s/it]

   300--0.036


 30%|███       | 301/1000 [07:09<33:51,  2.91s/it]

correlation on test dataset is  0.5828308433172184


 40%|████      | 400/1000 [09:26<13:26,  1.34s/it]

   400--0.039


 40%|████      | 401/1000 [09:32<26:39,  2.67s/it]

correlation on test dataset is  0.6283654854511584


 46%|████▌     | 456/1000 [10:54<15:21,  1.69s/it]