In [1]:
!pip install pytorch_lightning wandb

Collecting pytorch_lightning
[?25l  Downloading https://files.pythonhosted.org/packages/ed/af/2f10c8ee22d7a05fe8c9be58ad5c55b71ab4dd895b44f0156bfd5535a708/pytorch_lightning-0.9.0-py3-none-any.whl (408kB)
[K     |████████████████████████████████| 409kB 2.8MB/s 
[?25hCollecting wandb
[?25l  Downloading https://files.pythonhosted.org/packages/2e/d2/f864e4fea30223a694b1454fbe8634eab70d409b5185ec56914bae04d1e8/wandb-0.10.2-py2.py3-none-any.whl (1.6MB)
[K     |████████████████████████████████| 1.6MB 12.8MB/s 
[?25hCollecting tensorboard==2.2.0
[?25l  Downloading https://files.pythonhosted.org/packages/54/f5/d75a6f7935e4a4870d85770bc9976b12e7024fbceb83a1a6bc50e6deb7c4/tensorboard-2.2.0-py3-none-any.whl (2.8MB)
[K     |████████████████████████████████| 2.8MB 19.2MB/s 
Collecting PyYAML>=5.1
[?25l  Downloading https://files.pythonhosted.org/packages/64/c2/b80047c7ac2478f9501676c988a5411ed5572f35d1beff9cae07d321512c/PyYAML-5.3.1.tar.gz (269kB)
[K     |████████████████████████████████| 

In [2]:
import torch
import torch.nn as nn
from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR
from torchtext import data
import numpy as np
import pytorch_lightning as pl

from torchsummary import summary
from pytorch_lightning.loggers import WandbLogger

In [6]:
TEXT = data.Field(tokenize='spacy',batch_first=True,include_lengths=True)
LABEL = data.LabelField(dtype = torch.long,batch_first=True)

train_data=data.TabularDataset(path = "./drive/My Drive/train.txt",csv_reader_params={'delimiter':';'},format = 'csv',fields = [('text',TEXT),('label',LABEL)],skip_header = False)
val_data = data.TabularDataset(path = "./drive/My Drive/val.txt",csv_reader_params={'delimiter':';'},format = 'csv',fields = [('text',TEXT),('label',LABEL)],skip_header = False)
test_data = data.TabularDataset(path = "./drive/My Drive/test.txt",csv_reader_params={'delimiter':';'},format = 'csv',fields = [('text',TEXT),('label',LABEL)],skip_header = False)

train_iter, valid_iter = data.BucketIterator.splits(
    (train_data, val_data), 
    batch_size = 64,
    sort_key = lambda x: len(x.text),
    sort_within_batch=True,
    shuffle = True,
    device = 'cuda' if torch.cuda.is_available() else 'cpu')


In [14]:
TEXT.build_vocab(train_data,min_freq=3,vectors = "glove.6B.100d")  
LABEL.build_vocab(train_data)

#No. of unique tokens in text
print("Size of TEXT vocabulary:",len(TEXT.vocab))

#No. of unique tokens in label
print("Size of LABEL vocabulary:",len(LABEL.vocab))

#Commonly used words
print(TEXT.vocab.freqs.most_common(10))  

#Word dictionary
print(TEXT.vocab.stoi)   

.vector_cache/glove.6B.zip: 862MB [06:30, 2.21MB/s]                           
100%|█████████▉| 398911/400000 [00:23<00:00, 19050.73it/s]

Size of TEXT vocabulary: 5226
Size of LABEL vocabulary: 6
[('i', 29007), ('feel', 11183), ('and', 9589), ('to', 8972), ('the', 8370), ('a', 6201), ('that', 5217), ('feeling', 5112), ('of', 4990), ('my', 4283)]


In [8]:
class EmotionClassifier(pl.LightningModule):

  def __init__(self, num_embedding, embedding_dim = 100, hidden_size = 512):
    super(EmotionClassifier,self).__init__()

    self.embedding = nn.Embedding(num_embedding, embedding_dim)

    self.lstm = nn.LSTM(embedding_dim, 
                        hidden_size, 
                        num_layers=2, 
                        bidirectional=True, 
                        dropout=0.15,
                        batch_first=True)
    
    self.fc = nn.Linear(hidden_size * 2, 6)

    self.activation = nn.Softmax(dim = 1)

    self.crit = nn.CrossEntropyLoss();

  def forward(self,x):
    text, text_lengths = x.text
    embed = self.embedding(text)
    packed = nn.utils.rnn.pack_padded_sequence(embed,text_lengths,batch_first=True)
    _, (hidden, _) = self.lstm(packed)
    #hidden = [batch size, num layers * num directions,hid dim]
    #cell = [batch size, num layers * num directions,hid dim]
  
    hidden = torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim = 1)

    result = self.fc(hidden)
    return self.activation(result)

  def configure_optimizers(self):
    optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
    return optimizer
  
  def training_step(self, batch, batch_idx):
    y_hat = self(batch)
    loss = self.crit(y_hat, batch.label)
    result = pl.TrainResult(loss)
    result.log('train_loss', loss)
    _,idx = torch.max(y_hat,dim=1)
    acc = (batch.label == idx).sum().float()/len(batch.label)
    result.log('train_acc',acc)
    return result

  def validation_step(self, batch, batch_idx):
    y_hat = self(batch)
    loss = self.crit(y_hat, batch.label)
    result = pl.EvalResult(checkpoint_on=loss)
    result.log('val_loss', loss)
    # print(len(batch.label),batch.label.shape,y_hat.shape)
    _,idx = torch.max(y_hat,dim=1)
    acc = (batch.label == idx).sum().float()/len(batch.label)
    result.log('val_acc',acc)
    return result

  def test_step(self, batch, batch_idx):
    y_hat = self(batch)
    loss = self.crit(y_hat, batch.label)
    _,idx = torch.max(y_hat,dim=1)
    acc = (batch.label == idx).sum().float()/len(batch.label)
    result = pl.EvalResult()
    result.log('test_loss', loss)
    result.log('test_acc',acc)
    return result

In [13]:
model = EmotionClassifier(len(TEXT.vocab))

wandb = WandbLogger(project = "Emotion_Classification")
trainer = pl.Trainer(gpus=-1,
                     weights_save_path="./drive/My Drive/EC_save", 
                     max_epochs=100,
                    #  limit_train_batches=0.1,
                    #  limit_val_batches=0.1,
                     logger=wandb,
                     early_stop_callback=True,
                     )
trainer.fit(model,train_iter,valid_iter)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]
[34m[1mwandb[0m: Waiting for W&B process to finish, PID 582
[34m[1mwandb[0m: Program ended successfully.





[34m[1mwandb[0m: \ 0.01MB of 0.01MB uploaded (0.00MB deduped)[34m[1mwandb[0m:                                                                                
[34m[1mwandb[0m: Find user logs for this run at: wandb/run-20200928_132732-2qkf0722/logs/debug.log
[34m[1mwandb[0m: Find internal logs for this run at: wandb/run-20200928_132732-2qkf0722/logs/debug-internal.log
[34m[1mwandb[0m: Run summary:
[34m[1mwandb[0m:   global_step 8749
[34m[1mwandb[0m:    train_loss 1.0597378015518188
[34m[1mwandb[0m:     train_acc 0.984375
[34m[1mwandb[0m:         epoch 34
[34m[1mwandb[0m:         _step 209
[34m[1mwandb[0m:      _runtime 606
[34m[1mwandb[0m:    _timestamp 1601300262
[34m[1mwandb[0m:    valid_loss 1.1236865520477295
[34m[1mwandb[0m:     valid_acc 0.9194999933242798
[34m[1mwandb[0m: Run history:
[34m[1mwandb[0m:   global_step ▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
[34m[1mwandb[0m:    train_loss ██▇▇▆▄▄▃▄▄▂▂▂▂▁▂▂▁▂▁▁▂▂▂▁▂▁▁▁▂▁▁▂▁▁▂▁▁▁▂
[3


  | Name       | Type             | Params
------------------------------------------------
0 | embedding  | Embedding        | 522 K 
1 | lstm       | LSTM             | 8 M   
2 | fc         | Linear           | 6 K   
3 | activation | Softmax          | 0     
4 | crit       | CrossEntropyLoss | 0     





HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Saving latest checkpoint..





1

In [12]:
trainer=pl.Trainer(resume_from_checkpoint="/content/drive/My Drive/EC_save/Emotion_Classification/1ob1vexv/checkpoints/epoch=77.ckpt",gpus=-1)
model = EmotionClassifier(5226)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]


In [15]:
test_iter = data.BucketIterator(test_data,64,sort_key=lambda x: len(x.text),sort_within_batch=True)
trainer.test(model,test_iter)



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_acc': tensor(0.9180, device='cuda:0'),
 'test_loss': tensor(1.1244, device='cuda:0')}
--------------------------------------------------------------------------------



[{'test_acc': 0.9179999828338623, 'test_loss': 1.1244350671768188}]

In [None]:
from torch.utils.data import TensorDataset, DataLoader

model = EmotionClassifier();
model = model.to('cuda' if torch.cuda.is_available() else 'cpu')
state_dict = torch.load('./drive/My Drive/EC_save/EC1 ep acc loss.pt',map_location = 'cuda' if torch.cuda.is_available() else 'cpu');
model.load_state_dict(state_dict)
model.eval();

datal= DataLoader(TensorDataset(test_data,torch.tensor(test_labels)),batch_size=1, pin_memory = True, shuffle = True);
sum = 0;
i = 0;
confusion_matrix = np.zeros((6,6))  # row is truth and col is prediction
with torch.no_grad():
  for i,item in enumerate(datal):
    data = item[0].to('cuda' if torch.cuda.is_available() else 'cpu') 
    label = item[1].to('cuda' if torch.cuda.is_available() else 'cpu')
    output = model(data);
    _,y_pred = torch.max(output,1);
    # print(i,output.tolist(), y_pred.tolist(),label.tolist())
    for k in range(label.shape[0]):
      confusion_matrix[label[k],y_pred[k]] += 1
    sum = sum + (torch.sum(y_pred == label)).item();
    if (i == 500):
      break
  rowsum = np.sum(confusion_matrix,axis = 1);
  rowsum = np.where(rowsum == 0, 1, rowsum)
  confusion_matrix /= rowsum[:,None]
  print("output = ",sum/i);
  from sklearn.metrics import ConfusionMatrixDisplay
  disp = ConfusionMatrixDisplay(confusion_matrix,range(15))
  disp.plot(values_format='.1f')