In [21]:
import pandas as pd
import numpy as np
from sklearn.metrics import jaccard_score
import os
import time
from tqdm import tnrange, tqdm_notebook
# https://nbviewer.jupyter.org/github/kaushaltrivedi/bert-toxic-comments-multilabel/blob/master/toxic-bert-multilabel-classification.ipynb?source=post_page-------

In [22]:

WD = os.getcwd()
DATA_DIR = os.path.join(WD, 'data','mpst-movie-plot-synopses-with-tags','mpst_full_data.csv')

In [23]:
data = pd.read_csv(DATA_DIR)
data = data.drop(['synopsis_source'],axis=1)
data.shape

(14828, 5)

In [24]:
data.head()

Unnamed: 0,imdb_id,title,plot_synopsis,tags,split
0,tt0057603,I tre volti della paura,Note: this synopsis is for the orginal Italian...,"cult, horror, gothic, murder, atmospheric",train
1,tt1733125,Dungeons & Dragons: The Book of Vile Darkness,"Two thousand years ago, Nhagruul the Foul, a s...",violence,train
2,tt0033045,The Shop Around the Corner,"Matuschek's, a gift store in Budapest, is the ...",romantic,test
3,tt0113862,Mr. Holland's Opus,"Glenn Holland, not a morning person by anyone'...","inspiring, romantic, stupid, feel-good",train
4,tt0086250,Scarface,"In May 1980, a Cuban man named Tony Montana (A...","cruelty, murder, dramatic, cult, violence, atm...",val


In [25]:
split = data['tags'].str.split(', ')
lens = split.str.len()


In [26]:
temp_df = pd.DataFrame({'imdb_id': np.repeat(data['imdb_id'].values, lens), 
                        'category': np.concatenate(split),
                       'values': 1})

print(temp_df['category'].unique())
print(len(temp_df['category'].unique()))

temp_df = temp_df.pivot(index='imdb_id', columns='category', values='values').fillna(0).reset_index()



['cult' 'horror' 'gothic' 'murder' 'atmospheric' 'violence' 'romantic'
 'inspiring' 'stupid' 'feel-good' 'cruelty' 'dramatic' 'action' 'revenge'
 'sadist' 'queer' 'flashback' 'mystery' 'suspenseful' 'neo noir' 'prank'
 'psychedelic' 'tragedy' 'autobiographical' 'home movie'
 'good versus evil' 'depressing' 'realism' 'boring' 'haunting'
 'sentimental' 'paranormal' 'historical' 'storytelling' 'comedy' 'fantasy'
 'philosophical' 'adult comedy' 'cute' 'entertaining' 'bleak' 'humor'
 'plot twist' 'christian film' 'pornographic' 'insanity' 'brainwashing'
 'sci-fi' 'dark' 'claustrophobic' 'psychological' 'melodrama'
 'historical fiction' 'absurd' 'satire' 'alternate reality'
 'alternate history' 'comic' 'grindhouse film' 'thought-provoking'
 'clever' 'western' 'blaxploitation' 'whimsical' 'intrigue' 'allegory'
 'anti war' 'avant garde' 'suicidal' 'magical realism' 'non fiction']
71


In [27]:
data_separate = data.merge(temp_df, how='left', on='imdb_id')
data_separate.head()

Unnamed: 0,imdb_id,title,plot_synopsis,tags,split,absurd,action,adult comedy,allegory,alternate history,...,sentimental,storytelling,stupid,suicidal,suspenseful,thought-provoking,tragedy,violence,western,whimsical
0,tt0057603,I tre volti della paura,Note: this synopsis is for the orginal Italian...,"cult, horror, gothic, murder, atmospheric",train,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,tt1733125,Dungeons & Dragons: The Book of Vile Darkness,"Two thousand years ago, Nhagruul the Foul, a s...",violence,train,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0
2,tt0033045,The Shop Around the Corner,"Matuschek's, a gift store in Budapest, is the ...",romantic,test,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,tt0113862,Mr. Holland's Opus,"Glenn Holland, not a morning person by anyone'...","inspiring, romantic, stupid, feel-good",train,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,tt0086250,Scarface,"In May 1980, a Cuban man named Tony Montana (A...","cruelty, murder, dramatic, cult, violence, atm...",val,0.0,1.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0


In [28]:
train_df = data_separate[data_separate['split'] == 'train']
val_df = data_separate[data_separate['split'] == 'val']
test_df = data_separate[data_separate['split'] == 'test']

train_df.shape, val_df.shape, test_df.shape

((9489, 76), (2373, 76), (2966, 76))

# Pre-processing

In [44]:
import re

In [89]:
def clean_str(string):
    string = re.sub(r"[^A-Za-z0-9(),!?\'\`]", " ", string)
    string = re.sub(r"\'s", " \'s", string)
    string = re.sub(r"\'ve", " \'ve", string)
    string = re.sub(r"n\'t", " n\'t", string)
    string = re.sub(r"\'re", " \'re", string)
    string = re.sub(r"\'d", " \'d", string)
    string = re.sub(r"\'ll", " \'ll", string)
    string = re.sub(r",", " , ", string)
    string = re.sub(r"!", " ! ", string)
    string = re.sub(r"\(", " \( ", string)
    string = re.sub(r"\)", " \) ", string)
    string = re.sub(r"\?", " \? ", string)
    string = re.sub(r"\s{2,}", " ", string)
    return string.strip()

class Vocab:
    def __init__(self):
        self.word2index = {}
        self.index2word = {}
        self.word2count = {}
        self.n_words = 1 # 0 is reserved for none

    def addSentence(self, sentence):
        sentence = clean_str(sentence)
        for word in sentence.split(' '):
            self.addWord(word)

    def addWord(self, word):
        if word not in self.word2index:
            self.word2index[word] = self.n_words
            self.word2count[word] = 1
            self.index2word[self.n_words] = word
            self.n_words += 1
        else:
            self.word2count[word] += 1
            
def loadGloveModel(gloveFile):
    print("Loading Glove Model")
    f = open(gloveFile,'r', encoding="utf8")
    model = {}
    for line in f:
        splitLine = line.split()
        word = splitLine[0]
        embedding = np.array([float(val) for val in splitLine[1:]])
        model[word] = embedding
    print("Done.",len(model)," words loaded!")
    return model

In [90]:
# build vocab
vocab = Vocab()
for i in train_df['plot_synopsis']:
    vocab.addSentence(i)
print('Vocab Size:', vocab.n_words)

Vocab Size: 121301


In [91]:
# build word embeddings
glove = loadGloveModel(os.path.join(WD,'glove.6B','glove.6B.300d.txt'))

Loading Glove Model
Done. 400000  words loaded!


In [133]:
# create token-embedding mapping
weights_matrix = np.zeros((vocab.n_words+1, 300))
for word, i in vocab.word2index.items():
    if word in glove:
        weights_matrix[i] = glove[word]

# Data Loader

In [322]:
import torch
# from pytorch_transformers import *
# from pytorch_transformers.modeling_bert import BertPreTrainedModel
# from pytorch_transformers.optimization import AdamW

from torch.utils.data import Dataset, DataLoader
from torch.nn import BCEWithLogitsLoss

import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = "cpu"

In [331]:
class MPSTDataset(Dataset):

    def __init__(self, dataframe, max_seq_length, transform=None):
        self.dataframe = dataframe
        self.transform = transform
        self.max_seq_length = max_seq_length

    def __len__(self):
        return len(self.dataframe)
    
    def prepare_sample_features(self, sample):
        sample_clean = clean_str(sample)
        
        tokenized_sample = sample_clean.split(' ')[:self.max_seq_length]
        
        
    
        input_ids = [vocab.word2index[x] if (x in vocab.word2index) else 0 for x in tokenized_sample]

        # Zero-pad up to the sequence length.
        padding = [0] * (self.max_seq_length - len(input_ids))
        input_ids += padding
        
        assert len(input_ids) == self.max_seq_length
        
        return input_ids


    def __getitem__(self, idx):
        sample = self.dataframe.iloc[idx]['plot_synopsis']
        label = self.dataframe.iloc[idx][5:]
        
        input_ids = self.prepare_sample_features(sample)
        
        return torch.tensor(input_ids), torch.tensor(label)
        

In [404]:
class XML_CNN(torch.nn.Module):
    def __init__(self, weights_matrix, label_size=71):
        super(XML_CNN, self).__init__()
        weight_tensor = torch.tensor(weights_matrix)
        num_embeddings, embedding_dim = weight_tensor.size()
        
        self.embedding = torch.nn.Embedding(num_embeddings, embedding_dim)
        self.embedding.load_state_dict({'weight': weight_tensor})
        self.embedding.weight.requires_grad = False
        
        self.cnnk2 = torch.nn.Conv2d(1,3,(2,embedding_dim))
        self.cnnk4 = torch.nn.Conv2d(1,3,(4,embedding_dim))
        self.cnnk8 = torch.nn.Conv2d(1,3,(8,embedding_dim))
        
        self.linear = torch.nn.Linear(2989, 512)
        self.dropout = torch.nn.Dropout(0.5)
        self.output = torch.nn.Linear(512,label_size)
        
    def kmax_pooling(self, x, dim, k):
        index = x.topk(k, dim = dim)[1].sort(dim = dim)[0]
        return x.gather(dim, index)
    
    def conv_and_pool(self, conv, x):
        x = F.relu(conv(x)).squeeze(3)
        B, C, W = x.size() 
        x = self.kmax_pooling(x.view(B,-1), dim = 1, k=(C*W)//3)
        return(x)
        
    def forward(self,x):
        x = self.embedding(x) #  (N, W, D)
#         print(x.size())
        x = x.unsqueeze(1) # (N, Ci, W, D)
        
        x2 = self.conv_and_pool(self.cnnk2, x) # (B, W)
        x4 = self.conv_and_pool(self.cnnk4, x)
        x8 = self.conv_and_pool(self.cnnk8, x)
        
        x = torch.cat([x2,x4,x8],1)
        x = self.dropout(x)
        x = F.relu(self.linear(x))
        x = self.output(x)

#         print('cnn2', x2.size())
#         print('cnn4', x4.size())
#         print('cnn8', x8.size())
        return(x)
      

In [405]:
# metric
def precision_k(pred, label, k=[1, 3, 5]):
    batch_size = pred.shape[0]
    
    precision = []
    for _k in k:
        p = 0
        for i in range(batch_size):
            p += label[i, pred[i, :_k]].mean()
        precision.append(p*100/batch_size)
    
    return precision

def ndcg_k(pred, label, k=[1, 3, 5]):
    batch_size = pred.shape[0]
    
    ndcg = []
    for _k in k:
        score = 0
        rank = np.log2(np.arange(2, 2 + _k))
        for i in range(batch_size):
            l = label[i, pred[i, :_k]]
            n = l.sum()
            if(n == 0):
                continue
            
            dcg = (l/rank).sum()
            label_count = label[i].sum()
            norm = 1 / np.log2(np.arange(2, 2 + np.min((_k, label_count))))
            norm = norm.sum()
            score += dcg/norm
            
        ndcg.append(score*100/batch_size)
    
    return ndcg

In [448]:
def train_model(dataloaders, model, optimizer, criterion, scheduler, num_epochs=2):
    since = time.time()
    step_sizes = {'train': len(dataloaders['train']), 
                     'valid': len(dataloaders['valid'])}
    
    weight = torch.tensor([0.5, 1.5])

    for epoch in tnrange(int(num_epochs), desc="Epoch"):
        for phase in ['train', 'valid']:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            running_loss = 0
            running_acc = 0
            
            r_p1, r_p3, r_p5 = 0,0,0
            r_ndcg1, r_ndcg3, r_ndcg5 = 0,0,0
        
            for step, batch in enumerate(tqdm_notebook(dataloaders[phase], desc=phase)):
                batch = tuple(t.to(device) for t in batch)
                input_ids, label_ids = batch
                
                logits = model(input_ids)
                sigmoid = logits.sigmoid()

                loss = criterion(sigmoid, label_ids)
                
#                 label_weight = weight[label_ids.data.view(-1).long()].view_as(label_ids).to(device)
#                 weighted_loss = loss * label_weight
#                 weighted_loss_average = weighted_loss.mean()
#                 running_loss += weighted_loss_average.item()

                running_loss += loss.item()
                
#                 sigmoid_numpy = sigmoid.detach().cpu().numpy()
#                 labels_numpy = label_ids.detach().cpu().numpy()
#                 acc = jaccard_score(labels_numpy, sigmoid_numpy.round(), average='samples')
#                 running_acc += acc
#                 print(weighted_loss_average.item(), acc, sigmoid_numpy.round().sum(axis=1))

               

                logits_cpu = logits.data.cpu()
                labels_cpu = label_ids.data.cpu()
            
#                 print(logits_cpu.topk(k=5)[0].numpy())
#                 print(logits_cpu.topk(k=5)[1].numpy())
            
                _p1,_p3,_p5=precision_k(logits_cpu.topk(k=5)[1].numpy(), labels_cpu.numpy(), k=[1,3,5])
                r_p1+= _p1
                r_p3+= _p3
                r_p5+= _p5
                
                _ndcg1,_ndcg3,_ndcg5=ndcg_k(logits_cpu.topk(k=5)[1].numpy(), labels_cpu.numpy(), k=[1,3,5])
                r_ndcg1 += _ndcg1
                r_ndcg3 += _ndcg3
                r_ndcg5 += _ndcg5
                
                if phase == 'train':
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
#                     scheduler.step()
            
            if phase == 'train':
                train_loss = running_loss / step_sizes[phase]
                
                r_p1 = r_p1 / step_sizes[phase]
                r_p3 = r_p3 / step_sizes[phase]
                r_p5 = r_p5 / step_sizes[phase]
                
                r_ndcg1 = r_ndcg1 / step_sizes[phase]
                r_ndcg3= r_ndcg3 / step_sizes[phase]
                r_ndcg5 = r_ndcg5 / step_sizes[phase]
                
                print("precision@1 : %.4f , precision@3 : %.4f , precision@5 : %.4f "%(r_p1,r_p3,r_p5))
                print("ndcg@1 : %.4f , ndcg@3 : %.4f , ndcg@5 : %.4f "%(r_ndcg1,r_ndcg3,r_ndcg5))
            else:
                valid_loss = running_loss / step_sizes[phase]
                
                r_p1 = r_p1 / step_sizes[phase]
                r_p3 = r_p3 / step_sizes[phase]
                r_p5 = r_p5 / step_sizes[phase]
                
                r_ndcg1 = r_ndcg1 / step_sizes[phase]
                r_ndcg3= r_ndcg3 / step_sizes[phase]
                r_ndcg5 = r_ndcg5 / step_sizes[phase]
                
                print("precision@1 : %.4f , precision@3 : %.4f , precision@5 : %.4f "%(r_p1,r_p3,r_p5))
                print("ndcg@1 : %.4f , ndcg@3 : %.4f , ndcg@5 : %.4f "%(r_ndcg1,r_ndcg3,r_ndcg5))

                
        print('Epoch [{}/{}] train loss: {:.4f} valid loss: {:.4f} '.format(
                epoch+1, num_epochs,train_loss, valid_loss))
                

            
    return model

In [445]:
train_ds = MPSTDataset(train_df, 1000)
train_dl = DataLoader(train_ds,batch_size=32, shuffle=True)

val_ds = MPSTDataset(val_df, 1000)
val_dl = DataLoader(val_ds,batch_size=32, shuffle=True)

dloaders = {'train':train_dl, 'valid':val_dl}

In [446]:
model = XML_CNN(weights_matrix=weights_matrix)
model.to(device)
criterion = torch.nn.BCELoss(reduce=True)

# optimizer = torch.optim.Adamax(model.parameters(), lr=0.001)
# optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)



In [447]:
start_time = time.time()
model = train_model(dloaders, model, optimizer,criterion, scheduler=None, num_epochs=10)
print('Training time: {:10f} minutes'.format((time.time()-start_time)/60))

HBox(children=(IntProgress(value=0, description='Epoch', max=10, style=ProgressStyle(description_width='initia…

HBox(children=(IntProgress(value=0, description='train', max=297, style=ProgressStyle(description_width='initi…

precision@1 : 36.9256 , precision@3 : 28.9158 , precision@5 : 24.2293 
ndcg@1 : 36.9256 , ndcg@3 : 38.2709 , ndcg@5 : 41.3068 


HBox(children=(IntProgress(value=0, description='valid', max=75, style=ProgressStyle(description_width='initia…

precision@1 : 40.3500 , precision@3 : 31.7250 , precision@5 : 25.8017 
ndcg@1 : 40.3500 , ndcg@3 : 42.3261 , ndcg@5 : 44.7706 
Epoch [1/10] train loss: 0.0810 valid loss: 0.0750 


HBox(children=(IntProgress(value=0, description='train', max=297, style=ProgressStyle(description_width='initi…

precision@1 : 41.0979 , precision@3 : 31.3305 , precision@5 : 25.9143 
ndcg@1 : 41.0979 , ndcg@3 : 42.5119 , ndcg@5 : 45.4195 


HBox(children=(IntProgress(value=0, description='valid', max=75, style=ProgressStyle(description_width='initia…

precision@1 : 39.9250 , precision@3 : 30.3194 , precision@5 : 25.9650 
ndcg@1 : 39.9250 , ndcg@3 : 40.1382 , ndcg@5 : 43.7190 
Epoch [2/10] train loss: 0.0737 valid loss: 0.0754 


HBox(children=(IntProgress(value=0, description='train', max=297, style=ProgressStyle(description_width='initi…

precision@1 : 43.2350 , precision@3 : 32.6135 , precision@5 : 26.5726 
ndcg@1 : 43.2350 , ndcg@3 : 44.6177 , ndcg@5 : 47.3390 


HBox(children=(IntProgress(value=0, description='valid', max=75, style=ProgressStyle(description_width='initia…

precision@1 : 47.8833 , precision@3 : 33.7778 , precision@5 : 27.3533 
ndcg@1 : 47.8833 , ndcg@3 : 47.1790 , ndcg@5 : 49.8195 
Epoch [3/10] train loss: 0.0721 valid loss: 0.0725 


HBox(children=(IntProgress(value=0, description='train', max=297, style=ProgressStyle(description_width='initi…

precision@1 : 46.6522 , precision@3 : 33.9547 , precision@5 : 27.2309 
ndcg@1 : 46.6522 , ndcg@3 : 47.0140 , ndcg@5 : 49.3582 


HBox(children=(IntProgress(value=0, description='valid', max=75, style=ProgressStyle(description_width='initia…

precision@1 : 47.0750 , precision@3 : 33.5444 , precision@5 : 27.2850 
ndcg@1 : 47.0750 , ndcg@3 : 46.8793 , ndcg@5 : 49.6336 
Epoch [4/10] train loss: 0.0707 valid loss: 0.0725 


HBox(children=(IntProgress(value=0, description='train', max=297, style=ProgressStyle(description_width='initi…

precision@1 : 48.4972 , precision@3 : 34.9164 , precision@5 : 27.8133 
ndcg@1 : 48.4972 , ndcg@3 : 48.6566 , ndcg@5 : 50.8335 


HBox(children=(IntProgress(value=0, description='valid', max=75, style=ProgressStyle(description_width='initia…

precision@1 : 47.9917 , precision@3 : 34.7167 , precision@5 : 27.9450 
ndcg@1 : 47.9917 , ndcg@3 : 47.9699 , ndcg@5 : 50.4864 
Epoch [5/10] train loss: 0.0692 valid loss: 0.0727 


HBox(children=(IntProgress(value=0, description='train', max=297, style=ProgressStyle(description_width='initi…

precision@1 : 49.9164 , precision@3 : 35.9050 , precision@5 : 28.7127 
ndcg@1 : 49.9164 , ndcg@3 : 49.8788 , ndcg@5 : 52.3346 


HBox(children=(IntProgress(value=0, description='valid', max=75, style=ProgressStyle(description_width='initia…

precision@1 : 47.4667 , precision@3 : 34.1444 , precision@5 : 27.3267 
ndcg@1 : 47.4667 , ndcg@3 : 47.2537 , ndcg@5 : 49.6044 
Epoch [6/10] train loss: 0.0673 valid loss: 0.0727 


HBox(children=(IntProgress(value=0, description='train', max=297, style=ProgressStyle(description_width='initi…

precision@1 : 51.2422 , precision@3 : 36.8551 , precision@5 : 29.4816 
ndcg@1 : 51.2422 , ndcg@3 : 51.2850 , ndcg@5 : 53.7394 


HBox(children=(IntProgress(value=0, description='valid', max=75, style=ProgressStyle(description_width='initia…

precision@1 : 49.6583 , precision@3 : 34.3472 , precision@5 : 27.2733 
ndcg@1 : 49.6583 , ndcg@3 : 48.1743 , ndcg@5 : 50.3248 
Epoch [7/10] train loss: 0.0651 valid loss: 0.0738 


HBox(children=(IntProgress(value=0, description='train', max=297, style=ProgressStyle(description_width='initi…

precision@1 : 51.9131 , precision@3 : 37.8086 , precision@5 : 30.3233 
ndcg@1 : 51.9131 , ndcg@3 : 52.4049 , ndcg@5 : 54.9959 


HBox(children=(IntProgress(value=0, description='valid', max=75, style=ProgressStyle(description_width='initia…

precision@1 : 49.4500 , precision@3 : 34.1028 , precision@5 : 27.0400 
ndcg@1 : 49.4500 , ndcg@3 : 48.1710 , ndcg@5 : 50.2875 
Epoch [8/10] train loss: 0.0631 valid loss: 0.0747 


HBox(children=(IntProgress(value=0, description='train', max=297, style=ProgressStyle(description_width='initi…

precision@1 : 53.4097 , precision@3 : 38.9817 , precision@5 : 31.3131 
ndcg@1 : 53.4097 , ndcg@3 : 53.8931 , ndcg@5 : 56.4740 


HBox(children=(IntProgress(value=0, description='valid', max=75, style=ProgressStyle(description_width='initia…

precision@1 : 48.1417 , precision@3 : 33.9194 , precision@5 : 27.0467 
ndcg@1 : 48.1417 , ndcg@3 : 47.3307 , ndcg@5 : 49.7012 
Epoch [9/10] train loss: 0.0604 valid loss: 0.0754 


HBox(children=(IntProgress(value=0, description='train', max=297, style=ProgressStyle(description_width='initi…

precision@1 : 55.8929 , precision@3 : 40.2353 , precision@5 : 32.3131 
ndcg@1 : 55.8929 , ndcg@3 : 56.0303 , ndcg@5 : 58.7521 


HBox(children=(IntProgress(value=0, description='valid', max=75, style=ProgressStyle(description_width='initia…

precision@1 : 47.9250 , precision@3 : 33.2917 , precision@5 : 26.4667 
ndcg@1 : 47.9250 , ndcg@3 : 46.8679 , ndcg@5 : 49.0094 
Epoch [10/10] train loss: 0.0583 valid loss: 0.0767 
Training time:   5.166269 minutes


In [434]:
test_ds = MPSTDataset(test_df, 1000)
test_dl = DataLoader(test_ds,batch_size=3, shuffle=True)

In [435]:
test_sample = iter(test_dl).next()

In [436]:
test_sample[0]

tensor([[21643,    38,     6,  ...,  1539,  5562,    38],
        [  162,   369,     6,  ...,     0,     0,     0],
        [  607,     6,   201,  ...,  2654,    35,   140]])

In [437]:
pred = model(test_sample[0].to(device)).topk(5)
pred

torch.return_types.topk(
values=tensor([[ 0.1738,  0.0228, -0.6677, -1.2264, -1.4117],
        [-1.0175, -1.3914, -1.7474, -2.2155, -2.3684],
        [-0.0828, -0.1114, -0.2452, -0.2955, -0.8152]], device='cuda:0',
       grad_fn=<TopkBackward>),
indices=tensor([[43, 68, 56, 17,  1],
        [43, 68, 28, 57, 52],
        [20, 43, 68, 17, 57]], device='cuda:0'))

In [442]:
test_sample[1][[[43, 68, 56, 17,  1],
        [43, 68, 28, 57, 52],
        [20, 43, 68, 17, 57]]]

IndexError: too many indices for tensor of dimension 2

In [443]:
test_sample[1]

tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
         0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 1., 0., 1., 0., 0., 0., 0., 0., 0., 1., 0., 0., 1., 0., 0.]])

In [425]:
pred[1][1]

tensor([43, 68, 20, 17,  1], device='cuda:0')

In [439]:
test_sample[1][1][[43, 68, 28, 57, 52]]

tensor([1., 1., 0., 0., 0.])

In [433]:
test_sample[1][1]

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])