In [5]:
import numpy as np
import torch
import pandas as pd
from torch import nn

## Gold standard baseline

In [6]:
encodings = torch.load('../data/interim/encodings.pt')
df = pd.read_csv("../data/processed/gold_standard.csv")

In [7]:
class Model(nn.Module):
    def __init__(self, input_size, output_size, hidden_dim, n_layers):
        super(Model, self).__init__()

        # Defining some parameters
        self.hidden_dim = hidden_dim
        self.n_layers = n_layers

        #Defining the layers
        # RNN Layer
        self.rnn = nn.RNN(input_size, hidden_dim, n_layers, batch_first=True)   
        # Fully connected layer
        self.fc = nn.Linear(hidden_dim, output_size)
    
    def forward(self, x):
        
        batch_size = x.size(0)

        #Initializing hidden state for first input using method defined below
        hidden = self.init_hidden(batch_size)

        # Passing in the input and hidden state into the model and obtaining outputs
        out, hidden = self.rnn(x, hidden)
        
        # Reshaping the outputs such that it can be fit into the fully connected layer
        out = out.contiguous().view(-1, self.hidden_dim)
        out = self.fc(out)
        
        return out, hidden
    
    def init_hidden(self, batch_size):
        # This method generates the first hidden state of zeros which we'll use in the forward pass
        hidden = torch.zeros(self.n_layers, batch_size, self.hidden_dim).to(device)
         # We'll send the tensor holding the hidden state to the device we specified earlier as well
        return hidden

## Snorkel end model with RoBERTA-LSTM input module and multitask output head

In [9]:
probas = np.load('../data/interim/snorkel_proba.npy')

In [None]:
from metal.end_model import EndModel

if torch.cuda.is_available():
    device = 'cuda'
else:
    device='cpu'
end_model = EndModel([1000,10,2], seed=123, device=device)

In [None]:
X = df['sentences']

end_model.train_model((X, Y_train_ps), lr=0.01, l2=0.01, batch_size=256, 
                n_epochs=5, checkpoint_metric='accuracy', checkpoint_metric_mode='max')