In [None]:
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.models import vit_b_16, vit_b_32, mobilenet_v2, resnet18, alexnet, vgg16
import torch.nn.functional as F
import torch.nn as nn
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
import time
import pickle
from PIL import Image
import torch.optim as optim
from sklearn.metrics import f1_score, accuracy_score
from sklearn.model_selection import KFold
from sklearn.model_selection import train_test_split
from sklearn.model_selection import KFold
import gc

In [None]:
if torch.cuda.is_available():
    device = 'cuda'
elif torch.backends.mps.is_available():
    device = 'mps'
else:
    device = 'cpu'
print(f"device: {device}")

In [None]:
topicModel='LDA'
nClasses = 40

In [None]:
# dataset = np.load('LDA_dataset.npy')
# # dataset = np.load('NMF_dataset.npy')
# train_dataset, test_dataset = train_test_split(dataset, train_size=0.85 , random_state=42)
# train_dataset, val_dataset = train_test_split(train_dataset, train_size=0.8 , random_state=42)

In [None]:
# # def load_dataset(topicModel): # 'LDA' or 'NMF'
# #     dataset = np.load(topicModel + '_dataset.npy').tolist()
# def load_dataset(dataset):
#     dataset = dataset.tolist()
#     for i in range(len(dataset)):
#         if (i+1) % 100 == 0:
#             print('loading image {} of {}'.format(i+1, len(dataset)))
#         image = Image.open(dataset[i][0])
#         image = image.resize((224, 224))  # Resize image to 112x112 (adjust as needed)
#         image = image.convert('RGB')
#         image = np.array(image)
#         image = np.transpose(image, (2, 0, 1))
#         image = image / 255.0
#         image_tensor = torch.tensor(image, dtype=torch.float32)
#         dataset[i].append(image_tensor)
#     return dataset        

In [None]:
# train_dataset = load_dataset(train_dataset)
# val_dataset = load_dataset(val_dataset)
# test_dataset = load_dataset(test_dataset)

# with open(f'./{topicModel}_train.pkl', 'wb') as f:
#         pickle.dump(train_dataset, f)
# with open(f'./{topicModel}_val.pkl', 'wb') as f:
#         pickle.dump(val_dataset, f)
# with open(f'./{topicModel}_test.pkl', 'wb') as f:
#         pickle.dump(test_dataset, f)

In [None]:
with open(f'./data/processed/{topicModel}_train.pkl', 'rb') as f:
    train_dataset = pickle.load(f) 
with open(f'./data/processed/{topicModel}_val.pkl', 'rb') as f:
    val_dataset = pickle.load(f) 
with open(f'./data/processed/{topicModel}_test.pkl', 'rb') as f:
    test_dataset = pickle.load(f) 

In [None]:
class CustomDataset(Dataset):
    def __init__(self, data_array, transform=None):
        self.data_array = data_array
        self.transform = transform

    def __len__(self):
        return len(self.data_array)

    def __getitem__(self, index):
        image_path, label, image = self.data_array[index]

        if self.transform:
            image = self.transform(image)
        if isinstance(label, str):
            label = int(label)
        
        return image_path, label, image
    
# transform = transforms.Compose([
#     transforms.Resize((224, 224)),
# ])

train_dataset = CustomDataset(train_dataset, transform=None)
val_dataset = CustomDataset(val_dataset, transform=None)
test_dataset = CustomDataset(test_dataset, transform=None)

In [None]:
len(train_dataset), len(val_dataset), len(test_dataset)

In [None]:
def new_vit_b_16(nClasses, device):
    model = vit_b_16(weights='IMAGENET1K_V1')
#     model = vit_b_16(weights=None)

#     for name, param in model.named_parameters():
#         if 'encoder_layer_11' in name:
#             param.require_grad=True
#         else:
#             param.require_grad=False
                
    for param in model.parameters():
        param.requires_grad = False
    num_features = model.heads.head.in_features
    model.heads.head = torch.nn.Linear(num_features, nClasses)
    model = model.to(device)
    return model

In [None]:
def new_mobilenet_v2(nClasses, device):
    model = mobilenet_v2(weights=None)
    num_features = model.classifier[1].in_features
    model.classifier[1] = torch.nn.Linear(num_features, nClasses)
    model = model.to(device)
    return model

In [None]:
def new_resnet18(nClasses, device):
    model = resnet18(weights='IMAGENET1K_V1')
    # for param in model.parameters():
    #     param.requires_grad = False
    
    for name, child in model.named_children():
        if name in ['layer4']:  
            for param in child.parameters():
                param.requires_grad = True
        else:
            for param in child.parameters():
                param.requires_grad = False
            
    num_features = model.fc.in_features
    model.fc = torch.nn.Linear(num_features, nClasses)
    model = model.to(device)
    return model

In [None]:
def new_alexnet(nClasses, device):
    model = alexnet(weights='IMAGENET1K_V1')
    for param in model.parameters():
        param.requires_grad = False
    # model.classifier[4] = torch.nn.Linear(4096, 4096)
    num_features = model.classifier[6].in_features
    model.classifier[6] = torch.nn.Linear(num_features, nClasses)
    model = model.to(device)
    return model

In [None]:
# model = new_vit_b_16(nClasses, device)
# model

In [None]:
def new_vgg16(nClasses, device):
    model = vgg16(weights=None)
    num_features = model.classifier[6].in_features
    model.classifier[6] = torch.nn.Linear(num_features, nClasses)
    model = model.to(device)
    return model

In [None]:
def train_model(model, train_loader, criterion, optimizer, device='cpu'):
    model.train()
    preds = []
    targets = []
    epoch_loss = 0.0

    for imagepath, labels, inputs in train_loader:
        labels, inputs = labels.to(device), inputs.to(device)
        optimizer.zero_grad()
        out = model(inputs)
        loss = criterion(out, labels)
        loss.backward()
        optimizer.step()

        _, pred = torch.max(out, 1)
        preds.extend(pred.cpu().numpy())
        targets.extend(labels.cpu().numpy())
        epoch_loss += loss.item()

    train_loss = epoch_loss / len(train_loader)
    acc = accuracy_score(targets, preds)
    f1 = f1_score(targets, preds, average='macro')

    return train_loss, acc, f1

In [None]:
def evaluate_model(model, test_loader, criterion, device='cpu'):
    model.eval()
    preds = []
    top5preds = []
    imagepaths = []
    targets = []
    epoch_loss = 0.0

    with torch.no_grad():
        for imagepath, labels, inputs in test_loader:
            
            labels, inputs = labels.to(device), inputs.to(device)
            out = model(inputs)
            loss = criterion(out, labels)
            _, pred = torch.max(out, 1)
            preds.extend(pred.cpu().numpy())
            top5pred = torch.topk(out, 5)
            top5preds.extend(top5pred.indices.cpu().numpy())
            imagepaths.extend(list(imagepath))
            targets.extend(labels.cpu().numpy())
            epoch_loss += loss.item()

    val_loss = epoch_loss / len(test_loader)
    acc = accuracy_score(targets, preds)
    f1 = f1_score(targets, preds, average='macro')

    label_in_top5preds = np.any(top5preds == np.array(targets)[:, None], axis=1)
    top5_acc = np.mean(label_in_top5preds)
    
    return val_loss, acc, f1, top5_acc, top5preds, imagepaths 

In [None]:
def train_and_eval(train_data, val_data, nClasses, batch_size, num_epochs, lr=0.001, weight_decay=0, device='cpu'):
    train_scores = []
    val_scores = []
    train_losses = []
    valid_losses = []
    best_score = 0.0
    best_top5_acc = 0.0
    best_model = None
    
    gc.collect()
    torch.cuda.empty_cache()

    train_loader = DataLoader(train_data, batch_size=batch_size)
    val_loader = DataLoader(val_data, batch_size=batch_size)

    model = new_vit_b_16(nClasses, device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(1, num_epochs + 1):
        start_time = time.time()
        print(f'epoch {epoch}')
        # Train
        train_loss, train_acc, train_f1 = train_model(model, train_loader, criterion, optimizer, device)
        print(f'train loss: {train_loss}, train accuracy: {train_acc}, train f1score: {train_f1}')
        
        # Eval
        val_loss, val_acc, val_f1, top5_acc, _, _ = evaluate_model(model, val_loader, criterion, device)
        print(f'val loss: {val_loss}, val accuracy: {val_acc}, val f1score: {val_f1}, training time: {time.time() - start_time}')
        
        if val_f1 > best_score:
            best_score = val_f1
            best_model = model
            best_top5_acc = top5_acc
            
        train_losses.append(train_loss)
        valid_losses.append(val_loss)
        train_scores.append(train_f1)
        val_scores.append(val_f1)

    return train_losses, valid_losses, train_scores, val_scores, best_score, best_model, best_top5_acc

In [None]:
gc.collect()
torch.cuda.empty_cache()

In [None]:
num_epochs = 10
batch_sizes = [32]
learning_rate = [0.0005]
weight_decay = [0.0005]

In [None]:
best_train_losses = []
best_valid_losses = []
best_train_acc = [0]
best_val_acc = [0]
best_batch_size = 0
best_params = {}
best_model_tuned = None
best_score_tuned = 0.0
best_top5_acc_tuned = 0.0

for batch_size in batch_sizes:
    for lr in learning_rate:
        for wd in weight_decay:
            train_losses, valid_losses, train_acc, val_acc, best_score, best_model, best_top5_acc = train_and_eval(train_dataset, val_dataset, nClasses,
                                                                                                  batch_size, num_epochs, lr=lr, 
                                                                                                  weight_decay=wd, device=device)
            if best_score > best_score_tuned:
                best_train_losses = train_losses
                best_valid_losses = valid_losses
                best_train_acc = train_acc
                best_val_acc = val_acc
                best_batch_size = batch_size
                best_params = {"lr": lr, "weight_decay": wd}
                best_model_tuned = best_model
                best_score_tuned = best_score
                best_top5_acc_tuned = best_top5_acc

In [None]:
best_score_tuned, best_top5_acc_tuned

In [None]:
xrange = list(range(0, num_epochs))
# xrange = list(range(0, num_epochs + 1, 5))
plt.plot(xrange, best_train_losses, label="Training Loss")
plt.plot(xrange, best_valid_losses, label="Cross Validation Loss")
plt.title(f"Loss Curves ({topicModel})")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend(loc="best")
plt.grid()
plt.show()

In [None]:
gc.collect()
torch.cuda.empty_cache()

In [None]:
# torch.save(best_model_tuned.to('cpu'), f'./models/best_model_{topicModel}.pth')
best_model_tuned = torch.load(f'./models/best_model_{topicModel}.pth')
best_model_tuned = best_model_tuned.to(device)

In [None]:
# topicModel, num_epochs, best_batch_size, best_params

In [None]:
train_loader = DataLoader(train_dataset, batch_size=32)
val_loader = DataLoader(val_dataset, batch_size=32)
test_loader = DataLoader(test_dataset, batch_size=32)

In [None]:
# best_model_tuned = best_model_tuned.to(device)

In [None]:
criterion = nn.CrossEntropyLoss()
train_loss, train_acc, train_f1, train_top5_acc, train_top5preds, train_imagepaths = evaluate_model(best_model_tuned, train_loader, criterion, device=device)
val_loss, val_acc, val_f1, val_top5_acc, val_top5preds, val_imagepaths = evaluate_model(best_model_tuned, val_loader, criterion, device=device)
test_loss, test_acc, test_f1, test_top5_acc, test_top5preds, test_imagepaths = evaluate_model(best_model_tuned, test_loader, criterion, device=device)
print(f'train loss: {train_loss}, train accuracy {train_acc}, train f1score {train_f1}, train top5_acc {train_top5_acc}')
print(f'val loss: {val_loss}, val accuracy {val_acc}, val f1score {val_f1}, val top5_acc {val_top5_acc}')
print(f'test loss: {test_loss}, test accuracy {test_acc}, test f1score {test_f1}, test top5_acc {test_top5_acc}')

In [None]:
# np.savez('./outputs/model2_trainResult_' + topicModel + '.npz', trainResult_labels=train_top5preds, trainResult_imagePath=train_imagepaths)
# np.savez('./outputs/model2_valResult_' + topicModel + '.npz', valResult_labels=val_top5preds, valResult_imagePath=val_imagepaths)
# np.savez('./outputs/model2_testResult_' + topicModel + '.npz', testResult_labels=test_top5preds, testResult_imagePath=test_imagepaths)