In [1]:
import pandas as pd
import numpy as np
import os
import json
from collections import defaultdict
from sklearn.preprocessing import StandardScaler
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader

from data_prep_scripts.data_countries import get_country_of_origin_data 
from data_prep_scripts.data_manipulation import df_replaceColVals_vars,process_remaining_categ_cols
from data_prep_scripts.process_repetitive_cols import get_repetitive_cols

## Load and preprocess raw data

In [2]:
# column we predict
col_to_predict = 'Triple Negative'
# preparing the df containing the raw data
medData = pd.read_excel('Merged File 2.5.19 De-identified.xlsx')
medData['Country of Origin for Father'].replace(76,'Ireland',inplace=True)
medData = medData[(medData['Breast Cancer?']=='Yes') & (medData['Gender'] == 'Female')]
medData = medData.applymap(lambda s:s.lower() if type(s) == str else s)
predCol = pd.read_excel('Additional BCD Info 3.1.19 De-identified.xlsx',usecols="A,DZ")
medData = medData.merge(predCol,left_on='ID #', right_on='ID #')
medData = medData[medData[col_to_predict]!='Unknown']

In [3]:
ind_repititive_blocks = [
    (47,89,7),(89,117,7),(117,152,7),
    (163,191,7),(197,213,4),(214,249,5),
    (294,348,6),(349,356,7),(357,364,7),
    (427,435,8)
]
ind_not_del = list(range(22,43)) + [274]
ind_to_del = [2,3,4,(8,11),13,451,(197,213),(349,356),(386,414),(448,469)]

In [4]:
# processes blocks of columns that contain the same type of information
# groups them and returns them in lists.
rep_cols = get_repetitive_cols(medData)
# remove empty blocks and replace block[1] with an empty list
# if there is no cont. element
rep_cols_np = [
    [
        (block[0].astype('int64').values, block[1].astype('float32').values)
        if isinstance(block[1],pd.DataFrame) and block[1].shape[1]>0 
        else (block[0].astype('int64').values,[]) for block in group  if block
    ] 
    for group in rep_cols
] 

In [5]:
# get features for country of origin columns
country_of_origin_data = get_country_of_origin_data(medData)

In [6]:
country_of_origin_data[1][1]

Unnamed: 0,Pop. Density (per sq. mi.),Literacy (%),Phones (per 1000),Birthrate,Deathrate,imr20052010,AvgYrsSchool10,GDPPerCap2010,WHR score,HDI2013,...,Phones (per 1000)_null,Birthrate_null,Deathrate_null,imr20052010_null,AvgYrsSchool10_null,GDPPerCap2010_null,WHR score_null,HDI2013_null,country_longitude_null,country_latitude_null
0,31.0,97.0,898.0,14.14,8.26,6.869402,13.27,43887.913415,7.119,0.913742,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,31.0,97.0,898.0,14.14,8.26,6.869402,13.27,43887.913415,7.119,0.913742,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,31.0,97.0,898.0,14.14,8.26,6.869402,13.27,43887.913415,7.119,0.913742,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,,,,,,29.583241,6.90,4878.795941,4.885,0.700154,...,1.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,,,,,,5.159836,,18907.150302,6.298,,...,1.0,1.0,1.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0
5,31.0,97.0,898.0,14.14,8.26,6.869402,13.27,43887.913415,7.119,0.913742,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
6,,,,,,10.681814,11.73,6351.230075,5.716,0.778303,...,1.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
7,31.0,97.0,898.0,14.14,8.26,6.869402,13.27,43887.913415,7.119,0.913742,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
8,31.0,97.0,898.0,14.14,8.26,6.869402,13.27,43887.913415,7.119,0.913742,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
9,31.0,97.0,898.0,14.14,8.26,6.869402,13.27,43887.913415,7.119,0.913742,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [7]:
# mixed_cols=[]
# for ii, (a,b) in enumerate(medData.dtypes.iteritems()):
#     if b == object:
#         types = set([type(k) for k in list(medData[a].dropna() ) ])
#         if str in types and len(types)>1: 
#             print(ii,a,b)

In [8]:
# columns that should only contain real numbered values, but contain strings in the raw data.
# this part of the code saves such string values in a json file, where we can specify
# the values to replace them with.
mixedCols_write_strs = [92,121,126,131,136,141,146,151,152,192,294,303,416,417,418,446,449]
dict_replaceColVals = df_replaceColVals_vars(medData,mixedCols_write_strs,str_vals=True)

In [9]:
medData.replace(dict_replaceColVals,inplace=True)

In [10]:
# replace outlier values with the values given in the json file
# where we can specify the values to replace them with
dict_replaceColVals_cont = df_replaceColVals_vars(medData,str_vals=False, cont_vals=True)
for k,v in dict_replaceColVals_cont.items():
    k_l, v_l = [], []
    for key, val in v.items():
        k_l.append(float(key)); v_l.append(val)
    medData[k].replace(k_l,v_l,inplace=True)
        

  out=out, **kwargs)
  ret = ret.dtype.type(ret / rcount)
  keepdims=keepdims)
  arrmean, rcount, out=arrmean, casting='unsafe', subok=False)
  ret = ret.dtype.type(ret / rcount)


In [11]:
col_min_max_count = 250
val_count_threshold = 50
col_values_replace_OTH = defaultdict(lambda:[])
cols_to_del=[]
for c in ind_to_del: 
    if isinstance(c,tuple): cols_to_del += list(medData.columns[c[0]:c[1]])
    else: cols_to_del.append(medData.columns[c])

for ii, (a,b) in enumerate(medData.dtypes.iteritems()):
    # do not delete the column in repetitive blocks
    in_rep_block = False
    for block in ind_repititive_blocks: 
        if (ii>=block[0]) and (ii < block[1]): in_rep_block = True
            
    if (not in_rep_block) and (b == object) and (ii not in ind_not_del):
        col_counts = medData[a].value_counts()        
        col_max_count = col_counts.iloc[0]
        # the most frequent value of a column should occur at least 'col_min_max_count' times
        if col_max_count < col_min_max_count: cols_to_del.append(a)
        elif medData.shape[0]-col_max_count < col_min_max_count: cols_to_del.append(a)
#         # for the columns that we preserve, we replace values if their 
#         # frequency is not above the given threshold 'val_count_threshold'.
#         else:
#             vals_to_OTH = list(col_counts[col_counts<val_count_threshold].index)
#             if vals_to_OTH: 
#                 col_values_replace_OTH[a] = vals_to_OTH
# print(len(col_values_replace_OTH.keys()))
# print(col_values_replace_OTH.keys())
# print(len(cols_to_del))
# print(cols_to_del)

In [12]:
# I deleted the datetime values for now. Will be added in the next iterations.
cols_to_del += list(medData.select_dtypes(include=['datetime']))
# cols processed by the 'get_repetitive_cols()' function
cols_repetitive = sum([list(medData.columns[c[0]:c[1]]) for c in ind_repititive_blocks],[])


In [13]:
remaining_obj_cols = [
    k for k in medData.select_dtypes(include=['object']) 
        if (k not in cols_to_del + cols_repetitive +[col_to_predict]) and 
           (not medData[k].isnull().all() )    
]
remaining_float_cols = [
    k for k in medData.select_dtypes(include=['float64']) 
        if (k not in cols_to_del + cols_repetitive +[col_to_predict]) and 
           (not medData[k].isnull().all() )    
]

In [14]:
medData_categ = process_remaining_categ_cols(medData[remaining_obj_cols])
medData_float = medData[remaining_float_cols]

In [15]:
#rep_cols
#country_of_origin_data

In [16]:
# for i,k in enumerate(medData.columns):
#     if k in remaining_obj_cols: print(i)

## Create model's inputs

In [17]:
# training and validation splits (80% and 20% respectively)
n_patients = medData.shape[0]
temp_indices = np.arange(n_patients)
np.random.seed(0)
np.random.shuffle(temp_indices)
train_ind, valid_ind = [
    temp_indices[ int(n_patients*c[0]):int(n_patients*c[1]) ]
    for c in [(0,0.8), (.8,1.)]
]

Single, unique columns

In [18]:
# floats
float_fields_scaler = StandardScaler()
float_fields_scaler.fit(medData_float.iloc[train_ind])

train_x_f = np.nan_to_num(
    float_fields_scaler.transform(medData_float.iloc[train_ind])
).astype('float32')
valid_x_f = np.nan_to_num(
    float_fields_scaler.transform(medData_float.iloc[valid_ind])
).astype('float32')

In [19]:
# categorical variables
train_x_c = medData_categ.iloc[train_ind].values.astype('int64')
valid_x_c = medData_categ.iloc[valid_ind].values.astype('int64')
categ_vars_max_vals = list(medData_categ.values.max(axis=0))

Blocks of columns that repeat several times (e.g. `1. relative with condition x`, `1. relative age`, `2. relative with condition x`, `2. relative age` etc.). They are processed sepearately in a way that would allow parameter sharing.

In [20]:
train_x_rep_list_np = [] 
valid_x_rep_list_np = [] 
std_scalers_f_rep = []
categ_data_rep_maxvals = [] # redundent but useful for now
rep_blocks_dims_list = [] #list of tuples of (n_rep, c_n_fields, f_n_fields, c_max_vals)

x_rep_list = []
for group in rep_cols_np:
    rep_cols_c, rep_cols_f = list(zip(*group))
    rep_cols_c, rep_cols_f = np.array(rep_cols_c), np.array(rep_cols_f)
    # categorical values
    g_n_rep, g_n_batch, g_n_fields_c = rep_cols_c.shape
    rep_cols_c = rep_cols_c.transpose((1,0,2)).reshape(-1,g_n_fields_c)
    categ_data_rep_maxvals.append(list(rep_cols_c.max(axis=0)))
    rep_cols_c = rep_cols_c.reshape(g_n_batch,g_n_rep,g_n_fields_c)
    # cont values
    if rep_cols_f.ndim == 3:
        g_n_fields_f = rep_cols_f.shape[-1]
        rep_blocks_dims_list.append((g_n_rep,g_n_fields_c,g_n_fields_f, categ_data_rep_maxvals[-1]))
        rep_cols_f = rep_cols_f.transpose((1,0,2)).reshape(-1,g_n_fields_f)
        g_scaler = StandardScaler()
        g_scaler.fit(rep_cols_f[train_ind])
        std_scalers_f_rep.append(g_scaler)
        rep_cols_f = np.nan_to_num(g_scaler.transform(rep_cols_f))
        rep_cols_f = rep_cols_f.reshape(g_n_batch,1,g_n_rep,g_n_fields_f)
    else: 
        rep_blocks_dims_list.append((g_n_rep,g_n_fields_c,0, categ_data_rep_maxvals[-1]))
        std_scalers_f_rep.append(None)
        rep_cols_f = None
    
    train_x_rep_list_np.append((
        rep_cols_c[train_ind],
        rep_cols_f[train_ind] if isinstance(rep_cols_f,np.ndarray) else None 
        ))
    valid_x_rep_list_np.append((
        rep_cols_c[valid_ind],
        rep_cols_f[valid_ind] if isinstance(rep_cols_f,np.ndarray) else None 
        ))

Processing the country of columns: `Country of Origin for Patient`, `Country of Origin for Mother`, `Country of Origin for Father`

We learn one embedding for each country and use the same country embedding for each of these columns. Note that the countries are featurized, and represented by categorical variables (e.g. continent, continent_subregion, development_level, income_group, etc.) as well as some continuous variables (population_density, birth_rate, death_rate, human_development_index, etc.). We also retain the country_name as a categorical features if a given country has more samples than the given threshold in our dataset. Other countries are solely represented by the features that are mentioned above.

In [21]:
country_of_origin_np =  [
        (block[0].astype('int64').values, block[1].astype('float32').values)
        for block in country_of_origin_data
    ] 

In [22]:
country_cols_c, country_cols_f = list(zip(*country_of_origin_np))
country_cols_c, country_cols_f = np.array(country_cols_c), np.array(country_cols_f)
# categorical values
g_n_rep, g_n_batch, g_n_fields_c = country_cols_c.shape
country_cols_c = country_cols_c.transpose((1,0,2)).reshape(-1,g_n_fields_c)

In [23]:
country_cols_c, country_cols_f = list(zip(*country_of_origin_np))
country_cols_c, country_cols_f = np.array(country_cols_c), np.array(country_cols_f)
# categorical values
g_n_rep, g_n_batch, g_n_fields_c = country_cols_c.shape
country_cols_c = country_cols_c.transpose((1,0,2)).reshape(-1,g_n_fields_c)
country_data_maxvals = list(country_cols_c.max(axis=0))
country_cols_c = country_cols_c.reshape(g_n_batch,g_n_rep,g_n_fields_c)
# cont values
g_n_fields_f = country_cols_f.shape[-1]
country_block_dims = (g_n_rep,g_n_fields_c,g_n_fields_f, country_data_maxvals)  #
country_cols_f = country_cols_f.transpose((1,0,2)).reshape(-1,g_n_fields_f)
g_scaler = StandardScaler()
g_scaler.fit(country_cols_f[train_ind])
std_scalers_f_country = g_scaler    #
country_cols_f = np.nan_to_num(g_scaler.transform(country_cols_f))
country_cols_f = country_cols_f.reshape(g_n_batch,1,g_n_rep,g_n_fields_f)

train_x_country_np = (              #
    country_cols_c[train_ind],
    country_cols_f[train_ind]
    )
valid_x_country_np = (              #
    country_cols_c[valid_ind],
    country_cols_f[valid_ind]
    )

predicted variable

In [24]:
y = (medData[col_to_predict] == 'Yes').values.astype('int64')
y_train, y_valid = y[train_ind], y[valid_ind]

## Define the dataset and prepare the dataloaders

In [25]:
class NYULH_DS(Dataset):
    def __init__(self, x_f, x_c, x_rep_list_np, x_country_np, y):
        super(NYULH_DS, self).__init__()
        self.batch_size = x_f.shape[0]
        self.x_rep_list_np = x_rep_list_np
        self.x_country_np = x_country_np
        assert x_f.shape[0] == x_c.shape[0] == x_country_np[0].shape[0] == x_country_np[1].shape[0], 'number of rows do not match'
        for ii, arr in enumerate(x_rep_list_np): 
            assert x_f.shape[0] == arr[0].shape[0],\
            f'number of rows do not match for the categorical data at index {ii}'
            if isinstance(arr[1],np.ndarray) : 
                assert x_f.shape[0] == arr[1].shape[0], \
                f'number of rows do not match for the cont. data at index {ii}'
        
        self.x_f = x_f
        self.x_c = x_c
        self.x_f_rep_list_np = x_rep_list_np
        self.y = y
    def __len__(self):
        return self.batch_size
        
    def __getitem__(self, i):     
        return (
            (
                self.x_f[i], 
                self.x_c[i], 
                [(k[0][i],k[1][i] if isinstance(k[1],np.ndarray) else np.array(()))
                 for k in self.x_rep_list_np],
                (self.x_country_np[0][i],self.x_country_np[1][i])
            ),
            self.y[i]
        )

In [26]:
batch_size = 32

ds_train = NYULH_DS(train_x_f, train_x_c, train_x_rep_list_np, train_x_country_np, y[train_ind])
ds_valid = NYULH_DS(valid_x_f, valid_x_c, valid_x_rep_list_np, valid_x_country_np, y[valid_ind])
# balance classes
weights_weights_train = 1/torch.tensor([(y[train_ind]==0).sum(), (y[train_ind]==1).sum()], dtype=torch.float)
weights_weights_valid = 1/torch.tensor([(y[valid_ind]==0).sum(), (y[valid_ind]==1).sum()], dtype=torch.float)
samples_weights_train = weights_weights_train[y[train_ind]]
samples_weights_valid = weights_weights_valid[y[valid_ind]]

t_sampler = torch.utils.data.sampler.WeightedRandomSampler(samples_weights_train.clone().detach(), batch_size)
v_sampler = torch.utils.data.sampler.WeightedRandomSampler(samples_weights_valid.clone().detach(), batch_size)
train_loader = DataLoader(ds_train,batch_size=batch_size, sampler=t_sampler, num_workers=1)
valid_loader = DataLoader(ds_valid,batch_size=batch_size, sampler=v_sampler, num_workers=1)

  self.weights = torch.tensor(weights, dtype=torch.double)


## Create the model classes

In [27]:
class Conv_On_Blocks(nn.Module):
    '''
    This module is used for inputs that consists blocks of repetitive information, for example:
    `1. relative with condition x`, `1. relative age`, `2. relative with condition x`, `2. relative age`,..
    In such cases, we should be indifferent to where the information is given, and use it the same way
    regardless of the block that it is given in. We induce this behavior by using convolution operations
    and by doing parameter sharing.
    Input of the model is a tuple of (`categorical vars`, `continuous vars`)
    '''
    def __init__(
        self, n_rep, c_n_fields, f_n_fields, c_max_vals, n_hid_func=None, n_emb_func=None, country_data=False
    ):
        super(Conv_On_Blocks, self).__init__()
        assert n_rep>=0 and c_n_fields >= 0 and f_n_fields>=0 and c_n_fields+f_n_fields> 0, 'conv_on_block: invalid input'
        self.n_rep = n_rep
        self.c_n_fields = c_n_fields
        self.f_n_fields = f_n_fields
        self.categ_input, self.float_input = c_n_fields>0, f_n_fields>0
        
        self.n_emb_func = n_emb_func if n_emb_func else lambda x: int(np.sqrt(x))+1 
        self.n_hid_func = n_hid_func if n_hid_func else lambda x: int(np.round(x/2+1))
        
        self.country_data = country_data
        
        if self.categ_input:
            self.embeddings = nn.ModuleList(
                [nn.Embedding(k+1,self.n_emb_func(k+1)) for k in c_max_vals]
            )
            self.c_n_hidden = sum([self.n_emb_func(k+1) for k in c_max_vals])
        else:
            self.c_n_hidden = 0
            
        self.n_hidden_1 = max(3, self.n_hid_func(self.f_n_fields+self.c_n_hidden))
        self.n_hidden_2 = max(3, self.n_hid_func(self.n_hidden_1))
        
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=self.n_hidden_1, kernel_size=(1,self.f_n_fields+self.c_n_hidden))
        self.conv2 = nn.Conv2d(in_channels=1, out_channels=self.n_hidden_2, kernel_size=(1,self.n_hidden_1))
        self.fc1 = nn.Linear(
            self.n_hidden_2 if not country_data else self.n_hidden_2*self.n_rep,
            self.n_hidden_2 )

    def forward(self, x):
        if self.categ_input:
            x_c = x[0].view(-1,self.c_n_fields)            
            x_c = torch.cat(
                [self.embeddings[k](x_c[:,k]) for k in range(self.c_n_fields)],
                dim=1
            ).reshape(-1, 1, self.n_rep, self.c_n_hidden)
            
            if self.float_input: x = torch.cat([x_c,x[1]],dim=-1)
            else: x = x_c
        else:
            x = x[1]
            
        x = F.leaky_relu(self.conv1(x), negative_slope=0.2, inplace=True)
        x = x.transpose(1,3)
        x = F.leaky_relu(self.conv2(x), 0.2, True).view(-1, self.n_hidden_2, self.n_rep)
        if self.country_data: 
            return self.fc1(x.view(-1,n_hidden_2*n_rep))
        else:
            x = F.adaptive_avg_pool2d(x,(self.n_hidden_2,1)) if not self.country_data else x.view(-1,n_hidden_2*n_rep)
            x = self.fc1(x.view(-1, self.n_hidden_2))
            return x

In [28]:
class Net(nn.Module):
    def __init__(
        self, n_float_fields, n_categ_fields,categ_vars_max_vals, rep_blocks_dims_list, 
        country_block_dims, latent_size,
        x_drop_p = 0.2, h_drop_p=0.5, z_noise=1., train_mode=True, n_emb_func=None, autoencoder=False
    ):
        super(Net, self).__init__()
        self.n_float_fields = n_float_fields
        self.n_categ_fields = n_categ_fields
        self.categ_vars_max_vals = categ_vars_max_vals
        self.rep_blocks_dims_list = rep_blocks_dims_list
        self.country_block_dims = country_block_dims 
        self.latent_size = latent_size
        self.x_drop_p = x_drop_p
        self.h_drop_p = h_drop_p
        self.z_noise = z_noise
        self.train_mode = train_mode
        self.autoencoder = autoencoder
        
        self.n_emb_func = n_emb_func if n_emb_func else lambda x: int(np.sqrt(x))+1 
        
        self.x_embeddings = nn.ModuleList(
            [nn.Embedding(k+1,self.n_emb_func(k+1)) for k in categ_vars_max_vals]
        )
        self.c_n_hidden = sum([self.n_emb_func(k+1) for k in categ_vars_max_vals])
        
        self.x_rep_module = nn.ModuleList(
            [Conv_On_Blocks(*k) for k in rep_blocks_dims_list]
        )
        self.rep_n_hidden = sum([k.fc1.out_features for k in self.x_rep_module])
        
        self.x_country_module = Conv_On_Blocks(*(list(country_block_dims)+[lambda x: int(np.round(x/4+1)), None, True]) )
        self.country_n_hidden = self.x_country_module.fc1.out_features
        
        self.n_hidden_0 = self.c_n_hidden + self.rep_n_hidden + n_float_fields
        self.n_hidden_1 = int(self.n_hidden_0/2 )+1
        self.n_hidden_2 = int((self.n_hidden_1+latent_size)/2)
        self.n_hidden_y0 = int(latent_size*3/5)
        self.n_hidden_d1 = int((n_float_fields+n_categ_fields)*.7)
        
        self.enc_linear1 = nn.Linear(self.n_hidden_0,self.n_hidden_1)
        self.enc_linear2 = nn.Linear(self.n_hidden_1,self.n_hidden_2)
        self.enc_linear3 = nn.Linear(self.n_hidden_2,self.latent_size)
        
        
        self.dec_linear1 = nn.Linear(self.latent_size,self.n_hidden_d1)
        self.dec_linear2 = nn.Linear(self.n_hidden_d1,n_float_fields+n_categ_fields)
        
        self.y_linear1 = nn.Linear(self.latent_size,self.n_hidden_y0)
        self.y_linear2 = nn.Linear(self.n_hidden_y0,1)
                
    def forward(self, x, verbose=False):
        x = torch.cat(
            [x[0]] +\
            [self.x_embeddings[k](x[1][:,k]) for k in range(self.n_categ_fields)] +\
            [x_rep_module(k) for k in x[2]] +\
            [self.x_country_module[3]],
            dim = -1            
        )
        x = F.leaky_relu(self.enc_linear1(F.dropout(x,self.x_drop_p,self.train_mode,inplace=True)), .2,True)
        x = F.leaky_relu(self.enc_linear2(F.dropout(x,self.h_drop_p,self.train_mode,inplace=True)), .2,True)
        z = self.enc_linear3(F.dropout(x,self.h_drop_p,self.train_mode,inplace=True))
        x = z + torch.randn_like(z, requires_grad=False)*self.z_noise if self.train_mode else z
        
        if autoencoder:            
            x = F.leaky_relu(self.dec_linear1(x), .2,True)
            x = self.dec_linear2(F.dropout(x,self.h_drop_p,self.train_mode,inplace=True))
            return x,z
        else:
            x = F.leaky_relu(self.y_linear1(x), .2,True)
            x = self.y_linear2(F.dropout(x,self.h_drop_p,self.train_mode,inplace=True))
            return x,z
        

In [33]:
# for x,y in train_loader:
#     break
# temp_b_conv = Conv_On_Blocks(7,5,3,categ_data_rep_maxvals[0])
# temp_b_conv(x[2][0]).shape

# tempnet = Net(16,81,categ_vars_max_vals,rep_blocks_dims_list,country_block_dims,20)
# tempnet

In [30]:
model = Net(
    n_float_fields=16,
    n_categ_fields=81,
    categ_vars_max_vals=categ_vars_max_vals,
    rep_blocks_dims_list=rep_blocks_dims_list,
    country_block_dims=country_block_dims,
    latent_size=15
)

In [31]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [32]:
learning_rate = 1e-3
optimizer = torch.optim.Adam(
    model.parameters(),
    lr=learning_rate,
)