In [1]:
from torch.utils.data import Dataset, DataLoader
import torch

In [2]:
import pandas as pd

In [3]:
from transformers import AutoTokenizer



In [4]:
import torch.nn.functional as F
import torch.optim as optim
import pytorch_lightning as pl
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

In [5]:
class DADataset(Dataset):
    
    __label_dict = dict()
    
    def __init__(self, tokenizer, data, text_field = "clean_text", label_field="act_label_1", max_len=20):
        
        self.text = list(data[text_field]) #data['train'][text_field]
        self.acts = list(data[label_field]) #['train'][label_field]
        self.tokenizer = tokenizer
        self.max_len = max_len
        
        
        # build/update the label dictionary 
        classes = sorted(set([item for row in self.acts for item in row]))
        
        for cls in classes:
            if cls not in DADataset.__label_dict.keys():
                DADataset.__label_dict[cls]=len(DADataset.__label_dict.keys())
    
    def __len__(self):
        return len(self.text)
    
    def label_dict(self):
        return DADataset.__label_dict
    
    def tokenize(self, input):
        input_encoding = self.tokenizer.encode_plus(
            text=input,
            truncation=True,
            max_length=self.max_len,
            return_attention_mask=True,
            padding="max_length",
        )
        return input_encoding

    
    def __getitem__(self, index):
        
        text = self.text[index]
        act = self.acts[index]
        label = [DADataset.__label_dict[act] for act in self.acts[index]]
        output1 = []
        output2 = []
        sample = []
        for persona in text:
            personalistinp=[]
            personalistatt=[]
            for sent in persona:
                tkr = self.tokenize(sent)
                personalistinp.append(tkr['input_ids'])
                personalistatt.append(tkr['attention_mask'])
            output1.append(personalistinp)
            output2.append(personalistatt)
            
        return {
            "output1":torch.tensor(output1),
            "output2":torch.tensor(output2),
            "label":torch.tensor(label[0], dtype=torch.long),
        }


In [6]:
from config import config

In [7]:
df = pd.read_csv("swda/swda-train.csv", converters={'Text': eval,'DamslActTag': eval})

In [8]:
df.head()

Unnamed: 0,Text,DamslActTag
0,"[[Lucille Hughes.], [Okay,, Lucille, I'm on, o...","[fo_o_fw_""_by_bc, b, sd, b, %, sd, qy, ny, aa,..."
1,"[[Okay,, Lucille, I'm on, on (( )) . -], [All ...","[b, sd, b, %, sd, qy, ny, aa, sd, b]"
2,"[[Lucille, I'm on, on (( )) . -], [All right,,...","[sd, b, %, sd, qy, ny, aa, sd, b, sv]"
3,"[[All right,, # and what, # -], [# And our top...","[b, %, sd, qy, ny, aa, sd, b, sv, sd]"
4,"[[# and what, # -], [# And our top-, # okay, o...","[%, sd, qy, ny, aa, sd, b, sv, sd, sd]"


In [9]:
txt = list(df['Text'])
acts = list(df['DamslActTag'])

In [15]:
label = [train_dataset.label_dict()[act] for act in acts[0]]

NameError: name 'train_dataset' is not defined

In [16]:
torch.tensor([label], dtype=torch.long)

NameError: name 'label' is not defined

In [17]:
tokenizer = AutoTokenizer.from_pretrained('roberta-base')

In [18]:
def tokenize(input):
        input_encoding = tokenizer.encode_plus(
            text=input,
            truncation=True,
            max_length=20,
            return_tensors="pt",
            return_attention_mask=True,
            padding="max_length",
        )
        return input_encoding

In [19]:
class DADataset(Dataset):
    
    __label_dict = dict()
    
    def __init__(self, tokenizer, data, text_field = "clean_text", label_field="act_label_1", max_len=20):
        
        self.text = list(data[text_field]) #data['train'][text_field]
        self.acts = list(data[label_field]) #['train'][label_field]
        self.tokenizer = tokenizer
        self.max_len = max_len
        
        
        # build/update the label dictionary 
        classes = sorted(set([item for row in self.acts for item in row]))
        
        for cls in classes:
            if cls not in DADataset.__label_dict.keys():
                DADataset.__label_dict[cls]=len(DADataset.__label_dict.keys())
    
    def __len__(self):
        return len(self.text)
    
    def label_dict(self):
        return DADataset.__label_dict
    
    def tokenize(self, input):
        input_encoding = self.tokenizer.encode_plus(
            text=input,
            truncation=True,
            max_length=self.max_len,
            return_attention_mask=True,
            padding="max_length",
        )
        return input_encoding

    
    def __getitem__(self, index):
        
        text = self.text[index]
        act = self.acts[index]
        label = [DADataset.__label_dict[act] for act in self.acts[index]]
        inputid = []
        attention = []
        sequencemask_fr = []
        sequencemask_bk = []
        for persona in text:
            start = True
            for sent in persona:
                tkr = self.tokenize(sent)
                inputid.append(tkr['input_ids'])
                attention.append(tkr['attention_mask'])
                sequencemask_bk.append(1)
                if start:
                    sequencemask_fr.append(0)
                    start = False
                else:
                    sequencemask_fr.append(1)
            sequencemask_bk[len(sequencemask_bk)-1]=0

        return {
            "sequencemask_fr": torch.tensor(sequencemask_fr),
            "sequencemask_bk": torch.tensor(sequencemask_bk),
            "inputid":torch.tensor(inputid),
            "attention":torch.tensor(attention),
            "label":torch.tensor(label, dtype=torch.long),
        }


In [20]:
train_dataset = DADataset(tokenizer=tokenizer, data=df, max_len=20, text_field='Text', label_field='DamslActTag')
drop_last = True if len(train_dataset.text) % config['batch_size'] == 1 else False  
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=False, num_workers=config['num_workers'], drop_last=drop_last)

In [21]:
train_dataset.__getitem__(0)

{'sequencemask_fr': tensor([0, 0, 1, 0, 1, 0, 1, 0, 0, 1]),
 'sequencemask_bk': tensor([0, 1, 0, 1, 0, 1, 0, 0, 1, 0]),
 'inputid': tensor([[    0, 20793,  4061,  7799,     4,     2,     1,     1,     1,     1,
              1,     1,     1,     1,     1,     1,     1,     1,     1,     1],
         [    0, 33082,     6,     2,     1,     1,     1,     1,     1,     1,
              1,     1,     1,     1,     1,     1,     1,     1,     1,     1],
         [    0, 20793,  4061,     6,    38,   437,    15,     6,    15, 41006,
          49087,   479,   111,     2,     1,     1,     1,     1,     1,     1],
         [    0,  3684,   235,     6,     2,     1,     1,     1,     1,     1,
              1,     1,     1,     1,     1,     1,     1,     1,     1,     1],
         [    0, 10431,     8,    99,     6,   849,   111,     2,     1,     1,
              1,     1,     1,     1,     1,     1,     1,     1,     1,     1],
         [    0, 10431,   178,    84,   299, 20551,   849,  8578

In [22]:
x = torch.arange(4)
y = torch.ones((10))*2
x

tensor([0, 1, 2, 3])

In [23]:
x.unsqueeze(0).unsqueeze(0).squeeze(0)

tensor([[0, 1, 2, 3]])

In [24]:
for a,b in zip(x,y):
    print(a.size())
    #print(b.size())
    #print(a*b)
    

torch.Size([])
torch.Size([])
torch.Size([])
torch.Size([])


In [25]:
x.flip([0,1])

IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)

In [26]:
batch_sequences = torch.empty((0, 10, 20))
batch_sequences

tensor([], size=(0, 10, 20))

In [27]:
count = 0
for inp in train_loader:
    count=count+1
    batch_sequences = torch.empty((0, 10, 20))
    for i,x in enumerate(inp['inputid']):
        batch_sequences = torch.cat((batch_sequences, x.unsqueeze(0)), dim=0)
    #for x,y in zip(inp['inputid'],inp['attention']):
    #    print(y)
    print(batch_sequences.size())
    print(batch_sequences)
    if count>0:
        break

torch.Size([64, 10, 20])
tensor([[[0.0000e+00, 2.0793e+04, 4.0610e+03,  ..., 1.0000e+00,
          1.0000e+00, 1.0000e+00],
         [0.0000e+00, 3.3082e+04, 6.0000e+00,  ..., 1.0000e+00,
          1.0000e+00, 1.0000e+00],
         [0.0000e+00, 2.0793e+04, 4.0610e+03,  ..., 1.0000e+00,
          1.0000e+00, 1.0000e+00],
         ...,
         [0.0000e+00, 9.9040e+03, 4.0000e+00,  ..., 1.0000e+00,
          1.0000e+00, 1.0000e+00],
         [0.0000e+00, 3.6840e+03, 2.3500e+02,  ..., 1.0000e+00,
          1.0000e+00, 1.0000e+00],
         [0.0000e+00, 2.4167e+04, 4.0000e+00,  ..., 1.0000e+00,
          1.0000e+00, 1.0000e+00]],

        [[0.0000e+00, 3.3082e+04, 6.0000e+00,  ..., 1.0000e+00,
          1.0000e+00, 1.0000e+00],
         [0.0000e+00, 2.0793e+04, 4.0610e+03,  ..., 1.0000e+00,
          1.0000e+00, 1.0000e+00],
         [0.0000e+00, 3.6840e+03, 2.3500e+02,  ..., 1.0000e+00,
          1.0000e+00, 1.0000e+00],
         ...,
         [0.0000e+00, 3.6840e+03, 2.3500e+02,  ..., 1.

In [28]:
import torch.nn as nn


In [29]:
from transformers import AutoModel

In [30]:
class UtteranceGRU(nn.Module):
    
    def __init__(self, model_name="roberta-base", hidden_size=256, bidirectional=True, num_layers=1):
        super(UtteranceGRU, self).__init__()
        
        
        # embedding layer is replaced by pretrained roberta's embedding
        self.base = AutoModel.from_pretrained(pretrained_model_name_or_path=model_name)
        # freeze the model parameters
        for param in self.base.parameters():
            param.requires_grad = False
        
        #self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_size)
        self.gru = nn.GRU(
            input_size=768, 
            hidden_size=hidden_size, 
            num_layers=num_layers, 
            bidirectional=bidirectional,
            batch_first=True
        )
    
    def forward(self, input_ids, attention_mask):
        """
            input_ids.shape = [seq_sentences, seq_len]
            attention_mask.shape = [seq_sentences, seq_len]
        """
        
    
        hidden_states, _ = self.base(input_ids, attention_mask) # hidden_states.shape = [batch, max_len, hidden_size]
                
        outputs, hidden = self.gru(hidden_states)
                
        return torch.cat((hidden[-2,:,:],hidden[-1,:,:]), dim=1)

In [153]:
class PersonaGRU(nn.Module):
    
    def __init__(self, model_name="roberta-base", hidden_size=256, num_layers=2, num_classes=43, device=torch.device("cpu")):
        
        super(PersonaGRU, self).__init__()
        
        self.in_features = 2*hidden_size
        
        self.device = device
        
        # utterance encoder model
        self.utterance_rnn = UtteranceGRU(model_name=model_name, hidden_size=hidden_size)
        
        self.gru = nn.GRU(
            input_size=512,
            hidden_size=256, 
            num_layers=1, 
            bidirectional=True,
            batch_first=True
        )
        
        # initial hidden_states
        self.hx = torch.randn((2, 1, hidden_size), device=self.device)
        
    
    def forward(self, batch_input):
        """
            x.shape = [batch, seq_len, hidden_size]
        """
        
        batch_inputid = batch_input['inputid']
        batch_attention = batch_input['attention']
        batch_sequencemask_fr = batch_input['sequencemask_fr']
        batch_sequencemask_bk = batch_input['sequencemask_bk']
        
        batch_sequences = torch.empty((0, 10, self.in_features), device=self.device)
        for inputid, attention in zip(batch_inputid, batch_attention):
            seq_of_utterance = self.utterance_rnn(input_ids=inputid, attention_mask=attention)
            batch_sequences = torch.cat((batch_sequences, seq_of_utterance.unsqueeze(0)), dim=0)
        
        
        # create an empty feature vector 
        batch_personafeatures = torch.empty((0, 10, 1024), device=self.device)      
        
        for seq, fr_mask, bk_mask in zip(batch_sequences, batch_sequencemask_fr, batch_sequencemask_bk): #seq
            personafeatures_forward = torch.empty((0, self.in_features), device=self.device)
            personafeatures_backword = torch.empty((0, self.in_features), device=self.device)
            seq_tmp = seq.unsqueeze(0)
            seq_rev = seq_tmp.flip([0,1])
            backseq = seq_rev.squeeze()
            hidden = self.hx
            for sent, fr_hid_mask in zip(seq, fr_mask):
                sent = sent.unsqueeze(0)
                sent = sent.unsqueeze(0)
                outputs, hidden = self.gru(sent, torch.where(fr_hid_mask==1, hidden, self.hx))
                personafeatures_forward = torch.cat((personafeatures_forward, outputs.squeeze(0)), dim=0)
            
            hidden = self.hx
            for sent, bk_hid_mask in zip(backseq, bk_mask):
                sent = sent.unsqueeze(0)
                sent = sent.unsqueeze(0)
                outputs, hidden = self.gru(sent, torch.where(bk_hid_mask==1, hidden, self.hx))
                personafeatures_backword = torch.cat((personafeatures_backword, outputs.squeeze(0)), dim=0)
                
            personafeatures_backword = personafeatures_backword.unsqueeze(0).flip([0,1])
            personafeatures = torch.cat((personafeatures_forward, personafeatures_backword.squeeze(0)), dim=1)
            personafeatures.size()
            batch_personafeatures = torch.cat((batch_personafeatures, personafeatures.unsqueeze(0)), dim=0)
            
        return batch_personafeatures

In [154]:
persona = PersonaGRU()

In [155]:
out=[]
for inp in train_loader:
    out = persona(inp)
    break

In [156]:
out.size()

torch.Size([64, 10, 1024])

In [35]:
class Encoder(nn.Module):
    
    def __init__(self, model_name="roberta-base", hidden_size=256, num_layers=2, num_classes=43, device=torch.device("cpu")):
        
        super(Encoder, self).__init__()
        
        self.in_features = 2*hidden_size
        
        self.device = device
                
        # conversaton level rnn
        self.persona = PersonaGRU()

        self.gru = nn.GRU(
            input_size=1024,
            hidden_size=512, 
            num_layers=1, 
            bidirectional=True,
            batch_first=True
        )
        
        # initial hidden_states
        self.hx = torch.randn((2, 1, hidden_size), device=self.device)
        
    
    def forward(self, batch):
        """
            x.shape = [batch, seq_len, hidden_size]
        """
        

        # create an empty feature vector 
        #features = torch.empty((0, self.in_features), device=self.device)
        output = self.persona(batch)
            
        encoded_hidden, hi = self.gru(output)
        
        return encoded_hidden, hi
          

In [36]:
encoder = Encoder()

In [37]:
out=[]
for inp in train_loader:
    out = encoder(inp)
    break

In [38]:
out[0].size()

torch.Size([64, 10, 1024])

In [149]:
class GuidedAttentionDAC(nn.Module):
    
    def __init__(self, model_name="roberta-base", hidden_size=256, num_classes=43, device=torch.device("cpu")):
        
        super(GuidedAttentionDAC, self).__init__()
        
        self.in_features = 2*hidden_size
        
        self.device = device
        
        # utterance encoder model
        self.encoder = Encoder()
        self.decoder = Decoder()
        
    
    def forward(self, batch):
        """
            x.shape = [batch, seq_len, hidden_size]
        """
        
        outputs, hidden = self.encoder(batch)
        hidden = torch.cat((hidden[-2,:,:],hidden[-1,:,:]), dim=1)
        logits = self.decoder(hidden,outputs)
        
        return logits

In [150]:
model = GuidedAttentionDAC()
out=[]
for inp in train_loader:
    out = model(inp)
    break


In [151]:
out.size()

torch.Size([64, 10])

In [152]:
out

tensor([[28., 42., 27., 42., 27., 42., 28., 27., 28., 42.],
        [42., 28., 42., 28., 42., 27., 28., 42., 28., 28.],
        [42., 28., 42., 28., 42., 27., 28., 42., 28., 28.],
        [28., 42., 27., 27., 42., 28., 27., 42., 28., 28.],
        [42., 28., 42., 27., 28., 42., 28., 42., 27., 28.],
        [28., 42., 27., 42., 27., 28., 42., 27., 28., 28.],
        [42., 28., 42., 28., 42., 27., 28., 42., 27., 28.],
        [42., 28., 27., 42., 27., 27., 42., 28., 28., 42.],
        [28., 42., 27., 42., 27., 28., 27., 42., 28., 28.],
        [42., 28., 27., 27., 27., 42., 28., 27., 28., 42.],
        [42., 28., 42., 27., 28., 42., 28., 42., 27., 28.],
        [31., 24., 42., 27., 28., 42., 27., 28., 42., 28.],
        [28., 27., 27., 42., 27., 28., 42., 28., 42., 28.],
        [28., 42., 27., 27., 42., 27., 42., 28., 28., 42.],
        [42., 28., 42., 27., 28., 42., 28., 42., 27., 28.],
        [42., 28., 42., 27., 28., 42., 27., 28., 42., 28.],
        [42., 28., 42., 28., 42., 27., 2

In [141]:
class Decoder(nn.Module):
    
    def __init__(self, num_classes=43, output_size=43, num_layers=2, device=torch.device("cpu")):
        
        super(Decoder, self).__init__()
        self.device = device
        
        self.sostoken = torch.tensor([[43]]*64, dtype=torch.long)
        
        self.embedding = nn.Embedding(44, 128)
        
        
        self.gru = nn.GRU(
            input_size=128,
            hidden_size=1024, 
            num_layers=1, 
            batch_first=True
        )

        # classifier on top of feature extractor
        self.classifier = nn.Sequential(*[
            nn.Linear(in_features=2048, out_features=256),
            nn.LeakyReLU(),
            nn.Linear(in_features=256, out_features=128),
            nn.LeakyReLU(),
            nn.Linear(in_features=128, out_features=num_classes)
        ])
        
        # initial hidden_states
        self.hx = torch.randn((1, 1, 1024), device=self.device)
        
        
    
    def forward(self, final_hidden, hidden_seq):
        """
            x.shape = [batch, seq_len, hidden_size]
        """
        input = self.hx
        
        inputtoken = self.sostoken
        logitseq = torch.empty((64, 0), device=self.device)
        final_hidden = final_hidden.unsqueeze(0)
        for i, x in enumerate([i for i in range(10)]):
            embedded = self.embedding(inputtoken)
            #embedded = self.dropout(embedded)
            finaloutput, finalhidden = self.gru(embedded, final_hidden) 
            ctx = torch.cat((finaloutput[:,0,:], hidden_seq[:,i,:]), dim=1)
            output = self.classifier(ctx)
            logits = torch.argmax(output, dim=1)
            logits = logits.unsqueeze(1)
            #print(logitseq)
            #print(logits)
            logitseq = torch.cat((logitseq,logits), dim=1)
            #print(logitseq)
            inputtoken = logits
            final_hidden = finalhidden
        
        return logitseq
          

In [142]:
dec = Decoder()

In [143]:
logitseqout = dec(out[1], out[0])

In [144]:
logitseqout

tensor([[13., 13., 13., 13., 13., 13., 13., 13., 13., 13.],
        [12., 13., 13., 13., 13., 13., 13., 13., 13., 13.],
        [12., 13., 13., 13., 13., 13., 13., 13., 13., 13.],
        [12., 13., 13., 13., 13., 13., 13., 13., 13., 13.],
        [12., 13., 13., 13., 13., 13., 13., 13., 13., 13.],
        [12., 13., 13., 13., 13., 13., 13., 13., 13., 13.],
        [12., 13., 13., 13., 13., 13., 13., 13., 13., 13.],
        [12., 13., 13., 13., 13., 13., 13., 13., 13., 13.],
        [12., 13., 13., 13., 13., 13., 13., 13., 13., 13.],
        [12., 13., 13., 13., 13., 13., 13., 13., 13., 13.],
        [12., 13., 13., 13., 13., 13., 13., 13., 13., 13.],
        [12., 13., 13., 13., 13., 13., 13., 13., 13., 13.],
        [12., 13., 13., 13., 13., 13., 13., 13., 13., 13.],
        [12., 13., 13., 13., 13., 13., 13., 13., 13., 13.],
        [12., 13., 13., 13., 13., 13., 13., 13., 13., 13.],
        [12., 13., 13., 13., 13., 13., 13., 13., 13., 13.],
        [12., 13., 13., 13., 13., 13., 1

In [82]:
x = torch.tensor([[[1,2,3],[4,5,6]]])

In [86]:
x[:,0,:]

tensor([[1, 2, 3]])

In [90]:
x = torch.tensor([1,2,3,4,5,6])

In [94]:
x.unsqueeze(1).size()

torch.Size([6, 1])

In [88]:
class LightningModel(pl.LightningModule):
    
    def __init__(self, config):
        super(LightningModel, self).__init__()
        
        self.config = config
        
        self.model = GuidedAttentionDAC(
            model_name=self.config['model_name'],
            hidden_size=self.config['hidden_size'],
            num_classes=self.config['num_classes'],
            device=self.config['device']
        )
        self.tokenizer = AutoTokenizer.from_pretrained(config['model_name'])
        
    def forward(self, batch):
        logits  = self.model(batch)
        return logits
    
    def configure_optimizers(self):
        return optim.Adam(params=self.parameters(), lr=self.config['lr'])
    
    def train_dataloader(self):
        train_data = pd.read_csv(os.path.join(self.config['data_dir'], self.config['dataset'], self.config['dataset']+"_test.csv"), converters={'Text': eval,'DamslActTag': eval})
        train_dataset = DADataset(tokenizer=self.tokenizer, data=train_data, max_len=self.config['max_len'], text_field=self.config['text_field'], label_field=self.config['label_field'])
        drop_last = True if len(train_dataset.text) % self.config['batch_size'] == 1 else False  # Drop last batch if it cointains a single sample (causes error)
        train_loader = DataLoader(dataset=train_dataset, batch_size=self.config['batch_size'], shuffle=False, num_workers=self.config['num_workers'], drop_last=drop_last)
        return train_loader
    
    def training_step(self, batch, batch_idx):
        
        input_ids, attention_mask, targets = batch['input_ids'], batch['attention_mask'], batch['label'].squeeze()
        logits = self(batch)
        loss = F.cross_entropy(logits, targets)
        
        acc = accuracy_score(targets.cpu(), logits.argmax(dim=1).cpu())
        f1 = f1_score(targets.cpu(), logits.argmax(dim=1).cpu(), average=self.config['average'])
        wandb.log({"loss":loss, "accuracy":acc, "f1_score":f1})
        return {"loss":loss, "accuracy":acc, "f1_score":f1}
    
    def val_dataloader(self):
        #valid_data = load_dataset("csv", data_files=os.path.join(self.config['data_dir'], self.config['dataset'], self.config['dataset']+"_valid.csv"))
        valid_data = pd.read_csv(os.path.join(self.config['data_dir'], self.config['dataset'], self.config['dataset']+"_validation.csv"), converters={'Text': eval,'DamslActTag': eval}) # valid has ~40k samples this is valid is same as test to run it quickely, test has ~16k samples
        valid_dataset = DADataset(tokenizer=self.tokenizer, data=valid_data, max_len=self.config['max_len'], text_field=self.config['text_field'], label_field=self.config['label_field'])
        drop_last = True if len(valid_dataset.text) % self.config['batch_size'] == 1 else False  # Drop last batch if it cointains a single sample (causes error)
        valid_loader = DataLoader(dataset=valid_dataset, batch_size=self.config['batch_size'], shuffle=False, num_workers=self.config['num_workers'], drop_last=drop_last)
        return valid_loader
    
    def validation_step(self, batch, batch_idx):
        input_ids, attention_mask, targets = batch['input_ids'], batch['attention_mask'], batch['label'].squeeze()
        logits = self(batch)
        loss = F.cross_entropy(logits, targets)
        acc = accuracy_score(targets.cpu(), logits.argmax(dim=1).cpu())
        f1 = f1_score(targets.cpu(), logits.argmax(dim=1).cpu(), average=self.config['average'])
        precision = precision_score(targets.cpu(), logits.argmax(dim=1).cpu(), average=self.config['average'])
        recall = recall_score(targets.cpu(), logits.argmax(dim=1).cpu(), average=self.config['average'])
        return {"val_loss":loss, "val_accuracy":torch.tensor([acc]), "val_f1":torch.tensor([f1]), "val_precision":torch.tensor([precision]), "val_recall":torch.tensor([recall])}
    
    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        avg_acc = torch.stack([x['val_accuracy'] for x in outputs]).mean()
        avg_f1 = torch.stack([x['val_f1'] for x in outputs]).mean()
        avg_precision = torch.stack([x['val_precision'] for x in outputs]).mean()
        avg_recall = torch.stack([x['val_recall'] for x in outputs]).mean()
        wandb.log({"val_loss":avg_loss, "val_accuracy":avg_acc, "val_f1":avg_f1, "val_precision":avg_precision, "val_recall":avg_recall})
        return {"val_loss":avg_loss, "val_accuracy":avg_acc, "val_f1":avg_f1, "val_precision":avg_precision, "val_recall":avg_recall}
    
    def test_dataloader(self):
        #test_data = load_dataset("csv", data_files=os.path.join(self.config['data_dir'], self.config['dataset'], self.config['dataset']+"_test.csv"))
        test_data = pd.read_csv(os.path.join(self.config['data_dir'], self.config['dataset'], self.config['dataset']+"_test.csv"), converters={'Text': eval,'DamslActTag': eval})
        test_dataset = DADataset(tokenizer=self.tokenizer, data=test_data, max_len=self.config['max_len'], text_field=self.config['text_field'], label_field=self.config['label_field'])
        drop_last = True if len(test_dataset.text) % self.config['batch_size'] == 1 else False  # Drop last batch if it cointains a single sample (causes error)
        test_loader = DataLoader(dataset=test_dataset, batch_size=self.config['batch_size'], shuffle=False, num_workers=self.config['num_workers'], drop_last=drop_last)
        return test_loader
    
    def test_step(self, batch, batch_idx):
        input_ids, attention_mask, targets = batch['input_ids'], batch['attention_mask'], batch['label'].squeeze()
        logits = self(batch)
        loss = F.cross_entropy(logits, targets)
        acc = accuracy_score(targets.cpu(), logits.argmax(dim=1).cpu())
        f1 = f1_score(targets.cpu(), logits.argmax(dim=1).cpu(), average=self.config['average'])
        precision = precision_score(targets.cpu(), logits.argmax(dim=1).cpu(), average=self.config['average'])
        recall = recall_score(targets.cpu(), logits.argmax(dim=1).cpu(), average=self.config['average'])
        return {"test_loss":loss, "test_precision":torch.tensor([precision]), "test_recall":torch.tensor([recall]), "test_accuracy":torch.tensor([acc]), "test_f1":torch.tensor([f1])}
    
    def test_epoch_end(self, outputs):
        avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
        avg_acc = torch.stack([x['test_accuracy'] for x in outputs]).mean()
        avg_f1 = torch.stack([x['test_f1'] for x in outputs]).mean()
        avg_precision = torch.stack([x['test_precision'] for x in outputs]).mean()
        avg_recall = torch.stack([x['test_recall'] for x in outputs]).mean()
        return {"test_loss":avg_loss, "test_precision":avg_precision, "test_recall":avg_recall, "test_acc":avg_acc, "test_f1":avg_f1}