In [2]:
import torch
import sys
import os
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset,DataLoader
import pickle
import numpy as np
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
import time
import logging

In [3]:
torch.set_printoptions(sci_mode=False)
np.set_printoptions(suppress=True)
device = torch.device('cuda:1')

In [4]:
def get_logger(filename, name=None):
    formatter = logging.Formatter(
        " %(message)s"
    )
    logger = logging.getLogger(name)
    logger.setLevel(logging.INFO)
 
    fh = logging.FileHandler(filename, "w")
    fh.setFormatter(formatter)
    logger.addHandler(fh)
 
    sh = logging.StreamHandler()
    sh.setFormatter(formatter)
    logger.addHandler(sh)
 
    return logger

In [4]:
# !jupyter nbconvert --to python Action_Detetction.ipynb

## 子模型定义

### ResNet模型(分析STFT频谱)

In [5]:
class Block(nn.Module):
    def __init__(self,in_size,out_size,stride=1,downsampling=False, expansion = 4):
        super().__init__()
        self.expansion=expansion
        self.downsampling=downsampling
        self.resblock=nn.Sequential(
            nn.Conv2d(in_channels=in_size,out_channels=out_size,kernel_size=1,stride=1, bias=False),
            nn.BatchNorm2d(out_size),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=out_size,out_channels=out_size,kernel_size=3,stride=stride, padding=1,bias=False),
            nn.BatchNorm2d(out_size),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=out_size,out_channels=out_size*self.expansion,kernel_size=1,stride=1, bias=False),
            nn.BatchNorm2d(out_size*self.expansion)
        )
        if self.downsampling:
            self.downsample=nn.Sequential(
                nn.Conv2d(in_channels=in_size,out_channels=out_size*self.expansion,kernel_size=1,stride=stride,bias=False),
                nn.BatchNorm2d(out_size*self.expansion)
            )
        self.relu = nn.ReLU(inplace=True)
    def forward(self,x):
        residual=x
        out=self.resblock(x)
        if self.downsampling:
            residual = self.downsample(x)
        out+=residual
        return self.relu(out)
    
class ResNet(nn.Module):
    def __init__(self,blocks, encoder_len=1024,expansion = 4):
        super().__init__()
        self.expansion = expansion
        self.conv1=nn.Sequential(
            nn.Conv2d(in_channels=3,out_channels=64,kernel_size=7,stride=2,padding=3, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )
        self.block1=self.get_block(64,64,blocks[0],1)
        self.block2=self.get_block(256,128,blocks[1],2)
        self.block3=self.get_block(512,256,blocks[2],2)
        self.block4=self.get_block(1024,512,blocks[3],2)
        self.avgpool=nn.AvgPool2d(7, stride=1)
        self.fc=nn.Linear(2048,encoder_len)
    def get_block(self,in_size,out_size,block,stride):
        blocks=[]

        blocks.append(Block(in_size,out_size,stride,downsampling=True))
        for i in range(1,block):
            blocks.append(Block(out_size*self.expansion,out_size))
        return nn.Sequential(*blocks)
    
    def forward(self,x):
        b=x.shape[0]
        x=self.conv1(x)
        x=self.block1(x)
        x=self.block2(x)
        x=self.block3(x)
        x=self.block4(x)
        x=self.avgpool(x)
        x=x.reshape(b,-1)
        return self.fc(x)
    
def ResNet50():
    return ResNet([3, 4, 6, 3])

def ResNet101():
    return ResNet([3, 4, 23, 3])

def ResNet152():
    return ResNet([3, 8, 36, 3])

### LSTM模型（分析MFCC）

In [6]:
class LSTM(nn.Module):
 
    def __init__(self,num_layers=1):
        super().__init__()
        self.lstm= nn.LSTM(24,64,num_layers,dropout=0.1,batch_first=True,bidirectional=True)
        self.fc=nn.Linear(128*num_layers,256*num_layers)

    def forward(self, x):
        output,(h,c)=self.lstm(x)
        h=rearrange(h,'n b d -> b (n d)')
        out=self.fc(h)

        return out

## 训练模型1（Triplet Loss）

### 生成训练数据

In [None]:
class Train_Data1(Dataset):
    def __init__(self,length):
        with open('./model_data/train_high_img.data','rb') as f:
            data=pickle.load(f)
            self.labels=np.array(data['labels'])
            self.images=data['images']
            self.length=length
        with open('./model_data/train_low.data','rb') as f:
            data=pickle.load(f)
            self.mfccs=data['mfccs']
    def __getitem__(self,index):
        anchor=np.random.choice(np.arange(0,len(self.labels)))
        class_a=self.labels[anchor]
        postive=np.random.choice(np.argwhere(self.labels==class_a).T[0])
        negative=np.random.choice(np.argwhere(self.labels!=class_a).T[0])
        return self.images[anchor],self.images[postive],self.images[negative],self.mfccs[anchor],self.mfccs[postive],self.mfccs[negative]
    
    def __len__(self):
        return self.length

In [None]:
dataset_train=Train_Data1(6400)
data_train=DataLoader(dataset_train,batch_size=32,shuffle=True)

### 训练模型定义

In [None]:
class Siamese_Train1(nn.Module):
    def __init__(self,high,low):
        super().__init__()
        self.high=high
        self.low=low
        self.linear=nn.Linear(1536,1024)
    def forward(self, ha,hp,hn,la,lp,ln):
        zha=self.high(ha)
        zhp=self.high(hp)
        zhn=self.high(hn)
        zla=self.low(la)
        zlp=self.low(lp)
        zln=self.low(ln)
        za=self.linear(torch.cat((zha,zla),1))
        zp=self.linear(torch.cat((zhp,zlp),1))
        zn=self.linear(torch.cat((zhn,zln),1))
        return za,zp,zn

### 开始训练

In [None]:
def train(num_epoches,optimizer,model,loss_fn,data_train,save_itercept=5):
    model.train()
    logger1 = get_logger('./logger/train1_Step.log','step')
    logger2 = get_logger('./logger/res1_Epoch.log','epoch')
    logger3 = get_logger('./logger/tensor1_vec.log','vec')
    for epoch in range(1,num_epoches+1):
        loss_sum=0
        begin=time.time()
        logger1.info('=============Epoch:{} Strat!============='.format(epoch))
        for i,(ha,hp,hn,la,lp,ln) in enumerate(data_train):
            ha=ha.to(device)
            hp=hp.to(device)
            hn=hn.to(device)
            la=la.to(device)
            lp=lp.to(device)
            ln=ln.to(device)
            za,zp,zn=model(ha,hp,hn,la,lp,ln)
            loss=loss_fn(za,zp,zn)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if (i+1)%50==0:logger3.info('Epoch:{}'.format(epoch)+str(za))
            logger1.info('Epoch:{} on Step[{}/{}], Loss={:.4f}'.format(epoch ,i+1,len(data_train),float(loss)))
            loss_sum+=float(loss)
        end=time.time()
        logger2.info('Epoch:{}, Loss Mean={:.4f}, Runtime={:.4f}s'.format(epoch,(float(loss_sum)/len(data_train)),(end-begin)))
        logger1.info('=============Epoch:{} Ended!============='.format(epoch))
        if epoch%save_itercept==0:
            torch.save(model.high,'./model/model3.0/high3.0')
            torch.save(model.low,'./model/model3.0/low3.0')
            torch.save(model.linear,'./model/model3.0/linear3.0')
            logger2.info('|| Epoch:{} Model has been saved'.format(epoch))
        if (float(loss_sum)/len(data_train)) <0.002:
            break

In [None]:
high=ResNet50()
low=LSTM(2)
model=Siamese_Train1(high,low).to(device)
optimizer=torch.optim.Adam(model.parameters(),lr=1e-5)
loss_fn = nn.TripletMarginLoss(margin=1.5, p=2)

In [None]:
train(50,optimizer,model,loss_fn,data_train,save_itercept=5)

## 训练模型2(one2one分类测试)

### 生成训练数据

In [None]:
class Train_Data2(Dataset):
    def __init__(self,length):
        with open('./model_data/train_high_img.data','rb') as f:
            data=pickle.load(f)
            self.labels=np.array(data['labels'])
            self.images=data['images']
        with open('./model_data/train_low.data','rb') as f:
            data=pickle.load(f)
            self.mfccs=data['mfccs']
        self.length=length
    def __getitem__(self,index):
        issame=np.random.choice([True,False],p=[0.3,0.7])
        if issame:
            cls=np.random.choice(np.arange(1,8))
            cls_index=np.argwhere(self.labels==cls).T[0]
            x1,x2=np.random.choice(cls_index,size=2,replace=False)
            return self.images[x1],self.images[x2],self.mfccs[x1],self.mfccs[x2],torch.tensor(1.0)
        else:
            x1=np.random.choice(np.arange(0,len(self.labels)))
            cls_index=np.argwhere(self.labels!=self.labels[x1]).T[0]
            x2=np.random.choice(cls_index)
            return self.images[x1],self.images[x2],self.mfccs[x1],self.mfccs[x2],torch.tensor(0.0)
    def __len__(self):
        return self.length

In [None]:
dataset_train=Train_Data2(6400)
data_train=DataLoader(dataset_train,batch_size=48,shuffle=True)

In [None]:
class Siamese_Train2(nn.Module):
    def __init__(self):
        super().__init__()
        self.high=torch.load('./model/model3.0/high3.0')
        self.low=torch.load('./model/model3.0/low3.0')
        self.linear=torch.load('./model/model3.0/linear3.0')
        self.out=torch.load('./model/model3.0/out3.0')
    def forward(self, h1, h2,l1,l2):
        zh1=self.high(h1)
        zh2=self.high(h1)
        zl1=self.low(l1)
        zl2=self.low(l2)
        z1=self.linear(torch.cat((zh1,zl1),1))
        z2=self.linear(torch.cat((zh2,zl2),1))
        x=z1-z2
        x=self.out(x)
        return x.reshape(-1)

### 开始训练

In [None]:
model=Siamese_Train2().to(device)
optimizer=torch.optim.Adam(model.parameters(),lr=1e-4)
loss_fn = nn.BCELoss()

In [None]:
def train(num_epoches,optimizer,model,loss_fn,data_train,save_itercept=5):
    model.train()
    logger1 = get_logger('./logger/train2_Step.log','step')
    logger2 = get_logger('./logger/res2_Epoch.log','epoch')
    logger3 = get_logger('./logger/tensor2_vec.log','vec')
    for epoch in range(1,num_epoches+1):
        loss_sum=0
        begin=time.time()
        logger1.info('=============Epoch:{} Strat!============='.format(epoch))
        for i,(h1, h2,l1,l2,label) in enumerate(data_train):
            h1=h1.to(device)
            h2=h2.to(device)
            l1=l1.to(device)
            l2=l2.to(device)
            label=label.to(device)
            res=model(h1, h2,l1,l2)
            loss=loss_fn(res,label)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if (i+1)%50==0:logger3.info('Epoch:{}'.format(epoch)+str(res))
            logger1.info('Epoch:{} on Step[{}/{}], Loss={:.4f}'.format(epoch ,i+1,len(data_train),float(loss)))
            loss_sum+=float(loss)
        end=time.time()
        logger2.info('Epoch:{}, Loss Mean={:.4f}, Runtime={:.4f}s'.format(epoch,(float(loss_sum)/len(data_train)),(end-begin)))
        logger1.info('=============Epoch:{} Ended!============='.format(epoch))
        if epoch%save_itercept==0:
            torch.save(model.high,'./model/model3.0/high3.0')
            torch.save(model.low,'./model/model3.0/low3.0')
            torch.save(model.linear,'./model/model3.0/linear3.0')
            torch.save(model.out,'./model/model3.0/out3.0')
            logger2.info('|| Epoch:{} Model has been saved'.format(epoch))
        if (float(loss_sum)/len(data_train)) <0.002:
            break

In [None]:
train(200,optimizer,model,loss_fn,data_train,save_itercept=5)

## 测试模型1（错误拒绝率测试）

### 生成测试数据

In [7]:
class Test_Data1(Dataset):
    def __init__(self,length):
        with open('./model_data/test_high_img.data','rb') as f:
            data=pickle.load(f)
            self.labels=np.array(data['labels'])
            self.images=data['images']
        with open('./model_data/test_low.data','rb') as f:
            data=pickle.load(f)
            self.mfccs=data['mfccs']
        self.length=length
    def __getitem__(self,index):
        issame=np.random.choice([True,False],p=[0.3,0.7])
        if issame:
            cls=np.random.choice(np.arange(1,8))
            cls_index=np.argwhere(self.labels==cls).T[0]
            x1,x2=np.random.choice(cls_index,size=2,replace=False)
            return self.images[x1],self.images[x2],self.mfccs[x1],self.mfccs[x2],torch.tensor(1)
        else:
            x1=np.random.choice(np.arange(0,len(self.labels)))
            cls_index=np.argwhere(self.labels!=self.labels[x1]).T[0]
            x2=np.random.choice(cls_index)
            return self.images[x1],self.images[x2],self.mfccs[x1],self.mfccs[x2],torch.tensor(0)
    def __len__(self):
        return self.length

In [8]:
dataset_test=Test_Data1(length=1600)
data_test=DataLoader(dataset_test,batch_size=32,shuffle=True)

### 测试模型定义

In [9]:
class Siamese_Test1(nn.Module):
    def __init__(self):
        super().__init__()
        self.high=torch.load('./model/model3.0/high3.0')
        self.low=torch.load('./model/model3.0/low3.0')
        self.linear=torch.load('./model/model3.0/linear3.0')
        self.out=torch.load('./model/model3.0/out3.0')
    def forward(self, h1, h2,l1,l2):
        zh1=self.high(h1)
        zh2=self.high(h1)
        zl1=self.low(l1)
        zl2=self.low(l2)
        z1=self.linear(torch.cat((zh1,zl1),1))
        z2=self.linear(torch.cat((zh2,zl2),1))
        x=z1-z2
        x=self.out(x).reshape(-1)
        #return x
        return (x>0.6).int()

### 开始测试

In [10]:
model=Siamese_Test1().to(device)

In [13]:
def test(model,data_test,epochs):
    count=0
    total=0
    error_accept=0
    accept=0
    for epoch in range(epochs):
        c=0
        for i,(kh,kl,qh,ql,label) in enumerate(data_test):
            kh=kh.to(device)
            kl=kl.to(device)
            qh=qh.to(device)
            ql=ql.to(device)
            res=model(kh,kl,qh,ql)
            label=label.numpy()
            #print(res.cpu().detach().numpy(),label.numpy())
            res=res.cpu().detach().numpy()
            accept+=np.sum((res==1))
            e_index=np.argwhere(res!=label)
            e_label=label[e_index]
            error_accept+=np.sum(e_label==0)
            res=np.sum((res==label))
            c+=res
        count+=c
        total+=len(dataset_test)
        print('Epoch {}: [{}/{}]'.format(epoch+1,c,len(dataset_test)))
    print('Accuracy: %%%.4f, False Acceptance Rate: %%%.4f' %(count/total*100,error_accept/accept))

In [14]:
test(model,data_test,5)

Epoch 1: [1586/1600]
Epoch 2: [1585/1600]
Epoch 3: [1584/1600]
Epoch 4: [1587/1600]
Epoch 5: [1584/1600]
Accuracy: %99.0750, False Acceptance Rate: %0.0071


## 测试模型2( one way ten shot测试)

### 生成测试数据

In [None]:
class Test_Data2(Dataset):
    def __init__(self,way=5):
        with open('./model_data/query.data','rb') as f:
            data=pickle.load(f)
            self.q_images=data['images']
            self.q_mfccs=data['mfccs']
        with open('./model_data/key.data','rb') as f:
            data=pickle.load(f)
            self.k_images=data['images']
            self.k_mfccs=data['mfccs']
        self.ways=way
    def __getitem__(self,index):
        issame=np.random.choice([True,False])
        if issame:
            cls=np.random.choice(np.arange(0,7))
            q_index=np.random.choice(np.arange(0,10),self.ways,replace=False)
            kh=self.k_images[cls]
            kl=self.k_mfccs[cls]
            qh=self.q_images[cls][q_index]
            ql=self.q_mfccs[cls][q_index]
            return kh,kl,qh,ql,torch.tensor(1)
        else:
            cls1,cls2=np.random.choice(np.arange(0,7),2,replace=False)
            q_index=np.random.choice(np.arange(0,10),self.ways,replace=False)
            kh=self.k_images[cls1]
            kl=self.k_mfccs[cls1]
            qh=self.q_images[cls2][q_index]
            ql=self.q_mfccs[cls2][q_index]
            return kh,kl,qh,ql,torch.tensor(0)

    def __len__(self):
        return 80

In [None]:
dataset_test=Test_Data2(way=1)
data_test=DataLoader(dataset_test,batch_size=8,shuffle=True)

### 测试模型定义

In [None]:
class Siamese_Test2(nn.Module):
    def __init__(self):
        super().__init__()
        self.high=torch.load('./model/model3.0/high3.0')
        self.low=torch.load('./model/model3.0/low3.0')
        self.linear=torch.load('./model/model3.0/linear3.0')
        self.out=torch.load('./model/model3.0/out3.0')
    def forward(self, kh,kl,qh,ql):
        b=kh.shape[0]
        way=qh.shape[1]
        qh=qh.view(-1,3,224,224)
        ql=ql.view(-1,469,24)
        zhk=self.high(kh)
        zlk=self.low(kl)
        zhq=self.high(qh)
        zlq=self.low(ql)
        zk=self.linear(torch.cat((zhk,zlk),1))
        zq=self.linear(torch.cat((zhq,zlq),1))
       
        zq=torch.mean(zq.view(b,way,-1),dim=1)
        x=zq-zk
        x=self.out(x).reshape(-1)
        #return x
        return (x>0.5).int()

##### 开始测试

In [None]:
model=Siamese_Test2().to(device)

In [None]:
def test(model,data_test,epochs):
    count=0
    total=0
    for epoch in range(epochs):
        c=0
        for i,(kh,kl,qh,ql,label) in enumerate(data_test):
            kh=kh.to(device)
            kl=kl.to(device)
            qh=qh.to(device)
            ql=ql.to(device)
            res=model(kh,kl,qh,ql)
            #print(res.cpu().detach().numpy(),label.numpy())
            res=(res.cpu()==label)
            res=res.detach().numpy()
            
            res=(res==1).sum()
            c+=res
        count+=c
        total+=len(dataset_test)
        print('Epoch {}: [{}/{}]'.format(epoch+1,c,len(dataset_test)))
    print('Accuracy: %%%.2f' %(count/total*100))

In [None]:
test(model,data_test,10)

## 历史子模型

### VGG模型(分析STFT频谱)

In [None]:
class VGG(nn.Module):
 
    def __init__(self):
        super().__init__()

        # 3 * 224 * 224
        self.conv1_1 = nn.Conv2d(3, 64, 3) # 64 * 222 * 222
        self.conv1_2 = nn.Conv2d(64, 64, 3, padding=(1, 1)) # 64 * 222* 222
        self.maxpool1 = nn.MaxPool2d((2, 2), padding=(1, 1)) # pooling 64 * 112 * 112

        self.conv2_1 = nn.Conv2d(64, 128, 3) # 128 * 110 * 110
        self.conv2_2 = nn.Conv2d(128, 128, 3, padding=(1, 1)) # 128 * 110 * 110
        self.maxpool2 = nn.MaxPool2d((2, 2), padding=(1, 1)) # pooling 128 * 56 * 56

        self.conv3_1 = nn.Conv2d(128, 256, 3) # 256 * 54 * 54
        self.conv3_2 = nn.Conv2d(256, 256, 3, padding=(1, 1)) # 256 * 54 * 54
        self.conv3_3 = nn.Conv2d(256, 256, 3, padding=(1, 1)) # 256 * 54 * 54
        self.maxpool3 = nn.MaxPool2d((2, 2), padding=(1, 1)) # pooling 256 * 28 * 28

        self.conv4_1 = nn.Conv2d(256, 512, 3) # 512 * 26 * 26
        self.conv4_2 = nn.Conv2d(512, 512, 3, padding=(1, 1)) # 512 * 26 * 26
        self.conv4_3 = nn.Conv2d(512, 512, 3, padding=(1, 1)) # 512 * 26 * 26
        self.maxpool4 = nn.MaxPool2d((2, 2), padding=(1, 1)) # pooling 512 * 14 * 14

        self.conv5_1 = nn.Conv2d(512, 512, 3) # 512 * 12 * 12
        self.conv5_2 = nn.Conv2d(512, 512, 3, padding=(1, 1)) # 512 * 12 * 12
        self.conv5_3 = nn.Conv2d(512, 512, 3, padding=(1, 1)) # 512 * 12 * 12
        self.maxpool5 = nn.MaxPool2d((2, 2), padding=(1, 1)) # pooling 512 * 7 * 7
        # view

        self.fc1 = nn.Linear(512 * 7 * 7, 4096)
        self.fc2 = nn.Linear(4096, 4096)
        self.fc3 = nn.Linear(4096, 2048)
        # softmax 1 * 1 * 1000

    def forward(self, x):

        # x.size(0)即为batch_size
        in_size = x.size(0)

        out = self.conv1_1(x) # 222
        out = F.relu(out)
        out = self.conv1_2(out) # 222
        out = F.relu(out)
        out = self.maxpool1(out) # 112

        out = self.conv2_1(out) # 110
        out = F.relu(out)
        out = self.conv2_2(out) # 110
        out = F.relu(out)
        out = self.maxpool2(out) # 56

        out = self.conv3_1(out) # 54
        out = F.relu(out)
        out = self.conv3_2(out) # 54
        out = F.relu(out)
        out = self.conv3_3(out) # 54
        out = F.relu(out)
        out = self.maxpool3(out) # 28

        out = self.conv4_1(out) # 26
        out = F.relu(out)
        out = self.conv4_2(out) # 26
        out = F.relu(out)
        out = self.conv4_3(out) # 26
        out = F.relu(out)
        out = self.maxpool4(out) # 14

        out = self.conv5_1(out) # 12
        out = F.relu(out)
        out = self.conv5_2(out) # 12
        out = F.relu(out)
        out = self.conv5_3(out) # 12
        out = F.relu(out)
        out = self.maxpool5(out) # 7

        # 展平
        out = out.view(in_size, -1)

        out = self.fc1(out)
        out = F.relu(out)
        out = self.fc2(out)
        out = F.relu(out)
        out = self.fc3(out)

        return out

### Transformer模型（只包含Encoder）

In [None]:
class Multi_Attention(nn.Module):
    def __init__(self,dim=512,heads=8,head_dim=64):
        super().__init__()
        self.heads=heads
        self.scale=head_dim**-0.5
        self.norm=nn.LayerNorm(dim)
        self.to_qkv=nn.Linear(dim,heads*head_dim*3,bias=False)
        self.to_out = nn.Sequential(
            nn.Linear(heads*head_dim,dim),
            nn.Dropout(0.1)
        )
        
    def forward(self,input):
        x=input
        qkv=self.to_qkv(x).chunk(3,dim=-1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
        dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
        a=F.softmax(dots,dim=-1)
        attn_res=out = torch.einsum('bhij,bhjd->bhid', a, v)
        attn_res=rearrange(attn_res,'b h n d -> b n (h d)')
        x=self.to_out(x)
        x=x+input
        return self.norm(x)

In [None]:
class Fnn(nn.Module):
    def __init__(self,dim=512,FNN_dim=2048):
        super().__init__()
        self.norm=nn.LayerNorm(dim)
        self.mlp=nn.Sequential(
            nn.Linear(dim,FNN_dim),
            nn.ReLU(),
            nn.Linear(FNN_dim,dim),
            nn.Dropout(0.1)
        )
    def forward(self,input):
        x=input
        x=self.mlp(x)
        x=x+input
        x=self.norm(x)
        return x

In [None]:
class Encoder(nn.Module):
    def __init__(self,dim=512,N=6,heads=8,head_dim=64,FNN_dim=2048):
        super().__init__()
        self.blocks = nn.ModuleList([])
        for _ in range(N):
            self.blocks.append(nn.ModuleList([
                Multi_Attention(dim,heads,head_dim),
                Fnn(dim,FNN_dim)
            ]))
    def forward(self,input):
        x=input.to(torch.float32)
        for attention,fnn in self.blocks:
            x=attention(x)
            x=fnn(x)
        return x

In [None]:
class Transformer(nn.Module):
    def __init__(self,input_size,seq_len,dim=512,N=6,heads=8,FNN_dim=2048,head_dim=64):
        super().__init__()
        self.token_embedding=nn.Linear(input_size,dim)
        self.pos_embedding=nn.Parameter(torch.randn(1, seq_len + 1, dim))
        self.cls_token=nn.Parameter(torch.randn(1, 1,dim))
        self.dropout = nn.Dropout(0.1)
        self.encoder=Encoder(dim,N,heads,head_dim,FNN_dim)
    def forward(self,input):
        batch_size=input.shape[0]
        x=input.to(torch.float32)
        x=self.token_embedding(x)
        cls_token=self.cls_token.repeat(batch_size,1,1)
        x=torch.cat((cls_token,x),dim=1)
        x=x+self.pos_embedding
        x = self.dropout(x)
        x = self.encoder(x)
        x=x[:,0]
        return x