In [97]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append('../codes')
from torch.utils.data import Dataset, DataLoader
import pickle
import numpy as np
from sklearn.preprocessing import MinMaxScaler, OneHotEncoder
import pandas as pd
import config
import pickle
import config
import networks
import utils
import loss
import trainer
import evaluator
import plots
import torch

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [98]:
from sklearn.ensemble import RandomForestRegressor
from scipy.stats import spearmanr, pearsonr

In [99]:
def get_res_by_location(test_data):
    locations = test_data['Loc_no'].unique()
    trials = test_data['trial'].unique()

    result_dict ={}
    result_dict['location'] = []
    result_dict['trial'] = []
    result_dict['num_geno'] = []
    result_dict['pcc'] = []

    for location in locations:
        for trial in trials:
            partial_test = test_data[(test_data['Loc_no'] == location) & (test_data['trial'] == trial)]

            if len(partial_test) > 20:
               
                pcc = pearsonr(partial_test['Value'].to_numpy().reshape(-1,), partial_test['predicted'].to_numpy().reshape(-1,))[0]
                result_dict['location'].append(location)
                result_dict['trial'].append(trial)
                result_dict['num_geno'].append(len(partial_test))
                result_dict['pcc'].append(pcc)
   
    result_df = pd.DataFrame(result_dict)
   
    return result_df

In [100]:
def get_data(file):
    with open(file, 'rb') as pfile:
        all_data = pickle.load(pfile)

    all_data = all_data[all_data['Value'] < 10]
    
    weather = all_data['weather'].tolist()
    weather = np.array(weather)
    
    genotypes = all_data.iloc[:, 3:-14].to_numpy()
    
    target = all_data['Value'].to_numpy()
    
    return genotypes, weather, target

In [101]:
tr_genotypes, tr_weather, tr_target = get_data(config.training_data)

In [102]:
test_genotypes, test_weather, test_target = get_data(config.test_data)

In [103]:
val_genotypes, val_weather, val_target = get_data(config.validation_data)

In [104]:
test_genotypes_unique, test_weather_unique, test_target_unique = get_data(config.test_unique_env_data)

## weather representation

In [105]:
mdl_avg_over_geno = networks.fc_avg_net_over_geno(num_features = tr_weather.shape[1], hidden_dim=54)
mdl_avg_over_geno.load_state_dict(torch.load(config.model_avg_by_env_path))

<All keys matched successfully>

In [106]:
outputs, tr_weather_rep = evaluator.eval(torch.tensor(tr_weather), mdl_avg_over_geno)

In [107]:
outputs, test_weather_rep = evaluator.eval(torch.tensor(test_weather), mdl_avg_over_geno)
outputs, test_weather_rep_unique = evaluator.eval(torch.tensor(test_weather_unique), mdl_avg_over_geno)

In [108]:
outputs, val_weather_rep = evaluator.eval(torch.tensor(val_weather), mdl_avg_over_geno)

In [109]:
tr_weather_rep = tr_weather_rep.cpu().detach().numpy()
test_weather_rep = test_weather_rep.cpu().detach().numpy()
val_weather_rep = val_weather_rep.cpu().detach().numpy()
test_weather_rep_unique = test_weather_rep_unique.cpu().detach().numpy()
print(tr_weather_rep.shape)
print(test_weather_rep.shape)
print(val_weather_rep.shape)

(62556, 54)
(12684, 54)
(12347, 54)


## genotype representation

In [110]:
mdl_avg_over_env = networks.fc_avg_net(num_features = tr_genotypes.shape[1], hidden_dim=2000)
mdl_avg_over_env.load_state_dict(torch.load(config.model_avg_by_geno_path))

<All keys matched successfully>

In [111]:
outputs, tr_geno_rep = evaluator.eval(torch.tensor(tr_genotypes), mdl_avg_over_env)

In [112]:
outputs, test_geno_rep = evaluator.eval(torch.tensor(test_genotypes), mdl_avg_over_env)
outputs, test_geno_rep_unique = evaluator.eval(torch.tensor(test_genotypes_unique), mdl_avg_over_env)

In [113]:
outputs, val_geno_rep = evaluator.eval(torch.tensor(val_genotypes), mdl_avg_over_env)

In [114]:
tr_geno_rep = tr_geno_rep.cpu().detach().numpy()
test_geno_rep = test_geno_rep.cpu().detach().numpy()
val_geno_rep = val_geno_rep.cpu().detach().numpy()
test_geno_rep_unique = test_geno_rep_unique.cpu().detach().numpy()
print(tr_geno_rep.shape)
print(test_geno_rep.shape)
print(val_geno_rep.shape)

(62556, 296)
(12684, 296)
(12347, 296)


## Env specific pred

In [115]:
tr_data = np.concatenate((tr_geno_rep, tr_weather_rep), axis =1)
test_data = np.concatenate((test_geno_rep, test_weather_rep), axis=1)
val_data = np.concatenate((val_geno_rep, val_weather_rep), axis = 1)
test_data_unique = np.concatenate((test_geno_rep_unique, test_weather_rep_unique), axis=1)

In [116]:
tr_data.shape

(62556, 350)

In [117]:
with open(config.training_representation, 'wb') as wfile:
    pickle.dump(np.concatenate((tr_data, tr_target.reshape(-1,1)), axis=1), wfile)
    
with open(config.test_representation, 'wb') as wfile:
    pickle.dump(np.concatenate((test_data, test_target.reshape(-1,1)), axis=1), wfile)
    
with open(config.val_representation, 'wb') as wfile:
    pickle.dump(np.concatenate((val_data, val_target.reshape(-1,1)), axis=1), wfile)
    
with open(config.test_representation_unique_env, 'wb') as wfile:
    pickle.dump(np.concatenate((test_data_unique, test_target_unique.reshape(-1,1)), axis=1), wfile)

In [118]:
config.training_representation

'../processed_data/tr_rep_learned_v5.pkl'