In [1]:
import pandas as pd
import numpy as np
from sklearn.metrics import jaccard_score
import os
import time
from tqdm import tnrange, tqdm_notebook
# https://nbviewer.jupyter.org/github/kaushaltrivedi/bert-toxic-comments-multilabel/blob/master/toxic-bert-multilabel-classification.ipynb?source=post_page-------

In [2]:

WD = os.getcwd()
DATA_DIR = os.path.join(WD, 'data','mpst-movie-plot-synopses-with-tags','mpst_full_data.csv')

In [3]:
data = pd.read_csv(DATA_DIR)
data = data.drop(['synopsis_source'],axis=1)
data.shape

(14828, 5)

In [4]:
data.head()

Unnamed: 0,imdb_id,title,plot_synopsis,tags,split
0,tt0057603,I tre volti della paura,Note: this synopsis is for the orginal Italian...,"cult, horror, gothic, murder, atmospheric",train
1,tt1733125,Dungeons & Dragons: The Book of Vile Darkness,"Two thousand years ago, Nhagruul the Foul, a s...",violence,train
2,tt0033045,The Shop Around the Corner,"Matuschek's, a gift store in Budapest, is the ...",romantic,test
3,tt0113862,Mr. Holland's Opus,"Glenn Holland, not a morning person by anyone'...","inspiring, romantic, stupid, feel-good",train
4,tt0086250,Scarface,"In May 1980, a Cuban man named Tony Montana (A...","cruelty, murder, dramatic, cult, violence, atm...",val


In [5]:
split = data['tags'].str.split(', ')
lens = split.str.len()


In [6]:
 np.concatenate(split)

array(['cult', 'horror', 'gothic', ..., 'anti war', 'murder',
       'christian film'], dtype='<U18')

In [7]:
temp_df = pd.DataFrame({'imdb_id': np.repeat(data['imdb_id'].values, lens), 
                        'category': np.concatenate(split),
                       'values': 1})

print(temp_df['category'].unique())
print(len(temp_df['category'].unique()))

temp_df = temp_df.pivot(index='imdb_id', columns='category', values='values').fillna(0).reset_index()



['cult' 'horror' 'gothic' 'murder' 'atmospheric' 'violence' 'romantic'
 'inspiring' 'stupid' 'feel-good' 'cruelty' 'dramatic' 'action' 'revenge'
 'sadist' 'queer' 'flashback' 'mystery' 'suspenseful' 'neo noir' 'prank'
 'psychedelic' 'tragedy' 'autobiographical' 'home movie'
 'good versus evil' 'depressing' 'realism' 'boring' 'haunting'
 'sentimental' 'paranormal' 'historical' 'storytelling' 'comedy' 'fantasy'
 'philosophical' 'adult comedy' 'cute' 'entertaining' 'bleak' 'humor'
 'plot twist' 'christian film' 'pornographic' 'insanity' 'brainwashing'
 'sci-fi' 'dark' 'claustrophobic' 'psychological' 'melodrama'
 'historical fiction' 'absurd' 'satire' 'alternate reality'
 'alternate history' 'comic' 'grindhouse film' 'thought-provoking'
 'clever' 'western' 'blaxploitation' 'whimsical' 'intrigue' 'allegory'
 'anti war' 'avant garde' 'suicidal' 'magical realism' 'non fiction']
71


In [8]:
data_separate = data.merge(temp_df, how='left', on='imdb_id')
data_separate.head()

Unnamed: 0,imdb_id,title,plot_synopsis,tags,split,absurd,action,adult comedy,allegory,alternate history,...,sentimental,storytelling,stupid,suicidal,suspenseful,thought-provoking,tragedy,violence,western,whimsical
0,tt0057603,I tre volti della paura,Note: this synopsis is for the orginal Italian...,"cult, horror, gothic, murder, atmospheric",train,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,tt1733125,Dungeons & Dragons: The Book of Vile Darkness,"Two thousand years ago, Nhagruul the Foul, a s...",violence,train,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0
2,tt0033045,The Shop Around the Corner,"Matuschek's, a gift store in Budapest, is the ...",romantic,test,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,tt0113862,Mr. Holland's Opus,"Glenn Holland, not a morning person by anyone'...","inspiring, romantic, stupid, feel-good",train,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,tt0086250,Scarface,"In May 1980, a Cuban man named Tony Montana (A...","cruelty, murder, dramatic, cult, violence, atm...",val,0.0,1.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0


In [9]:
train_df = data_separate[data_separate['split'] == 'train']
val_df = data_separate[data_separate['split'] == 'val']
test_df = data_separate[data_separate['split'] == 'test']

train_df.shape, val_df.shape, test_df.shape

((9489, 76), (2373, 76), (2966, 76))

# BERT Modeling

In [10]:
import torch
from pytorch_transformers import *
from pytorch_transformers.modeling_bert import BertPreTrainedModel
from pytorch_transformers.optimization import AdamW

from torch.utils.data import Dataset, DataLoader
from torch.nn import BCEWithLogitsLoss

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = "cpu"

In [11]:
bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_model = BertModel.from_pretrained('bert-base-uncased')
# bert_config = BertConfig.from_pretrained('bert-base-uncased')

In [23]:
class MPSTDataset(Dataset):

    def __init__(self, dataframe, tokenizer, max_seq_length, transform=None):
        self.dataframe = dataframe
        self.transform = transform
        self.tokenizer = tokenizer
        self.max_seq_length = max_seq_length

    def __len__(self):
        return len(self.dataframe)
    
    def get_sample_features(self, sample):
        tokenized_sample = self.tokenizer.tokenize(sample)
        
        tokenized_sample = ["[CLS]"] + tokenized_sample[:self.max_seq_length-2] + ["[SEP]"]
    
        input_ids = self.tokenizer.convert_tokens_to_ids(tokenized_sample)
        segment_ids = [0] * len(input_ids)
        input_mask = [1] * len(input_ids)

        # Zero-pad up to the sequence length.
        padding = [0] * (self.max_seq_length - len(input_ids))
        input_ids += padding
        input_mask += padding
        segment_ids += padding
        
        assert len(input_ids) == self.max_seq_length
        assert len(input_mask) == self.max_seq_length
        assert len(segment_ids) == self.max_seq_length
        
        return input_ids, input_mask, segment_ids


    def __getitem__(self, idx):
        sample = self.dataframe.iloc[idx]['plot_synopsis']
        label = self.dataframe.iloc[idx][5:]
        
        input_ids, input_mask, segment_ids = self.get_sample_features(sample)
        
        return torch.tensor(input_ids), torch.tensor(input_mask), torch.tensor(segment_ids), torch.tensor(label)
        

In [82]:
class BertForMultiLabelClassification(torch.nn.Module):
    def __init__(self, num_labels=71):
        super().__init__()
        self.num_labels = num_labels
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.dropout = torch.nn.Dropout(0.1)
        self.layer = torch.nn.Linear(768,300)
        self.classifier = torch.nn.Linear(300, num_labels)
        self.batchnorm = torch.nn.BatchNorm1d(71)
        
    def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None,
                position_ids=None, head_mask=None):
        
        outputs = self.bert(input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
                            attention_mask=attention_mask, head_mask=head_mask)

        # pooled output
        pooled_output = outputs[1]
        
        x = self.dropout(pooled_output)
        x = torch.nn.functional.relu(self.layer(x))
        logits = self.classifier(x)
        logits = self.batchnorm(logits)

        if labels is not None:
            loss_fct = BCEWithLogitsLoss()
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1, self.num_labels))
            return loss, logits
        else:
            return logits
        
    def freeze_bert_encoder(self):
        for param in self.bert.parameters():
            param.requires_grad = False
    
    def unfreeze_bert_encoder(self):
        for param in self.bert.parameters():
            param.requires_grad = True

In [83]:
def train_model(dataloaders, model, optimizer, criterion, scheduler, num_epochs=2):
    since = time.time()
    step_sizes = {'train': len(dataloaders['train']), 
                     'valid': len(dataloaders['valid'])}

    for epoch in tnrange(int(num_epochs), desc="Epoch"):
        for phase in ['train', 'valid']:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            running_loss = 0
            running_acc = 0
        
            for step, batch in enumerate(tqdm_notebook(dataloaders[phase], desc="Iteration")):
                batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, segment_ids, label_ids = batch
                
                logits = model(input_ids, segment_ids, input_mask)

                loss = criterion(logits.view(-1, 71), label_ids.view(-1,71))
                print(loss.item())
                
                running_loss += loss.item()
                
                logits_numpy = logits.sigmoid().detach().cpu().numpy()
                labels_numpy = label_ids.detach().cpu().numpy()
    
                acc = jaccard_score(labels_numpy, logits_numpy.round(), average='samples')
                running_acc += acc
        
                print(acc, logits_numpy.round().sum(axis=1))
    
                if phase == 'train':
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
#                     scheduler.step()
            
            if phase == 'train':
                train_loss = running_loss / step_sizes[phase]
                train_acc = running_acc / step_sizes[phase]
            else:
                valid_loss = running_loss / step_sizes[phase]
                valid_acc = running_acc / step_sizes[phase]
                
                print('Epoch [{}/{}] train loss: {:.4f} acc: {:.4f} ' 
              'valid loss: {:.4f} acc: {:.4f}'.format(
                epoch+1, num_epochs,
                train_loss, train_acc, 
                valid_loss, valid_acc))
            
    return model

In [84]:
train_ds = MPSTDataset(train_df, bert_tokenizer, 128)
train_dl = DataLoader(train_ds,batch_size=16, shuffle=True)

val_ds = MPSTDataset(val_df, bert_tokenizer, 256)
val_dl = DataLoader(val_ds,batch_size=16, shuffle=True)

dloaders = {'train':train_dl, 'valid':val_dl}

In [85]:
# EPOCHS = 10
# LEARNING_RATE = 3e-4
# ADAM_EPSILON = 1e-6
# WARMUP_STEPS = 0

# t_total= len(train_dl) * EPOCHS

# model = BertForMultiLabelClassification.from_pretrained("bert-base-uncased")
# model.to(device)

# criterion = BCEWithLogitsLoss()

# param_optimizer = list(model.named_parameters())
# no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
# optimizer_grouped_parameters = [
#     {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
#     {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
#     ]
# optimizer = AdamW(optimizer_grouped_parameters, lr=LEARNING_RATE, eps=ADAM_EPSILON)
# scheduler = WarmupLinearSchedule(optimizer, warmup_steps=WARMUP_STEPS, t_total=t_total)

In [86]:
model = BertForMultiLabelClassification()
model.freeze_bert_encoder()

criterion = BCEWithLogitsLoss()

optimizer = torch.optim.Adamax(model.parameters(), lr=0.001)


In [87]:
start_time = time.time()
model = train_model(dloaders, model, optimizer,criterion, scheduler=None, num_epochs=10)
print('Training time: {:10f} minutes'.format((time.time()-start_time)/60))

HBox(children=(IntProgress(value=0, description='Epoch', max=10, style=ProgressStyle(description_width='initia…

HBox(children=(IntProgress(value=0, description='Iteration', max=594, style=ProgressStyle(description_width='i…

0.7350518107414246
0.03767708236434049 [34. 36. 36. 40. 39. 39. 39. 35. 42. 36. 40. 33. 31. 34. 29. 46.]
0.7378048300743103
0.03138755665736216 [49. 35. 37. 36. 38. 26. 36. 33. 35. 31. 33. 35. 32. 39. 40. 38.]
0.738446831703186
0.020230884288497137 [39. 38. 40. 36. 37. 39. 29. 27. 35. 40. 46. 31. 38. 33. 38. 34.]
0.7371724247932434
0.03755437205747567 [39. 33. 39. 31. 42. 34. 33. 39. 38. 34. 32. 29. 32. 39. 35. 42.]
0.7320803999900818
0.048841332647838756 [36. 35. 34. 32. 42. 31. 35. 39. 33. 32. 39. 35. 42. 35. 38. 42.]
0.731555163860321
0.04516910286768498 [41. 36. 31. 38. 37. 41. 44. 38. 27. 37. 44. 33. 38. 36. 39. 39.]
0.7309052348136902
0.06826564425097967 [33. 40. 40. 40. 39. 35. 33. 38. 34. 36. 32. 38. 34. 31. 34. 40.]
0.7297363877296448
0.03960218374133391 [41. 29. 34. 35. 35. 31. 41. 37. 34. 38. 35. 35. 42. 37. 38. 38.]
0.7282686829566956
0.046056920517279394 [34. 36. 36. 41. 37. 32. 38. 33. 46. 34. 36. 34. 33. 35. 34. 36.]
0.730772852897644
0.034662813160349144 [30. 34. 36. 48

0.0649018182110507 [26. 35. 27. 27. 26. 32. 27. 32. 35. 31. 27. 31. 27. 34. 27. 24.]
0.6870490908622742
0.050078872485251995 [27. 26. 33. 31. 27. 31. 35. 31. 27. 31. 34. 34. 30. 28. 28. 34.]
0.6947110891342163
0.03289951344039219 [25. 35. 26. 30. 36. 28. 29. 31. 37. 34. 30. 26. 26. 34. 27. 31.]
0.6882666349411011
0.055136066532684117 [34. 24. 23. 38. 27. 31. 36. 29. 31. 30. 33. 35. 29. 36. 24. 30.]
0.6938909888267517
0.031175659055028337 [26. 25. 26. 33. 23. 30. 34. 31. 29. 28. 32. 28. 23. 30. 30. 33.]
0.6927849054336548
0.043764053024154974 [28. 28. 23. 28. 32. 35. 23. 24. 29. 33. 28. 23. 30. 24. 34. 26.]
0.6911174654960632
0.03622180218467526 [36. 31. 35. 33. 34. 35. 29. 30. 33. 25. 36. 26. 29. 39. 26. 31.]
0.6945680379867554
0.02332625776062562 [37. 37. 36. 32. 26. 34. 27. 33. 26. 32. 26. 33. 30. 25. 35. 39.]
0.6977989673614502
0.04232766792435317 [30. 35. 27. 28. 34. 33. 35. 30. 30. 29. 26. 28. 32. 25. 30. 25.]
0.6886199712753296
0.05138351525221993 [36. 24. 36. 29. 31. 30. 28. 35.

0.6575514078140259
0.046194764459875325 [27. 21. 18. 30. 23. 17. 22. 18. 26. 26. 19. 28. 25. 19. 21. 20.]
0.6611804962158203
0.03405497621610425 [32. 26. 27. 19. 29. 33. 31. 19. 27. 33. 26. 30. 24. 30. 33. 28.]
0.659112811088562
0.044923081145907234 [12. 12. 29. 16. 22. 21. 20. 25. 19. 25. 30. 15. 22. 22. 19. 18.]
0.6574290990829468
0.04615119718599733 [24. 32. 28. 24. 35. 15. 30. 25. 24. 20. 23. 23. 24. 33. 24. 21.]
0.6552160382270813
0.06596436650524526 [23. 23. 33. 34. 27. 20. 28. 25. 14. 22. 23. 25. 22. 24. 25. 19.]
0.6551298499107361
0.026617024219394908 [27. 20. 15. 30. 28. 25. 26. 20. 17. 23. 18. 25. 22. 17. 18. 22.]
0.6506446003913879
0.06884263514935021 [21. 14. 20. 28. 25. 25. 23. 22. 21. 27. 21. 20. 19. 17. 20. 20.]
0.6558706760406494
0.06374366869232487 [20. 16. 26. 24. 24. 21. 30. 24. 17. 21. 17. 21. 15. 15. 21. 30.]
0.6595443487167358
0.03571464862062839 [27. 21. 24. 28. 24. 32. 16. 17. 32. 29. 26. 23. 14. 29. 17. 24.]
0.6647953987121582
0.0614264480859555 [29. 31. 27. 25

0.05577821457457007 [14. 11. 13. 10. 20. 14. 17.  7.  8. 26. 10. 13. 17. 12. 19. 22.]
0.626634955406189
0.07437472147950089 [22. 14. 18. 27. 10. 13. 18. 12.  5. 24. 12. 16. 13. 12.  7. 26.]
0.6310697197914124
0.0563484219311875 [25. 21. 27. 26. 15.  9. 23. 27. 24. 12. 24. 19. 15. 25. 27. 26.]
0.6282371878623962
0.034048134775505466 [19. 25. 16. 20. 18. 16. 16. 27. 21. 15. 22. 23. 17. 15. 20. 25.]
0.6226719617843628
0.0521950146659899 [23. 26. 16. 14. 16. 13. 14. 11. 19. 18. 13. 25. 20. 22. 23. 26.]
0.6247724890708923
0.05959363186425207 [ 9. 16. 15.  9. 20. 24. 27. 16. 15.  8. 23. 18. 19. 20. 19. 23.]
0.6253793835639954
0.061084547132641505 [24. 14. 13. 26.  9. 15. 26. 16. 26. 13. 24. 28. 23. 27. 16. 20.]
0.6246854066848755
0.04427892369068839 [16. 10. 11.  6. 12.  7. 22. 21. 12. 15. 26. 16. 14. 14. 18. 10.]
0.622319757938385
0.04718772873664178 [26.  4. 20. 13. 27. 19. 25. 20. 14. 23. 26. 19. 18. 26. 22. 20.]
0.6198201775550842
0.06322859760643336 [13. 16. 12. 19. 24. 20. 13. 14. 19. 

0.03177656116794544 [ 7. 22. 14. 17.  8.  8. 16. 22.  6. 21. 19. 23. 13. 12. 15. 10.]
0.5876682996749878
0.16516053391053392 [10.  4. 23.  5.  6.  6.  7.  1.  8.  8.  6.  5.  6.  8. 10.  1.]
0.5912442803382874
0.09129670676545676 [ 7. 19. 17.  5. 15. 13. 16. 14. 20. 10. 17.  9.  9.  7. 11.  7.]
0.5975375175476074
0.0528310057997558 [12. 10. 11. 21. 14.  8. 12.  8.  6. 18. 10. 10. 14. 22.  8. 16.]
0.5912365913391113
0.07880824838776783 [21. 11. 18. 15. 19. 11. 18. 18. 20. 18. 17. 17. 15.  8. 16.  9.]
0.593020498752594
0.04419910128356336 [ 8. 19. 13. 14. 14. 15. 11. 19. 17. 17. 16. 19. 19. 20. 19. 15.]
0.6013560891151428
0.0464499407219028 [18. 12. 20. 13. 14. 18.  6. 19. 15. 16. 15. 20. 17. 13. 20. 17.]
0.585673987865448
0.05454468971701257 [11. 17. 15. 19. 21. 15. 11. 14. 16. 14. 13. 15.  8. 19. 23. 16.]
0.5867551565170288
0.06055063329679698 [ 8. 17. 18. 21. 10. 13. 17.  6. 13. 10. 19. 24.  9. 18. 13. 11.]
0.5912320017814636
0.11071047008547008 [ 4. 11.  8.  7. 21.  9.  6.  9.  8.  4

0.14909569597069597 [ 5.  2.  9.  2.  4.  5.  2. 10. 20.  5.  2.  3.  7. 14. 12.  9.]
0.5668050050735474
0.04910609578759734 [11.  6.  9. 12.  8. 13.  9.  7. 10. 20. 10. 19.  8. 15. 20. 16.]
0.5567354559898376
0.09384920634920635 [ 3.  2.  4. 21.  4.  5. 12.  4.  6.  9.  5.  3.  6.  6.  2.  3.]
0.5592153668403625
0.06778163625590096 [ 8.  4.  9.  9. 12.  6. 12. 13. 12.  7. 10. 11. 17. 10.  8. 14.]
0.5577774047851562
0.06575126262626263 [ 4.  4.  7.  9. 10. 14.  2.  5.  5. 10. 18.  7. 14.  6. 10.  6.]
0.5527938008308411
0.2103422619047619 [ 4.  6.  5.  7.  5.  3.  4.  2.  3.  3.  6.  3. 21.  1.  1.  6.]
0.5517638921737671
0.16443452380952378 [ 3.  5.  5.  2.  5.  4.  3.  4.  5.  6.  1.  5.  2. 20.  5.  1.]
0.5552122592926025
0.08871527777777777 [ 6.  1.  6.  2.  3.  8. 20.  2.  3.  4. 19.  6.  5.  6.  3.  2.]
0.5575916767120361
0.10193903318903319 [ 1.  7.  7. 20.  8. 10.  7. 10.  5. 13.  9.  4.  7.  1.  2.  7.]
0.5641413331031799
0.1830305829228243 [ 5.  0. 18.  0.  2.  7.  5.  6.  2. 

0.15051174315880197 [ 5.  4.  8.  6.  7.  2.  5.  7.  1.  5.  4.  3. 13.  5. 12.  9.]
0.5373710989952087
0.09541396103896102 [ 1.  2.  5.  2.  3.  3. 17.  5.  4.  5.  9.  4.  9.  4.  6.  3.]
0.5320295691490173
0.12997855392156862 [ 3.  8.  3.  9.  4.  4.  1.  5. 15. 14. 10.  4.  6.  3.  2.  2.]
0.5387548804283142
0.07828033625730996 [ 6.  1.  3.  3.  3.  6.  7.  2.  4.  8.  4. 18. 12.  5.  3.  5.]
0.5355974435806274
0.08527800324675325 [ 7.  7.  8. 12. 11.  4.  4.  5.  7.  8.  7.  9.  3. 22. 16.  8.]
0.5315403342247009
0.16141098484848485 [ 4.  1.  3.  3. 19.  1.  0.  3.  4.  3.  4.  0.  3.  4.  7.  4.]
0.5346589088439941
0.09972758042610984 [ 6.  9.  4. 13.  7.  6.  5.  5.  7. 10. 16.  7.  2. 12.  6.  9.]
0.5311154723167419
0.1038500816993464 [ 2. 14.  5.  8.  4.  2.  2.  1.  1.  3. 18.  5.  8.  2.  3.  2.]
0.526311457157135
0.08506944444444443 [ 2. 16.  6.  2.  6. 18.  1.  1.  1.  3.  2.  3.  1.  6.  5.  6.]
0.5277138352394104
0.09065934065934066 [ 5.  7.  2.  3.  5.  6.  5. 12.  2. 

0.0751758658008658 [11.  7.  3.  6.  4. 15.  3.  0.  4.  4.  5.  5.  4. 10.  5.  5.]
0.5005397796630859
0.1282894736842105 [ 4.  2.  1.  0.  1.  1.  2.  0.  5.  2. 19.  1.  4.  4.  0.  9.]
0.5070752501487732
0.11536553724053725 [ 3. 14.  4.  4.  2.  3.  1.  3.  3. 16. 11.  3.  0.  2.  0.  0.]
0.501966655254364
0.15328724717311673 [ 1.  2.  5.  1.  1.  1.  3.  0. 19. 13.  1.  3.  5.  0.  1.  2.]
0.5023537874221802
0.09989541708291708 [ 5.  2. 10.  7.  7.  6.  3.  0.  4.  7. 10.  9.  5.  6.  7.  7.]
0.5016303658485413
0.10141369047619048 [ 5.  4.  4.  2. 16.  4.  5.  6.  1.  1.  6.  4.  6.  3.  1.  0.]
0.5078551769256592
0.17332589285714284 [15.  4.  3.  3.  6.  2.  3.  1.  7.  1.  3.  8.  4.  4.  5.  1.]
0.49942028522491455
0.10364583333333333 [ 4.  3.  3.  3.  2.  2.  1.  1.  3.  0.  4.  2.  2. 17.  2. 11.]
0.5001271367073059
0.139484126984127 [ 3.  3.  5.  3.  3.  3.  8.  2.  7. 17.  2.  8.  1.  0.  1.  4.]
0.4964638352394104
0.1326388888888889 [ 1.  6.  3.  8.  6.  2. 16.  4.  2.  1.

ValueError: Expected more than 1 value per channel when training, got input size torch.Size([1, 71])

In [27]:
single_sample = train_ds[0]
single_sample

(tensor([  101,  3602,  1024,  2023, 19962, 22599,  2003,  2005,  1996,  8917,
         13290,  3059,  2713,  2007,  1996,  9214,  1999,  2023,  3056,  2344,
          1012, 11235,  6382,  7245, 13999,  2093,  5469,  7122,  1997,  1996,
          6097,  7875,  2890,  1998,  1996, 11189,  2124,  2004,  1996,  1005,
          2093,  5344,  1997,  3571,  1005,  1012,  1996,  7026,  7352,  2100,
          1006, 15954, 21442, 19562,  1007,  2003,  2019,  8702,  1010,  2152,
          1011, 21125, 24262,  2655,  1011,  2611,  2040,  5651,  2000,  2014,
         22445,  1010,  8102,  4545,  2044,  2019,  3944,  2041,  2043,  2016,
          3202,  4152,  2022, 13462,  2011,  1037,  2186,  1997,  4326,  3042,
          4455,  1012,  1996, 20587,  2574,  4453,  2370,  2004,  3581,  1010,
          2014,  4654,  1011, 14255,  8737,  2040,  2038,  3728,  6376,  2013,
          3827,  1012, 26851,  2003, 10215,  2005,  2009,  2001,  2014, 10896,
          2008,  5565,  1996,  2158,  1999,  7173,  

In [34]:
model = BertForMultiLabelClassification()
model.freeze_bert_encoder()

criterion = BCEWithLogitsLoss()

optimizer = torch.optim.Adamax(model.parameters(), lr=0.001)


In [46]:
for i in range(100):
    single_sample = tuple(t.to(device) for t in single_sample)
    input_ids, input_mask, segment_ids, label_ids = single_sample
    
    logits = model(input_ids.unsqueeze(0), segment_ids.unsqueeze(0), input_mask.unsqueeze(0))
    
    loss = criterion(logits.view(-1, 71), label_ids.view(-1,71))
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    print(loss)

tensor(0.0382, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.0359, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.0321, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.0316, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.0307, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.0257, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.0274, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.0278, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.0286, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.0272, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.0264, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.0244, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.0252, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.0268, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.0251, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.0240, grad_fn=<BinaryCrossEntropyWithLogitsBac