In [45]:
import torch
import torch.nn as nn

In [46]:
DEVICE = torch.cuda.device("cuda:0" if torch.cuda.is_available else "cpu")
EPOCHS = 10
TIMESTEP = 2000
LR = 3e-4
BATCH_SIZE = 10
NUM_WORKERS = 2

In [76]:
class Convolute(nn.Module):
    def __init__(self, in_filters, out_filters, kernel_size, dropout=0.0, maxpool=0):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv1d(in_filters, out_filters, kernel_size),
            nn.BatchNorm1d(out_filters),
            nn.ReLU()
        )
        if dropout!=0: self.conv.append(nn.Dropout(dropout))
        if maxpool!=0: self.conv.append(nn.MaxPool1d(maxpool))

    def forward(self, x):
        return self.conv(x)

class DNF(nn.Module):
    def __init__(self, in_filters, out_filters, do_transpose=False, do_flatten=False):
        super().__init__()
        self.do_transpose = do_transpose
        self.dnf = nn.Sequential(nn.Linear(in_filters, out_filters))
        if do_flatten: self.dnf.append(nn.Flatten())
    
    def forward(self, x):
        if self.do_transpose: x = torch.transpose(x, 2, 1)
        return self.dnf(x)
    

In [84]:
x1=torch.randn(1, 2, 1925)
head1 = []
head1.append(Convolute(2, 576, 11))
head1.append(Convolute(576, 484, 11, 0.3, 4))
head1.append(Convolute(484, 400, 5))
head1.append(Convolute(400, 324, 5, 0.2))
head1.append(DNF(324, 256, True, True))
head1.append(DNF(119808, 150))
for obj in head1: x1=obj(x1)
x1.shape

torch.Size([1, 150])

In [85]:
x2=torch.randn(1, 2, 1925)
head2 = []
head2.append(Convolute(2, 576, 11))
head2.append(Convolute(576, 484, 11, 0.3, 4))
head2.append(Convolute(484, 400, 5))
head2.append(Convolute(400, 324, 5, 0.2))
head2.append(DNF(324, 256, True, True))
head2.append(DNF(119808, 150))
for obj in head2: x2=obj(x2)
x2.shape

torch.Size([1, 150])

In [87]:
x = torch.cat([x1, x2], dim=1)
x.shape

torch.Size([1, 300])