In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import pandas as pd
from classifier import train_and_evaluate, inference, evaluate_validation,  inference, train_and_evaluate, evaluate_validation, prepare_validation_output
from classifier import MLP, Model_Dataset
import torch
from sklearn.model_selection import train_test_split
from pytorch_optimizer import Lookahead
import torch.nn as nn
from torch.cuda.amp import GradScaler, autocast
seed_value = 0  # You can use any seed value
from sklearn.metrics import classification_report
# Set seed for CPU
import numpy as np
torch.manual_seed(seed_value)

# Set seed for GPU if available
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed_value)
    torch.cuda.manual_seed_all(seed_value)
    
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)

Pyarrow will become a required dependency of pandas in the next major release of pandas (pandas 3.0),
(to allow more performant data types, such as the Arrow string type, and better interoperability with other libraries)
but was not found to be installed on your system.
If this would cause problems for you,
please provide us feedback at https://github.com/pandas-dev/pandas/issues/54466
        
  import pandas as pd
  from .autonotebook import tqdm as notebook_tqdm


# Prepare Dataset

In [4]:
df = pd.read_csv('image_new.csv')
del df['Unnamed: 0']
df['material_id'] = "image" + df.index.astype(str)
# df.loc[df['label'].isin(['Gore_Violence', 'horror']), 'label'] = 'Negative'

In [5]:
df['label'].unique()

array([0, 1])

In [6]:
df['label'].value_counts()

label
1    7986
0    2536
Name: count, dtype: int64

In [7]:
def update_label(row):
    if str(row)=='1':
        return "Positive"
    else:
        return "Negative"

In [7]:
df['label'] = df['label'].apply(update_label)

In [8]:
df.head(3)

Unnamed: 0,image,temp,label,material_id
0,gorey06ac982080e86f913e1ce19d00970315.jpg,gor,Negative,image0
1,gorey0_Terrifier-2.jpg,gor,Negative,image1
2,gorey1.jpg,gor,Negative,image2


In [9]:
df[df['label']=='Positive']

Unnamed: 0,image,temp,label,material_id
2536,pos0 - Copy.jpg,pos,Positive,image2536
2537,pos0.jpg,pos,Positive,image2537
2538,pos00.jpg,pos,Positive,image2538
2539,posb.jpg,pos,Positive,image2539
2540,posb21.jpg,pos,Positive,image2540
...,...,...,...,...
10517,pos_p_sddefault.jpg,pos,Positive,image10517
10518,pos_p_sea-cliff-hotel-general-fc98067.jpg,pos,Positive,image10518
10519,pos_p_secret-romantic-date-night-kissing-coupl...,pos,Positive,image10519
10520,pos_p__bkuO7brYtpSveutq5DWqIL2qQJ1Adz6o_9rBeFT...,pos,Positive,image10520


In [10]:
df[df['label']=='Negative']

Unnamed: 0,image,temp,label,material_id
0,gorey06ac982080e86f913e1ce19d00970315.jpg,gor,Negative,image0
1,gorey0_Terrifier-2.jpg,gor,Negative,image1
2,gorey1.jpg,gor,Negative,image2
3,gorey1498729.jpg,gor,Negative,image3
4,gorey1lqljQa.jpg,gor,Negative,image4
...,...,...,...,...
2531,horror_h_WM-Film-The-Scariest-Moment-from-Ever...,hor,Negative,image2531
2532,horror_h_zEqyD0SBt6HL7W9JQoWwtd5Do1T.jpg,hor,Negative,image2532
2533,horror_h_zombie-horror-monster-blood-face-260n...,hor,Negative,image2533
2534,horror_h_zombie-woman-bloody-face-portrait-hor...,hor,Negative,image2534


In [11]:
df.sample(5)

Unnamed: 0,image,temp,label,material_id
1024,horror5992_movie_posters.jpg,hor,Negative,image1024
220,gore_violence20_train to busan gore scene.jpg,gor,Negative,image220
3266,pos_1497_random human.jpg,pos,Positive,image3266
8664,pos_495_random sport.jpg,pos,Positive,image8664
1025,horror5993_movie_posters.jpg,hor,Negative,image1025


In [12]:
df_label = df.groupby("material_id")["label"].unique().reset_index().rename(columns = {"label" : "label_list"})
df["label"].unique()

array(['Negative', 'Positive'], dtype=object)

In [13]:
from sklearn.preprocessing import MultiLabelBinarizer
mlb = MultiLabelBinarizer(sparse_output=True)
mlb.fit([df["label"].unique().tolist()])

df_label = df_label.join(pd.DataFrame.sparse.from_spmatrix(
            mlb.transform(df_label.pop('label_list')),
            index = df_label.index,
            columns = mlb.classes_))

In [14]:
unique_labels = list(mlb.classes_)
unique_labels = sorted([x for x in unique_labels ])
one_hot_class_dict = {value: index for index, value in enumerate(unique_labels)}
df_label["target_encoding"] = df_label[unique_labels].values.tolist()

In [15]:
unique_labels = list(mlb.classes_)
unique_labels = sorted([x for x in unique_labels ])
one_hot_class_dict = {value: index for index, value in enumerate(unique_labels)}

In [16]:
one_hot_class_dict

{'Negative': 0, 'Positive': 1}

In [17]:
df = df_label

In [18]:
def update_label(row):
    if row['Negative'] == 1:
        return 'Negative'
    elif row['Positive'] == 1:
        return 'Positive'

    else:
        return None  # or any other default label you prefer

df['label'] = df.apply(update_label, axis=1)

In [19]:
from sklearn.model_selection import train_test_split


train_data, remaining_data = train_test_split(df, test_size=0.4, random_state=42,  stratify=df['label'])

validation_data, test_data = train_test_split(remaining_data, test_size=0.5, random_state=42, stratify=remaining_data['label'])


print("Train set size:", len(train_data))
print("Validation set size:", len(validation_data))
print("Test set size:", len(test_data))


Train set size: 6313
Validation set size: 2104
Test set size: 2105


In [20]:
df

Unnamed: 0,material_id,Negative,Positive,target_encoding,label
0,image0,1,0,"[1, 0]",Negative
1,image1,1,0,"[1, 0]",Negative
2,image10,1,0,"[1, 0]",Negative
3,image100,1,0,"[1, 0]",Negative
4,image1000,1,0,"[1, 0]",Negative
...,...,...,...,...,...
10517,image9995,0,1,"[0, 1]",Positive
10518,image9996,0,1,"[0, 1]",Positive
10519,image9997,0,1,"[0, 1]",Positive
10520,image9998,0,1,"[0, 1]",Positive


### Data Loaders

In [21]:
params = {'batch_size': 32, 
          'num_workers':2,  
          'shuffle': True}

params_inference = {'batch_size': 32, 
          'num_workers': 2, 
          'shuffle': True}

In [22]:
from torch.utils.data import DataLoader
train_generator = DataLoader(Model_Dataset(train_data,
                              embedding_directory = "./image_embeddings/",
                              id_column = "material_id",
                              target = "target_encoding",
                              return_labels = True),
                              **params)

val_generator = DataLoader(Model_Dataset(validation_data,
                                embedding_directory = "./image_embeddings/",
                                id_column = "material_id",
                                target = "target_encoding",
                                return_labels = True),
                                **params_inference)

inference_generator = DataLoader(Model_Dataset(test_data,
                                embedding_directory = "./image_embeddings/",
                                id_column = "material_id",
                                target = "target_encoding",
                                return_labels = True),
                                **params_inference)

### Best Model (Focal Loss a=1, g=2 , SGD)

In [99]:
sample = torch.load('./image_embeddings/image0.pt')
sample = torch.from_numpy(sample)
sample.shape

torch.Size([1024])

In [100]:
import torch
from classifier import Model_Dataset
from classifier import MLP, train_and_evaluate

In [101]:
model_args = {
        "batch_size" : 32, 
        "early_stopping_rounds": 20,
        "num_workers":0,
        "max_epochs" :500,
        "dimensions" :sample.shape[0],
        "hidden_dim" :768,
        "learning_rate" :1e-2,
        }

In [102]:
model = MLP(dimensions = model_args['dimensions'], hidden_dim = model_args['hidden_dim']
                                                , number_of_classes = 2, multilabel_bool=False).to(device)

In [103]:
optimizer = torch.optim.SGD(model.parameters(),lr=model_args['learning_rate'])
# optimizer = torch.optim.Adam(model.parameters(), lr=model_args['learning_rate'])
optimizer = Lookahead(optimizer, k=10, alpha=0.5)
scaler = GradScaler()

In [104]:
import torch.nn.functional as F 
# criterion = nn.CrossEntropyLoss()
class FocalLoss(nn.Module):
    '''
    FocalLoss - nn.Module 
    
    Focal loss applies a modulating term to the cross entropy loss in order to focus learning on hard misclassified examples.
    It is a dynamically scaled cross entropy loss, where the scaling factor decays to zero as confidence in the correct class increases.
    Link - https://paperswithcode.com/method/focal-loss#:~:text=Focal%20loss%20applies%20a%20modulating,in%20the%20correct%20class%20increases.
    '''
    
    def __init__(self, alpha=1, gamma=2):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, inputs, targets):
        bce_loss = F.binary_cross_entropy(inputs,  targets)
        loss = self.alpha * (1 - torch.exp(-bce_loss)) ** self.gamma * bce_loss
        return loss

criterion = FocalLoss()

In [105]:
updated_model_args = {
        "batch_size" : 32, 
        "early_stopping_rounds": 20,
        "num_workers":0,
        "max_epochs" :1000,
        "dimensions" :sample.shape[0],
        "hidden_dim" :768,
        "learning_rate" : 1e-2,
        "optimizer": optimizer,
        "criterion": criterion,
        "scaler":GradScaler()
        }

In [106]:
model.to(device)

MLP(
  (classifier): Sequential(
    (0): Linear(in_features=1024, out_features=768, bias=True)
    (1): Dropout(p=0.3, inplace=True)
    (2): LeakyReLU(negative_slope=0.01, inplace=True)
    (3): Linear(in_features=768, out_features=768, bias=True)
    (4): Dropout(p=0.3, inplace=True)
    (5): LeakyReLU(negative_slope=0.01, inplace=True)
    (6): Linear(in_features=768, out_features=2, bias=True)
  )
)

In [107]:
best_model, criterion, training_losses, eval_losses = train_and_evaluate(model,
                                                                         train_generator,
                                                                         val_generator,
                                                                         updated_model_args,
                                                                         device)

starting


Epoch: 1 Training Loss:0.127935 Eval Loss:0.092631:   0%|          | 1/1000 [00:03<1:02:35,  3.76s/it]

Found a better loss!
Epoch: 1 Training Loss:0.127935 Eval Loss:0.092631


Epoch: 2 Training Loss:0.079299 Eval Loss:0.065365:   0%|          | 2/1000 [00:07<1:04:09,  3.86s/it]

Found a better loss!
Epoch: 2 Training Loss:0.079299 Eval Loss:0.065365


Epoch: 3 Training Loss:0.058640 Eval Loss:0.048144:   0%|          | 3/1000 [00:11<1:03:29,  3.82s/it]

Found a better loss!
Epoch: 3 Training Loss:0.058640 Eval Loss:0.048144


Epoch: 4 Training Loss:0.044652 Eval Loss:0.037874:   0%|          | 4/1000 [00:15<1:02:32,  3.77s/it]

Found a better loss!
Epoch: 4 Training Loss:0.044652 Eval Loss:0.037874


Epoch: 5 Training Loss:0.034503 Eval Loss:0.029172:   0%|          | 5/1000 [00:18<1:02:09,  3.75s/it]

Found a better loss!
Epoch: 5 Training Loss:0.034503 Eval Loss:0.029172


Epoch: 6 Training Loss:0.027088 Eval Loss:0.021908:   1%|          | 6/1000 [00:22<1:02:48,  3.79s/it]

Found a better loss!
Epoch: 6 Training Loss:0.027088 Eval Loss:0.021908


Epoch: 7 Training Loss:0.022282 Eval Loss:0.017694:   1%|          | 7/1000 [00:26<1:03:35,  3.84s/it]

Found a better loss!
Epoch: 7 Training Loss:0.022282 Eval Loss:0.017694


Epoch: 8 Training Loss:0.017545 Eval Loss:0.014309:   1%|          | 8/1000 [00:31<1:06:39,  4.03s/it]

Found a better loss!
Epoch: 8 Training Loss:0.017545 Eval Loss:0.014309


Epoch: 9 Training Loss:0.014982 Eval Loss:0.012046:   1%|          | 9/1000 [00:35<1:08:17,  4.13s/it]

Found a better loss!
Epoch: 9 Training Loss:0.014982 Eval Loss:0.012046


Epoch: 10 Training Loss:0.012577 Eval Loss:0.010384:   1%|          | 10/1000 [00:39<1:08:27,  4.15s/it]

Found a better loss!
Epoch: 10 Training Loss:0.012577 Eval Loss:0.010384


Epoch: 11 Training Loss:0.010516 Eval Loss:0.009233:   1%|          | 11/1000 [00:43<1:06:45,  4.05s/it]

Found a better loss!
Epoch: 11 Training Loss:0.010516 Eval Loss:0.009233


Epoch: 12 Training Loss:0.009591 Eval Loss:0.007834:   1%|          | 12/1000 [00:47<1:05:17,  3.97s/it]

Found a better loss!
Epoch: 12 Training Loss:0.009591 Eval Loss:0.007834


Epoch: 13 Training Loss:0.008516 Eval Loss:0.007302:   1%|▏         | 13/1000 [00:51<1:05:10,  3.96s/it]

Found a better loss!
Epoch: 13 Training Loss:0.008516 Eval Loss:0.007302


Epoch: 14 Training Loss:0.007761 Eval Loss:0.006458:   1%|▏         | 14/1000 [00:55<1:06:13,  4.03s/it]

Found a better loss!
Epoch: 14 Training Loss:0.007761 Eval Loss:0.006458


Epoch: 15 Training Loss:0.007017 Eval Loss:0.006173:   2%|▏         | 15/1000 [00:59<1:07:47,  4.13s/it]

Found a better loss!
Epoch: 15 Training Loss:0.007017 Eval Loss:0.006173


Epoch: 16 Training Loss:0.006420 Eval Loss:0.005257:   2%|▏         | 16/1000 [01:03<1:04:07,  3.91s/it]

Found a better loss!
Epoch: 16 Training Loss:0.006420 Eval Loss:0.005257


Epoch: 17 - no improvement005978 Eval Loss:0.005387:   2%|▏         | 16/1000 [01:07<1:04:07,  3.91s/it]
Epoch: 17 - no improvement005978 Eval Loss:0.005387:   2%|▏         | 16/1000 [01:07<1:04:07,  3.91s/it]
Epoch: 18 Training Loss:0.005677 Eval Loss:0.004858:   2%|▏         | 18/1000 [01:10<1:03:41,  3.89s/it]                           

Found a better loss!
Epoch: 18 Training Loss:0.005677 Eval Loss:0.004858


Epoch: 19 Training Loss:0.005394 Eval Loss:0.004695:   2%|▏         | 19/1000 [01:15<1:06:17,  4.05s/it]

Found a better loss!
Epoch: 19 Training Loss:0.005394 Eval Loss:0.004695


Epoch: 20 Training Loss:0.005097 Eval Loss:0.004245:   2%|▏         | 20/1000 [01:19<1:04:15,  3.93s/it]

Found a better loss!
Epoch: 20 Training Loss:0.005097 Eval Loss:0.004245


Epoch: 21 Training Loss:0.005087 Eval Loss:0.004143:   2%|▏         | 21/1000 [01:23<1:05:11,  4.00s/it]

Found a better loss!
Epoch: 21 Training Loss:0.005087 Eval Loss:0.004143


Epoch: 22 - no improvement004836 Eval Loss:0.004313:   2%|▏         | 21/1000 [01:27<1:05:11,  4.00s/it]
Epoch: 22 - no improvement004836 Eval Loss:0.004313:   2%|▏         | 21/1000 [01:27<1:05:11,  4.00s/it]
Epoch: 23 Training Loss:0.004464 Eval Loss:0.003794:   2%|▏         | 23/1000 [01:31<1:09:21,  4.26s/it]                           

Found a better loss!
Epoch: 23 Training Loss:0.004464 Eval Loss:0.003794


Epoch: 24 - no improvement004195 Eval Loss:0.004027:   2%|▏         | 23/1000 [01:38<1:09:21,  4.26s/it]
Epoch: 24 - no improvement004195 Eval Loss:0.004027:   2%|▏         | 23/1000 [01:38<1:09:21,  4.26s/it]
Epoch: 25 Training Loss:0.004091 Eval Loss:0.003756:   2%|▎         | 25/1000 [01:45<1:28:40,  5.46s/it]                           

Found a better loss!
Epoch: 25 Training Loss:0.004091 Eval Loss:0.003756


Epoch: 26 Training Loss:0.003826 Eval Loss:0.003209:   3%|▎         | 26/1000 [01:51<1:32:17,  5.69s/it]

Found a better loss!
Epoch: 26 Training Loss:0.003826 Eval Loss:0.003209


Epoch: 27 - no improvement004114 Eval Loss:0.003690:   3%|▎         | 26/1000 [01:58<1:32:17,  5.69s/it]
Epoch: 27 - no improvement004114 Eval Loss:0.003690:   3%|▎         | 26/1000 [01:58<1:32:17,  5.69s/it]
Epoch: 28 - no improvement004063 Eval Loss:0.003263:   3%|▎         | 27/1000 [02:03<1:36:30,  5.95s/it]                           
Epoch: 28 - no improvement004063 Eval Loss:0.003263:   3%|▎         | 27/1000 [02:03<1:36:30,  5.95s/it]
Epoch: 29 - no improvement004021 Eval Loss:0.003265:   3%|▎         | 28/1000 [02:09<1:32:23,  5.70s/it]                           
Epoch: 29 - no improvement004021 Eval Loss:0.003265:   3%|▎         | 28/1000 [02:09<1:32:23,  5.70s/it]
Epoch: 30 - no improvement003596 Eval Loss:0.003266:   3%|▎         | 29/1000 [02:15<1:34:40,  5.85s/it]                           
Epoch: 30 - no improvement003596 Eval Loss:0.003266:   3%|▎         | 29/1000 [02:15<1:34:40,  5.85s/it]
Epoch: 31 Training Loss:0.003684 Eval Loss:0.003003:   3%|▎         | 31/1000 [

Found a better loss!
Epoch: 31 Training Loss:0.003684 Eval Loss:0.003003


Epoch: 32 - no improvement003736 Eval Loss:0.003156:   3%|▎         | 31/1000 [02:27<1:33:38,  5.80s/it]
Epoch: 32 - no improvement003736 Eval Loss:0.003156:   3%|▎         | 31/1000 [02:27<1:33:38,  5.80s/it]
Epoch: 33 Training Loss:0.003319 Eval Loss:0.002862:   3%|▎         | 33/1000 [02:32<1:32:00,  5.71s/it]                           

Found a better loss!
Epoch: 33 Training Loss:0.003319 Eval Loss:0.002862


Epoch: 34 Training Loss:0.003503 Eval Loss:0.002775:   3%|▎         | 34/1000 [02:38<1:35:08,  5.91s/it]

Found a better loss!
Epoch: 34 Training Loss:0.003503 Eval Loss:0.002775


Epoch: 35 - no improvement003252 Eval Loss:0.003104:   3%|▎         | 34/1000 [02:44<1:35:08,  5.91s/it]
Epoch: 35 - no improvement003252 Eval Loss:0.003104:   3%|▎         | 34/1000 [02:44<1:35:08,  5.91s/it]
Epoch: 36 - no improvement003250 Eval Loss:0.002793:   4%|▎         | 35/1000 [02:51<1:36:16,  5.99s/it]                           
Epoch: 36 - no improvement003250 Eval Loss:0.002793:   4%|▎         | 35/1000 [02:51<1:36:16,  5.99s/it]
Epoch: 37 Training Loss:0.003142 Eval Loss:0.002720:   4%|▎         | 37/1000 [02:57<1:38:23,  6.13s/it]                           

Found a better loss!
Epoch: 37 Training Loss:0.003142 Eval Loss:0.002720


Epoch: 38 Training Loss:0.003229 Eval Loss:0.002533:   4%|▍         | 38/1000 [03:04<1:41:40,  6.34s/it]

Found a better loss!
Epoch: 38 Training Loss:0.003229 Eval Loss:0.002533


Epoch: 39 - no improvement003039 Eval Loss:0.002845:   4%|▍         | 38/1000 [03:10<1:41:40,  6.34s/it]
Epoch: 39 - no improvement003039 Eval Loss:0.002845:   4%|▍         | 38/1000 [03:10<1:41:40,  6.34s/it]
Epoch: 40 Training Loss:0.003217 Eval Loss:0.002524:   4%|▍         | 40/1000 [03:16<1:40:48,  6.30s/it]                           

Found a better loss!
Epoch: 40 Training Loss:0.003217 Eval Loss:0.002524


Epoch: 41 - no improvement002958 Eval Loss:0.002530:   4%|▍         | 40/1000 [03:22<1:40:48,  6.30s/it]
Epoch: 41 - no improvement002958 Eval Loss:0.002530:   4%|▍         | 40/1000 [03:22<1:40:48,  6.30s/it]
Epoch: 42 - no improvement003003 Eval Loss:0.002842:   4%|▍         | 41/1000 [03:27<1:39:38,  6.23s/it]                           
Epoch: 42 - no improvement003003 Eval Loss:0.002842:   4%|▍         | 41/1000 [03:27<1:39:38,  6.23s/it]
Epoch: 43 - no improvement003107 Eval Loss:0.002581:   4%|▍         | 42/1000 [03:33<1:33:34,  5.86s/it]                           
Epoch: 43 - no improvement003107 Eval Loss:0.002581:   4%|▍         | 42/1000 [03:33<1:33:34,  5.86s/it]
Epoch: 44 - no improvement002983 Eval Loss:0.002528:   4%|▍         | 43/1000 [03:40<1:32:43,  5.81s/it]                           
Epoch: 44 - no improvement002983 Eval Loss:0.002528:   4%|▍         | 43/1000 [03:40<1:32:43,  5.81s/it]
Epoch: 45 - no improvement003043 Eval Loss:0.002723:   4%|▍         | 44/1000 [

Epoch 00046: reducing learning rate of group 0 to 1.0000e-03.


Epoch: 47 Training Loss:0.002789 Eval Loss:0.002516:   5%|▍         | 47/1000 [04:00<1:40:02,  6.30s/it]                           

Found a better loss!
Epoch: 47 Training Loss:0.002789 Eval Loss:0.002516


Epoch: 48 - no improvement003018 Eval Loss:0.002640:   5%|▍         | 47/1000 [04:06<1:40:02,  6.30s/it]
Epoch: 48 - no improvement003018 Eval Loss:0.002640:   5%|▍         | 47/1000 [04:06<1:40:02,  6.30s/it]
Epoch: 49 Training Loss:0.002714 Eval Loss:0.002231:   5%|▍         | 49/1000 [04:13<1:43:23,  6.52s/it]                           

Found a better loss!
Epoch: 49 Training Loss:0.002714 Eval Loss:0.002231


Epoch: 50 - no improvement003013 Eval Loss:0.002811:   5%|▍         | 49/1000 [04:19<1:43:23,  6.52s/it]
Epoch: 50 - no improvement003013 Eval Loss:0.002811:   5%|▍         | 49/1000 [04:19<1:43:23,  6.52s/it]
Epoch: 51 - no improvement002763 Eval Loss:0.002469:   5%|▌         | 50/1000 [04:25<1:42:18,  6.46s/it]                           
Epoch: 51 - no improvement002763 Eval Loss:0.002469:   5%|▌         | 50/1000 [04:25<1:42:18,  6.46s/it]
Epoch: 52 - no improvement002699 Eval Loss:0.002726:   5%|▌         | 51/1000 [04:32<1:40:42,  6.37s/it]                           
Epoch: 52 - no improvement002699 Eval Loss:0.002726:   5%|▌         | 51/1000 [04:32<1:40:42,  6.37s/it]
Epoch: 53 - no improvement002764 Eval Loss:0.002374:   5%|▌         | 52/1000 [04:38<1:41:22,  6.42s/it]                           
Epoch: 53 - no improvement002764 Eval Loss:0.002374:   5%|▌         | 52/1000 [04:38<1:41:22,  6.42s/it]
Epoch: 54 - no improvement002670 Eval Loss:0.002518:   5%|▌         | 53/1000 [

Epoch 00055: reducing learning rate of group 0 to 1.0000e-04.


Epoch: 56 - no improvement002709 Eval Loss:0.002580:   6%|▌         | 55/1000 [04:58<1:39:04,  6.29s/it]                           
Epoch: 56 - no improvement002709 Eval Loss:0.002580:   6%|▌         | 55/1000 [04:58<1:39:04,  6.29s/it]
Epoch: 57 - no improvement002928 Eval Loss:0.003033:   6%|▌         | 56/1000 [05:04<1:41:56,  6.48s/it]                           
Epoch: 57 - no improvement002928 Eval Loss:0.003033:   6%|▌         | 56/1000 [05:04<1:41:56,  6.48s/it]
Epoch: 58 - no improvement003467 Eval Loss:0.002792:   6%|▌         | 57/1000 [05:10<1:41:44,  6.47s/it]                           
Epoch: 58 - no improvement003467 Eval Loss:0.002792:   6%|▌         | 57/1000 [05:10<1:41:44,  6.47s/it]
Epoch: 59 - no improvement002804 Eval Loss:0.002276:   6%|▌         | 58/1000 [05:16<1:40:53,  6.43s/it]                           
Epoch: 59 - no improvement002804 Eval Loss:0.002276:   6%|▌         | 58/1000 [05:16<1:40:53,  6.43s/it]
Epoch: 60 - no improvement002747 Eval Loss:0.002482:

Epoch 00061: reducing learning rate of group 0 to 1.0000e-05.


Epoch: 62 - no improvement002866 Eval Loss:0.002544:   6%|▌         | 61/1000 [05:35<1:35:51,  6.13s/it]                           
Epoch: 62 - no improvement002866 Eval Loss:0.002544:   6%|▌         | 61/1000 [05:35<1:35:51,  6.13s/it]
Epoch: 63 - no improvement002809 Eval Loss:0.002496:   6%|▌         | 62/1000 [05:41<1:41:19,  6.48s/it]                           
Epoch: 63 - no improvement002809 Eval Loss:0.002496:   6%|▌         | 62/1000 [05:41<1:41:19,  6.48s/it]
Epoch: 64 - no improvement002751 Eval Loss:0.002576:   6%|▋         | 63/1000 [05:47<1:38:49,  6.33s/it]                           
Epoch: 64 - no improvement002751 Eval Loss:0.002576:   6%|▋         | 63/1000 [05:47<1:38:49,  6.33s/it]
Epoch: 65 - no improvement002982 Eval Loss:0.002528:   6%|▋         | 64/1000 [05:53<1:34:48,  6.08s/it]                           
Epoch: 65 - no improvement002982 Eval Loss:0.002528:   6%|▋         | 64/1000 [05:53<1:34:48,  6.08s/it]
Epoch: 66 - no improvement002733 Eval Loss:0.002633:

Epoch 00067: reducing learning rate of group 0 to 1.0000e-06.


Epoch: 68 - no improvement002964 Eval Loss:0.002741:   7%|▋         | 67/1000 [06:12<1:35:55,  6.17s/it]                           
Epoch: 68 - no improvement002964 Eval Loss:0.002741:   7%|▋         | 67/1000 [06:12<1:35:55,  6.17s/it]
Epoch: 69 - no improvement002692 Eval Loss:0.002342:   7%|▋         | 68/1000 [06:17<1:35:49,  6.17s/it]                           
Epoch: 69 - no improvement002692 Eval Loss:0.002342:   7%|▋         | 68/1000 [06:17<1:35:49,  6.17s/it]
Epoch: 69 Training Loss:0.002692 Eval Loss:0.002342:   7%|▋         | 68/1000 [06:17<1:26:18,  5.56s/it]

No improvement in over 20 epochs, early break





In [108]:
val_predictions, val_truths = inference(val_generator,
                                        best_model,
                                        criterion,
                                        device,
                                        validation = True)

In [109]:
prediction_array_res, classification_report_dict, threshold_chosen = evaluate_validation(val_predictions,
                                                                                    val_truths,
                                                                                    one_hot_class_dict,
                                                                                    num_of_class = 2,
                                                                                    multilabel_bool = False)
print('Threshold chosen',round(threshold_chosen,4))

              precision    recall  f1-score   support

    Negative       0.91      0.94      0.93       507
    Positive       0.98      0.97      0.98      1597

    accuracy                           0.96      2104
   macro avg       0.95      0.96      0.95      2104
weighted avg       0.96      0.96      0.96      2104

Threshold chosen 0.6


In [110]:
import numpy as np
from sklearn.metrics import classification_report
def evaluate_validation(predictions, truths, one_hot_class_dict, multilabel_bool, num_of_class):
    
    if multilabel_bool:
        prediction_thresh = torch.tensor(np.where(predictions >= threshold_chosen, 1, 0))
        class_report_dict = classification_report(truths, prediction_thresh, target_names= one_hot_class_dict.keys(), output_dict = True)
        print(classification_report(truths, prediction_thresh, target_names= one_hot_class_dict.keys()))
        return prediction_thresh.numpy(), class_report_dict, threshold_chosen
    
    else:
        if num_of_class ==2:
            #binary
            indices = np.where(truths == 1)
            indices = indices[1]
            pred = predictions[:, 1]

            predictions_thresh = torch.tensor(np.where(pred >= 0.5, 1, 0))
            class_report_dict = classification_report(indices, predictions_thresh, target_names= one_hot_class_dict.keys(), output_dict = True)
            print(classification_report(indices, predictions_thresh, target_names= one_hot_class_dict.keys()))     
            return predictions_thresh.numpy(), class_report_dict, threshold_chosen
        
            
        else: 
            prediction_array_argmax = torch.argmax(torch.tensor(predictions), dim = 1)
            prediction_array_res = torch.zeros_like(torch.tensor(predictions)).scatter_(1, prediction_array_argmax.unsqueeze(1), 1.)
            class_report_dict = classification_report(truths, prediction_array_res, target_names= one_hot_class_dict.keys(), output_dict = True)
            print(classification_report(truths, prediction_array_res, target_names= one_hot_class_dict.keys()))

            return prediction_array_res.numpy(), class_report_dict, threshold_chosen

In [111]:
inf_predictions, inf_truths = inference(inference_generator,
                                        best_model,
                                        criterion,
                                        device,
                                        validation = True)

prediction_array_res, classification_report_dict, threshold_chosen = evaluate_validation(inf_predictions,
                                                                                    inf_truths,
                                                                                    one_hot_class_dict,
                                                                                    False, num_of_class = 2)

              precision    recall  f1-score   support

    Negative       0.91      0.94      0.93       507
    Positive       0.98      0.97      0.98      1598

    accuracy                           0.96      2105
   macro avg       0.95      0.96      0.95      2105
weighted avg       0.96      0.96      0.96      2105

