In [3]:
cd F:/MTP

F:\MTP


In [4]:
import os
path = os.listdir('EmoDB/wav')

In [5]:
emotion_code = {
    'W':0, #anger
    'L':1, #boredom
    'E':2, #disgust
    'A':3, #fear
    'F':4, #happy
    'T':5, #sad
    'N':6  #neutral
}

speaker_code = {
    '03':0,
    '08':1,
    '09':2,
    '10':3,
    '11':4,
    '12':5,
    '13':6,
    '14':7,
    '15':8,
    '16':9
}

gender_code = {
    '03':0,
    '08':1,
    '09':1,
    '10':0,
    '11':0,
    '12':0,
    '13':1,
    '14':1,
    '15':0,
    '16':1
}

In [6]:
spk_id = {0:"03",1:"08",2:"09",3:"10",4:"11",5:"12",6:"13",7:"14",8:"15",9:"16"}

In [7]:
import librosa
import numpy as np

In [8]:
import torchaudio
import torch

In [9]:
from torch.utils.data import Dataset, DataLoader

In [10]:
class TrainDataset(Dataset):

    def __init__(self, dir_path, test_key,valid_key, transform=None):
        self.dir_path = dir_path
        self.files = os.listdir(self.dir_path)
        self.test_key = test_key
        self.valid_key = valid_key
        self.melspecs,self.Y = self.loadData(dir_path)
        print(len(self.melspecs),len(self.Y))
        
    def loadData(self,dir_path):
        files = os.listdir(dir_path)
        train_keys = list((spk_id[i] for i in range(10) if i not in [self.test_key,self.valid_key]))
        melspecs = []
        Y = []
        for key in train_keys:
            for file in files:
                if file[:2]==key:
                    r, sr = librosa.load(dir_path + file, res_type='kaiser_fast')
                    mfc = librosa.feature.mfcc(y=r, sr=sr,n_fft = 512, hop_length=160, win_length=320)
                    temp = self.chunk(torch.Tensor(mfc))
                    melspecs.extend(temp)
                    for _ in range(len(temp)):
                        y = torch.zeros(7,dtype = int)
                        y[emotion_code[file[5]]]=1
                        Y.append(y)
        return melspecs,Y
    
    def chunk(self,melspec):
        melspec = melspec.transpose(0,1)
        res = []
        for i in range(0,melspec.size(0),50):
            temp = melspec[i:i+100,:]
            if temp.size(0)==100:
                res.append(temp)
        return res        
        
    def __len__(self):
        return len(self.melspecs)

    def __getitem__(self, idx):
        return self.melspecs[idx],self.Y[idx]

In [11]:
ds = TrainDataset("EmoDB/wav/",3,4)

2755 2755


In [12]:
train_dataloader = DataLoader(ds, batch_size=4,shuffle=True)

In [13]:
class ValidDataset(Dataset):

    def __init__(self, dir_path,valid_key, transform=None):
        self.dir_path = dir_path
        self.files = os.listdir(self.dir_path)
        self.valid_key = valid_key
        self.melspecs,self.Y = self.loadData(dir_path)
        print(len(self.melspecs),len(self.Y))
        
    def loadData(self,dir_path):
        files = os.listdir(dir_path)
        melspecs = []
        Y = []
        for file in files:
            if file[:2]==spk_id[self.valid_key]:
                r, sr = librosa.load(dir_path + file, res_type='kaiser_fast')
                melspec = librosa.feature.mfcc(y=r, sr=sr,n_fft = 512, hop_length=160, win_length=320)
                temp = self.chunk(torch.Tensor(melspec))
                melspecs.extend(temp)
                for _ in range(len(temp)):
                    y = torch.zeros(7,dtype = int)
                    y[emotion_code[file[5]]]=1
                    Y.append(y)
        return melspecs,Y
    
    def chunk(self,melspec):
        melspec = melspec.transpose(0,1)
        res = []
        for i in range(0,melspec.size(0),50):
            temp = melspec[i:i+100,:]
            if temp.size(0)==100:
                res.append(temp)
        return res        
        
    def __len__(self):
        return len(self.melspecs)

    def __getitem__(self, idx):
        return self.melspecs[idx],self.Y[idx]

In [14]:
validset = ValidDataset("EmoDB/wav/",4)

357 357


In [15]:
val_dataloader = DataLoader(validset, batch_size=4,shuffle=True)

In [16]:
len(val_dataloader)

90

In [17]:
class TestDataset(Dataset):

    def __init__(self, dir_path, test_key, transform=None):
        self.dir_path = dir_path
        self.files = os.listdir(self.dir_path)
        self.test_key = test_key
        self.melspecs,self.Y = self.loadData(dir_path)
        print(len(self.melspecs),len(self.Y))
        
    def loadData(self,dir_path):
        files = os.listdir(dir_path)
        melspecs = []
        Y = []
        key = spk_id[self.test_key]
        for file in files:
            if file[:2]==key:
                r, sr = librosa.load(dir_path + file, res_type='kaiser_fast')
                melspec = librosa.feature.mfcc(y=r, sr=sr,n_fft = 512, hop_length=160, win_length=320)
                melspec = melspec.transpose()
                melspecs.append(melspec)
                y = torch.zeros(7,dtype = int)
                y[emotion_code[file[5]]]=1
                Y.append(y)
        return melspecs,Y
   
    def __len__(self):
        return len(self.melspecs)

    def __getitem__(self, idx):
        return self.melspecs[idx],self.Y[idx]

In [18]:
testset = TestDataset("EmoDB/wav/",3)

38 38


In [19]:
# for i, data in enumerate(testloader, 0):
#         # get the inputs; data is a list of [inputs, labels]
#         inputs, labels = data
#         print(labels)

In [20]:
testloader = DataLoader(testset, batch_size=1,shuffle=False)

In [19]:
import torch.nn as nn
import torch.nn.functional as F

class TDNN(nn.Module):
    
    def __init__(
                    self, 
                    input_dim=23, 
                    output_dim=512,
                    context_size=5,
                    stride=1,
                    dilation=1,
                    batch_norm=False,
                    dropout_p=0.2
                ):
        super(TDNN, self).__init__()
        self.context_size = context_size
        self.stride = stride
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.dilation = dilation
        self.dropout_p = dropout_p
        self.batch_norm = batch_norm
      
        self.kernel = nn.Linear(input_dim*context_size, output_dim)
        self.nonlinearity = nn.ReLU()
        if self.batch_norm:
            self.bn = nn.BatchNorm1d(output_dim)
        if self.dropout_p:
            self.drop = nn.Dropout(p=self.dropout_p)
        
    def forward(self, x):
        
        
        _, _, d = x.shape
        assert (d == self.input_dim), 'Input dimension was wrong. Expected ({}), got ({})'.format(self.input_dim, d)
        x = x.unsqueeze(1)

        # Unfold input into smaller temporal contexts
        x = F.unfold(
                        x, 
                        (self.context_size, self.input_dim), 
                        stride=(1,self.input_dim), 
                        dilation=(self.dilation,1)
                    )

        # N, output_dim*context_size, new_t = x.shape
        x = x.transpose(1,2)
        x = self.kernel(x.float())
        x = self.nonlinearity(x)
        
        if self.dropout_p:
            x = self.drop(x)

        if self.batch_norm:
            x = x.transpose(1,2)
            x = self.bn(x)
            x = x.transpose(1,2)

        return x
import torch.nn as nn
# from models.tdnn import TDNN
import torch
import torch.nn.functional as F

class X_vector(nn.Module):
    def __init__(self, input_dim = 20, num_classes=7):
        super(X_vector, self).__init__()
        self.tdnn1 = TDNN(input_dim=input_dim, output_dim=512, context_size=5, dilation=1,dropout_p=0.5)
        self.tdnn2 = TDNN(input_dim=512, output_dim=512, context_size=3, dilation=1,dropout_p=0.5)
        self.tdnn3 = TDNN(input_dim=512, output_dim=512, context_size=2, dilation=2,dropout_p=0.5)
        self.tdnn4 = TDNN(input_dim=512, output_dim=512, context_size=1, dilation=1,dropout_p=0.5)
        self.tdnn5 = TDNN(input_dim=512, output_dim=512, context_size=1, dilation=3,dropout_p=0.5)
        #### Frame levelPooling
        self.segment6 = nn.Linear(1024, 512)
        self.segment7 = nn.Linear(512, 512)
        self.output = nn.Linear(512, num_classes)
#         self.softmax = nn.Softmax(dim=1)
    def forward(self, inputs):
        tdnn1_out = self.tdnn1(inputs)
#         return tdnn1_out
        tdnn2_out = self.tdnn2(tdnn1_out)
        tdnn3_out = self.tdnn3(tdnn2_out)
        tdnn4_out = self.tdnn4(tdnn3_out)
        tdnn5_out = self.tdnn5(tdnn4_out)
        ### Stat Pool
        mean = torch.mean(tdnn5_out,1)
        std = torch.std(tdnn5_out,1)
        stat_pooling = torch.cat((mean,std),1)
        segment6_out = self.segment6(stat_pooling)
        x_vec = self.segment7(segment6_out)
        predictions = self.output(x_vec)
        return predictions

In [20]:
input_feats = [4,100,20]
input = torch.rand(input_feats)
model = X_vector()
out = model(input)

In [21]:
out

tensor([[ 0.0170, -0.0404, -0.0174, -0.0133,  0.0375,  0.0010,  0.0251],
        [ 0.0170, -0.0417, -0.0166, -0.0157,  0.0377, -0.0012,  0.0280],
        [ 0.0174, -0.0404, -0.0149, -0.0133,  0.0371, -0.0003,  0.0263],
        [ 0.0173, -0.0411, -0.0154, -0.0155,  0.0387, -0.0016,  0.0273]],
       grad_fn=<AddmmBackward>)

In [22]:
for layer in model.children():
   if hasattr(layer, 'reset_parameters'):
       layer.reset_parameters()

In [26]:
import torch.optim as optim
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0, betas=(0.9, 0.98), eps=1e-9)
criterion = nn.CrossEntropyLoss()

In [50]:
def train(train_dataloader,epoch):
    running_loss = 0.0
    correct=0
    total=0
    train_loss_list=[]
    model.train()
    for i, data in enumerate(train_dataloader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data
        inputs.requires_grad = True
        # zero the parameter gradients
        optimizer.zero_grad()
        labels = torch.argmax(labels,dim =1)
        # forward + backward + optimize
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        train_loss_list.append(loss.item())
        total += labels.size(0)
        _, predicted = torch.max(outputs.data, 1)
        correct += (predicted == labels).sum().item()
        mean_loss = np.mean(np.asarray(train_loss_list))
        s = 0.0
    for param in model.parameters():
        s += torch.sum(param)
    return s
#         if i%100==0:
#             print('Iteration - {} Epoch - {} Total training loss - {}'.format(i,epoch,mean_loss))
    print("train accuracy after epoch " +str(epoch) +" is " + str(100 * correct / total))
            
def validation(valid_dataloader,epoch):
    model.eval()
    with torch.no_grad():
        val_loss_list=[]
        correct = 0
        total = 0
        for i, data in enumerate(valid_dataloader, 0):
            inputs, labels = data
            labels = torch.argmax(labels,dim =1)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            loss = criterion(outputs, labels)
            val_loss_list.append(loss.item())
#             if i%100==0:
#                 print('Iteration - {} Epoch - {} Loss - {}'.format(i,epoch,np.mean(np.asarray(val_loss_list))))
        print("valid accuracy after epoch "+str(epoch) +" is " + str(100 * correct / total))        
        mean_loss = np.mean(np.asarray(val_loss_list))
#         print('Total validation loss {} after {} epochs'.format(mean_loss,epoch))
        model_save_path = os.path.join( 'best_check_point_'+str(epoch)+'_'+str(mean_loss))
        state_dict = {'model': model.state_dict(),'optimizer': optimizer.state_dict(),'epoch': epoch}
        torch.save(state_dict, model_save_path)


In [51]:
for epoch in range(100):
    train(train_dataloader,epoch)
    validation(val_dataloader,epoch)

train accuracy after epoch 0 is 55.35390199637023
valid accuracy after epoch 0 is 33.05322128851541
train accuracy after epoch 1 is 56.152450090744104
valid accuracy after epoch 1 is 24.089635854341736
train accuracy after epoch 2 is 56.91470054446461
valid accuracy after epoch 2 is 32.212885154061624
train accuracy after epoch 3 is 56.8421052631579
valid accuracy after epoch 3 is 22.689075630252102
train accuracy after epoch 4 is 57.168784029038115
valid accuracy after epoch 4 is 22.689075630252102
train accuracy after epoch 5 is 57.53176043557169
valid accuracy after epoch 5 is 21.568627450980394
train accuracy after epoch 6 is 57.313974591651544
valid accuracy after epoch 6 is 21.288515406162464
train accuracy after epoch 7 is 58.87477313974592
valid accuracy after epoch 7 is 36.97478991596638
train accuracy after epoch 8 is 56.95099818511797
valid accuracy after epoch 8 is 31.092436974789916
train accuracy after epoch 9 is 55.82577132486389
valid accuracy after epoch 9 is 26.050420

train accuracy after epoch 80 is 59.201451905626136
valid accuracy after epoch 80 is 21.568627450980394
train accuracy after epoch 81 is 56.91470054446461
valid accuracy after epoch 81 is 23.249299719887954
train accuracy after epoch 82 is 58.69328493647913
valid accuracy after epoch 82 is 21.568627450980394
train accuracy after epoch 83 is 59.382940108892925
valid accuracy after epoch 83 is 23.80952380952381
train accuracy after epoch 84 is 57.56805807622504
valid accuracy after epoch 84 is 21.568627450980394
train accuracy after epoch 85 is 57.85843920145191
valid accuracy after epoch 85 is 23.529411764705884
train accuracy after epoch 86 is 57.56805807622504
valid accuracy after epoch 86 is 22.969187675070028
train accuracy after epoch 87 is 58.87477313974592
valid accuracy after epoch 87 is 23.529411764705884
train accuracy after epoch 88 is 58.22141560798548
valid accuracy after epoch 88 is 20.168067226890756
train accuracy after epoch 89 is 56.91470054446461
valid accuracy after 

In [52]:
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for i,data in enumerate(testloader):
        inputs, labels = data
        outputs = model(inputs)
        labels = torch.argmax(labels,dim =1)
        _, predicted = torch.max(outputs.data, 1)
        print(labels,predicted)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
print(correct,total)
print(100 * correct / total)

tensor([3]) tensor([5])
tensor([6]) tensor([5])
tensor([0]) tensor([2])
tensor([3]) tensor([5])
tensor([4]) tensor([2])
tensor([1]) tensor([5])
tensor([6]) tensor([5])
tensor([0]) tensor([2])
tensor([4]) tensor([2])
tensor([6]) tensor([5])
tensor([0]) tensor([5])
tensor([0]) tensor([4])
tensor([3]) tensor([2])
tensor([1]) tensor([5])
tensor([5]) tensor([5])
tensor([0]) tensor([2])
tensor([3]) tensor([5])
tensor([3]) tensor([5])
tensor([1]) tensor([5])
tensor([5]) tensor([5])
tensor([0]) tensor([2])
tensor([3]) tensor([2])
tensor([2]) tensor([5])
tensor([4]) tensor([2])
tensor([1]) tensor([5])
tensor([3]) tensor([5])
tensor([1]) tensor([5])
tensor([6]) tensor([5])
tensor([0]) tensor([4])
tensor([1]) tensor([5])
tensor([5]) tensor([5])
tensor([0]) tensor([4])
tensor([3]) tensor([5])
tensor([1]) tensor([5])
tensor([0]) tensor([3])
tensor([4]) tensor([3])
tensor([1]) tensor([5])
tensor([0]) tensor([4])
3 38
7.894736842105263


In [55]:
class MTL_Dataset(Dataset):

    def __init__(self, dir_path, test_key,valid_key, transform=None):
        self.dir_path = dir_path
        self.files = os.listdir(self.dir_path)
        self.test_key = test_key
        self.valid_key = valid_key
        self.melspecs,self.Y = self.loadData(dir_path)
        print(len(self.melspecs),len(self.Y))
        
    def loadData(self,dir_path):
        files = os.listdir(dir_path)
        train_keys = list((spk_id[i] for i in range(10) if i not in [self.test_key,self.valid_key]))
        melspecs = []
        Y = []
        for key in train_keys:
            for file in files:
                if file[:2]==key:
                    r, sr = librosa.load(dir_path + file, res_type='kaiser_fast')
                    melspec = librosa.feature.melspectrogram(y=r, sr=sr,n_fft = 512, hop_length=160, win_length=320,n_mels=24)
                    temp = self.chunk(torch.Tensor(melspec))
                    melspecs.extend(temp)
                    for _ in range(len(temp)):
                        ye = torch.zeros(7,dtype = int)
                        ye[emotion_code[file[5]]]=1
                        yg = torch.zeros(2,dtype = int)
                        yg[gender_code[file[:2]]]=1
                        Y.append((ye,yg))
        return melspecs,Y
    
    def chunk(self,melspec):
        melspec = melspec.transpose(0,1)
        res = []
        for i in range(0,melspec.size(0),50):
            temp = melspec[i:i+100,:]
            if temp.size(0)==100:
                res.append(temp)
        return res        
        
    def __len__(self):
        return len(self.melspecs)

    def __getitem__(self, idx):
        return self.melspecs[idx],self.Y[idx]

In [56]:
trmt = MTL_Dataset("EmoDB/wav/",0,1)

2629 2629


In [62]:
trainmt_dataloader = DataLoader(trmt, batch_size=4,shuffle=True)

In [60]:
class MTLVal_Dataset(Dataset):

    def __init__(self, dir_path,valid_key, transform=None):
        self.dir_path = dir_path
        self.files = os.listdir(self.dir_path)
        self.valid_key = valid_key
        self.melspecs,self.Y = self.loadData(dir_path)
        print(len(self.melspecs),len(self.Y))
        
    def loadData(self,dir_path):
        files = os.listdir(dir_path)
        melspecs = []
        Y = []
        for file in files:
            if file[:2]==spk_id[self.valid_key]:
                r, sr = librosa.load(dir_path + file, res_type='kaiser_fast')
                melspec = librosa.feature.melspectrogram(y=r, sr=sr,n_fft = 512, hop_length=160, win_length=320,n_mels=24)
                temp = self.chunk(torch.Tensor(melspec))
                melspecs.extend(temp)
                for _ in range(len(temp)):
                    ye = torch.zeros(7,dtype = int)
                    ye[emotion_code[file[5]]]=1
                    yg = torch.zeros(2,dtype = int)
                    yg[gender_code[file[:2]]]=1
                    Y.append((ye,yg))
        return melspecs,Y
    
    def chunk(self,melspec):
        melspec = melspec.transpose(0,1)
        res = []
        for i in range(0,melspec.size(0),50):
            temp = melspec[i:i+100,:]
            if temp.size(0)==100:
                res.append(temp)
        return res        
        
    def __len__(self):
        return len(self.melspecs)

    def __getitem__(self, idx):
        return self.melspecs[idx],self.Y[idx]

In [61]:
vmt = MTLVal_Dataset("EmoDB/wav/",1)

399 399


In [63]:
validmt_dataloader = DataLoader(vmt, batch_size=4,shuffle=True)

In [87]:
class MTLTest_Dataset(Dataset):

    def __init__(self, dir_path,test_key, transform=None):
        self.dir_path = dir_path
        self.files = os.listdir(self.dir_path)
        self.test_key = test_key
        self.melspecs,self.Y = self.loadData(dir_path)
        print(len(self.melspecs),len(self.Y))
        
    def loadData(self,dir_path):
        files = os.listdir(dir_path)
        melspecs = []
        Y = []
        for file in files:
            if file[:2]==spk_id[self.test_key]:
                r, sr = librosa.load(dir_path + file, res_type='kaiser_fast')
                melspec = librosa.feature.melspectrogram(y=r, sr=sr,n_fft = 512, hop_length=160, win_length=320,n_mels=24)
                melspec = melspec.transpose()
                melspecs.append(melspec)
                ye = torch.zeros(7,dtype = int)
                ye[emotion_code[file[5]]]=1
                yg = torch.zeros(2,dtype = int)
                yg[gender_code[file[:2]]]=1
                Y.append((ye,yg))
        return melspecs,Y
        
    def __len__(self):
        return len(self.melspecs)

    def __getitem__(self, idx):
        return self.melspecs[idx],self.Y[idx]

In [88]:
testmt = MTLTest_Dataset("EmoDB/wav/",0)

49 49


In [89]:
testmt_dataloader = DataLoader(testmt, batch_size=1,shuffle=True)

In [71]:
class multitask(nn.Module):
    def __init__(self, input_dim = 24, em_classes=7,gen_classes = 2):
        super(multitask, self).__init__()
        self.tdnn1 = TDNN(input_dim=input_dim, output_dim=512, context_size=5, dilation=1,dropout_p=0.5)
        self.tdnn2 = TDNN(input_dim=512, output_dim=512, context_size=3, dilation=1,dropout_p=0.5)
        self.tdnn3 = TDNN(input_dim=512, output_dim=512, context_size=2, dilation=2,dropout_p=0.5)
        self.tdnn4 = TDNN(input_dim=512, output_dim=512, context_size=1, dilation=1,dropout_p=0.5)
        self.tdnn5 = TDNN(input_dim=512, output_dim=512, context_size=1, dilation=3,dropout_p=0.5)
        #### Frame levelPooling
        self.segment6 = nn.Linear(1024, 512)
        self.segment7 = nn.Linear(512, 512)
        self.emotion = nn.Linear(512, em_classes)
        self.gender = nn.Linear(512, gen_classes)
#         self.softmax = nn.Softmax(dim=1)
    def forward(self, inputs):
        tdnn1_out = self.tdnn1(inputs)
#         return tdnn1_out
        tdnn2_out = self.tdnn2(tdnn1_out)
        tdnn3_out = self.tdnn3(tdnn2_out)
        tdnn4_out = self.tdnn4(tdnn3_out)
        tdnn5_out = self.tdnn5(tdnn4_out)
        ### Stat Pool
        mean = torch.mean(tdnn5_out,1)
        std = torch.std(tdnn5_out,1)
        stat_pooling = torch.cat((mean,std),1)
        segment6_out = self.segment6(stat_pooling)
        x_vec = self.segment7(segment6_out)
        em_predictions = self.emotion(x_vec)
        gen_predictions = self.gender(x_vec)
        return em_predictions,gen_predictions

In [131]:
input_feats = [4,100,24]
input = torch.rand(input_feats)
mt_model = multitask()
out = mt_model(input)
print(out)

(tensor([[-0.0098, -0.0171, -0.0305, -0.0194, -0.0344,  0.0426,  0.0026],
        [-0.0093, -0.0194, -0.0295, -0.0188, -0.0358,  0.0425,  0.0020],
        [-0.0084, -0.0177, -0.0309, -0.0175, -0.0381,  0.0416,  0.0036],
        [-0.0101, -0.0168, -0.0309, -0.0189, -0.0360,  0.0429,  0.0039]],
       grad_fn=<AddmmBackward>), tensor([[0.0078, 0.0154],
        [0.0082, 0.0159],
        [0.0083, 0.0154],
        [0.0078, 0.0159]], grad_fn=<AddmmBackward>))


In [83]:
def trainmt(train_dataloader,epoch):
    running_loss = 0.0
    train_loss_list=[]
    model.train()
    for i, data in enumerate(train_dataloader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data
        inputs.requires_grad = True
        # zero the parameter gradients
        optimizer.zero_grad()
        labels[0] = torch.argmax(labels[0],dim =1)
        labels[1] = torch.argmax(labels[1],dim =1)
        # forward + backward + optimize
        outputs = mt_model(inputs)
        loss1 = criterion(outputs[0], labels[0])
        loss2 = criterion(outputs[1], labels[1])
        loss = loss1+loss2
        loss.backward()
        optimizer.step()
        train_loss_list.append(loss.item())
        mean_loss = np.mean(np.asarray(train_loss_list))
        if i%100==0:
            print('Iteration - {} Epoch - {} Total training loss - {} '.format(i,epoch,mean_loss))
            
def validationmt(valid_dataloader,epoch):
    model.eval()
    with torch.no_grad():
        val_loss_list=[]
        for i, data in enumerate(valid_dataloader, 0):
            inputs, labels = data
            labels[0] = torch.argmax(labels[0],dim =1)
            labels[1] = torch.argmax(labels[1],dim =1)
            outputs = mt_model(inputs)
            loss1 = criterion(outputs[0], labels[0])
            loss2 = criterion(outputs[1], labels[1])
            loss = loss1+loss2
            val_loss_list.append(loss.item())
            if i%100==0:
                print('Iteration - {} Epoch - {} Loss - {}'.format(i,epoch,np.mean(np.asarray(val_loss_list))))
                
        mean_loss = np.mean(np.asarray(val_loss_list))
        print('Total validation loss {} after {} epochs'.format(mean_loss,epoch))
        model_save_path = os.path.join( 'best_check_point_'+str(epoch)+'_'+str(mean_loss))
        state_dict = {'model': model.state_dict(),'optimizer': optimizer.state_dict(),'epoch': epoch}
        torch.save(state_dict, model_save_path)


In [None]:
for epoch in range(100):
    trainmt(trainmt_dataloader,epoch)
    validationmt(validmt_dataloader,epoch)

In [None]:
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for i,data in enumerate(testmt_dataloader):
        inputs, labels = data
        labels[0] = torch.argmax(labels[0],dim =1)
        labels[1] = torch.argmax(labels[1],dim =1)
        outputs = mt_model(inputs)
        print(outputs,labels)
#         total += labels.size(0)
#         correct += (predicted == labels).sum().item()
# print(correct,total)
# print(100 * correct / total)