In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [4]:
class CombinedAECF(nn.Module):
    def __init__(self, autoencoder, classifier):
        super(CombinedAECF, self).__init__()
        self.encoder = autoencoder
        self.cf = classifier

    def forward(self, x):
        encoded = self.encoder(x[:, :, 4:])
        out = self.cf(encoded, x)
        return out

In [None]:
class CombinedAESVM(nn.Module):
    def __init__(self, autoencoder, classifier):
        super(CombinedAESVM, self).__init__()
        self.encoder = autoencoder
        self.cf = classifier

    def forward(self, x):
        encoded = self.encoder(x[:, :, 4:])
        encoded = encoded.cpu().detach().numpy()
        out = self.cf.predict_proba(encoded)
        out = torch.from_numpy(out).to(torch.float32).to('cuda')
        return out

In [None]:
class CombinedAESVMRule(nn.Module):
    def __init__(self, autoencoder, classifier):
        super(CombinedAESVMRule, self).__init__()
        self.encoder = autoencoder
        self.cf = classifier
        self.rule = StaticFeatureEncoder()

    def forward(self, x):
        static_out, condition = self.rule(x.clone())
        encoded = self.encoder(x[:, :, 4:])
        encoded = torch.cat((encoded, static_out), dim=1)
        encoded = encoded.cpu().detach().numpy()
        out = self.cf.predict_proba(encoded)
        out = torch.from_numpy(out).to(torch.float32).to('cuda')
        return out#*condition

In [None]:
class StaticFeatureEncoder(nn.Module):
    """ Encodes categorical features """
    def __init__(self, n_features=3+2*2):
        super(StaticFeatureEncoder, self).__init__()
        
        self.n_features = n_features
        enc = OneHotEncoder()
        
        #lane_types= ['travel', 'through', 'express', 'aux']
        lane_types = np.array([1, 2, 3, 4]).reshape(-1, 1)
        self.lane_enc = torch.from_numpy(enc.fit_transform(lane_types).toarray())

        v_class_arr = np.array([1, 2, 3]).reshape(-1, 1)
        self.v_enc = torch.from_numpy(enc.fit_transform(v_class_arr).toarray())
        
        neighbor_lanes = np.array([0, 1]).reshape(-1, 1)
        self.lane_neigh = torch.from_numpy(enc.fit_transform(neighbor_lanes).toarray())
        
        self.left_lane_exists = torch.ones(2, 3, device='cuda')
        self.left_lane_exists[0,0] = 0
        
        self.right_lane_exists = torch.ones(2, 3, device='cuda')
        self.right_lane_exists[0,-1] = 0
        
    def get_lane_enc(self, lane_n):
        return self.lane_enc[lane_n-1, :]
    
    def get_left_lane_enc(self, lane_n):
        return self.lane_neigh[lane_n, :], self.left_lane_exists[lane_n, :]
    
    def get_right_lane_enc(self, lane_n):
        return self.lane_neigh[lane_n, :], self.right_lane_exists[lane_n, :]

    def get_v_enc(self, v_n):
        return self.v_enc[v_n-1, :]
    
    def encode(self, batch):
        y = torch.empty(1, self.n_features, device='cuda')
        y[0, 0:3] = self.get_v_enc(batch[0])
        y[0, 3:5], le = self.get_left_lane_enc(batch[2])
        y[0, 5:], re = self.get_right_lane_enc(batch[3])
        
        return y, le*re
    
    def encode_batch(self, tensor):
        y = torch.empty(tensor.size(0), self.n_features, device='cuda')
        ex = torch.empty(tensor.size(0), 3, device='cuda')
        for batch in range(0, tensor.size(0)):
            y[batch, :], ex[batch, :] = self.encode(tensor[batch, :])
        return y, ex
    
    def forward(self, x):
        static_x = x[:, -1, 0:4].type(torch.int32)
        out, ex = self.encode_batch(static_x)
        return out, ex