In [22]:
import csv, time, random, numpy as np, pandas as pd
from types import SimpleNamespace
import torch
from torch import nn
from torch.optim import AdamW
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from itertools import cycle
from transformers import BertModel, BertTokenizer
import pandas as pd

In [23]:
def seed_everything(seed= 10002):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

## Datasets class to create Dataloaders

In [24]:
def preprocess_string(s):
    return ' '.join(s.lower().replace('.', ' .').replace('?', ' ?').replace(',', ' ,').replace('\'', ' \'').split())


class SentenceClassificationDataset(Dataset):
    """Inheriting the dataset class for the sentence classification task"""
    def __init__(self, dataset):
        self.dataset = dataset; self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        return self.dataset[idx]

    def pad_data(self, data):
        '''This function pads the data to the max length of the batch'''
        sents = [x[0] for x in data]
        labels = [x[1]-1 for x in data] # subtracting to accomodate for 0-indexed classes
        sent_ids = [x[2] for x in data]

        encoding = self.tokenizer(sents, return_tensors='pt', padding=True, truncation=True)
        token_ids = torch.LongTensor(encoding['input_ids']);    attention_mask = torch.LongTensor(encoding['attention_mask']);  labels = torch.LongTensor(labels)

        return token_ids, attention_mask, labels, sents, sent_ids

    def collate_fn(self, all_data):
        token_ids, attention_mask, labels, sents, sent_ids= self.pad_data(all_data)

        batched_data = { 'token_ids': token_ids, 'attention_mask': attention_mask, 'labels': labels, 'sents': sents, 'sent_ids': sent_ids }
        return batched_data

class SentencePairDataset(Dataset):
    def __init__(self, dataset, isRegression =False):
        self.dataset = dataset
        self.isRegression = isRegression
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        return self.dataset[idx]

    def pad_data(self, data):
        sent1 = [x[0] for x in data]
        sent2 = [x[1] for x in data]
        labels = [x[2] for x in data]
        sent_ids = [x[3] for x in data]

        encoding1 = self.tokenizer(sent1, return_tensors='pt', padding=True, truncation=True)
        encoding2 = self.tokenizer(sent2, return_tensors='pt', padding=True, truncation=True)

        token_ids = torch.LongTensor(encoding1['input_ids'])
        attention_mask = torch.LongTensor(encoding1['attention_mask'])
        token_type_ids = torch.LongTensor(encoding1['token_type_ids'])

        token_ids2 = torch.LongTensor(encoding2['input_ids'])
        attention_mask2 = torch.LongTensor(encoding2['attention_mask'])
        token_type_ids2 = torch.LongTensor(encoding2['token_type_ids'])
        if self.isRegression:
            labels = torch.FloatTensor(labels)
        else:
            labels = torch.LongTensor(labels)
            

        return (token_ids, token_type_ids, attention_mask,
                token_ids2, token_type_ids2, attention_mask2,
                labels,sent_ids)

    def collate_fn(self, all_data):
        (token_ids, token_type_ids, attention_mask, token_ids2, token_type_ids2, attention_mask2, labels, sent_ids) = self.pad_data(all_data)

        batched_data = { 'token_ids_1': token_ids, 'token_type_ids_1': token_type_ids, 'attention_mask_1': attention_mask, 'token_ids_2': token_ids2, 'token_type_ids_2': token_type_ids2, 'attention_mask_2': attention_mask2, 'labels': labels, 'sent_ids': sent_ids }
        return batched_data

def load_multitask_data( sentiment_filename, paraphrase_filename, similarity_filename, emotion_filename, split='train'):
    '''This function loads the training datasets for the multitask dataset'''
    sentiment_data = []
    num_labels = {}

    with open(sentiment_filename, 'r') as fp:
        for record in csv.DictReader(fp,delimiter = '\t'):
            sent = record['sentence'].lower().strip()
            sent_id = record['id'].lower().strip()
            label = int(record['sentiment'].strip())
            if label not in num_labels:
                num_labels[label] = len(num_labels)
            sentiment_data.append((sent, label,sent_id))

    print(f"Loaded {len(sentiment_data)} {split} examples from {sentiment_filename}")

    emotion_data= []
    with open(emotion_filename, 'r') as fp:
        for record in csv.DictReader(fp,delimiter = '\t'):
            sent = record['sentence'].lower().strip()
            sent_id = record['id'].lower().strip()
            label = int(record['sentiment'].strip())
            if label not in num_labels:
                num_labels[label] = len(num_labels)
            emotion_data.append((sent, label,sent_id))

    print(f"Loaded {len(emotion_data)} {split} examples from {emotion_filename}")

    paraphrase_data = []
    with open(paraphrase_filename, 'r') as fp:
        for record in csv.DictReader(fp,delimiter = '\t'):
            try:
                sent_id = record['id'].lower().strip()
                paraphrase_data.append((preprocess_string(record['sentence1']), preprocess_string(record['sentence2']), int(float(record['is_duplicate'])),sent_id))
            except:
                pass

    print(f"Loaded {len(paraphrase_data)} {split} examples from {paraphrase_filename}")

    similarity_data = []
    with open(similarity_filename, 'r') as fp:
        for record in csv.DictReader(fp,delimiter = '\t'):
            sent_id = record['id'].lower().strip()
            similarity_data.append((preprocess_string(record['sentence1']), preprocess_string(record['sentence2']), float(record['similarity']),sent_id))

    print(f"Loaded {len(similarity_data)} {split} examples from {similarity_filename}")
    return sentiment_data, num_labels, paraphrase_data, similarity_data, emotion_data

## Bert class which helps in training of the model

In [25]:
class Bert_MultiTask(nn.Module):
    def __init__(self,config):
        super(Bert_MultiTask, self).__init__()
        self.model = BertModel.from_pretrained("bert-base-uncased", torch_dtype=torch.float16); self.model.to("cpu")
        self.tokenizer= BertTokenizer.from_pretrained("bert-base-uncased")
        BERT_HIDDEN_SIZE = 768
        
        N_SENTIMENT_CLASSES = 5;    N_EMOTION_CLASSES= 14

        # defining the linear layers for sentiment classification
        self.dropout_sentiment = nn.ModuleList([nn.Dropout(config.hidden_dropout_prob) for _ in range(config.n_hidden_layers + 1)])
        self.linear_sentiment = nn.ModuleList([nn.Linear(BERT_HIDDEN_SIZE, BERT_HIDDEN_SIZE, dtype=torch.float16) for _ in range(config.n_hidden_layers)] + [nn.Linear(BERT_HIDDEN_SIZE, N_SENTIMENT_CLASSES, dtype=torch.float16)])
        self.last_linear_sentiment = None

        # defining the layers for emotion detection
        self.dropout_emotion = nn.ModuleList([nn.Dropout(config.hidden_dropout_prob) for _ in range(config.n_hidden_layers + 1)])
        self.linear_emotion = nn.ModuleList([nn.Linear(BERT_HIDDEN_SIZE, BERT_HIDDEN_SIZE, dtype=torch.float16) for _ in range(config.n_hidden_layers)] + [nn.Linear(BERT_HIDDEN_SIZE, N_EMOTION_CLASSES, dtype=torch.float16)])
        self.last_linear_emotion = None

        # Add a linear layer for paraphrase detection
        self.dropout_paraphrase = nn.ModuleList([nn.Dropout(config.hidden_dropout_prob) for _ in range(config.n_hidden_layers + 1)])
        self.linear_paraphrase = nn.ModuleList([nn.Linear(BERT_HIDDEN_SIZE, BERT_HIDDEN_SIZE, dtype=torch.float16) for _ in range(config.n_hidden_layers)] + [nn.Linear(BERT_HIDDEN_SIZE, 1, dtype=torch.float16)])

        # Add a linear layer for semantic textual similarity
        self.dropout_similarity = nn.ModuleList([nn.Dropout(config.hidden_dropout_prob) for _ in range(config.n_hidden_layers + 1)])
        self.linear_similarity = nn.ModuleList([nn.Linear(BERT_HIDDEN_SIZE, BERT_HIDDEN_SIZE,dtype=torch.float16) for _ in range(config.n_hidden_layers)] + [nn.Linear(BERT_HIDDEN_SIZE, 1,dtype=torch.float16)])

    def forward(self, input_ids, attention_mask, task_id):
        with torch.autocast(device_type='cpu', dtype=torch.float16):
            outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
        
        # Extract the [CLS] token embedding
        cls_embedding = outputs.last_hidden_state[:, 0, :]  
        combined_embedding = cls_embedding
        return combined_embedding
    
    def last_layers_sentiment(self, x):
        for i in range(len(self.linear_sentiment) - 1):
            x = self.dropout_sentiment[i](x)
            x.to(torch.float16)
            x = self.linear_sentiment[i](x)
            x = F.relu(x)

        x = self.dropout_sentiment[-1](x)
        logits = self.linear_sentiment[-1](x)
        return logits
    
    def predict_sentiment(self, input_ids, attention_mask):
        x = self.forward(input_ids, attention_mask, task_id=0 )
        x = self.last_layers_sentiment(x)
        return x
    
    def last_layers_emotion(self, x):
        for i in range(len(self.linear_emotion) - 1):
            x = self.dropout_emotion[i](x)
            x.to(torch.float16)
            x = self.linear_emotion[i](x)
            x = F.relu(x)

        x = self.dropout_emotion[-1](x)
        logits = self.linear_emotion[-1](x)
        return logits
    
    def predict_emotion(self, input_ids, attention_mask):
        x = self.forward(input_ids, attention_mask, task_id=3 )
        x = self.last_layers_emotion(x)
        return x

    def get_similarity_paraphrase_embeddings(self, input_ids_1, attention_mask_1, input_ids_2, attention_mask_2, task_id):
        # Get [SEP] token ids
        sep_token_id = torch.tensor([self.tokenizer.sep_token_id], dtype=torch.long, device=input_ids_1.device)
        batch_sep_token_id = sep_token_id.repeat(input_ids_1.shape[0], 1)

        # Concatenate the two sentences in: sent1 [SEP] sent2 [SEP]
        input_id = torch.cat((input_ids_1, batch_sep_token_id, input_ids_2, batch_sep_token_id), dim=1)
        attention_mask = torch.cat((attention_mask_1, torch.ones_like(batch_sep_token_id), attention_mask_2, torch.ones_like(batch_sep_token_id)), dim=1)
        x = self.forward(input_id, attention_mask, task_id=task_id)
        return x

    def last_layers_paraphrase(self, x):
        for i in range(len(self.linear_paraphrase) - 1):
            x = self.dropout_paraphrase[i](x)
            x = self.linear_paraphrase[i](x)
            x = F.relu(x)

        x = self.dropout_paraphrase[-1](x)
        logits = self.linear_paraphrase[-1](x)
        return logits

    def predict_paraphrase(self, input_ids_1, attention_mask_1, input_ids_2, attention_mask_2):
        x = self.get_similarity_paraphrase_embeddings(input_ids_1, attention_mask_1, input_ids_2, attention_mask_2, task_id=1)
        return self.last_layers_paraphrase(x)


    def last_layers_similarity(self, x):
        for i in range(len(self.linear_similarity) - 1):
            x = self.dropout_similarity[i](x)
            x = self.linear_similarity[i](x)
            x = F.relu(x)

        x = self.dropout_similarity[-1](x)
        preds = self.linear_similarity[-1](x)
        preds = torch.sigmoid(preds) * 4 + 1
        return preds
    
    def predict_similarity(self,input_ids_1, attention_mask_1,input_ids_2, attention_mask_2):
        x = self.get_similarity_paraphrase_embeddings(input_ids_1, attention_mask_1, input_ids_2, attention_mask_2, task_id=2)
        return self.last_layers_similarity(x)

## Objects and Scheduler

In [26]:
class ObjectsGroup:
    def __init__(self, model, optimizer, scaler = None):
        self.model = model
        self.optimizer = optimizer
        self.scaler = scaler
        self.loss_sum = 0

class Scheduler:
    def __init__(self, dataloaders, reset=True):
        self.dataloaders = dataloaders
        self.names = list(dataloaders.keys())
        if reset: self.reset()

    def reset(self):
        self.sst_iter = iter(self.dataloaders['sst'])
        self.para_iter = iter(self.dataloaders['para'])
        self.sts_iter = iter(self.dataloaders['sts'])
        self.emt_iter = iter(self.dataloaders['emt'])
        self.steps = {'sst': 0,  'para':0, 'sts':0, 'emt':0}

    def get_SST_batch(self):
        try:
            return next(self.sst_iter)
        except StopIteration:
            self.sst_iter = cycle(self.dataloaders['sst'])
            return next(self.sst_iter)

    def get_EMT_batch(self):
        try:
            return next(self.emt_iter)
        except StopIteration:
            self.emt_iter = cycle(self.dataloaders['emt'])
            return next(self.emt_iter)

    def get_Paraphrase_batch(self):
        try:
            return next(self.para_iter)
        except StopIteration:
            self.para_iter = cycle(self.dataloaders['para'])
            return next(self.para_iter)

    def get_STS_batch(self):
        try:
            return next(self.sts_iter)
        except StopIteration:
            self.sts_iter = cycle(self.dataloaders['sts'])
            return next(self.sts_iter)

    def get_batch(self, name: str):
        if name == "sst": return self.get_SST_batch()
        elif name == "para": return self.get_Paraphrase_batch()
        elif name == "sts": return self.get_STS_batch()
        elif name == "emt": return self.get_EMT_batch()
        raise ValueError(f"Unknown batch name: {name}")

    def process_named_batch(self, objects_group: ObjectsGroup, args: dict, name: str, prev, val, apply_optimization: bool = True):
        '''Processes a batch of data from the given dataset, and updates the model accordingly.'''
        batch = self.get_batch(name)
        process_fn, gradient_accumulations = None, 0
        if name == "sst":
            process_fn = process_sentiment_batch
            gradient_accumulations = args['gradient_accumulations_sst']
        elif name == "para":
            process_fn = process_paraphrase_batch
            gradient_accumulations = args['gradient_accumulations_para']
        elif name == "sts":
            process_fn = process_similarity_batch
            gradient_accumulations = args['gradient_accumulations_sts']
        elif name == "emt":
            process_fn = process_emotion_batch
            gradient_accumulations = args['gradient_accumulations_emt']
        else:
            raise ValueError(f"Unknown batch name: {name}")
        
        loss_of_batch = 0
        for _ in range(gradient_accumulations):
            loss_of_batch += process_fn(batch, objects_group, args)

        self.steps[name] += 1
        if apply_optimization: step_optimizer(objects_group, args, step=self.steps[name])
        
        if(torch.isnan(loss_of_batch).item()):
            loss_of_batch= np.sum(prev) / val

        return loss_of_batch


class RoundRobinScheduler(Scheduler):
    def __init__(self, dataloaders):
        super().__init__(dataloaders, reset=False)
        self.reset()

    def reset(self):
        self.index = 0
        return super().reset()

    def process_one_batch(self, epoch: int, num_epochs: int, objects_group: ObjectsGroup, args: dict):
        name = self.names[self.index]
        self.index = (self.index + 1) % len(self.names)
        return name, self.process_named_batch(objects_group, args, name)

def process_sentiment_batch(batch, objects_group: ObjectsGroup, args: dict):
    device = 'cpu'
    model, scaler = objects_group.model, objects_group.scaler

    with torch.autocast(device_type='cpu', dtype=torch.float16):
        b_ids, b_mask, b_labels = (batch['token_ids'], batch['attention_mask'], batch['labels'])
        b_ids, b_mask, b_labels = b_ids.to(device), b_mask.to(device), b_labels.to(device)

        embeddings = model.forward(b_ids, b_mask, task_id=0)
        logits = model.last_layers_sentiment(embeddings)
        
        loss = F.cross_entropy(logits, b_labels.view(-1), reduction='sum') / args['sst_batch_size']
        loss_value = loss.item()        
        objects_group.loss_sum += loss_value

        loss.backward()
        return loss

def process_emotion_batch(batch, objects_group: ObjectsGroup, args: dict):
    device = 'cpu'
    model, scaler = objects_group.model, objects_group.scaler

    with torch.autocast(device_type='cpu', dtype=torch.float16):
        b_ids, b_mask, b_labels = (batch['token_ids'], batch['attention_mask'], batch['labels'])
        b_ids, b_mask, b_labels = b_ids.to(device), b_mask.to(device), b_labels.to(device)

        embeddings = model.forward(b_ids, b_mask, task_id=3)
        logits = model.last_layers_emotion(embeddings)
        
        loss = F.cross_entropy(logits, b_labels.view(-1), reduction='sum') / args['emt_batch_size']
        loss_value = loss.item()        
        objects_group.loss_sum += loss_value

        loss.backward()
        return loss

def process_paraphrase_batch(batch, objects_group: ObjectsGroup, args: dict):
    device = 'cpu'
    model, scaler = objects_group.model, objects_group.scaler

    with torch.autocast(device_type='cpu', dtype=torch.float16):
        b_ids_1, b_mask_1, b_ids_2, b_mask_2, b_labels = (batch['token_ids_1'], batch['attention_mask_1'], batch['token_ids_2'], batch['attention_mask_2'], batch['labels'])
        b_ids_1, b_mask_1, b_ids_2, b_mask_2, b_labels = b_ids_1.to(device), b_mask_1.to(device), b_ids_2.to(device), b_mask_2.to(device), b_labels.to(device)

        embeddings = model.get_similarity_paraphrase_embeddings(b_ids_1, b_mask_1, b_ids_2, b_mask_2, task_id=1)
        preds = model.last_layers_paraphrase(embeddings)
        loss = F.binary_cross_entropy_with_logits(preds.view(-1), b_labels.float(), reduction='sum') / args['para_batch_size']
        loss_value = loss.item()
        objects_group.loss_sum += loss_value        
        loss.backward()        
        return loss

def process_similarity_batch(batch, objects_group: ObjectsGroup, args: dict):
    device = 'cpu'
    model, scaler = objects_group.model, objects_group.scaler

    with torch.autocast(device_type='cpu', dtype=torch.float16):
        b_ids_1, b_mask_1, b_ids_2, b_mask_2, b_labels = (batch['token_ids_1'], batch['attention_mask_1'], batch['token_ids_2'], batch['attention_mask_2'], batch['labels'])
        b_ids_1, b_mask_1, b_ids_2, b_mask_2, b_labels = b_ids_1.to(device), b_mask_1.to(device), b_ids_2.to(device), b_mask_2.to(device), b_labels.to(device)
        embeddings = model.get_similarity_paraphrase_embeddings(b_ids_1, b_mask_1, b_ids_2, b_mask_2, task_id=2)
        preds = model.last_layers_similarity(embeddings)
        loss = F.mse_loss(preds.view(-1), b_labels.view(-1), reduction='sum') / args['sts_batch_size']
        loss_value = loss.item()
        objects_group.loss_sum += loss_value
        loss.backward()
        return loss

def step_optimizer(objects_group: ObjectsGroup, args: dict, step: int, total_nb_batches = None):
    optimizer, scaler = objects_group.optimizer, objects_group.scaler
    optimizer.step()
    optimizer.zero_grad()
    loss_value = objects_group.loss_sum
    objects_group.loss_sum = 0
    torch.cuda.empty_cache()
    return loss_value

In [27]:
seed_value= 4269
seed_everything(seed_value)

In [28]:
args= {
    'sst_file': './train_data/train_sentiment.csv',
    'para_file': './train_data/train_paraphase.csv',
    'sts_file':  './train_data/train_similarity.csv',
    'emt_file': './train_data/train_emotion.csv',
    'para_batch_size': 2,
    'sst_batch_size': 2,
    'sts_batch_size': 2,
    'emt_batch_size':2,
    'hidden_layers': 2,
    'hidden_drp_prob': 0.2,
    'lr': 1e-6,
    'epochs': 6,
    'patience': 0.2,
    'option': 'individual_pertrin',
    'num_batches_per_epoch':2,
    'gradient_accumulations_sst':3,
    'gradient_accumulations_sts':3,
    'gradient_accumulations_para':3,
    'gradient_accumulations_emt': 3
}

In [29]:
device = torch.device('cpu')
sst_train_data, num_labels,para_train_data, sts_train_data, emt_train_data= load_multitask_data(args['sst_file'],args['para_file'],args['sts_file'], args['emt_file'],  split ='train')

Loaded 70 train examples from ./train_data/train_sentiment.csv
Loaded 70 train examples from ./train_data/train_emotion.csv
Loaded 70 train examples from ./train_data/train_paraphase.csv
Loaded 70 train examples from ./train_data/train_similarity.csv


In [30]:
# Sentiment data loader
sst_train_data = SentenceClassificationDataset(sst_train_data)
sst_train_dataloader = DataLoader(sst_train_data, shuffle=True, batch_size= args['sst_batch_size'], collate_fn=sst_train_data.collate_fn)

# Emotion data loader
emt_train_data = SentenceClassificationDataset(emt_train_data)
emt_train_dataloader = DataLoader(emt_train_data, shuffle=True, batch_size= args['emt_batch_size'], collate_fn=emt_train_data.collate_fn)

# Paraphrase data loader
para_train_data = SentencePairDataset(para_train_data)
para_train_dataloader = DataLoader(para_train_data, shuffle=True, batch_size= args['para_batch_size'],collate_fn=para_train_data.collate_fn)

# Similarity detection data loader
sts_train_data = SentencePairDataset(sts_train_data, isRegression=True)
sts_train_dataloader = DataLoader(sts_train_data, shuffle=True, batch_size=args['sts_batch_size'],collate_fn=sts_train_data.collate_fn)

In [31]:
config = {'hidden_dropout_prob': args['hidden_drp_prob'],'num_labels': 5,'hidden_size': 768,'data_dir': '.','option': args['option'],'n_hidden_layers': args['hidden_layers']}
config = SimpleNamespace(**config)
model = Bert_MultiTask(config).to(device)
optimizer = AdamW(model.parameters(), lr=args['lr'])
scaler = None

In [32]:
objects_group = ObjectsGroup(model, optimizer, scaler)
args['device'] = device
dataloaders = {'sst': sst_train_dataloader, 'para': para_train_dataloader, 'sts': sts_train_dataloader, 'emt':emt_train_dataloader}
scheduler = RoundRobinScheduler(dataloaders)

In [33]:
total_loss = {'sst': [], 'para': [], 'sts': [], 'emt': []}
n_batches = 0;  lva=0;t1=1
num_batches_per_epoch = args['num_batches_per_epoch'] if args['num_batches_per_epoch'] > 0 else len(sst_train_dataloader)
infos = {'sst': {'num_batches': num_batches_per_epoch,  'best_dev_acc': 0, 'best_model': None, 'layer': model.linear_sentiment, 'optimizer': AdamW(model.parameters(), lr= args['lr']), "last_improv": -1, 'first': True, 'first_loss': True}, 
        'para': {'num_batches': num_batches_per_epoch,  'best_dev_acc': 0, 'best_model': None, 'layer': model.linear_paraphrase, 'optimizer': AdamW(model.parameters(), lr=args['lr']), "last_improv": -1, 'first': True, 'first_loss': True},
        'sts':  {'num_batches': num_batches_per_epoch, 'best_dev_acc': 0, 'best_model': None, 'layer': model.linear_similarity, 'optimizer': AdamW(model.parameters(), lr=args['lr']), "last_improv": -1, 'first': True, 'first_loss': True},
        'emt':  {'num_batches': num_batches_per_epoch, 'best_dev_acc': 0, 'best_model': None, 'layer': model.linear_emotion, 'optimizer': AdamW(model.parameters(), lr=args['lr']), "last_improv": -1, 'first': True, 'first_loss': True}}

total_num_batches = {'sst': 0, 'para': 0, 'sts': 0, 'emt': 0}
for epoch in range(args['epochs']):
    for task in ['sst', 'sts', 'para', 'emt']:
        model.train()
        objects_group.optimizer = infos[task]['optimizer']
        for i in tqdm(range(infos[task]['num_batches']), desc=task + ' epoch ' + str(epoch), smoothing=0):
            
            loss = scheduler.process_named_batch(name=task, objects_group=objects_group, args=args, prev= total_loss[task], val=lva)
            total_loss[task].append(float(loss.item()))

            total_num_batches[task] += 1
            n_batches += 1
            lva+= t1; t1+=1

sst epoch 0:   0%|          | 0/2 [00:00<?, ?it/s]

sst epoch 0: 100%|██████████| 2/2 [02:59<00:00, 89.84s/it]
sts epoch 0: 100%|██████████| 2/2 [02:40<00:00, 80.48s/it]
para epoch 0: 100%|██████████| 2/2 [04:42<00:00, 141.21s/it]
emt epoch 0: 100%|██████████| 2/2 [01:57<00:00, 58.95s/it]
sst epoch 1: 100%|██████████| 2/2 [04:00<00:00, 120.21s/it]
sts epoch 1: 100%|██████████| 2/2 [02:15<00:00, 67.53s/it]
para epoch 1: 100%|██████████| 2/2 [03:29<00:00, 104.69s/it]
emt epoch 1: 100%|██████████| 2/2 [02:33<00:00, 76.79s/it]
sst epoch 2: 100%|██████████| 2/2 [15:55<00:00, 477.76s/it]
sts epoch 2: 100%|██████████| 2/2 [02:04<00:00, 62.43s/it]
para epoch 2: 100%|██████████| 2/2 [03:51<00:00, 115.63s/it]
emt epoch 2: 100%|██████████| 2/2 [01:36<00:00, 48.34s/it]
sst epoch 3: 100%|██████████| 2/2 [04:12<00:00, 126.27s/it]
sts epoch 3: 100%|██████████| 2/2 [01:48<00:00, 54.23s/it]
para epoch 3: 100%|██████████| 2/2 [04:36<00:00, 138.17s/it]
emt epoch 3: 100%|██████████| 2/2 [01:55<00:00, 57.99s/it]
sst epoch 4: 100%|██████████| 2/2 [03:50<00:0

In [37]:
torch.save(
    model, 'multimodel.pt'
)