# General

## CNN - RNN

### Implementation

In [None]:
import numpy as np
import pandas as pd
from pathlib import Path
from PIL import Image
import torch
from torch.utils.data import Dataset, WeightedRandomSampler
import torchvision.transforms as transforms

corrupted_images = [
    'Baroque/rembrandt_woman-standing-with-raised-hands.jpg',
    'Post_Impressionism/vincent-van-gogh_l-arlesienne-portrait-of-madame-ginoux-1890.jpg'
]

class TRAIN_GeneralArtGANDataset(Dataset):
    def __init__(self, data_path):
        super().__init__()
        self.data_path = data_path
        self.artist_classes_csv = pd.read_csv(data_path / f'artist_class.txt', sep=" ", names=['label', 'name'])
        self.genre_classes_csv = pd.read_csv(data_path / f'genre_class.txt', sep=" ", names=['label', 'name'])
        self.style_classes_csv = pd.read_csv(data_path / f'style_class.txt', sep=" ", names=['label', 'name'])

        # --------------------------- Cleaning data
        
        artist_csv = pd.read_csv(data_path / f'artist_train.csv', names=['filename', 'label'])
        genre_csv = pd.read_csv(data_path / f'genre_train.csv', names=['filename', 'label'])
        style_csv = pd.read_csv(data_path / f'style_train.csv', names=['filename', 'label'])
        artist_csv = artist_csv.query("filename not in @corrupted_images")
        genre_csv = genre_csv.query("filename not in @corrupted_images")
        style_csv = style_csv.query("filename not in @corrupted_images")
        
        import re
        import unicodedata
        def process_filename(f):
            dirname, filename = f.split('/', 1)
            normalized = unicodedata.normalize('NFKD', filename)
            ascii_filename = ''
            for char in normalized:
                if ord(char) < 128 and ord(char) != 39:
                    ascii_filename += char
                else:
                    replacements = {
                        'ä': 'a', 'ö': 'o', 'ü': 'u', 'ß': 'ss',
                        'á': 'a', 'é': 'e', 'í': 'i', 'ó': 'o', 'ú': 'u',
                        'à': 'a', 'è': 'e', 'ì': 'i', 'ò': 'o', 'ù': 'u',
                        'â': 'a', 'ê': 'e', 'î': 'i', 'ô': 'o', 'û': 'u',
                    }
                    ascii_filename += replacements.get(char, '_') # Replace with _ 
        
            ascii_filename = re.sub(r'[^a-zA-Z0-9_\-\.]', '_', ascii_filename)
            return dirname + "/" + ascii_filename
        
        artist_csv.loc[:, "filename"] = artist_csv["filename"].map(process_filename) # Pandas 3.0
        genre_csv.loc[:, "filename"] = genre_csv["filename"].map(process_filename) # Pandas 3.0
        style_csv.loc[:, "filename"] = style_csv["filename"].map(process_filename) # Pandas 3.0
        artist_genre = artist_csv.merge(genre_csv, how='outer', on='filename') # OUTER JOIN
        self.data_csv = artist_genre.merge(style_csv, how='outer', on='filename') # OUTER JOIN
        self.data_csv = self.data_csv.rename(columns={'label_x': 'artist', 'label_y': 'genre', 'label': 'style'})
        # Add dummy class for genre and style
        self.data_csv['artist'] = self.data_csv['artist'].fillna(len(self.artist_classes_csv))
        self.data_csv['genre'] = self.data_csv['genre'].fillna(len(self.genre_classes_csv))
        self.data_csv['style'] = self.data_csv['style'].fillna(len(self.style_classes_csv))
        
        self.imgs_path = self.data_csv["filename"].tolist()
        labels_artist = self.data_csv["artist"].tolist()
        labels_genre = self.data_csv["genre"].tolist()
        labels_style = self.data_csv["style"].tolist()
        self.labels = [(labels_artist[idx], labels_genre[idx], labels_style[idx]) for idx in range(len(self.imgs_path))]

        # --------------------------- Custom transforms for train dataset
        self.train_transforms = transforms.Compose([
            transforms.RandomResizedCrop((224,224)),
            transforms.RandomHorizontalFlip(0.3),
            transforms.RandomVerticalFlip(0.3),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        
    def __len__(self):
        return len(self.imgs_path)

    def get_num_classes(self):
        return (len(self.artist_classes_csv)+1, 
                len(self.genre_classes_csv)+1, 
                len(self.style_classes_csv)+1)

    def get_label_artist(self, label):
        return self.artist_classes_csv['name'][label]
        
    def get_label_genre(self, label):
        if label == len(self.genre_classes_csv):
            return "unknown genre"
        return self.genre_classes_csv['name'][label]
        
    def get_label_style(self, label):
        if label == len(self.style_classes_csv):
            return "unknown style"
        return self.style_classes_csv['name'][label]
    
    def transform(self, img):
        img = self.train_transforms(img)
        return img
    
    def __getitem__(self, idx):
        img = Image.open(self.data_path / self.imgs_path[idx])
        label = self.labels[idx] # (artist, genre, style)
        img = self.transform(img)
        return {
            'images': img,
            'labels': torch.tensor(label) # (3)
        }


class VAL_GeneralArtGANDataset(Dataset):
    def __init__(self, data_path):
        super().__init__()
        self.data_path = data_path
        self.artist_classes_csv = pd.read_csv(data_path / f'artist_class.txt', sep=" ", names=['label', 'name'])
        self.genre_classes_csv = pd.read_csv(data_path / f'genre_class.txt', sep=" ", names=['label', 'name'])
        self.style_classes_csv = pd.read_csv(data_path / f'style_class.txt', sep=" ", names=['label', 'name'])

        # --------------------------- Cleaning data
        
        artist_csv = pd.read_csv(data_path / f'artist_val.csv', names=['filename', 'label'])
        genre_csv = pd.read_csv(data_path / f'genre_val.csv', names=['filename', 'label'])
        style_csv = pd.read_csv(data_path / f'style_val.csv', names=['filename', 'label'])
        artist_csv = artist_csv.query("filename not in @corrupted_images")
        genre_csv = genre_csv.query("filename not in @corrupted_images")
        style_csv = style_csv.query("filename not in @corrupted_images")
        
        import re
        import unicodedata
        def process_filename(f):
            dirname, filename = f.split('/', 1)
            normalized = unicodedata.normalize('NFKD', filename)
            ascii_filename = ''
            for char in normalized:
                if ord(char) < 128 and ord(char) != 39:
                    ascii_filename += char
                else:
                    replacements = {
                        'ä': 'a', 'ö': 'o', 'ü': 'u', 'ß': 'ss',
                        'á': 'a', 'é': 'e', 'í': 'i', 'ó': 'o', 'ú': 'u',
                        'à': 'a', 'è': 'e', 'ì': 'i', 'ò': 'o', 'ù': 'u',
                        'â': 'a', 'ê': 'e', 'î': 'i', 'ô': 'o', 'û': 'u',
                    }
                    ascii_filename += replacements.get(char, '_') # Replace with _ 
        
            ascii_filename = re.sub(r'[^a-zA-Z0-9_\-\.]', '_', ascii_filename)
            return dirname + "/" + ascii_filename
        
        artist_csv.loc[:, "filename"] = artist_csv["filename"].map(process_filename) # Pandas 3.0
        genre_csv.loc[:, "filename"] = genre_csv["filename"].map(process_filename) # Pandas 3.0
        style_csv.loc[:, "filename"] = style_csv["filename"].map(process_filename) # Pandas 3.0
        artist_genre = artist_csv.merge(genre_csv, how='outer', on='filename') # OUTER JOIN
        self.data_csv = artist_genre.merge(style_csv, how='outer', on='filename') # OUTER JOIN
        self.data_csv = self.data_csv.rename(columns={'label_x': 'artist', 'label_y': 'genre', 'label': 'style'})
        # Add dummy class for genre and style
        self.data_csv['artist'] = self.data_csv['artist'].fillna(len(self.artist_classes_csv))
        self.data_csv['genre'] = self.data_csv['genre'].fillna(len(self.genre_classes_csv))
        self.data_csv['style'] = self.data_csv['style'].fillna(len(self.style_classes_csv))
        
        self.imgs_path = self.data_csv["filename"].tolist()
        labels_artist = self.data_csv["artist"].tolist()
        labels_genre = self.data_csv["genre"].tolist()
        labels_style = self.data_csv["style"].tolist()
        self.labels = [(labels_artist[idx], labels_genre[idx], labels_style[idx]) for idx in range(len(self.imgs_path))]

        # --------------------------- Custom transforms for train dataset
        self.val_transforms = transforms.Compose([
            transforms.RandomResizedCrop((224,224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        
    def __len__(self):
        return len(self.imgs_path)

    def get_num_classes(self):
        return (len(self.artist_classes_csv)+1, 
                len(self.genre_classes_csv)+1, 
                len(self.style_classes_csv)+1)

    def get_label_artist(self, label):
        if label == len(self.artist_classes_csv):
            return "unknown artist"
        return self.artist_classes_csv['name'][label]
        
    def get_label_genre(self, label):
        if label == len(self.genre_classes_csv):
            return "unknown genre"
        return self.genre_classes_csv['name'][label]
        
    def get_label_style(self, label):
        if label == len(self.style_classes_csv):
            return "unknown style"
        return self.style_classes_csv['name'][label]
    
    def transform(self, img):
        img = self.val_transforms(img)
        return img
    
    def __getitem__(self, idx):
        img = Image.open(self.data_path / self.imgs_path[idx])
        label = self.labels[idx] # (artist, genre, style)
        img = self.transform(img)
        return {
            'images': img,
            'labels': torch.tensor(label) # (3)
        }


path = Path('/kaggle/input/wikiart') 
train_dataset = TRAIN_GeneralArtGANDataset(data_path=path)
test_dataset = VAL_GeneralArtGANDataset(data_path=path)

In [2]:
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset, random_split

batch_size = 64
num_workers = 0
generator1 = torch.Generator().manual_seed(86)

train_dataset = TRAIN_GeneralArtGANDataset(data_path=path)
val_dataset = VAL_GeneralArtGANDataset(data_path=path)
    
train_dataloader = DataLoader(
    train_dataset, 
    batch_size=batch_size, 
    num_workers=num_workers, 
    shuffle=True
)
val_dataloader = DataLoader(
    val_dataset, 
    batch_size=batch_size, 
    num_workers=num_workers, 
    shuffle=False
)

In [3]:
import torch
import torch.nn as nn
from torchvision.models import resnet50, ResNet50_Weights

class CNN_RNN(nn.Module):
    def __init__(self, num_artist_classes, num_genre_classes, num_style_classes, lstm_hidden_dim=512):
        super().__init__()

        self.lstm_hidden_dim = lstm_hidden_dim
        self.num_artist_classes = num_artist_classes
        self.num_genre_classes = num_genre_classes
        self.num_style_classes = num_style_classes

        resnet = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
        self.cnn = nn.Sequential(*(list(resnet.children())[:-2]))
        self.pool = nn.AdaptiveAvgPool2d(7)

        self.lstm = nn.LSTM(input_size=2048, 
                            hidden_size=lstm_hidden_dim, 
                            num_layers=1, 
                            batch_first=True)
        
        self.artist_fc = nn.Linear(lstm_hidden_dim, num_artist_classes)
        self.genre_fc = nn.Linear(lstm_hidden_dim, num_genre_classes)
        self.style_fc = nn.Linear(lstm_hidden_dim, num_style_classes)

    def forward(self, images, labels=None):
        batch_size = images.shape[0]

        # CNN
        cnn_features = self.pool(self.cnn(images)) # (B, 2048, 7, 7)
        cnn_features = cnn_features.reshape(batch_size, 49, 2048) # (B, 49, 2048)

        # RNN
        _, (h_n, _) = self.lstm(cnn_features) # h_n (seq_len, )
        final_feature = h_n[-1]

        out_artist = self.artist_fc(final_feature)
        out_genre = self.genre_fc(final_feature)
        out_style = self.style_fc(final_feature)
        return out_artist, out_genre, out_style

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_artist_classes, num_genre_classes, num_style_classes = train_dataset.get_num_classes()
print(num_artist_classes, num_genre_classes, num_style_classes)
model = CNN_RNN(num_artist_classes, num_genre_classes, num_style_classes)

24 11 28


Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to /root/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 211MB/s]


In [4]:
def calculate_metrics(all_labels, all_pred_probs):
    '''
    Args:
        all_labels: true labels (len(val_dataset), n_labels)
        all_pred_probs: probability of each class for each labels Tuple<(len(val_dataset), n_classes), 
                                                                        (len(val_dataset), n_classes) 
                                                                        (len(val_dataset), n_classes)>
    '''

    global val_dataset
    global num_artist_classes, num_genre_classes, num_style_classes
    all_labels = all_labels.cpu().numpy()
    all_pred_labels = [all_pred_probs[i].argmax(axis=1) for i in range(3)]
    
    # TP[i][j] = true positive of class j in category (0: artist, 1: genre, 2: style)
    TP = [[0] * 28 for _ in range(3)]
    FP = [[0] * 28 for _ in range(3)]
    TN = [[0] * 28 for _ in range(3)]
    FN = [[0] * 28 for _ in range(3)]
    for i in range(3): # artist, genre, style
        for cls in range(val_dataset.get_num_classes()[i]): # class [0..n_classes]
            for idx in range(len(all_labels)):
                if all_labels[idx][i] == cls: 
                    if all_pred_labels[i][idx] == cls:
                        TP[i][cls] += 1
                    else:
                        FN[i][cls] += 1 
                else: 
                    if all_pred_labels[i][idx] == cls:
                        FP[i][cls] += 1
                    else:
                        TN[i][cls] += 1 
        
    # P: Per class, O: Overall 
    metrics = {}
    tp = 0.0
    tp_fn = 0.0
    tp_fp = 0.0
    # Artist
    for cls in range(num_artist_classes):
        cls_name = val_dataset.get_label_artist(cls)
        metrics[f'PP_{cls_name}'] = TP[0][cls] / max(1, TP[0][cls] + FN[0][cls])
        metrics[f'PR_{cls_name}'] = TP[0][cls] / max(1, TP[0][cls] + FP[0][cls])
        metrics[f'PF1_{cls_name}'] = 2 * (metrics[f'PP_{cls_name}'] * metrics[f'PR_{cls_name}']) / max(1, (metrics[f'PP_{cls_name}'] + metrics[f'PR_{cls_name}']))
        tp = tp + TP[0][cls]
        tp_fn = tp_fn + max(1e-5, TP[0][cls] + FN[0][cls])
        tp_fp = tp_fp + max(1e-5, TP[0][cls] + FP[0][cls])
    
    # Genre
    for cls in range(num_genre_classes):
        cls_name = val_dataset.get_label_genre(cls)
        metrics[f'PP_{cls_name}'] = TP[1][cls] / max(1, TP[1][cls] + FN[1][cls])
        metrics[f'PR_{cls_name}'] = TP[1][cls] / max(1, TP[1][cls] + FP[1][cls])
        metrics[f'PF1_{cls_name}'] = 2 * (metrics[f'PP_{cls_name}'] * metrics[f'PR_{cls_name}']) / max(1, (metrics[f'PP_{cls_name}'] + metrics[f'PR_{cls_name}']))
        tp = tp + TP[1][cls]
        tp_fn = tp_fn + max(1e-5, TP[1][cls] + FN[1][cls])
        tp_fp = tp_fp + max(1e-5, TP[1][cls] + FP[1][cls])
    
    # Style
    for cls in range(num_style_classes):
        cls_name = val_dataset.get_label_style(cls)
        metrics[f'PP_{cls_name}'] = TP[2][cls] / max(1, TP[2][cls] + FN[2][cls])
        metrics[f'PR_{cls_name}'] = TP[2][cls] / max(1, TP[2][cls] + FP[2][cls])
        metrics[f'PF1_{cls_name}'] = 2 * (metrics[f'PP_{cls_name}'] * metrics[f'PR_{cls_name}']) / max(1, (metrics[f'PP_{cls_name}'] + metrics[f'PR_{cls_name}']))
        tp = tp + TP[2][cls]
        tp_fn = tp_fn + max(1e-5, TP[2][cls] + FN[2][cls])
        tp_fp = tp_fp + max(1e-5, TP[2][cls] + FP[2][cls])
    
    metrics[f'OP'] = tp / tp_fn
    metrics[f'OR'] = tp / tp_fp
    metrics[f'OF1'] = 2 * (metrics[f'OP'] * metrics[f'OR']) / (metrics[f'OP'] + metrics[f'OR'])
    return metrics

def evaluate(model, dataloader, criterion, device):
    model.eval()
    losses = []
    all_labels = []
    all_pred_probs_artist, all_pred_probs_genre, all_pred_probs_style = [], [], []
    with torch.no_grad():
        with tqdm(total=len(val_dataset), desc=f'Validating', unit='img') as pbar:
            for batch in dataloader:
                images, labels = batch['images'], batch['labels']
                images = images.to(device, dtype=torch.float32, memory_format=torch.channels_last)
                labels = labels.to(device, dtype=torch.long) # (batch_size, 3)
                
                logits_pred = model(images, labels) # (batch_size, 3)
                loss_artist = criterion(logits_pred[0], labels[:, 0])
                loss_genre = criterion(logits_pred[1], labels[:, 1])
                loss_style = criterion(logits_pred[2], labels[:, 2])
                loss = loss_artist + loss_genre + loss_style
                
                losses.append(loss.item())
                all_pred_probs_artist.append(torch.softmax(logits_pred[0], dim=1))
                all_pred_probs_genre.append(torch.softmax(logits_pred[1], dim=1))
                all_pred_probs_style.append(torch.softmax(logits_pred[2], dim=1))
                all_labels.append(labels)
                pbar.update(images.shape[0])

    model.train()
    avg_loss = sum(losses) / len(losses)
    all_pred_probs_artist = torch.cat(all_pred_probs_artist)
    all_pred_probs_genre = torch.cat(all_pred_probs_genre)
    all_pred_probs_style = torch.cat(all_pred_probs_style)
    all_labels = torch.cat(all_labels)
    metrics = {}
    metrics = calculate_metrics(all_labels, (all_pred_probs_artist, all_pred_probs_genre, all_pred_probs_style))
    metrics['loss'] = avg_loss

    # Artist
    print(" \n================================ Artist ================================")
    for cls in range(num_artist_classes):
        cls_name = val_dataset.get_label_artist(cls)
        print(f"{cls_name} (Precision, Recall, F1 score): ({ metrics[f'PP_{cls_name}'] }, { metrics[f'PR_{cls_name}'] }, { metrics[f'PF1_{cls_name}'] })")
    
    # Genre
    print(" \n================================ Genre ================================")
    for cls in range(num_genre_classes):
        cls_name = val_dataset.get_label_genre(cls)
        print(f"{cls_name} (Precision, Recall, F1 score): ({ metrics[f'PP_{cls_name}'] }, { metrics[f'PR_{cls_name}'] }, { metrics[f'PF1_{cls_name}'] })")
    
    # Style
    print(" \n================================ Style ================================")
    for cls in range(num_style_classes):
        cls_name = val_dataset.get_label_style(cls)
        print(f"{cls_name} (Precision, Recall, F1 score): ({ metrics[f'PP_{cls_name}'] }, { metrics[f'PR_{cls_name}'] }, { metrics[f'PF1_{cls_name}'] })")
    
    print(" \n================================ Overall ================================")
    print(f"Overall Precision: {metrics['OP']}")
    print(f"Overall Recall: {metrics['OR']}")
    print(f"Overall F1: {metrics['OF1']}")
    return metrics

In [None]:
from tqdm import tqdm

n_epochs = 4
learning_rate = 5e-4
optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.000025)
criterion = nn.CrossEntropyLoss()
scheduler = optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.66)

global_step = 0
n_train = len(train_dataset)
threshold = 0.7

# Per epoch
train_losses = []
val_losses = []
val_f1 = []

# Per step
train_losses_steps = []

model.to(device)
best_model = model
best_val_f1 = 0.0

In [6]:
model_path = '/kaggle/input/wikiart_general_epoch2/pytorch/default/1/wikiart_general_CNN_RNN_epoch2.pt'
state_dict = torch.load(model_path, map_location=device)
model.load_state_dict(state_dict['model_state_dict'])
optimizer.load_state_dict(state_dict['optim_state_dict'])
scheduler.load_state_dict(state_dict['scheduler_state_dict'])
global_step = state_dict['global_step']
last_epoch = state_dict['current_epochs']
n_epochs = state_dict['n_epochs']

  state_dict = torch.load(model_path, map_location=device)


In [7]:
n_epochs = n_epochs + 2
for epoch in range(last_epoch+1, n_epochs+1):
    model.train()
    epoch_loss = 0.0
    with tqdm(total=n_train, desc=f'Epoch {epoch}/{n_epochs}', unit='img') as pbar:
        for batch in train_dataloader:
            images, labels = batch['images'], batch['labels']
            images = images.to(device, dtype=torch.float32, memory_format=torch.channels_last)
            labels = labels.to(device, dtype=torch.long) # (batch_size, 3)

            logits_pred = model(images, labels) # (3, batch_size)
            loss_artist = criterion(logits_pred[0], labels[:, 0]) # (batch_size, 24), (batch_size)
            loss_genre = criterion(logits_pred[1], labels[:, 1]) # (batch_size, 11), (batch_size)
            loss_style = criterion(logits_pred[2], labels[:, 2]) # (batch_size, 28), (batch_size)
            loss = loss_artist + loss_genre + loss_style
            
            epoch_loss += loss.item()
            train_losses_steps.append(loss.item() / batch_size)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            global_step += 1
            pbar.update(images.shape[0])
            pbar.set_postfix(**{f'loss (batch)': loss.item()})

    train_losses.append(epoch_loss / len(train_dataloader))
    val_metrics = evaluate(model, val_dataloader, criterion, device)
    scheduler.step(val_metrics['loss'])
    val_losses.append(val_metrics['loss'])
    val_f1.append(val_metrics['OF1'])
    if val_f1[-1] > best_val_f1:
        best_model = model
        best_val_f1 = val_f1[-1]

Epoch 3/4: 100%|██████████| 57338/57338 [59:09<00:00, 16.16img/s, loss (batch)=3.32]
Validating: 100%|██████████| 24673/24673 [11:37<00:00, 35.38img/s]


 
Albrecht_Durer (Precision, Recall, F1 score): (0.46774193548387094, 0.8345323741007195, 0.599483204134367)
Boris_Kustodiev (Precision, Recall, F1 score): (0.0, 0.0, 0.0)
Camille_Pissarro (Precision, Recall, F1 score): (0.12406015037593984, 0.868421052631579, 0.21547289275821133)
Childe_Hassam (Precision, Recall, F1 score): (0.04242424242424243, 0.5384615384615384, 0.04568764568764569)
Claude_Monet (Precision, Recall, F1 score): (0.4225, 0.4970588235294118, 0.42001470588235296)
Edgar_Degas (Precision, Recall, F1 score): (0.2786885245901639, 0.7183098591549296, 0.40036942969291156)
Eugene_Boudin (Precision, Recall, F1 score): (0.6024096385542169, 0.6896551724137931, 0.6430868167202574)
Gustave_Dore (Precision, Recall, F1 score): (0.92, 0.8697478991596639, 0.8941684665226783)
Ilya_Repin (Precision, Recall, F1 score): (0.0, 0.0, 0.0)
Ivan_Aivazovsky (Precision, Recall, F1 score): (0.5780346820809249, 0.8695652173913043, 0.6944444444444444)
Ivan_Shishkin (Precision, Recall, F1 score): (0.

Epoch 4/4: 100%|██████████| 57338/57338 [50:38<00:00, 18.87img/s, loss (batch)=3.09]
Validating: 100%|██████████| 24673/24673 [09:23<00:00, 43.78img/s]


 
Albrecht_Durer (Precision, Recall, F1 score): (0.657258064516129, 0.5970695970695971, 0.6257197696737044)
Boris_Kustodiev (Precision, Recall, F1 score): (0.021164021164021163, 1.0, 0.04145077720207253)
Camille_Pissarro (Precision, Recall, F1 score): (0.33458646616541354, 0.717741935483871, 0.4564102564102564)
Childe_Hassam (Precision, Recall, F1 score): (0.13333333333333333, 0.55, 0.14666666666666667)
Claude_Monet (Precision, Recall, F1 score): (0.42, 0.6339622641509434, 0.5052631578947369)
Edgar_Degas (Precision, Recall, F1 score): (0.30601092896174864, 0.7887323943661971, 0.4409448818897638)
Eugene_Boudin (Precision, Recall, F1 score): (0.5903614457831325, 0.8305084745762712, 0.6901408450704225)
Gustave_Dore (Precision, Recall, F1 score): (0.9333333333333333, 0.8641975308641975, 0.8974358974358974)
Ilya_Repin (Precision, Recall, F1 score): (0.031055900621118012, 0.5555555555555556, 0.03450655624568668)
Ivan_Aivazovsky (Precision, Recall, F1 score): (0.7630057803468208, 0.7252747252

In [8]:
# model = best_model
state_dict = {
    'lr': learning_rate,
    'global_step': global_step,
    'current_epochs': epoch,
    'n_epochs': n_epochs,
    'model_state_dict': model.state_dict(),
    'optim_state_dict': optimizer.state_dict(),
    'scheduler_state_dict': scheduler.state_dict()
}

torch.save(state_dict, str('/kaggle/working/wikiart_general_CNN_RNN_epoch{}.pt'.format(epoch)))

# Find outliers

Idea: Extract features -> UMAP -> Plot