# Jupyter notebook do trenowania klasyfikatora zachowań

Notatnik bazowany jest na szkielecie klasyfikatora przedstawionym podczas laboratoriów zajęć GSN (Głębokie Sieci Neuronowe w Mediach Cyfrowych)

# Instalacja niestandardowych bibliotek (Colab)

In [None]:
# pobranie biblioteki do danych w formacie video (klatki .jpg)
!git clone https://github.com/RaivoKoot/Video-Dataset-Loading-Pytorch.git
!mv Video-Dataset-Loading-Pytorch Video_Dataset_Loading_Pytorch

In [None]:
#pobranie pytorch lightning i w&b
!pip install pytorch-lightning --quiet
!pip install wandb --quiet

# Import bibliotek

In [1]:
# standardowe pakiety
import os
import numpy as np

# Pytorch
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split, Dataset
from torchvision import models
from torchvision.datasets import ImageFolder
from torchvision.transforms import Resize, Compose, ToTensor, Normalize, Grayscale 
from torchvision.transforms import RandomHorizontalFlip, RandomVerticalFlip, RandomRotation, InterpolationMode

# Pytorch Lightning
import pytorch_lightning as pl
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint
import torchmetrics

# Weights and Biases
import wandb

# sklearn
from sklearn.metrics import ConfusionMatrixDisplay

# matplotlib
import matplotlib.pyplot as plt

# Custom Video Dataset Loading function
# https://github.com/RaivoKoot/Video-Dataset-Loading-Pytorch
from Video_Dataset_Loading_Pytorch.video_dataset import VideoFrameDataset, ImglistToTensor

# Inceptionv4 implementation
# https://github.com/zhulf0804/Inceptionv4_and_Inception-ResNetv2.PyTorch
from inceptionv4_model.model.inceptionv4 import Inceptionv4

# Definicja funkcji oraz procedur trenujących/testujących

In [2]:
# Funkcja używana przy trenowaniu przez Colab
def custom_rats_loader(data_dir):
  !gdown 1Fn8LqXnsLnwRvIU9Zbpdlf5zIhwcrYnX
  !mkdir $data_dir
  !unzip -q RatDataset.zip
  !mv train $data_dir
  !mv test $data_dir
  !rm -rf RatDataset/

# https://stackoverflow.com/questions/74920920/pytorch-apply-data-augmentation-on-training-data-after-random-split
class TransformDataset(Dataset):
  def __init__(self, base_dataset, transformations):
    super(TransformDataset, self).__init__()
    self.base = base_dataset
    self.transformations = transformations

  def __len__(self):
    return len(self.base)

  def __getitem__(self, idx):
    x, y = self.base[idx]
    return self.transformations(x), y

class RatDataModule(pl.LightningDataModule):
    def __init__(self, downloader=custom_rats_loader,
                 data_dir: str ='ratset_optic/',
                 train_dataset_path: str='ratset_optic/train',
                 test_dataset_path: str='ratset_optic/test',
                 val_dataset_path=None):
        super().__init__()
        self.downloader = downloader
        self.data_dir = data_dir
        self.train_dataset_path = train_dataset_path
        self.test_dataset_path = test_dataset_path
        self.val_dataset_path = val_dataset_path
        self.image_size = (200, 200)
        self.imagenet_transform = Compose([Resize(self.image_size), ToTensor(), Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

        # zestaw transform bez augmentacji
        self.greyscale_transform = Compose([ImglistToTensor(), 
                                            Resize(self.image_size, antialias=None), 
                                            Grayscale(num_output_channels=1)])
        # zestaw transform używający augmentacji
        self.augmentation_transform = Compose([ImglistToTensor(), 
                                               Resize(self.image_size, antialias=None), 
                                               Grayscale(num_output_channels=1),
                                               RandomHorizontalFlip(),
                                               RandomVerticalFlip(),
                                               RandomRotation(30, interpolation=InterpolationMode.BILINEAR)])
        self.data_prepared = False

    def prepare_data(self):
        if not self.data_prepared:
          ## Odkomentować, gdy używane w Colab
          #!rm -rf $self.data_dir
          # download
          #self.downloader(self.data_dir)

          self.label_names = []
          for name in os.listdir(self.train_dataset_path):
            if name != "annotations.txt":
              self.label_names.append(name)
          self.label_names.sort()
          self.num_classes = len(self.label_names)
          print(f"Dataset prepared, classes: {self.num_classes}")
          print(self.label_names)
          self.batch_size = 10
          self.data_prepared = True
        else:
          pass

    def setup(self, stage=None):

        if stage == 'fit' or stage is None:

          annotation_path = os.path.join(self.train_dataset_path, "annotations.txt")
          dataset = VideoFrameDataset(
          root_path=self.train_dataset_path,
          annotationfile_path= annotation_path,
          num_segments=1,
          frames_per_segment=121,
          imagefile_template='{:03d}.jpg',
          #transform=self.augmentation_transform,
          test_mode=False
          )

          if self.val_dataset_path is None:
            train_dataset_size = int(len(dataset) * 0.9)
            raw_train_dataset, raw_val_dataset = random_split(dataset, [train_dataset_size, len(dataset) - train_dataset_size])

            self.train_dataset = TransformDataset(raw_train_dataset, self.augmentation_transform)
            self.val_dataset = TransformDataset(raw_val_dataset, self.greyscale_transform)

          else:
            self.train_dataset = dataset
            self.val_dataset = ImageFolder(self.val_dataset_path, transform=self.imagenet_transform)


        if stage == 'test' or stage is None:

          annotation_path = os.path.join(self.test_dataset_path, "annotations.txt")
          self.test_dataset = VideoFrameDataset(
          root_path=self.test_dataset_path,
          annotationfile_path= annotation_path,
          num_segments=1,
          frames_per_segment=121,
          imagefile_template='{:03d}.jpg',
          transform=self.greyscale_transform,
          test_mode=False
          )

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size)

In [3]:
# PredictionData - kontener danych, ułatwiający operacje na nich
# przechowuje predykcję, ground truth oraz opcjonalnie obraz
# używany do zbierania danych wysyłanych do W&B
class PredictionData():
  def __init__(self):
    self.values = []
    self.images = []
    self.names = []

  def append(self, values, image = None):
    self.values.append(values)
    if image is not None:
      self.images.append(image)

  def copy(self):
    copy = PredictionData(self.values, self.images, self.names)
    return copy

  def append_with_image(self):
    appended_list = []
    for i in range(len(self.images)):
      appended_list.append([self.images[i], self.names[i][0], self.names[i][1]])
    return appended_list

#SummaryLogger - zbiera statystyki o treningu
class SummaryLogger(Callback):
  def __init__(self):
      super().__init__()

  #generuje listę nazw klas z listy numerów klas
  def generate_label_names(self, trainer, data_list):
    data_list.names = []

    for values in data_list.values:
      namelabel = [trainer.datamodule.label_names[values[0]],
                   trainer.datamodule.label_names[values[1]]]
      data_list.names.append(namelabel)

  #pobiera po jednym przykładzie każdej klasy - zwraca ich listę
  def get_class_examples(self, dataset, label_names):
    imglist = []
    i = 0

    for data in dataset:
      im, label = data
      if label == i:
        imglist.append([wandb.Image(im[0]), label_names[label]])
        i += 1
    return imglist

  #pobiera po jednym przykładzie klas datasetu treningowego i wysyła do loggera
  def on_train_start(self, trainer, pl_module):

    example_list = self.get_class_examples(trainer.datamodule.train_dataset,
                                           trainer.datamodule.label_names)
    columns = ["image", "label"]

    test_example_table = wandb.Table(data=example_list, columns=columns)
    trainer.logger.experiment.log({"class_example_trainset" : test_example_table})

  #pobiera po jednym przykładzie klas datasetu testowego i wysyła do loggera
  #oraz pobiera jeden batch, augmentuje go oraz wysyła do loggera wynik
  #augmentacji
  def on_test_start(self, trainer, pl_module):

    example_list = self.get_class_examples(trainer.datamodule.test_dataset, trainer.datamodule.label_names)
    columns = ["image", "label"]

    test_example_table = wandb.Table(data=example_list, columns=columns)
    trainer.logger.experiment.log({"class_example_testset" : test_example_table})

  #wysyła do loggera dane zebrane podczas testu:
  #wszystkie predykcje
  #błędne predykcje oraz zdjęcie, które sieć nie sklasyfikowała poprawnie
  #wykres słupkowy ilości błędnych predykcji na klasę (zapisywana jest także tabela reprezentująca te dane)
  #macierz błędu (domyślna w&b oraz sklearn)
  def on_test_end(self, trainer, pl_module):

      test_data = pl_module.test_data
      wrong_data = pl_module.wrong_data

      self.generate_label_names(trainer, test_data)
      self.generate_label_names(trainer, wrong_data)
      wrong_data_appended = wrong_data.append_with_image()

      columns = ["label", "class_prediction"]
      test_table = wandb.Table(data=test_data.names, columns=columns)
      trainer.logger.experiment.log({"predictions" : test_table})

      wrong_columns = ["image" ,"label", "class_prediction"]
      wrong_table = wandb.Table(data=wrong_data_appended, columns=wrong_columns)
      trainer.logger.experiment.log({"wrong_predictions" : wrong_table})
      pl_module.test_data, pl_module.wrong_data = PredictionData(), PredictionData()

      data = [[name, error] for (name, error) in zip(trainer.datamodule.label_names, pl_module.errors)]
      table = wandb.Table(data=data, columns=["class_name", "errors"])
      trainer.logger.experiment.log({"class_error_chart" : wandb.plot.bar(table, "class_name",
                  "errors", title="Errors by class")})
      
      ground_truth = [row[0] for row in test_data.values]
      predictions = [row[1] for row in test_data.values]

      trainer.logger.experiment.log({"conf_mat" : wandb.plot.confusion_matrix(probs=None,
                                    y_true=ground_truth, preds=predictions,
                                    class_names=trainer.datamodule.label_names)})

      trainer.logger.experiment.log({"conf_mat_sklearn" : wandb.sklearn.plot_confusion_matrix(ground_truth, 
                                    predictions, trainer.datamodule.label_names)})
      
      ConfusionMatrixDisplay.from_predictions(y_true=ground_truth, y_pred=predictions, display_labels=trainer.datamodule.label_names)
      plt.xlabel("Predykcja klasy")
      plt.ylabel("Prawdziwa klasa")
      plt.show()

# Model dwustrumieniowy

In [4]:
class TwoStreamModel(nn.Module):
  def __init__(self,num_classes: int):
    super(TwoStreamModel, self).__init__()

    # model oparty na dwóch niemal identycznych sieciach densnet
    self.model_temporal = models.densenet121()
    # wymiana warstwy wejściowej, aby przyjmowały podawany format danych
    conv0 = nn.Conv2d(120, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    self.model_temporal.features.conv0 = conv0
    # usunięcie wbudowanego w densenet klasyfikatora
    # out = 1024
    self.model_temporal.classifier = nn.Identity()
    # wersja z modelem Inceptionv4 do przetwarzania danych czasowych (zakomentować kod powyżej i odkomentować fragmenty z podwójnym znakiem komentarza)
    ##self.model_temporal = Inceptionv4(in_channels=120)
    # out = 1536
    ##self.model_temporal.linear = nn.Identity()

    self.model_spatial = models.densenet121()
    # wymiana warstwy wejściowej, aby przyjmowały podawany format danych
    conv0 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    self.model_spatial.features.conv0 = conv0
    # usunięcie wbudowanego w densenet klasyfikatora
    # out = 1024
    self.model_spatial.classifier = nn.Identity()

    # klasyfikator przetwarzający końcowe dane (czy wydajność byłaby lepsza usuwając linear z densenet?)
    # in_features = DenseNet - 2048, Inception - 2560
    self.final_classifier = nn.Linear(in_features=2048, out_features=num_classes)

  def forward(self, x):

    # rozdzielenie danych na strumienie
    spatial_data = x[:,0]
    spatial_data = torch.unsqueeze(spatial_data, dim=1)
    temporal_data = x[:,1:121]

    # uzyskanie cech przestrzennych oraz czasowych
    spatial_out = self.model_spatial(spatial_data)
    temporal_out = self.model_temporal(temporal_data)

    # konkatenacja wektora w wektor przestrzenno-czasowy
    out = torch.cat((spatial_out, temporal_out), dim=1)

    # klasyfikacja
    out = self.final_classifier(out)
    return out

# Klasyfikator

In [5]:
#LightiningModule, w którym zawierają się procedury testu/treningu/walidacji
#oraz model
class RatBehaviorClassifier(pl.LightningModule):
    def __init__(self, num_classes: int, learning_rate: float=1e-3):
        super().__init__()
        self.save_hyperparameters()
        self.learning_rate = learning_rate
        self.num_classes = num_classes

        self.model = TwoStreamModel(self.num_classes)

        self.test_data = PredictionData()
        self.wrong_data = PredictionData()
        self.errors = np.zeros(self.num_classes)

    def forward(self, x):
        return self.model(x)

    def compute_loss(self, x, y):
        return F.cross_entropy(x, y)

    def common_step(self, batch, batch_idx):
        x, y = batch
        x = torch.squeeze(x, dim=2)
        outputs = self(x)
        loss = self.compute_loss(outputs,y)
        return loss, outputs, y

    def common_train_step(self, batch, batch_idx):
        loss, outputs, y = self.common_step(batch, batch_idx)
        preds = torch.argmax(outputs, dim=1)
        acc = torchmetrics.functional.accuracy(preds, y, num_classes = self.num_classes, task="multiclass")

        return loss, acc, y, preds

    def common_valid_step(self, batch, batch_idx):
        loss, outputs, y = self.common_step(batch, batch_idx)
        preds = torch.argmax(outputs, dim=1)
        acc = torchmetrics.functional.accuracy(preds, y, num_classes = self.num_classes, task="multiclass")

        return loss, acc, y, preds

    def common_test_step(self, batch, batch_idx):
        loss, acc, y, preds = self.common_valid_step(batch, batch_idx)

        preds_totable = preds.tolist()
        y_totable = y.tolist()
        x, _ = batch

        for i in range(len(y_totable)):

          self.test_data.append(values=[y_totable[i], preds_totable[i]])
          if y_totable[i] != preds_totable[i]:
            self.errors[y_totable[i]] += 1
            self.wrong_data.append(values=[y_totable[i], preds_totable[i]], image=wandb.Image(x[i][0]))

        return loss, acc

    def training_step(self, batch, batch_idx):
        loss, acc, _, _ = self.common_train_step(batch, batch_idx)
        self.log('train_loss', loss, on_step=True, on_epoch=True, logger=True)
        self.log('train_acc', acc, on_step=True, on_epoch=True, logger=True)
        return loss

    def validation_step(self, batch, batch_idx):
        loss, acc, _, _ = self.common_valid_step(batch, batch_idx)
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)
        return loss

    def test_step(self, batch, batch_idx):
        loss, acc = self.common_test_step(batch, batch_idx)
        self.log('test_loss', loss, prog_bar=True)
        self.log('test_acc', acc, prog_bar=True)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

In [6]:
#Callbacki używane podczas treningu

MODEL_CKPT_PATH = 'model/'
MODEL_CKPT = 'model-{epoch:02d}-{val_loss:.8f}'

checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',
    dirpath=MODEL_CKPT_PATH,
    filename=MODEL_CKPT,
    save_top_k=5,
    mode='min')

# Wczytanie datasetu oraz trening i test

In [None]:
# DataModule
dm = RatDataModule()
dm.prepare_data() # ściągnięcie danych, ustanowienie liczby klas
dm.setup()

In [None]:
# Inicjalizacja modelu - od zera lub z checkpointa
model = RatBehaviorClassifier(num_classes=dm.num_classes)
#model = RatBehaviorClassifier.load_from_checkpoint("best_models/model_threeclass_aug.ckpt")

# Inicjalizacja wandb logger
wandb_logger = WandbLogger(name="three-class-inception-optical-aug", project='rat-behavior-analysis', job_type='train')

# Inicjalizacja trenera
trainer = pl.Trainer(accelerator="auto", max_epochs=100, logger=wandb_logger,
                     callbacks=[checkpoint_callback, SummaryLogger()], log_every_n_steps=1)

In [None]:
# Trenowanie modelu
torch.set_float32_matmul_precision("high")
trainer.fit(model, dm)

In [10]:
#Wczytywanie najlepszego wyniku treningu - TYLKO PO TRENINGU CO NAJMNIEJ PRZEZ JEDNĄ EPOKĘ
best_model = RatBehaviorClassifier.load_from_checkpoint(checkpoint_callback.best_model_path)

In [None]:
#Test - do użycia po zakończeniu trenowania
# Ewaluacja modelu
trainer.test(model=best_model, datamodule=dm)

# Zamknięcie wandb
wandb.finish()

# Test wytrenowanego modelu wczytanego z checkpoint (do użycia tylko przy pominięciu treningu)

In [10]:
# DataModule
dm = RatDataModule()
dm.prepare_data() 
dm.setup()

Dataset prepared, classes: 3
['grooming', 'idle', 'reeling']


In [None]:
#Wczytywanie najlepszego wyniku treningu z folderu do ponownych testów
checkpoint_model = RatBehaviorClassifier.load_from_checkpoint("best_models/model_threeclass_aug.ckpt")

wandb_logger = WandbLogger(name="test_discard", project='rat-behavior-analysis', job_type='test')

#Trainer
trainer = pl.Trainer(accelerator="auto", max_epochs=100, logger=wandb_logger,
                     callbacks=[SummaryLogger()])

# Ewaluacja modelu
trainer.test(model=checkpoint_model, datamodule=dm)

# Zamknięcie wandb
wandb.finish()