In [4]:
from utils import read_data, read_instances
from transformers import AutoModelForSeq2SeqLM, PreTrainedTokenizerFast, AutoTokenizer, AutoConfig
from transformers.models.bart.modeling_bart import BartForConditionalGeneration
import torch.nn.functional as F
import torch

# load data
file_path = "data/task1/train/eLife_train.jsonl"
instances = read_instances(file_path)

# tokenizer = AutoTokenizer.from_pretrained("google/bigbird-pegasus-large-arxiv")
# by default encoder-attention is `block_sparse` with num_random_blocks=3, block_size=64
# model = BigBirdPegasusForConditionalGeneration.from_pretrained("google/bigbird-pegasus-large-arxiv") # 2.31G

Read data from :  data/task1/train/eLife_train.jsonl
The number of data:  4346


## Content Selection by Rouge Scores
select silient sentences

In [26]:
import numpy as np
from rouge import Rouge
from nltk import tokenize
rouge_pltrdy = Rouge()


def get_rouge2recall_scores_nopad(sentences, reference, oracle_type=None):
    # rouge_pltrdy is case sensitive
    reference = reference.lower()
    scores = [None for _ in range(len(sentences))]
    count_nonzero_rouge2recall = 0
    for i, sent in enumerate(sentences):
        sent = sent.lower()
        try:
            rouge_scores = rouge_pltrdy.get_scores(sent, reference)
            scores[i]  = rouge_scores[0]['rouge-2']['r'] # rouge2recall
        except ValueError:
            scores[i] = 0.0
        except RecursionError:
            scores[i] = 0.5 # just assign 0.5 as this sentence is simply too long
        if scores[i] > 0.0: count_nonzero_rouge2recall += 1
    # print('count_nonzero_rouge2recall=', count_nonzero_rouge2recall)
    scores = np.array(scores)
    N = len(scores)

    if oracle_type == 'padlead':
        biases = np.array([(N-i)*1e-12 for i in range(N)])
    elif oracle_type == 'padrand':
        biases = np.random.normal(scale=1e-10,size=(N,))
    else: # no pad 
        return np.array(scores)
    return np.array(scores) + biases

def compress_article(article):
    sentences = tokenize.sent_tokenize(article)
    # print(f'There are {len(sentences)} sentences.')
    reference = summaries[0]

    ## rank by ROUGH
    keep_idx = []
    scores = get_rouge2recall_scores_nopad( sentences, reference, oracle_type='padrand' )
    num_postive = sum(a > 0 for a in scores)
    rank = np.argsort(scores)[::-1][:num_postive] # only consider positive ones

    ## select high-ranked sentences
    keep_idx = []
    total_length = 0
    max_abssum_len = 1024-2
    for sent_i in rank:
        if total_length < max_abssum_len:
            sent = sentences[sent_i]
            total_length += len(bart_tokenizer.encode(sent)[1:-1]) # ignore <s> and </s>
            keep_idx.append(sent_i)
        else:
            break
    assert len(keep_idx) > 0
    ## if found nothing, selecting the top3 longest sentences
    # if len(keep_idx) == 0:
    #     sent_lengths = [len(tokenize.word_tokenize(ssent)) for ssent in sentences]
    #     keep_idx = np.argsort(sent_lengths)[::-1][:3].tolist()
    keep_idx = sorted(keep_idx) # back to original order
    filtered_sentences = [sentences[j] for j in keep_idx]
    filtered_input_text = " ".join(filtered_sentences)
    return filtered_input_text


In [19]:
compressed_articles = []
for i, article in enumerate(articles):
    print(i)
    filtered_input_text = compress_article(article)
    compressed_articles.append(compressed_articles)

There are 126 sentences.
count_nonzero_rouge2recall= 73


In [22]:
# with open(out_path, "w") as f:
#     f.write(filtered_input_text)
# print("write:", out_path)

'However , there is limited information on the timing and the relative magnitudes of maximum and minimum mortality , by local climate , age group , sex and medical cause of death . We used geo-coded mortality data and wavelets to analyse the seasonality of mortality by age group and sex from 1980 to 2016 in the USA and its subnational climatic regions . In adolescents and young adults , especially in males , death rates peaked in June/July and were lowest in December/January , driven by injury deaths . It is well-established that death rates vary throughout the year , and in temperate climates there tend to be more deaths in winter than in summer ( Campbell , 2017; Fowler et al . In a large country like the USA , which possesses distinct climate regions , the seasonality of mortality may vary geographically , due to geographical variations in mortality , localized weather patterns , and regional differences in adaptation measures such as heating , air conditioning and healthcare ( Davi

## Train BART

In [2]:
# load model
# tokenizer = PreTrainedTokenizerFast.from_pretrained("facebook/bart-base") # no <pad> token
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base")
bart = BartForConditionalGeneration.from_pretrained("facebook/bart-base")
config = AutoConfig.from_pretrained("facebook/bart-base")

from torch.utils.data import Dataset, DataLoader

# create a dataset class for data loader
class SummarizationDataset(Dataset):
    def __init__(self, instances):
        self.instances = instances
        # sort the batch in the descending order of the number of tokens in the article
        instances.sort(key=lambda x: len(x['article']), reverse=True)

    def __len__(self):
        return len(self.instances)
    
    def __getitem__(self, idx):
        return self.instances[idx]['article'], self.instances[idx]['lay_summary']

def collate_fn(batch):
    """create a collate function for data loader. It will convert a batch (list) of texts with varying lengths into rectangular tensors"""
    # batch is a list of tuples (article, summary)
    articles = [item[0] for item in batch]
    summaries = [item[1] for item in batch]
    
    # tokenize the articles and summaries
    source_encodings = tokenizer(articles, padding="longest", truncation=True, return_tensors='pt')
    target_encodings = tokenizer(summaries, padding="longest", truncation=True, return_tensors='pt')

    batch = {key: source_encodings[key] for key in source_encodings}
    batch['labels'] = torch.tensor([ [-100 if id == tokenizer.pad_token_id else id for id in token_ids ] for token_ids in target_encodings['input_ids']])

    return batch

def sequence_cross_entropy_with_logits(logits, shifted_target_ids, target_mask=None):
    """Cross entropy that accepts logits and target sentence ids.
    Args:
        target_ids: [batch_size x sequence_length]   ground-truth predictions                         <s>   A   B C </s>  
                                                                                                       ^    ^   ^ ^ ^
        logits: [batch_size x sequence_length x vocab_size] logits generated from the input x         </s>  <s> A B C 
        
        shifted_target_mask: [batch_size x sequence_length] mask for the shifted target ids
    """
    # flatten
    logits_flat = logits.view(-1, logits.size(-1))
    targets_flatten = shifted_target_ids.view(-1)
    return F.cross_entropy(logits_flat, targets_flatten, target_mask)

dataset = SummarizationDataset(instances)


In [11]:
from tqdm import tqdm
from transformers import AdamW

    
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
bart = bart.to(device)
optimizer = AdamW(bart.parameters(), lr=1e-5)
dataloader = DataLoader(dataset, batch_size=2, collate_fn=collate_fn)
for batch in tqdm(dataloader):
    torch.cuda.empty_cache() 
    batch = {key: batch[key].to(device) for key in batch}

    # decoder input by shifting the labels to the right as the input: <s> A B C </s> -> </s> <s> A B C 
    # wrong ones: shift the labels to the right: <s> A B C </s> -> <s> A B C 
    # decoder_input_ids=batch['labels'][:, :-1].contiguous()
    # decoder_attention_mask=target_mask[:, :-1].contiguous()
    decoder_input_ids = batch['labels'].new_zeros(batch['labels'].shape)
    decoder_input_ids[:, 1:] = batch['labels'][:, :-1].clone().type(torch.LongTensor).contiguous()
    decoder_input_ids[:, 0] = bart.config.decoder_start_token_id
    decoder_input_ids.masked_fill_(decoder_input_ids == -100, bart.config.pad_token_id)
    outputs = bart(input_ids=batch['input_ids'], attention_mask=batch['attention_mask'], decoder_input_ids=decoder_input_ids, return_dict=True, use_cache=False,)
    
    loss = sequence_cross_entropy_with_logits(outputs.logits, batch['labels'].contiguous()) # target_mask.type(torch.LongTensor).contiguous()
    loss.backward()
    optimizer.step()
    print(loss)






  0%|          | 1/2173 [00:00<17:32,  2.06it/s]

tensor(3.5736, device='cuda:0', grad_fn=<NllLossBackward>)


  0%|          | 2/2173 [00:01<18:03,  2.00it/s]

tensor(3.2399, device='cuda:0', grad_fn=<NllLossBackward>)


  0%|          | 3/2173 [00:01<16:57,  2.13it/s]

tensor(3.4055, device='cuda:0', grad_fn=<NllLossBackward>)


  0%|          | 4/2173 [00:01<16:13,  2.23it/s]

tensor(3.4330, device='cuda:0', grad_fn=<NllLossBackward>)


  0%|          | 5/2173 [00:02<16:10,  2.23it/s]

tensor(3.4173, device='cuda:0', grad_fn=<NllLossBackward>)


  0%|          | 6/2173 [00:02<15:59,  2.26it/s]

tensor(3.1338, device='cuda:0', grad_fn=<NllLossBackward>)


  0%|          | 7/2173 [00:03<15:38,  2.31it/s]

tensor(3.0283, device='cuda:0', grad_fn=<NllLossBackward>)


  0%|          | 8/2173 [00:03<15:17,  2.36it/s]

tensor(3.2676, device='cuda:0', grad_fn=<NllLossBackward>)


  0%|          | 9/2173 [00:03<15:13,  2.37it/s]

tensor(3.0552, device='cuda:0', grad_fn=<NllLossBackward>)


  0%|          | 10/2173 [00:04<15:22,  2.35it/s]

tensor(3.6096, device='cuda:0', grad_fn=<NllLossBackward>)


  1%|          | 11/2173 [00:04<15:20,  2.35it/s]

tensor(3.2752, device='cuda:0', grad_fn=<NllLossBackward>)


  1%|          | 12/2173 [00:05<15:33,  2.31it/s]

tensor(3.2852, device='cuda:0', grad_fn=<NllLossBackward>)


  1%|          | 13/2173 [00:05<15:49,  2.27it/s]

tensor(3.2770, device='cuda:0', grad_fn=<NllLossBackward>)


  1%|          | 14/2173 [00:06<16:17,  2.21it/s]

tensor(3.2158, device='cuda:0', grad_fn=<NllLossBackward>)


  1%|          | 15/2173 [00:06<16:04,  2.24it/s]

tensor(3.3716, device='cuda:0', grad_fn=<NllLossBackward>)


  1%|          | 16/2173 [00:07<15:46,  2.28it/s]

tensor(3.0644, device='cuda:0', grad_fn=<NllLossBackward>)


  1%|          | 17/2173 [00:07<15:36,  2.30it/s]

tensor(3.0823, device='cuda:0', grad_fn=<NllLossBackward>)


  1%|          | 18/2173 [00:07<15:22,  2.33it/s]

tensor(3.0612, device='cuda:0', grad_fn=<NllLossBackward>)


  1%|          | 19/2173 [00:08<15:14,  2.35it/s]

tensor(3.3015, device='cuda:0', grad_fn=<NllLossBackward>)


  1%|          | 20/2173 [00:08<15:13,  2.36it/s]

tensor(3.5641, device='cuda:0', grad_fn=<NllLossBackward>)


  1%|          | 21/2173 [00:09<15:38,  2.29it/s]

tensor(3.3743, device='cuda:0', grad_fn=<NllLossBackward>)


  1%|          | 22/2173 [00:09<15:28,  2.32it/s]

tensor(3.3847, device='cuda:0', grad_fn=<NllLossBackward>)


  1%|          | 23/2173 [00:09<15:10,  2.36it/s]

tensor(3.4030, device='cuda:0', grad_fn=<NllLossBackward>)


  1%|          | 24/2173 [00:10<15:06,  2.37it/s]

tensor(3.3244, device='cuda:0', grad_fn=<NllLossBackward>)


  1%|          | 25/2173 [00:10<15:09,  2.36it/s]

tensor(3.3146, device='cuda:0', grad_fn=<NllLossBackward>)


  1%|          | 26/2173 [00:11<15:14,  2.35it/s]

tensor(3.6039, device='cuda:0', grad_fn=<NllLossBackward>)


  1%|          | 27/2173 [00:11<15:10,  2.36it/s]

tensor(3.6006, device='cuda:0', grad_fn=<NllLossBackward>)


  1%|▏         | 28/2173 [00:12<14:54,  2.40it/s]

tensor(3.0269, device='cuda:0', grad_fn=<NllLossBackward>)


  1%|▏         | 29/2173 [00:12<15:05,  2.37it/s]

tensor(3.3275, device='cuda:0', grad_fn=<NllLossBackward>)


  1%|▏         | 30/2173 [00:12<14:50,  2.41it/s]

tensor(3.5449, device='cuda:0', grad_fn=<NllLossBackward>)


  1%|▏         | 31/2173 [00:13<14:24,  2.48it/s]

tensor(3.7461, device='cuda:0', grad_fn=<NllLossBackward>)


  1%|▏         | 32/2173 [00:13<14:20,  2.49it/s]

tensor(3.5761, device='cuda:0', grad_fn=<NllLossBackward>)


  2%|▏         | 33/2173 [00:14<14:34,  2.45it/s]

tensor(3.4767, device='cuda:0', grad_fn=<NllLossBackward>)


  2%|▏         | 34/2173 [00:14<14:42,  2.42it/s]

tensor(3.2609, device='cuda:0', grad_fn=<NllLossBackward>)


  2%|▏         | 35/2173 [00:14<14:37,  2.44it/s]

tensor(3.7554, device='cuda:0', grad_fn=<NllLossBackward>)


  2%|▏         | 36/2173 [00:15<14:28,  2.46it/s]

tensor(3.5478, device='cuda:0', grad_fn=<NllLossBackward>)


  2%|▏         | 37/2173 [00:15<14:29,  2.46it/s]

tensor(3.6731, device='cuda:0', grad_fn=<NllLossBackward>)


  2%|▏         | 38/2173 [00:16<14:50,  2.40it/s]

tensor(3.3076, device='cuda:0', grad_fn=<NllLossBackward>)


  2%|▏         | 39/2173 [00:16<14:59,  2.37it/s]

tensor(3.3298, device='cuda:0', grad_fn=<NllLossBackward>)


  2%|▏         | 40/2173 [00:17<15:01,  2.37it/s]

tensor(3.6968, device='cuda:0', grad_fn=<NllLossBackward>)


  2%|▏         | 41/2173 [00:17<15:18,  2.32it/s]

tensor(3.7190, device='cuda:0', grad_fn=<NllLossBackward>)


  2%|▏         | 42/2173 [00:17<15:33,  2.28it/s]

tensor(3.5147, device='cuda:0', grad_fn=<NllLossBackward>)


  2%|▏         | 43/2173 [00:18<15:01,  2.36it/s]

tensor(3.3114, device='cuda:0', grad_fn=<NllLossBackward>)


  2%|▏         | 44/2173 [00:18<14:36,  2.43it/s]

tensor(3.6724, device='cuda:0', grad_fn=<NllLossBackward>)


  2%|▏         | 45/2173 [00:19<14:10,  2.50it/s]

tensor(3.4323, device='cuda:0', grad_fn=<NllLossBackward>)


  2%|▏         | 46/2173 [00:19<14:01,  2.53it/s]

tensor(3.3988, device='cuda:0', grad_fn=<NllLossBackward>)


  2%|▏         | 47/2173 [00:19<14:09,  2.50it/s]

tensor(3.3981, device='cuda:0', grad_fn=<NllLossBackward>)


  2%|▏         | 48/2173 [00:20<14:05,  2.51it/s]

tensor(3.2672, device='cuda:0', grad_fn=<NllLossBackward>)


  2%|▏         | 49/2173 [00:20<14:10,  2.50it/s]

tensor(3.4471, device='cuda:0', grad_fn=<NllLossBackward>)


  2%|▏         | 50/2173 [00:21<14:06,  2.51it/s]

tensor(3.5446, device='cuda:0', grad_fn=<NllLossBackward>)


  2%|▏         | 51/2173 [00:21<14:23,  2.46it/s]

tensor(3.6412, device='cuda:0', grad_fn=<NllLossBackward>)


  2%|▏         | 52/2173 [00:21<14:27,  2.44it/s]

tensor(3.9774, device='cuda:0', grad_fn=<NllLossBackward>)


  2%|▏         | 53/2173 [00:22<14:08,  2.50it/s]

tensor(4.0256, device='cuda:0', grad_fn=<NllLossBackward>)


  2%|▏         | 54/2173 [00:22<15:26,  2.29it/s]

tensor(4.1483, device='cuda:0', grad_fn=<NllLossBackward>)


  3%|▎         | 55/2173 [00:23<15:20,  2.30it/s]

tensor(4.3251, device='cuda:0', grad_fn=<NllLossBackward>)


  3%|▎         | 56/2173 [00:23<14:53,  2.37it/s]

tensor(4.1979, device='cuda:0', grad_fn=<NllLossBackward>)


  3%|▎         | 57/2173 [00:24<14:56,  2.36it/s]

tensor(3.5535, device='cuda:0', grad_fn=<NllLossBackward>)


  3%|▎         | 58/2173 [00:24<14:37,  2.41it/s]

tensor(4.4002, device='cuda:0', grad_fn=<NllLossBackward>)


  3%|▎         | 59/2173 [00:24<14:22,  2.45it/s]

tensor(3.8977, device='cuda:0', grad_fn=<NllLossBackward>)


  3%|▎         | 60/2173 [00:25<14:21,  2.45it/s]

tensor(3.7476, device='cuda:0', grad_fn=<NllLossBackward>)


  3%|▎         | 61/2173 [00:25<14:11,  2.48it/s]

tensor(4.0955, device='cuda:0', grad_fn=<NllLossBackward>)


  3%|▎         | 62/2173 [00:26<14:21,  2.45it/s]

tensor(3.7231, device='cuda:0', grad_fn=<NllLossBackward>)


  3%|▎         | 63/2173 [00:26<14:03,  2.50it/s]

tensor(3.9349, device='cuda:0', grad_fn=<NllLossBackward>)


  3%|▎         | 64/2173 [00:26<14:15,  2.46it/s]

tensor(4.0672, device='cuda:0', grad_fn=<NllLossBackward>)


  3%|▎         | 65/2173 [00:27<14:12,  2.47it/s]

tensor(3.3498, device='cuda:0', grad_fn=<NllLossBackward>)


  3%|▎         | 66/2173 [00:27<14:08,  2.48it/s]

tensor(3.6647, device='cuda:0', grad_fn=<NllLossBackward>)


  3%|▎         | 67/2173 [00:28<14:19,  2.45it/s]

tensor(4.0196, device='cuda:0', grad_fn=<NllLossBackward>)


  3%|▎         | 68/2173 [00:28<14:18,  2.45it/s]

tensor(3.9069, device='cuda:0', grad_fn=<NllLossBackward>)


  3%|▎         | 69/2173 [00:28<14:22,  2.44it/s]

tensor(3.7091, device='cuda:0', grad_fn=<NllLossBackward>)


  3%|▎         | 70/2173 [00:29<14:40,  2.39it/s]

tensor(3.4978, device='cuda:0', grad_fn=<NllLossBackward>)


  3%|▎         | 71/2173 [00:29<14:33,  2.41it/s]

tensor(3.8211, device='cuda:0', grad_fn=<NllLossBackward>)


  3%|▎         | 72/2173 [00:30<14:23,  2.43it/s]

tensor(4.0533, device='cuda:0', grad_fn=<NllLossBackward>)


  3%|▎         | 73/2173 [00:30<14:20,  2.44it/s]

tensor(3.8006, device='cuda:0', grad_fn=<NllLossBackward>)


  3%|▎         | 74/2173 [00:30<14:10,  2.47it/s]

tensor(4.2750, device='cuda:0', grad_fn=<NllLossBackward>)


  3%|▎         | 75/2173 [00:31<14:04,  2.49it/s]

tensor(4.1118, device='cuda:0', grad_fn=<NllLossBackward>)


  3%|▎         | 76/2173 [00:31<14:08,  2.47it/s]

tensor(3.3973, device='cuda:0', grad_fn=<NllLossBackward>)


  4%|▎         | 77/2173 [00:32<14:04,  2.48it/s]

tensor(3.8520, device='cuda:0', grad_fn=<NllLossBackward>)


  4%|▎         | 78/2173 [00:32<14:14,  2.45it/s]

tensor(3.8352, device='cuda:0', grad_fn=<NllLossBackward>)


  4%|▎         | 79/2173 [00:32<14:00,  2.49it/s]

tensor(3.9003, device='cuda:0', grad_fn=<NllLossBackward>)


  4%|▎         | 80/2173 [00:33<13:47,  2.53it/s]

tensor(3.9803, device='cuda:0', grad_fn=<NllLossBackward>)


  4%|▎         | 81/2173 [00:33<13:41,  2.55it/s]

tensor(4.2028, device='cuda:0', grad_fn=<NllLossBackward>)


  4%|▍         | 82/2173 [00:34<14:00,  2.49it/s]

tensor(3.7888, device='cuda:0', grad_fn=<NllLossBackward>)


  4%|▍         | 83/2173 [00:34<14:05,  2.47it/s]

tensor(4.0451, device='cuda:0', grad_fn=<NllLossBackward>)


  4%|▍         | 84/2173 [00:35<14:37,  2.38it/s]

tensor(4.1383, device='cuda:0', grad_fn=<NllLossBackward>)


  4%|▍         | 85/2173 [00:35<14:42,  2.36it/s]

tensor(3.8635, device='cuda:0', grad_fn=<NllLossBackward>)


  4%|▍         | 86/2173 [00:35<14:41,  2.37it/s]

tensor(4.2951, device='cuda:0', grad_fn=<NllLossBackward>)


  4%|▍         | 87/2173 [00:36<14:27,  2.40it/s]

tensor(3.7562, device='cuda:0', grad_fn=<NllLossBackward>)


  4%|▍         | 88/2173 [00:36<14:17,  2.43it/s]

tensor(4.0767, device='cuda:0', grad_fn=<NllLossBackward>)


  4%|▍         | 89/2173 [00:37<14:09,  2.45it/s]

tensor(3.6683, device='cuda:0', grad_fn=<NllLossBackward>)


  4%|▍         | 90/2173 [00:37<14:07,  2.46it/s]

tensor(3.5781, device='cuda:0', grad_fn=<NllLossBackward>)


  4%|▍         | 91/2173 [00:37<14:23,  2.41it/s]

tensor(3.8919, device='cuda:0', grad_fn=<NllLossBackward>)


  4%|▍         | 92/2173 [00:38<14:24,  2.41it/s]

tensor(3.7572, device='cuda:0', grad_fn=<NllLossBackward>)


  4%|▍         | 93/2173 [00:38<14:39,  2.37it/s]

tensor(3.7279, device='cuda:0', grad_fn=<NllLossBackward>)


  4%|▍         | 94/2173 [00:39<14:26,  2.40it/s]

tensor(3.7631, device='cuda:0', grad_fn=<NllLossBackward>)


  4%|▍         | 95/2173 [00:39<14:13,  2.43it/s]

tensor(3.8865, device='cuda:0', grad_fn=<NllLossBackward>)


  4%|▍         | 96/2173 [00:40<14:11,  2.44it/s]

tensor(3.6979, device='cuda:0', grad_fn=<NllLossBackward>)


  4%|▍         | 97/2173 [00:40<13:58,  2.47it/s]

tensor(3.5368, device='cuda:0', grad_fn=<NllLossBackward>)


  5%|▍         | 98/2173 [00:40<13:52,  2.49it/s]

tensor(3.8725, device='cuda:0', grad_fn=<NllLossBackward>)


  5%|▍         | 99/2173 [00:41<13:52,  2.49it/s]

tensor(4.1526, device='cuda:0', grad_fn=<NllLossBackward>)


  5%|▍         | 100/2173 [00:41<13:49,  2.50it/s]

tensor(3.8307, device='cuda:0', grad_fn=<NllLossBackward>)


  5%|▍         | 101/2173 [00:42<14:09,  2.44it/s]

tensor(3.9147, device='cuda:0', grad_fn=<NllLossBackward>)


  5%|▍         | 102/2173 [00:42<14:14,  2.42it/s]

tensor(3.5953, device='cuda:0', grad_fn=<NllLossBackward>)


  5%|▍         | 103/2173 [00:42<14:08,  2.44it/s]

tensor(3.5552, device='cuda:0', grad_fn=<NllLossBackward>)


  5%|▍         | 104/2173 [00:43<14:15,  2.42it/s]

tensor(3.9316, device='cuda:0', grad_fn=<NllLossBackward>)


  5%|▍         | 105/2173 [00:43<14:18,  2.41it/s]

tensor(3.8370, device='cuda:0', grad_fn=<NllLossBackward>)


  5%|▍         | 106/2173 [00:44<14:20,  2.40it/s]

tensor(3.9904, device='cuda:0', grad_fn=<NllLossBackward>)


  5%|▍         | 107/2173 [00:44<14:18,  2.41it/s]

tensor(3.7620, device='cuda:0', grad_fn=<NllLossBackward>)


  5%|▍         | 108/2173 [00:44<14:03,  2.45it/s]

tensor(4.0771, device='cuda:0', grad_fn=<NllLossBackward>)


  5%|▌         | 109/2173 [00:45<14:25,  2.38it/s]

tensor(4.0167, device='cuda:0', grad_fn=<NllLossBackward>)


  5%|▌         | 110/2173 [00:45<14:24,  2.39it/s]

tensor(3.7944, device='cuda:0', grad_fn=<NllLossBackward>)


  5%|▌         | 111/2173 [00:46<14:09,  2.43it/s]

tensor(3.9031, device='cuda:0', grad_fn=<NllLossBackward>)


  5%|▌         | 112/2173 [00:46<13:57,  2.46it/s]

tensor(3.4272, device='cuda:0', grad_fn=<NllLossBackward>)


  5%|▌         | 113/2173 [00:46<14:01,  2.45it/s]

tensor(3.8988, device='cuda:0', grad_fn=<NllLossBackward>)


  5%|▌         | 114/2173 [00:47<14:06,  2.43it/s]

tensor(4.3037, device='cuda:0', grad_fn=<NllLossBackward>)


  5%|▌         | 115/2173 [00:47<14:01,  2.45it/s]

tensor(3.9484, device='cuda:0', grad_fn=<NllLossBackward>)


  5%|▌         | 116/2173 [00:48<13:44,  2.50it/s]

tensor(3.4392, device='cuda:0', grad_fn=<NllLossBackward>)


  5%|▌         | 117/2173 [00:48<13:35,  2.52it/s]

tensor(3.8204, device='cuda:0', grad_fn=<NllLossBackward>)


  5%|▌         | 118/2173 [00:48<13:41,  2.50it/s]

tensor(3.8008, device='cuda:0', grad_fn=<NllLossBackward>)


  5%|▌         | 119/2173 [00:49<13:39,  2.51it/s]

tensor(3.6988, device='cuda:0', grad_fn=<NllLossBackward>)


  6%|▌         | 120/2173 [00:49<13:58,  2.45it/s]

tensor(4.2549, device='cuda:0', grad_fn=<NllLossBackward>)


  6%|▌         | 121/2173 [00:50<14:17,  2.39it/s]

tensor(4.1715, device='cuda:0', grad_fn=<NllLossBackward>)


  6%|▌         | 122/2173 [00:50<14:01,  2.44it/s]

tensor(3.9111, device='cuda:0', grad_fn=<NllLossBackward>)


  6%|▌         | 123/2173 [00:51<14:06,  2.42it/s]

tensor(4.0817, device='cuda:0', grad_fn=<NllLossBackward>)


  6%|▌         | 124/2173 [00:51<14:00,  2.44it/s]

tensor(3.9385, device='cuda:0', grad_fn=<NllLossBackward>)


  6%|▌         | 125/2173 [00:51<14:09,  2.41it/s]

tensor(4.3299, device='cuda:0', grad_fn=<NllLossBackward>)


  6%|▌         | 126/2173 [00:52<14:16,  2.39it/s]

tensor(4.0422, device='cuda:0', grad_fn=<NllLossBackward>)


  6%|▌         | 127/2173 [00:52<14:15,  2.39it/s]

tensor(4.0650, device='cuda:0', grad_fn=<NllLossBackward>)


  6%|▌         | 128/2173 [00:53<13:54,  2.45it/s]

tensor(4.1455, device='cuda:0', grad_fn=<NllLossBackward>)


  6%|▌         | 129/2173 [00:53<13:45,  2.48it/s]

tensor(4.2166, device='cuda:0', grad_fn=<NllLossBackward>)


  6%|▌         | 130/2173 [00:53<13:39,  2.49it/s]

tensor(4.1566, device='cuda:0', grad_fn=<NllLossBackward>)


  6%|▌         | 131/2173 [00:54<13:27,  2.53it/s]

tensor(4.3068, device='cuda:0', grad_fn=<NllLossBackward>)


  6%|▌         | 132/2173 [00:54<13:40,  2.49it/s]

tensor(4.1045, device='cuda:0', grad_fn=<NllLossBackward>)


  6%|▌         | 133/2173 [00:55<14:08,  2.40it/s]

tensor(4.3184, device='cuda:0', grad_fn=<NllLossBackward>)


  6%|▌         | 134/2173 [00:55<14:13,  2.39it/s]

tensor(4.1808, device='cuda:0', grad_fn=<NllLossBackward>)


  6%|▌         | 135/2173 [00:55<14:17,  2.38it/s]

tensor(4.4756, device='cuda:0', grad_fn=<NllLossBackward>)


  6%|▋         | 136/2173 [00:56<14:05,  2.41it/s]

tensor(4.0341, device='cuda:0', grad_fn=<NllLossBackward>)


  6%|▋         | 137/2173 [00:56<13:54,  2.44it/s]

tensor(4.0899, device='cuda:0', grad_fn=<NllLossBackward>)


  6%|▋         | 138/2173 [00:57<14:09,  2.40it/s]

tensor(4.0562, device='cuda:0', grad_fn=<NllLossBackward>)


  6%|▋         | 139/2173 [00:57<14:11,  2.39it/s]

tensor(3.9381, device='cuda:0', grad_fn=<NllLossBackward>)


  6%|▋         | 140/2173 [00:58<14:13,  2.38it/s]

tensor(4.1149, device='cuda:0', grad_fn=<NllLossBackward>)


  6%|▋         | 141/2173 [00:58<13:55,  2.43it/s]

tensor(4.5038, device='cuda:0', grad_fn=<NllLossBackward>)


  7%|▋         | 142/2173 [00:58<13:38,  2.48it/s]

tensor(4.2119, device='cuda:0', grad_fn=<NllLossBackward>)


  7%|▋         | 143/2173 [00:59<14:03,  2.41it/s]

tensor(4.0325, device='cuda:0', grad_fn=<NllLossBackward>)


  7%|▋         | 144/2173 [00:59<14:32,  2.32it/s]

tensor(3.9797, device='cuda:0', grad_fn=<NllLossBackward>)


  7%|▋         | 145/2173 [01:00<14:23,  2.35it/s]

tensor(4.0518, device='cuda:0', grad_fn=<NllLossBackward>)


  7%|▋         | 146/2173 [01:00<14:00,  2.41it/s]

tensor(3.8026, device='cuda:0', grad_fn=<NllLossBackward>)


  7%|▋         | 147/2173 [01:00<13:46,  2.45it/s]

tensor(3.6222, device='cuda:0', grad_fn=<NllLossBackward>)


  7%|▋         | 148/2173 [01:01<13:38,  2.47it/s]

tensor(3.6576, device='cuda:0', grad_fn=<NllLossBackward>)


  7%|▋         | 149/2173 [01:01<13:26,  2.51it/s]

tensor(4.0185, device='cuda:0', grad_fn=<NllLossBackward>)


  7%|▋         | 150/2173 [01:02<13:33,  2.49it/s]

tensor(3.6167, device='cuda:0', grad_fn=<NllLossBackward>)


  7%|▋         | 151/2173 [01:02<13:42,  2.46it/s]

tensor(3.8850, device='cuda:0', grad_fn=<NllLossBackward>)


  7%|▋         | 152/2173 [01:02<13:41,  2.46it/s]

tensor(3.5681, device='cuda:0', grad_fn=<NllLossBackward>)


  7%|▋         | 153/2173 [01:03<13:45,  2.45it/s]

tensor(3.7508, device='cuda:0', grad_fn=<NllLossBackward>)


  7%|▋         | 154/2173 [01:03<14:19,  2.35it/s]

tensor(3.9575, device='cuda:0', grad_fn=<NllLossBackward>)


  7%|▋         | 155/2173 [01:04<14:28,  2.32it/s]

tensor(3.9749, device='cuda:0', grad_fn=<NllLossBackward>)


  7%|▋         | 156/2173 [01:04<13:58,  2.41it/s]

tensor(4.1773, device='cuda:0', grad_fn=<NllLossBackward>)


  7%|▋         | 157/2173 [01:05<13:53,  2.42it/s]

tensor(4.0192, device='cuda:0', grad_fn=<NllLossBackward>)


  7%|▋         | 158/2173 [01:05<13:59,  2.40it/s]

tensor(3.9635, device='cuda:0', grad_fn=<NllLossBackward>)


  7%|▋         | 159/2173 [01:05<13:49,  2.43it/s]

tensor(4.1093, device='cuda:0', grad_fn=<NllLossBackward>)


  7%|▋         | 160/2173 [01:06<14:00,  2.40it/s]

tensor(4.5301, device='cuda:0', grad_fn=<NllLossBackward>)


  7%|▋         | 161/2173 [01:06<14:15,  2.35it/s]

tensor(4.1472, device='cuda:0', grad_fn=<NllLossBackward>)


  7%|▋         | 162/2173 [01:07<14:12,  2.36it/s]

tensor(3.5658, device='cuda:0', grad_fn=<NllLossBackward>)


  8%|▊         | 163/2173 [01:07<14:14,  2.35it/s]

tensor(3.8432, device='cuda:0', grad_fn=<NllLossBackward>)


  8%|▊         | 164/2173 [01:08<14:43,  2.27it/s]

tensor(4.0720, device='cuda:0', grad_fn=<NllLossBackward>)


  8%|▊         | 165/2173 [01:08<15:01,  2.23it/s]

tensor(3.9273, device='cuda:0', grad_fn=<NllLossBackward>)


  8%|▊         | 166/2173 [01:09<14:59,  2.23it/s]

tensor(4.0011, device='cuda:0', grad_fn=<NllLossBackward>)


  8%|▊         | 167/2173 [01:09<14:51,  2.25it/s]

tensor(4.4182, device='cuda:0', grad_fn=<NllLossBackward>)


  8%|▊         | 168/2173 [01:09<15:11,  2.20it/s]

tensor(4.0583, device='cuda:0', grad_fn=<NllLossBackward>)


  8%|▊         | 169/2173 [01:10<14:39,  2.28it/s]

tensor(4.0439, device='cuda:0', grad_fn=<NllLossBackward>)


  8%|▊         | 170/2173 [01:10<14:30,  2.30it/s]

tensor(3.9955, device='cuda:0', grad_fn=<NllLossBackward>)


  8%|▊         | 171/2173 [01:11<14:11,  2.35it/s]

tensor(4.3677, device='cuda:0', grad_fn=<NllLossBackward>)


  8%|▊         | 172/2173 [01:11<14:04,  2.37it/s]

tensor(4.1793, device='cuda:0', grad_fn=<NllLossBackward>)


  8%|▊         | 173/2173 [01:12<14:22,  2.32it/s]

tensor(3.6823, device='cuda:0', grad_fn=<NllLossBackward>)


  8%|▊         | 174/2173 [01:12<14:25,  2.31it/s]

tensor(3.8910, device='cuda:0', grad_fn=<NllLossBackward>)


  8%|▊         | 175/2173 [01:12<14:23,  2.31it/s]

tensor(3.9535, device='cuda:0', grad_fn=<NllLossBackward>)


  8%|▊         | 176/2173 [01:13<14:24,  2.31it/s]

tensor(4.0948, device='cuda:0', grad_fn=<NllLossBackward>)


  8%|▊         | 177/2173 [01:13<14:56,  2.23it/s]

tensor(3.8423, device='cuda:0', grad_fn=<NllLossBackward>)


  8%|▊         | 178/2173 [01:14<14:36,  2.28it/s]

tensor(4.1577, device='cuda:0', grad_fn=<NllLossBackward>)


  8%|▊         | 179/2173 [01:14<14:19,  2.32it/s]

tensor(4.1011, device='cuda:0', grad_fn=<NllLossBackward>)


  8%|▊         | 180/2173 [01:15<14:31,  2.29it/s]

tensor(4.0415, device='cuda:0', grad_fn=<NllLossBackward>)


  8%|▊         | 181/2173 [01:15<14:58,  2.22it/s]

tensor(3.9924, device='cuda:0', grad_fn=<NllLossBackward>)


  8%|▊         | 182/2173 [01:16<15:32,  2.14it/s]

tensor(4.3498, device='cuda:0', grad_fn=<NllLossBackward>)


  8%|▊         | 183/2173 [01:16<15:36,  2.12it/s]

tensor(3.9746, device='cuda:0', grad_fn=<NllLossBackward>)


  8%|▊         | 184/2173 [01:17<15:43,  2.11it/s]

tensor(3.9299, device='cuda:0', grad_fn=<NllLossBackward>)


  9%|▊         | 185/2173 [01:17<15:31,  2.13it/s]

tensor(3.9669, device='cuda:0', grad_fn=<NllLossBackward>)


  9%|▊         | 186/2173 [01:17<15:44,  2.10it/s]

tensor(3.8509, device='cuda:0', grad_fn=<NllLossBackward>)


  9%|▊         | 187/2173 [01:18<15:33,  2.13it/s]

tensor(4.2816, device='cuda:0', grad_fn=<NllLossBackward>)


  9%|▊         | 188/2173 [01:18<15:35,  2.12it/s]

tensor(4.4399, device='cuda:0', grad_fn=<NllLossBackward>)


  9%|▊         | 189/2173 [01:19<15:46,  2.10it/s]

tensor(4.2731, device='cuda:0', grad_fn=<NllLossBackward>)


  9%|▊         | 190/2173 [01:19<15:49,  2.09it/s]

tensor(4.2859, device='cuda:0', grad_fn=<NllLossBackward>)


  9%|▉         | 191/2173 [01:20<15:14,  2.17it/s]

tensor(4.4652, device='cuda:0', grad_fn=<NllLossBackward>)


  9%|▉         | 192/2173 [01:20<15:50,  2.08it/s]

tensor(3.9045, device='cuda:0', grad_fn=<NllLossBackward>)


  9%|▉         | 193/2173 [01:21<15:50,  2.08it/s]

tensor(4.3270, device='cuda:0', grad_fn=<NllLossBackward>)


  9%|▉         | 194/2173 [01:21<15:35,  2.12it/s]

tensor(3.9736, device='cuda:0', grad_fn=<NllLossBackward>)


  9%|▉         | 195/2173 [01:22<15:52,  2.08it/s]

tensor(4.5408, device='cuda:0', grad_fn=<NllLossBackward>)


  9%|▉         | 196/2173 [01:22<16:19,  2.02it/s]

tensor(4.3122, device='cuda:0', grad_fn=<NllLossBackward>)


  9%|▉         | 197/2173 [01:23<16:13,  2.03it/s]

tensor(3.9721, device='cuda:0', grad_fn=<NllLossBackward>)


  9%|▉         | 198/2173 [01:23<16:20,  2.01it/s]

tensor(4.7127, device='cuda:0', grad_fn=<NllLossBackward>)


  9%|▉         | 199/2173 [01:24<16:08,  2.04it/s]

tensor(4.0956, device='cuda:0', grad_fn=<NllLossBackward>)


  9%|▉         | 200/2173 [01:24<15:50,  2.08it/s]

tensor(4.8221, device='cuda:0', grad_fn=<NllLossBackward>)


  9%|▉         | 201/2173 [01:25<15:33,  2.11it/s]

tensor(4.0789, device='cuda:0', grad_fn=<NllLossBackward>)


  9%|▉         | 202/2173 [01:25<15:19,  2.14it/s]

tensor(4.5677, device='cuda:0', grad_fn=<NllLossBackward>)


  9%|▉         | 203/2173 [01:26<14:59,  2.19it/s]

tensor(3.7965, device='cuda:0', grad_fn=<NllLossBackward>)


  9%|▉         | 204/2173 [01:26<15:23,  2.13it/s]

tensor(4.4400, device='cuda:0', grad_fn=<NllLossBackward>)


  9%|▉         | 205/2173 [01:27<15:17,  2.14it/s]

tensor(4.4693, device='cuda:0', grad_fn=<NllLossBackward>)


  9%|▉         | 206/2173 [01:27<15:15,  2.15it/s]

tensor(4.1682, device='cuda:0', grad_fn=<NllLossBackward>)


 10%|▉         | 207/2173 [01:27<15:12,  2.16it/s]

tensor(3.5407, device='cuda:0', grad_fn=<NllLossBackward>)


 10%|▉         | 208/2173 [01:28<15:00,  2.18it/s]

tensor(3.7665, device='cuda:0', grad_fn=<NllLossBackward>)


 10%|▉         | 209/2173 [01:28<15:13,  2.15it/s]

tensor(4.1893, device='cuda:0', grad_fn=<NllLossBackward>)


 10%|▉         | 210/2173 [01:29<15:26,  2.12it/s]

tensor(4.2527, device='cuda:0', grad_fn=<NllLossBackward>)


 10%|▉         | 211/2173 [01:29<15:52,  2.06it/s]

tensor(4.0583, device='cuda:0', grad_fn=<NllLossBackward>)


 10%|▉         | 212/2173 [01:30<16:17,  2.01it/s]

tensor(4.3753, device='cuda:0', grad_fn=<NllLossBackward>)


 10%|▉         | 213/2173 [01:30<15:54,  2.05it/s]

tensor(4.1015, device='cuda:0', grad_fn=<NllLossBackward>)


 10%|▉         | 214/2173 [01:31<16:12,  2.01it/s]

tensor(3.9415, device='cuda:0', grad_fn=<NllLossBackward>)


 10%|▉         | 215/2173 [01:31<16:33,  1.97it/s]

tensor(4.2433, device='cuda:0', grad_fn=<NllLossBackward>)


 10%|▉         | 216/2173 [01:32<16:13,  2.01it/s]

tensor(3.6932, device='cuda:0', grad_fn=<NllLossBackward>)


 10%|▉         | 217/2173 [01:32<16:06,  2.02it/s]

tensor(4.2374, device='cuda:0', grad_fn=<NllLossBackward>)


 10%|█         | 218/2173 [01:33<16:25,  1.98it/s]

tensor(4.2295, device='cuda:0', grad_fn=<NllLossBackward>)


 10%|█         | 219/2173 [01:33<16:29,  1.97it/s]

tensor(4.3294, device='cuda:0', grad_fn=<NllLossBackward>)


 10%|█         | 220/2173 [01:34<16:08,  2.02it/s]

tensor(4.1497, device='cuda:0', grad_fn=<NllLossBackward>)


 10%|█         | 221/2173 [01:34<16:09,  2.01it/s]

tensor(3.8668, device='cuda:0', grad_fn=<NllLossBackward>)


 10%|█         | 222/2173 [01:35<15:58,  2.04it/s]

tensor(4.7679, device='cuda:0', grad_fn=<NllLossBackward>)


 10%|█         | 223/2173 [01:35<15:53,  2.04it/s]

tensor(4.2168, device='cuda:0', grad_fn=<NllLossBackward>)


 10%|█         | 224/2173 [01:36<15:42,  2.07it/s]

tensor(4.0176, device='cuda:0', grad_fn=<NllLossBackward>)


 10%|█         | 225/2173 [01:36<15:47,  2.06it/s]

tensor(4.2004, device='cuda:0', grad_fn=<NllLossBackward>)


 10%|█         | 226/2173 [01:37<15:38,  2.08it/s]

tensor(4.4181, device='cuda:0', grad_fn=<NllLossBackward>)


 10%|█         | 227/2173 [01:37<15:37,  2.08it/s]

tensor(4.2762, device='cuda:0', grad_fn=<NllLossBackward>)


 10%|█         | 228/2173 [01:38<15:11,  2.13it/s]

tensor(4.2058, device='cuda:0', grad_fn=<NllLossBackward>)


 11%|█         | 229/2173 [01:38<15:04,  2.15it/s]

tensor(4.1841, device='cuda:0', grad_fn=<NllLossBackward>)


 11%|█         | 230/2173 [01:39<14:56,  2.17it/s]

tensor(3.9579, device='cuda:0', grad_fn=<NllLossBackward>)


 11%|█         | 231/2173 [01:39<15:26,  2.10it/s]

tensor(3.7780, device='cuda:0', grad_fn=<NllLossBackward>)


 11%|█         | 232/2173 [01:40<15:36,  2.07it/s]

tensor(4.1287, device='cuda:0', grad_fn=<NllLossBackward>)


 11%|█         | 233/2173 [01:40<15:29,  2.09it/s]

tensor(4.0317, device='cuda:0', grad_fn=<NllLossBackward>)


 11%|█         | 234/2173 [01:41<15:21,  2.10it/s]

tensor(4.3656, device='cuda:0', grad_fn=<NllLossBackward>)


 11%|█         | 235/2173 [01:41<15:58,  2.02it/s]

tensor(3.9538, device='cuda:0', grad_fn=<NllLossBackward>)


 11%|█         | 236/2173 [01:42<16:27,  1.96it/s]

tensor(4.6222, device='cuda:0', grad_fn=<NllLossBackward>)


 11%|█         | 237/2173 [01:42<16:13,  1.99it/s]

tensor(4.0178, device='cuda:0', grad_fn=<NllLossBackward>)


 11%|█         | 238/2173 [01:43<16:07,  2.00it/s]

tensor(4.2966, device='cuda:0', grad_fn=<NllLossBackward>)


 11%|█         | 239/2173 [01:43<16:12,  1.99it/s]

tensor(4.3855, device='cuda:0', grad_fn=<NllLossBackward>)


 11%|█         | 240/2173 [01:44<15:56,  2.02it/s]

tensor(3.7645, device='cuda:0', grad_fn=<NllLossBackward>)


 11%|█         | 241/2173 [01:44<15:39,  2.06it/s]

tensor(4.3172, device='cuda:0', grad_fn=<NllLossBackward>)


 11%|█         | 242/2173 [01:45<15:47,  2.04it/s]

tensor(3.8439, device='cuda:0', grad_fn=<NllLossBackward>)


 11%|█         | 243/2173 [01:45<15:40,  2.05it/s]

tensor(3.5456, device='cuda:0', grad_fn=<NllLossBackward>)


 11%|█         | 244/2173 [01:46<15:50,  2.03it/s]

tensor(3.9463, device='cuda:0', grad_fn=<NllLossBackward>)


 11%|█▏        | 245/2173 [01:46<15:52,  2.02it/s]

tensor(4.2731, device='cuda:0', grad_fn=<NllLossBackward>)


 11%|█▏        | 246/2173 [01:47<15:18,  2.10it/s]

tensor(3.7147, device='cuda:0', grad_fn=<NllLossBackward>)


 11%|█▏        | 247/2173 [01:47<15:26,  2.08it/s]

tensor(4.2832, device='cuda:0', grad_fn=<NllLossBackward>)


 11%|█▏        | 248/2173 [01:47<15:32,  2.06it/s]

tensor(3.9322, device='cuda:0', grad_fn=<NllLossBackward>)


 11%|█▏        | 249/2173 [01:48<15:09,  2.11it/s]

tensor(4.0872, device='cuda:0', grad_fn=<NllLossBackward>)


 12%|█▏        | 250/2173 [01:48<15:20,  2.09it/s]

tensor(4.1140, device='cuda:0', grad_fn=<NllLossBackward>)


 12%|█▏        | 251/2173 [01:49<15:27,  2.07it/s]

tensor(3.9028, device='cuda:0', grad_fn=<NllLossBackward>)


 12%|█▏        | 252/2173 [01:49<15:00,  2.13it/s]

tensor(3.6843, device='cuda:0', grad_fn=<NllLossBackward>)


 12%|█▏        | 253/2173 [01:50<15:08,  2.11it/s]

tensor(4.0295, device='cuda:0', grad_fn=<NllLossBackward>)


 12%|█▏        | 254/2173 [01:50<15:04,  2.12it/s]

tensor(3.9124, device='cuda:0', grad_fn=<NllLossBackward>)


 12%|█▏        | 255/2173 [01:51<15:11,  2.10it/s]

tensor(4.1550, device='cuda:0', grad_fn=<NllLossBackward>)


 12%|█▏        | 256/2173 [01:51<15:23,  2.08it/s]

tensor(4.3404, device='cuda:0', grad_fn=<NllLossBackward>)


 12%|█▏        | 257/2173 [01:52<15:24,  2.07it/s]

tensor(4.4697, device='cuda:0', grad_fn=<NllLossBackward>)


 12%|█▏        | 258/2173 [01:52<15:13,  2.10it/s]

tensor(4.9446, device='cuda:0', grad_fn=<NllLossBackward>)


 12%|█▏        | 259/2173 [01:53<15:25,  2.07it/s]

tensor(4.6838, device='cuda:0', grad_fn=<NllLossBackward>)


 12%|█▏        | 260/2173 [01:53<15:29,  2.06it/s]

tensor(3.9077, device='cuda:0', grad_fn=<NllLossBackward>)


 12%|█▏        | 261/2173 [01:54<15:10,  2.10it/s]

tensor(4.5854, device='cuda:0', grad_fn=<NllLossBackward>)


 12%|█▏        | 262/2173 [01:54<15:24,  2.07it/s]

tensor(4.0840, device='cuda:0', grad_fn=<NllLossBackward>)


 12%|█▏        | 263/2173 [01:55<15:25,  2.06it/s]

tensor(4.1746, device='cuda:0', grad_fn=<NllLossBackward>)


 12%|█▏        | 264/2173 [01:55<15:22,  2.07it/s]

tensor(4.3546, device='cuda:0', grad_fn=<NllLossBackward>)


 12%|█▏        | 265/2173 [01:56<15:22,  2.07it/s]

tensor(4.3481, device='cuda:0', grad_fn=<NllLossBackward>)


 12%|█▏        | 266/2173 [01:56<15:25,  2.06it/s]

tensor(4.2249, device='cuda:0', grad_fn=<NllLossBackward>)


 12%|█▏        | 267/2173 [01:57<15:14,  2.08it/s]

tensor(4.4138, device='cuda:0', grad_fn=<NllLossBackward>)


 12%|█▏        | 268/2173 [01:57<14:57,  2.12it/s]

tensor(4.0550, device='cuda:0', grad_fn=<NllLossBackward>)


 12%|█▏        | 269/2173 [01:58<15:10,  2.09it/s]

tensor(4.4448, device='cuda:0', grad_fn=<NllLossBackward>)


 12%|█▏        | 270/2173 [01:58<15:15,  2.08it/s]

tensor(4.5065, device='cuda:0', grad_fn=<NllLossBackward>)


 12%|█▏        | 271/2173 [01:59<15:35,  2.03it/s]

tensor(4.4733, device='cuda:0', grad_fn=<NllLossBackward>)


 13%|█▎        | 272/2173 [01:59<15:48,  2.00it/s]

tensor(4.2126, device='cuda:0', grad_fn=<NllLossBackward>)


 13%|█▎        | 273/2173 [02:00<15:43,  2.01it/s]

tensor(4.5446, device='cuda:0', grad_fn=<NllLossBackward>)


 13%|█▎        | 274/2173 [02:00<15:51,  2.00it/s]

tensor(4.0168, device='cuda:0', grad_fn=<NllLossBackward>)


 13%|█▎        | 275/2173 [02:01<15:30,  2.04it/s]

tensor(4.4246, device='cuda:0', grad_fn=<NllLossBackward>)


 13%|█▎        | 276/2173 [02:01<15:51,  1.99it/s]

tensor(4.4018, device='cuda:0', grad_fn=<NllLossBackward>)


 13%|█▎        | 277/2173 [02:02<15:28,  2.04it/s]

tensor(4.7346, device='cuda:0', grad_fn=<NllLossBackward>)


 13%|█▎        | 278/2173 [02:02<15:39,  2.02it/s]

tensor(4.2092, device='cuda:0', grad_fn=<NllLossBackward>)


 13%|█▎        | 279/2173 [02:03<15:29,  2.04it/s]

tensor(4.7053, device='cuda:0', grad_fn=<NllLossBackward>)


 13%|█▎        | 280/2173 [02:03<15:23,  2.05it/s]

tensor(5.1175, device='cuda:0', grad_fn=<NllLossBackward>)


 13%|█▎        | 281/2173 [02:03<15:33,  2.03it/s]

tensor(4.4225, device='cuda:0', grad_fn=<NllLossBackward>)


 13%|█▎        | 282/2173 [02:04<15:44,  2.00it/s]

tensor(4.0411, device='cuda:0', grad_fn=<NllLossBackward>)


 13%|█▎        | 283/2173 [02:04<15:32,  2.03it/s]

tensor(4.4328, device='cuda:0', grad_fn=<NllLossBackward>)


 13%|█▎        | 284/2173 [02:05<15:17,  2.06it/s]

tensor(3.9971, device='cuda:0', grad_fn=<NllLossBackward>)


 13%|█▎        | 285/2173 [02:05<15:40,  2.01it/s]

tensor(4.4240, device='cuda:0', grad_fn=<NllLossBackward>)


 13%|█▎        | 286/2173 [02:06<15:56,  1.97it/s]

tensor(4.2786, device='cuda:0', grad_fn=<NllLossBackward>)


 13%|█▎        | 287/2173 [02:07<15:48,  1.99it/s]

tensor(3.9859, device='cuda:0', grad_fn=<NllLossBackward>)


 13%|█▎        | 288/2173 [02:07<16:12,  1.94it/s]

tensor(4.0766, device='cuda:0', grad_fn=<NllLossBackward>)


 13%|█▎        | 289/2173 [02:08<16:19,  1.92it/s]

tensor(4.2142, device='cuda:0', grad_fn=<NllLossBackward>)


 13%|█▎        | 290/2173 [02:08<15:59,  1.96it/s]

tensor(4.1456, device='cuda:0', grad_fn=<NllLossBackward>)


 13%|█▎        | 291/2173 [02:09<15:39,  2.00it/s]

tensor(4.2248, device='cuda:0', grad_fn=<NllLossBackward>)


 13%|█▎        | 292/2173 [02:09<15:33,  2.01it/s]

tensor(4.1316, device='cuda:0', grad_fn=<NllLossBackward>)


 13%|█▎        | 293/2173 [02:10<15:51,  1.98it/s]

tensor(4.2900, device='cuda:0', grad_fn=<NllLossBackward>)


 14%|█▎        | 294/2173 [02:10<15:38,  2.00it/s]

tensor(4.1801, device='cuda:0', grad_fn=<NllLossBackward>)


 14%|█▎        | 295/2173 [02:11<15:25,  2.03it/s]

tensor(3.7988, device='cuda:0', grad_fn=<NllLossBackward>)


 14%|█▎        | 296/2173 [02:11<16:01,  1.95it/s]

tensor(4.2206, device='cuda:0', grad_fn=<NllLossBackward>)


 14%|█▎        | 297/2173 [02:12<16:05,  1.94it/s]

tensor(3.9716, device='cuda:0', grad_fn=<NllLossBackward>)


 14%|█▎        | 298/2173 [02:12<15:57,  1.96it/s]

tensor(4.0538, device='cuda:0', grad_fn=<NllLossBackward>)


 14%|█▍        | 299/2173 [02:13<15:59,  1.95it/s]

tensor(4.0887, device='cuda:0', grad_fn=<NllLossBackward>)


 14%|█▍        | 300/2173 [02:13<15:50,  1.97it/s]

tensor(4.1721, device='cuda:0', grad_fn=<NllLossBackward>)


 14%|█▍        | 301/2173 [02:14<15:31,  2.01it/s]

tensor(4.1143, device='cuda:0', grad_fn=<NllLossBackward>)


 14%|█▍        | 302/2173 [02:14<15:24,  2.02it/s]

tensor(4.2480, device='cuda:0', grad_fn=<NllLossBackward>)


 14%|█▍        | 303/2173 [02:15<15:11,  2.05it/s]

tensor(4.0628, device='cuda:0', grad_fn=<NllLossBackward>)


 14%|█▍        | 304/2173 [02:15<15:15,  2.04it/s]

tensor(4.2234, device='cuda:0', grad_fn=<NllLossBackward>)


 14%|█▍        | 305/2173 [02:16<15:22,  2.03it/s]

tensor(4.1369, device='cuda:0', grad_fn=<NllLossBackward>)


 14%|█▍        | 306/2173 [02:16<15:08,  2.05it/s]

tensor(3.8439, device='cuda:0', grad_fn=<NllLossBackward>)


 14%|█▍        | 307/2173 [02:16<15:03,  2.06it/s]

tensor(4.1556, device='cuda:0', grad_fn=<NllLossBackward>)


 14%|█▍        | 307/2173 [02:17<13:53,  2.24it/s]


KeyboardInterrupt: 

In [28]:
print(outputs.logits.shape)
print(batch['labels'].shape)

torch.Size([2, 673, 50265])
torch.Size([2, 673])


In [5]:
# iterate over epochs
epochs = 10
for epoch in range(epochs):
    pass

tensor(3.5736, device='cuda:0', grad_fn=<NllLossBackward>)