In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [3]:
!pip install transformers==4.19.2
!pip install pytorch_lightning

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


## T5 fine-tuning


In [4]:
import argparse
import glob
import os
import json
import time
import logging
import random
import re
from itertools import chain
from string import punctuation

import nltk
nltk.download('punkt')
from nltk.tokenize import sent_tokenize

import pandas as pd
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl


from transformers import (
    AdamW,
    T5ForConditionalGeneration,
    T5Tokenizer,
    get_linear_schedule_with_warmup
)

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [5]:
def set_seed(seed):
  random.seed(seed)
  np.random.seed(seed)
  torch.manual_seed(seed)
  if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)

set_seed(42)

## Model


In [6]:
class T5FineTuner(pl.LightningModule):
  def __init__(self, hp):
    super(T5FineTuner, self).__init__()
    self.hp = hp
    
    self.model = T5ForConditionalGeneration.from_pretrained(hp.model_name_or_path)
    self.tokenizer = T5Tokenizer.from_pretrained(hp.tokenizer_name_or_path)
  
  def is_logger(self):
    return self.trainer.global_rank <= 0
  
  def forward(
      self, input_ids, attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, labels=None
  ):
    return self.model(
        input_ids,
        attention_mask=attention_mask,
        decoder_input_ids=decoder_input_ids,
        decoder_attention_mask=decoder_attention_mask,
        labels=labels,
    )

  def _step(self, batch):
    labels = batch["target_ids"].to(device)
    labels[labels[:, :] == self.tokenizer.pad_token_id] = -100

    outputs = self(
        input_ids=batch["source_ids"].to(device),
        attention_mask=batch["source_mask"].to(device),
        labels=labels,
        decoder_attention_mask=batch['target_mask'].to(device)
    )

    loss = outputs[0]

    return loss

  def training_step(self, batch, batch_idx):
    loss = self._step(batch)

    tensorboard_logs = {"train_loss": loss}
    return {"loss": loss, "log": tensorboard_logs}
  
  def training_epoch_end(self, outputs):
    avg_train_loss = torch.stack([x["loss"] for x in outputs]).mean()
    tensorboard_logs = {"avg_train_loss": avg_train_loss}
    # return {"avg_train_loss": avg_train_loss, "log": tensorboard_logs, 'progress_bar': tensorboard_logs}

  def validation_step(self, batch, batch_idx):
    loss = self._step(batch)
    return {"val_loss": loss}
  
  def validation_epoch_end(self, outputs):
    avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
    tensorboard_logs = {"val_loss": avg_loss}
    return {"avg_val_loss": avg_loss, "log": tensorboard_logs, 'progress_bar': tensorboard_logs}

  def configure_optimizers(self):
    "Prepare optimizer and schedule (linear warmup and decay)"

    model = self.model
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": self.hp.weight_decay,
        },
        {
            "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
            "weight_decay": 0.0,
        },
    ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=self.hp.learning_rate, eps=self.hp.adam_epsilon)
    self.opt = optimizer
    return [optimizer]
  
  # def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, on_tpu = None, second_order_closure=None):
  #   if self.trainer.use_tpu:
  #     xm.optimizer_step(optimizer)
  #   else:
  #     optimizer.step()
  #   optimizer.zero_grad()
  #   self.lr_scheduler.step()
  
  def get_tqdm_dict(self):
    tqdm_dict = {"loss": "{:.3f}".format(self.trainer.avg_loss), "lr": self.lr_scheduler.get_last_lr()[-1]}

    return tqdm_dict

  def train_dataloader(self):
    train_dataset = get_dataset(self.tokenizer, train_df, self.hp)
    dataloader = DataLoader(train_dataset, batch_size=self.hp.train_batch_size, drop_last=True, shuffle=True, num_workers=2)
    t_total = (
        (len(dataloader.dataset) // (self.hp.train_batch_size * max(1, self.hp.n_gpu)))
        // self.hp.gradient_accumulation_steps
        * float(self.hp.num_train_epochs)
    )
    scheduler = get_linear_schedule_with_warmup(
        self.opt, num_warmup_steps=self.hp.warmup_steps, num_training_steps=t_total
    )
    self.lr_scheduler = scheduler
    return dataloader

  def val_dataloader(self):
    val_dataset = get_dataset(self.tokenizer, val_df, self.hp)
    return DataLoader(val_dataset, batch_size=self.hp.eval_batch_size, num_workers=2)

In [7]:
logger = logging.getLogger(__name__)

class LoggingCallback(pl.Callback):
  def on_validation_end(self, trainer, pl_module):
    logger.info("***** Validation results *****")
    if pl_module.is_logger():
      metrics = trainer.callback_metrics
      # Log results
      for key in sorted(metrics):
        if key not in ["log", "progress_bar"]:
          logger.info("{} = {}\n".format(key, str(metrics[key])))

  def on_test_end(self, trainer, pl_module):
    logger.info("***** Test results *****")

    if pl_module.is_logger():
      metrics = trainer.callback_metrics

      # Log and save results to file
      output_test_results_file = os.path.join(pl_module.hp.output_dir, "test_results.txt")
      with open(output_test_results_file, "w") as writer:
        for key in sorted(metrics):
          if key not in ["log", "progress_bar"]:
            logger.info("{} = {}\n".format(key, str(metrics[key])))
            writer.write("{} = {}\n".format(key, str(metrics[key])))

In [8]:
args_dict = dict(
    data_dir="", # path for data files
    output_dir="", # path to save the checkpoints
    model_name_or_path='t5-small',
    tokenizer_name_or_path='t5-small',
    max_seq_length=512,
    learning_rate=3e-4,
    weight_decay=0.0,
    adam_epsilon=1e-8,
    warmup_steps=0,
    train_batch_size=8,
    eval_batch_size=8,
    num_train_epochs=2,
    gradient_accumulation_steps=16,
    n_gpu=1,
    early_stop_callback=False,
    fp_16=False, # if you want to enable 16-bit training then install apex and set this to true
    opt_level='O1', # you can find out more on optimisation levels here https://nvidia.github.io/apex/amp.html#opt-levels-and-properties
    max_grad_norm=1.0, # if you enable 16-bit training then set this to a sensible value, 0.5 is a good default
    seed=42,
)

### Download and view data

In [None]:
from google.colab import files
files.upload()         # expire any previous token(s) and upload recreated token

In [10]:
! pip install -q kaggle

!rm -r ~/.kaggle
!mkdir ~/.kaggle
!mv ./kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json
! kaggle competitions download -c jigsaw-toxic-comment-classification-challenge
! rm -r data
! mkdir data
! unzip jigsaw-toxic-comment-classification-challenge.zip -d data
! unzip data/test.csv.zip -d data
! unzip data/train.csv.zip -d data
! unzip data/test_labels.csv.zip -d data

jigsaw-toxic-comment-classification-challenge.zip: Skipping, found more recently modified local copy (use --force to force download)
Archive:  jigsaw-toxic-comment-classification-challenge.zip
  inflating: data/sample_submission.csv.zip  
  inflating: data/test.csv.zip       
  inflating: data/test_labels.csv.zip  
  inflating: data/train.csv.zip      
Archive:  data/test.csv.zip
  inflating: data/test.csv           
Archive:  data/train.csv.zip
  inflating: data/train.csv          
Archive:  data/test_labels.csv.zip
  inflating: data/test_labels.csv    


In [11]:
train_df = pd.read_csv('/content/data/train.csv')
test_df = pd.read_csv('/content/data/test.csv')
test_labels_df = pd.read_csv('/content/data/test_labels.csv')

In [12]:
# preprocessing
def clean_data(df):
    df['comment_text'] = df['comment_text'].str.replace('-', ' ')
    df['comment_text'] = df["comment_text"].str.replace('\n', ' ')
    df['comment_text'] = df['comment_text'].apply(lambda x: x.lower())
    df['comment_text'] = df['comment_text'].apply(lambda x: re.sub(r'\s([?.!"](?:\s|$))', r'\1', x))
    df['comment_text'] = df['comment_text'].apply(lambda x: " ".join(x.split()))
    return df

def get_texts(df):
    df = clean_data(df)
    df['comment_text'] = "multilabel classification: " + df['comment_text']
    return df['comment_text'].values.tolist()

#can be rewritten
def get_labels(df):
    labels_list = [' '.join(x.lower().split()) for x in df.columns.to_list()[3:]]
    labels_matrix = np.array([labels_list] * len(df))

    mask = df.iloc[:, 3:].values.astype(bool)
    labels = []
    for l, m in zip(labels_matrix, mask):
        x = l[m]
        if len(x) > 0:
            labels.append(' , '.join(x.tolist()))
        else:
            labels.append('none')
    return labels


In [13]:
#joining test data and test labels 

test_df['toxic'] = test_labels_df['toxic']
test_df['severe'] = test_labels_df['severe_toxic']
test_df['obscene'] = test_labels_df['obscene']
test_df['threat'] = test_labels_df['threat']
test_df['insult'] = test_labels_df['insult']
test_df['identity'] = test_labels_df['identity_hate']

# remove -1 labels
test_df = test_df[test_df['toxic'] > -1]

#rename labels in training data
train_df.rename(columns={'severe_toxic':'severe'}, inplace=True)
train_df.rename(columns={'identity_hate':'identity'}, inplace=True)


In [58]:
total = len(test_df['toxic'])
x = 0
for i in test_df['toxic']:
    x+=1

print(x, total, x/total*100)

63978 63978 100.0


In [14]:
# train_df['target'] = get_labels(train_df)
# test_df['target'] = get_labels(test_df)
# train_df = train_df[['id', 'comment_text', 'target']]
# test_df = test_df[['id', 'comment_text', 'target']]

In [15]:
# train_df = clean_data(train_df)
# test_df = clean_data(test_df)

In [16]:
# train_df = train_df[:20000]
train_df.head()

Unnamed: 0,id,comment_text,toxic,severe,obscene,threat,insult,identity
0,0000997932d777bf,Explanation\nWhy the edits made under my usern...,0,0,0,0,0,0
1,000103f0d9cfb60f,D'aww! He matches this background colour I'm s...,0,0,0,0,0,0
2,000113f07ec002fd,"Hey man, I'm really not trying to edit war. It...",0,0,0,0,0,0
3,0001b41b1c6bb37e,"""\nMore\nI can't make any real suggestions on ...",0,0,0,0,0,0
4,0001d958c54c6e35,"You, sir, are my hero. Any chance you remember...",0,0,0,0,0,0


In [17]:
# test_df = test_df[:5000]
test_df.head()

Unnamed: 0,id,comment_text,toxic,severe,obscene,threat,insult,identity
5,0001ea8717f6de06,Thank you for understanding. I think very high...,0,0,0,0,0,0
7,000247e83dcc1211,:Dear god this site is horrible.,0,0,0,0,0,0
11,0002f87b16116a7f,"""::: Somebody will invariably try to add Relig...",0,0,0,0,0,0
13,0003e1cccfd5a40a,""" \n\n It says it right there that it IS a typ...",0,0,0,0,0,0
14,00059ace3e3e9a53,""" \n\n == Before adding a new product to the l...",0,0,0,0,0,0


In [18]:
SEED = 42
from sklearn.model_selection import train_test_split
train_df, val_df = train_test_split(train_df, test_size=0.2, random_state = SEED)
train_df.shape, train_df.shape

((127656, 8), (127656, 8))

In [19]:
val_df.shape

(31915, 8)

### Dataset

In [20]:
source_max_length = 512
target_max_length = 24
model_name = 't5-small'

In [21]:
class T5Dataset(Dataset):
    def __init__(self, tokenizer, df, max_len):
        super(T5Dataset, self).__init__()

        self.texts = get_texts(df)
        self.labels = get_labels(df)
        self.tokenizer = tokenizer
        self.src_max_length = source_max_length
        self.tgt_max_length = target_max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, index):
        source_token_df = self.tokenizer.encode_plus(
            self.texts[index], 
            max_length=self.src_max_length,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_token_type_ids=False,
            return_tensors='pt'
        )
        source_input_ids = source_token_df['input_ids'].squeeze()
        source_attention_masks = source_token_df['attention_mask'].squeeze()

        target_token_df = self.tokenizer.encode_plus(
            self.labels[index], 
            max_length=self.tgt_max_length,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_token_type_ids=False,
            return_tensors='pt'
        )

        target_input_ids = target_token_df['input_ids'].squeeze()
        target_attention_masks = target_token_df['attention_mask'].squeeze()

        return { 
            'source_ids': source_input_ids.long(),
            'source_mask': source_attention_masks.long(),
            'target_ids': target_input_ids.long(),
            'target_mask': target_attention_masks.long()
        }

In [22]:
!pip install sentencepiece

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [23]:
tokenizer = T5Tokenizer.from_pretrained('t5-small')

Downloading:   0%|          | 0.00/773k [00:00<?, ?B/s]

For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.
- Be aware that you SHOULD NOT rely on t5-small automatically truncating your input to 512 when padding/encoding.
- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.


In [24]:
dataset = T5Dataset(tokenizer, train_df, 512)
len(dataset)

127656

In [25]:
data = dataset[42]
print(tokenizer.decode(data['source_ids'], skip_special_tokens=True, clean_up_tokenization_spaces = True))
print(tokenizer.decode(data['target_ids'], skip_special_tokens=True, clean_up_tokenization_spaces = True))

multilabel classification: you a joker. you point of view is absolutely not neutral. i am going to report you. you rather stop you activity since it is just what the last world need. you show completely lack of history saying that: 1) recovered territories are “supposed” be historical polish territories 2) assuming that the nazi party are free responsibility of wwii results. finally you erase valuable lings for objective true. you have no arguments look for mediator commission. andrew
none


### Train

In [26]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [27]:
!mkdir -p t5_classification

In [28]:
args_dict.update({'data_dir': '/content/data', 'output_dir': 't5_classification', 'num_train_epochs':2})
args = argparse.Namespace(**args_dict)
print(args_dict)

{'data_dir': '/content/data', 'output_dir': 't5_classification', 'model_name_or_path': 't5-small', 'tokenizer_name_or_path': 't5-small', 'max_seq_length': 512, 'learning_rate': 0.0003, 'weight_decay': 0.0, 'adam_epsilon': 1e-08, 'warmup_steps': 0, 'train_batch_size': 8, 'eval_batch_size': 8, 'num_train_epochs': 2, 'gradient_accumulation_steps': 16, 'n_gpu': 1, 'early_stop_callback': False, 'fp_16': False, 'opt_level': 'O1', 'max_grad_norm': 1.0, 'seed': 42}


In [29]:
checkpoint_callback = pl.callbacks.ModelCheckpoint(
    dirpath=args.output_dir, filename="checkpoint", monitor="val_loss", mode="min", save_top_k=5
)

train_params = dict(
    accumulate_grad_batches=args.gradient_accumulation_steps,
    gpus=args.n_gpu,
    max_epochs=args.num_train_epochs,
    # early_stop_callback=False,
    precision= 16 if args.fp_16 else 32,
    # amp_level=args.opt_level,
    gradient_clip_val=args.max_grad_norm,
    checkpoint_callback=checkpoint_callback,
    callbacks=[LoggingCallback()],
)

In [30]:
def get_dataset(tokenizer, df, args):
  return T5Dataset(tokenizer, df, args.max_seq_length)

In [31]:
model = T5FineTuner(args)
model.to(device)

For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.
- Be aware that you SHOULD NOT rely on t5-small automatically truncating your input to 512 when padding/encoding.
- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.


T5FineTuner(
  (model): T5ForConditionalGeneration(
    (shared): Embedding(32128, 512)
    (encoder): T5Stack(
      (embed_tokens): Embedding(32128, 512)
      (block): ModuleList(
        (0): T5Block(
          (layer): ModuleList(
            (0): T5LayerSelfAttention(
              (SelfAttention): T5Attention(
                (q): Linear(in_features=512, out_features=512, bias=False)
                (k): Linear(in_features=512, out_features=512, bias=False)
                (v): Linear(in_features=512, out_features=512, bias=False)
                (o): Linear(in_features=512, out_features=512, bias=False)
                (relative_attention_bias): Embedding(32, 8)
              )
              (layer_norm): T5LayerNorm()
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (1): T5LayerFF(
              (DenseReluDense): T5DenseReluDense(
                (wi): Linear(in_features=512, out_features=2048, bias=False)
                (wo): Linear(in_feature

In [32]:
trainer = pl.Trainer(**train_params)

  f"Setting `Trainer(checkpoint_callback={checkpoint_callback})` is deprecated in v1.5 and will "
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [33]:
trainer.fit(model)

Missing logger folder: /content/lightning_logs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type                       | Params
-----------------------------------------------------
0 | model | T5ForConditionalGeneration | 60.5 M
-----------------------------------------------------
60.5 M    Trainable params
0         Non-trainable params
60.5 M    Total params
242.026   Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]



Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

### Evaluation

In [34]:
import textwrap
from tqdm.auto import tqdm
from sklearn import metrics

In [35]:
dataset = T5Dataset(tokenizer, test_df, 512)
loader = DataLoader(dataset, batch_size=32, shuffle=True)

In [36]:
it = iter(loader)

In [37]:
batch = next(it)
batch["source_ids"].shape

torch.Size([32, 512])

In [38]:
model.to(device);

In [39]:
outs = model.model.generate(input_ids=batch['source_ids'].to(device), 
                              attention_mask=batch['source_mask'].to(device), 
                              max_length=2)

dec = [tokenizer.decode(ids, skip_special_tokens=True, clean_up_tokenization_spaces = True) for ids in outs]

texts = [tokenizer.decode(ids, skip_special_tokens=True, clean_up_tokenization_spaces = True) for ids in batch['source_ids']]
targets = [tokenizer.decode(ids, skip_special_tokens=True, clean_up_tokenization_spaces = True) for ids in batch['target_ids']]

In [40]:
for i in range(32):
    c = texts[i]
    lines = textwrap.wrap("text:\n%s\n" % c, width=100)
    print("\n".join(lines))
    print("\nActual sentiment: %s" % targets[i])
    print("predicted sentiment: %s" % dec[i])
    print("=====================================================================\n")

text: multilabel classification: " *not much compromise there, i think i was reverted in pretty much
everything...let's try again, again bit by bit: 1 intro, no need for ""his"" consecutively (""he is
known for his technical abilities, especially x and y"", suffices like that); 2 wikilinks, one thing
is the competition and the other is the season of competition, so don't remove both the wikilinks
please (it's ""qualified for"", one wikilink, the uefa cup winners' cup another wikilink, don't glue
them please but i was wrong in the name of the competition, i admit it and i apologize); 3 you
continue to write ""2007 08 season"" even though ""2007 08"" is enough, really fail to grasp that
(2007 08 season, then 2009 10 season, then 2010 11 season, gets a bit tiring to read); 4 apparently
we can't have substitute linked even though it's a wp article, fair enough i quit; 5 the int.goals
chart: for the world cup qualifiers we write qualification, for european championship qualifiers we
write q

#### Test Metrics

In [41]:
dataset = T5Dataset(tokenizer, test_df, 512)
loader = DataLoader(dataset, batch_size=32, num_workers=2)
model.model.eval()
outputs = []
targets = []
for batch in tqdm(loader):
  outs = model.model.generate(input_ids=batch['source_ids'].cuda(), 
                              attention_mask=batch['source_mask'].cuda(), 
                              max_length=2)

  dec = [tokenizer.decode(ids, skip_special_tokens = True, clean_up_tokenization_spaces = True) for ids in outs]
  target = [tokenizer.decode(ids, skip_special_tokens = True, clean_up_tokenization_spaces = True) for ids in batch["target_ids"]]
  
  outputs.extend(dec)
  targets.extend(target)

  0%|          | 0/2000 [00:00<?, ?it/s]

In [42]:
labels = ['toxic', 'severe', 'obscene', 'threat', 'insult',  'identity']

In [43]:
for i, out in enumerate(outputs):
  if out not in labels:
    print(i, out)

[1;30;43mВыходные данные были обрезаны до нескольких последних строк (5000).[0m
58869 none
58870 none
58871 none
58872 none
58873 
58874 none
58875 none
58876 none
58877 none
58878 none
58879 none
58880 none
58881 none
58882 none
58883 none
58884 none
58885 none
58886 none
58887 none
58888 none
58889 none
58890 none
58891 none
58892 none
58893 none
58894 none
58895 none
58896 none
58897 none
58898 none
58899 none
58900 none
58901 none
58902 none
58903 none
58904 none
58905 none
58906 none
58907 none
58908 none
58909 none
58910 none
58911 none
58912 none
58913 none
58914 none
58915 none
58916 none
58917 none
58918 none
58919 none
58920 none
58921 none
58922 none
58923 none
58924 
58925 none
58926 none
58927 none
58928 none
58929 none
58930 
58931 
58932 none
58933 none
58934 none
58935 none
58936 none
58937 none
58939 none
58940 none
58941 none
58942 none
58943 none
58944 none
58945 none
58946 none
58947 none
58948 none
58949 none
58950 none
58951 none
58952 none
58953 none
58954 none

In [44]:
metrics.accuracy_score(targets, outputs)

0.8932289224420895

In [45]:
print(metrics.classification_report(targets, outputs))

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


                                           precision    recall  f1-score   support

                                                0.00      0.00      0.00         0
                                 identity       0.05      0.02      0.03        81
                                   insult       0.22      0.39      0.28       603
                         insult, identity       0.00      0.00      0.00        85
                                     none       0.98      0.96      0.97     59445
                                  obscene       0.00      0.00      0.00       903
                        obscene, identity       0.00      0.00      0.00        20
                          obscene, insult       0.00      0.00      0.00      1947
                obscene, insult, identity       0.00      0.00      0.00       362
                          obscene, threat       0.00      0.00      0.00         6
                  obscene, threat, insult       0.00      0.00      0.00        65
   

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [46]:
model_save_name = 'classifier_v2.pt'
path = F"/content/drive/MyDrive/{model_save_name}" 
torch.save(model.state_dict(), path)

Now lets plot  the confusion matrix and see for which classes our model is getting confused

In [47]:
import seaborn as sn
import pandas as pd
import matplotlib.pyplot as plt

In [48]:
cm = metrics.confusion_matrix(targets, outputs)

In [49]:
print(cm)

[[    0     0     0     0     0     0     0     0     0     0     0     0
      0     0     0     0     0     0     0     0     0     0     0     0]
 [    4     2    11     0    64     0     0     0     0     0     0     0
      0     0     0     0     0     0     0     0     0     0     0     0]
 [   86     1   238     0   277     0     0     0     0     0     0     0
      0     0     0     0     0     0     0     0     0     1     0     0]
 [   10     3    16     0    56     0     0     0     0     0     0     0
      0     0     0     0     0     0     0     0     0     0     0     0]
 [ 1971    20   527     0 56882     0     0     0     0     0     0     0
      0     0     0     0     0     0     0     0     0    45     0     0]
 [  692     0     8     0   193     0     0     0     0     0     0     0
     10     0     0     0     0     0     0     0     0     0     0     0]
 [    7     2     0     0    10     0     0     0     0     0     0     0
      1     0     0     0     0 