In [139]:
import torch
import torch.nn as nn

In [2]:
%%writefile dualbranchconvolution.py
import torch
import torch.nn as nn
class DualConv(nn.Module):
    def __init__(self, in_channels:int):
        super().__init__()
        self.conv1 = nn.Sequential(
            nn.Conv1d(in_channels=in_channels, out_channels=in_channels, kernel_size=3,stride=1, padding=1),
            nn.BatchNorm1d(in_channels),
            nn.ReLU(),
            nn.Conv1d(in_channels=in_channels, out_channels=in_channels, kernel_size=3,stride=1, padding=1),
            nn.BatchNorm1d(in_channels),
            nn.ReLU(),
            nn.Conv1d(in_channels=in_channels, out_channels=in_channels, kernel_size=3,stride=1, padding=1),
            nn.BatchNorm1d(in_channels),
            nn.ReLU()
        )
        self.conv2 = nn.Sequential(
            nn.Conv1d(in_channels=in_channels, out_channels=in_channels, kernel_size=5,stride=1, padding=2),
            nn.BatchNorm1d(in_channels),
            nn.ReLU(),
            nn.Conv1d(in_channels=in_channels, out_channels=in_channels, kernel_size=5,stride=1, padding=2),
            nn.BatchNorm1d(in_channels),
            nn.ReLU(),
            nn.Conv1d(in_channels=in_channels, out_channels=in_channels, kernel_size=5,stride=1, padding=2),
            nn.BatchNorm1d(in_channels),
            nn.ReLU()
        )
    def forward(self, X):
        out1 = self.conv1(X)
        out2 = self.conv2(X)
        conc = torch.concat((out1,out2), dim =1)
        return conc        

Writing dualbranchconvolution.py


In [141]:
a= torch.rand((625,32,10))
model = nn.Sequential(
            DualConv(32),
            DualConv(64),
            DualConv(128)
)
y = model(a)
y.shape, a.shape

(torch.Size([625, 256, 10]), torch.Size([625, 32, 10]))

In [5]:
%%writefile SicsbdMOSFET_model.py
import torch
import torch.nn as nn
from dualbranchconvolution import DualConv
class SicsbdMOSFET(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Sequential(
            nn.Linear(in_features=4, out_features=64),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Linear(64,128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Linear(128,320),
            nn.BatchNorm1d(320),
            nn.ReLU()
        )
        self.reshape = (64,5)
        self.transposed_cnn = nn.Sequential(
            nn.ConvTranspose1d(in_channels=64,out_channels=32,kernel_size=1),
            nn.BatchNorm1d(32),
            nn.ReLU()
        )
        self.dualcnn = nn.Sequential(
            DualConv(32),
            DualConv(64),
            DualConv(128)
        )
        self.cnn = nn.Sequential(
            nn.Conv1d(256, 128, 3, 1, 1),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Conv1d(128, 64, 3, 1, 1),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Conv1d(64, 32, 3, 1, 1),
            nn.BatchNorm1d(32),
            nn.ReLU(),
        )
        
        self.fc2 = nn.Sequential(
         nn.Flatten(),
         nn.Linear(32*5, 64),
         nn.BatchNorm1d(64),
         nn.ReLU(),
         nn.Linear(64, 32),
         nn.BatchNorm1d(32),
         nn.ReLU(),
         nn.Linear(32, 4),            
        )
    def forward(self, X):
        X = self.fc1(X)
        print("After fc1",X.shape)
        X = X.view(X.shape[0], *self.reshape)     
        print("before transposed",X.shape)   
        X = self.transposed_cnn(X)
        print("after transposed",X.shape)
        X = self.dualcnn(X)
        print("after dualcnn",X.shape)
        X = self.cnn(X)
        print("after cnn",X.shape)
        X = self.fc2(X)
        print("Final shape: ",X.shape)
        return X

Overwriting SicsbdMOSFET_model.py


In [163]:
model = SicsbdMOSFET()

In [164]:
X = torch.rand((625,4))

In [168]:
y_pred = model(X)

After fc1 torch.Size([625, 320])
before transposed torch.Size([625, 64, 5])
after transposed torch.Size([625, 32, 5])
after dualcnn torch.Size([625, 256, 5])
after cnn torch.Size([625, 32, 5])
Final shape:  torch.Size([625, 4])


In [169]:
y_pred.shape

torch.Size([625, 4])